1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.ml.clustering.evaluation;
19
20 import org.apache.commons.math4.legacy.exception.InsufficientDataException;
21 import org.apache.commons.math4.legacy.ml.clustering.Cluster;
22 import org.apache.commons.math4.legacy.ml.clustering.ClusterEvaluator;
23 import org.apache.commons.math4.legacy.ml.clustering.Clusterable;
24 import org.apache.commons.math4.legacy.core.MathArrays;
25
26 import java.util.Collection;
27 import java.util.List;
28
29
30
31
32
33
34
35
36
37
38
39
40 public class CalinskiHarabasz implements ClusterEvaluator {
41
42 @Override
43 public double score(List<? extends Cluster<? extends Clusterable>> clusters) {
44 final int dimension = dimensionOfClusters(clusters);
45 final double[] centroid = meanOfClusters(clusters, dimension);
46
47 double intraDistanceProduct = 0.0;
48 double extraDistanceProduct = 0.0;
49 for (Cluster<? extends Clusterable> cluster : clusters) {
50
51 double[] clusterCentroid = mean(cluster.getPoints(), dimension);
52 for (Clusterable p : cluster.getPoints()) {
53
54 intraDistanceProduct += covariance(clusterCentroid, p.getPoint());
55 }
56
57 extraDistanceProduct += cluster.getPoints().size() * covariance(centroid, clusterCentroid);
58 }
59
60 final int pointCount = countAllPoints(clusters);
61 final int clusterCount = clusters.size();
62
63 return intraDistanceProduct == 0.0 ? 1.0 :
64 (extraDistanceProduct * (pointCount - clusterCount) /
65 (intraDistanceProduct * (clusterCount - 1)));
66 }
67
68
69 @Override
70 public boolean isBetterScore(double a,
71 double b) {
72 return a > b;
73 }
74
75
76
77
78
79
80
81
82
83
84
85 private double covariance(double[] p1, double[] p2) {
86 MathArrays.checkEqualLength(p1, p2);
87 double sum = 0;
88 for (int i = 0; i < p1.length; i++) {
89 final double dp = p1[i] - p2[i];
90 sum += dp * dp;
91 }
92 return sum;
93 }
94
95
96
97
98
99
100
101
102 private double[] mean(final Collection<? extends Clusterable> points, final int dimension) {
103 final double[] centroid = new double[dimension];
104 for (final Clusterable p : points) {
105 final double[] point = p.getPoint();
106 for (int i = 0; i < centroid.length; i++) {
107 centroid[i] += point[i];
108 }
109 }
110 for (int i = 0; i < centroid.length; i++) {
111 centroid[i] /= points.size();
112 }
113 return centroid;
114 }
115
116
117
118
119
120
121
122
123 private double[] meanOfClusters(final Collection<? extends Cluster<? extends Clusterable>> clusters, final int dimension) {
124 final double[] centroid = new double[dimension];
125 int allPointsCount = 0;
126 for (Cluster<? extends Clusterable> cluster : clusters) {
127 for (Clusterable p : cluster.getPoints()) {
128 double[] point = p.getPoint();
129 for (int i = 0; i < centroid.length; i++) {
130 centroid[i] += point[i];
131 }
132 allPointsCount++;
133 }
134 }
135 for (int i = 0; i < centroid.length; i++) {
136 centroid[i] /= allPointsCount;
137 }
138 return centroid;
139 }
140
141
142
143
144
145
146
147 private int countAllPoints(final Collection<? extends Cluster<? extends Clusterable>> clusters) {
148 int pointCount = 0;
149 for (Cluster<? extends Clusterable> cluster : clusters) {
150 pointCount += cluster.getPoints().size();
151 }
152 return pointCount;
153 }
154
155
156
157
158
159
160
161 private int dimensionOfClusters(final Collection<? extends Cluster<? extends Clusterable>> clusters) {
162
163 for (Cluster<? extends Clusterable> cluster : clusters) {
164 for (Clusterable p : cluster.getPoints()) {
165 return p.getPoint().length;
166 }
167 }
168
169 throw new InsufficientDataException();
170 }
171 }