1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.math4.legacy.ml.clustering;
19
20 import org.apache.commons.math4.legacy.exception.NullArgumentException;
21 import org.apache.commons.math4.legacy.exception.ConvergenceException;
22 import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
23 import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
24 import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
25 import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
26 import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance;
27 import org.apache.commons.math4.legacy.stat.descriptive.moment.Variance;
28 import org.apache.commons.rng.UniformRandomProvider;
29 import org.apache.commons.rng.simple.RandomSource;
30
31 import java.util.ArrayList;
32 import java.util.Collection;
33 import java.util.Collections;
34 import java.util.List;
35
36 /**
37 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
38 * @param <T> type of the points to cluster
39 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
40 * @since 3.2
41 */
42 public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
43
44 /** Strategies to use for replacing an empty cluster. */
45 public enum EmptyClusterStrategy {
46
47 /** Split the cluster with largest distance variance. */
48 LARGEST_VARIANCE,
49
50 /** Split the cluster with largest number of points. */
51 LARGEST_POINTS_NUMBER,
52
53 /** Create a cluster around the point farthest from its centroid. */
54 FARTHEST_POINT,
55
56 /** Generate an error. */
57 ERROR
58 }
59
60 /** The number of clusters. */
61 private final int numberOfClusters;
62
63 /** The maximum number of iterations. */
64 private final int maxIterations;
65
66 /** Random generator for choosing initial centers. */
67 private final UniformRandomProvider random;
68
69 /** Selected strategy for empty clusters. */
70 private final EmptyClusterStrategy emptyStrategy;
71
72 /** Build a clusterer.
73 * <p>
74 * The default strategy for handling empty clusters that may appear during
75 * algorithm iterations is to split the cluster with largest distance variance.
76 * <p>
77 * The euclidean distance will be used as default distance measure.
78 *
79 * @param k the number of clusters to split the data into
80 */
81 public KMeansPlusPlusClusterer(final int k) {
82 this(k, Integer.MAX_VALUE);
83 }
84
85 /** Build a clusterer.
86 * <p>
87 * The default strategy for handling empty clusters that may appear during
88 * algorithm iterations is to split the cluster with largest distance variance.
89 * <p>
90 * The euclidean distance will be used as default distance measure.
91 *
92 * @param k the number of clusters to split the data into
93 * @param maxIterations the maximum number of iterations to run the algorithm for.
94 * If negative, no maximum will be used.
95 */
96 public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
97 this(k, maxIterations, new EuclideanDistance());
98 }
99
100 /** Build a clusterer.
101 * <p>
102 * The default strategy for handling empty clusters that may appear during
103 * algorithm iterations is to split the cluster with largest distance variance.
104 *
105 * @param k the number of clusters to split the data into
106 * @param maxIterations the maximum number of iterations to run the algorithm for.
107 * @param measure the distance measure to use
108 * @throws NotStrictlyPositiveException if {@code k <= 0}.
109 */
110 public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
111 this(k, maxIterations, measure, RandomSource.MT_64.create());
112 }
113
114 /** Build a clusterer.
115 * <p>
116 * The default strategy for handling empty clusters that may appear during
117 * algorithm iterations is to split the cluster with largest distance variance.
118 *
119 * @param k the number of clusters to split the data into
120 * @param maxIterations the maximum number of iterations to run the algorithm for.
121 * If negative, no maximum will be used.
122 * @param measure the distance measure to use
123 * @param random random generator to use for choosing initial centers
124 */
125 public KMeansPlusPlusClusterer(final int k, final int maxIterations,
126 final DistanceMeasure measure,
127 final UniformRandomProvider random) {
128 this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
129 }
130
131 /** Build a clusterer.
132 *
133 * @param k the number of clusters to split the data into
134 * @param maxIterations the maximum number of iterations to run the algorithm for.
135 * @param measure the distance measure to use
136 * @param random random generator to use for choosing initial centers
137 * @param emptyStrategy strategy to use for handling empty clusters that
138 * may appear during algorithm iterations
139 * @throws NotStrictlyPositiveException if {@code k <= 0} or
140 * {@code maxIterations <= 0}.
141 */
142 public KMeansPlusPlusClusterer(final int k,
143 final int maxIterations,
144 final DistanceMeasure measure,
145 final UniformRandomProvider random,
146 final EmptyClusterStrategy emptyStrategy) {
147 super(measure);
148
149 if (k <= 0) {
150 throw new NotStrictlyPositiveException(k);
151 }
152 if (maxIterations <= 0) {
153 throw new NotStrictlyPositiveException(maxIterations);
154 }
155
156 this.numberOfClusters = k;
157 this.maxIterations = maxIterations;
158 this.random = random;
159 this.emptyStrategy = emptyStrategy;
160 }
161
162 /**
163 * Return the number of clusters this instance will use.
164 * @return the number of clusters
165 */
166 public int getNumberOfClusters() {
167 return numberOfClusters;
168 }
169
170 /**
171 * Returns the maximum number of iterations this instance will use.
172 * @return the maximum number of iterations, or -1 if no maximum is set
173 */
174 public int getMaxIterations() {
175 return maxIterations;
176 }
177
178 /**
179 * Runs the K-means++ clustering algorithm.
180 *
181 * @param points the points to cluster
182 * @return a list of clusters containing the points
183 * @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException
184 * if the data points are null or the number of clusters is larger than the
185 * number of data points
186 * @throws ConvergenceException if an empty cluster is encountered and the
187 * empty cluster strategy is set to {@link EmptyClusterStrategy#ERROR}
188 */
189 @Override
190 public List<CentroidCluster<T>> cluster(final Collection<T> points) {
191 // sanity checks
192 NullArgumentException.check(points);
193
194 // number of clusters has to be smaller or equal the number of data points
195 if (points.size() < numberOfClusters) {
196 throw new NumberIsTooSmallException(points.size(), numberOfClusters, false);
197 }
198
199 // create the initial clusters
200 List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
201
202 // create an array containing the latest assignment of a point to a cluster
203 // no need to initialize the array, as it will be filled with the first assignment
204 int[] assignments = new int[points.size()];
205 assignPointsToClusters(clusters, points, assignments);
206
207 // iterate through updating the centers until we're done
208 for (int count = 0; count < maxIterations; count++) {
209 boolean hasEmptyCluster = clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty());
210 List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
211 int changes = assignPointsToClusters(newClusters, points, assignments);
212 clusters = newClusters;
213
214 // if there were no more changes in the point-to-cluster assignment
215 // and there are no empty clusters left, return the current clusters
216 if (changes == 0 && !hasEmptyCluster) {
217 return clusters;
218 }
219 }
220 return clusters;
221 }
222
223 /**
224 * @return the random generator
225 */
226 UniformRandomProvider getRandomGenerator() {
227 return random;
228 }
229
230 /**
231 * @return the {@link EmptyClusterStrategy}
232 */
233 EmptyClusterStrategy getEmptyClusterStrategy() {
234 return emptyStrategy;
235 }
236
237 /**
238 * Adjust the clusters's centers with means of points.
239 * @param clusters the origin clusters
240 * @return adjusted clusters with center points
241 */
242 List<CentroidCluster<T>> adjustClustersCenters(List<CentroidCluster<T>> clusters) {
243 List<CentroidCluster<T>> newClusters = new ArrayList<>();
244 for (final CentroidCluster<T> cluster : clusters) {
245 final Clusterable newCenter;
246 if (cluster.getPoints().isEmpty()) {
247 switch (emptyStrategy) {
248 case LARGEST_VARIANCE :
249 newCenter = getPointFromLargestVarianceCluster(clusters);
250 break;
251 case LARGEST_POINTS_NUMBER :
252 newCenter = getPointFromLargestNumberCluster(clusters);
253 break;
254 case FARTHEST_POINT :
255 newCenter = getFarthestPoint(clusters);
256 break;
257 default :
258 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
259 }
260 } else {
261 newCenter = cluster.centroid();
262 }
263 newClusters.add(new CentroidCluster<>(newCenter));
264 }
265 return newClusters;
266 }
267
268 /**
269 * Adds the given points to the closest {@link Cluster}.
270 *
271 * @param clusters the {@link Cluster}s to add the points to
272 * @param points the points to add to the given {@link Cluster}s
273 * @param assignments points assignments to clusters
274 * @return the number of points assigned to different clusters as the iteration before
275 */
276 private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
277 final Collection<T> points,
278 final int[] assignments) {
279 int assignedDifferently = 0;
280 int pointIndex = 0;
281 for (final T p : points) {
282 int clusterIndex = getNearestCluster(clusters, p);
283 if (clusterIndex != assignments[pointIndex]) {
284 assignedDifferently++;
285 }
286
287 CentroidCluster<T> cluster = clusters.get(clusterIndex);
288 cluster.addPoint(p);
289 assignments[pointIndex++] = clusterIndex;
290 }
291
292 return assignedDifferently;
293 }
294
295 /**
296 * Use K-means++ to choose the initial centers.
297 *
298 * @param points the points to choose the initial centers from
299 * @return the initial centers
300 */
301 List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
302
303 // Convert to list for indexed access. Make it unmodifiable, since removal of items
304 // would screw up the logic of this method.
305 final List<T> pointList = Collections.unmodifiableList(new ArrayList<> (points));
306
307 // The number of points in the list.
308 final int numPoints = pointList.size();
309
310 // Set the corresponding element in this array to indicate when
311 // elements of pointList are no longer available.
312 final boolean[] taken = new boolean[numPoints];
313
314 // The resulting list of initial centers.
315 final List<CentroidCluster<T>> resultSet = new ArrayList<>();
316
317 // Choose one center uniformly at random from among the data points.
318 final int firstPointIndex = random.nextInt(numPoints);
319
320 final T firstPoint = pointList.get(firstPointIndex);
321
322 resultSet.add(new CentroidCluster<>(firstPoint));
323
324 // Must mark it as taken
325 taken[firstPointIndex] = true;
326
327 // To keep track of the minimum distance squared of elements of
328 // pointList to elements of resultSet.
329 final double[] minDistSquared = new double[numPoints];
330
331 // Initialize the elements. Since the only point in resultSet is firstPoint,
332 // this is very easy.
333 for (int i = 0; i < numPoints; i++) {
334 if (i != firstPointIndex) { // That point isn't considered
335 double d = distance(firstPoint, pointList.get(i));
336 minDistSquared[i] = d*d;
337 }
338 }
339
340 while (resultSet.size() < numberOfClusters) {
341
342 // Sum up the squared distances for the points in pointList not
343 // already taken.
344 double distSqSum = 0.0;
345
346 for (int i = 0; i < numPoints; i++) {
347 if (!taken[i]) {
348 distSqSum += minDistSquared[i];
349 }
350 }
351
352 // Add one new data point as a center. Each point x is chosen with
353 // probability proportional to D(x)2
354 final double r = random.nextDouble() * distSqSum;
355
356 // The index of the next point to be added to the resultSet.
357 int nextPointIndex = -1;
358
359 // Sum through the squared min distances again, stopping when
360 // sum >= r.
361 double sum = 0.0;
362 for (int i = 0; i < numPoints; i++) {
363 if (!taken[i]) {
364 sum += minDistSquared[i];
365 if (sum >= r) {
366 nextPointIndex = i;
367 break;
368 }
369 }
370 }
371
372 // If it's not set to >= 0, the point wasn't found in the previous
373 // for loop, probably because distances are extremely small. Just pick
374 // the last available point.
375 if (nextPointIndex == -1) {
376 for (int i = numPoints - 1; i >= 0; i--) {
377 if (!taken[i]) {
378 nextPointIndex = i;
379 break;
380 }
381 }
382 }
383
384 // We found one.
385 if (nextPointIndex >= 0) {
386
387 final T p = pointList.get(nextPointIndex);
388
389 resultSet.add(new CentroidCluster<T> (p));
390
391 // Mark it as taken.
392 taken[nextPointIndex] = true;
393
394 if (resultSet.size() < numberOfClusters) {
395 // Now update elements of minDistSquared. We only have to compute
396 // the distance to the new center to do this.
397 for (int j = 0; j < numPoints; j++) {
398 // Only have to worry about the points still not taken.
399 if (!taken[j]) {
400 double d = distance(p, pointList.get(j));
401 double d2 = d * d;
402 if (d2 < minDistSquared[j]) {
403 minDistSquared[j] = d2;
404 }
405 }
406 }
407 }
408 } else {
409 // None found --
410 // Break from the while loop to prevent
411 // an infinite loop.
412 break;
413 }
414 }
415
416 return resultSet;
417 }
418
419 /**
420 * Get a random point from the {@link Cluster} with the largest distance variance.
421 *
422 * @param clusters the {@link Cluster}s to search
423 * @return a random point from the selected cluster
424 * @throws ConvergenceException if clusters are all empty
425 */
426 private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters) {
427 double maxVariance = Double.NEGATIVE_INFINITY;
428 Cluster<T> selected = null;
429 for (final CentroidCluster<T> cluster : clusters) {
430 if (!cluster.getPoints().isEmpty()) {
431
432 // compute the distance variance of the current cluster
433 final Clusterable center = cluster.getCenter();
434 final Variance stat = new Variance();
435 for (final T point : cluster.getPoints()) {
436 stat.increment(distance(point, center));
437 }
438 final double variance = stat.getResult();
439
440 // select the cluster with the largest variance
441 if (variance > maxVariance) {
442 maxVariance = variance;
443 selected = cluster;
444 }
445 }
446 }
447
448 // did we find at least one non-empty cluster ?
449 if (selected == null) {
450 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
451 }
452
453 // extract a random point from the cluster
454 final List<T> selectedPoints = selected.getPoints();
455 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
456 }
457
458 /**
459 * Get a random point from the {@link Cluster} with the largest number of points.
460 *
461 * @param clusters the {@link Cluster}s to search
462 * @return a random point from the selected cluster
463 * @throws ConvergenceException if clusters are all empty
464 */
465 private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters) {
466 int maxNumber = 0;
467 Cluster<T> selected = null;
468 for (final Cluster<T> cluster : clusters) {
469
470 // get the number of points of the current cluster
471 final int number = cluster.getPoints().size();
472
473 // select the cluster with the largest number of points
474 if (number > maxNumber) {
475 maxNumber = number;
476 selected = cluster;
477 }
478 }
479
480 // did we find at least one non-empty cluster ?
481 if (selected == null) {
482 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
483 }
484
485 // extract a random point from the cluster
486 final List<T> selectedPoints = selected.getPoints();
487 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
488 }
489
490 /**
491 * Get the point farthest to its cluster center.
492 *
493 * @param clusters the {@link Cluster}s to search
494 * @return point farthest to its cluster center
495 * @throws ConvergenceException if clusters are all empty
496 */
497 private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) {
498 double maxDistance = Double.NEGATIVE_INFINITY;
499 Cluster<T> selectedCluster = null;
500 int selectedPoint = -1;
501 for (final CentroidCluster<T> cluster : clusters) {
502
503 // get the farthest point
504 final Clusterable center = cluster.getCenter();
505 final List<T> points = cluster.getPoints();
506 for (int i = 0; i < points.size(); ++i) {
507 final double distance = distance(points.get(i), center);
508 if (distance > maxDistance) {
509 maxDistance = distance;
510 selectedCluster = cluster;
511 selectedPoint = i;
512 }
513 }
514 }
515
516 // did we find at least one non-empty cluster ?
517 if (selectedCluster == null) {
518 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
519 }
520
521 return selectedCluster.getPoints().remove(selectedPoint);
522 }
523
524 /**
525 * Returns the nearest {@link Cluster} to the given point.
526 *
527 * @param clusters the {@link Cluster}s to search
528 * @param point the point to find the nearest {@link Cluster} for
529 * @return the index of the nearest {@link Cluster} to the given point
530 */
531 private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
532 double minDistance = Double.MAX_VALUE;
533 int clusterIndex = 0;
534 int minCluster = 0;
535 for (final CentroidCluster<T> c : clusters) {
536 final double distance = distance(point, c.getCenter());
537 if (distance < minDistance) {
538 minDistance = distance;
539 minCluster = clusterIndex;
540 }
541 clusterIndex++;
542 }
543 return minCluster;
544 }
545 }