001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.stat.correlation;
019
020import java.util.ArrayList;
021import java.util.HashSet;
022import java.util.List;
023import java.util.Set;
024
025import org.apache.commons.math3.exception.DimensionMismatchException;
026import org.apache.commons.math3.exception.MathIllegalArgumentException;
027import org.apache.commons.math3.exception.util.LocalizedFormats;
028import org.apache.commons.math3.linear.BlockRealMatrix;
029import org.apache.commons.math3.linear.RealMatrix;
030import org.apache.commons.math3.stat.ranking.NaNStrategy;
031import org.apache.commons.math3.stat.ranking.NaturalRanking;
032import org.apache.commons.math3.stat.ranking.RankingAlgorithm;
033
034/**
035 * Spearman's rank correlation. This implementation performs a rank
036 * transformation on the input data and then computes {@link PearsonsCorrelation}
037 * on the ranked data.
038 * <p>
039 * By default, ranks are computed using {@link NaturalRanking} with default
040 * strategies for handling NaNs and ties in the data (NaNs maximal, ties averaged).
041 * The ranking algorithm can be set using a constructor argument.
042 *
043 * @since 2.0
044 */
045public class SpearmansCorrelation {
046
047    /** Input data */
048    private final RealMatrix data;
049
050    /** Ranking algorithm  */
051    private final RankingAlgorithm rankingAlgorithm;
052
053    /** Rank correlation */
054    private final PearsonsCorrelation rankCorrelation;
055
056    /**
057     * Create a SpearmansCorrelation without data.
058     */
059    public SpearmansCorrelation() {
060        this(new NaturalRanking());
061    }
062
063    /**
064     * Create a SpearmansCorrelation with the given ranking algorithm.
065     * <p>
066     * From version 4.0 onwards this constructor will throw an exception
067     * if the provided {@link NaturalRanking} uses a {@link NaNStrategy#REMOVED} strategy.
068     *
069     * @param rankingAlgorithm ranking algorithm
070     * @since 3.1
071     */
072    public SpearmansCorrelation(final RankingAlgorithm rankingAlgorithm) {
073        data = null;
074        this.rankingAlgorithm = rankingAlgorithm;
075        rankCorrelation = null;
076    }
077
078    /**
079     * Create a SpearmansCorrelation from the given data matrix.
080     *
081     * @param dataMatrix matrix of data with columns representing
082     * variables to correlate
083     */
084    public SpearmansCorrelation(final RealMatrix dataMatrix) {
085        this(dataMatrix, new NaturalRanking());
086    }
087
088    /**
089     * Create a SpearmansCorrelation with the given input data matrix
090     * and ranking algorithm.
091     * <p>
092     * From version 4.0 onwards this constructor will throw an exception
093     * if the provided {@link NaturalRanking} uses a {@link NaNStrategy#REMOVED} strategy.
094     *
095     * @param dataMatrix matrix of data with columns representing
096     * variables to correlate
097     * @param rankingAlgorithm ranking algorithm
098     */
099    public SpearmansCorrelation(final RealMatrix dataMatrix, final RankingAlgorithm rankingAlgorithm) {
100        this.rankingAlgorithm = rankingAlgorithm;
101        this.data = rankTransform(dataMatrix);
102        rankCorrelation = new PearsonsCorrelation(data);
103    }
104
105    /**
106     * Calculate the Spearman Rank Correlation Matrix.
107     *
108     * @return Spearman Rank Correlation Matrix
109     */
110    public RealMatrix getCorrelationMatrix() {
111        return rankCorrelation.getCorrelationMatrix();
112    }
113
114    /**
115     * Returns a {@link PearsonsCorrelation} instance constructed from the
116     * ranked input data. That is,
117     * <code>new SpearmansCorrelation(matrix).getRankCorrelation()</code>
118     * is equivalent to
119     * <code>new PearsonsCorrelation(rankTransform(matrix))</code> where
120     * <code>rankTransform(matrix)</code> is the result of applying the
121     * configured <code>RankingAlgorithm</code> to each of the columns of
122     * <code>matrix.</code>
123     *
124     * @return PearsonsCorrelation among ranked column data
125     */
126    public PearsonsCorrelation getRankCorrelation() {
127        return rankCorrelation;
128    }
129
130    /**
131     * Computes the Spearman's rank correlation matrix for the columns of the
132     * input matrix.
133     *
134     * @param matrix matrix with columns representing variables to correlate
135     * @return correlation matrix
136     */
137    public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) {
138        final RealMatrix matrixCopy = rankTransform(matrix);
139        return new PearsonsCorrelation().computeCorrelationMatrix(matrixCopy);
140    }
141
142    /**
143     * Computes the Spearman's rank correlation matrix for the columns of the
144     * input rectangular array.  The columns of the array represent values
145     * of variables to be correlated.
146     *
147     * @param matrix matrix with columns representing variables to correlate
148     * @return correlation matrix
149     */
150    public RealMatrix computeCorrelationMatrix(final double[][] matrix) {
151       return computeCorrelationMatrix(new BlockRealMatrix(matrix));
152    }
153
154    /**
155     * Computes the Spearman's rank correlation coefficient between the two arrays.
156     *
157     * @param xArray first data array
158     * @param yArray second data array
159     * @return Returns Spearman's rank correlation coefficient for the two arrays
160     * @throws DimensionMismatchException if the arrays lengths do not match
161     * @throws MathIllegalArgumentException if the array length is less than 2
162     */
163    public double correlation(final double[] xArray, final double[] yArray) {
164        if (xArray.length != yArray.length) {
165            throw new DimensionMismatchException(xArray.length, yArray.length);
166        } else if (xArray.length < 2) {
167            throw new MathIllegalArgumentException(LocalizedFormats.INSUFFICIENT_DIMENSION,
168                                                   xArray.length, 2);
169        } else {
170            double[] x = xArray;
171            double[] y = yArray;
172            if (rankingAlgorithm instanceof NaturalRanking &&
173                NaNStrategy.REMOVED == ((NaturalRanking) rankingAlgorithm).getNanStrategy()) {
174                final Set<Integer> nanPositions = new HashSet<Integer>();
175
176                nanPositions.addAll(getNaNPositions(xArray));
177                nanPositions.addAll(getNaNPositions(yArray));
178
179                x = removeValues(xArray, nanPositions);
180                y = removeValues(yArray, nanPositions);
181            }
182            return new PearsonsCorrelation().correlation(rankingAlgorithm.rank(x), rankingAlgorithm.rank(y));
183        }
184    }
185
186    /**
187     * Applies rank transform to each of the columns of <code>matrix</code>
188     * using the current <code>rankingAlgorithm</code>.
189     *
190     * @param matrix matrix to transform
191     * @return a rank-transformed matrix
192     */
193    private RealMatrix rankTransform(final RealMatrix matrix) {
194        RealMatrix transformed = null;
195
196        if (rankingAlgorithm instanceof NaturalRanking &&
197                ((NaturalRanking) rankingAlgorithm).getNanStrategy() == NaNStrategy.REMOVED) {
198            final Set<Integer> nanPositions = new HashSet<Integer>();
199            for (int i = 0; i < matrix.getColumnDimension(); i++) {
200                nanPositions.addAll(getNaNPositions(matrix.getColumn(i)));
201            }
202
203            // if we have found NaN values, we have to update the matrix size
204            if (!nanPositions.isEmpty()) {
205                transformed = new BlockRealMatrix(matrix.getRowDimension() - nanPositions.size(),
206                                                  matrix.getColumnDimension());
207                for (int i = 0; i < transformed.getColumnDimension(); i++) {
208                    transformed.setColumn(i, removeValues(matrix.getColumn(i), nanPositions));
209                }
210            }
211        }
212
213        if (transformed == null) {
214            transformed = matrix.copy();
215        }
216
217        for (int i = 0; i < transformed.getColumnDimension(); i++) {
218            transformed.setColumn(i, rankingAlgorithm.rank(transformed.getColumn(i)));
219        }
220
221        return transformed;
222    }
223
224    /**
225     * Returns a list containing the indices of NaN values in the input array.
226     *
227     * @param input the input array
228     * @return a list of NaN positions in the input array
229     */
230    private List<Integer> getNaNPositions(final double[] input) {
231        final List<Integer> positions = new ArrayList<Integer>();
232        for (int i = 0; i < input.length; i++) {
233            if (Double.isNaN(input[i])) {
234                positions.add(i);
235            }
236        }
237        return positions;
238    }
239
240    /**
241     * Removes all values from the input array at the specified indices.
242     *
243     * @param input the input array
244     * @param indices a set containing the indices to be removed
245     * @return the input array without the values at the specified indices
246     */
247    private double[] removeValues(final double[] input, final Set<Integer> indices) {
248        if (indices.isEmpty()) {
249            return input;
250        }
251        final double[] result = new double[input.length - indices.size()];
252        for (int i = 0, j = 0; i < input.length; i++) {
253            if (!indices.contains(i)) {
254                result[j++] = input[i];
255            }
256        }
257        return result;
258    }
259}