ElkanKMeansPlusPlusClusterer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.ml.clustering;

  18. import java.util.ArrayList;
  19. import java.util.Arrays;
  20. import java.util.Collection;
  21. import java.util.List;

  22. import org.apache.commons.rng.UniformRandomProvider;
  23. import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
  24. import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
  25. import org.apache.commons.math4.legacy.stat.descriptive.moment.VectorialMean;

  26. /**
  27.  * Implementation of k-means++ algorithm.
  28.  * It is based on
  29.  * <blockquote>
  30.  *  Elkan, Charles.
  31.  *  "Using the triangle inequality to accelerate k-means."
  32.  *  ICML. Vol. 3. 2003.
  33.  * </blockquote>
  34.  *
  35.  * <p>
  36.  * Algorithm uses triangle inequality to speed up computation, by reducing
  37.  * the amount of distances calculations.  Towards the last iterations of
  38.  * the algorithm, points which already assigned to some cluster are unlikely
  39.  * to move to a new cluster; updates of cluster centers are also usually
  40.  * relatively small.
  41.  * Triangle inequality is thus used to determine the cases where distance
  42.  * computation could be skipped since center move only a little, without
  43.  * affecting points partitioning.
  44.  *
  45.  * <p>
  46.  * For initial centers seeding, we apply the algorithm described in
  47.  * <blockquote>
  48.  *  Arthur, David, and Sergei Vassilvitskii.
  49.  *  "k-means++: The advantages of careful seeding."
  50.  *  Proceedings of the eighteenth annual ACM-SIAM symposium on Discrete algorithms.
  51.  *  Society for Industrial and Applied Mathematics, 2007.
  52.  * </blockquote>
  53.  *
  54.  * @param <T> Type of the points to cluster.
  55.  */
  56. public class ElkanKMeansPlusPlusClusterer<T extends Clusterable>
  57.     extends KMeansPlusPlusClusterer<T> {

  58.     /**
  59.      * @param k Clustering parameter.
  60.      */
  61.     public ElkanKMeansPlusPlusClusterer(int k) {
  62.         super(k);
  63.     }

  64.     /**
  65.      * @param k Clustering parameter.
  66.      * @param maxIterations Allowed number of iterations.
  67.      * @param measure Distance measure.
  68.      * @param random Random generator.
  69.      */
  70.     public ElkanKMeansPlusPlusClusterer(int k,
  71.                                         int maxIterations,
  72.                                         DistanceMeasure measure,
  73.                                         UniformRandomProvider random) {
  74.         super(k, maxIterations, measure, random);
  75.     }

  76.     /**
  77.      * @param k Clustering parameter.
  78.      * @param maxIterations Allowed number of iterations.
  79.      * @param measure Distance measure.
  80.      * @param random Random generator.
  81.      * @param emptyStrategy Strategy for handling empty clusters that
  82.      * may appear during algorithm progress.
  83.      */
  84.     public ElkanKMeansPlusPlusClusterer(int k,
  85.                                         int maxIterations,
  86.                                         DistanceMeasure measure,
  87.                                         UniformRandomProvider random,
  88.                                         EmptyClusterStrategy emptyStrategy) {
  89.         super(k, maxIterations, measure, random, emptyStrategy);
  90.     }

  91.     /** {@inheritDoc} */
  92.     @Override
  93.     public List<CentroidCluster<T>> cluster(final Collection<T> points) {
  94.         final int k = getNumberOfClusters();

  95.         // Number of clusters has to be smaller or equal the number of data points.
  96.         if (points.size() < k) {
  97.             throw new NumberIsTooSmallException(points.size(), k, false);
  98.         }

  99.         final List<T> pointsList = new ArrayList<>(points);
  100.         final int n = points.size();
  101.         final int dim = pointsList.get(0).getPoint().length;

  102.         // Keep minimum intra cluster distance, e.g. for given cluster c s[c] is
  103.         // the distance to the closest cluster c' or s[c] = 1/2 * min_{c'!=c} dist(c', c)
  104.         final double[] s = new double[k];
  105.         Arrays.fill(s, Double.MAX_VALUE);
  106.         // Store the matrix of distances between all cluster centers, e.g. dcc[c1][c2] = distance(c1, c2)
  107.         final double[][] dcc = new double[k][k];

  108.         // For each point keeps the upper bound distance to the cluster center.
  109.         final double[] u = new double[n];
  110.         Arrays.fill(u, Double.MAX_VALUE);

  111.         // For each point and for each cluster keeps the lower bound for the distance between the point and cluster
  112.         final double[][] l = new double[n][k];

  113.         // Seed initial set of cluster centers.
  114.         final double[][] centers = seed(pointsList);

  115.         // Points partitioning induced by cluster centers, e.g. for point xi the value of partitions[xi] indicates
  116.         // the cluster or index of the cluster center which is closest to xi. partitions[xi] = min_{c} distance(xi, c).
  117.         final int[] partitions = partitionPoints(pointsList, centers, u, l);

  118.         final double[] deltas = new double[k];
  119.         VectorialMean[] means = new VectorialMean[k];
  120.         for (int it = 0, max = getMaxIterations();
  121.              it < max;
  122.              it++) {
  123.             int changes = 0;
  124.             // Step I.
  125.             // Compute inter-cluster distances.
  126.             updateIntraCentersDistances(centers, dcc, s);

  127.             for (int xi = 0; xi < n; xi++) {
  128.                 boolean r = true;

  129.                 // Step II.
  130.                 if (u[xi] <= s[partitions[xi]]) {
  131.                     continue;
  132.                 }

  133.                 for (int c = 0; c < k; c++) {
  134.                     // Check condition III.
  135.                     if (isSkipNext(partitions, u, l, dcc, xi, c)) {
  136.                         continue;
  137.                     }

  138.                     final double[] x = pointsList.get(xi).getPoint();

  139.                     // III(a)
  140.                     if (r) {
  141.                         u[xi] = distance(x, centers[partitions[xi]]);
  142.                         l[xi][partitions[xi]] = u[xi];
  143.                         r = false;
  144.                     }
  145.                     // III(b)
  146.                     if (u[xi] > l[xi][c] || u[xi] > dcc[partitions[xi]][c]) {
  147.                         l[xi][c] = distance(x, centers[c]);
  148.                         if (l[xi][c] < u[xi]) {
  149.                             partitions[xi] = c;
  150.                             u[xi] = l[xi][c];
  151.                             ++changes;
  152.                         }
  153.                     }
  154.                 }
  155.             }

  156.             // Stopping criterion.
  157.             if (changes == 0 &&
  158.                 it != 0) { // First iteration needed (to update bounds).
  159.                 break;
  160.             }

  161.             // Step IV.
  162.             Arrays.fill(means, null);
  163.             for (int i = 0; i < n; i++) {
  164.                 if (means[partitions[i]] == null) {
  165.                     means[partitions[i]] = new VectorialMean(dim);
  166.                 }
  167.                 means[partitions[i]].increment(pointsList.get(i).getPoint());
  168.             }

  169.             for (int i = 0; i < k; i++) {
  170.                 deltas[i] = distance(centers[i], means[i].getResult());
  171.                 centers[i] = means[i].getResult();
  172.             }

  173.             updateBounds(partitions, u, l, deltas);
  174.         }

  175.         return buildResults(pointsList, partitions, centers);
  176.     }

  177.     /**
  178.      * kmeans++ seeding which provides guarantee of resulting with log(k) approximation
  179.      * for final clustering results
  180.      * <p>
  181.      * Arthur, David, and Sergei Vassilvitskii. "k-means++: The advantages of careful seeding."
  182.      * Proceedings of the eighteenth annual ACM-SIAM symposium on Discrete algorithms.
  183.      * Society for Industrial and Applied Mathematics, 2007.
  184.      *
  185.      * @param points input data points
  186.      * @return an array of initial clusters centers
  187.      *
  188.      */
  189.     private double[][] seed(final List<T> points) {
  190.         final int k = getNumberOfClusters();
  191.         final UniformRandomProvider random = getRandomGenerator();

  192.         final double[][] result = new double[k][];
  193.         final int n = points.size();
  194.         final int pointIndex = random.nextInt(n);

  195.         final double[] minDistances = new double[n];

  196.         int idx = 0;
  197.         result[idx] = points.get(pointIndex).getPoint();

  198.         double sumSqDist = 0;

  199.         for (int i = 0; i < n; i++) {
  200.             final double d = distance(result[idx], points.get(i).getPoint());
  201.             minDistances[i] = d * d;
  202.             sumSqDist += minDistances[i];
  203.         }

  204.         while (++idx < k) {
  205.             final double p = sumSqDist * random.nextDouble();
  206.             int next = 0;
  207.             for (double cdf = 0; cdf < p; next++) {
  208.                 cdf += minDistances[next];
  209.             }

  210.             result[idx] = points.get(next - 1).getPoint();
  211.             for (int i = 0; i < n; i++) {
  212.                 final double d = distance(result[idx], points.get(i).getPoint());
  213.                 sumSqDist -= minDistances[i];
  214.                 minDistances[i] = Math.min(minDistances[i], d * d);
  215.                 sumSqDist += minDistances[i];
  216.             }
  217.         }

  218.         return result;
  219.     }


  220.     /**
  221.      * Once initial centers are chosen, we can actually go through data points and assign points to the
  222.      * cluster based on the distance between initial centers and points.
  223.      *
  224.      * @param pointsList data points list
  225.      * @param centers current clusters centers
  226.      * @param u points upper bounds
  227.      * @param l lower bounds for points to clusters centers
  228.      *
  229.      * @return initial assignment of points into clusters
  230.      */
  231.     private int[] partitionPoints(List<T> pointsList,
  232.                                   double[][] centers,
  233.                                   double[] u,
  234.                                   double[][] l) {
  235.         final int k = getNumberOfClusters();
  236.         final int n = pointsList.size();
  237.         // Points assignments vector.
  238.         final int[] assignments = new int[n];
  239.         Arrays.fill(assignments, -1);
  240.         // Need to assign points to the clusters for the first time and intitialize the lower bound l(x, c)
  241.         for (int i = 0; i < n; i++) {
  242.             final double[] x = pointsList.get(i).getPoint();
  243.             for (int j = 0; j < k; j++) {
  244.                 l[i][j] = distance(x, centers[j]); // l(x, c) = d(x, c)
  245.                 if (u[i] > l[i][j]) {
  246.                     u[i] = l[i][j]; // u(x) = min_c d(x, c)
  247.                     assignments[i] = j; // c(x) = argmin_c d(x, c)
  248.                 }
  249.             }
  250.         }
  251.         return assignments;
  252.     }

  253.     /**
  254.      * Updated distances between clusters centers and for each cluster
  255.      * pick the closest neighbour and keep distance to it.
  256.      *
  257.      * @param centers cluster centers
  258.      * @param dcc matrix of distance between clusters centers, e.g.
  259.      * {@code dcc[i][j] = distance(centers[i], centers[j])}
  260.      * @param s For a given cluster, {@code s[si]} holds distance value
  261.      * to the closest cluster center.
  262.      */
  263.     private void updateIntraCentersDistances(double[][] centers,
  264.                                              double[][] dcc,
  265.                                              double[] s) {
  266.         final int k = getNumberOfClusters();
  267.         for (int i = 0; i < k; i++) {
  268.             // Since distance(xi, xj) == distance(xj, xi), we need to update
  269.             // only upper or lower triangle of the distances matrix and mirror
  270.             // to the lower of upper triangle accordingly, trace has to be all
  271.             // zeros, since distance(xi, xi) == 0.
  272.             for (int j = i + 1; j < k; j++) {
  273.                 dcc[i][j] = 0.5 * distance(centers[i], centers[j]);
  274.                 dcc[j][i] = dcc[i][j];
  275.                 if (dcc[i][j] < s[i]) {
  276.                     s[i] = dcc[i][j];
  277.                 }
  278.                 if (dcc[j][i] < s[j]) {
  279.                     s[j] = dcc[j][i];
  280.                 }
  281.             }
  282.         }
  283.     }

  284.     /**
  285.      * For given points and and cluster, check condition (3) of Elkan algorithm.
  286.      *
  287.      * <ul>
  288.      *  <li>c is not the cluster xi assigned to</li>
  289.      *  <li>{@code u[xi] > l[xi][x]} upper bound for point xi is greater than
  290.      *   lower bound between xi and some cluster c</li>
  291.      *  <li>{@code u[xi] > 1/2 * d(c(xi), c)} upper bound is greater than
  292.      *   distance between center of xi's cluster and c</li>
  293.      * </ul>
  294.      *
  295.      * @param partitions current partition of points into clusters
  296.      * @param u upper bounds for points
  297.      * @param l lower bounds for distance between cluster centers and points
  298.      * @param dcc matrix of distance between clusters centers
  299.      * @param xi index of the point
  300.      * @param c index of the cluster
  301.      * @return true if conditions above satisfied false otherwise
  302.      */
  303.     private static boolean isSkipNext(int[] partitions,
  304.                                       double[] u,
  305.                                       double[][] l,
  306.                                       double[][] dcc,
  307.                                       int xi,
  308.                                       int c) {
  309.         return c == partitions[xi] ||
  310.                u[xi] <= l[xi][c] ||
  311.                u[xi] <= dcc[partitions[xi]][c];
  312.     }

  313.     /**
  314.      * Once kmeans iterations have been converged and no more movements, we can build up the final
  315.      * resulted list of cluster centroids ({@link CentroidCluster}) and assign input points based
  316.      * on the converged partitioning.
  317.      *
  318.      * @param pointsList list of data points
  319.      * @param partitions current partition of points into clusters
  320.      * @param centers cluster centers
  321.      * @return cluster partitioning
  322.      */
  323.     private List<CentroidCluster<T>> buildResults(List<T> pointsList,
  324.                                                   int[] partitions,
  325.                                                   double[][] centers) {
  326.         final int k = getNumberOfClusters();
  327.         final List<CentroidCluster<T>> result = new ArrayList<>();
  328.         for (int i = 0; i < k; i++) {
  329.             final CentroidCluster<T> cluster = new CentroidCluster<>(new DoublePoint(centers[i]));
  330.             result.add(cluster);
  331.         }
  332.         for (int i = 0; i < pointsList.size(); i++) {
  333.             result.get(partitions[i]).addPoint(pointsList.get(i));
  334.         }
  335.         return result;
  336.     }

  337.     /**
  338.      * Based on the distance that cluster center has moved we need to update our upper and lower bound.
  339.      * Worst case assumption, the center of the assigned to given cluster moves away from the point, while
  340.      * centers of over clusters become closer.
  341.      *
  342.      * @param partitions current points assiments to the clusters
  343.      * @param u points upper bounds
  344.      * @param l lower bounds for distances between point and corresponding cluster
  345.      * @param deltas the movement delta for each cluster center
  346.      */
  347.     private void updateBounds(int[] partitions,
  348.                               double[] u,
  349.                               double[][] l,
  350.                               double[] deltas) {
  351.         final int k = getNumberOfClusters();
  352.         for (int i = 0; i < partitions.length; i++) {
  353.             u[i] += deltas[partitions[i]];
  354.             for (int j = 0; j < k; j++) {
  355.                 l[i][j] = Math.max(0, l[i][j] - deltas[j]);
  356.             }
  357.         }
  358.     }

  359.     /**
  360.      * @param a Coordinates.
  361.      * @param b Coordinates.
  362.      * @return the distance between {@code a} and {@code b}.
  363.      */
  364.     private double distance(final double[] a,
  365.                             final double[] b) {
  366.         return getDistanceMeasure().compute(a, b);
  367.     }
  368. }