001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.stat.clustering;
019    
020    import java.util.ArrayList;
021    import java.util.Collection;
022    import java.util.Collections;
023    import java.util.List;
024    import java.util.Random;
025    
026    import org.apache.commons.math.exception.ConvergenceException;
027    import org.apache.commons.math.exception.MathIllegalArgumentException;
028    import org.apache.commons.math.exception.NumberIsTooSmallException;
029    import org.apache.commons.math.exception.util.LocalizedFormats;
030    import org.apache.commons.math.stat.descriptive.moment.Variance;
031    import org.apache.commons.math.util.MathUtils;
032    
033    /**
034     * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
035     * @param <T> type of the points to cluster
036     * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
037     * @version $Id: KMeansPlusPlusClusterer.java 1139915 2011-06-26 19:13:06Z luc $
038     * @since 2.0
039     */
040    public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
041    
042        /** Strategies to use for replacing an empty cluster. */
043        public static enum EmptyClusterStrategy {
044    
045            /** Split the cluster with largest distance variance. */
046            LARGEST_VARIANCE,
047    
048            /** Split the cluster with largest number of points. */
049            LARGEST_POINTS_NUMBER,
050    
051            /** Create a cluster around the point farthest from its centroid. */
052            FARTHEST_POINT,
053    
054            /** Generate an error. */
055            ERROR
056    
057        }
058    
059        /** Random generator for choosing initial centers. */
060        private final Random random;
061    
062        /** Selected strategy for empty clusters. */
063        private final EmptyClusterStrategy emptyStrategy;
064    
065        /** Build a clusterer.
066         * <p>
067         * The default strategy for handling empty clusters that may appear during
068         * algorithm iterations is to split the cluster with largest distance variance.
069         * </p>
070         * @param random random generator to use for choosing initial centers
071         */
072        public KMeansPlusPlusClusterer(final Random random) {
073            this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
074        }
075    
076        /** Build a clusterer.
077         * @param random random generator to use for choosing initial centers
078         * @param emptyStrategy strategy to use for handling empty clusters that
079         * may appear during algorithm iterations
080         * @since 2.2
081         */
082        public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
083            this.random        = random;
084            this.emptyStrategy = emptyStrategy;
085        }
086    
087        /**
088         * Runs the K-means++ clustering algorithm.
089         *
090         * @param points the points to cluster
091         * @param k the number of clusters to split the data into
092         * @param numTrials number of trial runs
093         * @param maxIterationsPerTrial the maximum number of iterations to run the algorithm
094         *     for at each trial run.  If negative, no maximum will be used
095         * @return a list of clusters containing the points
096         * @throws MathIllegalArgumentException if the data points are null or the number
097         *     of clusters is larger than the number of data points
098         */
099        public List<Cluster<T>> cluster(final Collection<T> points, final int k,
100                                        int numTrials, int maxIterationsPerTrial)
101            throws MathIllegalArgumentException {
102    
103            // at first, we have not found any clusters list yet
104            List<Cluster<T>> best = null;
105            double bestVarianceSum = Double.POSITIVE_INFINITY;
106    
107            // do several clustering trials
108            for (int i = 0; i < numTrials; ++i) {
109    
110                // compute a clusters list
111                List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial);
112    
113                // compute the variance of the current list
114                double varianceSum = 0.0;
115                for (final Cluster<T> cluster : clusters) {
116                    if (!cluster.getPoints().isEmpty()) {
117    
118                        // compute the distance variance of the current cluster
119                        final T center = cluster.getCenter();
120                        final Variance stat = new Variance();
121                        for (final T point : cluster.getPoints()) {
122                            stat.increment(point.distanceFrom(center));
123                        }
124                        varianceSum += stat.getResult();
125    
126                    }
127                }
128    
129                if (varianceSum <= bestVarianceSum) {
130                    // this one is the best we have found so far, remember it
131                    best            = clusters;
132                    bestVarianceSum = varianceSum;
133                }
134    
135            }
136    
137            // return the best clusters list found
138            return best;
139    
140        }
141    
142        /**
143         * Runs the K-means++ clustering algorithm.
144         *
145         * @param points the points to cluster
146         * @param k the number of clusters to split the data into
147         * @param maxIterations the maximum number of iterations to run the algorithm
148         *     for.  If negative, no maximum will be used
149         * @return a list of clusters containing the points
150         * @throws MathIllegalArgumentException if the data points are null or the number
151         *     of clusters is larger than the number of data points
152         */
153        public List<Cluster<T>> cluster(final Collection<T> points, final int k,
154                                        final int maxIterations)
155            throws MathIllegalArgumentException {
156    
157            // sanity checks
158            MathUtils.checkNotNull(points);
159    
160            // number of clusters has to be smaller or equal the number of data points
161            if (points.size() < k) {
162                throw new NumberIsTooSmallException(points.size(), k, false);
163            }
164    
165            // create the initial clusters
166            List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
167    
168            // create an array containing the latest assignment of a point to a cluster
169            // no need to initialize the array, as it will be filled with the first assignment
170            int[] assignments = new int[points.size()];
171            assignPointsToClusters(clusters, points, assignments);
172    
173            // iterate through updating the centers until we're done
174            final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
175            for (int count = 0; count < max; count++) {
176                boolean emptyCluster = false;
177                List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
178                for (final Cluster<T> cluster : clusters) {
179                    final T newCenter;
180                    if (cluster.getPoints().isEmpty()) {
181                        switch (emptyStrategy) {
182                            case LARGEST_VARIANCE :
183                                newCenter = getPointFromLargestVarianceCluster(clusters);
184                                break;
185                            case LARGEST_POINTS_NUMBER :
186                                newCenter = getPointFromLargestNumberCluster(clusters);
187                                break;
188                            case FARTHEST_POINT :
189                                newCenter = getFarthestPoint(clusters);
190                                break;
191                            default :
192                                throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
193                        }
194                        emptyCluster = true;
195                    } else {
196                        newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
197                    }
198                    newClusters.add(new Cluster<T>(newCenter));
199                }
200                int changes = assignPointsToClusters(newClusters, points, assignments);
201                clusters = newClusters;
202    
203                // if there were no more changes in the point-to-cluster assignment
204                // and there are no empty clusters left, return the current clusters
205                if (changes == 0 && !emptyCluster) {
206                    return clusters;
207                }
208            }
209            return clusters;
210        }
211    
212        /**
213         * Adds the given points to the closest {@link Cluster}.
214         *
215         * @param <T> type of the points to cluster
216         * @param clusters the {@link Cluster}s to add the points to
217         * @param points the points to add to the given {@link Cluster}s
218         * @param assignments points assignments to clusters
219         * @return the number of points assigned to different clusters as the iteration before
220         */
221        private static <T extends Clusterable<T>> int
222            assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points,
223                                   final int[] assignments) {
224            int assignedDifferently = 0;
225            int pointIndex = 0;
226            for (final T p : points) {
227                int clusterIndex = getNearestCluster(clusters, p);
228                if (clusterIndex != assignments[pointIndex]) {
229                    assignedDifferently++;
230                }
231    
232                Cluster<T> cluster = clusters.get(clusterIndex);
233                cluster.addPoint(p);
234                assignments[pointIndex++] = clusterIndex;
235            }
236    
237            return assignedDifferently;
238        }
239    
240        /**
241         * Use K-means++ to choose the initial centers.
242         *
243         * @param <T> type of the points to cluster
244         * @param points the points to choose the initial centers from
245         * @param k the number of centers to choose
246         * @param random random generator to use
247         * @return the initial centers
248         */
249        private static <T extends Clusterable<T>> List<Cluster<T>>
250            chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
251    
252            // Convert to list for indexed access. Make it unmodifiable, since removal of items
253            // would screw up the logic of this method.
254            final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
255    
256            // The number of points in the list.
257            final int numPoints = pointList.size();
258    
259            // Set the corresponding element in this array to indicate when
260            // elements of pointList are no longer available.
261            final boolean[] taken = new boolean[numPoints];
262    
263            // The resulting list of initial centers.
264            final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
265    
266            // Choose one center uniformly at random from among the data points.
267            final int firstPointIndex = random.nextInt(numPoints);
268    
269            final T firstPoint = pointList.get(firstPointIndex);
270    
271            resultSet.add(new Cluster<T>(firstPoint));
272    
273            // Must mark it as taken
274            taken[firstPointIndex] = true;
275    
276            // To keep track of the minimum distance squared of elements of
277            // pointList to elements of resultSet.
278            final double[] minDistSquared = new double[numPoints];
279    
280            // Initialize the elements.  Since the only point in resultSet is firstPoint,
281            // this is very easy.
282            for (int i = 0; i < numPoints; i++) {
283                if (i != firstPointIndex) { // That point isn't considered
284                    double d = firstPoint.distanceFrom(pointList.get(i));
285                    minDistSquared[i] = d*d;
286                }
287            }
288    
289            while (resultSet.size() < k) {
290    
291                // Sum up the squared distances for the points in pointList not
292                // already taken.
293                double distSqSum = 0.0;
294    
295                for (int i = 0; i < numPoints; i++) {
296                    if (!taken[i]) {
297                        distSqSum += minDistSquared[i];
298                    }
299                }
300    
301                // Add one new data point as a center. Each point x is chosen with
302                // probability proportional to D(x)2
303                final double r = random.nextDouble() * distSqSum;
304    
305                // The index of the next point to be added to the resultSet.
306                int nextPointIndex = -1;
307    
308                // Sum through the squared min distances again, stopping when
309                // sum >= r.
310                double sum = 0.0;
311                for (int i = 0; i < numPoints; i++) {
312                    if (!taken[i]) {
313                        sum += minDistSquared[i];
314                        if (sum >= r) {
315                            nextPointIndex = i;
316                            break;
317                        }
318                    }
319                }
320    
321                // If it's not set to >= 0, the point wasn't found in the previous
322                // for loop, probably because distances are extremely small.  Just pick
323                // the last available point.
324                if (nextPointIndex == -1) {
325                    for (int i = numPoints - 1; i >= 0; i--) {
326                        if (!taken[i]) {
327                            nextPointIndex = i;
328                            break;
329                        }
330                    }
331                }
332    
333                // We found one.
334                if (nextPointIndex >= 0) {
335    
336                    final T p = pointList.get(nextPointIndex);
337    
338                    resultSet.add(new Cluster<T> (p));
339    
340                    // Mark it as taken.
341                    taken[nextPointIndex] = true;
342    
343                    if (resultSet.size() < k) {
344                        // Now update elements of minDistSquared.  We only have to compute
345                        // the distance to the new center to do this.
346                        for (int j = 0; j < numPoints; j++) {
347                            // Only have to worry about the points still not taken.
348                            if (!taken[j]) {
349                                double d = p.distanceFrom(pointList.get(j));
350                                double d2 = d * d;
351                                if (d2 < minDistSquared[j]) {
352                                    minDistSquared[j] = d2;
353                                }
354                            }
355                        }
356                    }
357    
358                } else {
359                    // None found --
360                    // Break from the while loop to prevent
361                    // an infinite loop.
362                    break;
363                }
364            }
365    
366            return resultSet;
367        }
368    
369        /**
370         * Get a random point from the {@link Cluster} with the largest distance variance.
371         *
372         * @param clusters the {@link Cluster}s to search
373         * @return a random point from the selected cluster
374         */
375        private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters) {
376    
377            double maxVariance = Double.NEGATIVE_INFINITY;
378            Cluster<T> selected = null;
379            for (final Cluster<T> cluster : clusters) {
380                if (!cluster.getPoints().isEmpty()) {
381    
382                    // compute the distance variance of the current cluster
383                    final T center = cluster.getCenter();
384                    final Variance stat = new Variance();
385                    for (final T point : cluster.getPoints()) {
386                        stat.increment(point.distanceFrom(center));
387                    }
388                    final double variance = stat.getResult();
389    
390                    // select the cluster with the largest variance
391                    if (variance > maxVariance) {
392                        maxVariance = variance;
393                        selected = cluster;
394                    }
395    
396                }
397            }
398    
399            // did we find at least one non-empty cluster ?
400            if (selected == null) {
401                throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
402            }
403    
404            // extract a random point from the cluster
405            final List<T> selectedPoints = selected.getPoints();
406            return selectedPoints.remove(random.nextInt(selectedPoints.size()));
407    
408        }
409    
410        /**
411         * Get a random point from the {@link Cluster} with the largest number of points
412         *
413         * @param clusters the {@link Cluster}s to search
414         * @return a random point from the selected cluster
415         */
416        private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) {
417    
418            int maxNumber = 0;
419            Cluster<T> selected = null;
420            for (final Cluster<T> cluster : clusters) {
421    
422                // get the number of points of the current cluster
423                final int number = cluster.getPoints().size();
424    
425                // select the cluster with the largest number of points
426                if (number > maxNumber) {
427                    maxNumber = number;
428                    selected = cluster;
429                }
430    
431            }
432    
433            // did we find at least one non-empty cluster ?
434            if (selected == null) {
435                throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
436            }
437    
438            // extract a random point from the cluster
439            final List<T> selectedPoints = selected.getPoints();
440            return selectedPoints.remove(random.nextInt(selectedPoints.size()));
441    
442        }
443    
444        /**
445         * Get the point farthest to its cluster center
446         *
447         * @param clusters the {@link Cluster}s to search
448         * @return point farthest to its cluster center
449         */
450        private T getFarthestPoint(final Collection<Cluster<T>> clusters) {
451    
452            double maxDistance = Double.NEGATIVE_INFINITY;
453            Cluster<T> selectedCluster = null;
454            int selectedPoint = -1;
455            for (final Cluster<T> cluster : clusters) {
456    
457                // get the farthest point
458                final T center = cluster.getCenter();
459                final List<T> points = cluster.getPoints();
460                for (int i = 0; i < points.size(); ++i) {
461                    final double distance = points.get(i).distanceFrom(center);
462                    if (distance > maxDistance) {
463                        maxDistance     = distance;
464                        selectedCluster = cluster;
465                        selectedPoint   = i;
466                    }
467                }
468    
469            }
470    
471            // did we find at least one non-empty cluster ?
472            if (selectedCluster == null) {
473                throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
474            }
475    
476            return selectedCluster.getPoints().remove(selectedPoint);
477    
478        }
479    
480        /**
481         * Returns the nearest {@link Cluster} to the given point
482         *
483         * @param <T> type of the points to cluster
484         * @param clusters the {@link Cluster}s to search
485         * @param point the point to find the nearest {@link Cluster} for
486         * @return the index of the nearest {@link Cluster} to the given point
487         */
488        private static <T extends Clusterable<T>> int
489            getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
490            double minDistance = Double.MAX_VALUE;
491            int clusterIndex = 0;
492            int minCluster = 0;
493            for (final Cluster<T> c : clusters) {
494                final double distance = point.distanceFrom(c.getCenter());
495                if (distance < minDistance) {
496                    minDistance = distance;
497                    minCluster = clusterIndex;
498                }
499                clusterIndex++;
500            }
501            return minCluster;
502        }
503    
504    }