001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math4.legacy.ml.clustering.evaluation; 019 020import org.apache.commons.math4.legacy.exception.InsufficientDataException; 021import org.apache.commons.math4.legacy.ml.clustering.Cluster; 022import org.apache.commons.math4.legacy.ml.clustering.ClusterEvaluator; 023import org.apache.commons.math4.legacy.ml.clustering.Clusterable; 024import org.apache.commons.math4.legacy.core.MathArrays; 025 026import java.util.Collection; 027import java.util.List; 028 029/** 030 * Compute the Calinski and Harabasz score. 031 * <p> 032 * It is also known as the Variance Ratio Criterion. 033 * <p> 034 * The score is defined as ratio between the within-cluster dispersion and 035 * the between-cluster dispersion. 036 * 037 * @see <a href="https://www.tandfonline.com/doi/abs/10.1080/03610927408827101">A dendrite method for cluster 038 * analysis</a> 039 */ 040public class CalinskiHarabasz implements ClusterEvaluator { 041 /** {@inheritDoc} */ 042 @Override 043 public double score(List<? extends Cluster<? extends Clusterable>> clusters) { 044 final int dimension = dimensionOfClusters(clusters); 045 final double[] centroid = meanOfClusters(clusters, dimension); 046 047 double intraDistanceProduct = 0.0; 048 double extraDistanceProduct = 0.0; 049 for (Cluster<? extends Clusterable> cluster : clusters) { 050 // Calculate the center of the cluster. 051 double[] clusterCentroid = mean(cluster.getPoints(), dimension); 052 for (Clusterable p : cluster.getPoints()) { 053 // Increase the intra distance sum 054 intraDistanceProduct += covariance(clusterCentroid, p.getPoint()); 055 } 056 // Increase the extra distance sum 057 extraDistanceProduct += cluster.getPoints().size() * covariance(centroid, clusterCentroid); 058 } 059 060 final int pointCount = countAllPoints(clusters); 061 final int clusterCount = clusters.size(); 062 // Return the ratio of the intraDistranceProduct to extraDistanceProduct 063 return intraDistanceProduct == 0.0 ? 1.0 : 064 (extraDistanceProduct * (pointCount - clusterCount) / 065 (intraDistanceProduct * (clusterCount - 1))); 066 } 067 068 /** {@inheritDoc} */ 069 @Override 070 public boolean isBetterScore(double a, 071 double b) { 072 return a > b; 073 } 074 075 /** 076 * Calculate covariance of two double array. 077 * <pre> 078 * covariance = sum((p1[i]-p2[i])^2) 079 * </pre> 080 * 081 * @param p1 Double array 082 * @param p2 Double array 083 * @return covariance of two double array 084 */ 085 private double covariance(double[] p1, double[] p2) { 086 MathArrays.checkEqualLength(p1, p2); 087 double sum = 0; 088 for (int i = 0; i < p1.length; i++) { 089 final double dp = p1[i] - p2[i]; 090 sum += dp * dp; 091 } 092 return sum; 093 } 094 095 /** 096 * Calculate the mean of all the points. 097 * 098 * @param points A collection of points 099 * @param dimension The dimension of each point 100 * @return The mean value. 101 */ 102 private double[] mean(final Collection<? extends Clusterable> points, final int dimension) { 103 final double[] centroid = new double[dimension]; 104 for (final Clusterable p : points) { 105 final double[] point = p.getPoint(); 106 for (int i = 0; i < centroid.length; i++) { 107 centroid[i] += point[i]; 108 } 109 } 110 for (int i = 0; i < centroid.length; i++) { 111 centroid[i] /= points.size(); 112 } 113 return centroid; 114 } 115 116 /** 117 * Calculate the mean of all the points in the clusters. 118 * 119 * @param clusters A collection of clusters. 120 * @param dimension The dimension of each point. 121 * @return The mean value. 122 */ 123 private double[] meanOfClusters(final Collection<? extends Cluster<? extends Clusterable>> clusters, final int dimension) { 124 final double[] centroid = new double[dimension]; 125 int allPointsCount = 0; 126 for (Cluster<? extends Clusterable> cluster : clusters) { 127 for (Clusterable p : cluster.getPoints()) { 128 double[] point = p.getPoint(); 129 for (int i = 0; i < centroid.length; i++) { 130 centroid[i] += point[i]; 131 } 132 allPointsCount++; 133 } 134 } 135 for (int i = 0; i < centroid.length; i++) { 136 centroid[i] /= allPointsCount; 137 } 138 return centroid; 139 } 140 141 /** 142 * Count all the points in collection of cluster. 143 * 144 * @param clusters collection of cluster 145 * @return points count 146 */ 147 private int countAllPoints(final Collection<? extends Cluster<? extends Clusterable>> clusters) { 148 int pointCount = 0; 149 for (Cluster<? extends Clusterable> cluster : clusters) { 150 pointCount += cluster.getPoints().size(); 151 } 152 return pointCount; 153 } 154 155 /** 156 * Detect the dimension of points in the clusters. 157 * 158 * @param clusters collection of cluster 159 * @return The dimension of the first point in clusters 160 */ 161 private int dimensionOfClusters(final Collection<? extends Cluster<? extends Clusterable>> clusters) { 162 // Iteration and find out the first point. 163 for (Cluster<? extends Clusterable> cluster : clusters) { 164 for (Clusterable p : cluster.getPoints()) { 165 return p.getPoint().length; 166 } 167 } 168 // Throw exception if there is no point. 169 throw new InsufficientDataException(); 170 } 171}