View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.stat.correlation;
18  
19  import java.util.Arrays;
20  import java.util.Comparator;
21  
22  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
23  import org.apache.commons.math4.legacy.linear.BlockRealMatrix;
24  import org.apache.commons.math4.legacy.linear.MatrixUtils;
25  import org.apache.commons.math4.legacy.linear.RealMatrix;
26  import org.apache.commons.math4.core.jdkmath.JdkMath;
27  import org.apache.commons.math4.legacy.core.Pair;
28  
29  /**
30   * Implementation of Kendall's Tau-b rank correlation.
31   * <p>
32   * A pair of observations (x<sub>1</sub>, y<sub>1</sub>) and
33   * (x<sub>2</sub>, y<sub>2</sub>) are considered <i>concordant</i> if
34   * x<sub>1</sub> &lt; x<sub>2</sub> and y<sub>1</sub> &lt; y<sub>2</sub>
35   * or x<sub>2</sub> &lt; x<sub>1</sub> and y<sub>2</sub> &lt; y<sub>1</sub>.
36   * The pair is <i>discordant</i> if x<sub>1</sub> &lt; x<sub>2</sub> and
37   * y<sub>2</sub> &lt; y<sub>1</sub> or x<sub>2</sub> &lt; x<sub>1</sub> and
38   * y<sub>1</sub> &lt; y<sub>2</sub>.  If either x<sub>1</sub> = x<sub>2</sub>
39   * or y<sub>1</sub> = y<sub>2</sub>, the pair is neither concordant nor
40   * discordant.
41   * <p>
42   * Kendall's Tau-b is defined as:
43   * <div style="white-space: pre"><code>
44   * tau<sub>b</sub> = (n<sub>c</sub> - n<sub>d</sub>) / sqrt((n<sub>0</sub> - n<sub>1</sub>) * (n<sub>0</sub> - n<sub>2</sub>))
45   * </code></div>
46   * <p>
47   * where:
48   * <ul>
49   *     <li>n<sub>0</sub> = n * (n - 1) / 2</li>
50   *     <li>n<sub>c</sub> = Number of concordant pairs</li>
51   *     <li>n<sub>d</sub> = Number of discordant pairs</li>
52   *     <li>n<sub>1</sub> = sum of t<sub>i</sub> * (t<sub>i</sub> - 1) / 2 for all i</li>
53   *     <li>n<sub>2</sub> = sum of u<sub>j</sub> * (u<sub>j</sub> - 1) / 2 for all j</li>
54   *     <li>t<sub>i</sub> = Number of tied values in the i<sup>th</sup> group of ties in x</li>
55   *     <li>u<sub>j</sub> = Number of tied values in the j<sup>th</sup> group of ties in y</li>
56   * </ul>
57   * <p>
58   * This implementation uses the O(n log n) algorithm described in
59   * William R. Knight's 1966 paper "A Computer Method for Calculating
60   * Kendall's Tau with Ungrouped Data" in the Journal of the American
61   * Statistical Association.
62   *
63   * @see <a href="http://en.wikipedia.org/wiki/Kendall_tau_rank_correlation_coefficient">
64   * Kendall tau rank correlation coefficient (Wikipedia)</a>
65   * @see <a href="http://www.jstor.org/stable/2282833">A Computer
66   * Method for Calculating Kendall's Tau with Ungrouped Data</a>
67   *
68   * @since 3.3
69   */
70  public class KendallsCorrelation {
71  
72      /** correlation matrix. */
73      private final RealMatrix correlationMatrix;
74  
75      /**
76       * Create a KendallsCorrelation instance without data.
77       */
78      public KendallsCorrelation() {
79          correlationMatrix = null;
80      }
81  
82      /**
83       * Create a KendallsCorrelation from a rectangular array
84       * whose columns represent values of variables to be correlated.
85       *
86       * @param data rectangular array with columns representing variables
87       * @throws IllegalArgumentException if the input data array is not
88       * rectangular with at least two rows and two columns.
89       */
90      public KendallsCorrelation(double[][] data) {
91          this(MatrixUtils.createRealMatrix(data));
92      }
93  
94      /**
95       * Create a KendallsCorrelation from a RealMatrix whose columns
96       * represent variables to be correlated.
97       *
98       * @param matrix matrix with columns representing variables to correlate
99       */
100     public KendallsCorrelation(RealMatrix matrix) {
101         correlationMatrix = computeCorrelationMatrix(matrix);
102     }
103 
104     /**
105      * Returns the correlation matrix.
106      *
107      * @return correlation matrix
108      */
109     public RealMatrix getCorrelationMatrix() {
110         return correlationMatrix;
111     }
112 
113     /**
114      * Computes the Kendall's Tau rank correlation matrix for the columns of
115      * the input matrix.
116      *
117      * @param matrix matrix with columns representing variables to correlate
118      * @return correlation matrix
119      */
120     public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) {
121         int nVars = matrix.getColumnDimension();
122         RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars);
123         for (int i = 0; i < nVars; i++) {
124             for (int j = 0; j < i; j++) {
125                 double corr = correlation(matrix.getColumn(i), matrix.getColumn(j));
126                 outMatrix.setEntry(i, j, corr);
127                 outMatrix.setEntry(j, i, corr);
128             }
129             outMatrix.setEntry(i, i, 1d);
130         }
131         return outMatrix;
132     }
133 
134     /**
135      * Computes the Kendall's Tau rank correlation matrix for the columns of
136      * the input rectangular array.  The columns of the array represent values
137      * of variables to be correlated.
138      *
139      * @param matrix matrix with columns representing variables to correlate
140      * @return correlation matrix
141      */
142     public RealMatrix computeCorrelationMatrix(final double[][] matrix) {
143        return computeCorrelationMatrix(new BlockRealMatrix(matrix));
144     }
145 
146     /**
147      * Computes the Kendall's Tau rank correlation coefficient between the two arrays.
148      *
149      * @param xArray first data array
150      * @param yArray second data array
151      * @return Returns Kendall's Tau rank correlation coefficient for the two arrays
152      * @throws DimensionMismatchException if the arrays lengths do not match
153      */
154     public double correlation(final double[] xArray, final double[] yArray)
155             throws DimensionMismatchException {
156 
157         if (xArray.length != yArray.length) {
158             throw new DimensionMismatchException(xArray.length, yArray.length);
159         }
160 
161         final int n = xArray.length;
162         final long numPairs = sum(n - 1);
163 
164         @SuppressWarnings("unchecked")
165         Pair<Double, Double>[] pairs = new Pair[n];
166         for (int i = 0; i < n; i++) {
167             pairs[i] = new Pair<>(xArray[i], yArray[i]);
168         }
169 
170         Arrays.sort(pairs, new Comparator<Pair<Double, Double>>() {
171             /** {@inheritDoc} */
172             @Override
173             public int compare(Pair<Double, Double> pair1, Pair<Double, Double> pair2) {
174                 int compareFirst = pair1.getFirst().compareTo(pair2.getFirst());
175                 return compareFirst != 0 ? compareFirst : pair1.getSecond().compareTo(pair2.getSecond());
176             }
177         });
178 
179         long tiedXPairs = 0;
180         long tiedXYPairs = 0;
181         long consecutiveXTies = 1;
182         long consecutiveXYTies = 1;
183         Pair<Double, Double> prev = pairs[0];
184         for (int i = 1; i < n; i++) {
185             final Pair<Double, Double> curr = pairs[i];
186             if (curr.getFirst().equals(prev.getFirst())) {
187                 consecutiveXTies++;
188                 if (curr.getSecond().equals(prev.getSecond())) {
189                     consecutiveXYTies++;
190                 } else {
191                     tiedXYPairs += sum(consecutiveXYTies - 1);
192                     consecutiveXYTies = 1;
193                 }
194             } else {
195                 tiedXPairs += sum(consecutiveXTies - 1);
196                 consecutiveXTies = 1;
197                 tiedXYPairs += sum(consecutiveXYTies - 1);
198                 consecutiveXYTies = 1;
199             }
200             prev = curr;
201         }
202         tiedXPairs += sum(consecutiveXTies - 1);
203         tiedXYPairs += sum(consecutiveXYTies - 1);
204 
205         long swaps = 0;
206         @SuppressWarnings("unchecked")
207         Pair<Double, Double>[] pairsDestination = new Pair[n];
208         for (int segmentSize = 1; segmentSize < n; segmentSize <<= 1) {
209             for (int offset = 0; offset < n; offset += 2 * segmentSize) {
210                 int i = offset;
211                 final int iEnd = JdkMath.min(i + segmentSize, n);
212                 int j = iEnd;
213                 final int jEnd = JdkMath.min(j + segmentSize, n);
214 
215                 int copyLocation = offset;
216                 while (i < iEnd || j < jEnd) {
217                     if (i < iEnd) {
218                         if (j < jEnd) {
219                             if (pairs[i].getSecond().compareTo(pairs[j].getSecond()) <= 0) {
220                                 pairsDestination[copyLocation] = pairs[i];
221                                 i++;
222                             } else {
223                                 pairsDestination[copyLocation] = pairs[j];
224                                 j++;
225                                 swaps += iEnd - i;
226                             }
227                         } else {
228                             pairsDestination[copyLocation] = pairs[i];
229                             i++;
230                         }
231                     } else {
232                         pairsDestination[copyLocation] = pairs[j];
233                         j++;
234                     }
235                     copyLocation++;
236                 }
237             }
238             final Pair<Double, Double>[] pairsTemp = pairs;
239             pairs = pairsDestination;
240             pairsDestination = pairsTemp;
241         }
242 
243         long tiedYPairs = 0;
244         long consecutiveYTies = 1;
245         prev = pairs[0];
246         for (int i = 1; i < n; i++) {
247             final Pair<Double, Double> curr = pairs[i];
248             if (curr.getSecond().equals(prev.getSecond())) {
249                 consecutiveYTies++;
250             } else {
251                 tiedYPairs += sum(consecutiveYTies - 1);
252                 consecutiveYTies = 1;
253             }
254             prev = curr;
255         }
256         tiedYPairs += sum(consecutiveYTies - 1);
257 
258         final long concordantMinusDiscordant = numPairs - tiedXPairs - tiedYPairs + tiedXYPairs - 2 * swaps;
259         final double nonTiedPairsMultiplied = (numPairs - tiedXPairs) * (double) (numPairs - tiedYPairs);
260         return concordantMinusDiscordant / JdkMath.sqrt(nonTiedPairsMultiplied);
261     }
262 
263     /**
264      * Returns the sum of the number from 1 .. n according to Gauss' summation formula:
265      * \[ \sum\limits_{k=1}^n k = \frac{n(n + 1)}{2} \]
266      *
267      * @param n the summation end
268      * @return the sum of the number from 1 to n
269      */
270     private static long sum(long n) {
271         return n * (n + 1) / 2L;
272     }
273 }