001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math4.legacy.ml.clustering; 019 020import org.apache.commons.math4.legacy.exception.NullArgumentException; 021import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException; 022import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure; 023import org.apache.commons.math4.legacy.core.Pair; 024import org.apache.commons.rng.UniformRandomProvider; 025import org.apache.commons.rng.sampling.ListSampler; 026 027import java.util.ArrayList; 028import java.util.Collection; 029import java.util.List; 030 031/** 032 * Clustering algorithm <a href="https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf"> 033 * based on KMeans</a>. 034 * 035 * @param <T> Type of the points to cluster. 036 */ 037public class MiniBatchKMeansClusterer<T extends Clusterable> 038 extends KMeansPlusPlusClusterer<T> { 039 /** Batch data size in iteration. */ 040 private final int batchSize; 041 /** Iteration count of initialize the centers. */ 042 private final int initIterations; 043 /** Data size of batch to initialize the centers. */ 044 private final int initBatchSize; 045 /** Maximum number of iterations during which no improvement is occuring. */ 046 private final int maxNoImprovementTimes; 047 048 049 /** 050 * Build a clusterer. 051 * 052 * @param k Number of clusters to split the data into. 053 * @param maxIterations Maximum number of iterations to run the algorithm for all the points, 054 * The actual number of iterationswill be smaller than {@code maxIterations * size / batchSize}, 055 * where {@code size} is the number of points to cluster. 056 * Disabled if negative. 057 * @param batchSize Batch size for training iterations. 058 * @param initIterations Number of iterations allowed in order to find out the best initial centers. 059 * @param initBatchSize Batch size for initializing the clusters centers. 060 * A value of {@code 3 * batchSize} should be suitable in most cases. 061 * @param maxNoImprovementTimes Maximum number of iterations during which no improvement is occuring. 062 * A value of 10 is suitable in most cases. 063 * @param measure Distance measure. 064 * @param random Random generator. 065 * @param emptyStrategy Strategy for handling empty clusters that may appear during algorithm iterations. 066 */ 067 public MiniBatchKMeansClusterer(final int k, 068 final int maxIterations, 069 final int batchSize, 070 final int initIterations, 071 final int initBatchSize, 072 final int maxNoImprovementTimes, 073 final DistanceMeasure measure, 074 final UniformRandomProvider random, 075 final EmptyClusterStrategy emptyStrategy) { 076 super(k, maxIterations, measure, random, emptyStrategy); 077 078 if (batchSize < 1) { 079 throw new NumberIsTooSmallException(batchSize, 1, true); 080 } 081 if (initIterations < 1) { 082 throw new NumberIsTooSmallException(initIterations, 1, true); 083 } 084 if (initBatchSize < 1) { 085 throw new NumberIsTooSmallException(initBatchSize, 1, true); 086 } 087 if (maxNoImprovementTimes < 1) { 088 throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true); 089 } 090 091 this.batchSize = batchSize; 092 this.initIterations = initIterations; 093 this.initBatchSize = initBatchSize; 094 this.maxNoImprovementTimes = maxNoImprovementTimes; 095 } 096 097 /** 098 * Runs the MiniBatch K-means clustering algorithm. 099 * 100 * @param points Points to cluster (cannot be {@code null}). 101 * @return the clusters. 102 * @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException 103 * if the number of points is smaller than the number of clusters. 104 */ 105 @Override 106 public List<CentroidCluster<T>> cluster(final Collection<T> points) { 107 // Sanity check. 108 NullArgumentException.check(points); 109 if (points.size() < getNumberOfClusters()) { 110 throw new NumberIsTooSmallException(points.size(), getNumberOfClusters(), false); 111 } 112 113 final int pointSize = points.size(); 114 final int batchCount = pointSize / batchSize + (pointSize % batchSize > 0 ? 1 : 0); 115 final int max = getMaxIterations() < 0 ? 116 Integer.MAX_VALUE : 117 getMaxIterations() * batchCount; 118 119 final List<T> pointList = new ArrayList<>(points); 120 List<CentroidCluster<T>> clusters = initialCenters(pointList); 121 122 final ImprovementEvaluator evaluator = new ImprovementEvaluator(batchSize, 123 maxNoImprovementTimes); 124 for (int i = 0; i < max; i++) { 125 clearClustersPoints(clusters); 126 final List<T> batchPoints = ListSampler.sample(getRandomGenerator(), pointList, batchSize); 127 // Training step. 128 final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters); 129 final double squareDistance = pair.getFirst(); 130 clusters = pair.getSecond(); 131 // Check whether the training can finished early. 132 if (evaluator.converge(squareDistance, pointSize)) { 133 break; 134 } 135 } 136 137 // Add every mini batch points to their nearest cluster. 138 clearClustersPoints(clusters); 139 for (final T point : points) { 140 addToNearestCentroidCluster(point, clusters); 141 } 142 143 return clusters; 144 } 145 146 /** 147 * Helper method. 148 * 149 * @param clusters Clusters to clear. 150 */ 151 private void clearClustersPoints(final List<CentroidCluster<T>> clusters) { 152 for (CentroidCluster<T> cluster : clusters) { 153 cluster.getPoints().clear(); 154 } 155 } 156 157 /** 158 * Mini batch iteration step. 159 * 160 * @param batchPoints Points selected for this batch. 161 * @param clusters Centers of the clusters. 162 * @return the squared distance of all the batch points to the nearest center. 163 */ 164 private Pair<Double, List<CentroidCluster<T>>> step(final List<T> batchPoints, 165 final List<CentroidCluster<T>> clusters) { 166 // Add every mini batch points to their nearest cluster. 167 for (final T point : batchPoints) { 168 addToNearestCentroidCluster(point, clusters); 169 } 170 final List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters); 171 // Add every mini batch points to their nearest cluster again. 172 double squareDistance = 0.0; 173 for (T point : batchPoints) { 174 final double d = addToNearestCentroidCluster(point, newClusters); 175 squareDistance += d * d; 176 } 177 178 return new Pair<>(squareDistance, newClusters); 179 } 180 181 /** 182 * Initializes the clusters centers. 183 * 184 * @param points Points used to initialize the centers. 185 * @return clusters with their center initialized. 186 */ 187 private List<CentroidCluster<T>> initialCenters(final List<T> points) { 188 final List<T> validPoints = initBatchSize < points.size() ? 189 ListSampler.sample(getRandomGenerator(), points, initBatchSize) : 190 new ArrayList<>(points); 191 double nearestSquareDistance = Double.POSITIVE_INFINITY; 192 List<CentroidCluster<T>> bestCenters = null; 193 194 for (int i = 0; i < initIterations; i++) { 195 final List<T> initialPoints = (initBatchSize < points.size()) ? 196 ListSampler.sample(getRandomGenerator(), points, initBatchSize) : 197 new ArrayList<>(points); 198 final List<CentroidCluster<T>> clusters = chooseInitialCenters(initialPoints); 199 final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters); 200 final double squareDistance = pair.getFirst(); 201 final List<CentroidCluster<T>> newClusters = pair.getSecond(); 202 //Find out a best centers that has the nearest total square distance. 203 if (squareDistance < nearestSquareDistance) { 204 nearestSquareDistance = squareDistance; 205 bestCenters = newClusters; 206 } 207 } 208 return bestCenters; 209 } 210 211 /** 212 * Adds a point to the cluster whose center is closest. 213 * 214 * @param point Point to add. 215 * @param clusters Clusters. 216 * @return the distance between point and the closest center. 217 */ 218 private double addToNearestCentroidCluster(final T point, 219 final List<CentroidCluster<T>> clusters) { 220 double minDistance = Double.POSITIVE_INFINITY; 221 CentroidCluster<T> closestCentroidCluster = null; 222 223 // Find cluster closest to the point. 224 for (CentroidCluster<T> centroidCluster : clusters) { 225 final double distance = distance(point, centroidCluster.getCenter()); 226 if (distance < minDistance) { 227 minDistance = distance; 228 closestCentroidCluster = centroidCluster; 229 } 230 } 231 NullArgumentException.check(closestCentroidCluster); 232 closestCentroidCluster.addPoint(point); 233 234 return minDistance; 235 } 236 237 /** 238 * Stopping criterion. 239 * The evaluator checks whether improvement occurred during the 240 * {@link #maxNoImprovementTimes allowed number of successive iterations}. 241 */ 242 private static final class ImprovementEvaluator { 243 /** Batch size. */ 244 private final int batchSize; 245 /** Maximum number of iterations during which no improvement is occuring. */ 246 private final int maxNoImprovementTimes; 247 /** 248 * <a href="https://en.wikipedia.org/wiki/Moving_average"> 249 * Exponentially Weighted Average</a> of the squared 250 * diff to monitor the convergence while discarding 251 * minibatch-local stochastic variability. 252 */ 253 private double ewaInertia = Double.NaN; 254 /** Minimum value of {@link #ewaInertia} during iteration. */ 255 private double ewaInertiaMin = Double.POSITIVE_INFINITY; 256 /** Number of iteration during which {@link #ewaInertia} did not improve. */ 257 private int noImprovementTimes; 258 259 /** 260 * @param batchSize Number of elements for each batch iteration. 261 * @param maxNoImprovementTimes Maximum number of iterations during 262 * which no improvement is occuring. 263 */ 264 private ImprovementEvaluator(int batchSize, 265 int maxNoImprovementTimes) { 266 this.batchSize = batchSize; 267 this.maxNoImprovementTimes = maxNoImprovementTimes; 268 } 269 270 /** 271 * Stopping criterion. 272 * 273 * @param squareDistance Total square distance from the batch points 274 * to their nearest center. 275 * @param pointSize Number of data points. 276 * @return {@code true} if no improvement was made after the allowed 277 * number of iterations, {@code false} otherwise. 278 */ 279 public boolean converge(final double squareDistance, 280 final int pointSize) { 281 final double batchInertia = squareDistance / batchSize; 282 if (Double.isNaN(ewaInertia)) { 283 ewaInertia = batchInertia; 284 } else { 285 final double alpha = Math.min(batchSize * 2 / (pointSize + 1), 1); 286 ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha; 287 } 288 289 if (ewaInertia < ewaInertiaMin) { 290 // Improved. 291 noImprovementTimes = 0; 292 ewaInertiaMin = ewaInertia; 293 } else { 294 // No improvement. 295 ++noImprovementTimes; 296 } 297 298 return noImprovementTimes >= maxNoImprovementTimes; 299 } 300 } 301}