001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math3.stat.clustering; 019 020import java.util.ArrayList; 021import java.util.Collection; 022import java.util.Collections; 023import java.util.List; 024import java.util.Random; 025 026import org.apache.commons.math3.exception.ConvergenceException; 027import org.apache.commons.math3.exception.MathIllegalArgumentException; 028import org.apache.commons.math3.exception.NumberIsTooSmallException; 029import org.apache.commons.math3.exception.util.LocalizedFormats; 030import org.apache.commons.math3.stat.descriptive.moment.Variance; 031import org.apache.commons.math3.util.MathUtils; 032 033/** 034 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm. 035 * @param <T> type of the points to cluster 036 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a> 037 * @since 2.0 038 * @deprecated As of 3.2 (to be removed in 4.0), 039 * use {@link org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer} instead 040 */ 041@Deprecated 042public class KMeansPlusPlusClusterer<T extends Clusterable<T>> { 043 044 /** Strategies to use for replacing an empty cluster. */ 045 public enum EmptyClusterStrategy { 046 047 /** Split the cluster with largest distance variance. */ 048 LARGEST_VARIANCE, 049 050 /** Split the cluster with largest number of points. */ 051 LARGEST_POINTS_NUMBER, 052 053 /** Create a cluster around the point farthest from its centroid. */ 054 FARTHEST_POINT, 055 056 /** Generate an error. */ 057 ERROR 058 059 } 060 061 /** Random generator for choosing initial centers. */ 062 private final Random random; 063 064 /** Selected strategy for empty clusters. */ 065 private final EmptyClusterStrategy emptyStrategy; 066 067 /** Build a clusterer. 068 * <p> 069 * The default strategy for handling empty clusters that may appear during 070 * algorithm iterations is to split the cluster with largest distance variance. 071 * </p> 072 * @param random random generator to use for choosing initial centers 073 */ 074 public KMeansPlusPlusClusterer(final Random random) { 075 this(random, EmptyClusterStrategy.LARGEST_VARIANCE); 076 } 077 078 /** Build a clusterer. 079 * @param random random generator to use for choosing initial centers 080 * @param emptyStrategy strategy to use for handling empty clusters that 081 * may appear during algorithm iterations 082 * @since 2.2 083 */ 084 public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) { 085 this.random = random; 086 this.emptyStrategy = emptyStrategy; 087 } 088 089 /** 090 * Runs the K-means++ clustering algorithm. 091 * 092 * @param points the points to cluster 093 * @param k the number of clusters to split the data into 094 * @param numTrials number of trial runs 095 * @param maxIterationsPerTrial the maximum number of iterations to run the algorithm 096 * for at each trial run. If negative, no maximum will be used 097 * @return a list of clusters containing the points 098 * @throws MathIllegalArgumentException if the data points are null or the number 099 * of clusters is larger than the number of data points 100 * @throws ConvergenceException if an empty cluster is encountered and the 101 * {@link #emptyStrategy} is set to {@code ERROR} 102 */ 103 public List<Cluster<T>> cluster(final Collection<T> points, final int k, 104 int numTrials, int maxIterationsPerTrial) 105 throws MathIllegalArgumentException, ConvergenceException { 106 107 // at first, we have not found any clusters list yet 108 List<Cluster<T>> best = null; 109 double bestVarianceSum = Double.POSITIVE_INFINITY; 110 111 // do several clustering trials 112 for (int i = 0; i < numTrials; ++i) { 113 114 // compute a clusters list 115 List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial); 116 117 // compute the variance of the current list 118 double varianceSum = 0.0; 119 for (final Cluster<T> cluster : clusters) { 120 if (!cluster.getPoints().isEmpty()) { 121 122 // compute the distance variance of the current cluster 123 final T center = cluster.getCenter(); 124 final Variance stat = new Variance(); 125 for (final T point : cluster.getPoints()) { 126 stat.increment(point.distanceFrom(center)); 127 } 128 varianceSum += stat.getResult(); 129 130 } 131 } 132 133 if (varianceSum <= bestVarianceSum) { 134 // this one is the best we have found so far, remember it 135 best = clusters; 136 bestVarianceSum = varianceSum; 137 } 138 139 } 140 141 // return the best clusters list found 142 return best; 143 144 } 145 146 /** 147 * Runs the K-means++ clustering algorithm. 148 * 149 * @param points the points to cluster 150 * @param k the number of clusters to split the data into 151 * @param maxIterations the maximum number of iterations to run the algorithm 152 * for. If negative, no maximum will be used 153 * @return a list of clusters containing the points 154 * @throws MathIllegalArgumentException if the data points are null or the number 155 * of clusters is larger than the number of data points 156 * @throws ConvergenceException if an empty cluster is encountered and the 157 * {@link #emptyStrategy} is set to {@code ERROR} 158 */ 159 public List<Cluster<T>> cluster(final Collection<T> points, final int k, 160 final int maxIterations) 161 throws MathIllegalArgumentException, ConvergenceException { 162 163 // sanity checks 164 MathUtils.checkNotNull(points); 165 166 // number of clusters has to be smaller or equal the number of data points 167 if (points.size() < k) { 168 throw new NumberIsTooSmallException(points.size(), k, false); 169 } 170 171 // create the initial clusters 172 List<Cluster<T>> clusters = chooseInitialCenters(points, k, random); 173 174 // create an array containing the latest assignment of a point to a cluster 175 // no need to initialize the array, as it will be filled with the first assignment 176 int[] assignments = new int[points.size()]; 177 assignPointsToClusters(clusters, points, assignments); 178 179 // iterate through updating the centers until we're done 180 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; 181 for (int count = 0; count < max; count++) { 182 boolean emptyCluster = false; 183 List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>(); 184 for (final Cluster<T> cluster : clusters) { 185 final T newCenter; 186 if (cluster.getPoints().isEmpty()) { 187 switch (emptyStrategy) { 188 case LARGEST_VARIANCE : 189 newCenter = getPointFromLargestVarianceCluster(clusters); 190 break; 191 case LARGEST_POINTS_NUMBER : 192 newCenter = getPointFromLargestNumberCluster(clusters); 193 break; 194 case FARTHEST_POINT : 195 newCenter = getFarthestPoint(clusters); 196 break; 197 default : 198 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 199 } 200 emptyCluster = true; 201 } else { 202 newCenter = cluster.getCenter().centroidOf(cluster.getPoints()); 203 } 204 newClusters.add(new Cluster<T>(newCenter)); 205 } 206 int changes = assignPointsToClusters(newClusters, points, assignments); 207 clusters = newClusters; 208 209 // if there were no more changes in the point-to-cluster assignment 210 // and there are no empty clusters left, return the current clusters 211 if (changes == 0 && !emptyCluster) { 212 return clusters; 213 } 214 } 215 return clusters; 216 } 217 218 /** 219 * Adds the given points to the closest {@link Cluster}. 220 * 221 * @param <T> type of the points to cluster 222 * @param clusters the {@link Cluster}s to add the points to 223 * @param points the points to add to the given {@link Cluster}s 224 * @param assignments points assignments to clusters 225 * @return the number of points assigned to different clusters as the iteration before 226 */ 227 private static <T extends Clusterable<T>> int 228 assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points, 229 final int[] assignments) { 230 int assignedDifferently = 0; 231 int pointIndex = 0; 232 for (final T p : points) { 233 int clusterIndex = getNearestCluster(clusters, p); 234 if (clusterIndex != assignments[pointIndex]) { 235 assignedDifferently++; 236 } 237 238 Cluster<T> cluster = clusters.get(clusterIndex); 239 cluster.addPoint(p); 240 assignments[pointIndex++] = clusterIndex; 241 } 242 243 return assignedDifferently; 244 } 245 246 /** 247 * Use K-means++ to choose the initial centers. 248 * 249 * @param <T> type of the points to cluster 250 * @param points the points to choose the initial centers from 251 * @param k the number of centers to choose 252 * @param random random generator to use 253 * @return the initial centers 254 */ 255 private static <T extends Clusterable<T>> List<Cluster<T>> 256 chooseInitialCenters(final Collection<T> points, final int k, final Random random) { 257 258 // Convert to list for indexed access. Make it unmodifiable, since removal of items 259 // would screw up the logic of this method. 260 final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points)); 261 262 // The number of points in the list. 263 final int numPoints = pointList.size(); 264 265 // Set the corresponding element in this array to indicate when 266 // elements of pointList are no longer available. 267 final boolean[] taken = new boolean[numPoints]; 268 269 // The resulting list of initial centers. 270 final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>(); 271 272 // Choose one center uniformly at random from among the data points. 273 final int firstPointIndex = random.nextInt(numPoints); 274 275 final T firstPoint = pointList.get(firstPointIndex); 276 277 resultSet.add(new Cluster<T>(firstPoint)); 278 279 // Must mark it as taken 280 taken[firstPointIndex] = true; 281 282 // To keep track of the minimum distance squared of elements of 283 // pointList to elements of resultSet. 284 final double[] minDistSquared = new double[numPoints]; 285 286 // Initialize the elements. Since the only point in resultSet is firstPoint, 287 // this is very easy. 288 for (int i = 0; i < numPoints; i++) { 289 if (i != firstPointIndex) { // That point isn't considered 290 double d = firstPoint.distanceFrom(pointList.get(i)); 291 minDistSquared[i] = d*d; 292 } 293 } 294 295 while (resultSet.size() < k) { 296 297 // Sum up the squared distances for the points in pointList not 298 // already taken. 299 double distSqSum = 0.0; 300 301 for (int i = 0; i < numPoints; i++) { 302 if (!taken[i]) { 303 distSqSum += minDistSquared[i]; 304 } 305 } 306 307 // Add one new data point as a center. Each point x is chosen with 308 // probability proportional to D(x)2 309 final double r = random.nextDouble() * distSqSum; 310 311 // The index of the next point to be added to the resultSet. 312 int nextPointIndex = -1; 313 314 // Sum through the squared min distances again, stopping when 315 // sum >= r. 316 double sum = 0.0; 317 for (int i = 0; i < numPoints; i++) { 318 if (!taken[i]) { 319 sum += minDistSquared[i]; 320 if (sum >= r) { 321 nextPointIndex = i; 322 break; 323 } 324 } 325 } 326 327 // If it's not set to >= 0, the point wasn't found in the previous 328 // for loop, probably because distances are extremely small. Just pick 329 // the last available point. 330 if (nextPointIndex == -1) { 331 for (int i = numPoints - 1; i >= 0; i--) { 332 if (!taken[i]) { 333 nextPointIndex = i; 334 break; 335 } 336 } 337 } 338 339 // We found one. 340 if (nextPointIndex >= 0) { 341 342 final T p = pointList.get(nextPointIndex); 343 344 resultSet.add(new Cluster<T> (p)); 345 346 // Mark it as taken. 347 taken[nextPointIndex] = true; 348 349 if (resultSet.size() < k) { 350 // Now update elements of minDistSquared. We only have to compute 351 // the distance to the new center to do this. 352 for (int j = 0; j < numPoints; j++) { 353 // Only have to worry about the points still not taken. 354 if (!taken[j]) { 355 double d = p.distanceFrom(pointList.get(j)); 356 double d2 = d * d; 357 if (d2 < minDistSquared[j]) { 358 minDistSquared[j] = d2; 359 } 360 } 361 } 362 } 363 364 } else { 365 // None found -- 366 // Break from the while loop to prevent 367 // an infinite loop. 368 break; 369 } 370 } 371 372 return resultSet; 373 } 374 375 /** 376 * Get a random point from the {@link Cluster} with the largest distance variance. 377 * 378 * @param clusters the {@link Cluster}s to search 379 * @return a random point from the selected cluster 380 * @throws ConvergenceException if clusters are all empty 381 */ 382 private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters) 383 throws ConvergenceException { 384 385 double maxVariance = Double.NEGATIVE_INFINITY; 386 Cluster<T> selected = null; 387 for (final Cluster<T> cluster : clusters) { 388 if (!cluster.getPoints().isEmpty()) { 389 390 // compute the distance variance of the current cluster 391 final T center = cluster.getCenter(); 392 final Variance stat = new Variance(); 393 for (final T point : cluster.getPoints()) { 394 stat.increment(point.distanceFrom(center)); 395 } 396 final double variance = stat.getResult(); 397 398 // select the cluster with the largest variance 399 if (variance > maxVariance) { 400 maxVariance = variance; 401 selected = cluster; 402 } 403 404 } 405 } 406 407 // did we find at least one non-empty cluster ? 408 if (selected == null) { 409 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 410 } 411 412 // extract a random point from the cluster 413 final List<T> selectedPoints = selected.getPoints(); 414 return selectedPoints.remove(random.nextInt(selectedPoints.size())); 415 416 } 417 418 /** 419 * Get a random point from the {@link Cluster} with the largest number of points 420 * 421 * @param clusters the {@link Cluster}s to search 422 * @return a random point from the selected cluster 423 * @throws ConvergenceException if clusters are all empty 424 */ 425 private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) throws ConvergenceException { 426 427 int maxNumber = 0; 428 Cluster<T> selected = null; 429 for (final Cluster<T> cluster : clusters) { 430 431 // get the number of points of the current cluster 432 final int number = cluster.getPoints().size(); 433 434 // select the cluster with the largest number of points 435 if (number > maxNumber) { 436 maxNumber = number; 437 selected = cluster; 438 } 439 440 } 441 442 // did we find at least one non-empty cluster ? 443 if (selected == null) { 444 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 445 } 446 447 // extract a random point from the cluster 448 final List<T> selectedPoints = selected.getPoints(); 449 return selectedPoints.remove(random.nextInt(selectedPoints.size())); 450 451 } 452 453 /** 454 * Get the point farthest to its cluster center 455 * 456 * @param clusters the {@link Cluster}s to search 457 * @return point farthest to its cluster center 458 * @throws ConvergenceException if clusters are all empty 459 */ 460 private T getFarthestPoint(final Collection<Cluster<T>> clusters) throws ConvergenceException { 461 462 double maxDistance = Double.NEGATIVE_INFINITY; 463 Cluster<T> selectedCluster = null; 464 int selectedPoint = -1; 465 for (final Cluster<T> cluster : clusters) { 466 467 // get the farthest point 468 final T center = cluster.getCenter(); 469 final List<T> points = cluster.getPoints(); 470 for (int i = 0; i < points.size(); ++i) { 471 final double distance = points.get(i).distanceFrom(center); 472 if (distance > maxDistance) { 473 maxDistance = distance; 474 selectedCluster = cluster; 475 selectedPoint = i; 476 } 477 } 478 479 } 480 481 // did we find at least one non-empty cluster ? 482 if (selectedCluster == null) { 483 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 484 } 485 486 return selectedCluster.getPoints().remove(selectedPoint); 487 488 } 489 490 /** 491 * Returns the nearest {@link Cluster} to the given point 492 * 493 * @param <T> type of the points to cluster 494 * @param clusters the {@link Cluster}s to search 495 * @param point the point to find the nearest {@link Cluster} for 496 * @return the index of the nearest {@link Cluster} to the given point 497 */ 498 private static <T extends Clusterable<T>> int 499 getNearestCluster(final Collection<Cluster<T>> clusters, final T point) { 500 double minDistance = Double.MAX_VALUE; 501 int clusterIndex = 0; 502 int minCluster = 0; 503 for (final Cluster<T> c : clusters) { 504 final double distance = point.distanceFrom(c.getCenter()); 505 if (distance < minDistance) { 506 minDistance = distance; 507 minCluster = clusterIndex; 508 } 509 clusterIndex++; 510 } 511 return minCluster; 512 } 513 514}