001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.ml.clustering;
019
020import org.apache.commons.math4.exception.ConvergenceException;
021import org.apache.commons.math4.exception.NumberIsTooSmallException;
022import org.apache.commons.math4.exception.util.LocalizedFormats;
023import org.apache.commons.math4.ml.distance.DistanceMeasure;
024import org.apache.commons.math4.ml.distance.EuclideanDistance;
025import org.apache.commons.math4.stat.descriptive.moment.Variance;
026import org.apache.commons.math4.util.MathUtils;
027import org.apache.commons.rng.UniformRandomProvider;
028import org.apache.commons.rng.simple.RandomSource;
029
030import java.util.ArrayList;
031import java.util.Collection;
032import java.util.Collections;
033import java.util.List;
034
035/**
036 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
037 * @param <T> type of the points to cluster
038 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
039 * @since 3.2
040 */
041public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
042
043    /** Strategies to use for replacing an empty cluster. */
044    public enum EmptyClusterStrategy {
045
046        /** Split the cluster with largest distance variance. */
047        LARGEST_VARIANCE,
048
049        /** Split the cluster with largest number of points. */
050        LARGEST_POINTS_NUMBER,
051
052        /** Create a cluster around the point farthest from its centroid. */
053        FARTHEST_POINT,
054
055        /** Generate an error. */
056        ERROR
057
058    }
059
060    /** The number of clusters. */
061    private final int numberOfClusters;
062
063    /** The maximum number of iterations. */
064    private final int maxIterations;
065
066    /** Random generator for choosing initial centers. */
067    private final UniformRandomProvider random;
068
069    /** Selected strategy for empty clusters. */
070    private final EmptyClusterStrategy emptyStrategy;
071
072    /** Build a clusterer.
073     * <p>
074     * The default strategy for handling empty clusters that may appear during
075     * algorithm iterations is to split the cluster with largest distance variance.
076     * <p>
077     * The euclidean distance will be used as default distance measure.
078     *
079     * @param k the number of clusters to split the data into
080     */
081    public KMeansPlusPlusClusterer(final int k) {
082        this(k, -1);
083    }
084
085    /** Build a clusterer.
086     * <p>
087     * The default strategy for handling empty clusters that may appear during
088     * algorithm iterations is to split the cluster with largest distance variance.
089     * <p>
090     * The euclidean distance will be used as default distance measure.
091     *
092     * @param k the number of clusters to split the data into
093     * @param maxIterations the maximum number of iterations to run the algorithm for.
094     *   If negative, no maximum will be used.
095     */
096    public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
097        this(k, maxIterations, new EuclideanDistance());
098    }
099
100    /** Build a clusterer.
101     * <p>
102     * The default strategy for handling empty clusters that may appear during
103     * algorithm iterations is to split the cluster with largest distance variance.
104     *
105     * @param k the number of clusters to split the data into
106     * @param maxIterations the maximum number of iterations to run the algorithm for.
107     *   If negative, no maximum will be used.
108     * @param measure the distance measure to use
109     */
110    public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
111        this(k, maxIterations, measure, RandomSource.create(RandomSource.MT_64));
112    }
113
114    /** Build a clusterer.
115     * <p>
116     * The default strategy for handling empty clusters that may appear during
117     * algorithm iterations is to split the cluster with largest distance variance.
118     *
119     * @param k the number of clusters to split the data into
120     * @param maxIterations the maximum number of iterations to run the algorithm for.
121     *   If negative, no maximum will be used.
122     * @param measure the distance measure to use
123     * @param random random generator to use for choosing initial centers
124     */
125    public KMeansPlusPlusClusterer(final int k, final int maxIterations,
126                                   final DistanceMeasure measure,
127                                   final UniformRandomProvider random) {
128        this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
129    }
130
131    /** Build a clusterer.
132     *
133     * @param k the number of clusters to split the data into
134     * @param maxIterations the maximum number of iterations to run the algorithm for.
135     *   If negative, no maximum will be used.
136     * @param measure the distance measure to use
137     * @param random random generator to use for choosing initial centers
138     * @param emptyStrategy strategy to use for handling empty clusters that
139     * may appear during algorithm iterations
140     */
141    public KMeansPlusPlusClusterer(final int k, final int maxIterations,
142                                   final DistanceMeasure measure,
143                                   final UniformRandomProvider random,
144                                   final EmptyClusterStrategy emptyStrategy) {
145        super(measure);
146        this.numberOfClusters = k;
147        this.maxIterations = maxIterations;
148        this.random        = random;
149        this.emptyStrategy = emptyStrategy;
150    }
151
152    /**
153     * Return the number of clusters this instance will use.
154     * @return the number of clusters
155     */
156    public int getNumberOfClusters() {
157        return numberOfClusters;
158    }
159
160    /**
161     * Returns the maximum number of iterations this instance will use.
162     * @return the maximum number of iterations, or -1 if no maximum is set
163     */
164    public int getMaxIterations() {
165        return maxIterations;
166    }
167
168    /**
169     * Runs the K-means++ clustering algorithm.
170     *
171     * @param points the points to cluster
172     * @return a list of clusters containing the points
173     * @throws org.apache.commons.math4.exception.MathIllegalArgumentException
174     * if the data points are null or the number of clusters is larger than the
175     * number of data points
176     * @throws ConvergenceException if an empty cluster is encountered and the
177     * empty cluster strategy is set to {@link EmptyClusterStrategy#ERROR}
178     */
179    @Override
180    public List<CentroidCluster<T>> cluster(final Collection<T> points) {
181        // sanity checks
182        MathUtils.checkNotNull(points);
183
184        // number of clusters has to be smaller or equal the number of data points
185        if (points.size() < numberOfClusters) {
186            throw new NumberIsTooSmallException(points.size(), numberOfClusters, false);
187        }
188
189        // create the initial clusters
190        List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
191
192        // create an array containing the latest assignment of a point to a cluster
193        // no need to initialize the array, as it will be filled with the first assignment
194        int[] assignments = new int[points.size()];
195        assignPointsToClusters(clusters, points, assignments);
196
197        // iterate through updating the centers until we're done
198        final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
199        for (int count = 0; count < max; count++) {
200            boolean hasEmptyCluster = clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty());
201            List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
202            int changes = assignPointsToClusters(newClusters, points, assignments);
203            clusters = newClusters;
204
205            // if there were no more changes in the point-to-cluster assignment
206            // and there are no empty clusters left, return the current clusters
207            if (changes == 0 && !hasEmptyCluster) {
208                return clusters;
209            }
210        }
211        return clusters;
212    }
213
214    /**
215     * @return the random generator
216     */
217    UniformRandomProvider getRandomGenerator() {
218        return random;
219    }
220
221    /**
222     * @return the {@link EmptyClusterStrategy}
223     */
224    EmptyClusterStrategy getEmptyClusterStrategy() {
225        return emptyStrategy;
226    }
227
228    /**
229     * Adjust the clusters's centers with means of points
230     * @param clusters the origin clusters
231     * @return adjusted clusters with center points
232     */
233    List<CentroidCluster<T>> adjustClustersCenters(List<CentroidCluster<T>> clusters) {
234        List<CentroidCluster<T>> newClusters = new ArrayList<>();
235        for (final CentroidCluster<T> cluster : clusters) {
236            final Clusterable newCenter;
237            if (cluster.getPoints().isEmpty()) {
238                switch (emptyStrategy) {
239                    case LARGEST_VARIANCE :
240                        newCenter = getPointFromLargestVarianceCluster(clusters);
241                        break;
242                    case LARGEST_POINTS_NUMBER :
243                        newCenter = getPointFromLargestNumberCluster(clusters);
244                        break;
245                    case FARTHEST_POINT :
246                        newCenter = getFarthestPoint(clusters);
247                        break;
248                    default :
249                        throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
250                }
251            } else {
252                newCenter = cluster.centroid();
253            }
254            newClusters.add(new CentroidCluster<>(newCenter));
255        }
256        return newClusters;
257    }
258
259    /**
260     * Adds the given points to the closest {@link Cluster}.
261     *
262     * @param clusters the {@link Cluster}s to add the points to
263     * @param points the points to add to the given {@link Cluster}s
264     * @param assignments points assignments to clusters
265     * @return the number of points assigned to different clusters as the iteration before
266     */
267    private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
268                                       final Collection<T> points,
269                                       final int[] assignments) {
270        int assignedDifferently = 0;
271        int pointIndex = 0;
272        for (final T p : points) {
273            int clusterIndex = getNearestCluster(clusters, p);
274            if (clusterIndex != assignments[pointIndex]) {
275                assignedDifferently++;
276            }
277
278            CentroidCluster<T> cluster = clusters.get(clusterIndex);
279            cluster.addPoint(p);
280            assignments[pointIndex++] = clusterIndex;
281        }
282
283        return assignedDifferently;
284    }
285
286    /**
287     * Use K-means++ to choose the initial centers.
288     *
289     * @param points the points to choose the initial centers from
290     * @return the initial centers
291     */
292    List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
293
294        // Convert to list for indexed access. Make it unmodifiable, since removal of items
295        // would screw up the logic of this method.
296        final List<T> pointList = Collections.unmodifiableList(new ArrayList<> (points));
297
298        // The number of points in the list.
299        final int numPoints = pointList.size();
300
301        // Set the corresponding element in this array to indicate when
302        // elements of pointList are no longer available.
303        final boolean[] taken = new boolean[numPoints];
304
305        // The resulting list of initial centers.
306        final List<CentroidCluster<T>> resultSet = new ArrayList<>();
307
308        // Choose one center uniformly at random from among the data points.
309        final int firstPointIndex = random.nextInt(numPoints);
310
311        final T firstPoint = pointList.get(firstPointIndex);
312
313        resultSet.add(new CentroidCluster<T>(firstPoint));
314
315        // Must mark it as taken
316        taken[firstPointIndex] = true;
317
318        // To keep track of the minimum distance squared of elements of
319        // pointList to elements of resultSet.
320        final double[] minDistSquared = new double[numPoints];
321
322        // Initialize the elements.  Since the only point in resultSet is firstPoint,
323        // this is very easy.
324        for (int i = 0; i < numPoints; i++) {
325            if (i != firstPointIndex) { // That point isn't considered
326                double d = distance(firstPoint, pointList.get(i));
327                minDistSquared[i] = d*d;
328            }
329        }
330
331        while (resultSet.size() < numberOfClusters) {
332
333            // Sum up the squared distances for the points in pointList not
334            // already taken.
335            double distSqSum = 0.0;
336
337            for (int i = 0; i < numPoints; i++) {
338                if (!taken[i]) {
339                    distSqSum += minDistSquared[i];
340                }
341            }
342
343            // Add one new data point as a center. Each point x is chosen with
344            // probability proportional to D(x)2
345            final double r = random.nextDouble() * distSqSum;
346
347            // The index of the next point to be added to the resultSet.
348            int nextPointIndex = -1;
349
350            // Sum through the squared min distances again, stopping when
351            // sum >= r.
352            double sum = 0.0;
353            for (int i = 0; i < numPoints; i++) {
354                if (!taken[i]) {
355                    sum += minDistSquared[i];
356                    if (sum >= r) {
357                        nextPointIndex = i;
358                        break;
359                    }
360                }
361            }
362
363            // If it's not set to >= 0, the point wasn't found in the previous
364            // for loop, probably because distances are extremely small.  Just pick
365            // the last available point.
366            if (nextPointIndex == -1) {
367                for (int i = numPoints - 1; i >= 0; i--) {
368                    if (!taken[i]) {
369                        nextPointIndex = i;
370                        break;
371                    }
372                }
373            }
374
375            // We found one.
376            if (nextPointIndex >= 0) {
377
378                final T p = pointList.get(nextPointIndex);
379
380                resultSet.add(new CentroidCluster<T> (p));
381
382                // Mark it as taken.
383                taken[nextPointIndex] = true;
384
385                if (resultSet.size() < numberOfClusters) {
386                    // Now update elements of minDistSquared.  We only have to compute
387                    // the distance to the new center to do this.
388                    for (int j = 0; j < numPoints; j++) {
389                        // Only have to worry about the points still not taken.
390                        if (!taken[j]) {
391                            double d = distance(p, pointList.get(j));
392                            double d2 = d * d;
393                            if (d2 < minDistSquared[j]) {
394                                minDistSquared[j] = d2;
395                            }
396                        }
397                    }
398                }
399
400            } else {
401                // None found --
402                // Break from the while loop to prevent
403                // an infinite loop.
404                break;
405            }
406        }
407
408        return resultSet;
409    }
410
411    /**
412     * Get a random point from the {@link Cluster} with the largest distance variance.
413     *
414     * @param clusters the {@link Cluster}s to search
415     * @return a random point from the selected cluster
416     * @throws ConvergenceException if clusters are all empty
417     */
418    private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters) {
419        double maxVariance = Double.NEGATIVE_INFINITY;
420        Cluster<T> selected = null;
421        for (final CentroidCluster<T> cluster : clusters) {
422            if (!cluster.getPoints().isEmpty()) {
423
424                // compute the distance variance of the current cluster
425                final Clusterable center = cluster.getCenter();
426                final Variance stat = new Variance();
427                for (final T point : cluster.getPoints()) {
428                    stat.increment(distance(point, center));
429                }
430                final double variance = stat.getResult();
431
432                // select the cluster with the largest variance
433                if (variance > maxVariance) {
434                    maxVariance = variance;
435                    selected = cluster;
436                }
437
438            }
439        }
440
441        // did we find at least one non-empty cluster ?
442        if (selected == null) {
443            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
444        }
445
446        // extract a random point from the cluster
447        final List<T> selectedPoints = selected.getPoints();
448        return selectedPoints.remove(random.nextInt(selectedPoints.size()));
449
450    }
451
452    /**
453     * Get a random point from the {@link Cluster} with the largest number of points
454     *
455     * @param clusters the {@link Cluster}s to search
456     * @return a random point from the selected cluster
457     * @throws ConvergenceException if clusters are all empty
458     */
459    private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters) {
460        int maxNumber = 0;
461        Cluster<T> selected = null;
462        for (final Cluster<T> cluster : clusters) {
463
464            // get the number of points of the current cluster
465            final int number = cluster.getPoints().size();
466
467            // select the cluster with the largest number of points
468            if (number > maxNumber) {
469                maxNumber = number;
470                selected = cluster;
471            }
472
473        }
474
475        // did we find at least one non-empty cluster ?
476        if (selected == null) {
477            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
478        }
479
480        // extract a random point from the cluster
481        final List<T> selectedPoints = selected.getPoints();
482        return selectedPoints.remove(random.nextInt(selectedPoints.size()));
483
484    }
485
486    /**
487     * Get the point farthest to its cluster center
488     *
489     * @param clusters the {@link Cluster}s to search
490     * @return point farthest to its cluster center
491     * @throws ConvergenceException if clusters are all empty
492     */
493    private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) {
494        double maxDistance = Double.NEGATIVE_INFINITY;
495        Cluster<T> selectedCluster = null;
496        int selectedPoint = -1;
497        for (final CentroidCluster<T> cluster : clusters) {
498
499            // get the farthest point
500            final Clusterable center = cluster.getCenter();
501            final List<T> points = cluster.getPoints();
502            for (int i = 0; i < points.size(); ++i) {
503                final double distance = distance(points.get(i), center);
504                if (distance > maxDistance) {
505                    maxDistance     = distance;
506                    selectedCluster = cluster;
507                    selectedPoint   = i;
508                }
509            }
510
511        }
512
513        // did we find at least one non-empty cluster ?
514        if (selectedCluster == null) {
515            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
516        }
517
518        return selectedCluster.getPoints().remove(selectedPoint);
519
520    }
521
522    /**
523     * Returns the nearest {@link Cluster} to the given point
524     *
525     * @param clusters the {@link Cluster}s to search
526     * @param point the point to find the nearest {@link Cluster} for
527     * @return the index of the nearest {@link Cluster} to the given point
528     */
529    private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
530        double minDistance = Double.MAX_VALUE;
531        int clusterIndex = 0;
532        int minCluster = 0;
533        for (final CentroidCluster<T> c : clusters) {
534            final double distance = distance(point, c.getCenter());
535            if (distance < minDistance) {
536                minDistance = distance;
537                minCluster = clusterIndex;
538            }
539            clusterIndex++;
540        }
541        return minCluster;
542    }
543}