001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.ml.clustering;
019
020import java.util.ArrayList;
021import java.util.Collection;
022import java.util.Collections;
023import java.util.List;
024
025import org.apache.commons.math3.exception.ConvergenceException;
026import org.apache.commons.math3.exception.MathIllegalArgumentException;
027import org.apache.commons.math3.exception.NumberIsTooSmallException;
028import org.apache.commons.math3.exception.util.LocalizedFormats;
029import org.apache.commons.math3.ml.distance.DistanceMeasure;
030import org.apache.commons.math3.ml.distance.EuclideanDistance;
031import org.apache.commons.math3.random.JDKRandomGenerator;
032import org.apache.commons.math3.random.RandomGenerator;
033import org.apache.commons.math3.stat.descriptive.moment.Variance;
034import org.apache.commons.math3.util.MathUtils;
035
036/**
037 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
038 * @param <T> type of the points to cluster
039 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
040 * @since 3.2
041 */
042public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
043
044    /** Strategies to use for replacing an empty cluster. */
045    public enum EmptyClusterStrategy {
046
047        /** Split the cluster with largest distance variance. */
048        LARGEST_VARIANCE,
049
050        /** Split the cluster with largest number of points. */
051        LARGEST_POINTS_NUMBER,
052
053        /** Create a cluster around the point farthest from its centroid. */
054        FARTHEST_POINT,
055
056        /** Generate an error. */
057        ERROR
058
059    }
060
061    /** The number of clusters. */
062    private final int k;
063
064    /** The maximum number of iterations. */
065    private final int maxIterations;
066
067    /** Random generator for choosing initial centers. */
068    private final RandomGenerator random;
069
070    /** Selected strategy for empty clusters. */
071    private final EmptyClusterStrategy emptyStrategy;
072
073    /** Build a clusterer.
074     * <p>
075     * The default strategy for handling empty clusters that may appear during
076     * algorithm iterations is to split the cluster with largest distance variance.
077     * <p>
078     * The euclidean distance will be used as default distance measure.
079     *
080     * @param k the number of clusters to split the data into
081     */
082    public KMeansPlusPlusClusterer(final int k) {
083        this(k, -1);
084    }
085
086    /** Build a clusterer.
087     * <p>
088     * The default strategy for handling empty clusters that may appear during
089     * algorithm iterations is to split the cluster with largest distance variance.
090     * <p>
091     * The euclidean distance will be used as default distance measure.
092     *
093     * @param k the number of clusters to split the data into
094     * @param maxIterations the maximum number of iterations to run the algorithm for.
095     *   If negative, no maximum will be used.
096     */
097    public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
098        this(k, maxIterations, new EuclideanDistance());
099    }
100
101    /** Build a clusterer.
102     * <p>
103     * The default strategy for handling empty clusters that may appear during
104     * algorithm iterations is to split the cluster with largest distance variance.
105     *
106     * @param k the number of clusters to split the data into
107     * @param maxIterations the maximum number of iterations to run the algorithm for.
108     *   If negative, no maximum will be used.
109     * @param measure the distance measure to use
110     */
111    public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
112        this(k, maxIterations, measure, new JDKRandomGenerator());
113    }
114
115    /** Build a clusterer.
116     * <p>
117     * The default strategy for handling empty clusters that may appear during
118     * algorithm iterations is to split the cluster with largest distance variance.
119     *
120     * @param k the number of clusters to split the data into
121     * @param maxIterations the maximum number of iterations to run the algorithm for.
122     *   If negative, no maximum will be used.
123     * @param measure the distance measure to use
124     * @param random random generator to use for choosing initial centers
125     */
126    public KMeansPlusPlusClusterer(final int k, final int maxIterations,
127                                   final DistanceMeasure measure,
128                                   final RandomGenerator random) {
129        this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
130    }
131
132    /** Build a clusterer.
133     *
134     * @param k the number of clusters to split the data into
135     * @param maxIterations the maximum number of iterations to run the algorithm for.
136     *   If negative, no maximum will be used.
137     * @param measure the distance measure to use
138     * @param random random generator to use for choosing initial centers
139     * @param emptyStrategy strategy to use for handling empty clusters that
140     * may appear during algorithm iterations
141     */
142    public KMeansPlusPlusClusterer(final int k, final int maxIterations,
143                                   final DistanceMeasure measure,
144                                   final RandomGenerator random,
145                                   final EmptyClusterStrategy emptyStrategy) {
146        super(measure);
147        this.k             = k;
148        this.maxIterations = maxIterations;
149        this.random        = random;
150        this.emptyStrategy = emptyStrategy;
151    }
152
153    /**
154     * Return the number of clusters this instance will use.
155     * @return the number of clusters
156     */
157    public int getK() {
158        return k;
159    }
160
161    /**
162     * Returns the maximum number of iterations this instance will use.
163     * @return the maximum number of iterations, or -1 if no maximum is set
164     */
165    public int getMaxIterations() {
166        return maxIterations;
167    }
168
169    /**
170     * Returns the random generator this instance will use.
171     * @return the random generator
172     */
173    public RandomGenerator getRandomGenerator() {
174        return random;
175    }
176
177    /**
178     * Returns the {@link EmptyClusterStrategy} used by this instance.
179     * @return the {@link EmptyClusterStrategy}
180     */
181    public EmptyClusterStrategy getEmptyClusterStrategy() {
182        return emptyStrategy;
183    }
184
185    /**
186     * Runs the K-means++ clustering algorithm.
187     *
188     * @param points the points to cluster
189     * @return a list of clusters containing the points
190     * @throws MathIllegalArgumentException if the data points are null or the number
191     *     of clusters is larger than the number of data points
192     * @throws ConvergenceException if an empty cluster is encountered and the
193     * {@link #emptyStrategy} is set to {@code ERROR}
194     */
195    @Override
196    public List<CentroidCluster<T>> cluster(final Collection<T> points)
197        throws MathIllegalArgumentException, ConvergenceException {
198
199        // sanity checks
200        MathUtils.checkNotNull(points);
201
202        // number of clusters has to be smaller or equal the number of data points
203        if (points.size() < k) {
204            throw new NumberIsTooSmallException(points.size(), k, false);
205        }
206
207        // create the initial clusters
208        List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
209
210        // create an array containing the latest assignment of a point to a cluster
211        // no need to initialize the array, as it will be filled with the first assignment
212        int[] assignments = new int[points.size()];
213        assignPointsToClusters(clusters, points, assignments);
214
215        // iterate through updating the centers until we're done
216        final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
217        for (int count = 0; count < max; count++) {
218            boolean emptyCluster = false;
219            List<CentroidCluster<T>> newClusters = new ArrayList<CentroidCluster<T>>();
220            for (final CentroidCluster<T> cluster : clusters) {
221                final Clusterable newCenter;
222                if (cluster.getPoints().isEmpty()) {
223                    switch (emptyStrategy) {
224                        case LARGEST_VARIANCE :
225                            newCenter = getPointFromLargestVarianceCluster(clusters);
226                            break;
227                        case LARGEST_POINTS_NUMBER :
228                            newCenter = getPointFromLargestNumberCluster(clusters);
229                            break;
230                        case FARTHEST_POINT :
231                            newCenter = getFarthestPoint(clusters);
232                            break;
233                        default :
234                            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
235                    }
236                    emptyCluster = true;
237                } else {
238                    newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
239                }
240                newClusters.add(new CentroidCluster<T>(newCenter));
241            }
242            int changes = assignPointsToClusters(newClusters, points, assignments);
243            clusters = newClusters;
244
245            // if there were no more changes in the point-to-cluster assignment
246            // and there are no empty clusters left, return the current clusters
247            if (changes == 0 && !emptyCluster) {
248                return clusters;
249            }
250        }
251        return clusters;
252    }
253
254    /**
255     * Adds the given points to the closest {@link Cluster}.
256     *
257     * @param clusters the {@link Cluster}s to add the points to
258     * @param points the points to add to the given {@link Cluster}s
259     * @param assignments points assignments to clusters
260     * @return the number of points assigned to different clusters as the iteration before
261     */
262    private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
263                                       final Collection<T> points,
264                                       final int[] assignments) {
265        int assignedDifferently = 0;
266        int pointIndex = 0;
267        for (final T p : points) {
268            int clusterIndex = getNearestCluster(clusters, p);
269            if (clusterIndex != assignments[pointIndex]) {
270                assignedDifferently++;
271            }
272
273            CentroidCluster<T> cluster = clusters.get(clusterIndex);
274            cluster.addPoint(p);
275            assignments[pointIndex++] = clusterIndex;
276        }
277
278        return assignedDifferently;
279    }
280
281    /**
282     * Use K-means++ to choose the initial centers.
283     *
284     * @param points the points to choose the initial centers from
285     * @return the initial centers
286     */
287    private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
288
289        // Convert to list for indexed access. Make it unmodifiable, since removal of items
290        // would screw up the logic of this method.
291        final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
292
293        // The number of points in the list.
294        final int numPoints = pointList.size();
295
296        // Set the corresponding element in this array to indicate when
297        // elements of pointList are no longer available.
298        final boolean[] taken = new boolean[numPoints];
299
300        // The resulting list of initial centers.
301        final List<CentroidCluster<T>> resultSet = new ArrayList<CentroidCluster<T>>();
302
303        // Choose one center uniformly at random from among the data points.
304        final int firstPointIndex = random.nextInt(numPoints);
305
306        final T firstPoint = pointList.get(firstPointIndex);
307
308        resultSet.add(new CentroidCluster<T>(firstPoint));
309
310        // Must mark it as taken
311        taken[firstPointIndex] = true;
312
313        // To keep track of the minimum distance squared of elements of
314        // pointList to elements of resultSet.
315        final double[] minDistSquared = new double[numPoints];
316
317        // Initialize the elements.  Since the only point in resultSet is firstPoint,
318        // this is very easy.
319        for (int i = 0; i < numPoints; i++) {
320            if (i != firstPointIndex) { // That point isn't considered
321                double d = distance(firstPoint, pointList.get(i));
322                minDistSquared[i] = d*d;
323            }
324        }
325
326        while (resultSet.size() < k) {
327
328            // Sum up the squared distances for the points in pointList not
329            // already taken.
330            double distSqSum = 0.0;
331
332            for (int i = 0; i < numPoints; i++) {
333                if (!taken[i]) {
334                    distSqSum += minDistSquared[i];
335                }
336            }
337
338            // Add one new data point as a center. Each point x is chosen with
339            // probability proportional to D(x)2
340            final double r = random.nextDouble() * distSqSum;
341
342            // The index of the next point to be added to the resultSet.
343            int nextPointIndex = -1;
344
345            // Sum through the squared min distances again, stopping when
346            // sum >= r.
347            double sum = 0.0;
348            for (int i = 0; i < numPoints; i++) {
349                if (!taken[i]) {
350                    sum += minDistSquared[i];
351                    if (sum >= r) {
352                        nextPointIndex = i;
353                        break;
354                    }
355                }
356            }
357
358            // If it's not set to >= 0, the point wasn't found in the previous
359            // for loop, probably because distances are extremely small.  Just pick
360            // the last available point.
361            if (nextPointIndex == -1) {
362                for (int i = numPoints - 1; i >= 0; i--) {
363                    if (!taken[i]) {
364                        nextPointIndex = i;
365                        break;
366                    }
367                }
368            }
369
370            // We found one.
371            if (nextPointIndex >= 0) {
372
373                final T p = pointList.get(nextPointIndex);
374
375                resultSet.add(new CentroidCluster<T> (p));
376
377                // Mark it as taken.
378                taken[nextPointIndex] = true;
379
380                if (resultSet.size() < k) {
381                    // Now update elements of minDistSquared.  We only have to compute
382                    // the distance to the new center to do this.
383                    for (int j = 0; j < numPoints; j++) {
384                        // Only have to worry about the points still not taken.
385                        if (!taken[j]) {
386                            double d = distance(p, pointList.get(j));
387                            double d2 = d * d;
388                            if (d2 < minDistSquared[j]) {
389                                minDistSquared[j] = d2;
390                            }
391                        }
392                    }
393                }
394
395            } else {
396                // None found --
397                // Break from the while loop to prevent
398                // an infinite loop.
399                break;
400            }
401        }
402
403        return resultSet;
404    }
405
406    /**
407     * Get a random point from the {@link Cluster} with the largest distance variance.
408     *
409     * @param clusters the {@link Cluster}s to search
410     * @return a random point from the selected cluster
411     * @throws ConvergenceException if clusters are all empty
412     */
413    private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters)
414            throws ConvergenceException {
415
416        double maxVariance = Double.NEGATIVE_INFINITY;
417        Cluster<T> selected = null;
418        for (final CentroidCluster<T> cluster : clusters) {
419            if (!cluster.getPoints().isEmpty()) {
420
421                // compute the distance variance of the current cluster
422                final Clusterable center = cluster.getCenter();
423                final Variance stat = new Variance();
424                for (final T point : cluster.getPoints()) {
425                    stat.increment(distance(point, center));
426                }
427                final double variance = stat.getResult();
428
429                // select the cluster with the largest variance
430                if (variance > maxVariance) {
431                    maxVariance = variance;
432                    selected = cluster;
433                }
434
435            }
436        }
437
438        // did we find at least one non-empty cluster ?
439        if (selected == null) {
440            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
441        }
442
443        // extract a random point from the cluster
444        final List<T> selectedPoints = selected.getPoints();
445        return selectedPoints.remove(random.nextInt(selectedPoints.size()));
446
447    }
448
449    /**
450     * Get a random point from the {@link Cluster} with the largest number of points
451     *
452     * @param clusters the {@link Cluster}s to search
453     * @return a random point from the selected cluster
454     * @throws ConvergenceException if clusters are all empty
455     */
456    private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters)
457            throws ConvergenceException {
458
459        int maxNumber = 0;
460        Cluster<T> selected = null;
461        for (final Cluster<T> cluster : clusters) {
462
463            // get the number of points of the current cluster
464            final int number = cluster.getPoints().size();
465
466            // select the cluster with the largest number of points
467            if (number > maxNumber) {
468                maxNumber = number;
469                selected = cluster;
470            }
471
472        }
473
474        // did we find at least one non-empty cluster ?
475        if (selected == null) {
476            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
477        }
478
479        // extract a random point from the cluster
480        final List<T> selectedPoints = selected.getPoints();
481        return selectedPoints.remove(random.nextInt(selectedPoints.size()));
482
483    }
484
485    /**
486     * Get the point farthest to its cluster center
487     *
488     * @param clusters the {@link Cluster}s to search
489     * @return point farthest to its cluster center
490     * @throws ConvergenceException if clusters are all empty
491     */
492    private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) throws ConvergenceException {
493
494        double maxDistance = Double.NEGATIVE_INFINITY;
495        Cluster<T> selectedCluster = null;
496        int selectedPoint = -1;
497        for (final CentroidCluster<T> cluster : clusters) {
498
499            // get the farthest point
500            final Clusterable center = cluster.getCenter();
501            final List<T> points = cluster.getPoints();
502            for (int i = 0; i < points.size(); ++i) {
503                final double distance = distance(points.get(i), center);
504                if (distance > maxDistance) {
505                    maxDistance     = distance;
506                    selectedCluster = cluster;
507                    selectedPoint   = i;
508                }
509            }
510
511        }
512
513        // did we find at least one non-empty cluster ?
514        if (selectedCluster == null) {
515            throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
516        }
517
518        return selectedCluster.getPoints().remove(selectedPoint);
519
520    }
521
522    /**
523     * Returns the nearest {@link Cluster} to the given point
524     *
525     * @param clusters the {@link Cluster}s to search
526     * @param point the point to find the nearest {@link Cluster} for
527     * @return the index of the nearest {@link Cluster} to the given point
528     */
529    private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
530        double minDistance = Double.MAX_VALUE;
531        int clusterIndex = 0;
532        int minCluster = 0;
533        for (final CentroidCluster<T> c : clusters) {
534            final double distance = distance(point, c.getCenter());
535            if (distance < minDistance) {
536                minDistance = distance;
537                minCluster = clusterIndex;
538            }
539            clusterIndex++;
540        }
541        return minCluster;
542    }
543
544    /**
545     * Computes the centroid for a set of points.
546     *
547     * @param points the set of points
548     * @param dimension the point dimension
549     * @return the computed centroid for the set of points
550     */
551    private Clusterable centroidOf(final Collection<T> points, final int dimension) {
552        final double[] centroid = new double[dimension];
553        for (final T p : points) {
554            final double[] point = p.getPoint();
555            for (int i = 0; i < centroid.length; i++) {
556                centroid[i] += point[i];
557            }
558        }
559        for (int i = 0; i < centroid.length; i++) {
560            centroid[i] /= points.size();
561        }
562        return new DoublePoint(centroid);
563    }
564
565}