001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.ml.clustering;
019
020import java.util.ArrayList;
021import java.util.Collection;
022import java.util.Collections;
023import java.util.List;
024
025import org.apache.commons.math3.exception.ConvergenceException;
026import org.apache.commons.math3.exception.MathIllegalArgumentException;
027import org.apache.commons.math3.exception.NumberIsTooSmallException;
028import org.apache.commons.math3.exception.util.LocalizedFormats;
029import org.apache.commons.math3.ml.distance.DistanceMeasure;
030import org.apache.commons.math3.ml.distance.EuclideanDistance;
031import org.apache.commons.math3.random.JDKRandomGenerator;
032import org.apache.commons.math3.random.RandomGenerator;
033import org.apache.commons.math3.stat.descriptive.moment.Variance;
034import org.apache.commons.math3.util.MathUtils;
035
036/**
037 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
038 * @param <T> type of the points to cluster
039 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
040 * @version $Id: KMeansPlusPlusClusterer.EmptyClusterStrategy.html 885258 2013-11-03 02:46:49Z tn $
041 * @since 3.2
042 */
043public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
044
045    /** Strategies to use for replacing an empty cluster. */
046    public static enum EmptyClusterStrategy {
047
048        /** Split the cluster with largest distance variance. */
049        LARGEST_VARIANCE,
050
051        /** Split the cluster with largest number of points. */
052        LARGEST_POINTS_NUMBER,
053
054        /** Create a cluster around the point farthest from its centroid. */
055        FARTHEST_POINT,
056
057        /** Generate an error. */
058        ERROR
059
060    }
061
062    /** The number of clusters. */
063    private final int k;
064
065    /** The maximum number of iterations. */
066    private final int maxIterations;
067
068    /** Random generator for choosing initial centers. */
069    private final RandomGenerator random;
070
071    /** Selected strategy for empty clusters. */
072    private final EmptyClusterStrategy emptyStrategy;
073
074    /** Build a clusterer.
075     * <p>
076     * The default strategy for handling empty clusters that may appear during
077     * algorithm iterations is to split the cluster with largest distance variance.
078     * <p>
079     * The euclidean distance will be used as default distance measure.
080     *
081     * @param k the number of clusters to split the data into
082     */
083    public KMeansPlusPlusClusterer(final int k) {
084        this(k, -1);
085    }
086
087    /** Build a clusterer.
088     * <p>
089     * The default strategy for handling empty clusters that may appear during
090     * algorithm iterations is to split the cluster with largest distance variance.
091     * <p>
092     * The euclidean distance will be used as default distance measure.
093     *
094     * @param k the number of clusters to split the data into
095     * @param maxIterations the maximum number of iterations to run the algorithm for.
096     *   If negative, no maximum will be used.
097     */
098    public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
099        this(k, maxIterations, new EuclideanDistance());
100    }
101
102    /** Build a clusterer.
103     * <p>
104     * The default strategy for handling empty clusters that may appear during
105     * algorithm iterations is to split the cluster with largest distance variance.
106     *
107     * @param k the number of clusters to split the data into
108     * @param maxIterations the maximum number of iterations to run the algorithm for.
109     *   If negative, no maximum will be used.
110     * @param measure the distance measure to use
111     */
112    public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
113        this(k, maxIterations, measure, new JDKRandomGenerator());
114    }
115
116    /** Build a clusterer.
117     * <p>
118     * The default strategy for handling empty clusters that may appear during
119     * algorithm iterations is to split the cluster with largest distance variance.
120     *
121     * @param k the number of clusters to split the data into
122     * @param maxIterations the maximum number of iterations to run the algorithm for.
123     *   If negative, no maximum will be used.
124     * @param measure the distance measure to use
125     * @param random random generator to use for choosing initial centers
126     */
127    public KMeansPlusPlusClusterer(final int k, final int maxIterations,
128                                   final DistanceMeasure measure,
129                                   final RandomGenerator random) {
130        this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
131    }
132
133    /** Build a clusterer.
134     *
135     * @param k the number of clusters to split the data into
136     * @param maxIterations the maximum number of iterations to run the algorithm for.
137     *   If negative, no maximum will be used.
138     * @param measure the distance measure to use
139     * @param random random generator to use for choosing initial centers
140     * @param emptyStrategy strategy to use for handling empty clusters that
141     * may appear during algorithm iterations
142     */
143    public KMeansPlusPlusClusterer(final int k, final int maxIterations,
144                                   final DistanceMeasure measure,
145                                   final RandomGenerator random,
146                                   final EmptyClusterStrategy emptyStrategy) {
147        super(measure);
148        this.k             = k;
149        this.maxIterations = maxIterations;
150        this.random        = random;
151        this.emptyStrategy = emptyStrategy;
152    }
153
154    /**
155     * Return the number of clusters this instance will use.
156     * @return the number of clusters
157     */
158    public int getK() {
159        return k;
160    }
161
162    /**
163     * Returns the maximum number of iterations this instance will use.
164     * @return the maximum number of iterations, or -1 if no maximum is set
165     */
166    public int getMaxIterations() {
167        return maxIterations;
168    }
169
170    /**
171     * Returns the random generator this instance will use.
172     * @return the random generator
173     */
174    public RandomGenerator getRandomGenerator() {
175        return random;
176    }
177
178    /**
179     * Returns the {@link EmptyClusterStrategy} used by this instance.
180     * @return the {@link EmptyClusterStrategy}
181     */
182    public EmptyClusterStrategy getEmptyClusterStrategy() {
183        return emptyStrategy;
184    }
185
186    /**
187     * Runs the K-means++ clustering algorithm.
188     *
189     * @param points the points to cluster
190     * @return a list of clusters containing the points
191     * @throws MathIllegalArgumentException if the data points are null or the number
192     *     of clusters is larger than the number of data points
193     * @throws ConvergenceException if an empty cluster is encountered and the
194     * {@link #emptyStrategy} is set to {@code ERROR}
195     */
196    @Override
197    public List<CentroidCluster<T>> cluster(final Collection<T> points)
198        throws MathIllegalArgumentException, ConvergenceException {
199
200        // sanity checks
201        MathUtils.checkNotNull(points);
202
203        // number of clusters has to be smaller or equal the number of data points
204        if (points.size() < k) {
205            throw new NumberIsTooSmallException(points.size(), k, false);
206        }
207
208        // create the initial clusters
209        List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
210
211        // create an array containing the latest assignment of a point to a cluster
212        // no need to initialize the array, as it will be filled with the first assignment
213        int[] assignments = new int[points.size()];
214        assignPointsToClusters(clusters, points, assignments);
215
216        // iterate through updating the centers until we're done
217        final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
218        for (int count = 0; count < max; count++) {
219            boolean emptyCluster = false;
220            List<CentroidCluster<T>> newClusters = new ArrayList<CentroidCluster<T>>();
221            for (final CentroidCluster<T> cluster : clusters) {
222                final Clusterable newCenter;
223                if (cluster.getPoints().isEmpty()) {
224                    switch (emptyStrategy) {
225                        case LARGEST_VARIANCE :
226                            newCenter = getPointFromLargestVarianceCluster(clusters);
227                            break;
228                        case LARGEST_POINTS_NUMBER :
229                            newCenter = getPointFromLargestNumberCluster(clusters);
230                            break;
231                        case FARTHEST_POINT :
232                            newCenter = getFarthestPoint(clusters);
233                            break;
234                        default :
235                            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
236                    }
237                    emptyCluster = true;
238                } else {
239                    newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
240                }
241                newClusters.add(new CentroidCluster<T>(newCenter));
242            }
243            int changes = assignPointsToClusters(newClusters, points, assignments);
244            clusters = newClusters;
245
246            // if there were no more changes in the point-to-cluster assignment
247            // and there are no empty clusters left, return the current clusters
248            if (changes == 0 && !emptyCluster) {
249                return clusters;
250            }
251        }
252        return clusters;
253    }
254
255    /**
256     * Adds the given points to the closest {@link Cluster}.
257     *
258     * @param clusters the {@link Cluster}s to add the points to
259     * @param points the points to add to the given {@link Cluster}s
260     * @param assignments points assignments to clusters
261     * @return the number of points assigned to different clusters as the iteration before
262     */
263    private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
264                                       final Collection<T> points,
265                                       final int[] assignments) {
266        int assignedDifferently = 0;
267        int pointIndex = 0;
268        for (final T p : points) {
269            int clusterIndex = getNearestCluster(clusters, p);
270            if (clusterIndex != assignments[pointIndex]) {
271                assignedDifferently++;
272            }
273
274            CentroidCluster<T> cluster = clusters.get(clusterIndex);
275            cluster.addPoint(p);
276            assignments[pointIndex++] = clusterIndex;
277        }
278
279        return assignedDifferently;
280    }
281
282    /**
283     * Use K-means++ to choose the initial centers.
284     *
285     * @param points the points to choose the initial centers from
286     * @return the initial centers
287     */
288    private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
289
290        // Convert to list for indexed access. Make it unmodifiable, since removal of items
291        // would screw up the logic of this method.
292        final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
293
294        // The number of points in the list.
295        final int numPoints = pointList.size();
296
297        // Set the corresponding element in this array to indicate when
298        // elements of pointList are no longer available.
299        final boolean[] taken = new boolean[numPoints];
300
301        // The resulting list of initial centers.
302        final List<CentroidCluster<T>> resultSet = new ArrayList<CentroidCluster<T>>();
303
304        // Choose one center uniformly at random from among the data points.
305        final int firstPointIndex = random.nextInt(numPoints);
306
307        final T firstPoint = pointList.get(firstPointIndex);
308
309        resultSet.add(new CentroidCluster<T>(firstPoint));
310
311        // Must mark it as taken
312        taken[firstPointIndex] = true;
313
314        // To keep track of the minimum distance squared of elements of
315        // pointList to elements of resultSet.
316        final double[] minDistSquared = new double[numPoints];
317
318        // Initialize the elements.  Since the only point in resultSet is firstPoint,
319        // this is very easy.
320        for (int i = 0; i < numPoints; i++) {
321            if (i != firstPointIndex) { // That point isn't considered
322                double d = distance(firstPoint, pointList.get(i));
323                minDistSquared[i] = d*d;
324            }
325        }
326
327        while (resultSet.size() < k) {
328
329            // Sum up the squared distances for the points in pointList not
330            // already taken.
331            double distSqSum = 0.0;
332
333            for (int i = 0; i < numPoints; i++) {
334                if (!taken[i]) {
335                    distSqSum += minDistSquared[i];
336                }
337            }
338
339            // Add one new data point as a center. Each point x is chosen with
340            // probability proportional to D(x)2
341            final double r = random.nextDouble() * distSqSum;
342
343            // The index of the next point to be added to the resultSet.
344            int nextPointIndex = -1;
345
346            // Sum through the squared min distances again, stopping when
347            // sum >= r.
348            double sum = 0.0;
349            for (int i = 0; i < numPoints; i++) {
350                if (!taken[i]) {
351                    sum += minDistSquared[i];
352                    if (sum >= r) {
353                        nextPointIndex = i;
354                        break;
355                    }
356                }
357            }
358
359            // If it's not set to >= 0, the point wasn't found in the previous
360            // for loop, probably because distances are extremely small.  Just pick
361            // the last available point.
362            if (nextPointIndex == -1) {
363                for (int i = numPoints - 1; i >= 0; i--) {
364                    if (!taken[i]) {
365                        nextPointIndex = i;
366                        break;
367                    }
368                }
369            }
370
371            // We found one.
372            if (nextPointIndex >= 0) {
373
374                final T p = pointList.get(nextPointIndex);
375
376                resultSet.add(new CentroidCluster<T> (p));
377
378                // Mark it as taken.
379                taken[nextPointIndex] = true;
380
381                if (resultSet.size() < k) {
382                    // Now update elements of minDistSquared.  We only have to compute
383                    // the distance to the new center to do this.
384                    for (int j = 0; j < numPoints; j++) {
385                        // Only have to worry about the points still not taken.
386                        if (!taken[j]) {
387                            double d = distance(p, pointList.get(j));
388                            double d2 = d * d;
389                            if (d2 < minDistSquared[j]) {
390                                minDistSquared[j] = d2;
391                            }
392                        }
393                    }
394                }
395
396            } else {
397                // None found --
398                // Break from the while loop to prevent
399                // an infinite loop.
400                break;
401            }
402        }
403
404        return resultSet;
405    }
406
407    /**
408     * Get a random point from the {@link Cluster} with the largest distance variance.
409     *
410     * @param clusters the {@link Cluster}s to search
411     * @return a random point from the selected cluster
412     * @throws ConvergenceException if clusters are all empty
413     */
414    private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters)
415            throws ConvergenceException {
416
417        double maxVariance = Double.NEGATIVE_INFINITY;
418        Cluster<T> selected = null;
419        for (final CentroidCluster<T> cluster : clusters) {
420            if (!cluster.getPoints().isEmpty()) {
421
422                // compute the distance variance of the current cluster
423                final Clusterable center = cluster.getCenter();
424                final Variance stat = new Variance();
425                for (final T point : cluster.getPoints()) {
426                    stat.increment(distance(point, center));
427                }
428                final double variance = stat.getResult();
429
430                // select the cluster with the largest variance
431                if (variance > maxVariance) {
432                    maxVariance = variance;
433                    selected = cluster;
434                }
435
436            }
437        }
438
439        // did we find at least one non-empty cluster ?
440        if (selected == null) {
441            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
442        }
443
444        // extract a random point from the cluster
445        final List<T> selectedPoints = selected.getPoints();
446        return selectedPoints.remove(random.nextInt(selectedPoints.size()));
447
448    }
449
450    /**
451     * Get a random point from the {@link Cluster} with the largest number of points
452     *
453     * @param clusters the {@link Cluster}s to search
454     * @return a random point from the selected cluster
455     * @throws ConvergenceException if clusters are all empty
456     */
457    private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters)
458            throws ConvergenceException {
459
460        int maxNumber = 0;
461        Cluster<T> selected = null;
462        for (final Cluster<T> cluster : clusters) {
463
464            // get the number of points of the current cluster
465            final int number = cluster.getPoints().size();
466
467            // select the cluster with the largest number of points
468            if (number > maxNumber) {
469                maxNumber = number;
470                selected = cluster;
471            }
472
473        }
474
475        // did we find at least one non-empty cluster ?
476        if (selected == null) {
477            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
478        }
479
480        // extract a random point from the cluster
481        final List<T> selectedPoints = selected.getPoints();
482        return selectedPoints.remove(random.nextInt(selectedPoints.size()));
483
484    }
485
486    /**
487     * Get the point farthest to its cluster center
488     *
489     * @param clusters the {@link Cluster}s to search
490     * @return point farthest to its cluster center
491     * @throws ConvergenceException if clusters are all empty
492     */
493    private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) throws ConvergenceException {
494
495        double maxDistance = Double.NEGATIVE_INFINITY;
496        Cluster<T> selectedCluster = null;
497        int selectedPoint = -1;
498        for (final CentroidCluster<T> cluster : clusters) {
499
500            // get the farthest point
501            final Clusterable center = cluster.getCenter();
502            final List<T> points = cluster.getPoints();
503            for (int i = 0; i < points.size(); ++i) {
504                final double distance = distance(points.get(i), center);
505                if (distance > maxDistance) {
506                    maxDistance     = distance;
507                    selectedCluster = cluster;
508                    selectedPoint   = i;
509                }
510            }
511
512        }
513
514        // did we find at least one non-empty cluster ?
515        if (selectedCluster == null) {
516            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
517        }
518
519        return selectedCluster.getPoints().remove(selectedPoint);
520
521    }
522
523    /**
524     * Returns the nearest {@link Cluster} to the given point
525     *
526     * @param clusters the {@link Cluster}s to search
527     * @param point the point to find the nearest {@link Cluster} for
528     * @return the index of the nearest {@link Cluster} to the given point
529     */
530    private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
531        double minDistance = Double.MAX_VALUE;
532        int clusterIndex = 0;
533        int minCluster = 0;
534        for (final CentroidCluster<T> c : clusters) {
535            final double distance = distance(point, c.getCenter());
536            if (distance < minDistance) {
537                minDistance = distance;
538                minCluster = clusterIndex;
539            }
540            clusterIndex++;
541        }
542        return minCluster;
543    }
544
545    /**
546     * Computes the centroid for a set of points.
547     *
548     * @param points the set of points
549     * @param dimension the point dimension
550     * @return the computed centroid for the set of points
551     */
552    private Clusterable centroidOf(final Collection<T> points, final int dimension) {
553        final double[] centroid = new double[dimension];
554        for (final T p : points) {
555            final double[] point = p.getPoint();
556            for (int i = 0; i < centroid.length; i++) {
557                centroid[i] += point[i];
558            }
559        }
560        for (int i = 0; i < centroid.length; i++) {
561            centroid[i] /= points.size();
562        }
563        return new DoublePoint(centroid);
564    }
565
566}