001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math3.ml.clustering; 018 019import java.util.ArrayList; 020import java.util.Collection; 021import java.util.Collections; 022import java.util.List; 023 024import org.apache.commons.math3.exception.MathIllegalArgumentException; 025import org.apache.commons.math3.exception.MathIllegalStateException; 026import org.apache.commons.math3.exception.NumberIsTooSmallException; 027import org.apache.commons.math3.linear.MatrixUtils; 028import org.apache.commons.math3.linear.RealMatrix; 029import org.apache.commons.math3.ml.distance.DistanceMeasure; 030import org.apache.commons.math3.ml.distance.EuclideanDistance; 031import org.apache.commons.math3.random.JDKRandomGenerator; 032import org.apache.commons.math3.random.RandomGenerator; 033import org.apache.commons.math3.util.FastMath; 034import org.apache.commons.math3.util.MathArrays; 035import org.apache.commons.math3.util.MathUtils; 036 037/** 038 * Fuzzy K-Means clustering algorithm. 039 * <p> 040 * The Fuzzy K-Means algorithm is a variation of the classical K-Means algorithm, with the 041 * major difference that a single data point is not uniquely assigned to a single cluster. 042 * Instead, each point i has a set of weights u<sub>ij</sub> which indicate the degree of membership 043 * to the cluster j. 044 * <p> 045 * The algorithm then tries to minimize the objective function: 046 * <pre> 047 * J = ∑<sub>i=1..C</sub>∑<sub>k=1..N</sub> u<sub>ik</sub><sup>m</sup>d<sub>ik</sub><sup>2</sup> 048 * </pre> 049 * with d<sub>ik</sub> being the distance between data point i and the cluster center k. 050 * <p> 051 * The algorithm requires two parameters: 052 * <ul> 053 * <li>k: the number of clusters 054 * <li>fuzziness: determines the level of cluster fuzziness, larger values lead to fuzzier clusters 055 * </ul> 056 * Additional, optional parameters: 057 * <ul> 058 * <li>maxIterations: the maximum number of iterations 059 * <li>epsilon: the convergence criteria, default is 1e-3 060 * </ul> 061 * <p> 062 * The fuzzy variant of the K-Means algorithm is more robust with regard to the selection 063 * of the initial cluster centers. 064 * 065 * @param <T> type of the points to cluster 066 * @since 3.3 067 */ 068public class FuzzyKMeansClusterer<T extends Clusterable> extends Clusterer<T> { 069 070 /** The default value for the convergence criteria. */ 071 private static final double DEFAULT_EPSILON = 1e-3; 072 073 /** The number of clusters. */ 074 private final int k; 075 076 /** The maximum number of iterations. */ 077 private final int maxIterations; 078 079 /** The fuzziness factor. */ 080 private final double fuzziness; 081 082 /** The convergence criteria. */ 083 private final double epsilon; 084 085 /** Random generator for choosing initial centers. */ 086 private final RandomGenerator random; 087 088 /** The membership matrix. */ 089 private double[][] membershipMatrix; 090 091 /** The list of points used in the last call to {@link #cluster(Collection)}. */ 092 private List<T> points; 093 094 /** The list of clusters resulting from the last call to {@link #cluster(Collection)}. */ 095 private List<CentroidCluster<T>> clusters; 096 097 /** 098 * Creates a new instance of a FuzzyKMeansClusterer. 099 * <p> 100 * The euclidean distance will be used as default distance measure. 101 * 102 * @param k the number of clusters to split the data into 103 * @param fuzziness the fuzziness factor, must be > 1.0 104 * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0} 105 */ 106 public FuzzyKMeansClusterer(final int k, final double fuzziness) throws NumberIsTooSmallException { 107 this(k, fuzziness, -1, new EuclideanDistance()); 108 } 109 110 /** 111 * Creates a new instance of a FuzzyKMeansClusterer. 112 * 113 * @param k the number of clusters to split the data into 114 * @param fuzziness the fuzziness factor, must be > 1.0 115 * @param maxIterations the maximum number of iterations to run the algorithm for. 116 * If negative, no maximum will be used. 117 * @param measure the distance measure to use 118 * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0} 119 */ 120 public FuzzyKMeansClusterer(final int k, final double fuzziness, 121 final int maxIterations, final DistanceMeasure measure) 122 throws NumberIsTooSmallException { 123 this(k, fuzziness, maxIterations, measure, DEFAULT_EPSILON, new JDKRandomGenerator()); 124 } 125 126 /** 127 * Creates a new instance of a FuzzyKMeansClusterer. 128 * 129 * @param k the number of clusters to split the data into 130 * @param fuzziness the fuzziness factor, must be > 1.0 131 * @param maxIterations the maximum number of iterations to run the algorithm for. 132 * If negative, no maximum will be used. 133 * @param measure the distance measure to use 134 * @param epsilon the convergence criteria (default is 1e-3) 135 * @param random random generator to use for choosing initial centers 136 * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0} 137 */ 138 public FuzzyKMeansClusterer(final int k, final double fuzziness, 139 final int maxIterations, final DistanceMeasure measure, 140 final double epsilon, final RandomGenerator random) 141 throws NumberIsTooSmallException { 142 143 super(measure); 144 145 if (fuzziness <= 1.0d) { 146 throw new NumberIsTooSmallException(fuzziness, 1.0, false); 147 } 148 this.k = k; 149 this.fuzziness = fuzziness; 150 this.maxIterations = maxIterations; 151 this.epsilon = epsilon; 152 this.random = random; 153 154 this.membershipMatrix = null; 155 this.points = null; 156 this.clusters = null; 157 } 158 159 /** 160 * Return the number of clusters this instance will use. 161 * @return the number of clusters 162 */ 163 public int getK() { 164 return k; 165 } 166 167 /** 168 * Returns the fuzziness factor used by this instance. 169 * @return the fuzziness factor 170 */ 171 public double getFuzziness() { 172 return fuzziness; 173 } 174 175 /** 176 * Returns the maximum number of iterations this instance will use. 177 * @return the maximum number of iterations, or -1 if no maximum is set 178 */ 179 public int getMaxIterations() { 180 return maxIterations; 181 } 182 183 /** 184 * Returns the convergence criteria used by this instance. 185 * @return the convergence criteria 186 */ 187 public double getEpsilon() { 188 return epsilon; 189 } 190 191 /** 192 * Returns the random generator this instance will use. 193 * @return the random generator 194 */ 195 public RandomGenerator getRandomGenerator() { 196 return random; 197 } 198 199 /** 200 * Returns the {@code nxk} membership matrix, where {@code n} is the number 201 * of data points and {@code k} the number of clusters. 202 * <p> 203 * The element U<sub>i,j</sub> represents the membership value for data point {@code i} 204 * to cluster {@code j}. 205 * 206 * @return the membership matrix 207 * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before 208 */ 209 public RealMatrix getMembershipMatrix() { 210 if (membershipMatrix == null) { 211 throw new MathIllegalStateException(); 212 } 213 return MatrixUtils.createRealMatrix(membershipMatrix); 214 } 215 216 /** 217 * Returns an unmodifiable list of the data points used in the last 218 * call to {@link #cluster(Collection)}. 219 * @return the list of data points, or {@code null} if {@link #cluster(Collection)} has 220 * not been called before. 221 */ 222 public List<T> getDataPoints() { 223 return points; 224 } 225 226 /** 227 * Returns the list of clusters resulting from the last call to {@link #cluster(Collection)}. 228 * @return the list of clusters, or {@code null} if {@link #cluster(Collection)} has 229 * not been called before. 230 */ 231 public List<CentroidCluster<T>> getClusters() { 232 return clusters; 233 } 234 235 /** 236 * Get the value of the objective function. 237 * @return the objective function evaluation as double value 238 * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before 239 */ 240 public double getObjectiveFunctionValue() { 241 if (points == null || clusters == null) { 242 throw new MathIllegalStateException(); 243 } 244 245 int i = 0; 246 double objFunction = 0.0; 247 for (final T point : points) { 248 int j = 0; 249 for (final CentroidCluster<T> cluster : clusters) { 250 final double dist = distance(point, cluster.getCenter()); 251 objFunction += (dist * dist) * FastMath.pow(membershipMatrix[i][j], fuzziness); 252 j++; 253 } 254 i++; 255 } 256 return objFunction; 257 } 258 259 /** 260 * Performs Fuzzy K-Means cluster analysis. 261 * 262 * @param dataPoints the points to cluster 263 * @return the list of clusters 264 * @throws MathIllegalArgumentException if the data points are null or the number 265 * of clusters is larger than the number of data points 266 */ 267 @Override 268 public List<CentroidCluster<T>> cluster(final Collection<T> dataPoints) 269 throws MathIllegalArgumentException { 270 271 // sanity checks 272 MathUtils.checkNotNull(dataPoints); 273 274 final int size = dataPoints.size(); 275 276 // number of clusters has to be smaller or equal the number of data points 277 if (size < k) { 278 throw new NumberIsTooSmallException(size, k, false); 279 } 280 281 // copy the input collection to an unmodifiable list with indexed access 282 points = Collections.unmodifiableList(new ArrayList<T>(dataPoints)); 283 clusters = new ArrayList<CentroidCluster<T>>(); 284 membershipMatrix = new double[size][k]; 285 final double[][] oldMatrix = new double[size][k]; 286 287 // if no points are provided, return an empty list of clusters 288 if (size == 0) { 289 return clusters; 290 } 291 292 initializeMembershipMatrix(); 293 294 // there is at least one point 295 final int pointDimension = points.get(0).getPoint().length; 296 for (int i = 0; i < k; i++) { 297 clusters.add(new CentroidCluster<T>(new DoublePoint(new double[pointDimension]))); 298 } 299 300 int iteration = 0; 301 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; 302 double difference = 0.0; 303 304 do { 305 saveMembershipMatrix(oldMatrix); 306 updateClusterCenters(); 307 updateMembershipMatrix(); 308 difference = calculateMaxMembershipChange(oldMatrix); 309 } while (difference > epsilon && ++iteration < max); 310 311 return clusters; 312 } 313 314 /** 315 * Update the cluster centers. 316 */ 317 private void updateClusterCenters() { 318 int j = 0; 319 final List<CentroidCluster<T>> newClusters = new ArrayList<CentroidCluster<T>>(k); 320 for (final CentroidCluster<T> cluster : clusters) { 321 final Clusterable center = cluster.getCenter(); 322 int i = 0; 323 double[] arr = new double[center.getPoint().length]; 324 double sum = 0.0; 325 for (final T point : points) { 326 final double u = FastMath.pow(membershipMatrix[i][j], fuzziness); 327 final double[] pointArr = point.getPoint(); 328 for (int idx = 0; idx < arr.length; idx++) { 329 arr[idx] += u * pointArr[idx]; 330 } 331 sum += u; 332 i++; 333 } 334 MathArrays.scaleInPlace(1.0 / sum, arr); 335 newClusters.add(new CentroidCluster<T>(new DoublePoint(arr))); 336 j++; 337 } 338 clusters.clear(); 339 clusters = newClusters; 340 } 341 342 /** 343 * Updates the membership matrix and assigns the points to the cluster with 344 * the highest membership. 345 */ 346 private void updateMembershipMatrix() { 347 for (int i = 0; i < points.size(); i++) { 348 final T point = points.get(i); 349 double maxMembership = Double.MIN_VALUE; 350 int newCluster = -1; 351 for (int j = 0; j < clusters.size(); j++) { 352 double sum = 0.0; 353 final double distA = FastMath.abs(distance(point, clusters.get(j).getCenter())); 354 355 if (distA != 0.0) { 356 for (final CentroidCluster<T> c : clusters) { 357 final double distB = FastMath.abs(distance(point, c.getCenter())); 358 if (distB == 0.0) { 359 sum = Double.POSITIVE_INFINITY; 360 break; 361 } 362 sum += FastMath.pow(distA / distB, 2.0 / (fuzziness - 1.0)); 363 } 364 } 365 366 double membership; 367 if (sum == 0.0) { 368 membership = 1.0; 369 } else if (sum == Double.POSITIVE_INFINITY) { 370 membership = 0.0; 371 } else { 372 membership = 1.0 / sum; 373 } 374 membershipMatrix[i][j] = membership; 375 376 if (membershipMatrix[i][j] > maxMembership) { 377 maxMembership = membershipMatrix[i][j]; 378 newCluster = j; 379 } 380 } 381 clusters.get(newCluster).addPoint(point); 382 } 383 } 384 385 /** 386 * Initialize the membership matrix with random values. 387 */ 388 private void initializeMembershipMatrix() { 389 for (int i = 0; i < points.size(); i++) { 390 for (int j = 0; j < k; j++) { 391 membershipMatrix[i][j] = random.nextDouble(); 392 } 393 membershipMatrix[i] = MathArrays.normalizeArray(membershipMatrix[i], 1.0); 394 } 395 } 396 397 /** 398 * Calculate the maximum element-by-element change of the membership matrix 399 * for the current iteration. 400 * 401 * @param matrix the membership matrix of the previous iteration 402 * @return the maximum membership matrix change 403 */ 404 private double calculateMaxMembershipChange(final double[][] matrix) { 405 double maxMembership = 0.0; 406 for (int i = 0; i < points.size(); i++) { 407 for (int j = 0; j < clusters.size(); j++) { 408 double v = FastMath.abs(membershipMatrix[i][j] - matrix[i][j]); 409 maxMembership = FastMath.max(v, maxMembership); 410 } 411 } 412 return maxMembership; 413 } 414 415 /** 416 * Copy the membership matrix into the provided matrix. 417 * 418 * @param matrix the place to store the membership matrix 419 */ 420 private void saveMembershipMatrix(final double[][] matrix) { 421 for (int i = 0; i < points.size(); i++) { 422 System.arraycopy(membershipMatrix[i], 0, matrix[i], 0, clusters.size()); 423 } 424 } 425 426}