001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.ml.clustering;
019
020import java.util.Collection;
021import java.util.List;
022
023import org.apache.commons.math3.exception.ConvergenceException;
024import org.apache.commons.math3.exception.MathIllegalArgumentException;
025import org.apache.commons.math3.ml.clustering.evaluation.ClusterEvaluator;
026import org.apache.commons.math3.ml.clustering.evaluation.SumOfClusterVariances;
027
028/**
029 * A wrapper around a k-means++ clustering algorithm which performs multiple trials
030 * and returns the best solution.
031 * @param <T> type of the points to cluster
032 * @since 3.2
033 */
034public class MultiKMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
035
036    /** The underlying k-means clusterer. */
037    private final KMeansPlusPlusClusterer<T> clusterer;
038
039    /** The number of trial runs. */
040    private final int numTrials;
041
042    /** The cluster evaluator to use. */
043    private final ClusterEvaluator<T> evaluator;
044
045    /** Build a clusterer.
046     * @param clusterer the k-means clusterer to use
047     * @param numTrials number of trial runs
048     */
049    public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
050                                        final int numTrials) {
051        this(clusterer, numTrials, new SumOfClusterVariances<T>(clusterer.getDistanceMeasure()));
052    }
053
054    /** Build a clusterer.
055     * @param clusterer the k-means clusterer to use
056     * @param numTrials number of trial runs
057     * @param evaluator the cluster evaluator to use
058     * @since 3.3
059     */
060    public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
061                                        final int numTrials,
062                                        final ClusterEvaluator<T> evaluator) {
063        super(clusterer.getDistanceMeasure());
064        this.clusterer = clusterer;
065        this.numTrials = numTrials;
066        this.evaluator = evaluator;
067    }
068
069    /**
070     * Returns the embedded k-means clusterer used by this instance.
071     * @return the embedded clusterer
072     */
073    public KMeansPlusPlusClusterer<T> getClusterer() {
074        return clusterer;
075    }
076
077    /**
078     * Returns the number of trials this instance will do.
079     * @return the number of trials
080     */
081    public int getNumTrials() {
082        return numTrials;
083    }
084
085    /**
086     * Returns the {@link ClusterEvaluator} used to determine the "best" clustering.
087     * @return the used {@link ClusterEvaluator}
088     * @since 3.3
089     */
090    public ClusterEvaluator<T> getClusterEvaluator() {
091       return evaluator;
092    }
093
094    /**
095     * Runs the K-means++ clustering algorithm.
096     *
097     * @param points the points to cluster
098     * @return a list of clusters containing the points
099     * @throws MathIllegalArgumentException if the data points are null or the number
100     *   of clusters is larger than the number of data points
101     * @throws ConvergenceException if an empty cluster is encountered and the
102     *   underlying {@link KMeansPlusPlusClusterer} has its
103     *   {@link KMeansPlusPlusClusterer.EmptyClusterStrategy} is set to {@code ERROR}.
104     */
105    @Override
106    public List<CentroidCluster<T>> cluster(final Collection<T> points)
107        throws MathIllegalArgumentException, ConvergenceException {
108
109        // at first, we have not found any clusters list yet
110        List<CentroidCluster<T>> best = null;
111        double bestVarianceSum = Double.POSITIVE_INFINITY;
112
113        // do several clustering trials
114        for (int i = 0; i < numTrials; ++i) {
115
116            // compute a clusters list
117            List<CentroidCluster<T>> clusters = clusterer.cluster(points);
118
119            // compute the variance of the current list
120            final double varianceSum = evaluator.score(clusters);
121
122            if (evaluator.isBetterScore(varianceSum, bestVarianceSum)) {
123                // this one is the best we have found so far, remember it
124                best            = clusters;
125                bestVarianceSum = varianceSum;
126            }
127
128        }
129
130        // return the best clusters list found
131        return best;
132
133    }
134
135}