001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math4.legacy.ml.clustering; 019 020import org.apache.commons.math4.legacy.exception.NullArgumentException; 021import org.apache.commons.math4.legacy.exception.ConvergenceException; 022import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException; 023import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException; 024import org.apache.commons.math4.legacy.exception.util.LocalizedFormats; 025import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure; 026import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance; 027import org.apache.commons.math4.legacy.stat.descriptive.moment.Variance; 028import org.apache.commons.rng.UniformRandomProvider; 029import org.apache.commons.rng.simple.RandomSource; 030 031import java.util.ArrayList; 032import java.util.Collection; 033import java.util.Collections; 034import java.util.List; 035 036/** 037 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm. 038 * @param <T> type of the points to cluster 039 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a> 040 * @since 3.2 041 */ 042public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> { 043 044 /** Strategies to use for replacing an empty cluster. */ 045 public enum EmptyClusterStrategy { 046 047 /** Split the cluster with largest distance variance. */ 048 LARGEST_VARIANCE, 049 050 /** Split the cluster with largest number of points. */ 051 LARGEST_POINTS_NUMBER, 052 053 /** Create a cluster around the point farthest from its centroid. */ 054 FARTHEST_POINT, 055 056 /** Generate an error. */ 057 ERROR 058 } 059 060 /** The number of clusters. */ 061 private final int numberOfClusters; 062 063 /** The maximum number of iterations. */ 064 private final int maxIterations; 065 066 /** Random generator for choosing initial centers. */ 067 private final UniformRandomProvider random; 068 069 /** Selected strategy for empty clusters. */ 070 private final EmptyClusterStrategy emptyStrategy; 071 072 /** Build a clusterer. 073 * <p> 074 * The default strategy for handling empty clusters that may appear during 075 * algorithm iterations is to split the cluster with largest distance variance. 076 * <p> 077 * The euclidean distance will be used as default distance measure. 078 * 079 * @param k the number of clusters to split the data into 080 */ 081 public KMeansPlusPlusClusterer(final int k) { 082 this(k, Integer.MAX_VALUE); 083 } 084 085 /** Build a clusterer. 086 * <p> 087 * The default strategy for handling empty clusters that may appear during 088 * algorithm iterations is to split the cluster with largest distance variance. 089 * <p> 090 * The euclidean distance will be used as default distance measure. 091 * 092 * @param k the number of clusters to split the data into 093 * @param maxIterations the maximum number of iterations to run the algorithm for. 094 * If negative, no maximum will be used. 095 */ 096 public KMeansPlusPlusClusterer(final int k, final int maxIterations) { 097 this(k, maxIterations, new EuclideanDistance()); 098 } 099 100 /** Build a clusterer. 101 * <p> 102 * The default strategy for handling empty clusters that may appear during 103 * algorithm iterations is to split the cluster with largest distance variance. 104 * 105 * @param k the number of clusters to split the data into 106 * @param maxIterations the maximum number of iterations to run the algorithm for. 107 * @param measure the distance measure to use 108 * @throws NotStrictlyPositiveException if {@code k <= 0}. 109 */ 110 public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) { 111 this(k, maxIterations, measure, RandomSource.MT_64.create()); 112 } 113 114 /** Build a clusterer. 115 * <p> 116 * The default strategy for handling empty clusters that may appear during 117 * algorithm iterations is to split the cluster with largest distance variance. 118 * 119 * @param k the number of clusters to split the data into 120 * @param maxIterations the maximum number of iterations to run the algorithm for. 121 * If negative, no maximum will be used. 122 * @param measure the distance measure to use 123 * @param random random generator to use for choosing initial centers 124 */ 125 public KMeansPlusPlusClusterer(final int k, final int maxIterations, 126 final DistanceMeasure measure, 127 final UniformRandomProvider random) { 128 this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE); 129 } 130 131 /** Build a clusterer. 132 * 133 * @param k the number of clusters to split the data into 134 * @param maxIterations the maximum number of iterations to run the algorithm for. 135 * @param measure the distance measure to use 136 * @param random random generator to use for choosing initial centers 137 * @param emptyStrategy strategy to use for handling empty clusters that 138 * may appear during algorithm iterations 139 * @throws NotStrictlyPositiveException if {@code k <= 0} or 140 * {@code maxIterations <= 0}. 141 */ 142 public KMeansPlusPlusClusterer(final int k, 143 final int maxIterations, 144 final DistanceMeasure measure, 145 final UniformRandomProvider random, 146 final EmptyClusterStrategy emptyStrategy) { 147 super(measure); 148 149 if (k <= 0) { 150 throw new NotStrictlyPositiveException(k); 151 } 152 if (maxIterations <= 0) { 153 throw new NotStrictlyPositiveException(maxIterations); 154 } 155 156 this.numberOfClusters = k; 157 this.maxIterations = maxIterations; 158 this.random = random; 159 this.emptyStrategy = emptyStrategy; 160 } 161 162 /** 163 * Return the number of clusters this instance will use. 164 * @return the number of clusters 165 */ 166 public int getNumberOfClusters() { 167 return numberOfClusters; 168 } 169 170 /** 171 * Returns the maximum number of iterations this instance will use. 172 * @return the maximum number of iterations, or -1 if no maximum is set 173 */ 174 public int getMaxIterations() { 175 return maxIterations; 176 } 177 178 /** 179 * Runs the K-means++ clustering algorithm. 180 * 181 * @param points the points to cluster 182 * @return a list of clusters containing the points 183 * @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException 184 * if the data points are null or the number of clusters is larger than the 185 * number of data points 186 * @throws ConvergenceException if an empty cluster is encountered and the 187 * empty cluster strategy is set to {@link EmptyClusterStrategy#ERROR} 188 */ 189 @Override 190 public List<CentroidCluster<T>> cluster(final Collection<T> points) { 191 // sanity checks 192 NullArgumentException.check(points); 193 194 // number of clusters has to be smaller or equal the number of data points 195 if (points.size() < numberOfClusters) { 196 throw new NumberIsTooSmallException(points.size(), numberOfClusters, false); 197 } 198 199 // create the initial clusters 200 List<CentroidCluster<T>> clusters = chooseInitialCenters(points); 201 202 // create an array containing the latest assignment of a point to a cluster 203 // no need to initialize the array, as it will be filled with the first assignment 204 int[] assignments = new int[points.size()]; 205 assignPointsToClusters(clusters, points, assignments); 206 207 // iterate through updating the centers until we're done 208 for (int count = 0; count < maxIterations; count++) { 209 boolean hasEmptyCluster = clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty()); 210 List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters); 211 int changes = assignPointsToClusters(newClusters, points, assignments); 212 clusters = newClusters; 213 214 // if there were no more changes in the point-to-cluster assignment 215 // and there are no empty clusters left, return the current clusters 216 if (changes == 0 && !hasEmptyCluster) { 217 return clusters; 218 } 219 } 220 return clusters; 221 } 222 223 /** 224 * @return the random generator 225 */ 226 UniformRandomProvider getRandomGenerator() { 227 return random; 228 } 229 230 /** 231 * @return the {@link EmptyClusterStrategy} 232 */ 233 EmptyClusterStrategy getEmptyClusterStrategy() { 234 return emptyStrategy; 235 } 236 237 /** 238 * Adjust the clusters's centers with means of points. 239 * @param clusters the origin clusters 240 * @return adjusted clusters with center points 241 */ 242 List<CentroidCluster<T>> adjustClustersCenters(List<CentroidCluster<T>> clusters) { 243 List<CentroidCluster<T>> newClusters = new ArrayList<>(); 244 for (final CentroidCluster<T> cluster : clusters) { 245 final Clusterable newCenter; 246 if (cluster.getPoints().isEmpty()) { 247 switch (emptyStrategy) { 248 case LARGEST_VARIANCE : 249 newCenter = getPointFromLargestVarianceCluster(clusters); 250 break; 251 case LARGEST_POINTS_NUMBER : 252 newCenter = getPointFromLargestNumberCluster(clusters); 253 break; 254 case FARTHEST_POINT : 255 newCenter = getFarthestPoint(clusters); 256 break; 257 default : 258 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 259 } 260 } else { 261 newCenter = cluster.centroid(); 262 } 263 newClusters.add(new CentroidCluster<>(newCenter)); 264 } 265 return newClusters; 266 } 267 268 /** 269 * Adds the given points to the closest {@link Cluster}. 270 * 271 * @param clusters the {@link Cluster}s to add the points to 272 * @param points the points to add to the given {@link Cluster}s 273 * @param assignments points assignments to clusters 274 * @return the number of points assigned to different clusters as the iteration before 275 */ 276 private int assignPointsToClusters(final List<CentroidCluster<T>> clusters, 277 final Collection<T> points, 278 final int[] assignments) { 279 int assignedDifferently = 0; 280 int pointIndex = 0; 281 for (final T p : points) { 282 int clusterIndex = getNearestCluster(clusters, p); 283 if (clusterIndex != assignments[pointIndex]) { 284 assignedDifferently++; 285 } 286 287 CentroidCluster<T> cluster = clusters.get(clusterIndex); 288 cluster.addPoint(p); 289 assignments[pointIndex++] = clusterIndex; 290 } 291 292 return assignedDifferently; 293 } 294 295 /** 296 * Use K-means++ to choose the initial centers. 297 * 298 * @param points the points to choose the initial centers from 299 * @return the initial centers 300 */ 301 List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) { 302 303 // Convert to list for indexed access. Make it unmodifiable, since removal of items 304 // would screw up the logic of this method. 305 final List<T> pointList = Collections.unmodifiableList(new ArrayList<> (points)); 306 307 // The number of points in the list. 308 final int numPoints = pointList.size(); 309 310 // Set the corresponding element in this array to indicate when 311 // elements of pointList are no longer available. 312 final boolean[] taken = new boolean[numPoints]; 313 314 // The resulting list of initial centers. 315 final List<CentroidCluster<T>> resultSet = new ArrayList<>(); 316 317 // Choose one center uniformly at random from among the data points. 318 final int firstPointIndex = random.nextInt(numPoints); 319 320 final T firstPoint = pointList.get(firstPointIndex); 321 322 resultSet.add(new CentroidCluster<>(firstPoint)); 323 324 // Must mark it as taken 325 taken[firstPointIndex] = true; 326 327 // To keep track of the minimum distance squared of elements of 328 // pointList to elements of resultSet. 329 final double[] minDistSquared = new double[numPoints]; 330 331 // Initialize the elements. Since the only point in resultSet is firstPoint, 332 // this is very easy. 333 for (int i = 0; i < numPoints; i++) { 334 if (i != firstPointIndex) { // That point isn't considered 335 double d = distance(firstPoint, pointList.get(i)); 336 minDistSquared[i] = d*d; 337 } 338 } 339 340 while (resultSet.size() < numberOfClusters) { 341 342 // Sum up the squared distances for the points in pointList not 343 // already taken. 344 double distSqSum = 0.0; 345 346 for (int i = 0; i < numPoints; i++) { 347 if (!taken[i]) { 348 distSqSum += minDistSquared[i]; 349 } 350 } 351 352 // Add one new data point as a center. Each point x is chosen with 353 // probability proportional to D(x)2 354 final double r = random.nextDouble() * distSqSum; 355 356 // The index of the next point to be added to the resultSet. 357 int nextPointIndex = -1; 358 359 // Sum through the squared min distances again, stopping when 360 // sum >= r. 361 double sum = 0.0; 362 for (int i = 0; i < numPoints; i++) { 363 if (!taken[i]) { 364 sum += minDistSquared[i]; 365 if (sum >= r) { 366 nextPointIndex = i; 367 break; 368 } 369 } 370 } 371 372 // If it's not set to >= 0, the point wasn't found in the previous 373 // for loop, probably because distances are extremely small. Just pick 374 // the last available point. 375 if (nextPointIndex == -1) { 376 for (int i = numPoints - 1; i >= 0; i--) { 377 if (!taken[i]) { 378 nextPointIndex = i; 379 break; 380 } 381 } 382 } 383 384 // We found one. 385 if (nextPointIndex >= 0) { 386 387 final T p = pointList.get(nextPointIndex); 388 389 resultSet.add(new CentroidCluster<T> (p)); 390 391 // Mark it as taken. 392 taken[nextPointIndex] = true; 393 394 if (resultSet.size() < numberOfClusters) { 395 // Now update elements of minDistSquared. We only have to compute 396 // the distance to the new center to do this. 397 for (int j = 0; j < numPoints; j++) { 398 // Only have to worry about the points still not taken. 399 if (!taken[j]) { 400 double d = distance(p, pointList.get(j)); 401 double d2 = d * d; 402 if (d2 < minDistSquared[j]) { 403 minDistSquared[j] = d2; 404 } 405 } 406 } 407 } 408 } else { 409 // None found -- 410 // Break from the while loop to prevent 411 // an infinite loop. 412 break; 413 } 414 } 415 416 return resultSet; 417 } 418 419 /** 420 * Get a random point from the {@link Cluster} with the largest distance variance. 421 * 422 * @param clusters the {@link Cluster}s to search 423 * @return a random point from the selected cluster 424 * @throws ConvergenceException if clusters are all empty 425 */ 426 private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters) { 427 double maxVariance = Double.NEGATIVE_INFINITY; 428 Cluster<T> selected = null; 429 for (final CentroidCluster<T> cluster : clusters) { 430 if (!cluster.getPoints().isEmpty()) { 431 432 // compute the distance variance of the current cluster 433 final Clusterable center = cluster.getCenter(); 434 final Variance stat = new Variance(); 435 for (final T point : cluster.getPoints()) { 436 stat.increment(distance(point, center)); 437 } 438 final double variance = stat.getResult(); 439 440 // select the cluster with the largest variance 441 if (variance > maxVariance) { 442 maxVariance = variance; 443 selected = cluster; 444 } 445 } 446 } 447 448 // did we find at least one non-empty cluster ? 449 if (selected == null) { 450 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 451 } 452 453 // extract a random point from the cluster 454 final List<T> selectedPoints = selected.getPoints(); 455 return selectedPoints.remove(random.nextInt(selectedPoints.size())); 456 } 457 458 /** 459 * Get a random point from the {@link Cluster} with the largest number of points. 460 * 461 * @param clusters the {@link Cluster}s to search 462 * @return a random point from the selected cluster 463 * @throws ConvergenceException if clusters are all empty 464 */ 465 private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters) { 466 int maxNumber = 0; 467 Cluster<T> selected = null; 468 for (final Cluster<T> cluster : clusters) { 469 470 // get the number of points of the current cluster 471 final int number = cluster.getPoints().size(); 472 473 // select the cluster with the largest number of points 474 if (number > maxNumber) { 475 maxNumber = number; 476 selected = cluster; 477 } 478 } 479 480 // did we find at least one non-empty cluster ? 481 if (selected == null) { 482 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 483 } 484 485 // extract a random point from the cluster 486 final List<T> selectedPoints = selected.getPoints(); 487 return selectedPoints.remove(random.nextInt(selectedPoints.size())); 488 } 489 490 /** 491 * Get the point farthest to its cluster center. 492 * 493 * @param clusters the {@link Cluster}s to search 494 * @return point farthest to its cluster center 495 * @throws ConvergenceException if clusters are all empty 496 */ 497 private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) { 498 double maxDistance = Double.NEGATIVE_INFINITY; 499 Cluster<T> selectedCluster = null; 500 int selectedPoint = -1; 501 for (final CentroidCluster<T> cluster : clusters) { 502 503 // get the farthest point 504 final Clusterable center = cluster.getCenter(); 505 final List<T> points = cluster.getPoints(); 506 for (int i = 0; i < points.size(); ++i) { 507 final double distance = distance(points.get(i), center); 508 if (distance > maxDistance) { 509 maxDistance = distance; 510 selectedCluster = cluster; 511 selectedPoint = i; 512 } 513 } 514 } 515 516 // did we find at least one non-empty cluster ? 517 if (selectedCluster == null) { 518 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 519 } 520 521 return selectedCluster.getPoints().remove(selectedPoint); 522 } 523 524 /** 525 * Returns the nearest {@link Cluster} to the given point. 526 * 527 * @param clusters the {@link Cluster}s to search 528 * @param point the point to find the nearest {@link Cluster} for 529 * @return the index of the nearest {@link Cluster} to the given point 530 */ 531 private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) { 532 double minDistance = Double.MAX_VALUE; 533 int clusterIndex = 0; 534 int minCluster = 0; 535 for (final CentroidCluster<T> c : clusters) { 536 final double distance = distance(point, c.getCenter()); 537 if (distance < minDistance) { 538 minDistance = distance; 539 minCluster = clusterIndex; 540 } 541 clusterIndex++; 542 } 543 return minCluster; 544 } 545}