View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.ml.clustering;
19  
20  import org.apache.commons.math4.legacy.exception.NullArgumentException;
21  import org.apache.commons.math4.legacy.exception.ConvergenceException;
22  import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
23  import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
24  import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
25  import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
26  import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance;
27  import org.apache.commons.math4.legacy.stat.descriptive.moment.Variance;
28  import org.apache.commons.rng.UniformRandomProvider;
29  import org.apache.commons.rng.simple.RandomSource;
30  
31  import java.util.ArrayList;
32  import java.util.Collection;
33  import java.util.Collections;
34  import java.util.List;
35  
36  /**
37   * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
38   * @param <T> type of the points to cluster
39   * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
40   * @since 3.2
41   */
42  public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
43  
44      /** Strategies to use for replacing an empty cluster. */
45      public enum EmptyClusterStrategy {
46  
47          /** Split the cluster with largest distance variance. */
48          LARGEST_VARIANCE,
49  
50          /** Split the cluster with largest number of points. */
51          LARGEST_POINTS_NUMBER,
52  
53          /** Create a cluster around the point farthest from its centroid. */
54          FARTHEST_POINT,
55  
56          /** Generate an error. */
57          ERROR
58      }
59  
60      /** The number of clusters. */
61      private final int numberOfClusters;
62  
63      /** The maximum number of iterations. */
64      private final int maxIterations;
65  
66      /** Random generator for choosing initial centers. */
67      private final UniformRandomProvider random;
68  
69      /** Selected strategy for empty clusters. */
70      private final EmptyClusterStrategy emptyStrategy;
71  
72      /** Build a clusterer.
73       * <p>
74       * The default strategy for handling empty clusters that may appear during
75       * algorithm iterations is to split the cluster with largest distance variance.
76       * <p>
77       * The euclidean distance will be used as default distance measure.
78       *
79       * @param k the number of clusters to split the data into
80       */
81      public KMeansPlusPlusClusterer(final int k) {
82          this(k, Integer.MAX_VALUE);
83      }
84  
85      /** Build a clusterer.
86       * <p>
87       * The default strategy for handling empty clusters that may appear during
88       * algorithm iterations is to split the cluster with largest distance variance.
89       * <p>
90       * The euclidean distance will be used as default distance measure.
91       *
92       * @param k the number of clusters to split the data into
93       * @param maxIterations the maximum number of iterations to run the algorithm for.
94       *   If negative, no maximum will be used.
95       */
96      public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
97          this(k, maxIterations, new EuclideanDistance());
98      }
99  
100     /** Build a clusterer.
101      * <p>
102      * The default strategy for handling empty clusters that may appear during
103      * algorithm iterations is to split the cluster with largest distance variance.
104      *
105      * @param k the number of clusters to split the data into
106      * @param maxIterations the maximum number of iterations to run the algorithm for.
107      * @param measure the distance measure to use
108      * @throws NotStrictlyPositiveException if {@code k <= 0}.
109      */
110     public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
111         this(k, maxIterations, measure, RandomSource.MT_64.create());
112     }
113 
114     /** Build a clusterer.
115      * <p>
116      * The default strategy for handling empty clusters that may appear during
117      * algorithm iterations is to split the cluster with largest distance variance.
118      *
119      * @param k the number of clusters to split the data into
120      * @param maxIterations the maximum number of iterations to run the algorithm for.
121      *   If negative, no maximum will be used.
122      * @param measure the distance measure to use
123      * @param random random generator to use for choosing initial centers
124      */
125     public KMeansPlusPlusClusterer(final int k, final int maxIterations,
126                                    final DistanceMeasure measure,
127                                    final UniformRandomProvider random) {
128         this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
129     }
130 
131     /** Build a clusterer.
132      *
133      * @param k the number of clusters to split the data into
134      * @param maxIterations the maximum number of iterations to run the algorithm for.
135      * @param measure the distance measure to use
136      * @param random random generator to use for choosing initial centers
137      * @param emptyStrategy strategy to use for handling empty clusters that
138      * may appear during algorithm iterations
139      * @throws NotStrictlyPositiveException if {@code k <= 0} or
140      * {@code maxIterations <= 0}.
141      */
142     public KMeansPlusPlusClusterer(final int k,
143                                    final int maxIterations,
144                                    final DistanceMeasure measure,
145                                    final UniformRandomProvider random,
146                                    final EmptyClusterStrategy emptyStrategy) {
147         super(measure);
148 
149         if (k <= 0) {
150             throw new NotStrictlyPositiveException(k);
151         }
152         if (maxIterations <= 0) {
153             throw new NotStrictlyPositiveException(maxIterations);
154         }
155 
156         this.numberOfClusters = k;
157         this.maxIterations = maxIterations;
158         this.random = random;
159         this.emptyStrategy = emptyStrategy;
160     }
161 
162     /**
163      * Return the number of clusters this instance will use.
164      * @return the number of clusters
165      */
166     public int getNumberOfClusters() {
167         return numberOfClusters;
168     }
169 
170     /**
171      * Returns the maximum number of iterations this instance will use.
172      * @return the maximum number of iterations, or -1 if no maximum is set
173      */
174     public int getMaxIterations() {
175         return maxIterations;
176     }
177 
178     /**
179      * Runs the K-means++ clustering algorithm.
180      *
181      * @param points the points to cluster
182      * @return a list of clusters containing the points
183      * @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException
184      * if the data points are null or the number of clusters is larger than the
185      * number of data points
186      * @throws ConvergenceException if an empty cluster is encountered and the
187      * empty cluster strategy is set to {@link EmptyClusterStrategy#ERROR}
188      */
189     @Override
190     public List<CentroidCluster<T>> cluster(final Collection<T> points) {
191         // sanity checks
192         NullArgumentException.check(points);
193 
194         // number of clusters has to be smaller or equal the number of data points
195         if (points.size() < numberOfClusters) {
196             throw new NumberIsTooSmallException(points.size(), numberOfClusters, false);
197         }
198 
199         // create the initial clusters
200         List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
201 
202         // create an array containing the latest assignment of a point to a cluster
203         // no need to initialize the array, as it will be filled with the first assignment
204         int[] assignments = new int[points.size()];
205         assignPointsToClusters(clusters, points, assignments);
206 
207         // iterate through updating the centers until we're done
208         for (int count = 0; count < maxIterations; count++) {
209             boolean hasEmptyCluster = clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty());
210             List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
211             int changes = assignPointsToClusters(newClusters, points, assignments);
212             clusters = newClusters;
213 
214             // if there were no more changes in the point-to-cluster assignment
215             // and there are no empty clusters left, return the current clusters
216             if (changes == 0 && !hasEmptyCluster) {
217                 return clusters;
218             }
219         }
220         return clusters;
221     }
222 
223     /**
224      * @return the random generator
225      */
226     UniformRandomProvider getRandomGenerator() {
227         return random;
228     }
229 
230     /**
231      * @return the {@link EmptyClusterStrategy}
232      */
233     EmptyClusterStrategy getEmptyClusterStrategy() {
234         return emptyStrategy;
235     }
236 
237     /**
238      * Adjust the clusters's centers with means of points.
239      * @param clusters the origin clusters
240      * @return adjusted clusters with center points
241      */
242     List<CentroidCluster<T>> adjustClustersCenters(List<CentroidCluster<T>> clusters) {
243         List<CentroidCluster<T>> newClusters = new ArrayList<>();
244         for (final CentroidCluster<T> cluster : clusters) {
245             final Clusterable newCenter;
246             if (cluster.getPoints().isEmpty()) {
247                 switch (emptyStrategy) {
248                     case LARGEST_VARIANCE :
249                         newCenter = getPointFromLargestVarianceCluster(clusters);
250                         break;
251                     case LARGEST_POINTS_NUMBER :
252                         newCenter = getPointFromLargestNumberCluster(clusters);
253                         break;
254                     case FARTHEST_POINT :
255                         newCenter = getFarthestPoint(clusters);
256                         break;
257                     default :
258                         throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
259                 }
260             } else {
261                 newCenter = cluster.centroid();
262             }
263             newClusters.add(new CentroidCluster<>(newCenter));
264         }
265         return newClusters;
266     }
267 
268     /**
269      * Adds the given points to the closest {@link Cluster}.
270      *
271      * @param clusters the {@link Cluster}s to add the points to
272      * @param points the points to add to the given {@link Cluster}s
273      * @param assignments points assignments to clusters
274      * @return the number of points assigned to different clusters as the iteration before
275      */
276     private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
277                                        final Collection<T> points,
278                                        final int[] assignments) {
279         int assignedDifferently = 0;
280         int pointIndex = 0;
281         for (final T p : points) {
282             int clusterIndex = getNearestCluster(clusters, p);
283             if (clusterIndex != assignments[pointIndex]) {
284                 assignedDifferently++;
285             }
286 
287             CentroidCluster<T> cluster = clusters.get(clusterIndex);
288             cluster.addPoint(p);
289             assignments[pointIndex++] = clusterIndex;
290         }
291 
292         return assignedDifferently;
293     }
294 
295     /**
296      * Use K-means++ to choose the initial centers.
297      *
298      * @param points the points to choose the initial centers from
299      * @return the initial centers
300      */
301     List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
302 
303         // Convert to list for indexed access. Make it unmodifiable, since removal of items
304         // would screw up the logic of this method.
305         final List<T> pointList = Collections.unmodifiableList(new ArrayList<> (points));
306 
307         // The number of points in the list.
308         final int numPoints = pointList.size();
309 
310         // Set the corresponding element in this array to indicate when
311         // elements of pointList are no longer available.
312         final boolean[] taken = new boolean[numPoints];
313 
314         // The resulting list of initial centers.
315         final List<CentroidCluster<T>> resultSet = new ArrayList<>();
316 
317         // Choose one center uniformly at random from among the data points.
318         final int firstPointIndex = random.nextInt(numPoints);
319 
320         final T firstPoint = pointList.get(firstPointIndex);
321 
322         resultSet.add(new CentroidCluster<>(firstPoint));
323 
324         // Must mark it as taken
325         taken[firstPointIndex] = true;
326 
327         // To keep track of the minimum distance squared of elements of
328         // pointList to elements of resultSet.
329         final double[] minDistSquared = new double[numPoints];
330 
331         // Initialize the elements.  Since the only point in resultSet is firstPoint,
332         // this is very easy.
333         for (int i = 0; i < numPoints; i++) {
334             if (i != firstPointIndex) { // That point isn't considered
335                 double d = distance(firstPoint, pointList.get(i));
336                 minDistSquared[i] = d*d;
337             }
338         }
339 
340         while (resultSet.size() < numberOfClusters) {
341 
342             // Sum up the squared distances for the points in pointList not
343             // already taken.
344             double distSqSum = 0.0;
345 
346             for (int i = 0; i < numPoints; i++) {
347                 if (!taken[i]) {
348                     distSqSum += minDistSquared[i];
349                 }
350             }
351 
352             // Add one new data point as a center. Each point x is chosen with
353             // probability proportional to D(x)2
354             final double r = random.nextDouble() * distSqSum;
355 
356             // The index of the next point to be added to the resultSet.
357             int nextPointIndex = -1;
358 
359             // Sum through the squared min distances again, stopping when
360             // sum >= r.
361             double sum = 0.0;
362             for (int i = 0; i < numPoints; i++) {
363                 if (!taken[i]) {
364                     sum += minDistSquared[i];
365                     if (sum >= r) {
366                         nextPointIndex = i;
367                         break;
368                     }
369                 }
370             }
371 
372             // If it's not set to >= 0, the point wasn't found in the previous
373             // for loop, probably because distances are extremely small.  Just pick
374             // the last available point.
375             if (nextPointIndex == -1) {
376                 for (int i = numPoints - 1; i >= 0; i--) {
377                     if (!taken[i]) {
378                         nextPointIndex = i;
379                         break;
380                     }
381                 }
382             }
383 
384             // We found one.
385             if (nextPointIndex >= 0) {
386 
387                 final T p = pointList.get(nextPointIndex);
388 
389                 resultSet.add(new CentroidCluster<T> (p));
390 
391                 // Mark it as taken.
392                 taken[nextPointIndex] = true;
393 
394                 if (resultSet.size() < numberOfClusters) {
395                     // Now update elements of minDistSquared.  We only have to compute
396                     // the distance to the new center to do this.
397                     for (int j = 0; j < numPoints; j++) {
398                         // Only have to worry about the points still not taken.
399                         if (!taken[j]) {
400                             double d = distance(p, pointList.get(j));
401                             double d2 = d * d;
402                             if (d2 < minDistSquared[j]) {
403                                 minDistSquared[j] = d2;
404                             }
405                         }
406                     }
407                 }
408             } else {
409                 // None found --
410                 // Break from the while loop to prevent
411                 // an infinite loop.
412                 break;
413             }
414         }
415 
416         return resultSet;
417     }
418 
419     /**
420      * Get a random point from the {@link Cluster} with the largest distance variance.
421      *
422      * @param clusters the {@link Cluster}s to search
423      * @return a random point from the selected cluster
424      * @throws ConvergenceException if clusters are all empty
425      */
426     private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters) {
427         double maxVariance = Double.NEGATIVE_INFINITY;
428         Cluster<T> selected = null;
429         for (final CentroidCluster<T> cluster : clusters) {
430             if (!cluster.getPoints().isEmpty()) {
431 
432                 // compute the distance variance of the current cluster
433                 final Clusterable center = cluster.getCenter();
434                 final Variance stat = new Variance();
435                 for (final T point : cluster.getPoints()) {
436                     stat.increment(distance(point, center));
437                 }
438                 final double variance = stat.getResult();
439 
440                 // select the cluster with the largest variance
441                 if (variance > maxVariance) {
442                     maxVariance = variance;
443                     selected = cluster;
444                 }
445             }
446         }
447 
448         // did we find at least one non-empty cluster ?
449         if (selected == null) {
450             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
451         }
452 
453         // extract a random point from the cluster
454         final List<T> selectedPoints = selected.getPoints();
455         return selectedPoints.remove(random.nextInt(selectedPoints.size()));
456     }
457 
458     /**
459      * Get a random point from the {@link Cluster} with the largest number of points.
460      *
461      * @param clusters the {@link Cluster}s to search
462      * @return a random point from the selected cluster
463      * @throws ConvergenceException if clusters are all empty
464      */
465     private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters) {
466         int maxNumber = 0;
467         Cluster<T> selected = null;
468         for (final Cluster<T> cluster : clusters) {
469 
470             // get the number of points of the current cluster
471             final int number = cluster.getPoints().size();
472 
473             // select the cluster with the largest number of points
474             if (number > maxNumber) {
475                 maxNumber = number;
476                 selected = cluster;
477             }
478         }
479 
480         // did we find at least one non-empty cluster ?
481         if (selected == null) {
482             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
483         }
484 
485         // extract a random point from the cluster
486         final List<T> selectedPoints = selected.getPoints();
487         return selectedPoints.remove(random.nextInt(selectedPoints.size()));
488     }
489 
490     /**
491      * Get the point farthest to its cluster center.
492      *
493      * @param clusters the {@link Cluster}s to search
494      * @return point farthest to its cluster center
495      * @throws ConvergenceException if clusters are all empty
496      */
497     private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) {
498         double maxDistance = Double.NEGATIVE_INFINITY;
499         Cluster<T> selectedCluster = null;
500         int selectedPoint = -1;
501         for (final CentroidCluster<T> cluster : clusters) {
502 
503             // get the farthest point
504             final Clusterable center = cluster.getCenter();
505             final List<T> points = cluster.getPoints();
506             for (int i = 0; i < points.size(); ++i) {
507                 final double distance = distance(points.get(i), center);
508                 if (distance > maxDistance) {
509                     maxDistance     = distance;
510                     selectedCluster = cluster;
511                     selectedPoint   = i;
512                 }
513             }
514         }
515 
516         // did we find at least one non-empty cluster ?
517         if (selectedCluster == null) {
518             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
519         }
520 
521         return selectedCluster.getPoints().remove(selectedPoint);
522     }
523 
524     /**
525      * Returns the nearest {@link Cluster} to the given point.
526      *
527      * @param clusters the {@link Cluster}s to search
528      * @param point the point to find the nearest {@link Cluster} for
529      * @return the index of the nearest {@link Cluster} to the given point
530      */
531     private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
532         double minDistance = Double.MAX_VALUE;
533         int clusterIndex = 0;
534         int minCluster = 0;
535         for (final CentroidCluster<T> c : clusters) {
536             final double distance = distance(point, c.getCenter());
537             if (distance < minDistance) {
538                 minDistance = distance;
539                 minCluster = clusterIndex;
540             }
541             clusterIndex++;
542         }
543         return minCluster;
544     }
545 }