001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.ml.clustering;
019
020import java.util.Collection;
021import java.util.List;
022
023import org.apache.commons.math3.exception.ConvergenceException;
024import org.apache.commons.math3.exception.MathIllegalArgumentException;
025import org.apache.commons.math3.ml.clustering.evaluation.ClusterEvaluator;
026import org.apache.commons.math3.ml.clustering.evaluation.SumOfClusterVariances;
027
028/**
029 * A wrapper around a k-means++ clustering algorithm which performs multiple trials
030 * and returns the best solution.
031 * @param <T> type of the points to cluster
032 * @version $Id: MultiKMeansPlusPlusClusterer.html 908881 2014-05-15 07:10:28Z luc $
033 * @since 3.2
034 */
035public class MultiKMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
036
037    /** The underlying k-means clusterer. */
038    private final KMeansPlusPlusClusterer<T> clusterer;
039
040    /** The number of trial runs. */
041    private final int numTrials;
042
043    /** The cluster evaluator to use. */
044    private final ClusterEvaluator<T> evaluator;
045
046    /** Build a clusterer.
047     * @param clusterer the k-means clusterer to use
048     * @param numTrials number of trial runs
049     */
050    public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
051                                        final int numTrials) {
052        this(clusterer, numTrials, new SumOfClusterVariances<T>(clusterer.getDistanceMeasure()));
053    }
054
055    /** Build a clusterer.
056     * @param clusterer the k-means clusterer to use
057     * @param numTrials number of trial runs
058     * @param evaluator the cluster evaluator to use
059     * @since 3.3
060     */
061    public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
062                                        final int numTrials,
063                                        final ClusterEvaluator<T> evaluator) {
064        super(clusterer.getDistanceMeasure());
065        this.clusterer = clusterer;
066        this.numTrials = numTrials;
067        this.evaluator = evaluator;
068    }
069
070    /**
071     * Returns the embedded k-means clusterer used by this instance.
072     * @return the embedded clusterer
073     */
074    public KMeansPlusPlusClusterer<T> getClusterer() {
075        return clusterer;
076    }
077
078    /**
079     * Returns the number of trials this instance will do.
080     * @return the number of trials
081     */
082    public int getNumTrials() {
083        return numTrials;
084    }
085
086    /**
087     * Returns the {@link ClusterEvaluator} used to determine the "best" clustering.
088     * @return the used {@link ClusterEvaluator}
089     * @since 3.3
090     */
091    public ClusterEvaluator<T> getClusterEvaluator() {
092       return evaluator;
093    }
094
095    /**
096     * Runs the K-means++ clustering algorithm.
097     *
098     * @param points the points to cluster
099     * @return a list of clusters containing the points
100     * @throws MathIllegalArgumentException if the data points are null or the number
101     *   of clusters is larger than the number of data points
102     * @throws ConvergenceException if an empty cluster is encountered and the
103     *   underlying {@link KMeansPlusPlusClusterer} has its
104     *   {@link KMeansPlusPlusClusterer.EmptyClusterStrategy} is set to {@code ERROR}.
105     */
106    @Override
107    public List<CentroidCluster<T>> cluster(final Collection<T> points)
108        throws MathIllegalArgumentException, ConvergenceException {
109
110        // at first, we have not found any clusters list yet
111        List<CentroidCluster<T>> best = null;
112        double bestVarianceSum = Double.POSITIVE_INFINITY;
113
114        // do several clustering trials
115        for (int i = 0; i < numTrials; ++i) {
116
117            // compute a clusters list
118            List<CentroidCluster<T>> clusters = clusterer.cluster(points);
119
120            // compute the variance of the current list
121            final double varianceSum = evaluator.score(clusters);
122
123            if (evaluator.isBetterScore(varianceSum, bestVarianceSum)) {
124                // this one is the best we have found so far, remember it
125                best            = clusters;
126                bestVarianceSum = varianceSum;
127            }
128
129        }
130
131        // return the best clusters list found
132        return best;
133
134    }
135
136}