1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.math4.legacy.ml.clustering;
19
20 import org.apache.commons.math4.legacy.exception.NullArgumentException;
21 import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
22 import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
23 import org.apache.commons.math4.legacy.core.Pair;
24 import org.apache.commons.rng.UniformRandomProvider;
25 import org.apache.commons.rng.sampling.ListSampler;
26
27 import java.util.ArrayList;
28 import java.util.Collection;
29 import java.util.List;
30
31 /**
32 * Clustering algorithm <a href="https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf">
33 * based on KMeans</a>.
34 *
35 * @param <T> Type of the points to cluster.
36 */
37 public class MiniBatchKMeansClusterer<T extends Clusterable>
38 extends KMeansPlusPlusClusterer<T> {
39 /** Batch data size in iteration. */
40 private final int batchSize;
41 /** Iteration count of initialize the centers. */
42 private final int initIterations;
43 /** Data size of batch to initialize the centers. */
44 private final int initBatchSize;
45 /** Maximum number of iterations during which no improvement is occuring. */
46 private final int maxNoImprovementTimes;
47
48
49 /**
50 * Build a clusterer.
51 *
52 * @param k Number of clusters to split the data into.
53 * @param maxIterations Maximum number of iterations to run the algorithm for all the points,
54 * The actual number of iterationswill be smaller than {@code maxIterations * size / batchSize},
55 * where {@code size} is the number of points to cluster.
56 * Disabled if negative.
57 * @param batchSize Batch size for training iterations.
58 * @param initIterations Number of iterations allowed in order to find out the best initial centers.
59 * @param initBatchSize Batch size for initializing the clusters centers.
60 * A value of {@code 3 * batchSize} should be suitable in most cases.
61 * @param maxNoImprovementTimes Maximum number of iterations during which no improvement is occuring.
62 * A value of 10 is suitable in most cases.
63 * @param measure Distance measure.
64 * @param random Random generator.
65 * @param emptyStrategy Strategy for handling empty clusters that may appear during algorithm iterations.
66 */
67 public MiniBatchKMeansClusterer(final int k,
68 final int maxIterations,
69 final int batchSize,
70 final int initIterations,
71 final int initBatchSize,
72 final int maxNoImprovementTimes,
73 final DistanceMeasure measure,
74 final UniformRandomProvider random,
75 final EmptyClusterStrategy emptyStrategy) {
76 super(k, maxIterations, measure, random, emptyStrategy);
77
78 if (batchSize < 1) {
79 throw new NumberIsTooSmallException(batchSize, 1, true);
80 }
81 if (initIterations < 1) {
82 throw new NumberIsTooSmallException(initIterations, 1, true);
83 }
84 if (initBatchSize < 1) {
85 throw new NumberIsTooSmallException(initBatchSize, 1, true);
86 }
87 if (maxNoImprovementTimes < 1) {
88 throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true);
89 }
90
91 this.batchSize = batchSize;
92 this.initIterations = initIterations;
93 this.initBatchSize = initBatchSize;
94 this.maxNoImprovementTimes = maxNoImprovementTimes;
95 }
96
97 /**
98 * Runs the MiniBatch K-means clustering algorithm.
99 *
100 * @param points Points to cluster (cannot be {@code null}).
101 * @return the clusters.
102 * @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException
103 * if the number of points is smaller than the number of clusters.
104 */
105 @Override
106 public List<CentroidCluster<T>> cluster(final Collection<T> points) {
107 // Sanity check.
108 NullArgumentException.check(points);
109 if (points.size() < getNumberOfClusters()) {
110 throw new NumberIsTooSmallException(points.size(), getNumberOfClusters(), false);
111 }
112
113 final int pointSize = points.size();
114 final int batchCount = pointSize / batchSize + (pointSize % batchSize > 0 ? 1 : 0);
115 final int max = getMaxIterations() < 0 ?
116 Integer.MAX_VALUE :
117 getMaxIterations() * batchCount;
118
119 final List<T> pointList = new ArrayList<>(points);
120 List<CentroidCluster<T>> clusters = initialCenters(pointList);
121
122 final ImprovementEvaluator evaluator = new ImprovementEvaluator(batchSize,
123 maxNoImprovementTimes);
124 for (int i = 0; i < max; i++) {
125 clearClustersPoints(clusters);
126 final List<T> batchPoints = ListSampler.sample(getRandomGenerator(), pointList, batchSize);
127 // Training step.
128 final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
129 final double squareDistance = pair.getFirst();
130 clusters = pair.getSecond();
131 // Check whether the training can finished early.
132 if (evaluator.converge(squareDistance, pointSize)) {
133 break;
134 }
135 }
136
137 // Add every mini batch points to their nearest cluster.
138 clearClustersPoints(clusters);
139 for (final T point : points) {
140 addToNearestCentroidCluster(point, clusters);
141 }
142
143 return clusters;
144 }
145
146 /**
147 * Helper method.
148 *
149 * @param clusters Clusters to clear.
150 */
151 private void clearClustersPoints(final List<CentroidCluster<T>> clusters) {
152 for (CentroidCluster<T> cluster : clusters) {
153 cluster.getPoints().clear();
154 }
155 }
156
157 /**
158 * Mini batch iteration step.
159 *
160 * @param batchPoints Points selected for this batch.
161 * @param clusters Centers of the clusters.
162 * @return the squared distance of all the batch points to the nearest center.
163 */
164 private Pair<Double, List<CentroidCluster<T>>> step(final List<T> batchPoints,
165 final List<CentroidCluster<T>> clusters) {
166 // Add every mini batch points to their nearest cluster.
167 for (final T point : batchPoints) {
168 addToNearestCentroidCluster(point, clusters);
169 }
170 final List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
171 // Add every mini batch points to their nearest cluster again.
172 double squareDistance = 0.0;
173 for (T point : batchPoints) {
174 final double d = addToNearestCentroidCluster(point, newClusters);
175 squareDistance += d * d;
176 }
177
178 return new Pair<>(squareDistance, newClusters);
179 }
180
181 /**
182 * Initializes the clusters centers.
183 *
184 * @param points Points used to initialize the centers.
185 * @return clusters with their center initialized.
186 */
187 private List<CentroidCluster<T>> initialCenters(final List<T> points) {
188 final List<T> validPoints = initBatchSize < points.size() ?
189 ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
190 new ArrayList<>(points);
191 double nearestSquareDistance = Double.POSITIVE_INFINITY;
192 List<CentroidCluster<T>> bestCenters = null;
193
194 for (int i = 0; i < initIterations; i++) {
195 final List<T> initialPoints = (initBatchSize < points.size()) ?
196 ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
197 new ArrayList<>(points);
198 final List<CentroidCluster<T>> clusters = chooseInitialCenters(initialPoints);
199 final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
200 final double squareDistance = pair.getFirst();
201 final List<CentroidCluster<T>> newClusters = pair.getSecond();
202 //Find out a best centers that has the nearest total square distance.
203 if (squareDistance < nearestSquareDistance) {
204 nearestSquareDistance = squareDistance;
205 bestCenters = newClusters;
206 }
207 }
208 return bestCenters;
209 }
210
211 /**
212 * Adds a point to the cluster whose center is closest.
213 *
214 * @param point Point to add.
215 * @param clusters Clusters.
216 * @return the distance between point and the closest center.
217 */
218 private double addToNearestCentroidCluster(final T point,
219 final List<CentroidCluster<T>> clusters) {
220 double minDistance = Double.POSITIVE_INFINITY;
221 CentroidCluster<T> closestCentroidCluster = null;
222
223 // Find cluster closest to the point.
224 for (CentroidCluster<T> centroidCluster : clusters) {
225 final double distance = distance(point, centroidCluster.getCenter());
226 if (distance < minDistance) {
227 minDistance = distance;
228 closestCentroidCluster = centroidCluster;
229 }
230 }
231 NullArgumentException.check(closestCentroidCluster);
232 closestCentroidCluster.addPoint(point);
233
234 return minDistance;
235 }
236
237 /**
238 * Stopping criterion.
239 * The evaluator checks whether improvement occurred during the
240 * {@link #maxNoImprovementTimes allowed number of successive iterations}.
241 */
242 private static final class ImprovementEvaluator {
243 /** Batch size. */
244 private final int batchSize;
245 /** Maximum number of iterations during which no improvement is occuring. */
246 private final int maxNoImprovementTimes;
247 /**
248 * <a href="https://en.wikipedia.org/wiki/Moving_average">
249 * Exponentially Weighted Average</a> of the squared
250 * diff to monitor the convergence while discarding
251 * minibatch-local stochastic variability.
252 */
253 private double ewaInertia = Double.NaN;
254 /** Minimum value of {@link #ewaInertia} during iteration. */
255 private double ewaInertiaMin = Double.POSITIVE_INFINITY;
256 /** Number of iteration during which {@link #ewaInertia} did not improve. */
257 private int noImprovementTimes;
258
259 /**
260 * @param batchSize Number of elements for each batch iteration.
261 * @param maxNoImprovementTimes Maximum number of iterations during
262 * which no improvement is occuring.
263 */
264 private ImprovementEvaluator(int batchSize,
265 int maxNoImprovementTimes) {
266 this.batchSize = batchSize;
267 this.maxNoImprovementTimes = maxNoImprovementTimes;
268 }
269
270 /**
271 * Stopping criterion.
272 *
273 * @param squareDistance Total square distance from the batch points
274 * to their nearest center.
275 * @param pointSize Number of data points.
276 * @return {@code true} if no improvement was made after the allowed
277 * number of iterations, {@code false} otherwise.
278 */
279 public boolean converge(final double squareDistance,
280 final int pointSize) {
281 final double batchInertia = squareDistance / batchSize;
282 if (Double.isNaN(ewaInertia)) {
283 ewaInertia = batchInertia;
284 } else {
285 final double alpha = Math.min(batchSize * 2 / (pointSize + 1), 1);
286 ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha;
287 }
288
289 if (ewaInertia < ewaInertiaMin) {
290 // Improved.
291 noImprovementTimes = 0;
292 ewaInertiaMin = ewaInertia;
293 } else {
294 // No improvement.
295 ++noImprovementTimes;
296 }
297
298 return noImprovementTimes >= maxNoImprovementTimes;
299 }
300 }
301 }