001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.stat.correlation;
018
019import org.apache.commons.math3.exception.DimensionMismatchException;
020import org.apache.commons.math3.linear.BlockRealMatrix;
021import org.apache.commons.math3.linear.MatrixUtils;
022import org.apache.commons.math3.linear.RealMatrix;
023import org.apache.commons.math3.util.FastMath;
024import org.apache.commons.math3.util.Pair;
025
026import java.util.Arrays;
027import java.util.Comparator;
028
029/**
030 * Implementation of Kendall's Tau-b rank correlation</a>.
031 * <p>
032 * A pair of observations (x<sub>1</sub>, y<sub>1</sub>) and
033 * (x<sub>2</sub>, y<sub>2</sub>) are considered <i>concordant</i> if
034 * x<sub>1</sub> &lt; x<sub>2</sub> and y<sub>1</sub> &lt; y<sub>2</sub>
035 * or x<sub>2</sub> &lt; x<sub>1</sub> and y<sub>2</sub> &lt; y<sub>1</sub>.
036 * The pair is <i>discordant</i> if x<sub>1</sub> &lt; x<sub>2</sub> and
037 * y<sub>2</sub> &lt; y<sub>1</sub> or x<sub>2</sub> &lt; x<sub>1</sub> and
038 * y<sub>1</sub> &lt; y<sub>2</sub>.  If either x<sub>1</sub> = x<sub>2</sub>
039 * or y<sub>1</sub> = y<sub>2</sub>, the pair is neither concordant nor
040 * discordant.
041 * <p>
042 * Kendall's Tau-b is defined as:
043 * <pre>
044 * tau<sub>b</sub> = (n<sub>c</sub> - n<sub>d</sub>) / sqrt((n<sub>0</sub> - n<sub>1</sub>) * (n<sub>0</sub> - n<sub>2</sub>))
045 * </pre>
046 * <p>
047 * where:
048 * <ul>
049 *     <li>n<sub>0</sub> = n * (n - 1) / 2</li>
050 *     <li>n<sub>c</sub> = Number of concordant pairs</li>
051 *     <li>n<sub>d</sub> = Number of discordant pairs</li>
052 *     <li>n<sub>1</sub> = sum of t<sub>i</sub> * (t<sub>i</sub> - 1) / 2 for all i</li>
053 *     <li>n<sub>2</sub> = sum of u<sub>j</sub> * (u<sub>j</sub> - 1) / 2 for all j</li>
054 *     <li>t<sub>i</sub> = Number of tied values in the i<sup>th</sup> group of ties in x</li>
055 *     <li>u<sub>j</sub> = Number of tied values in the j<sup>th</sup> group of ties in y</li>
056 * </ul>
057 * <p>
058 * This implementation uses the O(n log n) algorithm described in
059 * William R. Knight's 1966 paper "A Computer Method for Calculating
060 * Kendall's Tau with Ungrouped Data" in the Journal of the American
061 * Statistical Association.
062 *
063 * @see <a href="http://en.wikipedia.org/wiki/Kendall_tau_rank_correlation_coefficient">
064 * Kendall tau rank correlation coefficient (Wikipedia)</a>
065 * @see <a href="http://www.jstor.org/stable/2282833">A Computer
066 * Method for Calculating Kendall's Tau with Ungrouped Data</a>
067 *
068 * @since 3.3
069 */
070public class KendallsCorrelation {
071
072    /** correlation matrix */
073    private final RealMatrix correlationMatrix;
074
075    /**
076     * Create a KendallsCorrelation instance without data.
077     */
078    public KendallsCorrelation() {
079        correlationMatrix = null;
080    }
081
082    /**
083     * Create a KendallsCorrelation from a rectangular array
084     * whose columns represent values of variables to be correlated.
085     *
086     * @param data rectangular array with columns representing variables
087     * @throws IllegalArgumentException if the input data array is not
088     * rectangular with at least two rows and two columns.
089     */
090    public KendallsCorrelation(double[][] data) {
091        this(MatrixUtils.createRealMatrix(data));
092    }
093
094    /**
095     * Create a KendallsCorrelation from a RealMatrix whose columns
096     * represent variables to be correlated.
097     *
098     * @param matrix matrix with columns representing variables to correlate
099     */
100    public KendallsCorrelation(RealMatrix matrix) {
101        correlationMatrix = computeCorrelationMatrix(matrix);
102    }
103
104    /**
105     * Returns the correlation matrix.
106     *
107     * @return correlation matrix
108     */
109    public RealMatrix getCorrelationMatrix() {
110        return correlationMatrix;
111    }
112
113    /**
114     * Computes the Kendall's Tau rank correlation matrix for the columns of
115     * the input matrix.
116     *
117     * @param matrix matrix with columns representing variables to correlate
118     * @return correlation matrix
119     */
120    public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) {
121        int nVars = matrix.getColumnDimension();
122        RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars);
123        for (int i = 0; i < nVars; i++) {
124            for (int j = 0; j < i; j++) {
125                double corr = correlation(matrix.getColumn(i), matrix.getColumn(j));
126                outMatrix.setEntry(i, j, corr);
127                outMatrix.setEntry(j, i, corr);
128            }
129            outMatrix.setEntry(i, i, 1d);
130        }
131        return outMatrix;
132    }
133
134    /**
135     * Computes the Kendall's Tau rank correlation matrix for the columns of
136     * the input rectangular array.  The columns of the array represent values
137     * of variables to be correlated.
138     *
139     * @param matrix matrix with columns representing variables to correlate
140     * @return correlation matrix
141     */
142    public RealMatrix computeCorrelationMatrix(final double[][] matrix) {
143       return computeCorrelationMatrix(new BlockRealMatrix(matrix));
144    }
145
146    /**
147     * Computes the Kendall's Tau rank correlation coefficient between the two arrays.
148     *
149     * @param xArray first data array
150     * @param yArray second data array
151     * @return Returns Kendall's Tau rank correlation coefficient for the two arrays
152     * @throws DimensionMismatchException if the arrays lengths do not match
153     */
154    public double correlation(final double[] xArray, final double[] yArray)
155            throws DimensionMismatchException {
156
157        if (xArray.length != yArray.length) {
158            throw new DimensionMismatchException(xArray.length, yArray.length);
159        }
160
161        final int n = xArray.length;
162        final long numPairs = sum(n - 1);
163
164        @SuppressWarnings("unchecked")
165        Pair<Double, Double>[] pairs = new Pair[n];
166        for (int i = 0; i < n; i++) {
167            pairs[i] = new Pair<Double, Double>(xArray[i], yArray[i]);
168        }
169
170        Arrays.sort(pairs, new Comparator<Pair<Double, Double>>() {
171            /** {@inheritDoc} */
172            public int compare(Pair<Double, Double> pair1, Pair<Double, Double> pair2) {
173                int compareFirst = pair1.getFirst().compareTo(pair2.getFirst());
174                return compareFirst != 0 ? compareFirst : pair1.getSecond().compareTo(pair2.getSecond());
175            }
176        });
177
178        long tiedXPairs = 0;
179        long tiedXYPairs = 0;
180        long consecutiveXTies = 1;
181        long consecutiveXYTies = 1;
182        Pair<Double, Double> prev = pairs[0];
183        for (int i = 1; i < n; i++) {
184            final Pair<Double, Double> curr = pairs[i];
185            if (curr.getFirst().equals(prev.getFirst())) {
186                consecutiveXTies++;
187                if (curr.getSecond().equals(prev.getSecond())) {
188                    consecutiveXYTies++;
189                } else {
190                    tiedXYPairs += sum(consecutiveXYTies - 1);
191                    consecutiveXYTies = 1;
192                }
193            } else {
194                tiedXPairs += sum(consecutiveXTies - 1);
195                consecutiveXTies = 1;
196                tiedXYPairs += sum(consecutiveXYTies - 1);
197                consecutiveXYTies = 1;
198            }
199            prev = curr;
200        }
201        tiedXPairs += sum(consecutiveXTies - 1);
202        tiedXYPairs += sum(consecutiveXYTies - 1);
203
204        long swaps = 0;
205        @SuppressWarnings("unchecked")
206        Pair<Double, Double>[] pairsDestination = new Pair[n];
207        for (int segmentSize = 1; segmentSize < n; segmentSize <<= 1) {
208            for (int offset = 0; offset < n; offset += 2 * segmentSize) {
209                int i = offset;
210                final int iEnd = FastMath.min(i + segmentSize, n);
211                int j = iEnd;
212                final int jEnd = FastMath.min(j + segmentSize, n);
213
214                int copyLocation = offset;
215                while (i < iEnd || j < jEnd) {
216                    if (i < iEnd) {
217                        if (j < jEnd) {
218                            if (pairs[i].getSecond().compareTo(pairs[j].getSecond()) <= 0) {
219                                pairsDestination[copyLocation] = pairs[i];
220                                i++;
221                            } else {
222                                pairsDestination[copyLocation] = pairs[j];
223                                j++;
224                                swaps += iEnd - i;
225                            }
226                        } else {
227                            pairsDestination[copyLocation] = pairs[i];
228                            i++;
229                        }
230                    } else {
231                        pairsDestination[copyLocation] = pairs[j];
232                        j++;
233                    }
234                    copyLocation++;
235                }
236            }
237            final Pair<Double, Double>[] pairsTemp = pairs;
238            pairs = pairsDestination;
239            pairsDestination = pairsTemp;
240        }
241
242        long tiedYPairs = 0;
243        long consecutiveYTies = 1;
244        prev = pairs[0];
245        for (int i = 1; i < n; i++) {
246            final Pair<Double, Double> curr = pairs[i];
247            if (curr.getSecond().equals(prev.getSecond())) {
248                consecutiveYTies++;
249            } else {
250                tiedYPairs += sum(consecutiveYTies - 1);
251                consecutiveYTies = 1;
252            }
253            prev = curr;
254        }
255        tiedYPairs += sum(consecutiveYTies - 1);
256
257        final long concordantMinusDiscordant = numPairs - tiedXPairs - tiedYPairs + tiedXYPairs - 2 * swaps;
258        final double nonTiedPairsMultiplied = (numPairs - tiedXPairs) * (double) (numPairs - tiedYPairs);
259        return concordantMinusDiscordant / FastMath.sqrt(nonTiedPairsMultiplied);
260    }
261
262    /**
263     * Returns the sum of the number from 1 .. n according to Gauss' summation formula:
264     * \[ \sum\limits_{k=1}^n k = \frac{n(n + 1)}{2} \]
265     *
266     * @param n the summation end
267     * @return the sum of the number from 1 to n
268     */
269    private static long sum(long n) {
270        return n * (n + 1) / 2l;
271    }
272}