1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.ml.clustering;
19
20 import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
21 import org.apache.commons.math4.legacy.ml.clustering.evaluation.CalinskiHarabasz;
22 import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
23 import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance;
24 import org.apache.commons.rng.UniformRandomProvider;
25 import org.apache.commons.rng.simple.RandomSource;
26 import org.junit.Assert;
27 import org.junit.Test;
28
29 import java.util.ArrayList;
30 import java.util.List;
31
32 public class MiniBatchKMeansClustererTest {
33
34
35
36 @Test
37 public void testConstructorParameterChecks() {
38 expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, -1, 3, 300, 10, null, null, null));
39 expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, 100, -2, 300, 10, null, null, null));
40 expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, 100, 3, -300, 10, null, null, null));
41 expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, 100, 3, 300, -10, null, null, null));
42 }
43
44
45
46
47
48 private void expectNumberIsTooSmallException(Runnable block) {
49 assertException(block, NumberIsTooSmallException.class);
50 }
51
52
53
54
55 @Test
56 public void testCompareToKMeans() {
57
58 final UniformRandomProvider rng = RandomSource.MT_64.create();
59 List<DoublePoint> data = generateCircles(rng);
60 KMeansPlusPlusClusterer<DoublePoint> kMeans =
61 new KMeansPlusPlusClusterer<>(4, Integer.MAX_VALUE, DEFAULT_MEASURE, rng);
62 MiniBatchKMeansClusterer<DoublePoint> miniBatchKMeans =
63 new MiniBatchKMeansClusterer<>(4, Integer.MAX_VALUE, 100, 3, 300, 10, DEFAULT_MEASURE, rng,
64 KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE);
65
66 for (int i = 0; i < 100; i++) {
67 List<CentroidCluster<DoublePoint>> kMeansClusters = kMeans.cluster(data);
68 List<CentroidCluster<DoublePoint>> miniBatchKMeansClusters = miniBatchKMeans.cluster(data);
69
70 Assert.assertEquals(4, kMeansClusters.size());
71 Assert.assertEquals(kMeansClusters.size(), miniBatchKMeansClusters.size());
72 int totalDiffCount = 0;
73 for (CentroidCluster<DoublePoint> kMeanCluster : kMeansClusters) {
74
75 CentroidCluster<DoublePoint> miniBatchCluster = predict(miniBatchKMeansClusters, kMeanCluster.getCenter());
76 totalDiffCount += Math.abs(kMeanCluster.getPoints().size() - miniBatchCluster.getPoints().size());
77 }
78
79 double diffPointsRatio = totalDiffCount * 1.0 / data.size();
80
81 ClusterEvaluator clusterEvaluator = new CalinskiHarabasz();
82 double kMeansScore = clusterEvaluator.score(kMeansClusters);
83 double miniBatchKMeansScore = clusterEvaluator.score(miniBatchKMeansClusters);
84 double scoreDiffRatio = (kMeansScore - miniBatchKMeansScore) /
85 kMeansScore;
86
87 Assert.assertTrue(String.format("Different score ratio %f%%!, diff points ratio: %f%%", scoreDiffRatio * 100, diffPointsRatio * 100),
88 scoreDiffRatio < 0.1);
89 }
90 }
91
92
93
94
95
96
97 private List<DoublePoint> generateCircles(UniformRandomProvider random) {
98 List<DoublePoint> data = new ArrayList<>();
99 data.addAll(generateCircle(250, new double[]{-1.0, -1.0}, 1.0, random));
100 data.addAll(generateCircle(260, new double[]{0.0, 0.0}, 0.7, random));
101 data.addAll(generateCircle(270, new double[]{1.0, 1.0}, 0.7, random));
102 data.addAll(generateCircle(280, new double[]{2.0, 2.0}, 0.7, random));
103 return data;
104 }
105
106
107
108
109
110
111
112
113
114 List<DoublePoint> generateCircle(int count, double[] center, double radius,
115 UniformRandomProvider random) {
116 double x0 = center[0];
117 double y0 = center[1];
118 ArrayList<DoublePoint> list = new ArrayList<>(count);
119 for (int i = 0; i < count; i++) {
120 double ao = random.nextDouble() * 720 - 360;
121 double r = random.nextDouble() * radius * 2 - radius;
122 double x1 = x0 + r * Math.cos(ao * Math.PI / 180);
123 double y1 = y0 + r * Math.sin(ao * Math.PI / 180);
124 list.add(new DoublePoint(new double[]{x1, y1}));
125 }
126 return list;
127 }
128
129
130
131
132
133
134
135 public static void assertException(Runnable block, Class<? extends Throwable> exceptionClass) {
136 try {
137 block.run();
138 Assert.fail(String.format("Expects %s", exceptionClass.getSimpleName()));
139 } catch (Throwable e) {
140 if (!exceptionClass.isInstance(e)) {
141 throw e;
142 }
143 }
144 }
145
146
147
148
149 public static final DistanceMeasure DEFAULT_MEASURE = new EuclideanDistance();
150
151
152
153
154
155
156
157
158
159
160 public static <T extends Clusterable> CentroidCluster<T> predict(List<CentroidCluster<T>> clusters, Clusterable point, DistanceMeasure measure) {
161 double minDistance = Double.POSITIVE_INFINITY;
162 CentroidCluster<T> nearestCluster = null;
163 for (CentroidCluster<T> cluster : clusters) {
164 double distance = measure.compute(point.getPoint(), cluster.getCenter().getPoint());
165 if (distance < minDistance) {
166 minDistance = distance;
167 nearestCluster = cluster;
168 }
169 }
170 return nearestCluster;
171 }
172
173
174
175
176
177
178
179
180
181 public static <T extends Clusterable> CentroidCluster<T> predict(List<CentroidCluster<T>> clusters, Clusterable point) {
182 return predict(clusters, point, DEFAULT_MEASURE);
183 }
184 }