001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.stat.clustering;
019
020import java.util.ArrayList;
021import java.util.Collection;
022import java.util.Collections;
023import java.util.List;
024import java.util.Random;
025
026import org.apache.commons.math3.exception.ConvergenceException;
027import org.apache.commons.math3.exception.MathIllegalArgumentException;
028import org.apache.commons.math3.exception.NumberIsTooSmallException;
029import org.apache.commons.math3.exception.util.LocalizedFormats;
030import org.apache.commons.math3.stat.descriptive.moment.Variance;
031import org.apache.commons.math3.util.MathUtils;
032
033/**
034 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
035 * @param <T> type of the points to cluster
036 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
037 * @since 2.0
038 * @deprecated As of 3.2 (to be removed in 4.0),
039 * use {@link org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer} instead
040 */
041@Deprecated
042public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
043
044    /** Strategies to use for replacing an empty cluster. */
045    public enum EmptyClusterStrategy {
046
047        /** Split the cluster with largest distance variance. */
048        LARGEST_VARIANCE,
049
050        /** Split the cluster with largest number of points. */
051        LARGEST_POINTS_NUMBER,
052
053        /** Create a cluster around the point farthest from its centroid. */
054        FARTHEST_POINT,
055
056        /** Generate an error. */
057        ERROR
058
059    }
060
061    /** Random generator for choosing initial centers. */
062    private final Random random;
063
064    /** Selected strategy for empty clusters. */
065    private final EmptyClusterStrategy emptyStrategy;
066
067    /** Build a clusterer.
068     * <p>
069     * The default strategy for handling empty clusters that may appear during
070     * algorithm iterations is to split the cluster with largest distance variance.
071     * </p>
072     * @param random random generator to use for choosing initial centers
073     */
074    public KMeansPlusPlusClusterer(final Random random) {
075        this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
076    }
077
078    /** Build a clusterer.
079     * @param random random generator to use for choosing initial centers
080     * @param emptyStrategy strategy to use for handling empty clusters that
081     * may appear during algorithm iterations
082     * @since 2.2
083     */
084    public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
085        this.random        = random;
086        this.emptyStrategy = emptyStrategy;
087    }
088
089    /**
090     * Runs the K-means++ clustering algorithm.
091     *
092     * @param points the points to cluster
093     * @param k the number of clusters to split the data into
094     * @param numTrials number of trial runs
095     * @param maxIterationsPerTrial the maximum number of iterations to run the algorithm
096     *     for at each trial run.  If negative, no maximum will be used
097     * @return a list of clusters containing the points
098     * @throws MathIllegalArgumentException if the data points are null or the number
099     *     of clusters is larger than the number of data points
100     * @throws ConvergenceException if an empty cluster is encountered and the
101     * {@link #emptyStrategy} is set to {@code ERROR}
102     */
103    public List<Cluster<T>> cluster(final Collection<T> points, final int k,
104                                    int numTrials, int maxIterationsPerTrial)
105        throws MathIllegalArgumentException, ConvergenceException {
106
107        // at first, we have not found any clusters list yet
108        List<Cluster<T>> best = null;
109        double bestVarianceSum = Double.POSITIVE_INFINITY;
110
111        // do several clustering trials
112        for (int i = 0; i < numTrials; ++i) {
113
114            // compute a clusters list
115            List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial);
116
117            // compute the variance of the current list
118            double varianceSum = 0.0;
119            for (final Cluster<T> cluster : clusters) {
120                if (!cluster.getPoints().isEmpty()) {
121
122                    // compute the distance variance of the current cluster
123                    final T center = cluster.getCenter();
124                    final Variance stat = new Variance();
125                    for (final T point : cluster.getPoints()) {
126                        stat.increment(point.distanceFrom(center));
127                    }
128                    varianceSum += stat.getResult();
129
130                }
131            }
132
133            if (varianceSum <= bestVarianceSum) {
134                // this one is the best we have found so far, remember it
135                best            = clusters;
136                bestVarianceSum = varianceSum;
137            }
138
139        }
140
141        // return the best clusters list found
142        return best;
143
144    }
145
146    /**
147     * Runs the K-means++ clustering algorithm.
148     *
149     * @param points the points to cluster
150     * @param k the number of clusters to split the data into
151     * @param maxIterations the maximum number of iterations to run the algorithm
152     *     for.  If negative, no maximum will be used
153     * @return a list of clusters containing the points
154     * @throws MathIllegalArgumentException if the data points are null or the number
155     *     of clusters is larger than the number of data points
156     * @throws ConvergenceException if an empty cluster is encountered and the
157     * {@link #emptyStrategy} is set to {@code ERROR}
158     */
159    public List<Cluster<T>> cluster(final Collection<T> points, final int k,
160                                    final int maxIterations)
161        throws MathIllegalArgumentException, ConvergenceException {
162
163        // sanity checks
164        MathUtils.checkNotNull(points);
165
166        // number of clusters has to be smaller or equal the number of data points
167        if (points.size() < k) {
168            throw new NumberIsTooSmallException(points.size(), k, false);
169        }
170
171        // create the initial clusters
172        List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
173
174        // create an array containing the latest assignment of a point to a cluster
175        // no need to initialize the array, as it will be filled with the first assignment
176        int[] assignments = new int[points.size()];
177        assignPointsToClusters(clusters, points, assignments);
178
179        // iterate through updating the centers until we're done
180        final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
181        for (int count = 0; count < max; count++) {
182            boolean emptyCluster = false;
183            List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
184            for (final Cluster<T> cluster : clusters) {
185                final T newCenter;
186                if (cluster.getPoints().isEmpty()) {
187                    switch (emptyStrategy) {
188                        case LARGEST_VARIANCE :
189                            newCenter = getPointFromLargestVarianceCluster(clusters);
190                            break;
191                        case LARGEST_POINTS_NUMBER :
192                            newCenter = getPointFromLargestNumberCluster(clusters);
193                            break;
194                        case FARTHEST_POINT :
195                            newCenter = getFarthestPoint(clusters);
196                            break;
197                        default :
198                            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
199                    }
200                    emptyCluster = true;
201                } else {
202                    newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
203                }
204                newClusters.add(new Cluster<T>(newCenter));
205            }
206            int changes = assignPointsToClusters(newClusters, points, assignments);
207            clusters = newClusters;
208
209            // if there were no more changes in the point-to-cluster assignment
210            // and there are no empty clusters left, return the current clusters
211            if (changes == 0 && !emptyCluster) {
212                return clusters;
213            }
214        }
215        return clusters;
216    }
217
218    /**
219     * Adds the given points to the closest {@link Cluster}.
220     *
221     * @param <T> type of the points to cluster
222     * @param clusters the {@link Cluster}s to add the points to
223     * @param points the points to add to the given {@link Cluster}s
224     * @param assignments points assignments to clusters
225     * @return the number of points assigned to different clusters as the iteration before
226     */
227    private static <T extends Clusterable<T>> int
228        assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points,
229                               final int[] assignments) {
230        int assignedDifferently = 0;
231        int pointIndex = 0;
232        for (final T p : points) {
233            int clusterIndex = getNearestCluster(clusters, p);
234            if (clusterIndex != assignments[pointIndex]) {
235                assignedDifferently++;
236            }
237
238            Cluster<T> cluster = clusters.get(clusterIndex);
239            cluster.addPoint(p);
240            assignments[pointIndex++] = clusterIndex;
241        }
242
243        return assignedDifferently;
244    }
245
246    /**
247     * Use K-means++ to choose the initial centers.
248     *
249     * @param <T> type of the points to cluster
250     * @param points the points to choose the initial centers from
251     * @param k the number of centers to choose
252     * @param random random generator to use
253     * @return the initial centers
254     */
255    private static <T extends Clusterable<T>> List<Cluster<T>>
256        chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
257
258        // Convert to list for indexed access. Make it unmodifiable, since removal of items
259        // would screw up the logic of this method.
260        final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
261
262        // The number of points in the list.
263        final int numPoints = pointList.size();
264
265        // Set the corresponding element in this array to indicate when
266        // elements of pointList are no longer available.
267        final boolean[] taken = new boolean[numPoints];
268
269        // The resulting list of initial centers.
270        final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
271
272        // Choose one center uniformly at random from among the data points.
273        final int firstPointIndex = random.nextInt(numPoints);
274
275        final T firstPoint = pointList.get(firstPointIndex);
276
277        resultSet.add(new Cluster<T>(firstPoint));
278
279        // Must mark it as taken
280        taken[firstPointIndex] = true;
281
282        // To keep track of the minimum distance squared of elements of
283        // pointList to elements of resultSet.
284        final double[] minDistSquared = new double[numPoints];
285
286        // Initialize the elements.  Since the only point in resultSet is firstPoint,
287        // this is very easy.
288        for (int i = 0; i < numPoints; i++) {
289            if (i != firstPointIndex) { // That point isn't considered
290                double d = firstPoint.distanceFrom(pointList.get(i));
291                minDistSquared[i] = d*d;
292            }
293        }
294
295        while (resultSet.size() < k) {
296
297            // Sum up the squared distances for the points in pointList not
298            // already taken.
299            double distSqSum = 0.0;
300
301            for (int i = 0; i < numPoints; i++) {
302                if (!taken[i]) {
303                    distSqSum += minDistSquared[i];
304                }
305            }
306
307            // Add one new data point as a center. Each point x is chosen with
308            // probability proportional to D(x)2
309            final double r = random.nextDouble() * distSqSum;
310
311            // The index of the next point to be added to the resultSet.
312            int nextPointIndex = -1;
313
314            // Sum through the squared min distances again, stopping when
315            // sum >= r.
316            double sum = 0.0;
317            for (int i = 0; i < numPoints; i++) {
318                if (!taken[i]) {
319                    sum += minDistSquared[i];
320                    if (sum >= r) {
321                        nextPointIndex = i;
322                        break;
323                    }
324                }
325            }
326
327            // If it's not set to >= 0, the point wasn't found in the previous
328            // for loop, probably because distances are extremely small.  Just pick
329            // the last available point.
330            if (nextPointIndex == -1) {
331                for (int i = numPoints - 1; i >= 0; i--) {
332                    if (!taken[i]) {
333                        nextPointIndex = i;
334                        break;
335                    }
336                }
337            }
338
339            // We found one.
340            if (nextPointIndex >= 0) {
341
342                final T p = pointList.get(nextPointIndex);
343
344                resultSet.add(new Cluster<T> (p));
345
346                // Mark it as taken.
347                taken[nextPointIndex] = true;
348
349                if (resultSet.size() < k) {
350                    // Now update elements of minDistSquared.  We only have to compute
351                    // the distance to the new center to do this.
352                    for (int j = 0; j < numPoints; j++) {
353                        // Only have to worry about the points still not taken.
354                        if (!taken[j]) {
355                            double d = p.distanceFrom(pointList.get(j));
356                            double d2 = d * d;
357                            if (d2 < minDistSquared[j]) {
358                                minDistSquared[j] = d2;
359                            }
360                        }
361                    }
362                }
363
364            } else {
365                // None found --
366                // Break from the while loop to prevent
367                // an infinite loop.
368                break;
369            }
370        }
371
372        return resultSet;
373    }
374
375    /**
376     * Get a random point from the {@link Cluster} with the largest distance variance.
377     *
378     * @param clusters the {@link Cluster}s to search
379     * @return a random point from the selected cluster
380     * @throws ConvergenceException if clusters are all empty
381     */
382    private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters)
383    throws ConvergenceException {
384
385        double maxVariance = Double.NEGATIVE_INFINITY;
386        Cluster<T> selected = null;
387        for (final Cluster<T> cluster : clusters) {
388            if (!cluster.getPoints().isEmpty()) {
389
390                // compute the distance variance of the current cluster
391                final T center = cluster.getCenter();
392                final Variance stat = new Variance();
393                for (final T point : cluster.getPoints()) {
394                    stat.increment(point.distanceFrom(center));
395                }
396                final double variance = stat.getResult();
397
398                // select the cluster with the largest variance
399                if (variance > maxVariance) {
400                    maxVariance = variance;
401                    selected = cluster;
402                }
403
404            }
405        }
406
407        // did we find at least one non-empty cluster ?
408        if (selected == null) {
409            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
410        }
411
412        // extract a random point from the cluster
413        final List<T> selectedPoints = selected.getPoints();
414        return selectedPoints.remove(random.nextInt(selectedPoints.size()));
415
416    }
417
418    /**
419     * Get a random point from the {@link Cluster} with the largest number of points
420     *
421     * @param clusters the {@link Cluster}s to search
422     * @return a random point from the selected cluster
423     * @throws ConvergenceException if clusters are all empty
424     */
425    private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) throws ConvergenceException {
426
427        int maxNumber = 0;
428        Cluster<T> selected = null;
429        for (final Cluster<T> cluster : clusters) {
430
431            // get the number of points of the current cluster
432            final int number = cluster.getPoints().size();
433
434            // select the cluster with the largest number of points
435            if (number > maxNumber) {
436                maxNumber = number;
437                selected = cluster;
438            }
439
440        }
441
442        // did we find at least one non-empty cluster ?
443        if (selected == null) {
444            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
445        }
446
447        // extract a random point from the cluster
448        final List<T> selectedPoints = selected.getPoints();
449        return selectedPoints.remove(random.nextInt(selectedPoints.size()));
450
451    }
452
453    /**
454     * Get the point farthest to its cluster center
455     *
456     * @param clusters the {@link Cluster}s to search
457     * @return point farthest to its cluster center
458     * @throws ConvergenceException if clusters are all empty
459     */
460    private T getFarthestPoint(final Collection<Cluster<T>> clusters) throws ConvergenceException {
461
462        double maxDistance = Double.NEGATIVE_INFINITY;
463        Cluster<T> selectedCluster = null;
464        int selectedPoint = -1;
465        for (final Cluster<T> cluster : clusters) {
466
467            // get the farthest point
468            final T center = cluster.getCenter();
469            final List<T> points = cluster.getPoints();
470            for (int i = 0; i < points.size(); ++i) {
471                final double distance = points.get(i).distanceFrom(center);
472                if (distance > maxDistance) {
473                    maxDistance     = distance;
474                    selectedCluster = cluster;
475                    selectedPoint   = i;
476                }
477            }
478
479        }
480
481        // did we find at least one non-empty cluster ?
482        if (selectedCluster == null) {
483            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
484        }
485
486        return selectedCluster.getPoints().remove(selectedPoint);
487
488    }
489
490    /**
491     * Returns the nearest {@link Cluster} to the given point
492     *
493     * @param <T> type of the points to cluster
494     * @param clusters the {@link Cluster}s to search
495     * @param point the point to find the nearest {@link Cluster} for
496     * @return the index of the nearest {@link Cluster} to the given point
497     */
498    private static <T extends Clusterable<T>> int
499        getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
500        double minDistance = Double.MAX_VALUE;
501        int clusterIndex = 0;
502        int minCluster = 0;
503        for (final Cluster<T> c : clusters) {
504            final double distance = point.distanceFrom(c.getCenter());
505            if (distance < minDistance) {
506                minDistance = distance;
507                minCluster = clusterIndex;
508            }
509            clusterIndex++;
510        }
511        return minCluster;
512    }
513
514}