1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.ml.clustering;
19
20 import org.apache.commons.math4.legacy.exception.NullArgumentException;
21 import org.apache.commons.math4.legacy.exception.ConvergenceException;
22 import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
23 import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
24 import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
25 import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
26 import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance;
27 import org.apache.commons.math4.legacy.stat.descriptive.moment.Variance;
28 import org.apache.commons.rng.UniformRandomProvider;
29 import org.apache.commons.rng.simple.RandomSource;
30
31 import java.util.ArrayList;
32 import java.util.Collection;
33 import java.util.Collections;
34 import java.util.List;
35
36
37
38
39
40
41
42 public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
43
44
45 public enum EmptyClusterStrategy {
46
47
48 LARGEST_VARIANCE,
49
50
51 LARGEST_POINTS_NUMBER,
52
53
54 FARTHEST_POINT,
55
56
57 ERROR
58 }
59
60
61 private final int numberOfClusters;
62
63
64 private final int maxIterations;
65
66
67 private final UniformRandomProvider random;
68
69
70 private final EmptyClusterStrategy emptyStrategy;
71
72
73
74
75
76
77
78
79
80
81 public KMeansPlusPlusClusterer(final int k) {
82 this(k, Integer.MAX_VALUE);
83 }
84
85
86
87
88
89
90
91
92
93
94
95
96 public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
97 this(k, maxIterations, new EuclideanDistance());
98 }
99
100
101
102
103
104
105
106
107
108
109
110 public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
111 this(k, maxIterations, measure, RandomSource.MT_64.create());
112 }
113
114
115
116
117
118
119
120
121
122
123
124
125 public KMeansPlusPlusClusterer(final int k, final int maxIterations,
126 final DistanceMeasure measure,
127 final UniformRandomProvider random) {
128 this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
129 }
130
131
132
133
134
135
136
137
138
139
140
141
142 public KMeansPlusPlusClusterer(final int k,
143 final int maxIterations,
144 final DistanceMeasure measure,
145 final UniformRandomProvider random,
146 final EmptyClusterStrategy emptyStrategy) {
147 super(measure);
148
149 if (k <= 0) {
150 throw new NotStrictlyPositiveException(k);
151 }
152 if (maxIterations <= 0) {
153 throw new NotStrictlyPositiveException(maxIterations);
154 }
155
156 this.numberOfClusters = k;
157 this.maxIterations = maxIterations;
158 this.random = random;
159 this.emptyStrategy = emptyStrategy;
160 }
161
162
163
164
165
166 public int getNumberOfClusters() {
167 return numberOfClusters;
168 }
169
170
171
172
173
174 public int getMaxIterations() {
175 return maxIterations;
176 }
177
178
179
180
181
182
183
184
185
186
187
188
189 @Override
190 public List<CentroidCluster<T>> cluster(final Collection<T> points) {
191
192 NullArgumentException.check(points);
193
194
195 if (points.size() < numberOfClusters) {
196 throw new NumberIsTooSmallException(points.size(), numberOfClusters, false);
197 }
198
199
200 List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
201
202
203
204 int[] assignments = new int[points.size()];
205 assignPointsToClusters(clusters, points, assignments);
206
207
208 for (int count = 0; count < maxIterations; count++) {
209 boolean hasEmptyCluster = clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty());
210 List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
211 int changes = assignPointsToClusters(newClusters, points, assignments);
212 clusters = newClusters;
213
214
215
216 if (changes == 0 && !hasEmptyCluster) {
217 return clusters;
218 }
219 }
220 return clusters;
221 }
222
223
224
225
226 UniformRandomProvider getRandomGenerator() {
227 return random;
228 }
229
230
231
232
233 EmptyClusterStrategy getEmptyClusterStrategy() {
234 return emptyStrategy;
235 }
236
237
238
239
240
241
242 List<CentroidCluster<T>> adjustClustersCenters(List<CentroidCluster<T>> clusters) {
243 List<CentroidCluster<T>> newClusters = new ArrayList<>();
244 for (final CentroidCluster<T> cluster : clusters) {
245 final Clusterable newCenter;
246 if (cluster.getPoints().isEmpty()) {
247 switch (emptyStrategy) {
248 case LARGEST_VARIANCE :
249 newCenter = getPointFromLargestVarianceCluster(clusters);
250 break;
251 case LARGEST_POINTS_NUMBER :
252 newCenter = getPointFromLargestNumberCluster(clusters);
253 break;
254 case FARTHEST_POINT :
255 newCenter = getFarthestPoint(clusters);
256 break;
257 default :
258 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
259 }
260 } else {
261 newCenter = cluster.centroid();
262 }
263 newClusters.add(new CentroidCluster<>(newCenter));
264 }
265 return newClusters;
266 }
267
268
269
270
271
272
273
274
275
276 private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
277 final Collection<T> points,
278 final int[] assignments) {
279 int assignedDifferently = 0;
280 int pointIndex = 0;
281 for (final T p : points) {
282 int clusterIndex = getNearestCluster(clusters, p);
283 if (clusterIndex != assignments[pointIndex]) {
284 assignedDifferently++;
285 }
286
287 CentroidCluster<T> cluster = clusters.get(clusterIndex);
288 cluster.addPoint(p);
289 assignments[pointIndex++] = clusterIndex;
290 }
291
292 return assignedDifferently;
293 }
294
295
296
297
298
299
300
301 List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
302
303
304
305 final List<T> pointList = Collections.unmodifiableList(new ArrayList<> (points));
306
307
308 final int numPoints = pointList.size();
309
310
311
312 final boolean[] taken = new boolean[numPoints];
313
314
315 final List<CentroidCluster<T>> resultSet = new ArrayList<>();
316
317
318 final int firstPointIndex = random.nextInt(numPoints);
319
320 final T firstPoint = pointList.get(firstPointIndex);
321
322 resultSet.add(new CentroidCluster<>(firstPoint));
323
324
325 taken[firstPointIndex] = true;
326
327
328
329 final double[] minDistSquared = new double[numPoints];
330
331
332
333 for (int i = 0; i < numPoints; i++) {
334 if (i != firstPointIndex) {
335 double d = distance(firstPoint, pointList.get(i));
336 minDistSquared[i] = d*d;
337 }
338 }
339
340 while (resultSet.size() < numberOfClusters) {
341
342
343
344 double distSqSum = 0.0;
345
346 for (int i = 0; i < numPoints; i++) {
347 if (!taken[i]) {
348 distSqSum += minDistSquared[i];
349 }
350 }
351
352
353
354 final double r = random.nextDouble() * distSqSum;
355
356
357 int nextPointIndex = -1;
358
359
360
361 double sum = 0.0;
362 for (int i = 0; i < numPoints; i++) {
363 if (!taken[i]) {
364 sum += minDistSquared[i];
365 if (sum >= r) {
366 nextPointIndex = i;
367 break;
368 }
369 }
370 }
371
372
373
374
375 if (nextPointIndex == -1) {
376 for (int i = numPoints - 1; i >= 0; i--) {
377 if (!taken[i]) {
378 nextPointIndex = i;
379 break;
380 }
381 }
382 }
383
384
385 if (nextPointIndex >= 0) {
386
387 final T p = pointList.get(nextPointIndex);
388
389 resultSet.add(new CentroidCluster<T> (p));
390
391
392 taken[nextPointIndex] = true;
393
394 if (resultSet.size() < numberOfClusters) {
395
396
397 for (int j = 0; j < numPoints; j++) {
398
399 if (!taken[j]) {
400 double d = distance(p, pointList.get(j));
401 double d2 = d * d;
402 if (d2 < minDistSquared[j]) {
403 minDistSquared[j] = d2;
404 }
405 }
406 }
407 }
408 } else {
409
410
411
412 break;
413 }
414 }
415
416 return resultSet;
417 }
418
419
420
421
422
423
424
425
426 private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters) {
427 double maxVariance = Double.NEGATIVE_INFINITY;
428 Cluster<T> selected = null;
429 for (final CentroidCluster<T> cluster : clusters) {
430 if (!cluster.getPoints().isEmpty()) {
431
432
433 final Clusterable center = cluster.getCenter();
434 final Variance stat = new Variance();
435 for (final T point : cluster.getPoints()) {
436 stat.increment(distance(point, center));
437 }
438 final double variance = stat.getResult();
439
440
441 if (variance > maxVariance) {
442 maxVariance = variance;
443 selected = cluster;
444 }
445 }
446 }
447
448
449 if (selected == null) {
450 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
451 }
452
453
454 final List<T> selectedPoints = selected.getPoints();
455 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
456 }
457
458
459
460
461
462
463
464
465 private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters) {
466 int maxNumber = 0;
467 Cluster<T> selected = null;
468 for (final Cluster<T> cluster : clusters) {
469
470
471 final int number = cluster.getPoints().size();
472
473
474 if (number > maxNumber) {
475 maxNumber = number;
476 selected = cluster;
477 }
478 }
479
480
481 if (selected == null) {
482 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
483 }
484
485
486 final List<T> selectedPoints = selected.getPoints();
487 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
488 }
489
490
491
492
493
494
495
496
497 private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) {
498 double maxDistance = Double.NEGATIVE_INFINITY;
499 Cluster<T> selectedCluster = null;
500 int selectedPoint = -1;
501 for (final CentroidCluster<T> cluster : clusters) {
502
503
504 final Clusterable center = cluster.getCenter();
505 final List<T> points = cluster.getPoints();
506 for (int i = 0; i < points.size(); ++i) {
507 final double distance = distance(points.get(i), center);
508 if (distance > maxDistance) {
509 maxDistance = distance;
510 selectedCluster = cluster;
511 selectedPoint = i;
512 }
513 }
514 }
515
516
517 if (selectedCluster == null) {
518 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
519 }
520
521 return selectedCluster.getPoints().remove(selectedPoint);
522 }
523
524
525
526
527
528
529
530
531 private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
532 double minDistance = Double.MAX_VALUE;
533 int clusterIndex = 0;
534 int minCluster = 0;
535 for (final CentroidCluster<T> c : clusters) {
536 final double distance = distance(point, c.getCenter());
537 if (distance < minDistance) {
538 minDistance = distance;
539 minCluster = clusterIndex;
540 }
541 clusterIndex++;
542 }
543 return minCluster;
544 }
545 }