CalinskiHarabasz.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.ml.clustering.evaluation;

  18. import org.apache.commons.math4.legacy.exception.InsufficientDataException;
  19. import org.apache.commons.math4.legacy.ml.clustering.Cluster;
  20. import org.apache.commons.math4.legacy.ml.clustering.ClusterEvaluator;
  21. import org.apache.commons.math4.legacy.ml.clustering.Clusterable;
  22. import org.apache.commons.math4.legacy.core.MathArrays;

  23. import java.util.Collection;
  24. import java.util.List;

  25. /**
  26.  * Compute the Calinski and Harabasz score.
  27.  * <p>
  28.  * It is also known as the Variance Ratio Criterion.
  29.  * <p>
  30.  * The score is defined as ratio between the within-cluster dispersion and
  31.  * the between-cluster dispersion.
  32.  *
  33.  * @see <a href="https://www.tandfonline.com/doi/abs/10.1080/03610927408827101">A dendrite method for cluster
  34.  * analysis</a>
  35.  */
  36. public class CalinskiHarabasz implements ClusterEvaluator {
  37.     /** {@inheritDoc} */
  38.     @Override
  39.     public double score(List<? extends Cluster<? extends Clusterable>> clusters) {
  40.         final int dimension = dimensionOfClusters(clusters);
  41.         final double[] centroid = meanOfClusters(clusters, dimension);

  42.         double intraDistanceProduct = 0.0;
  43.         double extraDistanceProduct = 0.0;
  44.         for (Cluster<? extends Clusterable> cluster : clusters) {
  45.             // Calculate the center of the cluster.
  46.             double[] clusterCentroid = mean(cluster.getPoints(), dimension);
  47.             for (Clusterable p : cluster.getPoints()) {
  48.                 // Increase the intra distance sum
  49.                 intraDistanceProduct += covariance(clusterCentroid, p.getPoint());
  50.             }
  51.             // Increase the extra distance sum
  52.             extraDistanceProduct += cluster.getPoints().size() * covariance(centroid, clusterCentroid);
  53.         }

  54.         final int pointCount = countAllPoints(clusters);
  55.         final int clusterCount = clusters.size();
  56.         // Return the ratio of the intraDistranceProduct to extraDistanceProduct
  57.         return intraDistanceProduct == 0.0 ? 1.0 :
  58.                 (extraDistanceProduct * (pointCount - clusterCount) /
  59.                         (intraDistanceProduct * (clusterCount - 1)));
  60.     }

  61.     /** {@inheritDoc} */
  62.     @Override
  63.     public boolean isBetterScore(double a,
  64.                                  double b) {
  65.         return a > b;
  66.     }

  67.     /**
  68.      * Calculate covariance of two double array.
  69.      * <pre>
  70.      *   covariance = sum((p1[i]-p2[i])^2)
  71.      * </pre>
  72.      *
  73.      * @param p1 Double array
  74.      * @param p2 Double array
  75.      * @return covariance of two double array
  76.      */
  77.     private double covariance(double[] p1, double[] p2) {
  78.         MathArrays.checkEqualLength(p1, p2);
  79.         double sum = 0;
  80.         for (int i = 0; i < p1.length; i++) {
  81.             final double dp = p1[i] - p2[i];
  82.             sum += dp * dp;
  83.         }
  84.         return sum;
  85.     }

  86.     /**
  87.      * Calculate the mean of all the points.
  88.      *
  89.      * @param points    A collection of points
  90.      * @param dimension The dimension of each point
  91.      * @return The mean value.
  92.      */
  93.     private double[] mean(final Collection<? extends Clusterable> points, final int dimension) {
  94.         final double[] centroid = new double[dimension];
  95.         for (final Clusterable p : points) {
  96.             final double[] point = p.getPoint();
  97.             for (int i = 0; i < centroid.length; i++) {
  98.                 centroid[i] += point[i];
  99.             }
  100.         }
  101.         for (int i = 0; i < centroid.length; i++) {
  102.             centroid[i] /= points.size();
  103.         }
  104.         return centroid;
  105.     }

  106.     /**
  107.      * Calculate the mean of all the points in the clusters.
  108.      *
  109.      * @param clusters  A collection of clusters.
  110.      * @param dimension The dimension of each point.
  111.      * @return The mean value.
  112.      */
  113.     private double[] meanOfClusters(final Collection<? extends Cluster<? extends Clusterable>> clusters, final int dimension) {
  114.         final double[] centroid = new double[dimension];
  115.         int allPointsCount = 0;
  116.         for (Cluster<? extends Clusterable> cluster : clusters) {
  117.             for (Clusterable p : cluster.getPoints()) {
  118.                 double[] point = p.getPoint();
  119.                 for (int i = 0; i < centroid.length; i++) {
  120.                     centroid[i] += point[i];
  121.                 }
  122.                 allPointsCount++;
  123.             }
  124.         }
  125.         for (int i = 0; i < centroid.length; i++) {
  126.             centroid[i] /= allPointsCount;
  127.         }
  128.         return centroid;
  129.     }

  130.     /**
  131.      * Count all the points in collection of cluster.
  132.      *
  133.      * @param clusters collection of cluster
  134.      * @return points count
  135.      */
  136.     private int countAllPoints(final Collection<? extends Cluster<? extends Clusterable>> clusters) {
  137.         int pointCount = 0;
  138.         for (Cluster<? extends Clusterable> cluster : clusters) {
  139.             pointCount += cluster.getPoints().size();
  140.         }
  141.         return pointCount;
  142.     }

  143.     /**
  144.      * Detect the dimension of points in the clusters.
  145.      *
  146.      * @param clusters collection of cluster
  147.      * @return The dimension of the first point in clusters
  148.      */
  149.     private int dimensionOfClusters(final Collection<? extends Cluster<? extends Clusterable>> clusters) {
  150.         // Iteration and find out the first point.
  151.         for (Cluster<? extends Clusterable> cluster : clusters) {
  152.             for (Clusterable p : cluster.getPoints()) {
  153.                 return p.getPoint().length;
  154.             }
  155.         }
  156.         // Throw exception if there is no point.
  157.         throw new InsufficientDataException();
  158.     }
  159. }