001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math4.legacy.stat.correlation; 018 019import java.util.Arrays; 020import java.util.Comparator; 021 022import org.apache.commons.math4.legacy.exception.DimensionMismatchException; 023import org.apache.commons.math4.legacy.linear.BlockRealMatrix; 024import org.apache.commons.math4.legacy.linear.MatrixUtils; 025import org.apache.commons.math4.legacy.linear.RealMatrix; 026import org.apache.commons.math4.core.jdkmath.JdkMath; 027import org.apache.commons.math4.legacy.core.Pair; 028 029/** 030 * Implementation of Kendall's Tau-b rank correlation. 031 * <p> 032 * A pair of observations (x<sub>1</sub>, y<sub>1</sub>) and 033 * (x<sub>2</sub>, y<sub>2</sub>) are considered <i>concordant</i> if 034 * x<sub>1</sub> < x<sub>2</sub> and y<sub>1</sub> < y<sub>2</sub> 035 * or x<sub>2</sub> < x<sub>1</sub> and y<sub>2</sub> < y<sub>1</sub>. 036 * The pair is <i>discordant</i> if x<sub>1</sub> < x<sub>2</sub> and 037 * y<sub>2</sub> < y<sub>1</sub> or x<sub>2</sub> < x<sub>1</sub> and 038 * y<sub>1</sub> < y<sub>2</sub>. If either x<sub>1</sub> = x<sub>2</sub> 039 * or y<sub>1</sub> = y<sub>2</sub>, the pair is neither concordant nor 040 * discordant. 041 * <p> 042 * Kendall's Tau-b is defined as: 043 * <div style="white-space: pre"><code> 044 * tau<sub>b</sub> = (n<sub>c</sub> - n<sub>d</sub>) / sqrt((n<sub>0</sub> - n<sub>1</sub>) * (n<sub>0</sub> - n<sub>2</sub>)) 045 * </code></div> 046 * <p> 047 * where: 048 * <ul> 049 * <li>n<sub>0</sub> = n * (n - 1) / 2</li> 050 * <li>n<sub>c</sub> = Number of concordant pairs</li> 051 * <li>n<sub>d</sub> = Number of discordant pairs</li> 052 * <li>n<sub>1</sub> = sum of t<sub>i</sub> * (t<sub>i</sub> - 1) / 2 for all i</li> 053 * <li>n<sub>2</sub> = sum of u<sub>j</sub> * (u<sub>j</sub> - 1) / 2 for all j</li> 054 * <li>t<sub>i</sub> = Number of tied values in the i<sup>th</sup> group of ties in x</li> 055 * <li>u<sub>j</sub> = Number of tied values in the j<sup>th</sup> group of ties in y</li> 056 * </ul> 057 * <p> 058 * This implementation uses the O(n log n) algorithm described in 059 * William R. Knight's 1966 paper "A Computer Method for Calculating 060 * Kendall's Tau with Ungrouped Data" in the Journal of the American 061 * Statistical Association. 062 * 063 * @see <a href="http://en.wikipedia.org/wiki/Kendall_tau_rank_correlation_coefficient"> 064 * Kendall tau rank correlation coefficient (Wikipedia)</a> 065 * @see <a href="http://www.jstor.org/stable/2282833">A Computer 066 * Method for Calculating Kendall's Tau with Ungrouped Data</a> 067 * 068 * @since 3.3 069 */ 070public class KendallsCorrelation { 071 072 /** correlation matrix. */ 073 private final RealMatrix correlationMatrix; 074 075 /** 076 * Create a KendallsCorrelation instance without data. 077 */ 078 public KendallsCorrelation() { 079 correlationMatrix = null; 080 } 081 082 /** 083 * Create a KendallsCorrelation from a rectangular array 084 * whose columns represent values of variables to be correlated. 085 * 086 * @param data rectangular array with columns representing variables 087 * @throws IllegalArgumentException if the input data array is not 088 * rectangular with at least two rows and two columns. 089 */ 090 public KendallsCorrelation(double[][] data) { 091 this(MatrixUtils.createRealMatrix(data)); 092 } 093 094 /** 095 * Create a KendallsCorrelation from a RealMatrix whose columns 096 * represent variables to be correlated. 097 * 098 * @param matrix matrix with columns representing variables to correlate 099 */ 100 public KendallsCorrelation(RealMatrix matrix) { 101 correlationMatrix = computeCorrelationMatrix(matrix); 102 } 103 104 /** 105 * Returns the correlation matrix. 106 * 107 * @return correlation matrix 108 */ 109 public RealMatrix getCorrelationMatrix() { 110 return correlationMatrix; 111 } 112 113 /** 114 * Computes the Kendall's Tau rank correlation matrix for the columns of 115 * the input matrix. 116 * 117 * @param matrix matrix with columns representing variables to correlate 118 * @return correlation matrix 119 */ 120 public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) { 121 int nVars = matrix.getColumnDimension(); 122 RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars); 123 for (int i = 0; i < nVars; i++) { 124 for (int j = 0; j < i; j++) { 125 double corr = correlation(matrix.getColumn(i), matrix.getColumn(j)); 126 outMatrix.setEntry(i, j, corr); 127 outMatrix.setEntry(j, i, corr); 128 } 129 outMatrix.setEntry(i, i, 1d); 130 } 131 return outMatrix; 132 } 133 134 /** 135 * Computes the Kendall's Tau rank correlation matrix for the columns of 136 * the input rectangular array. The columns of the array represent values 137 * of variables to be correlated. 138 * 139 * @param matrix matrix with columns representing variables to correlate 140 * @return correlation matrix 141 */ 142 public RealMatrix computeCorrelationMatrix(final double[][] matrix) { 143 return computeCorrelationMatrix(new BlockRealMatrix(matrix)); 144 } 145 146 /** 147 * Computes the Kendall's Tau rank correlation coefficient between the two arrays. 148 * 149 * @param xArray first data array 150 * @param yArray second data array 151 * @return Returns Kendall's Tau rank correlation coefficient for the two arrays 152 * @throws DimensionMismatchException if the arrays lengths do not match 153 */ 154 public double correlation(final double[] xArray, final double[] yArray) 155 throws DimensionMismatchException { 156 157 if (xArray.length != yArray.length) { 158 throw new DimensionMismatchException(xArray.length, yArray.length); 159 } 160 161 final int n = xArray.length; 162 final long numPairs = sum(n - 1); 163 164 @SuppressWarnings("unchecked") 165 Pair<Double, Double>[] pairs = new Pair[n]; 166 for (int i = 0; i < n; i++) { 167 pairs[i] = new Pair<>(xArray[i], yArray[i]); 168 } 169 170 Arrays.sort(pairs, new Comparator<Pair<Double, Double>>() { 171 /** {@inheritDoc} */ 172 @Override 173 public int compare(Pair<Double, Double> pair1, Pair<Double, Double> pair2) { 174 int compareFirst = pair1.getFirst().compareTo(pair2.getFirst()); 175 return compareFirst != 0 ? compareFirst : pair1.getSecond().compareTo(pair2.getSecond()); 176 } 177 }); 178 179 long tiedXPairs = 0; 180 long tiedXYPairs = 0; 181 long consecutiveXTies = 1; 182 long consecutiveXYTies = 1; 183 Pair<Double, Double> prev = pairs[0]; 184 for (int i = 1; i < n; i++) { 185 final Pair<Double, Double> curr = pairs[i]; 186 if (curr.getFirst().equals(prev.getFirst())) { 187 consecutiveXTies++; 188 if (curr.getSecond().equals(prev.getSecond())) { 189 consecutiveXYTies++; 190 } else { 191 tiedXYPairs += sum(consecutiveXYTies - 1); 192 consecutiveXYTies = 1; 193 } 194 } else { 195 tiedXPairs += sum(consecutiveXTies - 1); 196 consecutiveXTies = 1; 197 tiedXYPairs += sum(consecutiveXYTies - 1); 198 consecutiveXYTies = 1; 199 } 200 prev = curr; 201 } 202 tiedXPairs += sum(consecutiveXTies - 1); 203 tiedXYPairs += sum(consecutiveXYTies - 1); 204 205 long swaps = 0; 206 @SuppressWarnings("unchecked") 207 Pair<Double, Double>[] pairsDestination = new Pair[n]; 208 for (int segmentSize = 1; segmentSize < n; segmentSize <<= 1) { 209 for (int offset = 0; offset < n; offset += 2 * segmentSize) { 210 int i = offset; 211 final int iEnd = JdkMath.min(i + segmentSize, n); 212 int j = iEnd; 213 final int jEnd = JdkMath.min(j + segmentSize, n); 214 215 int copyLocation = offset; 216 while (i < iEnd || j < jEnd) { 217 if (i < iEnd) { 218 if (j < jEnd) { 219 if (pairs[i].getSecond().compareTo(pairs[j].getSecond()) <= 0) { 220 pairsDestination[copyLocation] = pairs[i]; 221 i++; 222 } else { 223 pairsDestination[copyLocation] = pairs[j]; 224 j++; 225 swaps += iEnd - i; 226 } 227 } else { 228 pairsDestination[copyLocation] = pairs[i]; 229 i++; 230 } 231 } else { 232 pairsDestination[copyLocation] = pairs[j]; 233 j++; 234 } 235 copyLocation++; 236 } 237 } 238 final Pair<Double, Double>[] pairsTemp = pairs; 239 pairs = pairsDestination; 240 pairsDestination = pairsTemp; 241 } 242 243 long tiedYPairs = 0; 244 long consecutiveYTies = 1; 245 prev = pairs[0]; 246 for (int i = 1; i < n; i++) { 247 final Pair<Double, Double> curr = pairs[i]; 248 if (curr.getSecond().equals(prev.getSecond())) { 249 consecutiveYTies++; 250 } else { 251 tiedYPairs += sum(consecutiveYTies - 1); 252 consecutiveYTies = 1; 253 } 254 prev = curr; 255 } 256 tiedYPairs += sum(consecutiveYTies - 1); 257 258 final long concordantMinusDiscordant = numPairs - tiedXPairs - tiedYPairs + tiedXYPairs - 2 * swaps; 259 final double nonTiedPairsMultiplied = (numPairs - tiedXPairs) * (double) (numPairs - tiedYPairs); 260 return concordantMinusDiscordant / JdkMath.sqrt(nonTiedPairsMultiplied); 261 } 262 263 /** 264 * Returns the sum of the number from 1 .. n according to Gauss' summation formula: 265 * \[ \sum\limits_{k=1}^n k = \frac{n(n + 1)}{2} \] 266 * 267 * @param n the summation end 268 * @return the sum of the number from 1 to n 269 */ 270 private static long sum(long n) { 271 return n * (n + 1) / 2L; 272 } 273}