1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.math4.legacy.ml.clustering;
18
19 import java.util.ArrayList;
20 import java.util.Collection;
21 import java.util.Collections;
22 import java.util.List;
23
24 import org.apache.commons.math4.legacy.exception.NullArgumentException;
25 import org.apache.commons.math4.legacy.exception.MathIllegalStateException;
26 import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
27 import org.apache.commons.math4.legacy.linear.MatrixUtils;
28 import org.apache.commons.math4.legacy.linear.RealMatrix;
29 import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
30 import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance;
31 import org.apache.commons.rng.simple.RandomSource;
32 import org.apache.commons.rng.UniformRandomProvider;
33 import org.apache.commons.math4.core.jdkmath.JdkMath;
34 import org.apache.commons.math4.legacy.core.MathArrays;
35
36 /**
37 * Fuzzy K-Means clustering algorithm.
38 * <p>
39 * The Fuzzy K-Means algorithm is a variation of the classical K-Means algorithm, with the
40 * major difference that a single data point is not uniquely assigned to a single cluster.
41 * Instead, each point i has a set of weights u<sub>ij</sub> which indicate the degree of membership
42 * to the cluster j.
43 * <p>
44 * The algorithm then tries to minimize the objective function:
45 * <div style="white-space: pre"><code>
46 * J = ∑<sub>i=1..C</sub>∑<sub>k=1..N</sub> u<sub>ik</sub><sup>m</sup>d<sub>ik</sub><sup>2</sup>
47 * </code></div>
48 * with d<sub>ik</sub> being the distance between data point i and the cluster center k.
49 * <p>
50 * The algorithm requires two parameters:
51 * <ul>
52 * <li>k: the number of clusters
53 * <li>fuzziness: determines the level of cluster fuzziness, larger values lead to fuzzier clusters
54 * </ul>
55 * Additional, optional parameters:
56 * <ul>
57 * <li>maxIterations: the maximum number of iterations
58 * <li>epsilon: the convergence criteria, default is 1e-3
59 * </ul>
60 * <p>
61 * The fuzzy variant of the K-Means algorithm is more robust with regard to the selection
62 * of the initial cluster centers.
63 *
64 * @param <T> type of the points to cluster
65 * @since 3.3
66 */
67 public class FuzzyKMeansClusterer<T extends Clusterable> extends Clusterer<T> {
68
69 /** The default value for the convergence criteria. */
70 private static final double DEFAULT_EPSILON = 1e-3;
71
72 /** The number of clusters. */
73 private final int k;
74
75 /** The maximum number of iterations. */
76 private final int maxIterations;
77
78 /** The fuzziness factor. */
79 private final double fuzziness;
80
81 /** The convergence criteria. */
82 private final double epsilon;
83
84 /** Random generator for choosing initial centers. */
85 private final UniformRandomProvider random;
86
87 /** The membership matrix. */
88 private double[][] membershipMatrix;
89
90 /** The list of points used in the last call to {@link #cluster(Collection)}. */
91 private List<T> points;
92
93 /** The list of clusters resulting from the last call to {@link #cluster(Collection)}. */
94 private List<CentroidCluster<T>> clusters;
95
96 /**
97 * Creates a new instance of a FuzzyKMeansClusterer.
98 * <p>
99 * The euclidean distance will be used as default distance measure.
100 *
101 * @param k the number of clusters to split the data into
102 * @param fuzziness the fuzziness factor, must be > 1.0
103 * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0}
104 */
105 public FuzzyKMeansClusterer(final int k, final double fuzziness) {
106 this(k, fuzziness, -1, new EuclideanDistance());
107 }
108
109 /**
110 * Creates a new instance of a FuzzyKMeansClusterer.
111 *
112 * @param k the number of clusters to split the data into
113 * @param fuzziness the fuzziness factor, must be > 1.0
114 * @param maxIterations the maximum number of iterations to run the algorithm for.
115 * If negative, no maximum will be used.
116 * @param measure the distance measure to use
117 * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0}
118 */
119 public FuzzyKMeansClusterer(final int k, final double fuzziness,
120 final int maxIterations, final DistanceMeasure measure) {
121 this(k, fuzziness, maxIterations, measure, DEFAULT_EPSILON, RandomSource.MT_64.create());
122 }
123
124 /**
125 * Creates a new instance of a FuzzyKMeansClusterer.
126 *
127 * @param k the number of clusters to split the data into
128 * @param fuzziness the fuzziness factor, must be > 1.0
129 * @param maxIterations the maximum number of iterations to run the algorithm for.
130 * If negative, no maximum will be used.
131 * @param measure the distance measure to use
132 * @param epsilon the convergence criteria (default is 1e-3)
133 * @param random random generator to use for choosing initial centers
134 * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0}
135 */
136 public FuzzyKMeansClusterer(final int k, final double fuzziness,
137 final int maxIterations, final DistanceMeasure measure,
138 final double epsilon, final UniformRandomProvider random) {
139 super(measure);
140
141 if (fuzziness <= 1.0d) {
142 throw new NumberIsTooSmallException(fuzziness, 1.0, false);
143 }
144 this.k = k;
145 this.fuzziness = fuzziness;
146 this.maxIterations = maxIterations;
147 this.epsilon = epsilon;
148 this.random = random;
149
150 this.membershipMatrix = null;
151 this.points = null;
152 this.clusters = null;
153 }
154
155 /**
156 * Return the number of clusters this instance will use.
157 * @return the number of clusters
158 */
159 public int getK() {
160 return k;
161 }
162
163 /**
164 * Returns the fuzziness factor used by this instance.
165 * @return the fuzziness factor
166 */
167 public double getFuzziness() {
168 return fuzziness;
169 }
170
171 /**
172 * Returns the maximum number of iterations this instance will use.
173 * @return the maximum number of iterations, or -1 if no maximum is set
174 */
175 public int getMaxIterations() {
176 return maxIterations;
177 }
178
179 /**
180 * Returns the convergence criteria used by this instance.
181 * @return the convergence criteria
182 */
183 public double getEpsilon() {
184 return epsilon;
185 }
186
187 /**
188 * Returns the random generator this instance will use.
189 * @return the random generator
190 */
191 public UniformRandomProvider getRandomGenerator() {
192 return random;
193 }
194
195 /**
196 * Returns the {@code nxk} membership matrix, where {@code n} is the number
197 * of data points and {@code k} the number of clusters.
198 * <p>
199 * The element U<sub>i,j</sub> represents the membership value for data point {@code i}
200 * to cluster {@code j}.
201 *
202 * @return the membership matrix
203 * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before
204 */
205 public RealMatrix getMembershipMatrix() {
206 if (membershipMatrix == null) {
207 throw new MathIllegalStateException();
208 }
209 return MatrixUtils.createRealMatrix(membershipMatrix);
210 }
211
212 /**
213 * Returns an unmodifiable list of the data points used in the last
214 * call to {@link #cluster(Collection)}.
215 * @return the list of data points, or {@code null} if {@link #cluster(Collection)} has
216 * not been called before.
217 */
218 public List<T> getDataPoints() {
219 return points;
220 }
221
222 /**
223 * Returns the list of clusters resulting from the last call to {@link #cluster(Collection)}.
224 * @return the list of clusters, or {@code null} if {@link #cluster(Collection)} has
225 * not been called before.
226 */
227 public List<CentroidCluster<T>> getClusters() {
228 return clusters;
229 }
230
231 /**
232 * Get the value of the objective function.
233 * @return the objective function evaluation as double value
234 * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before
235 */
236 public double getObjectiveFunctionValue() {
237 if (points == null || clusters == null) {
238 throw new MathIllegalStateException();
239 }
240
241 int i = 0;
242 double objFunction = 0.0;
243 for (final T point : points) {
244 int j = 0;
245 for (final CentroidCluster<T> cluster : clusters) {
246 final double dist = distance(point, cluster.getCenter());
247 objFunction += (dist * dist) * JdkMath.pow(membershipMatrix[i][j], fuzziness);
248 j++;
249 }
250 i++;
251 }
252 return objFunction;
253 }
254
255 /**
256 * Performs Fuzzy K-Means cluster analysis.
257 *
258 * @param dataPoints the points to cluster
259 * @return the list of clusters
260 * @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException if
261 * the data points are null or the number of clusters is larger than the number
262 * of data points
263 */
264 @Override
265 public List<CentroidCluster<T>> cluster(final Collection<T> dataPoints) {
266 // sanity checks
267 NullArgumentException.check(dataPoints);
268
269 final int size = dataPoints.size();
270
271 // number of clusters has to be smaller or equal the number of data points
272 if (size < k) {
273 throw new NumberIsTooSmallException(size, k, false);
274 }
275
276 // copy the input collection to an unmodifiable list with indexed access
277 points = Collections.unmodifiableList(new ArrayList<>(dataPoints));
278 clusters = new ArrayList<>();
279 membershipMatrix = new double[size][k];
280 final double[][] oldMatrix = new double[size][k];
281
282 // if no points are provided, return an empty list of clusters
283 if (size == 0) {
284 return clusters;
285 }
286
287 initializeMembershipMatrix();
288
289 // there is at least one point
290 final int pointDimension = points.get(0).getPoint().length;
291 for (int i = 0; i < k; i++) {
292 clusters.add(new CentroidCluster<>(new DoublePoint(new double[pointDimension])));
293 }
294
295 int iteration = 0;
296 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
297 double difference = 0.0;
298
299 do {
300 saveMembershipMatrix(oldMatrix);
301 updateClusterCenters();
302 updateMembershipMatrix();
303 difference = calculateMaxMembershipChange(oldMatrix);
304 } while (difference > epsilon && ++iteration < max);
305
306 return clusters;
307 }
308
309 /**
310 * Update the cluster centers.
311 */
312 private void updateClusterCenters() {
313 int j = 0;
314 final List<CentroidCluster<T>> newClusters = new ArrayList<>(k);
315 for (final CentroidCluster<T> cluster : clusters) {
316 final Clusterable center = cluster.getCenter();
317 int i = 0;
318 double[] arr = new double[center.getPoint().length];
319 double sum = 0.0;
320 for (final T point : points) {
321 final double u = JdkMath.pow(membershipMatrix[i][j], fuzziness);
322 final double[] pointArr = point.getPoint();
323 for (int idx = 0; idx < arr.length; idx++) {
324 arr[idx] += u * pointArr[idx];
325 }
326 sum += u;
327 i++;
328 }
329 MathArrays.scaleInPlace(1.0 / sum, arr);
330 newClusters.add(new CentroidCluster<>(new DoublePoint(arr)));
331 j++;
332 }
333 clusters.clear();
334 clusters = newClusters;
335 }
336
337 /**
338 * Updates the membership matrix and assigns the points to the cluster with
339 * the highest membership.
340 */
341 private void updateMembershipMatrix() {
342 for (int i = 0; i < points.size(); i++) {
343 final T point = points.get(i);
344 double maxMembership = Double.MIN_VALUE;
345 int newCluster = -1;
346 for (int j = 0; j < clusters.size(); j++) {
347 double sum = 0.0;
348 final double distA = JdkMath.abs(distance(point, clusters.get(j).getCenter()));
349
350 if (distA != 0.0) {
351 for (final CentroidCluster<T> c : clusters) {
352 final double distB = JdkMath.abs(distance(point, c.getCenter()));
353 if (distB == 0.0) {
354 sum = Double.POSITIVE_INFINITY;
355 break;
356 }
357 sum += JdkMath.pow(distA / distB, 2.0 / (fuzziness - 1.0));
358 }
359 }
360
361 double membership;
362 if (sum == 0.0) {
363 membership = 1.0;
364 } else if (sum == Double.POSITIVE_INFINITY) {
365 membership = 0.0;
366 } else {
367 membership = 1.0 / sum;
368 }
369 membershipMatrix[i][j] = membership;
370
371 if (membershipMatrix[i][j] > maxMembership) {
372 maxMembership = membershipMatrix[i][j];
373 newCluster = j;
374 }
375 }
376 clusters.get(newCluster).addPoint(point);
377 }
378 }
379
380 /**
381 * Initialize the membership matrix with random values.
382 */
383 private void initializeMembershipMatrix() {
384 for (int i = 0; i < points.size(); i++) {
385 for (int j = 0; j < k; j++) {
386 membershipMatrix[i][j] = random.nextDouble();
387 }
388 membershipMatrix[i] = MathArrays.normalizeArray(membershipMatrix[i], 1.0);
389 }
390 }
391
392 /**
393 * Calculate the maximum element-by-element change of the membership matrix
394 * for the current iteration.
395 *
396 * @param matrix the membership matrix of the previous iteration
397 * @return the maximum membership matrix change
398 */
399 private double calculateMaxMembershipChange(final double[][] matrix) {
400 double maxMembership = 0.0;
401 for (int i = 0; i < points.size(); i++) {
402 for (int j = 0; j < clusters.size(); j++) {
403 double v = JdkMath.abs(membershipMatrix[i][j] - matrix[i][j]);
404 maxMembership = JdkMath.max(v, maxMembership);
405 }
406 }
407 return maxMembership;
408 }
409
410 /**
411 * Copy the membership matrix into the provided matrix.
412 *
413 * @param matrix the place to store the membership matrix
414 */
415 private void saveMembershipMatrix(final double[][] matrix) {
416 for (int i = 0; i < points.size(); i++) {
417 System.arraycopy(membershipMatrix[i], 0, matrix[i], 0, clusters.size());
418 }
419 }
420 }