View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.ml.clustering;
19  
20  import java.util.ArrayList;
21  import java.util.Arrays;
22  import java.util.Collections;
23  import java.util.List;
24  import org.junit.Assert;
25  import org.junit.Test;
26  
27  import org.apache.commons.rng.simple.RandomSource;
28  import org.apache.commons.rng.sampling.shape.BoxSampler;
29  import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
30  import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
31  import org.apache.commons.math4.legacy.stat.descriptive.moment.VectorialMean;
32  import org.apache.commons.math4.legacy.core.MathArrays;
33  
34  /**
35   * Tests for {@link ElkanKmeansPlusPlusClusterer}.
36   */
37  public class ElkanKMeansPlusPlusClustererTest {
38      @Test
39      public void validateOneDimensionSingleClusterZeroMean() {
40          final List<DoublePoint> testPoints = Arrays.asList(new DoublePoint(new double[]{1}),
41                                                             new DoublePoint(new double[]{2}),
42                                                             new DoublePoint(new double[]{-3}));
43          final ElkanKMeansPlusPlusClusterer<DoublePoint> clusterer = new ElkanKMeansPlusPlusClusterer<>(1);
44          final List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(testPoints);
45          Assert.assertEquals(1, clusters.size());
46          Assert.assertTrue(MathArrays.equals(new double[]{0}, clusters.get(0).getCenter().getPoint()));
47      }
48  
49      @Test(expected = NumberIsTooSmallException.class)
50      public void illegalKParameter() {
51          final int n = 20;
52          final int d = 3;
53          final int k = 100;
54  
55          final List<DoublePoint> testPoints = generatePoints(n, d);
56          final ElkanKMeansPlusPlusClusterer<DoublePoint> clusterer = new ElkanKMeansPlusPlusClusterer<>(k);
57          clusterer.cluster(testPoints);
58      }
59  
60      @Test
61      public void numberOfClustersSameAsInputSize() {
62          final int n = 3;
63          final int d = 2;
64          final int k = 3;
65  
66          final List<DoublePoint> testPoints = generatePoints(n, d);
67          final ElkanKMeansPlusPlusClusterer<DoublePoint> clusterer = new ElkanKMeansPlusPlusClusterer<>(k);
68          final List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(testPoints);
69          Assert.assertEquals(k, clusters.size());
70          Assert.assertEquals(1, clusters.get(0).getPoints().size());
71          Assert.assertEquals(1, clusters.get(1).getPoints().size());
72          Assert.assertEquals(1, clusters.get(2).getPoints().size());
73      }
74  
75      @Test(expected = NullPointerException.class)
76      public void illegalInputParameter() {
77          final ElkanKMeansPlusPlusClusterer<DoublePoint> clusterer = new ElkanKMeansPlusPlusClusterer<>(10);
78          clusterer.cluster(null);
79      }
80  
81      @Test(expected = NumberIsTooSmallException.class)
82      public void emptyInputPointsList() {
83          final ElkanKMeansPlusPlusClusterer<DoublePoint> clusterer = new ElkanKMeansPlusPlusClusterer<>(10);
84          clusterer.cluster(Collections.<DoublePoint>emptyList());
85      }
86  
87      @Test(expected = NotStrictlyPositiveException.class)
88      public void negativeKParameterValue() {
89          new ElkanKMeansPlusPlusClusterer<>(-1);
90      }
91  
92      @Test(expected = NotStrictlyPositiveException.class)
93      public void kParameterEqualsZero() {
94          new ElkanKMeansPlusPlusClusterer<>(0);
95      }
96  
97      @Test
98      public void oneClusterCenterShouldBeTheMean() {
99          final int n = 100;
100         final int d = 2;
101 
102         final List<DoublePoint> testPoints = generatePoints(n, d);
103         final KMeansPlusPlusClusterer<DoublePoint> clusterer = new KMeansPlusPlusClusterer<>(1);
104 
105         final List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(testPoints);
106 
107         final VectorialMean mean = new VectorialMean(d);
108         for (DoublePoint each : testPoints) {
109             mean.increment(each.getPoint());
110         }
111         Assert.assertEquals(1, clusters.size());
112         Assert.assertArrayEquals(mean.getResult(), clusters.get(0).getCenter().getPoint(), 1e-6);
113     }
114 
115     /**
116      * Generates a list of random uncorrelated points to cluster.
117      *
118      * @param n number of points
119      * @param d dimensionality
120      * @return list of n generated random vectors of dimension d.
121      */
122     private static List<DoublePoint> generatePoints(int n, int d) {
123         final List<DoublePoint> results = new ArrayList<>();
124         final double[] lower = new double[d];
125         final double[] upper = new double[d];
126         Arrays.fill(upper, 1);
127         final BoxSampler rnd = BoxSampler.of(RandomSource.KISS.create(),
128                                              lower,
129                                              upper);
130 
131         for (int i = 0; i < n; i++) {
132             results.add(new DoublePoint(rnd.sample()));
133         }
134 
135         return results;
136     }
137 }