KMeansPlusPlusClusterer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.ml.clustering;

  18. import org.apache.commons.math4.legacy.exception.NullArgumentException;
  19. import org.apache.commons.math4.legacy.exception.ConvergenceException;
  20. import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
  21. import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
  22. import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
  23. import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
  24. import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance;
  25. import org.apache.commons.math4.legacy.stat.descriptive.moment.Variance;
  26. import org.apache.commons.rng.UniformRandomProvider;
  27. import org.apache.commons.rng.simple.RandomSource;

  28. import java.util.ArrayList;
  29. import java.util.Collection;
  30. import java.util.Collections;
  31. import java.util.List;

  32. /**
  33.  * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
  34.  * @param <T> type of the points to cluster
  35.  * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
  36.  * @since 3.2
  37.  */
  38. public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {

  39.     /** Strategies to use for replacing an empty cluster. */
  40.     public enum EmptyClusterStrategy {

  41.         /** Split the cluster with largest distance variance. */
  42.         LARGEST_VARIANCE,

  43.         /** Split the cluster with largest number of points. */
  44.         LARGEST_POINTS_NUMBER,

  45.         /** Create a cluster around the point farthest from its centroid. */
  46.         FARTHEST_POINT,

  47.         /** Generate an error. */
  48.         ERROR
  49.     }

  50.     /** The number of clusters. */
  51.     private final int numberOfClusters;

  52.     /** The maximum number of iterations. */
  53.     private final int maxIterations;

  54.     /** Random generator for choosing initial centers. */
  55.     private final UniformRandomProvider random;

  56.     /** Selected strategy for empty clusters. */
  57.     private final EmptyClusterStrategy emptyStrategy;

  58.     /** Build a clusterer.
  59.      * <p>
  60.      * The default strategy for handling empty clusters that may appear during
  61.      * algorithm iterations is to split the cluster with largest distance variance.
  62.      * <p>
  63.      * The euclidean distance will be used as default distance measure.
  64.      *
  65.      * @param k the number of clusters to split the data into
  66.      */
  67.     public KMeansPlusPlusClusterer(final int k) {
  68.         this(k, Integer.MAX_VALUE);
  69.     }

  70.     /** Build a clusterer.
  71.      * <p>
  72.      * The default strategy for handling empty clusters that may appear during
  73.      * algorithm iterations is to split the cluster with largest distance variance.
  74.      * <p>
  75.      * The euclidean distance will be used as default distance measure.
  76.      *
  77.      * @param k the number of clusters to split the data into
  78.      * @param maxIterations the maximum number of iterations to run the algorithm for.
  79.      *   If negative, no maximum will be used.
  80.      */
  81.     public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
  82.         this(k, maxIterations, new EuclideanDistance());
  83.     }

  84.     /** Build a clusterer.
  85.      * <p>
  86.      * The default strategy for handling empty clusters that may appear during
  87.      * algorithm iterations is to split the cluster with largest distance variance.
  88.      *
  89.      * @param k the number of clusters to split the data into
  90.      * @param maxIterations the maximum number of iterations to run the algorithm for.
  91.      * @param measure the distance measure to use
  92.      * @throws NotStrictlyPositiveException if {@code k <= 0}.
  93.      */
  94.     public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
  95.         this(k, maxIterations, measure, RandomSource.MT_64.create());
  96.     }

  97.     /** Build a clusterer.
  98.      * <p>
  99.      * The default strategy for handling empty clusters that may appear during
  100.      * algorithm iterations is to split the cluster with largest distance variance.
  101.      *
  102.      * @param k the number of clusters to split the data into
  103.      * @param maxIterations the maximum number of iterations to run the algorithm for.
  104.      *   If negative, no maximum will be used.
  105.      * @param measure the distance measure to use
  106.      * @param random random generator to use for choosing initial centers
  107.      */
  108.     public KMeansPlusPlusClusterer(final int k, final int maxIterations,
  109.                                    final DistanceMeasure measure,
  110.                                    final UniformRandomProvider random) {
  111.         this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
  112.     }

  113.     /** Build a clusterer.
  114.      *
  115.      * @param k the number of clusters to split the data into
  116.      * @param maxIterations the maximum number of iterations to run the algorithm for.
  117.      * @param measure the distance measure to use
  118.      * @param random random generator to use for choosing initial centers
  119.      * @param emptyStrategy strategy to use for handling empty clusters that
  120.      * may appear during algorithm iterations
  121.      * @throws NotStrictlyPositiveException if {@code k <= 0} or
  122.      * {@code maxIterations <= 0}.
  123.      */
  124.     public KMeansPlusPlusClusterer(final int k,
  125.                                    final int maxIterations,
  126.                                    final DistanceMeasure measure,
  127.                                    final UniformRandomProvider random,
  128.                                    final EmptyClusterStrategy emptyStrategy) {
  129.         super(measure);

  130.         if (k <= 0) {
  131.             throw new NotStrictlyPositiveException(k);
  132.         }
  133.         if (maxIterations <= 0) {
  134.             throw new NotStrictlyPositiveException(maxIterations);
  135.         }

  136.         this.numberOfClusters = k;
  137.         this.maxIterations = maxIterations;
  138.         this.random = random;
  139.         this.emptyStrategy = emptyStrategy;
  140.     }

  141.     /**
  142.      * Return the number of clusters this instance will use.
  143.      * @return the number of clusters
  144.      */
  145.     public int getNumberOfClusters() {
  146.         return numberOfClusters;
  147.     }

  148.     /**
  149.      * Returns the maximum number of iterations this instance will use.
  150.      * @return the maximum number of iterations, or -1 if no maximum is set
  151.      */
  152.     public int getMaxIterations() {
  153.         return maxIterations;
  154.     }

  155.     /**
  156.      * Runs the K-means++ clustering algorithm.
  157.      *
  158.      * @param points the points to cluster
  159.      * @return a list of clusters containing the points
  160.      * @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException
  161.      * if the data points are null or the number of clusters is larger than the
  162.      * number of data points
  163.      * @throws ConvergenceException if an empty cluster is encountered and the
  164.      * empty cluster strategy is set to {@link EmptyClusterStrategy#ERROR}
  165.      */
  166.     @Override
  167.     public List<CentroidCluster<T>> cluster(final Collection<T> points) {
  168.         // sanity checks
  169.         NullArgumentException.check(points);

  170.         // number of clusters has to be smaller or equal the number of data points
  171.         if (points.size() < numberOfClusters) {
  172.             throw new NumberIsTooSmallException(points.size(), numberOfClusters, false);
  173.         }

  174.         // create the initial clusters
  175.         List<CentroidCluster<T>> clusters = chooseInitialCenters(points);

  176.         // create an array containing the latest assignment of a point to a cluster
  177.         // no need to initialize the array, as it will be filled with the first assignment
  178.         int[] assignments = new int[points.size()];
  179.         assignPointsToClusters(clusters, points, assignments);

  180.         // iterate through updating the centers until we're done
  181.         for (int count = 0; count < maxIterations; count++) {
  182.             boolean hasEmptyCluster = clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty());
  183.             List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
  184.             int changes = assignPointsToClusters(newClusters, points, assignments);
  185.             clusters = newClusters;

  186.             // if there were no more changes in the point-to-cluster assignment
  187.             // and there are no empty clusters left, return the current clusters
  188.             if (changes == 0 && !hasEmptyCluster) {
  189.                 return clusters;
  190.             }
  191.         }
  192.         return clusters;
  193.     }

  194.     /**
  195.      * @return the random generator
  196.      */
  197.     UniformRandomProvider getRandomGenerator() {
  198.         return random;
  199.     }

  200.     /**
  201.      * @return the {@link EmptyClusterStrategy}
  202.      */
  203.     EmptyClusterStrategy getEmptyClusterStrategy() {
  204.         return emptyStrategy;
  205.     }

  206.     /**
  207.      * Adjust the clusters's centers with means of points.
  208.      * @param clusters the origin clusters
  209.      * @return adjusted clusters with center points
  210.      */
  211.     List<CentroidCluster<T>> adjustClustersCenters(List<CentroidCluster<T>> clusters) {
  212.         List<CentroidCluster<T>> newClusters = new ArrayList<>();
  213.         for (final CentroidCluster<T> cluster : clusters) {
  214.             final Clusterable newCenter;
  215.             if (cluster.getPoints().isEmpty()) {
  216.                 switch (emptyStrategy) {
  217.                     case LARGEST_VARIANCE :
  218.                         newCenter = getPointFromLargestVarianceCluster(clusters);
  219.                         break;
  220.                     case LARGEST_POINTS_NUMBER :
  221.                         newCenter = getPointFromLargestNumberCluster(clusters);
  222.                         break;
  223.                     case FARTHEST_POINT :
  224.                         newCenter = getFarthestPoint(clusters);
  225.                         break;
  226.                     default :
  227.                         throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
  228.                 }
  229.             } else {
  230.                 newCenter = cluster.centroid();
  231.             }
  232.             newClusters.add(new CentroidCluster<>(newCenter));
  233.         }
  234.         return newClusters;
  235.     }

  236.     /**
  237.      * Adds the given points to the closest {@link Cluster}.
  238.      *
  239.      * @param clusters the {@link Cluster}s to add the points to
  240.      * @param points the points to add to the given {@link Cluster}s
  241.      * @param assignments points assignments to clusters
  242.      * @return the number of points assigned to different clusters as the iteration before
  243.      */
  244.     private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
  245.                                        final Collection<T> points,
  246.                                        final int[] assignments) {
  247.         int assignedDifferently = 0;
  248.         int pointIndex = 0;
  249.         for (final T p : points) {
  250.             int clusterIndex = getNearestCluster(clusters, p);
  251.             if (clusterIndex != assignments[pointIndex]) {
  252.                 assignedDifferently++;
  253.             }

  254.             CentroidCluster<T> cluster = clusters.get(clusterIndex);
  255.             cluster.addPoint(p);
  256.             assignments[pointIndex++] = clusterIndex;
  257.         }

  258.         return assignedDifferently;
  259.     }

  260.     /**
  261.      * Use K-means++ to choose the initial centers.
  262.      *
  263.      * @param points the points to choose the initial centers from
  264.      * @return the initial centers
  265.      */
  266.     List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {

  267.         // Convert to list for indexed access. Make it unmodifiable, since removal of items
  268.         // would screw up the logic of this method.
  269.         final List<T> pointList = Collections.unmodifiableList(new ArrayList<> (points));

  270.         // The number of points in the list.
  271.         final int numPoints = pointList.size();

  272.         // Set the corresponding element in this array to indicate when
  273.         // elements of pointList are no longer available.
  274.         final boolean[] taken = new boolean[numPoints];

  275.         // The resulting list of initial centers.
  276.         final List<CentroidCluster<T>> resultSet = new ArrayList<>();

  277.         // Choose one center uniformly at random from among the data points.
  278.         final int firstPointIndex = random.nextInt(numPoints);

  279.         final T firstPoint = pointList.get(firstPointIndex);

  280.         resultSet.add(new CentroidCluster<>(firstPoint));

  281.         // Must mark it as taken
  282.         taken[firstPointIndex] = true;

  283.         // To keep track of the minimum distance squared of elements of
  284.         // pointList to elements of resultSet.
  285.         final double[] minDistSquared = new double[numPoints];

  286.         // Initialize the elements.  Since the only point in resultSet is firstPoint,
  287.         // this is very easy.
  288.         for (int i = 0; i < numPoints; i++) {
  289.             if (i != firstPointIndex) { // That point isn't considered
  290.                 double d = distance(firstPoint, pointList.get(i));
  291.                 minDistSquared[i] = d*d;
  292.             }
  293.         }

  294.         while (resultSet.size() < numberOfClusters) {

  295.             // Sum up the squared distances for the points in pointList not
  296.             // already taken.
  297.             double distSqSum = 0.0;

  298.             for (int i = 0; i < numPoints; i++) {
  299.                 if (!taken[i]) {
  300.                     distSqSum += minDistSquared[i];
  301.                 }
  302.             }

  303.             // Add one new data point as a center. Each point x is chosen with
  304.             // probability proportional to D(x)2
  305.             final double r = random.nextDouble() * distSqSum;

  306.             // The index of the next point to be added to the resultSet.
  307.             int nextPointIndex = -1;

  308.             // Sum through the squared min distances again, stopping when
  309.             // sum >= r.
  310.             double sum = 0.0;
  311.             for (int i = 0; i < numPoints; i++) {
  312.                 if (!taken[i]) {
  313.                     sum += minDistSquared[i];
  314.                     if (sum >= r) {
  315.                         nextPointIndex = i;
  316.                         break;
  317.                     }
  318.                 }
  319.             }

  320.             // If it's not set to >= 0, the point wasn't found in the previous
  321.             // for loop, probably because distances are extremely small.  Just pick
  322.             // the last available point.
  323.             if (nextPointIndex == -1) {
  324.                 for (int i = numPoints - 1; i >= 0; i--) {
  325.                     if (!taken[i]) {
  326.                         nextPointIndex = i;
  327.                         break;
  328.                     }
  329.                 }
  330.             }

  331.             // We found one.
  332.             if (nextPointIndex >= 0) {

  333.                 final T p = pointList.get(nextPointIndex);

  334.                 resultSet.add(new CentroidCluster<T> (p));

  335.                 // Mark it as taken.
  336.                 taken[nextPointIndex] = true;

  337.                 if (resultSet.size() < numberOfClusters) {
  338.                     // Now update elements of minDistSquared.  We only have to compute
  339.                     // the distance to the new center to do this.
  340.                     for (int j = 0; j < numPoints; j++) {
  341.                         // Only have to worry about the points still not taken.
  342.                         if (!taken[j]) {
  343.                             double d = distance(p, pointList.get(j));
  344.                             double d2 = d * d;
  345.                             if (d2 < minDistSquared[j]) {
  346.                                 minDistSquared[j] = d2;
  347.                             }
  348.                         }
  349.                     }
  350.                 }
  351.             } else {
  352.                 // None found --
  353.                 // Break from the while loop to prevent
  354.                 // an infinite loop.
  355.                 break;
  356.             }
  357.         }

  358.         return resultSet;
  359.     }

  360.     /**
  361.      * Get a random point from the {@link Cluster} with the largest distance variance.
  362.      *
  363.      * @param clusters the {@link Cluster}s to search
  364.      * @return a random point from the selected cluster
  365.      * @throws ConvergenceException if clusters are all empty
  366.      */
  367.     private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters) {
  368.         double maxVariance = Double.NEGATIVE_INFINITY;
  369.         Cluster<T> selected = null;
  370.         for (final CentroidCluster<T> cluster : clusters) {
  371.             if (!cluster.getPoints().isEmpty()) {

  372.                 // compute the distance variance of the current cluster
  373.                 final Clusterable center = cluster.getCenter();
  374.                 final Variance stat = new Variance();
  375.                 for (final T point : cluster.getPoints()) {
  376.                     stat.increment(distance(point, center));
  377.                 }
  378.                 final double variance = stat.getResult();

  379.                 // select the cluster with the largest variance
  380.                 if (variance > maxVariance) {
  381.                     maxVariance = variance;
  382.                     selected = cluster;
  383.                 }
  384.             }
  385.         }

  386.         // did we find at least one non-empty cluster ?
  387.         if (selected == null) {
  388.             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
  389.         }

  390.         // extract a random point from the cluster
  391.         final List<T> selectedPoints = selected.getPoints();
  392.         return selectedPoints.remove(random.nextInt(selectedPoints.size()));
  393.     }

  394.     /**
  395.      * Get a random point from the {@link Cluster} with the largest number of points.
  396.      *
  397.      * @param clusters the {@link Cluster}s to search
  398.      * @return a random point from the selected cluster
  399.      * @throws ConvergenceException if clusters are all empty
  400.      */
  401.     private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters) {
  402.         int maxNumber = 0;
  403.         Cluster<T> selected = null;
  404.         for (final Cluster<T> cluster : clusters) {

  405.             // get the number of points of the current cluster
  406.             final int number = cluster.getPoints().size();

  407.             // select the cluster with the largest number of points
  408.             if (number > maxNumber) {
  409.                 maxNumber = number;
  410.                 selected = cluster;
  411.             }
  412.         }

  413.         // did we find at least one non-empty cluster ?
  414.         if (selected == null) {
  415.             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
  416.         }

  417.         // extract a random point from the cluster
  418.         final List<T> selectedPoints = selected.getPoints();
  419.         return selectedPoints.remove(random.nextInt(selectedPoints.size()));
  420.     }

  421.     /**
  422.      * Get the point farthest to its cluster center.
  423.      *
  424.      * @param clusters the {@link Cluster}s to search
  425.      * @return point farthest to its cluster center
  426.      * @throws ConvergenceException if clusters are all empty
  427.      */
  428.     private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) {
  429.         double maxDistance = Double.NEGATIVE_INFINITY;
  430.         Cluster<T> selectedCluster = null;
  431.         int selectedPoint = -1;
  432.         for (final CentroidCluster<T> cluster : clusters) {

  433.             // get the farthest point
  434.             final Clusterable center = cluster.getCenter();
  435.             final List<T> points = cluster.getPoints();
  436.             for (int i = 0; i < points.size(); ++i) {
  437.                 final double distance = distance(points.get(i), center);
  438.                 if (distance > maxDistance) {
  439.                     maxDistance     = distance;
  440.                     selectedCluster = cluster;
  441.                     selectedPoint   = i;
  442.                 }
  443.             }
  444.         }

  445.         // did we find at least one non-empty cluster ?
  446.         if (selectedCluster == null) {
  447.             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
  448.         }

  449.         return selectedCluster.getPoints().remove(selectedPoint);
  450.     }

  451.     /**
  452.      * Returns the nearest {@link Cluster} to the given point.
  453.      *
  454.      * @param clusters the {@link Cluster}s to search
  455.      * @param point the point to find the nearest {@link Cluster} for
  456.      * @return the index of the nearest {@link Cluster} to the given point
  457.      */
  458.     private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
  459.         double minDistance = Double.MAX_VALUE;
  460.         int clusterIndex = 0;
  461.         int minCluster = 0;
  462.         for (final CentroidCluster<T> c : clusters) {
  463.             final double distance = distance(point, c.getCenter());
  464.             if (distance < minDistance) {
  465.                 minDistance = distance;
  466.                 minCluster = clusterIndex;
  467.             }
  468.             clusterIndex++;
  469.         }
  470.         return minCluster;
  471.     }
  472. }