001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.legacy.ml.clustering;
019
020import org.apache.commons.math4.legacy.exception.NullArgumentException;
021import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
022import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
023import org.apache.commons.math4.legacy.core.Pair;
024import org.apache.commons.rng.UniformRandomProvider;
025import org.apache.commons.rng.sampling.ListSampler;
026
027import java.util.ArrayList;
028import java.util.Collection;
029import java.util.List;
030
031/**
032 * Clustering algorithm <a href="https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf">
033 * based on KMeans</a>.
034 *
035 * @param <T> Type of the points to cluster.
036 */
037public class MiniBatchKMeansClusterer<T extends Clusterable>
038    extends KMeansPlusPlusClusterer<T> {
039    /** Batch data size in iteration. */
040    private final int batchSize;
041    /** Iteration count of initialize the centers. */
042    private final int initIterations;
043    /** Data size of batch to initialize the centers. */
044    private final int initBatchSize;
045    /** Maximum number of iterations during which no improvement is occuring. */
046    private final int maxNoImprovementTimes;
047
048
049    /**
050     * Build a clusterer.
051     *
052     * @param k Number of clusters to split the data into.
053     * @param maxIterations Maximum number of iterations to run the algorithm for all the points,
054     * The actual number of iterationswill be smaller than {@code maxIterations * size / batchSize},
055     * where {@code size} is the number of points to cluster.
056     * Disabled if negative.
057     * @param batchSize Batch size for training iterations.
058     * @param initIterations Number of iterations allowed in order to find out the best initial centers.
059     * @param initBatchSize Batch size for initializing the clusters centers.
060     * A value of {@code 3 * batchSize} should be suitable in most cases.
061     * @param maxNoImprovementTimes Maximum number of iterations during which no improvement is occuring.
062     * A value of 10 is suitable in most cases.
063     * @param measure Distance measure.
064     * @param random Random generator.
065     * @param emptyStrategy Strategy for handling empty clusters that may appear during algorithm iterations.
066     */
067    public MiniBatchKMeansClusterer(final int k,
068                                    final int maxIterations,
069                                    final int batchSize,
070                                    final int initIterations,
071                                    final int initBatchSize,
072                                    final int maxNoImprovementTimes,
073                                    final DistanceMeasure measure,
074                                    final UniformRandomProvider random,
075                                    final EmptyClusterStrategy emptyStrategy) {
076        super(k, maxIterations, measure, random, emptyStrategy);
077
078        if (batchSize < 1) {
079            throw new NumberIsTooSmallException(batchSize, 1, true);
080        }
081        if (initIterations < 1) {
082            throw new NumberIsTooSmallException(initIterations, 1, true);
083        }
084        if (initBatchSize < 1) {
085            throw new NumberIsTooSmallException(initBatchSize, 1, true);
086        }
087        if (maxNoImprovementTimes < 1) {
088            throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true);
089        }
090
091        this.batchSize = batchSize;
092        this.initIterations = initIterations;
093        this.initBatchSize = initBatchSize;
094        this.maxNoImprovementTimes = maxNoImprovementTimes;
095    }
096
097    /**
098     * Runs the MiniBatch K-means clustering algorithm.
099     *
100     * @param points Points to cluster (cannot be {@code null}).
101     * @return the clusters.
102     * @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException
103     * if the number of points is smaller than the number of clusters.
104     */
105    @Override
106    public List<CentroidCluster<T>> cluster(final Collection<T> points) {
107        // Sanity check.
108        NullArgumentException.check(points);
109        if (points.size() < getNumberOfClusters()) {
110            throw new NumberIsTooSmallException(points.size(), getNumberOfClusters(), false);
111        }
112
113        final int pointSize = points.size();
114        final int batchCount = pointSize / batchSize + (pointSize % batchSize > 0 ? 1 : 0);
115        final int max = getMaxIterations() < 0 ?
116            Integer.MAX_VALUE :
117            getMaxIterations() * batchCount;
118
119        final List<T> pointList = new ArrayList<>(points);
120        List<CentroidCluster<T>> clusters = initialCenters(pointList);
121
122        final ImprovementEvaluator evaluator = new ImprovementEvaluator(batchSize,
123                                                                        maxNoImprovementTimes);
124        for (int i = 0; i < max; i++) {
125            clearClustersPoints(clusters);
126            final List<T> batchPoints = ListSampler.sample(getRandomGenerator(), pointList, batchSize);
127            // Training step.
128            final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
129            final double squareDistance = pair.getFirst();
130            clusters = pair.getSecond();
131            // Check whether the training can finished early.
132            if (evaluator.converge(squareDistance, pointSize)) {
133                break;
134            }
135        }
136
137        // Add every mini batch points to their nearest cluster.
138        clearClustersPoints(clusters);
139        for (final T point : points) {
140            addToNearestCentroidCluster(point, clusters);
141        }
142
143        return clusters;
144    }
145
146    /**
147     * Helper method.
148     *
149     * @param clusters Clusters to clear.
150     */
151    private void clearClustersPoints(final List<CentroidCluster<T>> clusters) {
152        for (CentroidCluster<T> cluster : clusters) {
153            cluster.getPoints().clear();
154        }
155    }
156
157    /**
158     * Mini batch iteration step.
159     *
160     * @param batchPoints Points selected for this batch.
161     * @param clusters Centers of the clusters.
162     * @return the squared distance of all the batch points to the nearest center.
163     */
164    private Pair<Double, List<CentroidCluster<T>>> step(final List<T> batchPoints,
165                                                        final List<CentroidCluster<T>> clusters) {
166        // Add every mini batch points to their nearest cluster.
167        for (final T point : batchPoints) {
168            addToNearestCentroidCluster(point, clusters);
169        }
170        final List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
171        // Add every mini batch points to their nearest cluster again.
172        double squareDistance = 0.0;
173        for (T point : batchPoints) {
174            final double d = addToNearestCentroidCluster(point, newClusters);
175            squareDistance += d * d;
176        }
177
178        return new Pair<>(squareDistance, newClusters);
179    }
180
181    /**
182     * Initializes the clusters centers.
183     *
184     * @param points Points used to initialize the centers.
185     * @return clusters with their center initialized.
186     */
187    private List<CentroidCluster<T>> initialCenters(final List<T> points) {
188        final List<T> validPoints = initBatchSize < points.size() ?
189            ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
190            new ArrayList<>(points);
191        double nearestSquareDistance = Double.POSITIVE_INFINITY;
192        List<CentroidCluster<T>> bestCenters = null;
193
194        for (int i = 0; i < initIterations; i++) {
195            final List<T> initialPoints = (initBatchSize < points.size()) ?
196                ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
197                new ArrayList<>(points);
198            final List<CentroidCluster<T>> clusters = chooseInitialCenters(initialPoints);
199            final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
200            final double squareDistance = pair.getFirst();
201            final List<CentroidCluster<T>> newClusters = pair.getSecond();
202            //Find out a best centers that has the nearest total square distance.
203            if (squareDistance < nearestSquareDistance) {
204                nearestSquareDistance = squareDistance;
205                bestCenters = newClusters;
206            }
207        }
208        return bestCenters;
209    }
210
211    /**
212     * Adds a point to the cluster whose center is closest.
213     *
214     * @param point Point to add.
215     * @param clusters Clusters.
216     * @return the distance between point and the closest center.
217     */
218    private double addToNearestCentroidCluster(final T point,
219                                               final List<CentroidCluster<T>> clusters) {
220        double minDistance = Double.POSITIVE_INFINITY;
221        CentroidCluster<T> closestCentroidCluster = null;
222
223        // Find cluster closest to the point.
224        for (CentroidCluster<T> centroidCluster : clusters) {
225            final double distance = distance(point, centroidCluster.getCenter());
226            if (distance < minDistance) {
227                minDistance = distance;
228                closestCentroidCluster = centroidCluster;
229            }
230        }
231        NullArgumentException.check(closestCentroidCluster);
232        closestCentroidCluster.addPoint(point);
233
234        return minDistance;
235    }
236
237    /**
238     * Stopping criterion.
239     * The evaluator checks whether improvement occurred during the
240     * {@link #maxNoImprovementTimes allowed number of successive iterations}.
241     */
242    private static final class ImprovementEvaluator {
243        /** Batch size. */
244        private final int batchSize;
245        /** Maximum number of iterations during which no improvement is occuring. */
246        private final int maxNoImprovementTimes;
247        /**
248         * <a href="https://en.wikipedia.org/wiki/Moving_average">
249         * Exponentially Weighted Average</a> of the squared
250         * diff to monitor the convergence while discarding
251         * minibatch-local stochastic variability.
252         */
253        private double ewaInertia = Double.NaN;
254        /** Minimum value of {@link #ewaInertia} during iteration. */
255        private double ewaInertiaMin = Double.POSITIVE_INFINITY;
256        /** Number of iteration during which {@link #ewaInertia} did not improve. */
257        private int noImprovementTimes;
258
259        /**
260         * @param batchSize Number of elements for each batch iteration.
261         * @param maxNoImprovementTimes Maximum number of iterations during
262         * which no improvement is occuring.
263         */
264        private ImprovementEvaluator(int batchSize,
265                                     int maxNoImprovementTimes) {
266            this.batchSize = batchSize;
267            this.maxNoImprovementTimes = maxNoImprovementTimes;
268        }
269
270        /**
271         * Stopping criterion.
272         *
273         * @param squareDistance Total square distance from the batch points
274         * to their nearest center.
275         * @param pointSize Number of data points.
276         * @return {@code true} if no improvement was made after the allowed
277         * number of iterations, {@code false} otherwise.
278         */
279        public boolean converge(final double squareDistance,
280                                final int pointSize) {
281            final double batchInertia = squareDistance / batchSize;
282            if (Double.isNaN(ewaInertia)) {
283                ewaInertia = batchInertia;
284            } else {
285                final double alpha = Math.min(batchSize * 2 / (pointSize + 1), 1);
286                ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha;
287            }
288
289            if (ewaInertia < ewaInertiaMin) {
290                // Improved.
291                noImprovementTimes = 0;
292                ewaInertiaMin = ewaInertia;
293            } else {
294                // No improvement.
295                ++noImprovementTimes;
296            }
297
298            return noImprovementTimes >= maxNoImprovementTimes;
299        }
300    }
301}