1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.ml.clustering;
19
20 import org.apache.commons.math4.legacy.exception.NullArgumentException;
21 import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
22 import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
23 import org.apache.commons.math4.legacy.core.Pair;
24 import org.apache.commons.rng.UniformRandomProvider;
25 import org.apache.commons.rng.sampling.ListSampler;
26
27 import java.util.ArrayList;
28 import java.util.Collection;
29 import java.util.List;
30
31
32
33
34
35
36
37 public class MiniBatchKMeansClusterer<T extends Clusterable>
38 extends KMeansPlusPlusClusterer<T> {
39
40 private final int batchSize;
41
42 private final int initIterations;
43
44 private final int initBatchSize;
45
46 private final int maxNoImprovementTimes;
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67 public MiniBatchKMeansClusterer(final int k,
68 final int maxIterations,
69 final int batchSize,
70 final int initIterations,
71 final int initBatchSize,
72 final int maxNoImprovementTimes,
73 final DistanceMeasure measure,
74 final UniformRandomProvider random,
75 final EmptyClusterStrategy emptyStrategy) {
76 super(k, maxIterations, measure, random, emptyStrategy);
77
78 if (batchSize < 1) {
79 throw new NumberIsTooSmallException(batchSize, 1, true);
80 }
81 if (initIterations < 1) {
82 throw new NumberIsTooSmallException(initIterations, 1, true);
83 }
84 if (initBatchSize < 1) {
85 throw new NumberIsTooSmallException(initBatchSize, 1, true);
86 }
87 if (maxNoImprovementTimes < 1) {
88 throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true);
89 }
90
91 this.batchSize = batchSize;
92 this.initIterations = initIterations;
93 this.initBatchSize = initBatchSize;
94 this.maxNoImprovementTimes = maxNoImprovementTimes;
95 }
96
97
98
99
100
101
102
103
104
105 @Override
106 public List<CentroidCluster<T>> cluster(final Collection<T> points) {
107
108 NullArgumentException.check(points);
109 if (points.size() < getNumberOfClusters()) {
110 throw new NumberIsTooSmallException(points.size(), getNumberOfClusters(), false);
111 }
112
113 final int pointSize = points.size();
114 final int batchCount = pointSize / batchSize + (pointSize % batchSize > 0 ? 1 : 0);
115 final int max = getMaxIterations() < 0 ?
116 Integer.MAX_VALUE :
117 getMaxIterations() * batchCount;
118
119 final List<T> pointList = new ArrayList<>(points);
120 List<CentroidCluster<T>> clusters = initialCenters(pointList);
121
122 final ImprovementEvaluator evaluator = new ImprovementEvaluator(batchSize,
123 maxNoImprovementTimes);
124 for (int i = 0; i < max; i++) {
125 clearClustersPoints(clusters);
126 final List<T> batchPoints = ListSampler.sample(getRandomGenerator(), pointList, batchSize);
127
128 final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
129 final double squareDistance = pair.getFirst();
130 clusters = pair.getSecond();
131
132 if (evaluator.converge(squareDistance, pointSize)) {
133 break;
134 }
135 }
136
137
138 clearClustersPoints(clusters);
139 for (final T point : points) {
140 addToNearestCentroidCluster(point, clusters);
141 }
142
143 return clusters;
144 }
145
146
147
148
149
150
151 private void clearClustersPoints(final List<CentroidCluster<T>> clusters) {
152 for (CentroidCluster<T> cluster : clusters) {
153 cluster.getPoints().clear();
154 }
155 }
156
157
158
159
160
161
162
163
164 private Pair<Double, List<CentroidCluster<T>>> step(final List<T> batchPoints,
165 final List<CentroidCluster<T>> clusters) {
166
167 for (final T point : batchPoints) {
168 addToNearestCentroidCluster(point, clusters);
169 }
170 final List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
171
172 double squareDistance = 0.0;
173 for (T point : batchPoints) {
174 final double d = addToNearestCentroidCluster(point, newClusters);
175 squareDistance += d * d;
176 }
177
178 return new Pair<>(squareDistance, newClusters);
179 }
180
181
182
183
184
185
186
187 private List<CentroidCluster<T>> initialCenters(final List<T> points) {
188 final List<T> validPoints = initBatchSize < points.size() ?
189 ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
190 new ArrayList<>(points);
191 double nearestSquareDistance = Double.POSITIVE_INFINITY;
192 List<CentroidCluster<T>> bestCenters = null;
193
194 for (int i = 0; i < initIterations; i++) {
195 final List<T> initialPoints = (initBatchSize < points.size()) ?
196 ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
197 new ArrayList<>(points);
198 final List<CentroidCluster<T>> clusters = chooseInitialCenters(initialPoints);
199 final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
200 final double squareDistance = pair.getFirst();
201 final List<CentroidCluster<T>> newClusters = pair.getSecond();
202
203 if (squareDistance < nearestSquareDistance) {
204 nearestSquareDistance = squareDistance;
205 bestCenters = newClusters;
206 }
207 }
208 return bestCenters;
209 }
210
211
212
213
214
215
216
217
218 private double addToNearestCentroidCluster(final T point,
219 final List<CentroidCluster<T>> clusters) {
220 double minDistance = Double.POSITIVE_INFINITY;
221 CentroidCluster<T> closestCentroidCluster = null;
222
223
224 for (CentroidCluster<T> centroidCluster : clusters) {
225 final double distance = distance(point, centroidCluster.getCenter());
226 if (distance < minDistance) {
227 minDistance = distance;
228 closestCentroidCluster = centroidCluster;
229 }
230 }
231 NullArgumentException.check(closestCentroidCluster);
232 closestCentroidCluster.addPoint(point);
233
234 return minDistance;
235 }
236
237
238
239
240
241
242 private static final class ImprovementEvaluator {
243
244 private final int batchSize;
245
246 private final int maxNoImprovementTimes;
247
248
249
250
251
252
253 private double ewaInertia = Double.NaN;
254
255 private double ewaInertiaMin = Double.POSITIVE_INFINITY;
256
257 private int noImprovementTimes;
258
259
260
261
262
263
264 private ImprovementEvaluator(int batchSize,
265 int maxNoImprovementTimes) {
266 this.batchSize = batchSize;
267 this.maxNoImprovementTimes = maxNoImprovementTimes;
268 }
269
270
271
272
273
274
275
276
277
278
279 public boolean converge(final double squareDistance,
280 final int pointSize) {
281 final double batchInertia = squareDistance / batchSize;
282 if (Double.isNaN(ewaInertia)) {
283 ewaInertia = batchInertia;
284 } else {
285 final double alpha = Math.min(batchSize * 2 / (pointSize + 1), 1);
286 ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha;
287 }
288
289 if (ewaInertia < ewaInertiaMin) {
290
291 noImprovementTimes = 0;
292 ewaInertiaMin = ewaInertia;
293 } else {
294
295 ++noImprovementTimes;
296 }
297
298 return noImprovementTimes >= maxNoImprovementTimes;
299 }
300 }
301 }