MiniBatchKMeansClusterer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.ml.clustering;

  18. import org.apache.commons.math4.legacy.exception.NullArgumentException;
  19. import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
  20. import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
  21. import org.apache.commons.math4.legacy.core.Pair;
  22. import org.apache.commons.rng.UniformRandomProvider;
  23. import org.apache.commons.rng.sampling.ListSampler;

  24. import java.util.ArrayList;
  25. import java.util.Collection;
  26. import java.util.List;

  27. /**
  28.  * Clustering algorithm <a href="https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf">
  29.  * based on KMeans</a>.
  30.  *
  31.  * @param <T> Type of the points to cluster.
  32.  */
  33. public class MiniBatchKMeansClusterer<T extends Clusterable>
  34.     extends KMeansPlusPlusClusterer<T> {
  35.     /** Batch data size in iteration. */
  36.     private final int batchSize;
  37.     /** Iteration count of initialize the centers. */
  38.     private final int initIterations;
  39.     /** Data size of batch to initialize the centers. */
  40.     private final int initBatchSize;
  41.     /** Maximum number of iterations during which no improvement is occuring. */
  42.     private final int maxNoImprovementTimes;


  43.     /**
  44.      * Build a clusterer.
  45.      *
  46.      * @param k Number of clusters to split the data into.
  47.      * @param maxIterations Maximum number of iterations to run the algorithm for all the points,
  48.      * The actual number of iterationswill be smaller than {@code maxIterations * size / batchSize},
  49.      * where {@code size} is the number of points to cluster.
  50.      * Disabled if negative.
  51.      * @param batchSize Batch size for training iterations.
  52.      * @param initIterations Number of iterations allowed in order to find out the best initial centers.
  53.      * @param initBatchSize Batch size for initializing the clusters centers.
  54.      * A value of {@code 3 * batchSize} should be suitable in most cases.
  55.      * @param maxNoImprovementTimes Maximum number of iterations during which no improvement is occuring.
  56.      * A value of 10 is suitable in most cases.
  57.      * @param measure Distance measure.
  58.      * @param random Random generator.
  59.      * @param emptyStrategy Strategy for handling empty clusters that may appear during algorithm iterations.
  60.      */
  61.     public MiniBatchKMeansClusterer(final int k,
  62.                                     final int maxIterations,
  63.                                     final int batchSize,
  64.                                     final int initIterations,
  65.                                     final int initBatchSize,
  66.                                     final int maxNoImprovementTimes,
  67.                                     final DistanceMeasure measure,
  68.                                     final UniformRandomProvider random,
  69.                                     final EmptyClusterStrategy emptyStrategy) {
  70.         super(k, maxIterations, measure, random, emptyStrategy);

  71.         if (batchSize < 1) {
  72.             throw new NumberIsTooSmallException(batchSize, 1, true);
  73.         }
  74.         if (initIterations < 1) {
  75.             throw new NumberIsTooSmallException(initIterations, 1, true);
  76.         }
  77.         if (initBatchSize < 1) {
  78.             throw new NumberIsTooSmallException(initBatchSize, 1, true);
  79.         }
  80.         if (maxNoImprovementTimes < 1) {
  81.             throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true);
  82.         }

  83.         this.batchSize = batchSize;
  84.         this.initIterations = initIterations;
  85.         this.initBatchSize = initBatchSize;
  86.         this.maxNoImprovementTimes = maxNoImprovementTimes;
  87.     }

  88.     /**
  89.      * Runs the MiniBatch K-means clustering algorithm.
  90.      *
  91.      * @param points Points to cluster (cannot be {@code null}).
  92.      * @return the clusters.
  93.      * @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException
  94.      * if the number of points is smaller than the number of clusters.
  95.      */
  96.     @Override
  97.     public List<CentroidCluster<T>> cluster(final Collection<T> points) {
  98.         // Sanity check.
  99.         NullArgumentException.check(points);
  100.         if (points.size() < getNumberOfClusters()) {
  101.             throw new NumberIsTooSmallException(points.size(), getNumberOfClusters(), false);
  102.         }

  103.         final int pointSize = points.size();
  104.         final int batchCount = pointSize / batchSize + (pointSize % batchSize > 0 ? 1 : 0);
  105.         final int max = getMaxIterations() < 0 ?
  106.             Integer.MAX_VALUE :
  107.             getMaxIterations() * batchCount;

  108.         final List<T> pointList = new ArrayList<>(points);
  109.         List<CentroidCluster<T>> clusters = initialCenters(pointList);

  110.         final ImprovementEvaluator evaluator = new ImprovementEvaluator(batchSize,
  111.                                                                         maxNoImprovementTimes);
  112.         for (int i = 0; i < max; i++) {
  113.             clearClustersPoints(clusters);
  114.             final List<T> batchPoints = ListSampler.sample(getRandomGenerator(), pointList, batchSize);
  115.             // Training step.
  116.             final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
  117.             final double squareDistance = pair.getFirst();
  118.             clusters = pair.getSecond();
  119.             // Check whether the training can finished early.
  120.             if (evaluator.converge(squareDistance, pointSize)) {
  121.                 break;
  122.             }
  123.         }

  124.         // Add every mini batch points to their nearest cluster.
  125.         clearClustersPoints(clusters);
  126.         for (final T point : points) {
  127.             addToNearestCentroidCluster(point, clusters);
  128.         }

  129.         return clusters;
  130.     }

  131.     /**
  132.      * Helper method.
  133.      *
  134.      * @param clusters Clusters to clear.
  135.      */
  136.     private void clearClustersPoints(final List<CentroidCluster<T>> clusters) {
  137.         for (CentroidCluster<T> cluster : clusters) {
  138.             cluster.getPoints().clear();
  139.         }
  140.     }

  141.     /**
  142.      * Mini batch iteration step.
  143.      *
  144.      * @param batchPoints Points selected for this batch.
  145.      * @param clusters Centers of the clusters.
  146.      * @return the squared distance of all the batch points to the nearest center.
  147.      */
  148.     private Pair<Double, List<CentroidCluster<T>>> step(final List<T> batchPoints,
  149.                                                         final List<CentroidCluster<T>> clusters) {
  150.         // Add every mini batch points to their nearest cluster.
  151.         for (final T point : batchPoints) {
  152.             addToNearestCentroidCluster(point, clusters);
  153.         }
  154.         final List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
  155.         // Add every mini batch points to their nearest cluster again.
  156.         double squareDistance = 0.0;
  157.         for (T point : batchPoints) {
  158.             final double d = addToNearestCentroidCluster(point, newClusters);
  159.             squareDistance += d * d;
  160.         }

  161.         return new Pair<>(squareDistance, newClusters);
  162.     }

  163.     /**
  164.      * Initializes the clusters centers.
  165.      *
  166.      * @param points Points used to initialize the centers.
  167.      * @return clusters with their center initialized.
  168.      */
  169.     private List<CentroidCluster<T>> initialCenters(final List<T> points) {
  170.         final List<T> validPoints = initBatchSize < points.size() ?
  171.             ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
  172.             new ArrayList<>(points);
  173.         double nearestSquareDistance = Double.POSITIVE_INFINITY;
  174.         List<CentroidCluster<T>> bestCenters = null;

  175.         for (int i = 0; i < initIterations; i++) {
  176.             final List<T> initialPoints = (initBatchSize < points.size()) ?
  177.                 ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
  178.                 new ArrayList<>(points);
  179.             final List<CentroidCluster<T>> clusters = chooseInitialCenters(initialPoints);
  180.             final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
  181.             final double squareDistance = pair.getFirst();
  182.             final List<CentroidCluster<T>> newClusters = pair.getSecond();
  183.             //Find out a best centers that has the nearest total square distance.
  184.             if (squareDistance < nearestSquareDistance) {
  185.                 nearestSquareDistance = squareDistance;
  186.                 bestCenters = newClusters;
  187.             }
  188.         }
  189.         return bestCenters;
  190.     }

  191.     /**
  192.      * Adds a point to the cluster whose center is closest.
  193.      *
  194.      * @param point Point to add.
  195.      * @param clusters Clusters.
  196.      * @return the distance between point and the closest center.
  197.      */
  198.     private double addToNearestCentroidCluster(final T point,
  199.                                                final List<CentroidCluster<T>> clusters) {
  200.         double minDistance = Double.POSITIVE_INFINITY;
  201.         CentroidCluster<T> closestCentroidCluster = null;

  202.         // Find cluster closest to the point.
  203.         for (CentroidCluster<T> centroidCluster : clusters) {
  204.             final double distance = distance(point, centroidCluster.getCenter());
  205.             if (distance < minDistance) {
  206.                 minDistance = distance;
  207.                 closestCentroidCluster = centroidCluster;
  208.             }
  209.         }
  210.         NullArgumentException.check(closestCentroidCluster);
  211.         closestCentroidCluster.addPoint(point);

  212.         return minDistance;
  213.     }

  214.     /**
  215.      * Stopping criterion.
  216.      * The evaluator checks whether improvement occurred during the
  217.      * {@link #maxNoImprovementTimes allowed number of successive iterations}.
  218.      */
  219.     private static final class ImprovementEvaluator {
  220.         /** Batch size. */
  221.         private final int batchSize;
  222.         /** Maximum number of iterations during which no improvement is occuring. */
  223.         private final int maxNoImprovementTimes;
  224.         /**
  225.          * <a href="https://en.wikipedia.org/wiki/Moving_average">
  226.          * Exponentially Weighted Average</a> of the squared
  227.          * diff to monitor the convergence while discarding
  228.          * minibatch-local stochastic variability.
  229.          */
  230.         private double ewaInertia = Double.NaN;
  231.         /** Minimum value of {@link #ewaInertia} during iteration. */
  232.         private double ewaInertiaMin = Double.POSITIVE_INFINITY;
  233.         /** Number of iteration during which {@link #ewaInertia} did not improve. */
  234.         private int noImprovementTimes;

  235.         /**
  236.          * @param batchSize Number of elements for each batch iteration.
  237.          * @param maxNoImprovementTimes Maximum number of iterations during
  238.          * which no improvement is occuring.
  239.          */
  240.         private ImprovementEvaluator(int batchSize,
  241.                                      int maxNoImprovementTimes) {
  242.             this.batchSize = batchSize;
  243.             this.maxNoImprovementTimes = maxNoImprovementTimes;
  244.         }

  245.         /**
  246.          * Stopping criterion.
  247.          *
  248.          * @param squareDistance Total square distance from the batch points
  249.          * to their nearest center.
  250.          * @param pointSize Number of data points.
  251.          * @return {@code true} if no improvement was made after the allowed
  252.          * number of iterations, {@code false} otherwise.
  253.          */
  254.         public boolean converge(final double squareDistance,
  255.                                 final int pointSize) {
  256.             final double batchInertia = squareDistance / batchSize;
  257.             if (Double.isNaN(ewaInertia)) {
  258.                 ewaInertia = batchInertia;
  259.             } else {
  260.                 final double alpha = Math.min(batchSize * 2 / (pointSize + 1), 1);
  261.                 ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha;
  262.             }

  263.             if (ewaInertia < ewaInertiaMin) {
  264.                 // Improved.
  265.                 noImprovementTimes = 0;
  266.                 ewaInertiaMin = ewaInertia;
  267.             } else {
  268.                 // No improvement.
  269.                 ++noImprovementTimes;
  270.             }

  271.             return noImprovementTimes >= maxNoImprovementTimes;
  272.         }
  273.     }
  274. }