1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.ml.clustering;
19
20 import java.util.ArrayList;
21 import java.util.Arrays;
22 import java.util.Collections;
23 import java.util.List;
24 import org.junit.Assert;
25 import org.junit.Test;
26
27 import org.apache.commons.rng.simple.RandomSource;
28 import org.apache.commons.rng.sampling.shape.BoxSampler;
29 import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
30 import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
31 import org.apache.commons.math4.legacy.stat.descriptive.moment.VectorialMean;
32 import org.apache.commons.math4.legacy.core.MathArrays;
33
34
35
36
37 public class ElkanKMeansPlusPlusClustererTest {
38 @Test
39 public void validateOneDimensionSingleClusterZeroMean() {
40 final List<DoublePoint> testPoints = Arrays.asList(new DoublePoint(new double[]{1}),
41 new DoublePoint(new double[]{2}),
42 new DoublePoint(new double[]{-3}));
43 final ElkanKMeansPlusPlusClusterer<DoublePoint> clusterer = new ElkanKMeansPlusPlusClusterer<>(1);
44 final List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(testPoints);
45 Assert.assertEquals(1, clusters.size());
46 Assert.assertTrue(MathArrays.equals(new double[]{0}, clusters.get(0).getCenter().getPoint()));
47 }
48
49 @Test(expected = NumberIsTooSmallException.class)
50 public void illegalKParameter() {
51 final int n = 20;
52 final int d = 3;
53 final int k = 100;
54
55 final List<DoublePoint> testPoints = generatePoints(n, d);
56 final ElkanKMeansPlusPlusClusterer<DoublePoint> clusterer = new ElkanKMeansPlusPlusClusterer<>(k);
57 clusterer.cluster(testPoints);
58 }
59
60 @Test
61 public void numberOfClustersSameAsInputSize() {
62 final int n = 3;
63 final int d = 2;
64 final int k = 3;
65
66 final List<DoublePoint> testPoints = generatePoints(n, d);
67 final ElkanKMeansPlusPlusClusterer<DoublePoint> clusterer = new ElkanKMeansPlusPlusClusterer<>(k);
68 final List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(testPoints);
69 Assert.assertEquals(k, clusters.size());
70 Assert.assertEquals(1, clusters.get(0).getPoints().size());
71 Assert.assertEquals(1, clusters.get(1).getPoints().size());
72 Assert.assertEquals(1, clusters.get(2).getPoints().size());
73 }
74
75 @Test(expected = NullPointerException.class)
76 public void illegalInputParameter() {
77 final ElkanKMeansPlusPlusClusterer<DoublePoint> clusterer = new ElkanKMeansPlusPlusClusterer<>(10);
78 clusterer.cluster(null);
79 }
80
81 @Test(expected = NumberIsTooSmallException.class)
82 public void emptyInputPointsList() {
83 final ElkanKMeansPlusPlusClusterer<DoublePoint> clusterer = new ElkanKMeansPlusPlusClusterer<>(10);
84 clusterer.cluster(Collections.<DoublePoint>emptyList());
85 }
86
87 @Test(expected = NotStrictlyPositiveException.class)
88 public void negativeKParameterValue() {
89 new ElkanKMeansPlusPlusClusterer<>(-1);
90 }
91
92 @Test(expected = NotStrictlyPositiveException.class)
93 public void kParameterEqualsZero() {
94 new ElkanKMeansPlusPlusClusterer<>(0);
95 }
96
97 @Test
98 public void oneClusterCenterShouldBeTheMean() {
99 final int n = 100;
100 final int d = 2;
101
102 final List<DoublePoint> testPoints = generatePoints(n, d);
103 final KMeansPlusPlusClusterer<DoublePoint> clusterer = new KMeansPlusPlusClusterer<>(1);
104
105 final List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(testPoints);
106
107 final VectorialMean mean = new VectorialMean(d);
108 for (DoublePoint each : testPoints) {
109 mean.increment(each.getPoint());
110 }
111 Assert.assertEquals(1, clusters.size());
112 Assert.assertArrayEquals(mean.getResult(), clusters.get(0).getCenter().getPoint(), 1e-6);
113 }
114
115
116
117
118
119
120
121
122 private static List<DoublePoint> generatePoints(int n, int d) {
123 final List<DoublePoint> results = new ArrayList<>();
124 final double[] lower = new double[d];
125 final double[] upper = new double[d];
126 Arrays.fill(upper, 1);
127 final BoxSampler rnd = BoxSampler.of(RandomSource.KISS.create(),
128 lower,
129 upper);
130
131 for (int i = 0; i < n; i++) {
132 results.add(new DoublePoint(rnd.sample()));
133 }
134
135 return results;
136 }
137 }