View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.ml.clustering;
19  
20  import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
21  import org.apache.commons.math4.legacy.ml.clustering.evaluation.CalinskiHarabasz;
22  import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
23  import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance;
24  import org.apache.commons.rng.UniformRandomProvider;
25  import org.apache.commons.rng.simple.RandomSource;
26  import org.junit.Assert;
27  import org.junit.Test;
28  
29  import java.util.ArrayList;
30  import java.util.List;
31  
32  public class MiniBatchKMeansClustererTest {
33      /**
34       * Assert the illegal parameter throws proper Exceptions.
35       */
36      @Test
37      public void testConstructorParameterChecks() {
38          expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, -1, 3, 300, 10, null, null, null));
39          expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, 100, -2, 300, 10, null, null, null));
40          expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, 100, 3, -300, 10, null, null, null));
41          expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, 100, 3, 300, -10, null, null, null));
42      }
43  
44      /**
45       * Expects block throws NumberIsTooSmallException.
46       * @param block the block need to run.
47       */
48      private void expectNumberIsTooSmallException(Runnable block) {
49          assertException(block, NumberIsTooSmallException.class);
50      }
51  
52      /**
53       * Compare the result to KMeansPlusPlusClusterer
54       */
55      @Test
56      public void testCompareToKMeans() {
57          //Generate 4 cluster
58          final UniformRandomProvider rng = RandomSource.MT_64.create();
59          List<DoublePoint> data = generateCircles(rng);
60          KMeansPlusPlusClusterer<DoublePoint> kMeans =
61              new KMeansPlusPlusClusterer<>(4, Integer.MAX_VALUE, DEFAULT_MEASURE, rng);
62          MiniBatchKMeansClusterer<DoublePoint> miniBatchKMeans =
63              new MiniBatchKMeansClusterer<>(4, Integer.MAX_VALUE, 100, 3, 300, 10, DEFAULT_MEASURE, rng,
64                                             KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE);
65          // Test 100 times between KMeansPlusPlusClusterer and MiniBatchKMeansClusterer
66          for (int i = 0; i < 100; i++) {
67              List<CentroidCluster<DoublePoint>> kMeansClusters = kMeans.cluster(data);
68              List<CentroidCluster<DoublePoint>> miniBatchKMeansClusters = miniBatchKMeans.cluster(data);
69              // Assert cluster result has proper clusters count.
70              Assert.assertEquals(4, kMeansClusters.size());
71              Assert.assertEquals(kMeansClusters.size(), miniBatchKMeansClusters.size());
72              int totalDiffCount = 0;
73              for (CentroidCluster<DoublePoint> kMeanCluster : kMeansClusters) {
74                  // Find out most similar cluster between two clusters, and summary the points count variances.
75                  CentroidCluster<DoublePoint> miniBatchCluster = predict(miniBatchKMeansClusters, kMeanCluster.getCenter());
76                  totalDiffCount += Math.abs(kMeanCluster.getPoints().size() - miniBatchCluster.getPoints().size());
77              }
78              // Statistic points different ratio.
79              double diffPointsRatio = totalDiffCount * 1.0 / data.size();
80              // Evaluator score different ratio by "CalinskiHarabasz" algorithm.
81              ClusterEvaluator clusterEvaluator = new CalinskiHarabasz();
82              double kMeansScore = clusterEvaluator.score(kMeansClusters);
83              double miniBatchKMeansScore = clusterEvaluator.score(miniBatchKMeansClusters);
84              double scoreDiffRatio = (kMeansScore - miniBatchKMeansScore) /
85                      kMeansScore;
86              // MiniBatchKMeansClusterer has few score differences between KMeansClusterer(less then 10%).
87              Assert.assertTrue(String.format("Different score ratio %f%%!, diff points ratio: %f%%", scoreDiffRatio * 100, diffPointsRatio * 100),
88                      scoreDiffRatio < 0.1);
89          }
90      }
91  
92      /**
93       * Generate points around 4 circles.
94       * @param rng RNG.
95       * @return Generated points.
96       */
97      private List<DoublePoint> generateCircles(UniformRandomProvider random) {
98          List<DoublePoint> data = new ArrayList<>();
99          data.addAll(generateCircle(250, new double[]{-1.0, -1.0}, 1.0, random));
100         data.addAll(generateCircle(260, new double[]{0.0, 0.0}, 0.7, random));
101         data.addAll(generateCircle(270, new double[]{1.0, 1.0}, 0.7, random));
102         data.addAll(generateCircle(280, new double[]{2.0, 2.0}, 0.7, random));
103         return data;
104     }
105 
106     /**
107      * Generate points as circles.
108      * @param count total points count.
109      * @param center circle center point.
110      * @param radius the circle radius points around.
111      * @param random the Random source.
112      * @return Generated points.
113      */
114     List<DoublePoint> generateCircle(int count, double[] center, double radius,
115                                      UniformRandomProvider random) {
116         double x0 = center[0];
117         double y0 = center[1];
118         ArrayList<DoublePoint> list = new ArrayList<>(count);
119         for (int i = 0; i < count; i++) {
120             double ao = random.nextDouble() * 720 - 360;
121             double r = random.nextDouble() * radius * 2 - radius;
122             double x1 = x0 + r * Math.cos(ao * Math.PI / 180);
123             double y1 = y0 + r * Math.sin(ao * Math.PI / 180);
124             list.add(new DoublePoint(new double[]{x1, y1}));
125         }
126         return list;
127     }
128 
129     /**
130      * Assert there should be a exception.
131      *
132      * @param block          The code block need to assert.
133      * @param exceptionClass A exception class.
134      */
135     public static void assertException(Runnable block, Class<? extends Throwable> exceptionClass) {
136         try {
137             block.run();
138             Assert.fail(String.format("Expects %s", exceptionClass.getSimpleName()));
139         } catch (Throwable e) {
140             if (!exceptionClass.isInstance(e)) {
141                 throw e;
142             }
143         }
144     }
145 
146     /**
147      * Use EuclideanDistance as default DistanceMeasure
148      */
149     public static final DistanceMeasure DEFAULT_MEASURE = new EuclideanDistance();
150 
151     /**
152      * Predict which cluster is best for the point
153      *
154      * @param clusters cluster to predict into
155      * @param point    point to predict
156      * @param measure  distance measurer
157      * @param <T>      type of cluster point
158      * @return the cluster which has nearest center to the point
159      */
160     public static <T extends Clusterable> CentroidCluster<T> predict(List<CentroidCluster<T>> clusters, Clusterable point, DistanceMeasure measure) {
161         double minDistance = Double.POSITIVE_INFINITY;
162         CentroidCluster<T> nearestCluster = null;
163         for (CentroidCluster<T> cluster : clusters) {
164             double distance = measure.compute(point.getPoint(), cluster.getCenter().getPoint());
165             if (distance < minDistance) {
166                 minDistance = distance;
167                 nearestCluster = cluster;
168             }
169         }
170         return nearestCluster;
171     }
172 
173     /**
174      * Predict which cluster is best for the point
175      *
176      * @param clusters cluster to predict into
177      * @param point    point to predict
178      * @param <T>      type of cluster point
179      * @return the cluster which has nearest center to the point
180      */
181     public static <T extends Clusterable> CentroidCluster<T> predict(List<CentroidCluster<T>> clusters, Clusterable point) {
182         return predict(clusters, point, DEFAULT_MEASURE);
183     }
184 }