1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.stat.correlation;
18
19 import java.util.Arrays;
20 import java.util.Comparator;
21
22 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
23 import org.apache.commons.math4.legacy.linear.BlockRealMatrix;
24 import org.apache.commons.math4.legacy.linear.MatrixUtils;
25 import org.apache.commons.math4.legacy.linear.RealMatrix;
26 import org.apache.commons.math4.core.jdkmath.JdkMath;
27 import org.apache.commons.math4.legacy.core.Pair;
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70 public class KendallsCorrelation {
71
72
73 private final RealMatrix correlationMatrix;
74
75
76
77
78 public KendallsCorrelation() {
79 correlationMatrix = null;
80 }
81
82
83
84
85
86
87
88
89
90 public KendallsCorrelation(double[][] data) {
91 this(MatrixUtils.createRealMatrix(data));
92 }
93
94
95
96
97
98
99
100 public KendallsCorrelation(RealMatrix matrix) {
101 correlationMatrix = computeCorrelationMatrix(matrix);
102 }
103
104
105
106
107
108
109 public RealMatrix getCorrelationMatrix() {
110 return correlationMatrix;
111 }
112
113
114
115
116
117
118
119
120 public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) {
121 int nVars = matrix.getColumnDimension();
122 RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars);
123 for (int i = 0; i < nVars; i++) {
124 for (int j = 0; j < i; j++) {
125 double corr = correlation(matrix.getColumn(i), matrix.getColumn(j));
126 outMatrix.setEntry(i, j, corr);
127 outMatrix.setEntry(j, i, corr);
128 }
129 outMatrix.setEntry(i, i, 1d);
130 }
131 return outMatrix;
132 }
133
134
135
136
137
138
139
140
141
142 public RealMatrix computeCorrelationMatrix(final double[][] matrix) {
143 return computeCorrelationMatrix(new BlockRealMatrix(matrix));
144 }
145
146
147
148
149
150
151
152
153
154 public double correlation(final double[] xArray, final double[] yArray)
155 throws DimensionMismatchException {
156
157 if (xArray.length != yArray.length) {
158 throw new DimensionMismatchException(xArray.length, yArray.length);
159 }
160
161 final int n = xArray.length;
162 final long numPairs = sum(n - 1);
163
164 @SuppressWarnings("unchecked")
165 Pair<Double, Double>[] pairs = new Pair[n];
166 for (int i = 0; i < n; i++) {
167 pairs[i] = new Pair<>(xArray[i], yArray[i]);
168 }
169
170 Arrays.sort(pairs, new Comparator<Pair<Double, Double>>() {
171
172 @Override
173 public int compare(Pair<Double, Double> pair1, Pair<Double, Double> pair2) {
174 int compareFirst = pair1.getFirst().compareTo(pair2.getFirst());
175 return compareFirst != 0 ? compareFirst : pair1.getSecond().compareTo(pair2.getSecond());
176 }
177 });
178
179 long tiedXPairs = 0;
180 long tiedXYPairs = 0;
181 long consecutiveXTies = 1;
182 long consecutiveXYTies = 1;
183 Pair<Double, Double> prev = pairs[0];
184 for (int i = 1; i < n; i++) {
185 final Pair<Double, Double> curr = pairs[i];
186 if (curr.getFirst().equals(prev.getFirst())) {
187 consecutiveXTies++;
188 if (curr.getSecond().equals(prev.getSecond())) {
189 consecutiveXYTies++;
190 } else {
191 tiedXYPairs += sum(consecutiveXYTies - 1);
192 consecutiveXYTies = 1;
193 }
194 } else {
195 tiedXPairs += sum(consecutiveXTies - 1);
196 consecutiveXTies = 1;
197 tiedXYPairs += sum(consecutiveXYTies - 1);
198 consecutiveXYTies = 1;
199 }
200 prev = curr;
201 }
202 tiedXPairs += sum(consecutiveXTies - 1);
203 tiedXYPairs += sum(consecutiveXYTies - 1);
204
205 long swaps = 0;
206 @SuppressWarnings("unchecked")
207 Pair<Double, Double>[] pairsDestination = new Pair[n];
208 for (int segmentSize = 1; segmentSize < n; segmentSize <<= 1) {
209 for (int offset = 0; offset < n; offset += 2 * segmentSize) {
210 int i = offset;
211 final int iEnd = JdkMath.min(i + segmentSize, n);
212 int j = iEnd;
213 final int jEnd = JdkMath.min(j + segmentSize, n);
214
215 int copyLocation = offset;
216 while (i < iEnd || j < jEnd) {
217 if (i < iEnd) {
218 if (j < jEnd) {
219 if (pairs[i].getSecond().compareTo(pairs[j].getSecond()) <= 0) {
220 pairsDestination[copyLocation] = pairs[i];
221 i++;
222 } else {
223 pairsDestination[copyLocation] = pairs[j];
224 j++;
225 swaps += iEnd - i;
226 }
227 } else {
228 pairsDestination[copyLocation] = pairs[i];
229 i++;
230 }
231 } else {
232 pairsDestination[copyLocation] = pairs[j];
233 j++;
234 }
235 copyLocation++;
236 }
237 }
238 final Pair<Double, Double>[] pairsTemp = pairs;
239 pairs = pairsDestination;
240 pairsDestination = pairsTemp;
241 }
242
243 long tiedYPairs = 0;
244 long consecutiveYTies = 1;
245 prev = pairs[0];
246 for (int i = 1; i < n; i++) {
247 final Pair<Double, Double> curr = pairs[i];
248 if (curr.getSecond().equals(prev.getSecond())) {
249 consecutiveYTies++;
250 } else {
251 tiedYPairs += sum(consecutiveYTies - 1);
252 consecutiveYTies = 1;
253 }
254 prev = curr;
255 }
256 tiedYPairs += sum(consecutiveYTies - 1);
257
258 final long concordantMinusDiscordant = numPairs - tiedXPairs - tiedYPairs + tiedXYPairs - 2 * swaps;
259 final double nonTiedPairsMultiplied = (numPairs - tiedXPairs) * (double) (numPairs - tiedYPairs);
260 return concordantMinusDiscordant / JdkMath.sqrt(nonTiedPairsMultiplied);
261 }
262
263
264
265
266
267
268
269
270 private static long sum(long n) {
271 return n * (n + 1) / 2L;
272 }
273 }