001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.apache.commons.math.stat.clustering;
019
020 import java.util.ArrayList;
021 import java.util.Collection;
022 import java.util.Collections;
023 import java.util.List;
024 import java.util.Random;
025
026 import org.apache.commons.math.exception.ConvergenceException;
027 import org.apache.commons.math.exception.MathIllegalArgumentException;
028 import org.apache.commons.math.exception.NumberIsTooSmallException;
029 import org.apache.commons.math.exception.util.LocalizedFormats;
030 import org.apache.commons.math.stat.descriptive.moment.Variance;
031 import org.apache.commons.math.util.MathUtils;
032
033 /**
034 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
035 * @param <T> type of the points to cluster
036 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
037 * @version $Id: KMeansPlusPlusClusterer.java 1139915 2011-06-26 19:13:06Z luc $
038 * @since 2.0
039 */
040 public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
041
042 /** Strategies to use for replacing an empty cluster. */
043 public static enum EmptyClusterStrategy {
044
045 /** Split the cluster with largest distance variance. */
046 LARGEST_VARIANCE,
047
048 /** Split the cluster with largest number of points. */
049 LARGEST_POINTS_NUMBER,
050
051 /** Create a cluster around the point farthest from its centroid. */
052 FARTHEST_POINT,
053
054 /** Generate an error. */
055 ERROR
056
057 }
058
059 /** Random generator for choosing initial centers. */
060 private final Random random;
061
062 /** Selected strategy for empty clusters. */
063 private final EmptyClusterStrategy emptyStrategy;
064
065 /** Build a clusterer.
066 * <p>
067 * The default strategy for handling empty clusters that may appear during
068 * algorithm iterations is to split the cluster with largest distance variance.
069 * </p>
070 * @param random random generator to use for choosing initial centers
071 */
072 public KMeansPlusPlusClusterer(final Random random) {
073 this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
074 }
075
076 /** Build a clusterer.
077 * @param random random generator to use for choosing initial centers
078 * @param emptyStrategy strategy to use for handling empty clusters that
079 * may appear during algorithm iterations
080 * @since 2.2
081 */
082 public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
083 this.random = random;
084 this.emptyStrategy = emptyStrategy;
085 }
086
087 /**
088 * Runs the K-means++ clustering algorithm.
089 *
090 * @param points the points to cluster
091 * @param k the number of clusters to split the data into
092 * @param numTrials number of trial runs
093 * @param maxIterationsPerTrial the maximum number of iterations to run the algorithm
094 * for at each trial run. If negative, no maximum will be used
095 * @return a list of clusters containing the points
096 * @throws MathIllegalArgumentException if the data points are null or the number
097 * of clusters is larger than the number of data points
098 */
099 public List<Cluster<T>> cluster(final Collection<T> points, final int k,
100 int numTrials, int maxIterationsPerTrial)
101 throws MathIllegalArgumentException {
102
103 // at first, we have not found any clusters list yet
104 List<Cluster<T>> best = null;
105 double bestVarianceSum = Double.POSITIVE_INFINITY;
106
107 // do several clustering trials
108 for (int i = 0; i < numTrials; ++i) {
109
110 // compute a clusters list
111 List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial);
112
113 // compute the variance of the current list
114 double varianceSum = 0.0;
115 for (final Cluster<T> cluster : clusters) {
116 if (!cluster.getPoints().isEmpty()) {
117
118 // compute the distance variance of the current cluster
119 final T center = cluster.getCenter();
120 final Variance stat = new Variance();
121 for (final T point : cluster.getPoints()) {
122 stat.increment(point.distanceFrom(center));
123 }
124 varianceSum += stat.getResult();
125
126 }
127 }
128
129 if (varianceSum <= bestVarianceSum) {
130 // this one is the best we have found so far, remember it
131 best = clusters;
132 bestVarianceSum = varianceSum;
133 }
134
135 }
136
137 // return the best clusters list found
138 return best;
139
140 }
141
142 /**
143 * Runs the K-means++ clustering algorithm.
144 *
145 * @param points the points to cluster
146 * @param k the number of clusters to split the data into
147 * @param maxIterations the maximum number of iterations to run the algorithm
148 * for. If negative, no maximum will be used
149 * @return a list of clusters containing the points
150 * @throws MathIllegalArgumentException if the data points are null or the number
151 * of clusters is larger than the number of data points
152 */
153 public List<Cluster<T>> cluster(final Collection<T> points, final int k,
154 final int maxIterations)
155 throws MathIllegalArgumentException {
156
157 // sanity checks
158 MathUtils.checkNotNull(points);
159
160 // number of clusters has to be smaller or equal the number of data points
161 if (points.size() < k) {
162 throw new NumberIsTooSmallException(points.size(), k, false);
163 }
164
165 // create the initial clusters
166 List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
167
168 // create an array containing the latest assignment of a point to a cluster
169 // no need to initialize the array, as it will be filled with the first assignment
170 int[] assignments = new int[points.size()];
171 assignPointsToClusters(clusters, points, assignments);
172
173 // iterate through updating the centers until we're done
174 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
175 for (int count = 0; count < max; count++) {
176 boolean emptyCluster = false;
177 List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
178 for (final Cluster<T> cluster : clusters) {
179 final T newCenter;
180 if (cluster.getPoints().isEmpty()) {
181 switch (emptyStrategy) {
182 case LARGEST_VARIANCE :
183 newCenter = getPointFromLargestVarianceCluster(clusters);
184 break;
185 case LARGEST_POINTS_NUMBER :
186 newCenter = getPointFromLargestNumberCluster(clusters);
187 break;
188 case FARTHEST_POINT :
189 newCenter = getFarthestPoint(clusters);
190 break;
191 default :
192 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
193 }
194 emptyCluster = true;
195 } else {
196 newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
197 }
198 newClusters.add(new Cluster<T>(newCenter));
199 }
200 int changes = assignPointsToClusters(newClusters, points, assignments);
201 clusters = newClusters;
202
203 // if there were no more changes in the point-to-cluster assignment
204 // and there are no empty clusters left, return the current clusters
205 if (changes == 0 && !emptyCluster) {
206 return clusters;
207 }
208 }
209 return clusters;
210 }
211
212 /**
213 * Adds the given points to the closest {@link Cluster}.
214 *
215 * @param <T> type of the points to cluster
216 * @param clusters the {@link Cluster}s to add the points to
217 * @param points the points to add to the given {@link Cluster}s
218 * @param assignments points assignments to clusters
219 * @return the number of points assigned to different clusters as the iteration before
220 */
221 private static <T extends Clusterable<T>> int
222 assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points,
223 final int[] assignments) {
224 int assignedDifferently = 0;
225 int pointIndex = 0;
226 for (final T p : points) {
227 int clusterIndex = getNearestCluster(clusters, p);
228 if (clusterIndex != assignments[pointIndex]) {
229 assignedDifferently++;
230 }
231
232 Cluster<T> cluster = clusters.get(clusterIndex);
233 cluster.addPoint(p);
234 assignments[pointIndex++] = clusterIndex;
235 }
236
237 return assignedDifferently;
238 }
239
240 /**
241 * Use K-means++ to choose the initial centers.
242 *
243 * @param <T> type of the points to cluster
244 * @param points the points to choose the initial centers from
245 * @param k the number of centers to choose
246 * @param random random generator to use
247 * @return the initial centers
248 */
249 private static <T extends Clusterable<T>> List<Cluster<T>>
250 chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
251
252 // Convert to list for indexed access. Make it unmodifiable, since removal of items
253 // would screw up the logic of this method.
254 final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
255
256 // The number of points in the list.
257 final int numPoints = pointList.size();
258
259 // Set the corresponding element in this array to indicate when
260 // elements of pointList are no longer available.
261 final boolean[] taken = new boolean[numPoints];
262
263 // The resulting list of initial centers.
264 final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
265
266 // Choose one center uniformly at random from among the data points.
267 final int firstPointIndex = random.nextInt(numPoints);
268
269 final T firstPoint = pointList.get(firstPointIndex);
270
271 resultSet.add(new Cluster<T>(firstPoint));
272
273 // Must mark it as taken
274 taken[firstPointIndex] = true;
275
276 // To keep track of the minimum distance squared of elements of
277 // pointList to elements of resultSet.
278 final double[] minDistSquared = new double[numPoints];
279
280 // Initialize the elements. Since the only point in resultSet is firstPoint,
281 // this is very easy.
282 for (int i = 0; i < numPoints; i++) {
283 if (i != firstPointIndex) { // That point isn't considered
284 double d = firstPoint.distanceFrom(pointList.get(i));
285 minDistSquared[i] = d*d;
286 }
287 }
288
289 while (resultSet.size() < k) {
290
291 // Sum up the squared distances for the points in pointList not
292 // already taken.
293 double distSqSum = 0.0;
294
295 for (int i = 0; i < numPoints; i++) {
296 if (!taken[i]) {
297 distSqSum += minDistSquared[i];
298 }
299 }
300
301 // Add one new data point as a center. Each point x is chosen with
302 // probability proportional to D(x)2
303 final double r = random.nextDouble() * distSqSum;
304
305 // The index of the next point to be added to the resultSet.
306 int nextPointIndex = -1;
307
308 // Sum through the squared min distances again, stopping when
309 // sum >= r.
310 double sum = 0.0;
311 for (int i = 0; i < numPoints; i++) {
312 if (!taken[i]) {
313 sum += minDistSquared[i];
314 if (sum >= r) {
315 nextPointIndex = i;
316 break;
317 }
318 }
319 }
320
321 // If it's not set to >= 0, the point wasn't found in the previous
322 // for loop, probably because distances are extremely small. Just pick
323 // the last available point.
324 if (nextPointIndex == -1) {
325 for (int i = numPoints - 1; i >= 0; i--) {
326 if (!taken[i]) {
327 nextPointIndex = i;
328 break;
329 }
330 }
331 }
332
333 // We found one.
334 if (nextPointIndex >= 0) {
335
336 final T p = pointList.get(nextPointIndex);
337
338 resultSet.add(new Cluster<T> (p));
339
340 // Mark it as taken.
341 taken[nextPointIndex] = true;
342
343 if (resultSet.size() < k) {
344 // Now update elements of minDistSquared. We only have to compute
345 // the distance to the new center to do this.
346 for (int j = 0; j < numPoints; j++) {
347 // Only have to worry about the points still not taken.
348 if (!taken[j]) {
349 double d = p.distanceFrom(pointList.get(j));
350 double d2 = d * d;
351 if (d2 < minDistSquared[j]) {
352 minDistSquared[j] = d2;
353 }
354 }
355 }
356 }
357
358 } else {
359 // None found --
360 // Break from the while loop to prevent
361 // an infinite loop.
362 break;
363 }
364 }
365
366 return resultSet;
367 }
368
369 /**
370 * Get a random point from the {@link Cluster} with the largest distance variance.
371 *
372 * @param clusters the {@link Cluster}s to search
373 * @return a random point from the selected cluster
374 */
375 private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters) {
376
377 double maxVariance = Double.NEGATIVE_INFINITY;
378 Cluster<T> selected = null;
379 for (final Cluster<T> cluster : clusters) {
380 if (!cluster.getPoints().isEmpty()) {
381
382 // compute the distance variance of the current cluster
383 final T center = cluster.getCenter();
384 final Variance stat = new Variance();
385 for (final T point : cluster.getPoints()) {
386 stat.increment(point.distanceFrom(center));
387 }
388 final double variance = stat.getResult();
389
390 // select the cluster with the largest variance
391 if (variance > maxVariance) {
392 maxVariance = variance;
393 selected = cluster;
394 }
395
396 }
397 }
398
399 // did we find at least one non-empty cluster ?
400 if (selected == null) {
401 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
402 }
403
404 // extract a random point from the cluster
405 final List<T> selectedPoints = selected.getPoints();
406 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
407
408 }
409
410 /**
411 * Get a random point from the {@link Cluster} with the largest number of points
412 *
413 * @param clusters the {@link Cluster}s to search
414 * @return a random point from the selected cluster
415 */
416 private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) {
417
418 int maxNumber = 0;
419 Cluster<T> selected = null;
420 for (final Cluster<T> cluster : clusters) {
421
422 // get the number of points of the current cluster
423 final int number = cluster.getPoints().size();
424
425 // select the cluster with the largest number of points
426 if (number > maxNumber) {
427 maxNumber = number;
428 selected = cluster;
429 }
430
431 }
432
433 // did we find at least one non-empty cluster ?
434 if (selected == null) {
435 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
436 }
437
438 // extract a random point from the cluster
439 final List<T> selectedPoints = selected.getPoints();
440 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
441
442 }
443
444 /**
445 * Get the point farthest to its cluster center
446 *
447 * @param clusters the {@link Cluster}s to search
448 * @return point farthest to its cluster center
449 */
450 private T getFarthestPoint(final Collection<Cluster<T>> clusters) {
451
452 double maxDistance = Double.NEGATIVE_INFINITY;
453 Cluster<T> selectedCluster = null;
454 int selectedPoint = -1;
455 for (final Cluster<T> cluster : clusters) {
456
457 // get the farthest point
458 final T center = cluster.getCenter();
459 final List<T> points = cluster.getPoints();
460 for (int i = 0; i < points.size(); ++i) {
461 final double distance = points.get(i).distanceFrom(center);
462 if (distance > maxDistance) {
463 maxDistance = distance;
464 selectedCluster = cluster;
465 selectedPoint = i;
466 }
467 }
468
469 }
470
471 // did we find at least one non-empty cluster ?
472 if (selectedCluster == null) {
473 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
474 }
475
476 return selectedCluster.getPoints().remove(selectedPoint);
477
478 }
479
480 /**
481 * Returns the nearest {@link Cluster} to the given point
482 *
483 * @param <T> type of the points to cluster
484 * @param clusters the {@link Cluster}s to search
485 * @param point the point to find the nearest {@link Cluster} for
486 * @return the index of the nearest {@link Cluster} to the given point
487 */
488 private static <T extends Clusterable<T>> int
489 getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
490 double minDistance = Double.MAX_VALUE;
491 int clusterIndex = 0;
492 int minCluster = 0;
493 for (final Cluster<T> c : clusters) {
494 final double distance = point.distanceFrom(c.getCenter());
495 if (distance < minDistance) {
496 minDistance = distance;
497 minCluster = clusterIndex;
498 }
499 clusterIndex++;
500 }
501 return minCluster;
502 }
503
504 }