001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.ml.clustering.evaluation;
019
020import org.apache.commons.math4.exception.InsufficientDataException;
021import org.apache.commons.math4.ml.clustering.Cluster;
022import org.apache.commons.math4.ml.clustering.ClusterEvaluator;
023import org.apache.commons.math4.ml.clustering.Clusterable;
024import org.apache.commons.math4.util.MathArrays;
025
026import java.util.Collection;
027import java.util.List;
028
029/**
030 * Compute the Calinski and Harabasz score.
031 * <p>
032 * It is also known as the Variance Ratio Criterion.
033 * <p>
034 * The score is defined as ratio between the within-cluster dispersion and
035 * the between-cluster dispersion.
036 *
037 * @see <a href="https://www.tandfonline.com/doi/abs/10.1080/03610927408827101">A dendrite method for cluster
038 * analysis</a>
039 */
040public class CalinskiHarabasz implements ClusterEvaluator {
041    /** {@inheritDoc} */
042    @Override
043    public double score(List<? extends Cluster<? extends Clusterable>> clusters) {
044        final int dimension = dimensionOfClusters(clusters);
045        final double[] centroid = meanOfClusters(clusters, dimension);
046
047        double intraDistanceProduct = 0.0;
048        double extraDistanceProduct = 0.0;
049        for (Cluster<? extends Clusterable> cluster : clusters) {
050            // Calculate the center of the cluster.
051            double[] clusterCentroid = mean(cluster.getPoints(), dimension);
052            for (Clusterable p : cluster.getPoints()) {
053                // Increase the intra distance sum
054                intraDistanceProduct += covariance(clusterCentroid, p.getPoint());
055            }
056            // Increase the extra distance sum
057            extraDistanceProduct += cluster.getPoints().size() * covariance(centroid, clusterCentroid);
058        }
059
060        final int pointCount = countAllPoints(clusters);
061        final int clusterCount = clusters.size();
062        // Return the ratio of the intraDistranceProduct to extraDistanceProduct
063        return intraDistanceProduct == 0.0 ? 1.0 :
064                (extraDistanceProduct * (pointCount - clusterCount) /
065                        (intraDistanceProduct * (clusterCount - 1)));
066    }
067
068    /** {@inheritDoc} */
069    @Override
070    public boolean isBetterScore(double a,
071                                 double b) {
072        return a > b;
073    }
074
075    /**
076     * Calculate covariance of two double array.
077     * <pre>
078     *   covariance = sum((p1[i]-p2[i])^2)
079     * </pre>
080     *
081     * @param p1 Double array
082     * @param p2 Double array
083     * @return covariance of two double array
084     */
085    private double covariance(double[] p1, double[] p2) {
086        MathArrays.checkEqualLength(p1, p2);
087        double sum = 0;
088        for (int i = 0; i < p1.length; i++) {
089            final double dp = p1[i] - p2[i];
090            sum += dp * dp;
091        }
092        return sum;
093    }
094
095    /**
096     * Calculate the mean of all the points.
097     *
098     * @param points    A collection of points
099     * @param dimension The dimension of each point
100     * @return The mean value.
101     */
102    private double[] mean(final Collection<? extends Clusterable> points, final int dimension) {
103        final double[] centroid = new double[dimension];
104        for (final Clusterable p : points) {
105            final double[] point = p.getPoint();
106            for (int i = 0; i < centroid.length; i++) {
107                centroid[i] += point[i];
108            }
109        }
110        for (int i = 0; i < centroid.length; i++) {
111            centroid[i] /= points.size();
112        }
113        return centroid;
114    }
115
116    /**
117     * Calculate the mean of all the points in the clusters.
118     *
119     * @param clusters  A collection of clusters.
120     * @param dimension The dimension of each point.
121     * @return The mean value.
122     */
123    private double[] meanOfClusters(final Collection<? extends Cluster<? extends Clusterable>> clusters, final int dimension) {
124        final double[] centroid = new double[dimension];
125        int allPointsCount = 0;
126        for (Cluster<? extends Clusterable> cluster : clusters) {
127            for (Clusterable p : cluster.getPoints()) {
128                double[] point = p.getPoint();
129                for (int i = 0; i < centroid.length; i++) {
130                    centroid[i] += point[i];
131                }
132                allPointsCount++;
133            }
134        }
135        for (int i = 0; i < centroid.length; i++) {
136            centroid[i] /= allPointsCount;
137        }
138        return centroid;
139    }
140
141    /**
142     * Count all the points in collection of cluster.
143     *
144     * @param clusters collection of cluster
145     * @return points count
146     */
147    private int countAllPoints(final Collection<? extends Cluster<? extends Clusterable>> clusters) {
148        int pointCount = 0;
149        for (Cluster<? extends Clusterable> cluster : clusters) {
150            pointCount += cluster.getPoints().size();
151        }
152        return pointCount;
153    }
154
155    /**
156     * Detect the dimension of points in the clusters
157     *
158     * @param clusters collection of cluster
159     * @return The dimension of the first point in clusters
160     */
161    private int dimensionOfClusters(final Collection<? extends Cluster<? extends Clusterable>> clusters) {
162        // Iteration and find out the first point.
163        for (Cluster<? extends Clusterable> cluster : clusters) {
164            for (Clusterable p : cluster.getPoints()) {
165                return p.getPoint().length;
166            }
167        }
168        // Throw exception if there is no point.
169        throw new InsufficientDataException();
170    }
171}