001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math3.stat.clustering;
019    
020    import java.util.ArrayList;
021    import java.util.Collection;
022    import java.util.Collections;
023    import java.util.List;
024    import java.util.Random;
025    
026    import org.apache.commons.math3.exception.ConvergenceException;
027    import org.apache.commons.math3.exception.MathIllegalArgumentException;
028    import org.apache.commons.math3.exception.NumberIsTooSmallException;
029    import org.apache.commons.math3.exception.util.LocalizedFormats;
030    import org.apache.commons.math3.stat.descriptive.moment.Variance;
031    import org.apache.commons.math3.util.MathUtils;
032    
033    /**
034     * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
035     * @param <T> type of the points to cluster
036     * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
037     * @version $Id: KMeansPlusPlusClusterer.java 1461871 2013-03-27 22:01:25Z tn $
038     * @since 2.0
039     * @deprecated As of 3.2 (to be removed in 4.0),
040     * use {@link org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer} instead
041     */
042    @Deprecated
043    public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
044    
045        /** Strategies to use for replacing an empty cluster. */
046        public static enum EmptyClusterStrategy {
047    
048            /** Split the cluster with largest distance variance. */
049            LARGEST_VARIANCE,
050    
051            /** Split the cluster with largest number of points. */
052            LARGEST_POINTS_NUMBER,
053    
054            /** Create a cluster around the point farthest from its centroid. */
055            FARTHEST_POINT,
056    
057            /** Generate an error. */
058            ERROR
059    
060        }
061    
062        /** Random generator for choosing initial centers. */
063        private final Random random;
064    
065        /** Selected strategy for empty clusters. */
066        private final EmptyClusterStrategy emptyStrategy;
067    
068        /** Build a clusterer.
069         * <p>
070         * The default strategy for handling empty clusters that may appear during
071         * algorithm iterations is to split the cluster with largest distance variance.
072         * </p>
073         * @param random random generator to use for choosing initial centers
074         */
075        public KMeansPlusPlusClusterer(final Random random) {
076            this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
077        }
078    
079        /** Build a clusterer.
080         * @param random random generator to use for choosing initial centers
081         * @param emptyStrategy strategy to use for handling empty clusters that
082         * may appear during algorithm iterations
083         * @since 2.2
084         */
085        public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
086            this.random        = random;
087            this.emptyStrategy = emptyStrategy;
088        }
089    
090        /**
091         * Runs the K-means++ clustering algorithm.
092         *
093         * @param points the points to cluster
094         * @param k the number of clusters to split the data into
095         * @param numTrials number of trial runs
096         * @param maxIterationsPerTrial the maximum number of iterations to run the algorithm
097         *     for at each trial run.  If negative, no maximum will be used
098         * @return a list of clusters containing the points
099         * @throws MathIllegalArgumentException if the data points are null or the number
100         *     of clusters is larger than the number of data points
101         * @throws ConvergenceException if an empty cluster is encountered and the
102         * {@link #emptyStrategy} is set to {@code ERROR}
103         */
104        public List<Cluster<T>> cluster(final Collection<T> points, final int k,
105                                        int numTrials, int maxIterationsPerTrial)
106            throws MathIllegalArgumentException, ConvergenceException {
107    
108            // at first, we have not found any clusters list yet
109            List<Cluster<T>> best = null;
110            double bestVarianceSum = Double.POSITIVE_INFINITY;
111    
112            // do several clustering trials
113            for (int i = 0; i < numTrials; ++i) {
114    
115                // compute a clusters list
116                List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial);
117    
118                // compute the variance of the current list
119                double varianceSum = 0.0;
120                for (final Cluster<T> cluster : clusters) {
121                    if (!cluster.getPoints().isEmpty()) {
122    
123                        // compute the distance variance of the current cluster
124                        final T center = cluster.getCenter();
125                        final Variance stat = new Variance();
126                        for (final T point : cluster.getPoints()) {
127                            stat.increment(point.distanceFrom(center));
128                        }
129                        varianceSum += stat.getResult();
130    
131                    }
132                }
133    
134                if (varianceSum <= bestVarianceSum) {
135                    // this one is the best we have found so far, remember it
136                    best            = clusters;
137                    bestVarianceSum = varianceSum;
138                }
139    
140            }
141    
142            // return the best clusters list found
143            return best;
144    
145        }
146    
147        /**
148         * Runs the K-means++ clustering algorithm.
149         *
150         * @param points the points to cluster
151         * @param k the number of clusters to split the data into
152         * @param maxIterations the maximum number of iterations to run the algorithm
153         *     for.  If negative, no maximum will be used
154         * @return a list of clusters containing the points
155         * @throws MathIllegalArgumentException if the data points are null or the number
156         *     of clusters is larger than the number of data points
157         * @throws ConvergenceException if an empty cluster is encountered and the
158         * {@link #emptyStrategy} is set to {@code ERROR}
159         */
160        public List<Cluster<T>> cluster(final Collection<T> points, final int k,
161                                        final int maxIterations)
162            throws MathIllegalArgumentException, ConvergenceException {
163    
164            // sanity checks
165            MathUtils.checkNotNull(points);
166    
167            // number of clusters has to be smaller or equal the number of data points
168            if (points.size() < k) {
169                throw new NumberIsTooSmallException(points.size(), k, false);
170            }
171    
172            // create the initial clusters
173            List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
174    
175            // create an array containing the latest assignment of a point to a cluster
176            // no need to initialize the array, as it will be filled with the first assignment
177            int[] assignments = new int[points.size()];
178            assignPointsToClusters(clusters, points, assignments);
179    
180            // iterate through updating the centers until we're done
181            final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
182            for (int count = 0; count < max; count++) {
183                boolean emptyCluster = false;
184                List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
185                for (final Cluster<T> cluster : clusters) {
186                    final T newCenter;
187                    if (cluster.getPoints().isEmpty()) {
188                        switch (emptyStrategy) {
189                            case LARGEST_VARIANCE :
190                                newCenter = getPointFromLargestVarianceCluster(clusters);
191                                break;
192                            case LARGEST_POINTS_NUMBER :
193                                newCenter = getPointFromLargestNumberCluster(clusters);
194                                break;
195                            case FARTHEST_POINT :
196                                newCenter = getFarthestPoint(clusters);
197                                break;
198                            default :
199                                throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
200                        }
201                        emptyCluster = true;
202                    } else {
203                        newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
204                    }
205                    newClusters.add(new Cluster<T>(newCenter));
206                }
207                int changes = assignPointsToClusters(newClusters, points, assignments);
208                clusters = newClusters;
209    
210                // if there were no more changes in the point-to-cluster assignment
211                // and there are no empty clusters left, return the current clusters
212                if (changes == 0 && !emptyCluster) {
213                    return clusters;
214                }
215            }
216            return clusters;
217        }
218    
219        /**
220         * Adds the given points to the closest {@link Cluster}.
221         *
222         * @param <T> type of the points to cluster
223         * @param clusters the {@link Cluster}s to add the points to
224         * @param points the points to add to the given {@link Cluster}s
225         * @param assignments points assignments to clusters
226         * @return the number of points assigned to different clusters as the iteration before
227         */
228        private static <T extends Clusterable<T>> int
229            assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points,
230                                   final int[] assignments) {
231            int assignedDifferently = 0;
232            int pointIndex = 0;
233            for (final T p : points) {
234                int clusterIndex = getNearestCluster(clusters, p);
235                if (clusterIndex != assignments[pointIndex]) {
236                    assignedDifferently++;
237                }
238    
239                Cluster<T> cluster = clusters.get(clusterIndex);
240                cluster.addPoint(p);
241                assignments[pointIndex++] = clusterIndex;
242            }
243    
244            return assignedDifferently;
245        }
246    
247        /**
248         * Use K-means++ to choose the initial centers.
249         *
250         * @param <T> type of the points to cluster
251         * @param points the points to choose the initial centers from
252         * @param k the number of centers to choose
253         * @param random random generator to use
254         * @return the initial centers
255         */
256        private static <T extends Clusterable<T>> List<Cluster<T>>
257            chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
258    
259            // Convert to list for indexed access. Make it unmodifiable, since removal of items
260            // would screw up the logic of this method.
261            final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
262    
263            // The number of points in the list.
264            final int numPoints = pointList.size();
265    
266            // Set the corresponding element in this array to indicate when
267            // elements of pointList are no longer available.
268            final boolean[] taken = new boolean[numPoints];
269    
270            // The resulting list of initial centers.
271            final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
272    
273            // Choose one center uniformly at random from among the data points.
274            final int firstPointIndex = random.nextInt(numPoints);
275    
276            final T firstPoint = pointList.get(firstPointIndex);
277    
278            resultSet.add(new Cluster<T>(firstPoint));
279    
280            // Must mark it as taken
281            taken[firstPointIndex] = true;
282    
283            // To keep track of the minimum distance squared of elements of
284            // pointList to elements of resultSet.
285            final double[] minDistSquared = new double[numPoints];
286    
287            // Initialize the elements.  Since the only point in resultSet is firstPoint,
288            // this is very easy.
289            for (int i = 0; i < numPoints; i++) {
290                if (i != firstPointIndex) { // That point isn't considered
291                    double d = firstPoint.distanceFrom(pointList.get(i));
292                    minDistSquared[i] = d*d;
293                }
294            }
295    
296            while (resultSet.size() < k) {
297    
298                // Sum up the squared distances for the points in pointList not
299                // already taken.
300                double distSqSum = 0.0;
301    
302                for (int i = 0; i < numPoints; i++) {
303                    if (!taken[i]) {
304                        distSqSum += minDistSquared[i];
305                    }
306                }
307    
308                // Add one new data point as a center. Each point x is chosen with
309                // probability proportional to D(x)2
310                final double r = random.nextDouble() * distSqSum;
311    
312                // The index of the next point to be added to the resultSet.
313                int nextPointIndex = -1;
314    
315                // Sum through the squared min distances again, stopping when
316                // sum >= r.
317                double sum = 0.0;
318                for (int i = 0; i < numPoints; i++) {
319                    if (!taken[i]) {
320                        sum += minDistSquared[i];
321                        if (sum >= r) {
322                            nextPointIndex = i;
323                            break;
324                        }
325                    }
326                }
327    
328                // If it's not set to >= 0, the point wasn't found in the previous
329                // for loop, probably because distances are extremely small.  Just pick
330                // the last available point.
331                if (nextPointIndex == -1) {
332                    for (int i = numPoints - 1; i >= 0; i--) {
333                        if (!taken[i]) {
334                            nextPointIndex = i;
335                            break;
336                        }
337                    }
338                }
339    
340                // We found one.
341                if (nextPointIndex >= 0) {
342    
343                    final T p = pointList.get(nextPointIndex);
344    
345                    resultSet.add(new Cluster<T> (p));
346    
347                    // Mark it as taken.
348                    taken[nextPointIndex] = true;
349    
350                    if (resultSet.size() < k) {
351                        // Now update elements of minDistSquared.  We only have to compute
352                        // the distance to the new center to do this.
353                        for (int j = 0; j < numPoints; j++) {
354                            // Only have to worry about the points still not taken.
355                            if (!taken[j]) {
356                                double d = p.distanceFrom(pointList.get(j));
357                                double d2 = d * d;
358                                if (d2 < minDistSquared[j]) {
359                                    minDistSquared[j] = d2;
360                                }
361                            }
362                        }
363                    }
364    
365                } else {
366                    // None found --
367                    // Break from the while loop to prevent
368                    // an infinite loop.
369                    break;
370                }
371            }
372    
373            return resultSet;
374        }
375    
376        /**
377         * Get a random point from the {@link Cluster} with the largest distance variance.
378         *
379         * @param clusters the {@link Cluster}s to search
380         * @return a random point from the selected cluster
381         * @throws ConvergenceException if clusters are all empty
382         */
383        private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters)
384        throws ConvergenceException {
385    
386            double maxVariance = Double.NEGATIVE_INFINITY;
387            Cluster<T> selected = null;
388            for (final Cluster<T> cluster : clusters) {
389                if (!cluster.getPoints().isEmpty()) {
390    
391                    // compute the distance variance of the current cluster
392                    final T center = cluster.getCenter();
393                    final Variance stat = new Variance();
394                    for (final T point : cluster.getPoints()) {
395                        stat.increment(point.distanceFrom(center));
396                    }
397                    final double variance = stat.getResult();
398    
399                    // select the cluster with the largest variance
400                    if (variance > maxVariance) {
401                        maxVariance = variance;
402                        selected = cluster;
403                    }
404    
405                }
406            }
407    
408            // did we find at least one non-empty cluster ?
409            if (selected == null) {
410                throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
411            }
412    
413            // extract a random point from the cluster
414            final List<T> selectedPoints = selected.getPoints();
415            return selectedPoints.remove(random.nextInt(selectedPoints.size()));
416    
417        }
418    
419        /**
420         * Get a random point from the {@link Cluster} with the largest number of points
421         *
422         * @param clusters the {@link Cluster}s to search
423         * @return a random point from the selected cluster
424         * @throws ConvergenceException if clusters are all empty
425         */
426        private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) throws ConvergenceException {
427    
428            int maxNumber = 0;
429            Cluster<T> selected = null;
430            for (final Cluster<T> cluster : clusters) {
431    
432                // get the number of points of the current cluster
433                final int number = cluster.getPoints().size();
434    
435                // select the cluster with the largest number of points
436                if (number > maxNumber) {
437                    maxNumber = number;
438                    selected = cluster;
439                }
440    
441            }
442    
443            // did we find at least one non-empty cluster ?
444            if (selected == null) {
445                throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
446            }
447    
448            // extract a random point from the cluster
449            final List<T> selectedPoints = selected.getPoints();
450            return selectedPoints.remove(random.nextInt(selectedPoints.size()));
451    
452        }
453    
454        /**
455         * Get the point farthest to its cluster center
456         *
457         * @param clusters the {@link Cluster}s to search
458         * @return point farthest to its cluster center
459         * @throws ConvergenceException if clusters are all empty
460         */
461        private T getFarthestPoint(final Collection<Cluster<T>> clusters) throws ConvergenceException {
462    
463            double maxDistance = Double.NEGATIVE_INFINITY;
464            Cluster<T> selectedCluster = null;
465            int selectedPoint = -1;
466            for (final Cluster<T> cluster : clusters) {
467    
468                // get the farthest point
469                final T center = cluster.getCenter();
470                final List<T> points = cluster.getPoints();
471                for (int i = 0; i < points.size(); ++i) {
472                    final double distance = points.get(i).distanceFrom(center);
473                    if (distance > maxDistance) {
474                        maxDistance     = distance;
475                        selectedCluster = cluster;
476                        selectedPoint   = i;
477                    }
478                }
479    
480            }
481    
482            // did we find at least one non-empty cluster ?
483            if (selectedCluster == null) {
484                throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
485            }
486    
487            return selectedCluster.getPoints().remove(selectedPoint);
488    
489        }
490    
491        /**
492         * Returns the nearest {@link Cluster} to the given point
493         *
494         * @param <T> type of the points to cluster
495         * @param clusters the {@link Cluster}s to search
496         * @param point the point to find the nearest {@link Cluster} for
497         * @return the index of the nearest {@link Cluster} to the given point
498         */
499        private static <T extends Clusterable<T>> int
500            getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
501            double minDistance = Double.MAX_VALUE;
502            int clusterIndex = 0;
503            int minCluster = 0;
504            for (final Cluster<T> c : clusters) {
505                final double distance = point.distanceFrom(c.getCenter());
506                if (distance < minDistance) {
507                    minDistance = distance;
508                    minCluster = clusterIndex;
509                }
510                clusterIndex++;
511            }
512            return minCluster;
513        }
514    
515    }