KendallsCorrelation.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.stat.correlation;

  18. import java.util.Arrays;
  19. import java.util.Comparator;

  20. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  21. import org.apache.commons.math4.legacy.linear.BlockRealMatrix;
  22. import org.apache.commons.math4.legacy.linear.MatrixUtils;
  23. import org.apache.commons.math4.legacy.linear.RealMatrix;
  24. import org.apache.commons.math4.core.jdkmath.JdkMath;
  25. import org.apache.commons.math4.legacy.core.Pair;

  26. /**
  27.  * Implementation of Kendall's Tau-b rank correlation.
  28.  * <p>
  29.  * A pair of observations (x<sub>1</sub>, y<sub>1</sub>) and
  30.  * (x<sub>2</sub>, y<sub>2</sub>) are considered <i>concordant</i> if
  31.  * x<sub>1</sub> &lt; x<sub>2</sub> and y<sub>1</sub> &lt; y<sub>2</sub>
  32.  * or x<sub>2</sub> &lt; x<sub>1</sub> and y<sub>2</sub> &lt; y<sub>1</sub>.
  33.  * The pair is <i>discordant</i> if x<sub>1</sub> &lt; x<sub>2</sub> and
  34.  * y<sub>2</sub> &lt; y<sub>1</sub> or x<sub>2</sub> &lt; x<sub>1</sub> and
  35.  * y<sub>1</sub> &lt; y<sub>2</sub>.  If either x<sub>1</sub> = x<sub>2</sub>
  36.  * or y<sub>1</sub> = y<sub>2</sub>, the pair is neither concordant nor
  37.  * discordant.
  38.  * <p>
  39.  * Kendall's Tau-b is defined as:
  40.  * <div style="white-space: pre"><code>
  41.  * tau<sub>b</sub> = (n<sub>c</sub> - n<sub>d</sub>) / sqrt((n<sub>0</sub> - n<sub>1</sub>) * (n<sub>0</sub> - n<sub>2</sub>))
  42.  * </code></div>
  43.  * <p>
  44.  * where:
  45.  * <ul>
  46.  *     <li>n<sub>0</sub> = n * (n - 1) / 2</li>
  47.  *     <li>n<sub>c</sub> = Number of concordant pairs</li>
  48.  *     <li>n<sub>d</sub> = Number of discordant pairs</li>
  49.  *     <li>n<sub>1</sub> = sum of t<sub>i</sub> * (t<sub>i</sub> - 1) / 2 for all i</li>
  50.  *     <li>n<sub>2</sub> = sum of u<sub>j</sub> * (u<sub>j</sub> - 1) / 2 for all j</li>
  51.  *     <li>t<sub>i</sub> = Number of tied values in the i<sup>th</sup> group of ties in x</li>
  52.  *     <li>u<sub>j</sub> = Number of tied values in the j<sup>th</sup> group of ties in y</li>
  53.  * </ul>
  54.  * <p>
  55.  * This implementation uses the O(n log n) algorithm described in
  56.  * William R. Knight's 1966 paper "A Computer Method for Calculating
  57.  * Kendall's Tau with Ungrouped Data" in the Journal of the American
  58.  * Statistical Association.
  59.  *
  60.  * @see <a href="http://en.wikipedia.org/wiki/Kendall_tau_rank_correlation_coefficient">
  61.  * Kendall tau rank correlation coefficient (Wikipedia)</a>
  62.  * @see <a href="http://www.jstor.org/stable/2282833">A Computer
  63.  * Method for Calculating Kendall's Tau with Ungrouped Data</a>
  64.  *
  65.  * @since 3.3
  66.  */
  67. public class KendallsCorrelation {

  68.     /** correlation matrix. */
  69.     private final RealMatrix correlationMatrix;

  70.     /**
  71.      * Create a KendallsCorrelation instance without data.
  72.      */
  73.     public KendallsCorrelation() {
  74.         correlationMatrix = null;
  75.     }

  76.     /**
  77.      * Create a KendallsCorrelation from a rectangular array
  78.      * whose columns represent values of variables to be correlated.
  79.      *
  80.      * @param data rectangular array with columns representing variables
  81.      * @throws IllegalArgumentException if the input data array is not
  82.      * rectangular with at least two rows and two columns.
  83.      */
  84.     public KendallsCorrelation(double[][] data) {
  85.         this(MatrixUtils.createRealMatrix(data));
  86.     }

  87.     /**
  88.      * Create a KendallsCorrelation from a RealMatrix whose columns
  89.      * represent variables to be correlated.
  90.      *
  91.      * @param matrix matrix with columns representing variables to correlate
  92.      */
  93.     public KendallsCorrelation(RealMatrix matrix) {
  94.         correlationMatrix = computeCorrelationMatrix(matrix);
  95.     }

  96.     /**
  97.      * Returns the correlation matrix.
  98.      *
  99.      * @return correlation matrix
  100.      */
  101.     public RealMatrix getCorrelationMatrix() {
  102.         return correlationMatrix;
  103.     }

  104.     /**
  105.      * Computes the Kendall's Tau rank correlation matrix for the columns of
  106.      * the input matrix.
  107.      *
  108.      * @param matrix matrix with columns representing variables to correlate
  109.      * @return correlation matrix
  110.      */
  111.     public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) {
  112.         int nVars = matrix.getColumnDimension();
  113.         RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars);
  114.         for (int i = 0; i < nVars; i++) {
  115.             for (int j = 0; j < i; j++) {
  116.                 double corr = correlation(matrix.getColumn(i), matrix.getColumn(j));
  117.                 outMatrix.setEntry(i, j, corr);
  118.                 outMatrix.setEntry(j, i, corr);
  119.             }
  120.             outMatrix.setEntry(i, i, 1d);
  121.         }
  122.         return outMatrix;
  123.     }

  124.     /**
  125.      * Computes the Kendall's Tau rank correlation matrix for the columns of
  126.      * the input rectangular array.  The columns of the array represent values
  127.      * of variables to be correlated.
  128.      *
  129.      * @param matrix matrix with columns representing variables to correlate
  130.      * @return correlation matrix
  131.      */
  132.     public RealMatrix computeCorrelationMatrix(final double[][] matrix) {
  133.        return computeCorrelationMatrix(new BlockRealMatrix(matrix));
  134.     }

  135.     /**
  136.      * Computes the Kendall's Tau rank correlation coefficient between the two arrays.
  137.      *
  138.      * @param xArray first data array
  139.      * @param yArray second data array
  140.      * @return Returns Kendall's Tau rank correlation coefficient for the two arrays
  141.      * @throws DimensionMismatchException if the arrays lengths do not match
  142.      */
  143.     public double correlation(final double[] xArray, final double[] yArray)
  144.             throws DimensionMismatchException {

  145.         if (xArray.length != yArray.length) {
  146.             throw new DimensionMismatchException(xArray.length, yArray.length);
  147.         }

  148.         final int n = xArray.length;
  149.         final long numPairs = sum(n - 1);

  150.         @SuppressWarnings("unchecked")
  151.         Pair<Double, Double>[] pairs = new Pair[n];
  152.         for (int i = 0; i < n; i++) {
  153.             pairs[i] = new Pair<>(xArray[i], yArray[i]);
  154.         }

  155.         Arrays.sort(pairs, new Comparator<Pair<Double, Double>>() {
  156.             /** {@inheritDoc} */
  157.             @Override
  158.             public int compare(Pair<Double, Double> pair1, Pair<Double, Double> pair2) {
  159.                 int compareFirst = pair1.getFirst().compareTo(pair2.getFirst());
  160.                 return compareFirst != 0 ? compareFirst : pair1.getSecond().compareTo(pair2.getSecond());
  161.             }
  162.         });

  163.         long tiedXPairs = 0;
  164.         long tiedXYPairs = 0;
  165.         long consecutiveXTies = 1;
  166.         long consecutiveXYTies = 1;
  167.         Pair<Double, Double> prev = pairs[0];
  168.         for (int i = 1; i < n; i++) {
  169.             final Pair<Double, Double> curr = pairs[i];
  170.             if (curr.getFirst().equals(prev.getFirst())) {
  171.                 consecutiveXTies++;
  172.                 if (curr.getSecond().equals(prev.getSecond())) {
  173.                     consecutiveXYTies++;
  174.                 } else {
  175.                     tiedXYPairs += sum(consecutiveXYTies - 1);
  176.                     consecutiveXYTies = 1;
  177.                 }
  178.             } else {
  179.                 tiedXPairs += sum(consecutiveXTies - 1);
  180.                 consecutiveXTies = 1;
  181.                 tiedXYPairs += sum(consecutiveXYTies - 1);
  182.                 consecutiveXYTies = 1;
  183.             }
  184.             prev = curr;
  185.         }
  186.         tiedXPairs += sum(consecutiveXTies - 1);
  187.         tiedXYPairs += sum(consecutiveXYTies - 1);

  188.         long swaps = 0;
  189.         @SuppressWarnings("unchecked")
  190.         Pair<Double, Double>[] pairsDestination = new Pair[n];
  191.         for (int segmentSize = 1; segmentSize < n; segmentSize <<= 1) {
  192.             for (int offset = 0; offset < n; offset += 2 * segmentSize) {
  193.                 int i = offset;
  194.                 final int iEnd = JdkMath.min(i + segmentSize, n);
  195.                 int j = iEnd;
  196.                 final int jEnd = JdkMath.min(j + segmentSize, n);

  197.                 int copyLocation = offset;
  198.                 while (i < iEnd || j < jEnd) {
  199.                     if (i < iEnd) {
  200.                         if (j < jEnd) {
  201.                             if (pairs[i].getSecond().compareTo(pairs[j].getSecond()) <= 0) {
  202.                                 pairsDestination[copyLocation] = pairs[i];
  203.                                 i++;
  204.                             } else {
  205.                                 pairsDestination[copyLocation] = pairs[j];
  206.                                 j++;
  207.                                 swaps += iEnd - i;
  208.                             }
  209.                         } else {
  210.                             pairsDestination[copyLocation] = pairs[i];
  211.                             i++;
  212.                         }
  213.                     } else {
  214.                         pairsDestination[copyLocation] = pairs[j];
  215.                         j++;
  216.                     }
  217.                     copyLocation++;
  218.                 }
  219.             }
  220.             final Pair<Double, Double>[] pairsTemp = pairs;
  221.             pairs = pairsDestination;
  222.             pairsDestination = pairsTemp;
  223.         }

  224.         long tiedYPairs = 0;
  225.         long consecutiveYTies = 1;
  226.         prev = pairs[0];
  227.         for (int i = 1; i < n; i++) {
  228.             final Pair<Double, Double> curr = pairs[i];
  229.             if (curr.getSecond().equals(prev.getSecond())) {
  230.                 consecutiveYTies++;
  231.             } else {
  232.                 tiedYPairs += sum(consecutiveYTies - 1);
  233.                 consecutiveYTies = 1;
  234.             }
  235.             prev = curr;
  236.         }
  237.         tiedYPairs += sum(consecutiveYTies - 1);

  238.         final long concordantMinusDiscordant = numPairs - tiedXPairs - tiedYPairs + tiedXYPairs - 2 * swaps;
  239.         final double nonTiedPairsMultiplied = (numPairs - tiedXPairs) * (double) (numPairs - tiedYPairs);
  240.         return concordantMinusDiscordant / JdkMath.sqrt(nonTiedPairsMultiplied);
  241.     }

  242.     /**
  243.      * Returns the sum of the number from 1 .. n according to Gauss' summation formula:
  244.      * \[ \sum\limits_{k=1}^n k = \frac{n(n + 1)}{2} \]
  245.      *
  246.      * @param n the summation end
  247.      * @return the sum of the number from 1 to n
  248.      */
  249.     private static long sum(long n) {
  250.         return n * (n + 1) / 2L;
  251.     }
  252. }