1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.math4.legacy.stat.correlation;
18
19 import java.util.Arrays;
20 import java.util.Comparator;
21
22 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
23 import org.apache.commons.math4.legacy.linear.BlockRealMatrix;
24 import org.apache.commons.math4.legacy.linear.MatrixUtils;
25 import org.apache.commons.math4.legacy.linear.RealMatrix;
26 import org.apache.commons.math4.core.jdkmath.JdkMath;
27 import org.apache.commons.math4.legacy.core.Pair;
28
29 /**
30 * Implementation of Kendall's Tau-b rank correlation.
31 * <p>
32 * A pair of observations (x<sub>1</sub>, y<sub>1</sub>) and
33 * (x<sub>2</sub>, y<sub>2</sub>) are considered <i>concordant</i> if
34 * x<sub>1</sub> < x<sub>2</sub> and y<sub>1</sub> < y<sub>2</sub>
35 * or x<sub>2</sub> < x<sub>1</sub> and y<sub>2</sub> < y<sub>1</sub>.
36 * The pair is <i>discordant</i> if x<sub>1</sub> < x<sub>2</sub> and
37 * y<sub>2</sub> < y<sub>1</sub> or x<sub>2</sub> < x<sub>1</sub> and
38 * y<sub>1</sub> < y<sub>2</sub>. If either x<sub>1</sub> = x<sub>2</sub>
39 * or y<sub>1</sub> = y<sub>2</sub>, the pair is neither concordant nor
40 * discordant.
41 * <p>
42 * Kendall's Tau-b is defined as:
43 * <div style="white-space: pre"><code>
44 * tau<sub>b</sub> = (n<sub>c</sub> - n<sub>d</sub>) / sqrt((n<sub>0</sub> - n<sub>1</sub>) * (n<sub>0</sub> - n<sub>2</sub>))
45 * </code></div>
46 * <p>
47 * where:
48 * <ul>
49 * <li>n<sub>0</sub> = n * (n - 1) / 2</li>
50 * <li>n<sub>c</sub> = Number of concordant pairs</li>
51 * <li>n<sub>d</sub> = Number of discordant pairs</li>
52 * <li>n<sub>1</sub> = sum of t<sub>i</sub> * (t<sub>i</sub> - 1) / 2 for all i</li>
53 * <li>n<sub>2</sub> = sum of u<sub>j</sub> * (u<sub>j</sub> - 1) / 2 for all j</li>
54 * <li>t<sub>i</sub> = Number of tied values in the i<sup>th</sup> group of ties in x</li>
55 * <li>u<sub>j</sub> = Number of tied values in the j<sup>th</sup> group of ties in y</li>
56 * </ul>
57 * <p>
58 * This implementation uses the O(n log n) algorithm described in
59 * William R. Knight's 1966 paper "A Computer Method for Calculating
60 * Kendall's Tau with Ungrouped Data" in the Journal of the American
61 * Statistical Association.
62 *
63 * @see <a href="http://en.wikipedia.org/wiki/Kendall_tau_rank_correlation_coefficient">
64 * Kendall tau rank correlation coefficient (Wikipedia)</a>
65 * @see <a href="http://www.jstor.org/stable/2282833">A Computer
66 * Method for Calculating Kendall's Tau with Ungrouped Data</a>
67 *
68 * @since 3.3
69 */
70 public class KendallsCorrelation {
71
72 /** correlation matrix. */
73 private final RealMatrix correlationMatrix;
74
75 /**
76 * Create a KendallsCorrelation instance without data.
77 */
78 public KendallsCorrelation() {
79 correlationMatrix = null;
80 }
81
82 /**
83 * Create a KendallsCorrelation from a rectangular array
84 * whose columns represent values of variables to be correlated.
85 *
86 * @param data rectangular array with columns representing variables
87 * @throws IllegalArgumentException if the input data array is not
88 * rectangular with at least two rows and two columns.
89 */
90 public KendallsCorrelation(double[][] data) {
91 this(MatrixUtils.createRealMatrix(data));
92 }
93
94 /**
95 * Create a KendallsCorrelation from a RealMatrix whose columns
96 * represent variables to be correlated.
97 *
98 * @param matrix matrix with columns representing variables to correlate
99 */
100 public KendallsCorrelation(RealMatrix matrix) {
101 correlationMatrix = computeCorrelationMatrix(matrix);
102 }
103
104 /**
105 * Returns the correlation matrix.
106 *
107 * @return correlation matrix
108 */
109 public RealMatrix getCorrelationMatrix() {
110 return correlationMatrix;
111 }
112
113 /**
114 * Computes the Kendall's Tau rank correlation matrix for the columns of
115 * the input matrix.
116 *
117 * @param matrix matrix with columns representing variables to correlate
118 * @return correlation matrix
119 */
120 public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) {
121 int nVars = matrix.getColumnDimension();
122 RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars);
123 for (int i = 0; i < nVars; i++) {
124 for (int j = 0; j < i; j++) {
125 double corr = correlation(matrix.getColumn(i), matrix.getColumn(j));
126 outMatrix.setEntry(i, j, corr);
127 outMatrix.setEntry(j, i, corr);
128 }
129 outMatrix.setEntry(i, i, 1d);
130 }
131 return outMatrix;
132 }
133
134 /**
135 * Computes the Kendall's Tau rank correlation matrix for the columns of
136 * the input rectangular array. The columns of the array represent values
137 * of variables to be correlated.
138 *
139 * @param matrix matrix with columns representing variables to correlate
140 * @return correlation matrix
141 */
142 public RealMatrix computeCorrelationMatrix(final double[][] matrix) {
143 return computeCorrelationMatrix(new BlockRealMatrix(matrix));
144 }
145
146 /**
147 * Computes the Kendall's Tau rank correlation coefficient between the two arrays.
148 *
149 * @param xArray first data array
150 * @param yArray second data array
151 * @return Returns Kendall's Tau rank correlation coefficient for the two arrays
152 * @throws DimensionMismatchException if the arrays lengths do not match
153 */
154 public double correlation(final double[] xArray, final double[] yArray)
155 throws DimensionMismatchException {
156
157 if (xArray.length != yArray.length) {
158 throw new DimensionMismatchException(xArray.length, yArray.length);
159 }
160
161 final int n = xArray.length;
162 final long numPairs = sum(n - 1);
163
164 @SuppressWarnings("unchecked")
165 Pair<Double, Double>[] pairs = new Pair[n];
166 for (int i = 0; i < n; i++) {
167 pairs[i] = new Pair<>(xArray[i], yArray[i]);
168 }
169
170 Arrays.sort(pairs, new Comparator<Pair<Double, Double>>() {
171 /** {@inheritDoc} */
172 @Override
173 public int compare(Pair<Double, Double> pair1, Pair<Double, Double> pair2) {
174 int compareFirst = pair1.getFirst().compareTo(pair2.getFirst());
175 return compareFirst != 0 ? compareFirst : pair1.getSecond().compareTo(pair2.getSecond());
176 }
177 });
178
179 long tiedXPairs = 0;
180 long tiedXYPairs = 0;
181 long consecutiveXTies = 1;
182 long consecutiveXYTies = 1;
183 Pair<Double, Double> prev = pairs[0];
184 for (int i = 1; i < n; i++) {
185 final Pair<Double, Double> curr = pairs[i];
186 if (curr.getFirst().equals(prev.getFirst())) {
187 consecutiveXTies++;
188 if (curr.getSecond().equals(prev.getSecond())) {
189 consecutiveXYTies++;
190 } else {
191 tiedXYPairs += sum(consecutiveXYTies - 1);
192 consecutiveXYTies = 1;
193 }
194 } else {
195 tiedXPairs += sum(consecutiveXTies - 1);
196 consecutiveXTies = 1;
197 tiedXYPairs += sum(consecutiveXYTies - 1);
198 consecutiveXYTies = 1;
199 }
200 prev = curr;
201 }
202 tiedXPairs += sum(consecutiveXTies - 1);
203 tiedXYPairs += sum(consecutiveXYTies - 1);
204
205 long swaps = 0;
206 @SuppressWarnings("unchecked")
207 Pair<Double, Double>[] pairsDestination = new Pair[n];
208 for (int segmentSize = 1; segmentSize < n; segmentSize <<= 1) {
209 for (int offset = 0; offset < n; offset += 2 * segmentSize) {
210 int i = offset;
211 final int iEnd = JdkMath.min(i + segmentSize, n);
212 int j = iEnd;
213 final int jEnd = JdkMath.min(j + segmentSize, n);
214
215 int copyLocation = offset;
216 while (i < iEnd || j < jEnd) {
217 if (i < iEnd) {
218 if (j < jEnd) {
219 if (pairs[i].getSecond().compareTo(pairs[j].getSecond()) <= 0) {
220 pairsDestination[copyLocation] = pairs[i];
221 i++;
222 } else {
223 pairsDestination[copyLocation] = pairs[j];
224 j++;
225 swaps += iEnd - i;
226 }
227 } else {
228 pairsDestination[copyLocation] = pairs[i];
229 i++;
230 }
231 } else {
232 pairsDestination[copyLocation] = pairs[j];
233 j++;
234 }
235 copyLocation++;
236 }
237 }
238 final Pair<Double, Double>[] pairsTemp = pairs;
239 pairs = pairsDestination;
240 pairsDestination = pairsTemp;
241 }
242
243 long tiedYPairs = 0;
244 long consecutiveYTies = 1;
245 prev = pairs[0];
246 for (int i = 1; i < n; i++) {
247 final Pair<Double, Double> curr = pairs[i];
248 if (curr.getSecond().equals(prev.getSecond())) {
249 consecutiveYTies++;
250 } else {
251 tiedYPairs += sum(consecutiveYTies - 1);
252 consecutiveYTies = 1;
253 }
254 prev = curr;
255 }
256 tiedYPairs += sum(consecutiveYTies - 1);
257
258 final long concordantMinusDiscordant = numPairs - tiedXPairs - tiedYPairs + tiedXYPairs - 2 * swaps;
259 final double nonTiedPairsMultiplied = (numPairs - tiedXPairs) * (double) (numPairs - tiedYPairs);
260 return concordantMinusDiscordant / JdkMath.sqrt(nonTiedPairsMultiplied);
261 }
262
263 /**
264 * Returns the sum of the number from 1 .. n according to Gauss' summation formula:
265 * \[ \sum\limits_{k=1}^n k = \frac{n(n + 1)}{2} \]
266 *
267 * @param n the summation end
268 * @return the sum of the number from 1 to n
269 */
270 private static long sum(long n) {
271 return n * (n + 1) / 2L;
272 }
273 }