001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.legacy.ml.clustering;
019
020import org.apache.commons.math4.legacy.exception.NullArgumentException;
021import org.apache.commons.math4.legacy.exception.ConvergenceException;
022import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
023import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
024import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
025import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
026import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance;
027import org.apache.commons.math4.legacy.stat.descriptive.moment.Variance;
028import org.apache.commons.rng.UniformRandomProvider;
029import org.apache.commons.rng.simple.RandomSource;
030
031import java.util.ArrayList;
032import java.util.Collection;
033import java.util.Collections;
034import java.util.List;
035
036/**
037 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
038 * @param <T> type of the points to cluster
039 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
040 * @since 3.2
041 */
042public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
043
044    /** Strategies to use for replacing an empty cluster. */
045    public enum EmptyClusterStrategy {
046
047        /** Split the cluster with largest distance variance. */
048        LARGEST_VARIANCE,
049
050        /** Split the cluster with largest number of points. */
051        LARGEST_POINTS_NUMBER,
052
053        /** Create a cluster around the point farthest from its centroid. */
054        FARTHEST_POINT,
055
056        /** Generate an error. */
057        ERROR
058    }
059
060    /** The number of clusters. */
061    private final int numberOfClusters;
062
063    /** The maximum number of iterations. */
064    private final int maxIterations;
065
066    /** Random generator for choosing initial centers. */
067    private final UniformRandomProvider random;
068
069    /** Selected strategy for empty clusters. */
070    private final EmptyClusterStrategy emptyStrategy;
071
072    /** Build a clusterer.
073     * <p>
074     * The default strategy for handling empty clusters that may appear during
075     * algorithm iterations is to split the cluster with largest distance variance.
076     * <p>
077     * The euclidean distance will be used as default distance measure.
078     *
079     * @param k the number of clusters to split the data into
080     */
081    public KMeansPlusPlusClusterer(final int k) {
082        this(k, Integer.MAX_VALUE);
083    }
084
085    /** Build a clusterer.
086     * <p>
087     * The default strategy for handling empty clusters that may appear during
088     * algorithm iterations is to split the cluster with largest distance variance.
089     * <p>
090     * The euclidean distance will be used as default distance measure.
091     *
092     * @param k the number of clusters to split the data into
093     * @param maxIterations the maximum number of iterations to run the algorithm for.
094     *   If negative, no maximum will be used.
095     */
096    public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
097        this(k, maxIterations, new EuclideanDistance());
098    }
099
100    /** Build a clusterer.
101     * <p>
102     * The default strategy for handling empty clusters that may appear during
103     * algorithm iterations is to split the cluster with largest distance variance.
104     *
105     * @param k the number of clusters to split the data into
106     * @param maxIterations the maximum number of iterations to run the algorithm for.
107     * @param measure the distance measure to use
108     * @throws NotStrictlyPositiveException if {@code k <= 0}.
109     */
110    public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
111        this(k, maxIterations, measure, RandomSource.MT_64.create());
112    }
113
114    /** Build a clusterer.
115     * <p>
116     * The default strategy for handling empty clusters that may appear during
117     * algorithm iterations is to split the cluster with largest distance variance.
118     *
119     * @param k the number of clusters to split the data into
120     * @param maxIterations the maximum number of iterations to run the algorithm for.
121     *   If negative, no maximum will be used.
122     * @param measure the distance measure to use
123     * @param random random generator to use for choosing initial centers
124     */
125    public KMeansPlusPlusClusterer(final int k, final int maxIterations,
126                                   final DistanceMeasure measure,
127                                   final UniformRandomProvider random) {
128        this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
129    }
130
131    /** Build a clusterer.
132     *
133     * @param k the number of clusters to split the data into
134     * @param maxIterations the maximum number of iterations to run the algorithm for.
135     * @param measure the distance measure to use
136     * @param random random generator to use for choosing initial centers
137     * @param emptyStrategy strategy to use for handling empty clusters that
138     * may appear during algorithm iterations
139     * @throws NotStrictlyPositiveException if {@code k <= 0} or
140     * {@code maxIterations <= 0}.
141     */
142    public KMeansPlusPlusClusterer(final int k,
143                                   final int maxIterations,
144                                   final DistanceMeasure measure,
145                                   final UniformRandomProvider random,
146                                   final EmptyClusterStrategy emptyStrategy) {
147        super(measure);
148
149        if (k <= 0) {
150            throw new NotStrictlyPositiveException(k);
151        }
152        if (maxIterations <= 0) {
153            throw new NotStrictlyPositiveException(maxIterations);
154        }
155
156        this.numberOfClusters = k;
157        this.maxIterations = maxIterations;
158        this.random = random;
159        this.emptyStrategy = emptyStrategy;
160    }
161
162    /**
163     * Return the number of clusters this instance will use.
164     * @return the number of clusters
165     */
166    public int getNumberOfClusters() {
167        return numberOfClusters;
168    }
169
170    /**
171     * Returns the maximum number of iterations this instance will use.
172     * @return the maximum number of iterations, or -1 if no maximum is set
173     */
174    public int getMaxIterations() {
175        return maxIterations;
176    }
177
178    /**
179     * Runs the K-means++ clustering algorithm.
180     *
181     * @param points the points to cluster
182     * @return a list of clusters containing the points
183     * @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException
184     * if the data points are null or the number of clusters is larger than the
185     * number of data points
186     * @throws ConvergenceException if an empty cluster is encountered and the
187     * empty cluster strategy is set to {@link EmptyClusterStrategy#ERROR}
188     */
189    @Override
190    public List<CentroidCluster<T>> cluster(final Collection<T> points) {
191        // sanity checks
192        NullArgumentException.check(points);
193
194        // number of clusters has to be smaller or equal the number of data points
195        if (points.size() < numberOfClusters) {
196            throw new NumberIsTooSmallException(points.size(), numberOfClusters, false);
197        }
198
199        // create the initial clusters
200        List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
201
202        // create an array containing the latest assignment of a point to a cluster
203        // no need to initialize the array, as it will be filled with the first assignment
204        int[] assignments = new int[points.size()];
205        assignPointsToClusters(clusters, points, assignments);
206
207        // iterate through updating the centers until we're done
208        for (int count = 0; count < maxIterations; count++) {
209            boolean hasEmptyCluster = clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty());
210            List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
211            int changes = assignPointsToClusters(newClusters, points, assignments);
212            clusters = newClusters;
213
214            // if there were no more changes in the point-to-cluster assignment
215            // and there are no empty clusters left, return the current clusters
216            if (changes == 0 && !hasEmptyCluster) {
217                return clusters;
218            }
219        }
220        return clusters;
221    }
222
223    /**
224     * @return the random generator
225     */
226    UniformRandomProvider getRandomGenerator() {
227        return random;
228    }
229
230    /**
231     * @return the {@link EmptyClusterStrategy}
232     */
233    EmptyClusterStrategy getEmptyClusterStrategy() {
234        return emptyStrategy;
235    }
236
237    /**
238     * Adjust the clusters's centers with means of points.
239     * @param clusters the origin clusters
240     * @return adjusted clusters with center points
241     */
242    List<CentroidCluster<T>> adjustClustersCenters(List<CentroidCluster<T>> clusters) {
243        List<CentroidCluster<T>> newClusters = new ArrayList<>();
244        for (final CentroidCluster<T> cluster : clusters) {
245            final Clusterable newCenter;
246            if (cluster.getPoints().isEmpty()) {
247                switch (emptyStrategy) {
248                    case LARGEST_VARIANCE :
249                        newCenter = getPointFromLargestVarianceCluster(clusters);
250                        break;
251                    case LARGEST_POINTS_NUMBER :
252                        newCenter = getPointFromLargestNumberCluster(clusters);
253                        break;
254                    case FARTHEST_POINT :
255                        newCenter = getFarthestPoint(clusters);
256                        break;
257                    default :
258                        throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
259                }
260            } else {
261                newCenter = cluster.centroid();
262            }
263            newClusters.add(new CentroidCluster<>(newCenter));
264        }
265        return newClusters;
266    }
267
268    /**
269     * Adds the given points to the closest {@link Cluster}.
270     *
271     * @param clusters the {@link Cluster}s to add the points to
272     * @param points the points to add to the given {@link Cluster}s
273     * @param assignments points assignments to clusters
274     * @return the number of points assigned to different clusters as the iteration before
275     */
276    private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
277                                       final Collection<T> points,
278                                       final int[] assignments) {
279        int assignedDifferently = 0;
280        int pointIndex = 0;
281        for (final T p : points) {
282            int clusterIndex = getNearestCluster(clusters, p);
283            if (clusterIndex != assignments[pointIndex]) {
284                assignedDifferently++;
285            }
286
287            CentroidCluster<T> cluster = clusters.get(clusterIndex);
288            cluster.addPoint(p);
289            assignments[pointIndex++] = clusterIndex;
290        }
291
292        return assignedDifferently;
293    }
294
295    /**
296     * Use K-means++ to choose the initial centers.
297     *
298     * @param points the points to choose the initial centers from
299     * @return the initial centers
300     */
301    List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
302
303        // Convert to list for indexed access. Make it unmodifiable, since removal of items
304        // would screw up the logic of this method.
305        final List<T> pointList = Collections.unmodifiableList(new ArrayList<> (points));
306
307        // The number of points in the list.
308        final int numPoints = pointList.size();
309
310        // Set the corresponding element in this array to indicate when
311        // elements of pointList are no longer available.
312        final boolean[] taken = new boolean[numPoints];
313
314        // The resulting list of initial centers.
315        final List<CentroidCluster<T>> resultSet = new ArrayList<>();
316
317        // Choose one center uniformly at random from among the data points.
318        final int firstPointIndex = random.nextInt(numPoints);
319
320        final T firstPoint = pointList.get(firstPointIndex);
321
322        resultSet.add(new CentroidCluster<>(firstPoint));
323
324        // Must mark it as taken
325        taken[firstPointIndex] = true;
326
327        // To keep track of the minimum distance squared of elements of
328        // pointList to elements of resultSet.
329        final double[] minDistSquared = new double[numPoints];
330
331        // Initialize the elements.  Since the only point in resultSet is firstPoint,
332        // this is very easy.
333        for (int i = 0; i < numPoints; i++) {
334            if (i != firstPointIndex) { // That point isn't considered
335                double d = distance(firstPoint, pointList.get(i));
336                minDistSquared[i] = d*d;
337            }
338        }
339
340        while (resultSet.size() < numberOfClusters) {
341
342            // Sum up the squared distances for the points in pointList not
343            // already taken.
344            double distSqSum = 0.0;
345
346            for (int i = 0; i < numPoints; i++) {
347                if (!taken[i]) {
348                    distSqSum += minDistSquared[i];
349                }
350            }
351
352            // Add one new data point as a center. Each point x is chosen with
353            // probability proportional to D(x)2
354            final double r = random.nextDouble() * distSqSum;
355
356            // The index of the next point to be added to the resultSet.
357            int nextPointIndex = -1;
358
359            // Sum through the squared min distances again, stopping when
360            // sum >= r.
361            double sum = 0.0;
362            for (int i = 0; i < numPoints; i++) {
363                if (!taken[i]) {
364                    sum += minDistSquared[i];
365                    if (sum >= r) {
366                        nextPointIndex = i;
367                        break;
368                    }
369                }
370            }
371
372            // If it's not set to >= 0, the point wasn't found in the previous
373            // for loop, probably because distances are extremely small.  Just pick
374            // the last available point.
375            if (nextPointIndex == -1) {
376                for (int i = numPoints - 1; i >= 0; i--) {
377                    if (!taken[i]) {
378                        nextPointIndex = i;
379                        break;
380                    }
381                }
382            }
383
384            // We found one.
385            if (nextPointIndex >= 0) {
386
387                final T p = pointList.get(nextPointIndex);
388
389                resultSet.add(new CentroidCluster<T> (p));
390
391                // Mark it as taken.
392                taken[nextPointIndex] = true;
393
394                if (resultSet.size() < numberOfClusters) {
395                    // Now update elements of minDistSquared.  We only have to compute
396                    // the distance to the new center to do this.
397                    for (int j = 0; j < numPoints; j++) {
398                        // Only have to worry about the points still not taken.
399                        if (!taken[j]) {
400                            double d = distance(p, pointList.get(j));
401                            double d2 = d * d;
402                            if (d2 < minDistSquared[j]) {
403                                minDistSquared[j] = d2;
404                            }
405                        }
406                    }
407                }
408            } else {
409                // None found --
410                // Break from the while loop to prevent
411                // an infinite loop.
412                break;
413            }
414        }
415
416        return resultSet;
417    }
418
419    /**
420     * Get a random point from the {@link Cluster} with the largest distance variance.
421     *
422     * @param clusters the {@link Cluster}s to search
423     * @return a random point from the selected cluster
424     * @throws ConvergenceException if clusters are all empty
425     */
426    private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters) {
427        double maxVariance = Double.NEGATIVE_INFINITY;
428        Cluster<T> selected = null;
429        for (final CentroidCluster<T> cluster : clusters) {
430            if (!cluster.getPoints().isEmpty()) {
431
432                // compute the distance variance of the current cluster
433                final Clusterable center = cluster.getCenter();
434                final Variance stat = new Variance();
435                for (final T point : cluster.getPoints()) {
436                    stat.increment(distance(point, center));
437                }
438                final double variance = stat.getResult();
439
440                // select the cluster with the largest variance
441                if (variance > maxVariance) {
442                    maxVariance = variance;
443                    selected = cluster;
444                }
445            }
446        }
447
448        // did we find at least one non-empty cluster ?
449        if (selected == null) {
450            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
451        }
452
453        // extract a random point from the cluster
454        final List<T> selectedPoints = selected.getPoints();
455        return selectedPoints.remove(random.nextInt(selectedPoints.size()));
456    }
457
458    /**
459     * Get a random point from the {@link Cluster} with the largest number of points.
460     *
461     * @param clusters the {@link Cluster}s to search
462     * @return a random point from the selected cluster
463     * @throws ConvergenceException if clusters are all empty
464     */
465    private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters) {
466        int maxNumber = 0;
467        Cluster<T> selected = null;
468        for (final Cluster<T> cluster : clusters) {
469
470            // get the number of points of the current cluster
471            final int number = cluster.getPoints().size();
472
473            // select the cluster with the largest number of points
474            if (number > maxNumber) {
475                maxNumber = number;
476                selected = cluster;
477            }
478        }
479
480        // did we find at least one non-empty cluster ?
481        if (selected == null) {
482            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
483        }
484
485        // extract a random point from the cluster
486        final List<T> selectedPoints = selected.getPoints();
487        return selectedPoints.remove(random.nextInt(selectedPoints.size()));
488    }
489
490    /**
491     * Get the point farthest to its cluster center.
492     *
493     * @param clusters the {@link Cluster}s to search
494     * @return point farthest to its cluster center
495     * @throws ConvergenceException if clusters are all empty
496     */
497    private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) {
498        double maxDistance = Double.NEGATIVE_INFINITY;
499        Cluster<T> selectedCluster = null;
500        int selectedPoint = -1;
501        for (final CentroidCluster<T> cluster : clusters) {
502
503            // get the farthest point
504            final Clusterable center = cluster.getCenter();
505            final List<T> points = cluster.getPoints();
506            for (int i = 0; i < points.size(); ++i) {
507                final double distance = distance(points.get(i), center);
508                if (distance > maxDistance) {
509                    maxDistance     = distance;
510                    selectedCluster = cluster;
511                    selectedPoint   = i;
512                }
513            }
514        }
515
516        // did we find at least one non-empty cluster ?
517        if (selectedCluster == null) {
518            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
519        }
520
521        return selectedCluster.getPoints().remove(selectedPoint);
522    }
523
524    /**
525     * Returns the nearest {@link Cluster} to the given point.
526     *
527     * @param clusters the {@link Cluster}s to search
528     * @param point the point to find the nearest {@link Cluster} for
529     * @return the index of the nearest {@link Cluster} to the given point
530     */
531    private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
532        double minDistance = Double.MAX_VALUE;
533        int clusterIndex = 0;
534        int minCluster = 0;
535        for (final CentroidCluster<T> c : clusters) {
536            final double distance = distance(point, c.getCenter());
537            if (distance < minDistance) {
538                minDistance = distance;
539                minCluster = clusterIndex;
540            }
541            clusterIndex++;
542        }
543        return minCluster;
544    }
545}