001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.stat.clustering;
019
020import java.util.ArrayList;
021import java.util.Collection;
022import java.util.Collections;
023import java.util.List;
024import java.util.Random;
025
026import org.apache.commons.math3.exception.ConvergenceException;
027import org.apache.commons.math3.exception.MathIllegalArgumentException;
028import org.apache.commons.math3.exception.NumberIsTooSmallException;
029import org.apache.commons.math3.exception.util.LocalizedFormats;
030import org.apache.commons.math3.stat.descriptive.moment.Variance;
031import org.apache.commons.math3.util.MathUtils;
032
033/**
034 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
035 * @param <T> type of the points to cluster
036 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
037 * @version $Id: KMeansPlusPlusClusterer.java 1461871 2013-03-27 22:01:25Z tn $
038 * @since 2.0
039 * @deprecated As of 3.2 (to be removed in 4.0),
040 * use {@link org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer} instead
041 */
042@Deprecated
043public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
044
045    /** Strategies to use for replacing an empty cluster. */
046    public static enum EmptyClusterStrategy {
047
048        /** Split the cluster with largest distance variance. */
049        LARGEST_VARIANCE,
050
051        /** Split the cluster with largest number of points. */
052        LARGEST_POINTS_NUMBER,
053
054        /** Create a cluster around the point farthest from its centroid. */
055        FARTHEST_POINT,
056
057        /** Generate an error. */
058        ERROR
059
060    }
061
062    /** Random generator for choosing initial centers. */
063    private final Random random;
064
065    /** Selected strategy for empty clusters. */
066    private final EmptyClusterStrategy emptyStrategy;
067
068    /** Build a clusterer.
069     * <p>
070     * The default strategy for handling empty clusters that may appear during
071     * algorithm iterations is to split the cluster with largest distance variance.
072     * </p>
073     * @param random random generator to use for choosing initial centers
074     */
075    public KMeansPlusPlusClusterer(final Random random) {
076        this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
077    }
078
079    /** Build a clusterer.
080     * @param random random generator to use for choosing initial centers
081     * @param emptyStrategy strategy to use for handling empty clusters that
082     * may appear during algorithm iterations
083     * @since 2.2
084     */
085    public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
086        this.random        = random;
087        this.emptyStrategy = emptyStrategy;
088    }
089
090    /**
091     * Runs the K-means++ clustering algorithm.
092     *
093     * @param points the points to cluster
094     * @param k the number of clusters to split the data into
095     * @param numTrials number of trial runs
096     * @param maxIterationsPerTrial the maximum number of iterations to run the algorithm
097     *     for at each trial run.  If negative, no maximum will be used
098     * @return a list of clusters containing the points
099     * @throws MathIllegalArgumentException if the data points are null or the number
100     *     of clusters is larger than the number of data points
101     * @throws ConvergenceException if an empty cluster is encountered and the
102     * {@link #emptyStrategy} is set to {@code ERROR}
103     */
104    public List<Cluster<T>> cluster(final Collection<T> points, final int k,
105                                    int numTrials, int maxIterationsPerTrial)
106        throws MathIllegalArgumentException, ConvergenceException {
107
108        // at first, we have not found any clusters list yet
109        List<Cluster<T>> best = null;
110        double bestVarianceSum = Double.POSITIVE_INFINITY;
111
112        // do several clustering trials
113        for (int i = 0; i < numTrials; ++i) {
114
115            // compute a clusters list
116            List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial);
117
118            // compute the variance of the current list
119            double varianceSum = 0.0;
120            for (final Cluster<T> cluster : clusters) {
121                if (!cluster.getPoints().isEmpty()) {
122
123                    // compute the distance variance of the current cluster
124                    final T center = cluster.getCenter();
125                    final Variance stat = new Variance();
126                    for (final T point : cluster.getPoints()) {
127                        stat.increment(point.distanceFrom(center));
128                    }
129                    varianceSum += stat.getResult();
130
131                }
132            }
133
134            if (varianceSum <= bestVarianceSum) {
135                // this one is the best we have found so far, remember it
136                best            = clusters;
137                bestVarianceSum = varianceSum;
138            }
139
140        }
141
142        // return the best clusters list found
143        return best;
144
145    }
146
147    /**
148     * Runs the K-means++ clustering algorithm.
149     *
150     * @param points the points to cluster
151     * @param k the number of clusters to split the data into
152     * @param maxIterations the maximum number of iterations to run the algorithm
153     *     for.  If negative, no maximum will be used
154     * @return a list of clusters containing the points
155     * @throws MathIllegalArgumentException if the data points are null or the number
156     *     of clusters is larger than the number of data points
157     * @throws ConvergenceException if an empty cluster is encountered and the
158     * {@link #emptyStrategy} is set to {@code ERROR}
159     */
160    public List<Cluster<T>> cluster(final Collection<T> points, final int k,
161                                    final int maxIterations)
162        throws MathIllegalArgumentException, ConvergenceException {
163
164        // sanity checks
165        MathUtils.checkNotNull(points);
166
167        // number of clusters has to be smaller or equal the number of data points
168        if (points.size() < k) {
169            throw new NumberIsTooSmallException(points.size(), k, false);
170        }
171
172        // create the initial clusters
173        List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
174
175        // create an array containing the latest assignment of a point to a cluster
176        // no need to initialize the array, as it will be filled with the first assignment
177        int[] assignments = new int[points.size()];
178        assignPointsToClusters(clusters, points, assignments);
179
180        // iterate through updating the centers until we're done
181        final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
182        for (int count = 0; count < max; count++) {
183            boolean emptyCluster = false;
184            List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
185            for (final Cluster<T> cluster : clusters) {
186                final T newCenter;
187                if (cluster.getPoints().isEmpty()) {
188                    switch (emptyStrategy) {
189                        case LARGEST_VARIANCE :
190                            newCenter = getPointFromLargestVarianceCluster(clusters);
191                            break;
192                        case LARGEST_POINTS_NUMBER :
193                            newCenter = getPointFromLargestNumberCluster(clusters);
194                            break;
195                        case FARTHEST_POINT :
196                            newCenter = getFarthestPoint(clusters);
197                            break;
198                        default :
199                            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
200                    }
201                    emptyCluster = true;
202                } else {
203                    newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
204                }
205                newClusters.add(new Cluster<T>(newCenter));
206            }
207            int changes = assignPointsToClusters(newClusters, points, assignments);
208            clusters = newClusters;
209
210            // if there were no more changes in the point-to-cluster assignment
211            // and there are no empty clusters left, return the current clusters
212            if (changes == 0 && !emptyCluster) {
213                return clusters;
214            }
215        }
216        return clusters;
217    }
218
219    /**
220     * Adds the given points to the closest {@link Cluster}.
221     *
222     * @param <T> type of the points to cluster
223     * @param clusters the {@link Cluster}s to add the points to
224     * @param points the points to add to the given {@link Cluster}s
225     * @param assignments points assignments to clusters
226     * @return the number of points assigned to different clusters as the iteration before
227     */
228    private static <T extends Clusterable<T>> int
229        assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points,
230                               final int[] assignments) {
231        int assignedDifferently = 0;
232        int pointIndex = 0;
233        for (final T p : points) {
234            int clusterIndex = getNearestCluster(clusters, p);
235            if (clusterIndex != assignments[pointIndex]) {
236                assignedDifferently++;
237            }
238
239            Cluster<T> cluster = clusters.get(clusterIndex);
240            cluster.addPoint(p);
241            assignments[pointIndex++] = clusterIndex;
242        }
243
244        return assignedDifferently;
245    }
246
247    /**
248     * Use K-means++ to choose the initial centers.
249     *
250     * @param <T> type of the points to cluster
251     * @param points the points to choose the initial centers from
252     * @param k the number of centers to choose
253     * @param random random generator to use
254     * @return the initial centers
255     */
256    private static <T extends Clusterable<T>> List<Cluster<T>>
257        chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
258
259        // Convert to list for indexed access. Make it unmodifiable, since removal of items
260        // would screw up the logic of this method.
261        final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
262
263        // The number of points in the list.
264        final int numPoints = pointList.size();
265
266        // Set the corresponding element in this array to indicate when
267        // elements of pointList are no longer available.
268        final boolean[] taken = new boolean[numPoints];
269
270        // The resulting list of initial centers.
271        final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
272
273        // Choose one center uniformly at random from among the data points.
274        final int firstPointIndex = random.nextInt(numPoints);
275
276        final T firstPoint = pointList.get(firstPointIndex);
277
278        resultSet.add(new Cluster<T>(firstPoint));
279
280        // Must mark it as taken
281        taken[firstPointIndex] = true;
282
283        // To keep track of the minimum distance squared of elements of
284        // pointList to elements of resultSet.
285        final double[] minDistSquared = new double[numPoints];
286
287        // Initialize the elements.  Since the only point in resultSet is firstPoint,
288        // this is very easy.
289        for (int i = 0; i < numPoints; i++) {
290            if (i != firstPointIndex) { // That point isn't considered
291                double d = firstPoint.distanceFrom(pointList.get(i));
292                minDistSquared[i] = d*d;
293            }
294        }
295
296        while (resultSet.size() < k) {
297
298            // Sum up the squared distances for the points in pointList not
299            // already taken.
300            double distSqSum = 0.0;
301
302            for (int i = 0; i < numPoints; i++) {
303                if (!taken[i]) {
304                    distSqSum += minDistSquared[i];
305                }
306            }
307
308            // Add one new data point as a center. Each point x is chosen with
309            // probability proportional to D(x)2
310            final double r = random.nextDouble() * distSqSum;
311
312            // The index of the next point to be added to the resultSet.
313            int nextPointIndex = -1;
314
315            // Sum through the squared min distances again, stopping when
316            // sum >= r.
317            double sum = 0.0;
318            for (int i = 0; i < numPoints; i++) {
319                if (!taken[i]) {
320                    sum += minDistSquared[i];
321                    if (sum >= r) {
322                        nextPointIndex = i;
323                        break;
324                    }
325                }
326            }
327
328            // If it's not set to >= 0, the point wasn't found in the previous
329            // for loop, probably because distances are extremely small.  Just pick
330            // the last available point.
331            if (nextPointIndex == -1) {
332                for (int i = numPoints - 1; i >= 0; i--) {
333                    if (!taken[i]) {
334                        nextPointIndex = i;
335                        break;
336                    }
337                }
338            }
339
340            // We found one.
341            if (nextPointIndex >= 0) {
342
343                final T p = pointList.get(nextPointIndex);
344
345                resultSet.add(new Cluster<T> (p));
346
347                // Mark it as taken.
348                taken[nextPointIndex] = true;
349
350                if (resultSet.size() < k) {
351                    // Now update elements of minDistSquared.  We only have to compute
352                    // the distance to the new center to do this.
353                    for (int j = 0; j < numPoints; j++) {
354                        // Only have to worry about the points still not taken.
355                        if (!taken[j]) {
356                            double d = p.distanceFrom(pointList.get(j));
357                            double d2 = d * d;
358                            if (d2 < minDistSquared[j]) {
359                                minDistSquared[j] = d2;
360                            }
361                        }
362                    }
363                }
364
365            } else {
366                // None found --
367                // Break from the while loop to prevent
368                // an infinite loop.
369                break;
370            }
371        }
372
373        return resultSet;
374    }
375
376    /**
377     * Get a random point from the {@link Cluster} with the largest distance variance.
378     *
379     * @param clusters the {@link Cluster}s to search
380     * @return a random point from the selected cluster
381     * @throws ConvergenceException if clusters are all empty
382     */
383    private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters)
384    throws ConvergenceException {
385
386        double maxVariance = Double.NEGATIVE_INFINITY;
387        Cluster<T> selected = null;
388        for (final Cluster<T> cluster : clusters) {
389            if (!cluster.getPoints().isEmpty()) {
390
391                // compute the distance variance of the current cluster
392                final T center = cluster.getCenter();
393                final Variance stat = new Variance();
394                for (final T point : cluster.getPoints()) {
395                    stat.increment(point.distanceFrom(center));
396                }
397                final double variance = stat.getResult();
398
399                // select the cluster with the largest variance
400                if (variance > maxVariance) {
401                    maxVariance = variance;
402                    selected = cluster;
403                }
404
405            }
406        }
407
408        // did we find at least one non-empty cluster ?
409        if (selected == null) {
410            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
411        }
412
413        // extract a random point from the cluster
414        final List<T> selectedPoints = selected.getPoints();
415        return selectedPoints.remove(random.nextInt(selectedPoints.size()));
416
417    }
418
419    /**
420     * Get a random point from the {@link Cluster} with the largest number of points
421     *
422     * @param clusters the {@link Cluster}s to search
423     * @return a random point from the selected cluster
424     * @throws ConvergenceException if clusters are all empty
425     */
426    private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) throws ConvergenceException {
427
428        int maxNumber = 0;
429        Cluster<T> selected = null;
430        for (final Cluster<T> cluster : clusters) {
431
432            // get the number of points of the current cluster
433            final int number = cluster.getPoints().size();
434
435            // select the cluster with the largest number of points
436            if (number > maxNumber) {
437                maxNumber = number;
438                selected = cluster;
439            }
440
441        }
442
443        // did we find at least one non-empty cluster ?
444        if (selected == null) {
445            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
446        }
447
448        // extract a random point from the cluster
449        final List<T> selectedPoints = selected.getPoints();
450        return selectedPoints.remove(random.nextInt(selectedPoints.size()));
451
452    }
453
454    /**
455     * Get the point farthest to its cluster center
456     *
457     * @param clusters the {@link Cluster}s to search
458     * @return point farthest to its cluster center
459     * @throws ConvergenceException if clusters are all empty
460     */
461    private T getFarthestPoint(final Collection<Cluster<T>> clusters) throws ConvergenceException {
462
463        double maxDistance = Double.NEGATIVE_INFINITY;
464        Cluster<T> selectedCluster = null;
465        int selectedPoint = -1;
466        for (final Cluster<T> cluster : clusters) {
467
468            // get the farthest point
469            final T center = cluster.getCenter();
470            final List<T> points = cluster.getPoints();
471            for (int i = 0; i < points.size(); ++i) {
472                final double distance = points.get(i).distanceFrom(center);
473                if (distance > maxDistance) {
474                    maxDistance     = distance;
475                    selectedCluster = cluster;
476                    selectedPoint   = i;
477                }
478            }
479
480        }
481
482        // did we find at least one non-empty cluster ?
483        if (selectedCluster == null) {
484            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
485        }
486
487        return selectedCluster.getPoints().remove(selectedPoint);
488
489    }
490
491    /**
492     * Returns the nearest {@link Cluster} to the given point
493     *
494     * @param <T> type of the points to cluster
495     * @param clusters the {@link Cluster}s to search
496     * @param point the point to find the nearest {@link Cluster} for
497     * @return the index of the nearest {@link Cluster} to the given point
498     */
499    private static <T extends Clusterable<T>> int
500        getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
501        double minDistance = Double.MAX_VALUE;
502        int clusterIndex = 0;
503        int minCluster = 0;
504        for (final Cluster<T> c : clusters) {
505            final double distance = point.distanceFrom(c.getCenter());
506            if (distance < minDistance) {
507                minDistance = distance;
508                minCluster = clusterIndex;
509            }
510            clusterIndex++;
511        }
512        return minCluster;
513    }
514
515}