001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018 package org.apache.commons.math3.stat.clustering; 019 020 import java.util.ArrayList; 021 import java.util.Collection; 022 import java.util.Collections; 023 import java.util.List; 024 import java.util.Random; 025 026 import org.apache.commons.math3.exception.ConvergenceException; 027 import org.apache.commons.math3.exception.MathIllegalArgumentException; 028 import org.apache.commons.math3.exception.NumberIsTooSmallException; 029 import org.apache.commons.math3.exception.util.LocalizedFormats; 030 import org.apache.commons.math3.stat.descriptive.moment.Variance; 031 import org.apache.commons.math3.util.MathUtils; 032 033 /** 034 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm. 035 * @param <T> type of the points to cluster 036 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a> 037 * @version $Id: KMeansPlusPlusClusterer.java 1416643 2012-12-03 19:37:14Z tn $ 038 * @since 2.0 039 */ 040 public class KMeansPlusPlusClusterer<T extends Clusterable<T>> { 041 042 /** Strategies to use for replacing an empty cluster. */ 043 public static enum EmptyClusterStrategy { 044 045 /** Split the cluster with largest distance variance. */ 046 LARGEST_VARIANCE, 047 048 /** Split the cluster with largest number of points. */ 049 LARGEST_POINTS_NUMBER, 050 051 /** Create a cluster around the point farthest from its centroid. */ 052 FARTHEST_POINT, 053 054 /** Generate an error. */ 055 ERROR 056 057 } 058 059 /** Random generator for choosing initial centers. */ 060 private final Random random; 061 062 /** Selected strategy for empty clusters. */ 063 private final EmptyClusterStrategy emptyStrategy; 064 065 /** Build a clusterer. 066 * <p> 067 * The default strategy for handling empty clusters that may appear during 068 * algorithm iterations is to split the cluster with largest distance variance. 069 * </p> 070 * @param random random generator to use for choosing initial centers 071 */ 072 public KMeansPlusPlusClusterer(final Random random) { 073 this(random, EmptyClusterStrategy.LARGEST_VARIANCE); 074 } 075 076 /** Build a clusterer. 077 * @param random random generator to use for choosing initial centers 078 * @param emptyStrategy strategy to use for handling empty clusters that 079 * may appear during algorithm iterations 080 * @since 2.2 081 */ 082 public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) { 083 this.random = random; 084 this.emptyStrategy = emptyStrategy; 085 } 086 087 /** 088 * Runs the K-means++ clustering algorithm. 089 * 090 * @param points the points to cluster 091 * @param k the number of clusters to split the data into 092 * @param numTrials number of trial runs 093 * @param maxIterationsPerTrial the maximum number of iterations to run the algorithm 094 * for at each trial run. If negative, no maximum will be used 095 * @return a list of clusters containing the points 096 * @throws MathIllegalArgumentException if the data points are null or the number 097 * of clusters is larger than the number of data points 098 * @throws ConvergenceException if an empty cluster is encountered and the 099 * {@link #emptyStrategy} is set to {@code ERROR} 100 */ 101 public List<Cluster<T>> cluster(final Collection<T> points, final int k, 102 int numTrials, int maxIterationsPerTrial) 103 throws MathIllegalArgumentException, ConvergenceException { 104 105 // at first, we have not found any clusters list yet 106 List<Cluster<T>> best = null; 107 double bestVarianceSum = Double.POSITIVE_INFINITY; 108 109 // do several clustering trials 110 for (int i = 0; i < numTrials; ++i) { 111 112 // compute a clusters list 113 List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial); 114 115 // compute the variance of the current list 116 double varianceSum = 0.0; 117 for (final Cluster<T> cluster : clusters) { 118 if (!cluster.getPoints().isEmpty()) { 119 120 // compute the distance variance of the current cluster 121 final T center = cluster.getCenter(); 122 final Variance stat = new Variance(); 123 for (final T point : cluster.getPoints()) { 124 stat.increment(point.distanceFrom(center)); 125 } 126 varianceSum += stat.getResult(); 127 128 } 129 } 130 131 if (varianceSum <= bestVarianceSum) { 132 // this one is the best we have found so far, remember it 133 best = clusters; 134 bestVarianceSum = varianceSum; 135 } 136 137 } 138 139 // return the best clusters list found 140 return best; 141 142 } 143 144 /** 145 * Runs the K-means++ clustering algorithm. 146 * 147 * @param points the points to cluster 148 * @param k the number of clusters to split the data into 149 * @param maxIterations the maximum number of iterations to run the algorithm 150 * for. If negative, no maximum will be used 151 * @return a list of clusters containing the points 152 * @throws MathIllegalArgumentException if the data points are null or the number 153 * of clusters is larger than the number of data points 154 * @throws ConvergenceException if an empty cluster is encountered and the 155 * {@link #emptyStrategy} is set to {@code ERROR} 156 */ 157 public List<Cluster<T>> cluster(final Collection<T> points, final int k, 158 final int maxIterations) 159 throws MathIllegalArgumentException, ConvergenceException { 160 161 // sanity checks 162 MathUtils.checkNotNull(points); 163 164 // number of clusters has to be smaller or equal the number of data points 165 if (points.size() < k) { 166 throw new NumberIsTooSmallException(points.size(), k, false); 167 } 168 169 // create the initial clusters 170 List<Cluster<T>> clusters = chooseInitialCenters(points, k, random); 171 172 // create an array containing the latest assignment of a point to a cluster 173 // no need to initialize the array, as it will be filled with the first assignment 174 int[] assignments = new int[points.size()]; 175 assignPointsToClusters(clusters, points, assignments); 176 177 // iterate through updating the centers until we're done 178 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; 179 for (int count = 0; count < max; count++) { 180 boolean emptyCluster = false; 181 List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>(); 182 for (final Cluster<T> cluster : clusters) { 183 final T newCenter; 184 if (cluster.getPoints().isEmpty()) { 185 switch (emptyStrategy) { 186 case LARGEST_VARIANCE : 187 newCenter = getPointFromLargestVarianceCluster(clusters); 188 break; 189 case LARGEST_POINTS_NUMBER : 190 newCenter = getPointFromLargestNumberCluster(clusters); 191 break; 192 case FARTHEST_POINT : 193 newCenter = getFarthestPoint(clusters); 194 break; 195 default : 196 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 197 } 198 emptyCluster = true; 199 } else { 200 newCenter = cluster.getCenter().centroidOf(cluster.getPoints()); 201 } 202 newClusters.add(new Cluster<T>(newCenter)); 203 } 204 int changes = assignPointsToClusters(newClusters, points, assignments); 205 clusters = newClusters; 206 207 // if there were no more changes in the point-to-cluster assignment 208 // and there are no empty clusters left, return the current clusters 209 if (changes == 0 && !emptyCluster) { 210 return clusters; 211 } 212 } 213 return clusters; 214 } 215 216 /** 217 * Adds the given points to the closest {@link Cluster}. 218 * 219 * @param <T> type of the points to cluster 220 * @param clusters the {@link Cluster}s to add the points to 221 * @param points the points to add to the given {@link Cluster}s 222 * @param assignments points assignments to clusters 223 * @return the number of points assigned to different clusters as the iteration before 224 */ 225 private static <T extends Clusterable<T>> int 226 assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points, 227 final int[] assignments) { 228 int assignedDifferently = 0; 229 int pointIndex = 0; 230 for (final T p : points) { 231 int clusterIndex = getNearestCluster(clusters, p); 232 if (clusterIndex != assignments[pointIndex]) { 233 assignedDifferently++; 234 } 235 236 Cluster<T> cluster = clusters.get(clusterIndex); 237 cluster.addPoint(p); 238 assignments[pointIndex++] = clusterIndex; 239 } 240 241 return assignedDifferently; 242 } 243 244 /** 245 * Use K-means++ to choose the initial centers. 246 * 247 * @param <T> type of the points to cluster 248 * @param points the points to choose the initial centers from 249 * @param k the number of centers to choose 250 * @param random random generator to use 251 * @return the initial centers 252 */ 253 private static <T extends Clusterable<T>> List<Cluster<T>> 254 chooseInitialCenters(final Collection<T> points, final int k, final Random random) { 255 256 // Convert to list for indexed access. Make it unmodifiable, since removal of items 257 // would screw up the logic of this method. 258 final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points)); 259 260 // The number of points in the list. 261 final int numPoints = pointList.size(); 262 263 // Set the corresponding element in this array to indicate when 264 // elements of pointList are no longer available. 265 final boolean[] taken = new boolean[numPoints]; 266 267 // The resulting list of initial centers. 268 final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>(); 269 270 // Choose one center uniformly at random from among the data points. 271 final int firstPointIndex = random.nextInt(numPoints); 272 273 final T firstPoint = pointList.get(firstPointIndex); 274 275 resultSet.add(new Cluster<T>(firstPoint)); 276 277 // Must mark it as taken 278 taken[firstPointIndex] = true; 279 280 // To keep track of the minimum distance squared of elements of 281 // pointList to elements of resultSet. 282 final double[] minDistSquared = new double[numPoints]; 283 284 // Initialize the elements. Since the only point in resultSet is firstPoint, 285 // this is very easy. 286 for (int i = 0; i < numPoints; i++) { 287 if (i != firstPointIndex) { // That point isn't considered 288 double d = firstPoint.distanceFrom(pointList.get(i)); 289 minDistSquared[i] = d*d; 290 } 291 } 292 293 while (resultSet.size() < k) { 294 295 // Sum up the squared distances for the points in pointList not 296 // already taken. 297 double distSqSum = 0.0; 298 299 for (int i = 0; i < numPoints; i++) { 300 if (!taken[i]) { 301 distSqSum += minDistSquared[i]; 302 } 303 } 304 305 // Add one new data point as a center. Each point x is chosen with 306 // probability proportional to D(x)2 307 final double r = random.nextDouble() * distSqSum; 308 309 // The index of the next point to be added to the resultSet. 310 int nextPointIndex = -1; 311 312 // Sum through the squared min distances again, stopping when 313 // sum >= r. 314 double sum = 0.0; 315 for (int i = 0; i < numPoints; i++) { 316 if (!taken[i]) { 317 sum += minDistSquared[i]; 318 if (sum >= r) { 319 nextPointIndex = i; 320 break; 321 } 322 } 323 } 324 325 // If it's not set to >= 0, the point wasn't found in the previous 326 // for loop, probably because distances are extremely small. Just pick 327 // the last available point. 328 if (nextPointIndex == -1) { 329 for (int i = numPoints - 1; i >= 0; i--) { 330 if (!taken[i]) { 331 nextPointIndex = i; 332 break; 333 } 334 } 335 } 336 337 // We found one. 338 if (nextPointIndex >= 0) { 339 340 final T p = pointList.get(nextPointIndex); 341 342 resultSet.add(new Cluster<T> (p)); 343 344 // Mark it as taken. 345 taken[nextPointIndex] = true; 346 347 if (resultSet.size() < k) { 348 // Now update elements of minDistSquared. We only have to compute 349 // the distance to the new center to do this. 350 for (int j = 0; j < numPoints; j++) { 351 // Only have to worry about the points still not taken. 352 if (!taken[j]) { 353 double d = p.distanceFrom(pointList.get(j)); 354 double d2 = d * d; 355 if (d2 < minDistSquared[j]) { 356 minDistSquared[j] = d2; 357 } 358 } 359 } 360 } 361 362 } else { 363 // None found -- 364 // Break from the while loop to prevent 365 // an infinite loop. 366 break; 367 } 368 } 369 370 return resultSet; 371 } 372 373 /** 374 * Get a random point from the {@link Cluster} with the largest distance variance. 375 * 376 * @param clusters the {@link Cluster}s to search 377 * @return a random point from the selected cluster 378 * @throws ConvergenceException if clusters are all empty 379 */ 380 private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters) 381 throws ConvergenceException { 382 383 double maxVariance = Double.NEGATIVE_INFINITY; 384 Cluster<T> selected = null; 385 for (final Cluster<T> cluster : clusters) { 386 if (!cluster.getPoints().isEmpty()) { 387 388 // compute the distance variance of the current cluster 389 final T center = cluster.getCenter(); 390 final Variance stat = new Variance(); 391 for (final T point : cluster.getPoints()) { 392 stat.increment(point.distanceFrom(center)); 393 } 394 final double variance = stat.getResult(); 395 396 // select the cluster with the largest variance 397 if (variance > maxVariance) { 398 maxVariance = variance; 399 selected = cluster; 400 } 401 402 } 403 } 404 405 // did we find at least one non-empty cluster ? 406 if (selected == null) { 407 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 408 } 409 410 // extract a random point from the cluster 411 final List<T> selectedPoints = selected.getPoints(); 412 return selectedPoints.remove(random.nextInt(selectedPoints.size())); 413 414 } 415 416 /** 417 * Get a random point from the {@link Cluster} with the largest number of points 418 * 419 * @param clusters the {@link Cluster}s to search 420 * @return a random point from the selected cluster 421 * @throws ConvergenceException if clusters are all empty 422 */ 423 private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) throws ConvergenceException { 424 425 int maxNumber = 0; 426 Cluster<T> selected = null; 427 for (final Cluster<T> cluster : clusters) { 428 429 // get the number of points of the current cluster 430 final int number = cluster.getPoints().size(); 431 432 // select the cluster with the largest number of points 433 if (number > maxNumber) { 434 maxNumber = number; 435 selected = cluster; 436 } 437 438 } 439 440 // did we find at least one non-empty cluster ? 441 if (selected == null) { 442 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 443 } 444 445 // extract a random point from the cluster 446 final List<T> selectedPoints = selected.getPoints(); 447 return selectedPoints.remove(random.nextInt(selectedPoints.size())); 448 449 } 450 451 /** 452 * Get the point farthest to its cluster center 453 * 454 * @param clusters the {@link Cluster}s to search 455 * @return point farthest to its cluster center 456 * @throws ConvergenceException if clusters are all empty 457 */ 458 private T getFarthestPoint(final Collection<Cluster<T>> clusters) throws ConvergenceException { 459 460 double maxDistance = Double.NEGATIVE_INFINITY; 461 Cluster<T> selectedCluster = null; 462 int selectedPoint = -1; 463 for (final Cluster<T> cluster : clusters) { 464 465 // get the farthest point 466 final T center = cluster.getCenter(); 467 final List<T> points = cluster.getPoints(); 468 for (int i = 0; i < points.size(); ++i) { 469 final double distance = points.get(i).distanceFrom(center); 470 if (distance > maxDistance) { 471 maxDistance = distance; 472 selectedCluster = cluster; 473 selectedPoint = i; 474 } 475 } 476 477 } 478 479 // did we find at least one non-empty cluster ? 480 if (selectedCluster == null) { 481 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 482 } 483 484 return selectedCluster.getPoints().remove(selectedPoint); 485 486 } 487 488 /** 489 * Returns the nearest {@link Cluster} to the given point 490 * 491 * @param <T> type of the points to cluster 492 * @param clusters the {@link Cluster}s to search 493 * @param point the point to find the nearest {@link Cluster} for 494 * @return the index of the nearest {@link Cluster} to the given point 495 */ 496 private static <T extends Clusterable<T>> int 497 getNearestCluster(final Collection<Cluster<T>> clusters, final T point) { 498 double minDistance = Double.MAX_VALUE; 499 int clusterIndex = 0; 500 int minCluster = 0; 501 for (final Cluster<T> c : clusters) { 502 final double distance = point.distanceFrom(c.getCenter()); 503 if (distance < minDistance) { 504 minDistance = distance; 505 minCluster = clusterIndex; 506 } 507 clusterIndex++; 508 } 509 return minCluster; 510 } 511 512 }