001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math3.ml.clustering; 019 020import java.util.ArrayList; 021import java.util.Collection; 022import java.util.Collections; 023import java.util.List; 024 025import org.apache.commons.math3.exception.ConvergenceException; 026import org.apache.commons.math3.exception.MathIllegalArgumentException; 027import org.apache.commons.math3.exception.NumberIsTooSmallException; 028import org.apache.commons.math3.exception.util.LocalizedFormats; 029import org.apache.commons.math3.ml.distance.DistanceMeasure; 030import org.apache.commons.math3.ml.distance.EuclideanDistance; 031import org.apache.commons.math3.random.JDKRandomGenerator; 032import org.apache.commons.math3.random.RandomGenerator; 033import org.apache.commons.math3.stat.descriptive.moment.Variance; 034import org.apache.commons.math3.util.MathUtils; 035 036/** 037 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm. 038 * @param <T> type of the points to cluster 039 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a> 040 * @since 3.2 041 */ 042public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> { 043 044 /** Strategies to use for replacing an empty cluster. */ 045 public enum EmptyClusterStrategy { 046 047 /** Split the cluster with largest distance variance. */ 048 LARGEST_VARIANCE, 049 050 /** Split the cluster with largest number of points. */ 051 LARGEST_POINTS_NUMBER, 052 053 /** Create a cluster around the point farthest from its centroid. */ 054 FARTHEST_POINT, 055 056 /** Generate an error. */ 057 ERROR 058 059 } 060 061 /** The number of clusters. */ 062 private final int k; 063 064 /** The maximum number of iterations. */ 065 private final int maxIterations; 066 067 /** Random generator for choosing initial centers. */ 068 private final RandomGenerator random; 069 070 /** Selected strategy for empty clusters. */ 071 private final EmptyClusterStrategy emptyStrategy; 072 073 /** Build a clusterer. 074 * <p> 075 * The default strategy for handling empty clusters that may appear during 076 * algorithm iterations is to split the cluster with largest distance variance. 077 * <p> 078 * The euclidean distance will be used as default distance measure. 079 * 080 * @param k the number of clusters to split the data into 081 */ 082 public KMeansPlusPlusClusterer(final int k) { 083 this(k, -1); 084 } 085 086 /** Build a clusterer. 087 * <p> 088 * The default strategy for handling empty clusters that may appear during 089 * algorithm iterations is to split the cluster with largest distance variance. 090 * <p> 091 * The euclidean distance will be used as default distance measure. 092 * 093 * @param k the number of clusters to split the data into 094 * @param maxIterations the maximum number of iterations to run the algorithm for. 095 * If negative, no maximum will be used. 096 */ 097 public KMeansPlusPlusClusterer(final int k, final int maxIterations) { 098 this(k, maxIterations, new EuclideanDistance()); 099 } 100 101 /** Build a clusterer. 102 * <p> 103 * The default strategy for handling empty clusters that may appear during 104 * algorithm iterations is to split the cluster with largest distance variance. 105 * 106 * @param k the number of clusters to split the data into 107 * @param maxIterations the maximum number of iterations to run the algorithm for. 108 * If negative, no maximum will be used. 109 * @param measure the distance measure to use 110 */ 111 public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) { 112 this(k, maxIterations, measure, new JDKRandomGenerator()); 113 } 114 115 /** Build a clusterer. 116 * <p> 117 * The default strategy for handling empty clusters that may appear during 118 * algorithm iterations is to split the cluster with largest distance variance. 119 * 120 * @param k the number of clusters to split the data into 121 * @param maxIterations the maximum number of iterations to run the algorithm for. 122 * If negative, no maximum will be used. 123 * @param measure the distance measure to use 124 * @param random random generator to use for choosing initial centers 125 */ 126 public KMeansPlusPlusClusterer(final int k, final int maxIterations, 127 final DistanceMeasure measure, 128 final RandomGenerator random) { 129 this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE); 130 } 131 132 /** Build a clusterer. 133 * 134 * @param k the number of clusters to split the data into 135 * @param maxIterations the maximum number of iterations to run the algorithm for. 136 * If negative, no maximum will be used. 137 * @param measure the distance measure to use 138 * @param random random generator to use for choosing initial centers 139 * @param emptyStrategy strategy to use for handling empty clusters that 140 * may appear during algorithm iterations 141 */ 142 public KMeansPlusPlusClusterer(final int k, final int maxIterations, 143 final DistanceMeasure measure, 144 final RandomGenerator random, 145 final EmptyClusterStrategy emptyStrategy) { 146 super(measure); 147 this.k = k; 148 this.maxIterations = maxIterations; 149 this.random = random; 150 this.emptyStrategy = emptyStrategy; 151 } 152 153 /** 154 * Return the number of clusters this instance will use. 155 * @return the number of clusters 156 */ 157 public int getK() { 158 return k; 159 } 160 161 /** 162 * Returns the maximum number of iterations this instance will use. 163 * @return the maximum number of iterations, or -1 if no maximum is set 164 */ 165 public int getMaxIterations() { 166 return maxIterations; 167 } 168 169 /** 170 * Returns the random generator this instance will use. 171 * @return the random generator 172 */ 173 public RandomGenerator getRandomGenerator() { 174 return random; 175 } 176 177 /** 178 * Returns the {@link EmptyClusterStrategy} used by this instance. 179 * @return the {@link EmptyClusterStrategy} 180 */ 181 public EmptyClusterStrategy getEmptyClusterStrategy() { 182 return emptyStrategy; 183 } 184 185 /** 186 * Runs the K-means++ clustering algorithm. 187 * 188 * @param points the points to cluster 189 * @return a list of clusters containing the points 190 * @throws MathIllegalArgumentException if the data points are null or the number 191 * of clusters is larger than the number of data points 192 * @throws ConvergenceException if an empty cluster is encountered and the 193 * {@link #emptyStrategy} is set to {@code ERROR} 194 */ 195 @Override 196 public List<CentroidCluster<T>> cluster(final Collection<T> points) 197 throws MathIllegalArgumentException, ConvergenceException { 198 199 // sanity checks 200 MathUtils.checkNotNull(points); 201 202 // number of clusters has to be smaller or equal the number of data points 203 if (points.size() < k) { 204 throw new NumberIsTooSmallException(points.size(), k, false); 205 } 206 207 // create the initial clusters 208 List<CentroidCluster<T>> clusters = chooseInitialCenters(points); 209 210 // create an array containing the latest assignment of a point to a cluster 211 // no need to initialize the array, as it will be filled with the first assignment 212 int[] assignments = new int[points.size()]; 213 assignPointsToClusters(clusters, points, assignments); 214 215 // iterate through updating the centers until we're done 216 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; 217 for (int count = 0; count < max; count++) { 218 boolean emptyCluster = false; 219 List<CentroidCluster<T>> newClusters = new ArrayList<CentroidCluster<T>>(); 220 for (final CentroidCluster<T> cluster : clusters) { 221 final Clusterable newCenter; 222 if (cluster.getPoints().isEmpty()) { 223 switch (emptyStrategy) { 224 case LARGEST_VARIANCE : 225 newCenter = getPointFromLargestVarianceCluster(clusters); 226 break; 227 case LARGEST_POINTS_NUMBER : 228 newCenter = getPointFromLargestNumberCluster(clusters); 229 break; 230 case FARTHEST_POINT : 231 newCenter = getFarthestPoint(clusters); 232 break; 233 default : 234 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 235 } 236 emptyCluster = true; 237 } else { 238 newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length); 239 } 240 newClusters.add(new CentroidCluster<T>(newCenter)); 241 } 242 int changes = assignPointsToClusters(newClusters, points, assignments); 243 clusters = newClusters; 244 245 // if there were no more changes in the point-to-cluster assignment 246 // and there are no empty clusters left, return the current clusters 247 if (changes == 0 && !emptyCluster) { 248 return clusters; 249 } 250 } 251 return clusters; 252 } 253 254 /** 255 * Adds the given points to the closest {@link Cluster}. 256 * 257 * @param clusters the {@link Cluster}s to add the points to 258 * @param points the points to add to the given {@link Cluster}s 259 * @param assignments points assignments to clusters 260 * @return the number of points assigned to different clusters as the iteration before 261 */ 262 private int assignPointsToClusters(final List<CentroidCluster<T>> clusters, 263 final Collection<T> points, 264 final int[] assignments) { 265 int assignedDifferently = 0; 266 int pointIndex = 0; 267 for (final T p : points) { 268 int clusterIndex = getNearestCluster(clusters, p); 269 if (clusterIndex != assignments[pointIndex]) { 270 assignedDifferently++; 271 } 272 273 CentroidCluster<T> cluster = clusters.get(clusterIndex); 274 cluster.addPoint(p); 275 assignments[pointIndex++] = clusterIndex; 276 } 277 278 return assignedDifferently; 279 } 280 281 /** 282 * Use K-means++ to choose the initial centers. 283 * 284 * @param points the points to choose the initial centers from 285 * @return the initial centers 286 */ 287 private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) { 288 289 // Convert to list for indexed access. Make it unmodifiable, since removal of items 290 // would screw up the logic of this method. 291 final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points)); 292 293 // The number of points in the list. 294 final int numPoints = pointList.size(); 295 296 // Set the corresponding element in this array to indicate when 297 // elements of pointList are no longer available. 298 final boolean[] taken = new boolean[numPoints]; 299 300 // The resulting list of initial centers. 301 final List<CentroidCluster<T>> resultSet = new ArrayList<CentroidCluster<T>>(); 302 303 // Choose one center uniformly at random from among the data points. 304 final int firstPointIndex = random.nextInt(numPoints); 305 306 final T firstPoint = pointList.get(firstPointIndex); 307 308 resultSet.add(new CentroidCluster<T>(firstPoint)); 309 310 // Must mark it as taken 311 taken[firstPointIndex] = true; 312 313 // To keep track of the minimum distance squared of elements of 314 // pointList to elements of resultSet. 315 final double[] minDistSquared = new double[numPoints]; 316 317 // Initialize the elements. Since the only point in resultSet is firstPoint, 318 // this is very easy. 319 for (int i = 0; i < numPoints; i++) { 320 if (i != firstPointIndex) { // That point isn't considered 321 double d = distance(firstPoint, pointList.get(i)); 322 minDistSquared[i] = d*d; 323 } 324 } 325 326 while (resultSet.size() < k) { 327 328 // Sum up the squared distances for the points in pointList not 329 // already taken. 330 double distSqSum = 0.0; 331 332 for (int i = 0; i < numPoints; i++) { 333 if (!taken[i]) { 334 distSqSum += minDistSquared[i]; 335 } 336 } 337 338 // Add one new data point as a center. Each point x is chosen with 339 // probability proportional to D(x)2 340 final double r = random.nextDouble() * distSqSum; 341 342 // The index of the next point to be added to the resultSet. 343 int nextPointIndex = -1; 344 345 // Sum through the squared min distances again, stopping when 346 // sum >= r. 347 double sum = 0.0; 348 for (int i = 0; i < numPoints; i++) { 349 if (!taken[i]) { 350 sum += minDistSquared[i]; 351 if (sum >= r) { 352 nextPointIndex = i; 353 break; 354 } 355 } 356 } 357 358 // If it's not set to >= 0, the point wasn't found in the previous 359 // for loop, probably because distances are extremely small. Just pick 360 // the last available point. 361 if (nextPointIndex == -1) { 362 for (int i = numPoints - 1; i >= 0; i--) { 363 if (!taken[i]) { 364 nextPointIndex = i; 365 break; 366 } 367 } 368 } 369 370 // We found one. 371 if (nextPointIndex >= 0) { 372 373 final T p = pointList.get(nextPointIndex); 374 375 resultSet.add(new CentroidCluster<T> (p)); 376 377 // Mark it as taken. 378 taken[nextPointIndex] = true; 379 380 if (resultSet.size() < k) { 381 // Now update elements of minDistSquared. We only have to compute 382 // the distance to the new center to do this. 383 for (int j = 0; j < numPoints; j++) { 384 // Only have to worry about the points still not taken. 385 if (!taken[j]) { 386 double d = distance(p, pointList.get(j)); 387 double d2 = d * d; 388 if (d2 < minDistSquared[j]) { 389 minDistSquared[j] = d2; 390 } 391 } 392 } 393 } 394 395 } else { 396 // None found -- 397 // Break from the while loop to prevent 398 // an infinite loop. 399 break; 400 } 401 } 402 403 return resultSet; 404 } 405 406 /** 407 * Get a random point from the {@link Cluster} with the largest distance variance. 408 * 409 * @param clusters the {@link Cluster}s to search 410 * @return a random point from the selected cluster 411 * @throws ConvergenceException if clusters are all empty 412 */ 413 private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters) 414 throws ConvergenceException { 415 416 double maxVariance = Double.NEGATIVE_INFINITY; 417 Cluster<T> selected = null; 418 for (final CentroidCluster<T> cluster : clusters) { 419 if (!cluster.getPoints().isEmpty()) { 420 421 // compute the distance variance of the current cluster 422 final Clusterable center = cluster.getCenter(); 423 final Variance stat = new Variance(); 424 for (final T point : cluster.getPoints()) { 425 stat.increment(distance(point, center)); 426 } 427 final double variance = stat.getResult(); 428 429 // select the cluster with the largest variance 430 if (variance > maxVariance) { 431 maxVariance = variance; 432 selected = cluster; 433 } 434 435 } 436 } 437 438 // did we find at least one non-empty cluster ? 439 if (selected == null) { 440 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 441 } 442 443 // extract a random point from the cluster 444 final List<T> selectedPoints = selected.getPoints(); 445 return selectedPoints.remove(random.nextInt(selectedPoints.size())); 446 447 } 448 449 /** 450 * Get a random point from the {@link Cluster} with the largest number of points 451 * 452 * @param clusters the {@link Cluster}s to search 453 * @return a random point from the selected cluster 454 * @throws ConvergenceException if clusters are all empty 455 */ 456 private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters) 457 throws ConvergenceException { 458 459 int maxNumber = 0; 460 Cluster<T> selected = null; 461 for (final Cluster<T> cluster : clusters) { 462 463 // get the number of points of the current cluster 464 final int number = cluster.getPoints().size(); 465 466 // select the cluster with the largest number of points 467 if (number > maxNumber) { 468 maxNumber = number; 469 selected = cluster; 470 } 471 472 } 473 474 // did we find at least one non-empty cluster ? 475 if (selected == null) { 476 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 477 } 478 479 // extract a random point from the cluster 480 final List<T> selectedPoints = selected.getPoints(); 481 return selectedPoints.remove(random.nextInt(selectedPoints.size())); 482 483 } 484 485 /** 486 * Get the point farthest to its cluster center 487 * 488 * @param clusters the {@link Cluster}s to search 489 * @return point farthest to its cluster center 490 * @throws ConvergenceException if clusters are all empty 491 */ 492 private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) throws ConvergenceException { 493 494 double maxDistance = Double.NEGATIVE_INFINITY; 495 Cluster<T> selectedCluster = null; 496 int selectedPoint = -1; 497 for (final CentroidCluster<T> cluster : clusters) { 498 499 // get the farthest point 500 final Clusterable center = cluster.getCenter(); 501 final List<T> points = cluster.getPoints(); 502 for (int i = 0; i < points.size(); ++i) { 503 final double distance = distance(points.get(i), center); 504 if (distance > maxDistance) { 505 maxDistance = distance; 506 selectedCluster = cluster; 507 selectedPoint = i; 508 } 509 } 510 511 } 512 513 // did we find at least one non-empty cluster ? 514 if (selectedCluster == null) { 515 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 516 } 517 518 return selectedCluster.getPoints().remove(selectedPoint); 519 520 } 521 522 /** 523 * Returns the nearest {@link Cluster} to the given point 524 * 525 * @param clusters the {@link Cluster}s to search 526 * @param point the point to find the nearest {@link Cluster} for 527 * @return the index of the nearest {@link Cluster} to the given point 528 */ 529 private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) { 530 double minDistance = Double.MAX_VALUE; 531 int clusterIndex = 0; 532 int minCluster = 0; 533 for (final CentroidCluster<T> c : clusters) { 534 final double distance = distance(point, c.getCenter()); 535 if (distance < minDistance) { 536 minDistance = distance; 537 minCluster = clusterIndex; 538 } 539 clusterIndex++; 540 } 541 return minCluster; 542 } 543 544 /** 545 * Computes the centroid for a set of points. 546 * 547 * @param points the set of points 548 * @param dimension the point dimension 549 * @return the computed centroid for the set of points 550 */ 551 private Clusterable centroidOf(final Collection<T> points, final int dimension) { 552 final double[] centroid = new double[dimension]; 553 for (final T p : points) { 554 final double[] point = p.getPoint(); 555 for (int i = 0; i < centroid.length; i++) { 556 centroid[i] += point[i]; 557 } 558 } 559 for (int i = 0; i < centroid.length; i++) { 560 centroid[i] /= points.size(); 561 } 562 return new DoublePoint(centroid); 563 } 564 565}