View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.ml.clustering.evaluation;
19  
20  import org.apache.commons.math4.legacy.exception.InsufficientDataException;
21  import org.apache.commons.math4.legacy.ml.clustering.Cluster;
22  import org.apache.commons.math4.legacy.ml.clustering.ClusterEvaluator;
23  import org.apache.commons.math4.legacy.ml.clustering.Clusterable;
24  import org.apache.commons.math4.legacy.core.MathArrays;
25  
26  import java.util.Collection;
27  import java.util.List;
28  
29  /**
30   * Compute the Calinski and Harabasz score.
31   * <p>
32   * It is also known as the Variance Ratio Criterion.
33   * <p>
34   * The score is defined as ratio between the within-cluster dispersion and
35   * the between-cluster dispersion.
36   *
37   * @see <a href="https://www.tandfonline.com/doi/abs/10.1080/03610927408827101">A dendrite method for cluster
38   * analysis</a>
39   */
40  public class CalinskiHarabasz implements ClusterEvaluator {
41      /** {@inheritDoc} */
42      @Override
43      public double score(List<? extends Cluster<? extends Clusterable>> clusters) {
44          final int dimension = dimensionOfClusters(clusters);
45          final double[] centroid = meanOfClusters(clusters, dimension);
46  
47          double intraDistanceProduct = 0.0;
48          double extraDistanceProduct = 0.0;
49          for (Cluster<? extends Clusterable> cluster : clusters) {
50              // Calculate the center of the cluster.
51              double[] clusterCentroid = mean(cluster.getPoints(), dimension);
52              for (Clusterable p : cluster.getPoints()) {
53                  // Increase the intra distance sum
54                  intraDistanceProduct += covariance(clusterCentroid, p.getPoint());
55              }
56              // Increase the extra distance sum
57              extraDistanceProduct += cluster.getPoints().size() * covariance(centroid, clusterCentroid);
58          }
59  
60          final int pointCount = countAllPoints(clusters);
61          final int clusterCount = clusters.size();
62          // Return the ratio of the intraDistranceProduct to extraDistanceProduct
63          return intraDistanceProduct == 0.0 ? 1.0 :
64                  (extraDistanceProduct * (pointCount - clusterCount) /
65                          (intraDistanceProduct * (clusterCount - 1)));
66      }
67  
68      /** {@inheritDoc} */
69      @Override
70      public boolean isBetterScore(double a,
71                                   double b) {
72          return a > b;
73      }
74  
75      /**
76       * Calculate covariance of two double array.
77       * <pre>
78       *   covariance = sum((p1[i]-p2[i])^2)
79       * </pre>
80       *
81       * @param p1 Double array
82       * @param p2 Double array
83       * @return covariance of two double array
84       */
85      private double covariance(double[] p1, double[] p2) {
86          MathArrays.checkEqualLength(p1, p2);
87          double sum = 0;
88          for (int i = 0; i < p1.length; i++) {
89              final double dp = p1[i] - p2[i];
90              sum += dp * dp;
91          }
92          return sum;
93      }
94  
95      /**
96       * Calculate the mean of all the points.
97       *
98       * @param points    A collection of points
99       * @param dimension The dimension of each point
100      * @return The mean value.
101      */
102     private double[] mean(final Collection<? extends Clusterable> points, final int dimension) {
103         final double[] centroid = new double[dimension];
104         for (final Clusterable p : points) {
105             final double[] point = p.getPoint();
106             for (int i = 0; i < centroid.length; i++) {
107                 centroid[i] += point[i];
108             }
109         }
110         for (int i = 0; i < centroid.length; i++) {
111             centroid[i] /= points.size();
112         }
113         return centroid;
114     }
115 
116     /**
117      * Calculate the mean of all the points in the clusters.
118      *
119      * @param clusters  A collection of clusters.
120      * @param dimension The dimension of each point.
121      * @return The mean value.
122      */
123     private double[] meanOfClusters(final Collection<? extends Cluster<? extends Clusterable>> clusters, final int dimension) {
124         final double[] centroid = new double[dimension];
125         int allPointsCount = 0;
126         for (Cluster<? extends Clusterable> cluster : clusters) {
127             for (Clusterable p : cluster.getPoints()) {
128                 double[] point = p.getPoint();
129                 for (int i = 0; i < centroid.length; i++) {
130                     centroid[i] += point[i];
131                 }
132                 allPointsCount++;
133             }
134         }
135         for (int i = 0; i < centroid.length; i++) {
136             centroid[i] /= allPointsCount;
137         }
138         return centroid;
139     }
140 
141     /**
142      * Count all the points in collection of cluster.
143      *
144      * @param clusters collection of cluster
145      * @return points count
146      */
147     private int countAllPoints(final Collection<? extends Cluster<? extends Clusterable>> clusters) {
148         int pointCount = 0;
149         for (Cluster<? extends Clusterable> cluster : clusters) {
150             pointCount += cluster.getPoints().size();
151         }
152         return pointCount;
153     }
154 
155     /**
156      * Detect the dimension of points in the clusters.
157      *
158      * @param clusters collection of cluster
159      * @return The dimension of the first point in clusters
160      */
161     private int dimensionOfClusters(final Collection<? extends Cluster<? extends Clusterable>> clusters) {
162         // Iteration and find out the first point.
163         for (Cluster<? extends Clusterable> cluster : clusters) {
164             for (Clusterable p : cluster.getPoints()) {
165                 return p.getPoint().length;
166             }
167         }
168         // Throw exception if there is no point.
169         throw new InsufficientDataException();
170     }
171 }