View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.ml.clustering;
19  
20  import java.util.ArrayList;
21  import java.util.Arrays;
22  import java.util.Collection;
23  import java.util.List;
24  
25  import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
26  import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance;
27  import org.apache.commons.rng.simple.RandomSource;
28  import org.apache.commons.rng.UniformRandomProvider;
29  import org.junit.Assert;
30  import org.junit.Before;
31  import org.junit.Test;
32  
33  public class KMeansPlusPlusClustererTest {
34  
35      private UniformRandomProvider random;
36  
37      @Before
38      public void setUp() {
39          random = RandomSource.MT_64.create(1746432956321L);
40      }
41  
42      /**
43       * JIRA: MATH-305
44       *
45       * Two points, one cluster, one iteration
46       */
47      @Test
48      public void testPerformClusterAnalysisDegenerate() {
49          KMeansPlusPlusClusterer<DoublePoint> transformer =
50                  new KMeansPlusPlusClusterer<>(1, 1);
51  
52          DoublePoint[] points = new DoublePoint[] {
53                  new DoublePoint(new int[] { 1959, 325100 }),
54                  new DoublePoint(new int[] { 1960, 373200 }), };
55          List<? extends Cluster<DoublePoint>> clusters = transformer.cluster(Arrays.asList(points));
56          Assert.assertEquals(1, clusters.size());
57          Assert.assertEquals(2, clusters.get(0).getPoints().size());
58          DoublePoint pt1 = new DoublePoint(new int[] { 1959, 325100 });
59          DoublePoint pt2 = new DoublePoint(new int[] { 1960, 373200 });
60          Assert.assertTrue(clusters.get(0).getPoints().contains(pt1));
61          Assert.assertTrue(clusters.get(0).getPoints().contains(pt2));
62      }
63  
64      @Test
65      public void testCertainSpace() {
66          KMeansPlusPlusClusterer.EmptyClusterStrategy[] strategies = {
67              KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE,
68              KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_POINTS_NUMBER,
69              KMeansPlusPlusClusterer.EmptyClusterStrategy.FARTHEST_POINT
70          };
71          for (KMeansPlusPlusClusterer.EmptyClusterStrategy strategy : strategies) {
72              int numberOfVariables = 27;
73              // initialize test values
74              int position1 = 1;
75              int position2 = position1 + numberOfVariables;
76              int position3 = position2 + numberOfVariables;
77              int position4 = position3 + numberOfVariables;
78              // test values will be multiplied
79              int multiplier = 1000000;
80  
81              DoublePoint[] breakingPoints = new DoublePoint[numberOfVariables];
82              // define the space which will break the cluster algorithm
83              for (int i = 0; i < numberOfVariables; i++) {
84                  int points[] = { position1, position2, position3, position4 };
85                  // multiply the values
86                  for (int j = 0; j < points.length; j++) {
87                      points[j] *= multiplier;
88                  }
89                  DoublePoint DoublePoint = new DoublePoint(points);
90                  breakingPoints[i] = DoublePoint;
91                  position1 += numberOfVariables;
92                  position2 += numberOfVariables;
93                  position3 += numberOfVariables;
94                  position4 += numberOfVariables;
95              }
96  
97              for (int n = 2; n < 27; ++n) {
98                  KMeansPlusPlusClusterer<DoublePoint> transformer =
99                      new KMeansPlusPlusClusterer<>(n, 100, new EuclideanDistance(), random, strategy);
100 
101                 List<? extends Cluster<DoublePoint>> clusters =
102                         transformer.cluster(Arrays.asList(breakingPoints));
103 
104                 Assert.assertEquals(n, clusters.size());
105                 int sum = 0;
106                 for (Cluster<DoublePoint> cluster : clusters) {
107                     sum += cluster.getPoints().size();
108                 }
109                 Assert.assertEquals(numberOfVariables, sum);
110             }
111         }
112     }
113 
114     /**
115      * A helper class for testSmallDistances(). This class is similar to DoublePoint, but
116      * it defines a different distanceFrom() method that tends to return distances less than 1.
117      */
118     private static final class CloseDistance extends EuclideanDistance {
119         private static final long serialVersionUID = 1L;
120 
121         @Override
122         public double compute(double[] a, double[] b) {
123             return super.compute(a, b) * 0.001;
124         }
125     }
126 
127     /**
128      * Test points that are very close together. See issue MATH-546.
129      */
130     @Test
131     public void testSmallDistances() {
132         // Create a bunch of CloseDoublePoints. Most are identical, but one is different by a
133         // small distance.
134         final int[] repeatedArray = { 0 };
135         final int[] uniqueArray = { 1 };
136         final DoublePoint repeatedPoint = new DoublePoint(repeatedArray);
137         final DoublePoint uniquePoint = new DoublePoint(uniqueArray);
138 
139         final Collection<DoublePoint> points = new ArrayList<>();
140         final int numRepeated = 10000;
141         for (int i = 0; i < numRepeated; i++) {
142             points.add(repeatedPoint);
143         }
144         points.add(uniquePoint);
145 
146         final KMeansPlusPlusClusterer<DoublePoint> clusterer =
147             new KMeansPlusPlusClusterer<>(2, 1, new CloseDistance(), random);
148         final List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(points);
149 
150         // Check that one of the chosen centers is the unique point.
151         boolean uniquePointIsCenter = false;
152         for (CentroidCluster<DoublePoint> cluster : clusters) {
153             if (cluster.getCenter().equals(uniquePoint)) {
154                 uniquePointIsCenter = true;
155             }
156         }
157         Assert.assertTrue(uniquePointIsCenter);
158     }
159 
160     /**
161      * 2 variables cannot be clustered into 3 clusters. See issue MATH-436.
162      */
163     @Test(expected=NumberIsTooSmallException.class)
164     public void testPerformClusterAnalysisToManyClusters() {
165         KMeansPlusPlusClusterer<DoublePoint> transformer =
166             new KMeansPlusPlusClusterer<>(3, 1, new EuclideanDistance(), random);
167 
168         DoublePoint[] points = new DoublePoint[] {
169             new DoublePoint(new int[] {
170                 1959, 325100
171             }), new DoublePoint(new int[] {
172                 1960, 373200
173             })
174         };
175 
176         transformer.cluster(Arrays.asList(points));
177     }
178 }