001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math3.stat.clustering;
019    
020    import java.util.ArrayList;
021    import java.util.Collection;
022    import java.util.Collections;
023    import java.util.List;
024    import java.util.Random;
025    
026    import org.apache.commons.math3.exception.ConvergenceException;
027    import org.apache.commons.math3.exception.MathIllegalArgumentException;
028    import org.apache.commons.math3.exception.NumberIsTooSmallException;
029    import org.apache.commons.math3.exception.util.LocalizedFormats;
030    import org.apache.commons.math3.stat.descriptive.moment.Variance;
031    import org.apache.commons.math3.util.MathUtils;
032    
033    /**
034     * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
035     * @param <T> type of the points to cluster
036     * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
037     * @version $Id: KMeansPlusPlusClusterer.java 1416643 2012-12-03 19:37:14Z tn $
038     * @since 2.0
039     */
040    public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
041    
042        /** Strategies to use for replacing an empty cluster. */
043        public static enum EmptyClusterStrategy {
044    
045            /** Split the cluster with largest distance variance. */
046            LARGEST_VARIANCE,
047    
048            /** Split the cluster with largest number of points. */
049            LARGEST_POINTS_NUMBER,
050    
051            /** Create a cluster around the point farthest from its centroid. */
052            FARTHEST_POINT,
053    
054            /** Generate an error. */
055            ERROR
056    
057        }
058    
059        /** Random generator for choosing initial centers. */
060        private final Random random;
061    
062        /** Selected strategy for empty clusters. */
063        private final EmptyClusterStrategy emptyStrategy;
064    
065        /** Build a clusterer.
066         * <p>
067         * The default strategy for handling empty clusters that may appear during
068         * algorithm iterations is to split the cluster with largest distance variance.
069         * </p>
070         * @param random random generator to use for choosing initial centers
071         */
072        public KMeansPlusPlusClusterer(final Random random) {
073            this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
074        }
075    
076        /** Build a clusterer.
077         * @param random random generator to use for choosing initial centers
078         * @param emptyStrategy strategy to use for handling empty clusters that
079         * may appear during algorithm iterations
080         * @since 2.2
081         */
082        public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
083            this.random        = random;
084            this.emptyStrategy = emptyStrategy;
085        }
086    
087        /**
088         * Runs the K-means++ clustering algorithm.
089         *
090         * @param points the points to cluster
091         * @param k the number of clusters to split the data into
092         * @param numTrials number of trial runs
093         * @param maxIterationsPerTrial the maximum number of iterations to run the algorithm
094         *     for at each trial run.  If negative, no maximum will be used
095         * @return a list of clusters containing the points
096         * @throws MathIllegalArgumentException if the data points are null or the number
097         *     of clusters is larger than the number of data points
098         * @throws ConvergenceException if an empty cluster is encountered and the
099         * {@link #emptyStrategy} is set to {@code ERROR}
100         */
101        public List<Cluster<T>> cluster(final Collection<T> points, final int k,
102                                        int numTrials, int maxIterationsPerTrial)
103            throws MathIllegalArgumentException, ConvergenceException {
104    
105            // at first, we have not found any clusters list yet
106            List<Cluster<T>> best = null;
107            double bestVarianceSum = Double.POSITIVE_INFINITY;
108    
109            // do several clustering trials
110            for (int i = 0; i < numTrials; ++i) {
111    
112                // compute a clusters list
113                List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial);
114    
115                // compute the variance of the current list
116                double varianceSum = 0.0;
117                for (final Cluster<T> cluster : clusters) {
118                    if (!cluster.getPoints().isEmpty()) {
119    
120                        // compute the distance variance of the current cluster
121                        final T center = cluster.getCenter();
122                        final Variance stat = new Variance();
123                        for (final T point : cluster.getPoints()) {
124                            stat.increment(point.distanceFrom(center));
125                        }
126                        varianceSum += stat.getResult();
127    
128                    }
129                }
130    
131                if (varianceSum <= bestVarianceSum) {
132                    // this one is the best we have found so far, remember it
133                    best            = clusters;
134                    bestVarianceSum = varianceSum;
135                }
136    
137            }
138    
139            // return the best clusters list found
140            return best;
141    
142        }
143    
144        /**
145         * Runs the K-means++ clustering algorithm.
146         *
147         * @param points the points to cluster
148         * @param k the number of clusters to split the data into
149         * @param maxIterations the maximum number of iterations to run the algorithm
150         *     for.  If negative, no maximum will be used
151         * @return a list of clusters containing the points
152         * @throws MathIllegalArgumentException if the data points are null or the number
153         *     of clusters is larger than the number of data points
154         * @throws ConvergenceException if an empty cluster is encountered and the
155         * {@link #emptyStrategy} is set to {@code ERROR}
156         */
157        public List<Cluster<T>> cluster(final Collection<T> points, final int k,
158                                        final int maxIterations)
159            throws MathIllegalArgumentException, ConvergenceException {
160    
161            // sanity checks
162            MathUtils.checkNotNull(points);
163    
164            // number of clusters has to be smaller or equal the number of data points
165            if (points.size() < k) {
166                throw new NumberIsTooSmallException(points.size(), k, false);
167            }
168    
169            // create the initial clusters
170            List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
171    
172            // create an array containing the latest assignment of a point to a cluster
173            // no need to initialize the array, as it will be filled with the first assignment
174            int[] assignments = new int[points.size()];
175            assignPointsToClusters(clusters, points, assignments);
176    
177            // iterate through updating the centers until we're done
178            final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
179            for (int count = 0; count < max; count++) {
180                boolean emptyCluster = false;
181                List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
182                for (final Cluster<T> cluster : clusters) {
183                    final T newCenter;
184                    if (cluster.getPoints().isEmpty()) {
185                        switch (emptyStrategy) {
186                            case LARGEST_VARIANCE :
187                                newCenter = getPointFromLargestVarianceCluster(clusters);
188                                break;
189                            case LARGEST_POINTS_NUMBER :
190                                newCenter = getPointFromLargestNumberCluster(clusters);
191                                break;
192                            case FARTHEST_POINT :
193                                newCenter = getFarthestPoint(clusters);
194                                break;
195                            default :
196                                throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
197                        }
198                        emptyCluster = true;
199                    } else {
200                        newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
201                    }
202                    newClusters.add(new Cluster<T>(newCenter));
203                }
204                int changes = assignPointsToClusters(newClusters, points, assignments);
205                clusters = newClusters;
206    
207                // if there were no more changes in the point-to-cluster assignment
208                // and there are no empty clusters left, return the current clusters
209                if (changes == 0 && !emptyCluster) {
210                    return clusters;
211                }
212            }
213            return clusters;
214        }
215    
216        /**
217         * Adds the given points to the closest {@link Cluster}.
218         *
219         * @param <T> type of the points to cluster
220         * @param clusters the {@link Cluster}s to add the points to
221         * @param points the points to add to the given {@link Cluster}s
222         * @param assignments points assignments to clusters
223         * @return the number of points assigned to different clusters as the iteration before
224         */
225        private static <T extends Clusterable<T>> int
226            assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points,
227                                   final int[] assignments) {
228            int assignedDifferently = 0;
229            int pointIndex = 0;
230            for (final T p : points) {
231                int clusterIndex = getNearestCluster(clusters, p);
232                if (clusterIndex != assignments[pointIndex]) {
233                    assignedDifferently++;
234                }
235    
236                Cluster<T> cluster = clusters.get(clusterIndex);
237                cluster.addPoint(p);
238                assignments[pointIndex++] = clusterIndex;
239            }
240    
241            return assignedDifferently;
242        }
243    
244        /**
245         * Use K-means++ to choose the initial centers.
246         *
247         * @param <T> type of the points to cluster
248         * @param points the points to choose the initial centers from
249         * @param k the number of centers to choose
250         * @param random random generator to use
251         * @return the initial centers
252         */
253        private static <T extends Clusterable<T>> List<Cluster<T>>
254            chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
255    
256            // Convert to list for indexed access. Make it unmodifiable, since removal of items
257            // would screw up the logic of this method.
258            final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
259    
260            // The number of points in the list.
261            final int numPoints = pointList.size();
262    
263            // Set the corresponding element in this array to indicate when
264            // elements of pointList are no longer available.
265            final boolean[] taken = new boolean[numPoints];
266    
267            // The resulting list of initial centers.
268            final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
269    
270            // Choose one center uniformly at random from among the data points.
271            final int firstPointIndex = random.nextInt(numPoints);
272    
273            final T firstPoint = pointList.get(firstPointIndex);
274    
275            resultSet.add(new Cluster<T>(firstPoint));
276    
277            // Must mark it as taken
278            taken[firstPointIndex] = true;
279    
280            // To keep track of the minimum distance squared of elements of
281            // pointList to elements of resultSet.
282            final double[] minDistSquared = new double[numPoints];
283    
284            // Initialize the elements.  Since the only point in resultSet is firstPoint,
285            // this is very easy.
286            for (int i = 0; i < numPoints; i++) {
287                if (i != firstPointIndex) { // That point isn't considered
288                    double d = firstPoint.distanceFrom(pointList.get(i));
289                    minDistSquared[i] = d*d;
290                }
291            }
292    
293            while (resultSet.size() < k) {
294    
295                // Sum up the squared distances for the points in pointList not
296                // already taken.
297                double distSqSum = 0.0;
298    
299                for (int i = 0; i < numPoints; i++) {
300                    if (!taken[i]) {
301                        distSqSum += minDistSquared[i];
302                    }
303                }
304    
305                // Add one new data point as a center. Each point x is chosen with
306                // probability proportional to D(x)2
307                final double r = random.nextDouble() * distSqSum;
308    
309                // The index of the next point to be added to the resultSet.
310                int nextPointIndex = -1;
311    
312                // Sum through the squared min distances again, stopping when
313                // sum >= r.
314                double sum = 0.0;
315                for (int i = 0; i < numPoints; i++) {
316                    if (!taken[i]) {
317                        sum += minDistSquared[i];
318                        if (sum >= r) {
319                            nextPointIndex = i;
320                            break;
321                        }
322                    }
323                }
324    
325                // If it's not set to >= 0, the point wasn't found in the previous
326                // for loop, probably because distances are extremely small.  Just pick
327                // the last available point.
328                if (nextPointIndex == -1) {
329                    for (int i = numPoints - 1; i >= 0; i--) {
330                        if (!taken[i]) {
331                            nextPointIndex = i;
332                            break;
333                        }
334                    }
335                }
336    
337                // We found one.
338                if (nextPointIndex >= 0) {
339    
340                    final T p = pointList.get(nextPointIndex);
341    
342                    resultSet.add(new Cluster<T> (p));
343    
344                    // Mark it as taken.
345                    taken[nextPointIndex] = true;
346    
347                    if (resultSet.size() < k) {
348                        // Now update elements of minDistSquared.  We only have to compute
349                        // the distance to the new center to do this.
350                        for (int j = 0; j < numPoints; j++) {
351                            // Only have to worry about the points still not taken.
352                            if (!taken[j]) {
353                                double d = p.distanceFrom(pointList.get(j));
354                                double d2 = d * d;
355                                if (d2 < minDistSquared[j]) {
356                                    minDistSquared[j] = d2;
357                                }
358                            }
359                        }
360                    }
361    
362                } else {
363                    // None found --
364                    // Break from the while loop to prevent
365                    // an infinite loop.
366                    break;
367                }
368            }
369    
370            return resultSet;
371        }
372    
373        /**
374         * Get a random point from the {@link Cluster} with the largest distance variance.
375         *
376         * @param clusters the {@link Cluster}s to search
377         * @return a random point from the selected cluster
378         * @throws ConvergenceException if clusters are all empty
379         */
380        private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters)
381        throws ConvergenceException {
382    
383            double maxVariance = Double.NEGATIVE_INFINITY;
384            Cluster<T> selected = null;
385            for (final Cluster<T> cluster : clusters) {
386                if (!cluster.getPoints().isEmpty()) {
387    
388                    // compute the distance variance of the current cluster
389                    final T center = cluster.getCenter();
390                    final Variance stat = new Variance();
391                    for (final T point : cluster.getPoints()) {
392                        stat.increment(point.distanceFrom(center));
393                    }
394                    final double variance = stat.getResult();
395    
396                    // select the cluster with the largest variance
397                    if (variance > maxVariance) {
398                        maxVariance = variance;
399                        selected = cluster;
400                    }
401    
402                }
403            }
404    
405            // did we find at least one non-empty cluster ?
406            if (selected == null) {
407                throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
408            }
409    
410            // extract a random point from the cluster
411            final List<T> selectedPoints = selected.getPoints();
412            return selectedPoints.remove(random.nextInt(selectedPoints.size()));
413    
414        }
415    
416        /**
417         * Get a random point from the {@link Cluster} with the largest number of points
418         *
419         * @param clusters the {@link Cluster}s to search
420         * @return a random point from the selected cluster
421         * @throws ConvergenceException if clusters are all empty
422         */
423        private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) throws ConvergenceException {
424    
425            int maxNumber = 0;
426            Cluster<T> selected = null;
427            for (final Cluster<T> cluster : clusters) {
428    
429                // get the number of points of the current cluster
430                final int number = cluster.getPoints().size();
431    
432                // select the cluster with the largest number of points
433                if (number > maxNumber) {
434                    maxNumber = number;
435                    selected = cluster;
436                }
437    
438            }
439    
440            // did we find at least one non-empty cluster ?
441            if (selected == null) {
442                throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
443            }
444    
445            // extract a random point from the cluster
446            final List<T> selectedPoints = selected.getPoints();
447            return selectedPoints.remove(random.nextInt(selectedPoints.size()));
448    
449        }
450    
451        /**
452         * Get the point farthest to its cluster center
453         *
454         * @param clusters the {@link Cluster}s to search
455         * @return point farthest to its cluster center
456         * @throws ConvergenceException if clusters are all empty
457         */
458        private T getFarthestPoint(final Collection<Cluster<T>> clusters) throws ConvergenceException {
459    
460            double maxDistance = Double.NEGATIVE_INFINITY;
461            Cluster<T> selectedCluster = null;
462            int selectedPoint = -1;
463            for (final Cluster<T> cluster : clusters) {
464    
465                // get the farthest point
466                final T center = cluster.getCenter();
467                final List<T> points = cluster.getPoints();
468                for (int i = 0; i < points.size(); ++i) {
469                    final double distance = points.get(i).distanceFrom(center);
470                    if (distance > maxDistance) {
471                        maxDistance     = distance;
472                        selectedCluster = cluster;
473                        selectedPoint   = i;
474                    }
475                }
476    
477            }
478    
479            // did we find at least one non-empty cluster ?
480            if (selectedCluster == null) {
481                throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
482            }
483    
484            return selectedCluster.getPoints().remove(selectedPoint);
485    
486        }
487    
488        /**
489         * Returns the nearest {@link Cluster} to the given point
490         *
491         * @param <T> type of the points to cluster
492         * @param clusters the {@link Cluster}s to search
493         * @param point the point to find the nearest {@link Cluster} for
494         * @return the index of the nearest {@link Cluster} to the given point
495         */
496        private static <T extends Clusterable<T>> int
497            getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
498            double minDistance = Double.MAX_VALUE;
499            int clusterIndex = 0;
500            int minCluster = 0;
501            for (final Cluster<T> c : clusters) {
502                final double distance = point.distanceFrom(c.getCenter());
503                if (distance < minDistance) {
504                    minDistance = distance;
505                    minCluster = clusterIndex;
506                }
507                clusterIndex++;
508            }
509            return minCluster;
510        }
511    
512    }