001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.stat.correlation;
019
020import java.util.ArrayList;
021import java.util.HashSet;
022import java.util.List;
023import java.util.Set;
024
025import org.apache.commons.math3.exception.DimensionMismatchException;
026import org.apache.commons.math3.exception.MathIllegalArgumentException;
027import org.apache.commons.math3.exception.util.LocalizedFormats;
028import org.apache.commons.math3.linear.BlockRealMatrix;
029import org.apache.commons.math3.linear.RealMatrix;
030import org.apache.commons.math3.stat.ranking.NaNStrategy;
031import org.apache.commons.math3.stat.ranking.NaturalRanking;
032import org.apache.commons.math3.stat.ranking.RankingAlgorithm;
033
034/**
035 * Spearman's rank correlation. This implementation performs a rank
036 * transformation on the input data and then computes {@link PearsonsCorrelation}
037 * on the ranked data.
038 * <p>
039 * By default, ranks are computed using {@link NaturalRanking} with default
040 * strategies for handling NaNs and ties in the data (NaNs maximal, ties averaged).
041 * The ranking algorithm can be set using a constructor argument.
042 *
043 * @since 2.0
044 * @version $Id: SpearmansCorrelation.java 1461822 2013-03-27 19:44:22Z tn $
045 */
046public class SpearmansCorrelation {
047
048    /** Input data */
049    private final RealMatrix data;
050
051    /** Ranking algorithm  */
052    private final RankingAlgorithm rankingAlgorithm;
053
054    /** Rank correlation */
055    private final PearsonsCorrelation rankCorrelation;
056
057    /**
058     * Create a SpearmansCorrelation without data.
059     */
060    public SpearmansCorrelation() {
061        this(new NaturalRanking());
062    }
063
064    /**
065     * Create a SpearmansCorrelation with the given ranking algorithm.
066     * <p>
067     * From version 4.0 onwards this constructor will throw an exception
068     * if the provided {@link NaturalRanking} uses a {@link NaNStrategy#REMOVED} strategy.
069     *
070     * @param rankingAlgorithm ranking algorithm
071     * @since 3.1
072     */
073    public SpearmansCorrelation(final RankingAlgorithm rankingAlgorithm) {
074        data = null;
075        this.rankingAlgorithm = rankingAlgorithm;
076        rankCorrelation = null;
077    }
078
079    /**
080     * Create a SpearmansCorrelation from the given data matrix.
081     *
082     * @param dataMatrix matrix of data with columns representing
083     * variables to correlate
084     */
085    public SpearmansCorrelation(final RealMatrix dataMatrix) {
086        this(dataMatrix, new NaturalRanking());
087    }
088
089    /**
090     * Create a SpearmansCorrelation with the given input data matrix
091     * and ranking algorithm.
092     * <p>
093     * From version 4.0 onwards this constructor will throw an exception
094     * if the provided {@link NaturalRanking} uses a {@link NaNStrategy#REMOVED} strategy.
095     *
096     * @param dataMatrix matrix of data with columns representing
097     * variables to correlate
098     * @param rankingAlgorithm ranking algorithm
099     */
100    public SpearmansCorrelation(final RealMatrix dataMatrix, final RankingAlgorithm rankingAlgorithm) {
101        this.rankingAlgorithm = rankingAlgorithm;
102        this.data = rankTransform(dataMatrix);
103        rankCorrelation = new PearsonsCorrelation(data);
104    }
105
106    /**
107     * Calculate the Spearman Rank Correlation Matrix.
108     *
109     * @return Spearman Rank Correlation Matrix
110     */
111    public RealMatrix getCorrelationMatrix() {
112        return rankCorrelation.getCorrelationMatrix();
113    }
114
115    /**
116     * Returns a {@link PearsonsCorrelation} instance constructed from the
117     * ranked input data. That is,
118     * <code>new SpearmansCorrelation(matrix).getRankCorrelation()</code>
119     * is equivalent to
120     * <code>new PearsonsCorrelation(rankTransform(matrix))</code> where
121     * <code>rankTransform(matrix)</code> is the result of applying the
122     * configured <code>RankingAlgorithm</code> to each of the columns of
123     * <code>matrix.</code>
124     *
125     * @return PearsonsCorrelation among ranked column data
126     */
127    public PearsonsCorrelation getRankCorrelation() {
128        return rankCorrelation;
129    }
130
131    /**
132     * Computes the Spearman's rank correlation matrix for the columns of the
133     * input matrix.
134     *
135     * @param matrix matrix with columns representing variables to correlate
136     * @return correlation matrix
137     */
138    public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) {
139        final RealMatrix matrixCopy = rankTransform(matrix);
140        return new PearsonsCorrelation().computeCorrelationMatrix(matrixCopy);
141    }
142
143    /**
144     * Computes the Spearman's rank correlation matrix for the columns of the
145     * input rectangular array.  The columns of the array represent values
146     * of variables to be correlated.
147     *
148     * @param matrix matrix with columns representing variables to correlate
149     * @return correlation matrix
150     */
151    public RealMatrix computeCorrelationMatrix(final double[][] matrix) {
152       return computeCorrelationMatrix(new BlockRealMatrix(matrix));
153    }
154
155    /**
156     * Computes the Spearman's rank correlation coefficient between the two arrays.
157     *
158     * @param xArray first data array
159     * @param yArray second data array
160     * @return Returns Spearman's rank correlation coefficient for the two arrays
161     * @throws DimensionMismatchException if the arrays lengths do not match
162     * @throws MathIllegalArgumentException if the array length is less than 2
163     */
164    public double correlation(final double[] xArray, final double[] yArray) {
165        if (xArray.length != yArray.length) {
166            throw new DimensionMismatchException(xArray.length, yArray.length);
167        } else if (xArray.length < 2) {
168            throw new MathIllegalArgumentException(LocalizedFormats.INSUFFICIENT_DIMENSION,
169                                                   xArray.length, 2);
170        } else {
171            double[] x = xArray;
172            double[] y = yArray;
173            if (rankingAlgorithm instanceof NaturalRanking &&
174                NaNStrategy.REMOVED == ((NaturalRanking) rankingAlgorithm).getNanStrategy()) {
175                final Set<Integer> nanPositions = new HashSet<Integer>();
176
177                nanPositions.addAll(getNaNPositions(xArray));
178                nanPositions.addAll(getNaNPositions(yArray));
179
180                x = removeValues(xArray, nanPositions);
181                y = removeValues(yArray, nanPositions);
182            }
183            return new PearsonsCorrelation().correlation(rankingAlgorithm.rank(x), rankingAlgorithm.rank(y));
184        }
185    }
186
187    /**
188     * Applies rank transform to each of the columns of <code>matrix</code>
189     * using the current <code>rankingAlgorithm</code>.
190     *
191     * @param matrix matrix to transform
192     * @return a rank-transformed matrix
193     */
194    private RealMatrix rankTransform(final RealMatrix matrix) {
195        RealMatrix transformed = null;
196
197        if (rankingAlgorithm instanceof NaturalRanking &&
198                ((NaturalRanking) rankingAlgorithm).getNanStrategy() == NaNStrategy.REMOVED) {
199            final Set<Integer> nanPositions = new HashSet<Integer>();
200            for (int i = 0; i < matrix.getColumnDimension(); i++) {
201                nanPositions.addAll(getNaNPositions(matrix.getColumn(i)));
202            }
203
204            // if we have found NaN values, we have to update the matrix size
205            if (!nanPositions.isEmpty()) {
206                transformed = new BlockRealMatrix(matrix.getRowDimension() - nanPositions.size(),
207                                                  matrix.getColumnDimension());
208                for (int i = 0; i < transformed.getColumnDimension(); i++) {
209                    transformed.setColumn(i, removeValues(matrix.getColumn(i), nanPositions));
210                }
211            }
212        }
213
214        if (transformed == null) {
215            transformed = matrix.copy();
216        }
217
218        for (int i = 0; i < transformed.getColumnDimension(); i++) {
219            transformed.setColumn(i, rankingAlgorithm.rank(transformed.getColumn(i)));
220        }
221
222        return transformed;
223    }
224
225    /**
226     * Returns a list containing the indices of NaN values in the input array.
227     *
228     * @param input the input array
229     * @return a list of NaN positions in the input array
230     */
231    private List<Integer> getNaNPositions(final double[] input) {
232        final List<Integer> positions = new ArrayList<Integer>();
233        for (int i = 0; i < input.length; i++) {
234            if (Double.isNaN(input[i])) {
235                positions.add(i);
236            }
237        }
238        return positions;
239    }
240
241    /**
242     * Removes all values from the input array at the specified indices.
243     *
244     * @param input the input array
245     * @param indices a set containing the indices to be removed
246     * @return the input array without the values at the specified indices
247     */
248    private double[] removeValues(final double[] input, final Set<Integer> indices) {
249        if (indices.isEmpty()) {
250            return input;
251        }
252        final double[] result = new double[input.length - indices.size()];
253        for (int i = 0, j = 0; i < input.length; i++) {
254            if (!indices.contains(i)) {
255                result[j++] = input[i];
256            }
257        }
258        return result;
259    }
260}