View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math3.stat.clustering;
19  
20  import java.util.ArrayList;
21  import java.util.Collection;
22  import java.util.Collections;
23  import java.util.List;
24  import java.util.Random;
25  
26  import org.apache.commons.math3.exception.ConvergenceException;
27  import org.apache.commons.math3.exception.MathIllegalArgumentException;
28  import org.apache.commons.math3.exception.NumberIsTooSmallException;
29  import org.apache.commons.math3.exception.util.LocalizedFormats;
30  import org.apache.commons.math3.stat.descriptive.moment.Variance;
31  import org.apache.commons.math3.util.MathUtils;
32  
33  /**
34   * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
35   * @param <T> type of the points to cluster
36   * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
37   * @version $Id: KMeansPlusPlusClusterer.java 1461871 2013-03-27 22:01:25Z tn $
38   * @since 2.0
39   * @deprecated As of 3.2 (to be removed in 4.0),
40   * use {@link org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer} instead
41   */
42  @Deprecated
43  public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
44  
45      /** Strategies to use for replacing an empty cluster. */
46      public static enum EmptyClusterStrategy {
47  
48          /** Split the cluster with largest distance variance. */
49          LARGEST_VARIANCE,
50  
51          /** Split the cluster with largest number of points. */
52          LARGEST_POINTS_NUMBER,
53  
54          /** Create a cluster around the point farthest from its centroid. */
55          FARTHEST_POINT,
56  
57          /** Generate an error. */
58          ERROR
59  
60      }
61  
62      /** Random generator for choosing initial centers. */
63      private final Random random;
64  
65      /** Selected strategy for empty clusters. */
66      private final EmptyClusterStrategy emptyStrategy;
67  
68      /** Build a clusterer.
69       * <p>
70       * The default strategy for handling empty clusters that may appear during
71       * algorithm iterations is to split the cluster with largest distance variance.
72       * </p>
73       * @param random random generator to use for choosing initial centers
74       */
75      public KMeansPlusPlusClusterer(final Random random) {
76          this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
77      }
78  
79      /** Build a clusterer.
80       * @param random random generator to use for choosing initial centers
81       * @param emptyStrategy strategy to use for handling empty clusters that
82       * may appear during algorithm iterations
83       * @since 2.2
84       */
85      public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
86          this.random        = random;
87          this.emptyStrategy = emptyStrategy;
88      }
89  
90      /**
91       * Runs the K-means++ clustering algorithm.
92       *
93       * @param points the points to cluster
94       * @param k the number of clusters to split the data into
95       * @param numTrials number of trial runs
96       * @param maxIterationsPerTrial the maximum number of iterations to run the algorithm
97       *     for at each trial run.  If negative, no maximum will be used
98       * @return a list of clusters containing the points
99       * @throws MathIllegalArgumentException if the data points are null or the number
100      *     of clusters is larger than the number of data points
101      * @throws ConvergenceException if an empty cluster is encountered and the
102      * {@link #emptyStrategy} is set to {@code ERROR}
103      */
104     public List<Cluster<T>> cluster(final Collection<T> points, final int k,
105                                     int numTrials, int maxIterationsPerTrial)
106         throws MathIllegalArgumentException, ConvergenceException {
107 
108         // at first, we have not found any clusters list yet
109         List<Cluster<T>> best = null;
110         double bestVarianceSum = Double.POSITIVE_INFINITY;
111 
112         // do several clustering trials
113         for (int i = 0; i < numTrials; ++i) {
114 
115             // compute a clusters list
116             List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial);
117 
118             // compute the variance of the current list
119             double varianceSum = 0.0;
120             for (final Cluster<T> cluster : clusters) {
121                 if (!cluster.getPoints().isEmpty()) {
122 
123                     // compute the distance variance of the current cluster
124                     final T center = cluster.getCenter();
125                     final Variance stat = new Variance();
126                     for (final T point : cluster.getPoints()) {
127                         stat.increment(point.distanceFrom(center));
128                     }
129                     varianceSum += stat.getResult();
130 
131                 }
132             }
133 
134             if (varianceSum <= bestVarianceSum) {
135                 // this one is the best we have found so far, remember it
136                 best            = clusters;
137                 bestVarianceSum = varianceSum;
138             }
139 
140         }
141 
142         // return the best clusters list found
143         return best;
144 
145     }
146 
147     /**
148      * Runs the K-means++ clustering algorithm.
149      *
150      * @param points the points to cluster
151      * @param k the number of clusters to split the data into
152      * @param maxIterations the maximum number of iterations to run the algorithm
153      *     for.  If negative, no maximum will be used
154      * @return a list of clusters containing the points
155      * @throws MathIllegalArgumentException if the data points are null or the number
156      *     of clusters is larger than the number of data points
157      * @throws ConvergenceException if an empty cluster is encountered and the
158      * {@link #emptyStrategy} is set to {@code ERROR}
159      */
160     public List<Cluster<T>> cluster(final Collection<T> points, final int k,
161                                     final int maxIterations)
162         throws MathIllegalArgumentException, ConvergenceException {
163 
164         // sanity checks
165         MathUtils.checkNotNull(points);
166 
167         // number of clusters has to be smaller or equal the number of data points
168         if (points.size() < k) {
169             throw new NumberIsTooSmallException(points.size(), k, false);
170         }
171 
172         // create the initial clusters
173         List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
174 
175         // create an array containing the latest assignment of a point to a cluster
176         // no need to initialize the array, as it will be filled with the first assignment
177         int[] assignments = new int[points.size()];
178         assignPointsToClusters(clusters, points, assignments);
179 
180         // iterate through updating the centers until we're done
181         final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
182         for (int count = 0; count < max; count++) {
183             boolean emptyCluster = false;
184             List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
185             for (final Cluster<T> cluster : clusters) {
186                 final T newCenter;
187                 if (cluster.getPoints().isEmpty()) {
188                     switch (emptyStrategy) {
189                         case LARGEST_VARIANCE :
190                             newCenter = getPointFromLargestVarianceCluster(clusters);
191                             break;
192                         case LARGEST_POINTS_NUMBER :
193                             newCenter = getPointFromLargestNumberCluster(clusters);
194                             break;
195                         case FARTHEST_POINT :
196                             newCenter = getFarthestPoint(clusters);
197                             break;
198                         default :
199                             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
200                     }
201                     emptyCluster = true;
202                 } else {
203                     newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
204                 }
205                 newClusters.add(new Cluster<T>(newCenter));
206             }
207             int changes = assignPointsToClusters(newClusters, points, assignments);
208             clusters = newClusters;
209 
210             // if there were no more changes in the point-to-cluster assignment
211             // and there are no empty clusters left, return the current clusters
212             if (changes == 0 && !emptyCluster) {
213                 return clusters;
214             }
215         }
216         return clusters;
217     }
218 
219     /**
220      * Adds the given points to the closest {@link Cluster}.
221      *
222      * @param <T> type of the points to cluster
223      * @param clusters the {@link Cluster}s to add the points to
224      * @param points the points to add to the given {@link Cluster}s
225      * @param assignments points assignments to clusters
226      * @return the number of points assigned to different clusters as the iteration before
227      */
228     private static <T extends Clusterable<T>> int
229         assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points,
230                                final int[] assignments) {
231         int assignedDifferently = 0;
232         int pointIndex = 0;
233         for (final T p : points) {
234             int clusterIndex = getNearestCluster(clusters, p);
235             if (clusterIndex != assignments[pointIndex]) {
236                 assignedDifferently++;
237             }
238 
239             Cluster<T> cluster = clusters.get(clusterIndex);
240             cluster.addPoint(p);
241             assignments[pointIndex++] = clusterIndex;
242         }
243 
244         return assignedDifferently;
245     }
246 
247     /**
248      * Use K-means++ to choose the initial centers.
249      *
250      * @param <T> type of the points to cluster
251      * @param points the points to choose the initial centers from
252      * @param k the number of centers to choose
253      * @param random random generator to use
254      * @return the initial centers
255      */
256     private static <T extends Clusterable<T>> List<Cluster<T>>
257         chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
258 
259         // Convert to list for indexed access. Make it unmodifiable, since removal of items
260         // would screw up the logic of this method.
261         final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
262 
263         // The number of points in the list.
264         final int numPoints = pointList.size();
265 
266         // Set the corresponding element in this array to indicate when
267         // elements of pointList are no longer available.
268         final boolean[] taken = new boolean[numPoints];
269 
270         // The resulting list of initial centers.
271         final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
272 
273         // Choose one center uniformly at random from among the data points.
274         final int firstPointIndex = random.nextInt(numPoints);
275 
276         final T firstPoint = pointList.get(firstPointIndex);
277 
278         resultSet.add(new Cluster<T>(firstPoint));
279 
280         // Must mark it as taken
281         taken[firstPointIndex] = true;
282 
283         // To keep track of the minimum distance squared of elements of
284         // pointList to elements of resultSet.
285         final double[] minDistSquared = new double[numPoints];
286 
287         // Initialize the elements.  Since the only point in resultSet is firstPoint,
288         // this is very easy.
289         for (int i = 0; i < numPoints; i++) {
290             if (i != firstPointIndex) { // That point isn't considered
291                 double d = firstPoint.distanceFrom(pointList.get(i));
292                 minDistSquared[i] = d*d;
293             }
294         }
295 
296         while (resultSet.size() < k) {
297 
298             // Sum up the squared distances for the points in pointList not
299             // already taken.
300             double distSqSum = 0.0;
301 
302             for (int i = 0; i < numPoints; i++) {
303                 if (!taken[i]) {
304                     distSqSum += minDistSquared[i];
305                 }
306             }
307 
308             // Add one new data point as a center. Each point x is chosen with
309             // probability proportional to D(x)2
310             final double r = random.nextDouble() * distSqSum;
311 
312             // The index of the next point to be added to the resultSet.
313             int nextPointIndex = -1;
314 
315             // Sum through the squared min distances again, stopping when
316             // sum >= r.
317             double sum = 0.0;
318             for (int i = 0; i < numPoints; i++) {
319                 if (!taken[i]) {
320                     sum += minDistSquared[i];
321                     if (sum >= r) {
322                         nextPointIndex = i;
323                         break;
324                     }
325                 }
326             }
327 
328             // If it's not set to >= 0, the point wasn't found in the previous
329             // for loop, probably because distances are extremely small.  Just pick
330             // the last available point.
331             if (nextPointIndex == -1) {
332                 for (int i = numPoints - 1; i >= 0; i--) {
333                     if (!taken[i]) {
334                         nextPointIndex = i;
335                         break;
336                     }
337                 }
338             }
339 
340             // We found one.
341             if (nextPointIndex >= 0) {
342 
343                 final T p = pointList.get(nextPointIndex);
344 
345                 resultSet.add(new Cluster<T> (p));
346 
347                 // Mark it as taken.
348                 taken[nextPointIndex] = true;
349 
350                 if (resultSet.size() < k) {
351                     // Now update elements of minDistSquared.  We only have to compute
352                     // the distance to the new center to do this.
353                     for (int j = 0; j < numPoints; j++) {
354                         // Only have to worry about the points still not taken.
355                         if (!taken[j]) {
356                             double d = p.distanceFrom(pointList.get(j));
357                             double d2 = d * d;
358                             if (d2 < minDistSquared[j]) {
359                                 minDistSquared[j] = d2;
360                             }
361                         }
362                     }
363                 }
364 
365             } else {
366                 // None found --
367                 // Break from the while loop to prevent
368                 // an infinite loop.
369                 break;
370             }
371         }
372 
373         return resultSet;
374     }
375 
376     /**
377      * Get a random point from the {@link Cluster} with the largest distance variance.
378      *
379      * @param clusters the {@link Cluster}s to search
380      * @return a random point from the selected cluster
381      * @throws ConvergenceException if clusters are all empty
382      */
383     private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters)
384     throws ConvergenceException {
385 
386         double maxVariance = Double.NEGATIVE_INFINITY;
387         Cluster<T> selected = null;
388         for (final Cluster<T> cluster : clusters) {
389             if (!cluster.getPoints().isEmpty()) {
390 
391                 // compute the distance variance of the current cluster
392                 final T center = cluster.getCenter();
393                 final Variance stat = new Variance();
394                 for (final T point : cluster.getPoints()) {
395                     stat.increment(point.distanceFrom(center));
396                 }
397                 final double variance = stat.getResult();
398 
399                 // select the cluster with the largest variance
400                 if (variance > maxVariance) {
401                     maxVariance = variance;
402                     selected = cluster;
403                 }
404 
405             }
406         }
407 
408         // did we find at least one non-empty cluster ?
409         if (selected == null) {
410             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
411         }
412 
413         // extract a random point from the cluster
414         final List<T> selectedPoints = selected.getPoints();
415         return selectedPoints.remove(random.nextInt(selectedPoints.size()));
416 
417     }
418 
419     /**
420      * Get a random point from the {@link Cluster} with the largest number of points
421      *
422      * @param clusters the {@link Cluster}s to search
423      * @return a random point from the selected cluster
424      * @throws ConvergenceException if clusters are all empty
425      */
426     private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) throws ConvergenceException {
427 
428         int maxNumber = 0;
429         Cluster<T> selected = null;
430         for (final Cluster<T> cluster : clusters) {
431 
432             // get the number of points of the current cluster
433             final int number = cluster.getPoints().size();
434 
435             // select the cluster with the largest number of points
436             if (number > maxNumber) {
437                 maxNumber = number;
438                 selected = cluster;
439             }
440 
441         }
442 
443         // did we find at least one non-empty cluster ?
444         if (selected == null) {
445             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
446         }
447 
448         // extract a random point from the cluster
449         final List<T> selectedPoints = selected.getPoints();
450         return selectedPoints.remove(random.nextInt(selectedPoints.size()));
451 
452     }
453 
454     /**
455      * Get the point farthest to its cluster center
456      *
457      * @param clusters the {@link Cluster}s to search
458      * @return point farthest to its cluster center
459      * @throws ConvergenceException if clusters are all empty
460      */
461     private T getFarthestPoint(final Collection<Cluster<T>> clusters) throws ConvergenceException {
462 
463         double maxDistance = Double.NEGATIVE_INFINITY;
464         Cluster<T> selectedCluster = null;
465         int selectedPoint = -1;
466         for (final Cluster<T> cluster : clusters) {
467 
468             // get the farthest point
469             final T center = cluster.getCenter();
470             final List<T> points = cluster.getPoints();
471             for (int i = 0; i < points.size(); ++i) {
472                 final double distance = points.get(i).distanceFrom(center);
473                 if (distance > maxDistance) {
474                     maxDistance     = distance;
475                     selectedCluster = cluster;
476                     selectedPoint   = i;
477                 }
478             }
479 
480         }
481 
482         // did we find at least one non-empty cluster ?
483         if (selectedCluster == null) {
484             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
485         }
486 
487         return selectedCluster.getPoints().remove(selectedPoint);
488 
489     }
490 
491     /**
492      * Returns the nearest {@link Cluster} to the given point
493      *
494      * @param <T> type of the points to cluster
495      * @param clusters the {@link Cluster}s to search
496      * @param point the point to find the nearest {@link Cluster} for
497      * @return the index of the nearest {@link Cluster} to the given point
498      */
499     private static <T extends Clusterable<T>> int
500         getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
501         double minDistance = Double.MAX_VALUE;
502         int clusterIndex = 0;
503         int minCluster = 0;
504         for (final Cluster<T> c : clusters) {
505             final double distance = point.distanceFrom(c.getCenter());
506             if (distance < minDistance) {
507                 minDistance = distance;
508                 minCluster = clusterIndex;
509             }
510             clusterIndex++;
511         }
512         return minCluster;
513     }
514 
515 }