001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math4.legacy.ml.clustering; 018 019import java.util.ArrayList; 020import java.util.Collection; 021import java.util.Collections; 022import java.util.List; 023 024import org.apache.commons.math4.legacy.exception.NullArgumentException; 025import org.apache.commons.math4.legacy.exception.MathIllegalStateException; 026import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException; 027import org.apache.commons.math4.legacy.linear.MatrixUtils; 028import org.apache.commons.math4.legacy.linear.RealMatrix; 029import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure; 030import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance; 031import org.apache.commons.rng.simple.RandomSource; 032import org.apache.commons.rng.UniformRandomProvider; 033import org.apache.commons.math4.core.jdkmath.JdkMath; 034import org.apache.commons.math4.legacy.core.MathArrays; 035 036/** 037 * Fuzzy K-Means clustering algorithm. 038 * <p> 039 * The Fuzzy K-Means algorithm is a variation of the classical K-Means algorithm, with the 040 * major difference that a single data point is not uniquely assigned to a single cluster. 041 * Instead, each point i has a set of weights u<sub>ij</sub> which indicate the degree of membership 042 * to the cluster j. 043 * <p> 044 * The algorithm then tries to minimize the objective function: 045 * <div style="white-space: pre"><code> 046 * J = ∑<sub>i=1..C</sub>∑<sub>k=1..N</sub> u<sub>ik</sub><sup>m</sup>d<sub>ik</sub><sup>2</sup> 047 * </code></div> 048 * with d<sub>ik</sub> being the distance between data point i and the cluster center k. 049 * <p> 050 * The algorithm requires two parameters: 051 * <ul> 052 * <li>k: the number of clusters 053 * <li>fuzziness: determines the level of cluster fuzziness, larger values lead to fuzzier clusters 054 * </ul> 055 * Additional, optional parameters: 056 * <ul> 057 * <li>maxIterations: the maximum number of iterations 058 * <li>epsilon: the convergence criteria, default is 1e-3 059 * </ul> 060 * <p> 061 * The fuzzy variant of the K-Means algorithm is more robust with regard to the selection 062 * of the initial cluster centers. 063 * 064 * @param <T> type of the points to cluster 065 * @since 3.3 066 */ 067public class FuzzyKMeansClusterer<T extends Clusterable> extends Clusterer<T> { 068 069 /** The default value for the convergence criteria. */ 070 private static final double DEFAULT_EPSILON = 1e-3; 071 072 /** The number of clusters. */ 073 private final int k; 074 075 /** The maximum number of iterations. */ 076 private final int maxIterations; 077 078 /** The fuzziness factor. */ 079 private final double fuzziness; 080 081 /** The convergence criteria. */ 082 private final double epsilon; 083 084 /** Random generator for choosing initial centers. */ 085 private final UniformRandomProvider random; 086 087 /** The membership matrix. */ 088 private double[][] membershipMatrix; 089 090 /** The list of points used in the last call to {@link #cluster(Collection)}. */ 091 private List<T> points; 092 093 /** The list of clusters resulting from the last call to {@link #cluster(Collection)}. */ 094 private List<CentroidCluster<T>> clusters; 095 096 /** 097 * Creates a new instance of a FuzzyKMeansClusterer. 098 * <p> 099 * The euclidean distance will be used as default distance measure. 100 * 101 * @param k the number of clusters to split the data into 102 * @param fuzziness the fuzziness factor, must be > 1.0 103 * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0} 104 */ 105 public FuzzyKMeansClusterer(final int k, final double fuzziness) { 106 this(k, fuzziness, -1, new EuclideanDistance()); 107 } 108 109 /** 110 * Creates a new instance of a FuzzyKMeansClusterer. 111 * 112 * @param k the number of clusters to split the data into 113 * @param fuzziness the fuzziness factor, must be > 1.0 114 * @param maxIterations the maximum number of iterations to run the algorithm for. 115 * If negative, no maximum will be used. 116 * @param measure the distance measure to use 117 * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0} 118 */ 119 public FuzzyKMeansClusterer(final int k, final double fuzziness, 120 final int maxIterations, final DistanceMeasure measure) { 121 this(k, fuzziness, maxIterations, measure, DEFAULT_EPSILON, RandomSource.MT_64.create()); 122 } 123 124 /** 125 * Creates a new instance of a FuzzyKMeansClusterer. 126 * 127 * @param k the number of clusters to split the data into 128 * @param fuzziness the fuzziness factor, must be > 1.0 129 * @param maxIterations the maximum number of iterations to run the algorithm for. 130 * If negative, no maximum will be used. 131 * @param measure the distance measure to use 132 * @param epsilon the convergence criteria (default is 1e-3) 133 * @param random random generator to use for choosing initial centers 134 * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0} 135 */ 136 public FuzzyKMeansClusterer(final int k, final double fuzziness, 137 final int maxIterations, final DistanceMeasure measure, 138 final double epsilon, final UniformRandomProvider random) { 139 super(measure); 140 141 if (fuzziness <= 1.0d) { 142 throw new NumberIsTooSmallException(fuzziness, 1.0, false); 143 } 144 this.k = k; 145 this.fuzziness = fuzziness; 146 this.maxIterations = maxIterations; 147 this.epsilon = epsilon; 148 this.random = random; 149 150 this.membershipMatrix = null; 151 this.points = null; 152 this.clusters = null; 153 } 154 155 /** 156 * Return the number of clusters this instance will use. 157 * @return the number of clusters 158 */ 159 public int getK() { 160 return k; 161 } 162 163 /** 164 * Returns the fuzziness factor used by this instance. 165 * @return the fuzziness factor 166 */ 167 public double getFuzziness() { 168 return fuzziness; 169 } 170 171 /** 172 * Returns the maximum number of iterations this instance will use. 173 * @return the maximum number of iterations, or -1 if no maximum is set 174 */ 175 public int getMaxIterations() { 176 return maxIterations; 177 } 178 179 /** 180 * Returns the convergence criteria used by this instance. 181 * @return the convergence criteria 182 */ 183 public double getEpsilon() { 184 return epsilon; 185 } 186 187 /** 188 * Returns the random generator this instance will use. 189 * @return the random generator 190 */ 191 public UniformRandomProvider getRandomGenerator() { 192 return random; 193 } 194 195 /** 196 * Returns the {@code nxk} membership matrix, where {@code n} is the number 197 * of data points and {@code k} the number of clusters. 198 * <p> 199 * The element U<sub>i,j</sub> represents the membership value for data point {@code i} 200 * to cluster {@code j}. 201 * 202 * @return the membership matrix 203 * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before 204 */ 205 public RealMatrix getMembershipMatrix() { 206 if (membershipMatrix == null) { 207 throw new MathIllegalStateException(); 208 } 209 return MatrixUtils.createRealMatrix(membershipMatrix); 210 } 211 212 /** 213 * Returns an unmodifiable list of the data points used in the last 214 * call to {@link #cluster(Collection)}. 215 * @return the list of data points, or {@code null} if {@link #cluster(Collection)} has 216 * not been called before. 217 */ 218 public List<T> getDataPoints() { 219 return points; 220 } 221 222 /** 223 * Returns the list of clusters resulting from the last call to {@link #cluster(Collection)}. 224 * @return the list of clusters, or {@code null} if {@link #cluster(Collection)} has 225 * not been called before. 226 */ 227 public List<CentroidCluster<T>> getClusters() { 228 return clusters; 229 } 230 231 /** 232 * Get the value of the objective function. 233 * @return the objective function evaluation as double value 234 * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before 235 */ 236 public double getObjectiveFunctionValue() { 237 if (points == null || clusters == null) { 238 throw new MathIllegalStateException(); 239 } 240 241 int i = 0; 242 double objFunction = 0.0; 243 for (final T point : points) { 244 int j = 0; 245 for (final CentroidCluster<T> cluster : clusters) { 246 final double dist = distance(point, cluster.getCenter()); 247 objFunction += (dist * dist) * JdkMath.pow(membershipMatrix[i][j], fuzziness); 248 j++; 249 } 250 i++; 251 } 252 return objFunction; 253 } 254 255 /** 256 * Performs Fuzzy K-Means cluster analysis. 257 * 258 * @param dataPoints the points to cluster 259 * @return the list of clusters 260 * @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException if 261 * the data points are null or the number of clusters is larger than the number 262 * of data points 263 */ 264 @Override 265 public List<CentroidCluster<T>> cluster(final Collection<T> dataPoints) { 266 // sanity checks 267 NullArgumentException.check(dataPoints); 268 269 final int size = dataPoints.size(); 270 271 // number of clusters has to be smaller or equal the number of data points 272 if (size < k) { 273 throw new NumberIsTooSmallException(size, k, false); 274 } 275 276 // copy the input collection to an unmodifiable list with indexed access 277 points = Collections.unmodifiableList(new ArrayList<>(dataPoints)); 278 clusters = new ArrayList<>(); 279 membershipMatrix = new double[size][k]; 280 final double[][] oldMatrix = new double[size][k]; 281 282 // if no points are provided, return an empty list of clusters 283 if (size == 0) { 284 return clusters; 285 } 286 287 initializeMembershipMatrix(); 288 289 // there is at least one point 290 final int pointDimension = points.get(0).getPoint().length; 291 for (int i = 0; i < k; i++) { 292 clusters.add(new CentroidCluster<>(new DoublePoint(new double[pointDimension]))); 293 } 294 295 int iteration = 0; 296 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; 297 double difference = 0.0; 298 299 do { 300 saveMembershipMatrix(oldMatrix); 301 updateClusterCenters(); 302 updateMembershipMatrix(); 303 difference = calculateMaxMembershipChange(oldMatrix); 304 } while (difference > epsilon && ++iteration < max); 305 306 return clusters; 307 } 308 309 /** 310 * Update the cluster centers. 311 */ 312 private void updateClusterCenters() { 313 int j = 0; 314 final List<CentroidCluster<T>> newClusters = new ArrayList<>(k); 315 for (final CentroidCluster<T> cluster : clusters) { 316 final Clusterable center = cluster.getCenter(); 317 int i = 0; 318 double[] arr = new double[center.getPoint().length]; 319 double sum = 0.0; 320 for (final T point : points) { 321 final double u = JdkMath.pow(membershipMatrix[i][j], fuzziness); 322 final double[] pointArr = point.getPoint(); 323 for (int idx = 0; idx < arr.length; idx++) { 324 arr[idx] += u * pointArr[idx]; 325 } 326 sum += u; 327 i++; 328 } 329 MathArrays.scaleInPlace(1.0 / sum, arr); 330 newClusters.add(new CentroidCluster<>(new DoublePoint(arr))); 331 j++; 332 } 333 clusters.clear(); 334 clusters = newClusters; 335 } 336 337 /** 338 * Updates the membership matrix and assigns the points to the cluster with 339 * the highest membership. 340 */ 341 private void updateMembershipMatrix() { 342 for (int i = 0; i < points.size(); i++) { 343 final T point = points.get(i); 344 double maxMembership = Double.MIN_VALUE; 345 int newCluster = -1; 346 for (int j = 0; j < clusters.size(); j++) { 347 double sum = 0.0; 348 final double distA = JdkMath.abs(distance(point, clusters.get(j).getCenter())); 349 350 if (distA != 0.0) { 351 for (final CentroidCluster<T> c : clusters) { 352 final double distB = JdkMath.abs(distance(point, c.getCenter())); 353 if (distB == 0.0) { 354 sum = Double.POSITIVE_INFINITY; 355 break; 356 } 357 sum += JdkMath.pow(distA / distB, 2.0 / (fuzziness - 1.0)); 358 } 359 } 360 361 double membership; 362 if (sum == 0.0) { 363 membership = 1.0; 364 } else if (sum == Double.POSITIVE_INFINITY) { 365 membership = 0.0; 366 } else { 367 membership = 1.0 / sum; 368 } 369 membershipMatrix[i][j] = membership; 370 371 if (membershipMatrix[i][j] > maxMembership) { 372 maxMembership = membershipMatrix[i][j]; 373 newCluster = j; 374 } 375 } 376 clusters.get(newCluster).addPoint(point); 377 } 378 } 379 380 /** 381 * Initialize the membership matrix with random values. 382 */ 383 private void initializeMembershipMatrix() { 384 for (int i = 0; i < points.size(); i++) { 385 for (int j = 0; j < k; j++) { 386 membershipMatrix[i][j] = random.nextDouble(); 387 } 388 membershipMatrix[i] = MathArrays.normalizeArray(membershipMatrix[i], 1.0); 389 } 390 } 391 392 /** 393 * Calculate the maximum element-by-element change of the membership matrix 394 * for the current iteration. 395 * 396 * @param matrix the membership matrix of the previous iteration 397 * @return the maximum membership matrix change 398 */ 399 private double calculateMaxMembershipChange(final double[][] matrix) { 400 double maxMembership = 0.0; 401 for (int i = 0; i < points.size(); i++) { 402 for (int j = 0; j < clusters.size(); j++) { 403 double v = JdkMath.abs(membershipMatrix[i][j] - matrix[i][j]); 404 maxMembership = JdkMath.max(v, maxMembership); 405 } 406 } 407 return maxMembership; 408 } 409 410 /** 411 * Copy the membership matrix into the provided matrix. 412 * 413 * @param matrix the place to store the membership matrix 414 */ 415 private void saveMembershipMatrix(final double[][] matrix) { 416 for (int i = 0; i < points.size(); i++) { 417 System.arraycopy(membershipMatrix[i], 0, matrix[i], 0, clusters.size()); 418 } 419 } 420}