001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.apache.commons.math3.stat.clustering;
019
020 import java.util.ArrayList;
021 import java.util.Collection;
022 import java.util.Collections;
023 import java.util.List;
024 import java.util.Random;
025
026 import org.apache.commons.math3.exception.ConvergenceException;
027 import org.apache.commons.math3.exception.MathIllegalArgumentException;
028 import org.apache.commons.math3.exception.NumberIsTooSmallException;
029 import org.apache.commons.math3.exception.util.LocalizedFormats;
030 import org.apache.commons.math3.stat.descriptive.moment.Variance;
031 import org.apache.commons.math3.util.MathUtils;
032
033 /**
034 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
035 * @param <T> type of the points to cluster
036 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
037 * @version $Id: KMeansPlusPlusClusterer.java 1461871 2013-03-27 22:01:25Z tn $
038 * @since 2.0
039 * @deprecated As of 3.2 (to be removed in 4.0),
040 * use {@link org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer} instead
041 */
042 @Deprecated
043 public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
044
045 /** Strategies to use for replacing an empty cluster. */
046 public static enum EmptyClusterStrategy {
047
048 /** Split the cluster with largest distance variance. */
049 LARGEST_VARIANCE,
050
051 /** Split the cluster with largest number of points. */
052 LARGEST_POINTS_NUMBER,
053
054 /** Create a cluster around the point farthest from its centroid. */
055 FARTHEST_POINT,
056
057 /** Generate an error. */
058 ERROR
059
060 }
061
062 /** Random generator for choosing initial centers. */
063 private final Random random;
064
065 /** Selected strategy for empty clusters. */
066 private final EmptyClusterStrategy emptyStrategy;
067
068 /** Build a clusterer.
069 * <p>
070 * The default strategy for handling empty clusters that may appear during
071 * algorithm iterations is to split the cluster with largest distance variance.
072 * </p>
073 * @param random random generator to use for choosing initial centers
074 */
075 public KMeansPlusPlusClusterer(final Random random) {
076 this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
077 }
078
079 /** Build a clusterer.
080 * @param random random generator to use for choosing initial centers
081 * @param emptyStrategy strategy to use for handling empty clusters that
082 * may appear during algorithm iterations
083 * @since 2.2
084 */
085 public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
086 this.random = random;
087 this.emptyStrategy = emptyStrategy;
088 }
089
090 /**
091 * Runs the K-means++ clustering algorithm.
092 *
093 * @param points the points to cluster
094 * @param k the number of clusters to split the data into
095 * @param numTrials number of trial runs
096 * @param maxIterationsPerTrial the maximum number of iterations to run the algorithm
097 * for at each trial run. If negative, no maximum will be used
098 * @return a list of clusters containing the points
099 * @throws MathIllegalArgumentException if the data points are null or the number
100 * of clusters is larger than the number of data points
101 * @throws ConvergenceException if an empty cluster is encountered and the
102 * {@link #emptyStrategy} is set to {@code ERROR}
103 */
104 public List<Cluster<T>> cluster(final Collection<T> points, final int k,
105 int numTrials, int maxIterationsPerTrial)
106 throws MathIllegalArgumentException, ConvergenceException {
107
108 // at first, we have not found any clusters list yet
109 List<Cluster<T>> best = null;
110 double bestVarianceSum = Double.POSITIVE_INFINITY;
111
112 // do several clustering trials
113 for (int i = 0; i < numTrials; ++i) {
114
115 // compute a clusters list
116 List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial);
117
118 // compute the variance of the current list
119 double varianceSum = 0.0;
120 for (final Cluster<T> cluster : clusters) {
121 if (!cluster.getPoints().isEmpty()) {
122
123 // compute the distance variance of the current cluster
124 final T center = cluster.getCenter();
125 final Variance stat = new Variance();
126 for (final T point : cluster.getPoints()) {
127 stat.increment(point.distanceFrom(center));
128 }
129 varianceSum += stat.getResult();
130
131 }
132 }
133
134 if (varianceSum <= bestVarianceSum) {
135 // this one is the best we have found so far, remember it
136 best = clusters;
137 bestVarianceSum = varianceSum;
138 }
139
140 }
141
142 // return the best clusters list found
143 return best;
144
145 }
146
147 /**
148 * Runs the K-means++ clustering algorithm.
149 *
150 * @param points the points to cluster
151 * @param k the number of clusters to split the data into
152 * @param maxIterations the maximum number of iterations to run the algorithm
153 * for. If negative, no maximum will be used
154 * @return a list of clusters containing the points
155 * @throws MathIllegalArgumentException if the data points are null or the number
156 * of clusters is larger than the number of data points
157 * @throws ConvergenceException if an empty cluster is encountered and the
158 * {@link #emptyStrategy} is set to {@code ERROR}
159 */
160 public List<Cluster<T>> cluster(final Collection<T> points, final int k,
161 final int maxIterations)
162 throws MathIllegalArgumentException, ConvergenceException {
163
164 // sanity checks
165 MathUtils.checkNotNull(points);
166
167 // number of clusters has to be smaller or equal the number of data points
168 if (points.size() < k) {
169 throw new NumberIsTooSmallException(points.size(), k, false);
170 }
171
172 // create the initial clusters
173 List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
174
175 // create an array containing the latest assignment of a point to a cluster
176 // no need to initialize the array, as it will be filled with the first assignment
177 int[] assignments = new int[points.size()];
178 assignPointsToClusters(clusters, points, assignments);
179
180 // iterate through updating the centers until we're done
181 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
182 for (int count = 0; count < max; count++) {
183 boolean emptyCluster = false;
184 List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
185 for (final Cluster<T> cluster : clusters) {
186 final T newCenter;
187 if (cluster.getPoints().isEmpty()) {
188 switch (emptyStrategy) {
189 case LARGEST_VARIANCE :
190 newCenter = getPointFromLargestVarianceCluster(clusters);
191 break;
192 case LARGEST_POINTS_NUMBER :
193 newCenter = getPointFromLargestNumberCluster(clusters);
194 break;
195 case FARTHEST_POINT :
196 newCenter = getFarthestPoint(clusters);
197 break;
198 default :
199 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
200 }
201 emptyCluster = true;
202 } else {
203 newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
204 }
205 newClusters.add(new Cluster<T>(newCenter));
206 }
207 int changes = assignPointsToClusters(newClusters, points, assignments);
208 clusters = newClusters;
209
210 // if there were no more changes in the point-to-cluster assignment
211 // and there are no empty clusters left, return the current clusters
212 if (changes == 0 && !emptyCluster) {
213 return clusters;
214 }
215 }
216 return clusters;
217 }
218
219 /**
220 * Adds the given points to the closest {@link Cluster}.
221 *
222 * @param <T> type of the points to cluster
223 * @param clusters the {@link Cluster}s to add the points to
224 * @param points the points to add to the given {@link Cluster}s
225 * @param assignments points assignments to clusters
226 * @return the number of points assigned to different clusters as the iteration before
227 */
228 private static <T extends Clusterable<T>> int
229 assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points,
230 final int[] assignments) {
231 int assignedDifferently = 0;
232 int pointIndex = 0;
233 for (final T p : points) {
234 int clusterIndex = getNearestCluster(clusters, p);
235 if (clusterIndex != assignments[pointIndex]) {
236 assignedDifferently++;
237 }
238
239 Cluster<T> cluster = clusters.get(clusterIndex);
240 cluster.addPoint(p);
241 assignments[pointIndex++] = clusterIndex;
242 }
243
244 return assignedDifferently;
245 }
246
247 /**
248 * Use K-means++ to choose the initial centers.
249 *
250 * @param <T> type of the points to cluster
251 * @param points the points to choose the initial centers from
252 * @param k the number of centers to choose
253 * @param random random generator to use
254 * @return the initial centers
255 */
256 private static <T extends Clusterable<T>> List<Cluster<T>>
257 chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
258
259 // Convert to list for indexed access. Make it unmodifiable, since removal of items
260 // would screw up the logic of this method.
261 final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
262
263 // The number of points in the list.
264 final int numPoints = pointList.size();
265
266 // Set the corresponding element in this array to indicate when
267 // elements of pointList are no longer available.
268 final boolean[] taken = new boolean[numPoints];
269
270 // The resulting list of initial centers.
271 final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
272
273 // Choose one center uniformly at random from among the data points.
274 final int firstPointIndex = random.nextInt(numPoints);
275
276 final T firstPoint = pointList.get(firstPointIndex);
277
278 resultSet.add(new Cluster<T>(firstPoint));
279
280 // Must mark it as taken
281 taken[firstPointIndex] = true;
282
283 // To keep track of the minimum distance squared of elements of
284 // pointList to elements of resultSet.
285 final double[] minDistSquared = new double[numPoints];
286
287 // Initialize the elements. Since the only point in resultSet is firstPoint,
288 // this is very easy.
289 for (int i = 0; i < numPoints; i++) {
290 if (i != firstPointIndex) { // That point isn't considered
291 double d = firstPoint.distanceFrom(pointList.get(i));
292 minDistSquared[i] = d*d;
293 }
294 }
295
296 while (resultSet.size() < k) {
297
298 // Sum up the squared distances for the points in pointList not
299 // already taken.
300 double distSqSum = 0.0;
301
302 for (int i = 0; i < numPoints; i++) {
303 if (!taken[i]) {
304 distSqSum += minDistSquared[i];
305 }
306 }
307
308 // Add one new data point as a center. Each point x is chosen with
309 // probability proportional to D(x)2
310 final double r = random.nextDouble() * distSqSum;
311
312 // The index of the next point to be added to the resultSet.
313 int nextPointIndex = -1;
314
315 // Sum through the squared min distances again, stopping when
316 // sum >= r.
317 double sum = 0.0;
318 for (int i = 0; i < numPoints; i++) {
319 if (!taken[i]) {
320 sum += minDistSquared[i];
321 if (sum >= r) {
322 nextPointIndex = i;
323 break;
324 }
325 }
326 }
327
328 // If it's not set to >= 0, the point wasn't found in the previous
329 // for loop, probably because distances are extremely small. Just pick
330 // the last available point.
331 if (nextPointIndex == -1) {
332 for (int i = numPoints - 1; i >= 0; i--) {
333 if (!taken[i]) {
334 nextPointIndex = i;
335 break;
336 }
337 }
338 }
339
340 // We found one.
341 if (nextPointIndex >= 0) {
342
343 final T p = pointList.get(nextPointIndex);
344
345 resultSet.add(new Cluster<T> (p));
346
347 // Mark it as taken.
348 taken[nextPointIndex] = true;
349
350 if (resultSet.size() < k) {
351 // Now update elements of minDistSquared. We only have to compute
352 // the distance to the new center to do this.
353 for (int j = 0; j < numPoints; j++) {
354 // Only have to worry about the points still not taken.
355 if (!taken[j]) {
356 double d = p.distanceFrom(pointList.get(j));
357 double d2 = d * d;
358 if (d2 < minDistSquared[j]) {
359 minDistSquared[j] = d2;
360 }
361 }
362 }
363 }
364
365 } else {
366 // None found --
367 // Break from the while loop to prevent
368 // an infinite loop.
369 break;
370 }
371 }
372
373 return resultSet;
374 }
375
376 /**
377 * Get a random point from the {@link Cluster} with the largest distance variance.
378 *
379 * @param clusters the {@link Cluster}s to search
380 * @return a random point from the selected cluster
381 * @throws ConvergenceException if clusters are all empty
382 */
383 private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters)
384 throws ConvergenceException {
385
386 double maxVariance = Double.NEGATIVE_INFINITY;
387 Cluster<T> selected = null;
388 for (final Cluster<T> cluster : clusters) {
389 if (!cluster.getPoints().isEmpty()) {
390
391 // compute the distance variance of the current cluster
392 final T center = cluster.getCenter();
393 final Variance stat = new Variance();
394 for (final T point : cluster.getPoints()) {
395 stat.increment(point.distanceFrom(center));
396 }
397 final double variance = stat.getResult();
398
399 // select the cluster with the largest variance
400 if (variance > maxVariance) {
401 maxVariance = variance;
402 selected = cluster;
403 }
404
405 }
406 }
407
408 // did we find at least one non-empty cluster ?
409 if (selected == null) {
410 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
411 }
412
413 // extract a random point from the cluster
414 final List<T> selectedPoints = selected.getPoints();
415 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
416
417 }
418
419 /**
420 * Get a random point from the {@link Cluster} with the largest number of points
421 *
422 * @param clusters the {@link Cluster}s to search
423 * @return a random point from the selected cluster
424 * @throws ConvergenceException if clusters are all empty
425 */
426 private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) throws ConvergenceException {
427
428 int maxNumber = 0;
429 Cluster<T> selected = null;
430 for (final Cluster<T> cluster : clusters) {
431
432 // get the number of points of the current cluster
433 final int number = cluster.getPoints().size();
434
435 // select the cluster with the largest number of points
436 if (number > maxNumber) {
437 maxNumber = number;
438 selected = cluster;
439 }
440
441 }
442
443 // did we find at least one non-empty cluster ?
444 if (selected == null) {
445 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
446 }
447
448 // extract a random point from the cluster
449 final List<T> selectedPoints = selected.getPoints();
450 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
451
452 }
453
454 /**
455 * Get the point farthest to its cluster center
456 *
457 * @param clusters the {@link Cluster}s to search
458 * @return point farthest to its cluster center
459 * @throws ConvergenceException if clusters are all empty
460 */
461 private T getFarthestPoint(final Collection<Cluster<T>> clusters) throws ConvergenceException {
462
463 double maxDistance = Double.NEGATIVE_INFINITY;
464 Cluster<T> selectedCluster = null;
465 int selectedPoint = -1;
466 for (final Cluster<T> cluster : clusters) {
467
468 // get the farthest point
469 final T center = cluster.getCenter();
470 final List<T> points = cluster.getPoints();
471 for (int i = 0; i < points.size(); ++i) {
472 final double distance = points.get(i).distanceFrom(center);
473 if (distance > maxDistance) {
474 maxDistance = distance;
475 selectedCluster = cluster;
476 selectedPoint = i;
477 }
478 }
479
480 }
481
482 // did we find at least one non-empty cluster ?
483 if (selectedCluster == null) {
484 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
485 }
486
487 return selectedCluster.getPoints().remove(selectedPoint);
488
489 }
490
491 /**
492 * Returns the nearest {@link Cluster} to the given point
493 *
494 * @param <T> type of the points to cluster
495 * @param clusters the {@link Cluster}s to search
496 * @param point the point to find the nearest {@link Cluster} for
497 * @return the index of the nearest {@link Cluster} to the given point
498 */
499 private static <T extends Clusterable<T>> int
500 getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
501 double minDistance = Double.MAX_VALUE;
502 int clusterIndex = 0;
503 int minCluster = 0;
504 for (final Cluster<T> c : clusters) {
505 final double distance = point.distanceFrom(c.getCenter());
506 if (distance < minDistance) {
507 minDistance = distance;
508 minCluster = clusterIndex;
509 }
510 clusterIndex++;
511 }
512 return minCluster;
513 }
514
515 }