001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math3.ml.clustering; 019 020import java.util.Collection; 021import java.util.List; 022 023import org.apache.commons.math3.exception.ConvergenceException; 024import org.apache.commons.math3.exception.MathIllegalArgumentException; 025import org.apache.commons.math3.ml.clustering.evaluation.ClusterEvaluator; 026import org.apache.commons.math3.ml.clustering.evaluation.SumOfClusterVariances; 027 028/** 029 * A wrapper around a k-means++ clustering algorithm which performs multiple trials 030 * and returns the best solution. 031 * @param <T> type of the points to cluster 032 * @since 3.2 033 */ 034public class MultiKMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> { 035 036 /** The underlying k-means clusterer. */ 037 private final KMeansPlusPlusClusterer<T> clusterer; 038 039 /** The number of trial runs. */ 040 private final int numTrials; 041 042 /** The cluster evaluator to use. */ 043 private final ClusterEvaluator<T> evaluator; 044 045 /** Build a clusterer. 046 * @param clusterer the k-means clusterer to use 047 * @param numTrials number of trial runs 048 */ 049 public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer, 050 final int numTrials) { 051 this(clusterer, numTrials, new SumOfClusterVariances<T>(clusterer.getDistanceMeasure())); 052 } 053 054 /** Build a clusterer. 055 * @param clusterer the k-means clusterer to use 056 * @param numTrials number of trial runs 057 * @param evaluator the cluster evaluator to use 058 * @since 3.3 059 */ 060 public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer, 061 final int numTrials, 062 final ClusterEvaluator<T> evaluator) { 063 super(clusterer.getDistanceMeasure()); 064 this.clusterer = clusterer; 065 this.numTrials = numTrials; 066 this.evaluator = evaluator; 067 } 068 069 /** 070 * Returns the embedded k-means clusterer used by this instance. 071 * @return the embedded clusterer 072 */ 073 public KMeansPlusPlusClusterer<T> getClusterer() { 074 return clusterer; 075 } 076 077 /** 078 * Returns the number of trials this instance will do. 079 * @return the number of trials 080 */ 081 public int getNumTrials() { 082 return numTrials; 083 } 084 085 /** 086 * Returns the {@link ClusterEvaluator} used to determine the "best" clustering. 087 * @return the used {@link ClusterEvaluator} 088 * @since 3.3 089 */ 090 public ClusterEvaluator<T> getClusterEvaluator() { 091 return evaluator; 092 } 093 094 /** 095 * Runs the K-means++ clustering algorithm. 096 * 097 * @param points the points to cluster 098 * @return a list of clusters containing the points 099 * @throws MathIllegalArgumentException if the data points are null or the number 100 * of clusters is larger than the number of data points 101 * @throws ConvergenceException if an empty cluster is encountered and the 102 * underlying {@link KMeansPlusPlusClusterer} has its 103 * {@link KMeansPlusPlusClusterer.EmptyClusterStrategy} is set to {@code ERROR}. 104 */ 105 @Override 106 public List<CentroidCluster<T>> cluster(final Collection<T> points) 107 throws MathIllegalArgumentException, ConvergenceException { 108 109 // at first, we have not found any clusters list yet 110 List<CentroidCluster<T>> best = null; 111 double bestVarianceSum = Double.POSITIVE_INFINITY; 112 113 // do several clustering trials 114 for (int i = 0; i < numTrials; ++i) { 115 116 // compute a clusters list 117 List<CentroidCluster<T>> clusters = clusterer.cluster(points); 118 119 // compute the variance of the current list 120 final double varianceSum = evaluator.score(clusters); 121 122 if (evaluator.isBetterScore(varianceSum, bestVarianceSum)) { 123 // this one is the best we have found so far, remember it 124 best = clusters; 125 bestVarianceSum = varianceSum; 126 } 127 128 } 129 130 // return the best clusters list found 131 return best; 132 133 } 134 135}