001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.stat.correlation;
019
020import java.util.ArrayList;
021import java.util.HashSet;
022import java.util.List;
023import java.util.Set;
024
025import org.apache.commons.math3.exception.DimensionMismatchException;
026import org.apache.commons.math3.exception.MathIllegalArgumentException;
027import org.apache.commons.math3.exception.util.LocalizedFormats;
028import org.apache.commons.math3.linear.BlockRealMatrix;
029import org.apache.commons.math3.linear.RealMatrix;
030import org.apache.commons.math3.stat.ranking.NaNStrategy;
031import org.apache.commons.math3.stat.ranking.NaturalRanking;
032import org.apache.commons.math3.stat.ranking.RankingAlgorithm;
033
034/**
035 * Spearman's rank correlation. This implementation performs a rank
036 * transformation on the input data and then computes {@link PearsonsCorrelation}
037 * on the ranked data.
038 * <p>
039 * By default, ranks are computed using {@link NaturalRanking} with default
040 * strategies for handling NaNs and ties in the data (NaNs maximal, ties averaged).
041 * The ranking algorithm can be set using a constructor argument.
042 *
043 * @since 2.0
044 */
045public class SpearmansCorrelation {
046
047    /** Input data */
048    private final RealMatrix data;
049
050    /** Ranking algorithm  */
051    private final RankingAlgorithm rankingAlgorithm;
052
053    /** Rank correlation */
054    private final PearsonsCorrelation rankCorrelation;
055
056    /**
057     * Create a SpearmansCorrelation without data.
058     */
059    public SpearmansCorrelation() {
060        this(new NaturalRanking());
061    }
062
063    /**
064     * Create a SpearmansCorrelation with the given ranking algorithm.
065     * <p>
066     * From version 4.0 onwards this constructor will throw an exception
067     * if the provided {@link NaturalRanking} uses a {@link NaNStrategy#REMOVED} strategy.
068     *
069     * @param rankingAlgorithm ranking algorithm
070     * @since 3.1
071     */
072    public SpearmansCorrelation(final RankingAlgorithm rankingAlgorithm) {
073        data = null;
074        this.rankingAlgorithm = rankingAlgorithm;
075        rankCorrelation = null;
076    }
077
078    /**
079     * Create a SpearmansCorrelation from the given data matrix.
080     *
081     * @param dataMatrix matrix of data with columns representing
082     * variables to correlate
083     */
084    public SpearmansCorrelation(final RealMatrix dataMatrix) {
085        this(dataMatrix, new NaturalRanking());
086    }
087
088    /**
089     * Create a SpearmansCorrelation with the given input data matrix
090     * and ranking algorithm.
091     * <p>
092     * From version 4.0 onwards this constructor will throw an exception
093     * if the provided {@link NaturalRanking} uses a {@link NaNStrategy#REMOVED} strategy.
094     *
095     * @param dataMatrix matrix of data with columns representing
096     * variables to correlate
097     * @param rankingAlgorithm ranking algorithm
098     */
099    public SpearmansCorrelation(final RealMatrix dataMatrix, final RankingAlgorithm rankingAlgorithm) {
100        this.rankingAlgorithm = rankingAlgorithm;
101        this.data = rankTransform(dataMatrix);
102        rankCorrelation = new PearsonsCorrelation(data);
103    }
104
105    /**
106     * Calculate the Spearman Rank Correlation Matrix.
107     *
108     * @return Spearman Rank Correlation Matrix
109     * @throws NullPointerException if this instance was created with no data
110     */
111    public RealMatrix getCorrelationMatrix() {
112        return rankCorrelation.getCorrelationMatrix();
113    }
114
115    /**
116     * Returns a {@link PearsonsCorrelation} instance constructed from the
117     * ranked input data. That is,
118     * <code>new SpearmansCorrelation(matrix).getRankCorrelation()</code>
119     * is equivalent to
120     * <code>new PearsonsCorrelation(rankTransform(matrix))</code> where
121     * <code>rankTransform(matrix)</code> is the result of applying the
122     * configured <code>RankingAlgorithm</code> to each of the columns of
123     * <code>matrix.</code>
124     *
125     * <p>Returns null if this instance was created with no data.</p>
126     *
127     * @return PearsonsCorrelation among ranked column data
128     */
129    public PearsonsCorrelation getRankCorrelation() {
130        return rankCorrelation;
131    }
132
133    /**
134     * Computes the Spearman's rank correlation matrix for the columns of the
135     * input matrix.
136     *
137     * @param matrix matrix with columns representing variables to correlate
138     * @return correlation matrix
139     */
140    public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) {
141        final RealMatrix matrixCopy = rankTransform(matrix);
142        return new PearsonsCorrelation().computeCorrelationMatrix(matrixCopy);
143    }
144
145    /**
146     * Computes the Spearman's rank correlation matrix for the columns of the
147     * input rectangular array.  The columns of the array represent values
148     * of variables to be correlated.
149     *
150     * @param matrix matrix with columns representing variables to correlate
151     * @return correlation matrix
152     */
153    public RealMatrix computeCorrelationMatrix(final double[][] matrix) {
154       return computeCorrelationMatrix(new BlockRealMatrix(matrix));
155    }
156
157    /**
158     * Computes the Spearman's rank correlation coefficient between the two arrays.
159     *
160     * @param xArray first data array
161     * @param yArray second data array
162     * @return Returns Spearman's rank correlation coefficient for the two arrays
163     * @throws DimensionMismatchException if the arrays lengths do not match
164     * @throws MathIllegalArgumentException if the array length is less than 2
165     */
166    public double correlation(final double[] xArray, final double[] yArray) {
167        if (xArray.length != yArray.length) {
168            throw new DimensionMismatchException(xArray.length, yArray.length);
169        } else if (xArray.length < 2) {
170            throw new MathIllegalArgumentException(LocalizedFormats.INSUFFICIENT_DIMENSION,
171                                                   xArray.length, 2);
172        } else {
173            double[] x = xArray;
174            double[] y = yArray;
175            if (rankingAlgorithm instanceof NaturalRanking &&
176                NaNStrategy.REMOVED == ((NaturalRanking) rankingAlgorithm).getNanStrategy()) {
177                final Set<Integer> nanPositions = new HashSet<Integer>();
178
179                nanPositions.addAll(getNaNPositions(xArray));
180                nanPositions.addAll(getNaNPositions(yArray));
181
182                x = removeValues(xArray, nanPositions);
183                y = removeValues(yArray, nanPositions);
184            }
185            return new PearsonsCorrelation().correlation(rankingAlgorithm.rank(x), rankingAlgorithm.rank(y));
186        }
187    }
188
189    /**
190     * Applies rank transform to each of the columns of <code>matrix</code>
191     * using the current <code>rankingAlgorithm</code>.
192     *
193     * @param matrix matrix to transform
194     * @return a rank-transformed matrix
195     */
196    private RealMatrix rankTransform(final RealMatrix matrix) {
197        RealMatrix transformed = null;
198
199        if (rankingAlgorithm instanceof NaturalRanking &&
200                ((NaturalRanking) rankingAlgorithm).getNanStrategy() == NaNStrategy.REMOVED) {
201            final Set<Integer> nanPositions = new HashSet<Integer>();
202            for (int i = 0; i < matrix.getColumnDimension(); i++) {
203                nanPositions.addAll(getNaNPositions(matrix.getColumn(i)));
204            }
205
206            // if we have found NaN values, we have to update the matrix size
207            if (!nanPositions.isEmpty()) {
208                transformed = new BlockRealMatrix(matrix.getRowDimension() - nanPositions.size(),
209                                                  matrix.getColumnDimension());
210                for (int i = 0; i < transformed.getColumnDimension(); i++) {
211                    transformed.setColumn(i, removeValues(matrix.getColumn(i), nanPositions));
212                }
213            }
214        }
215
216        if (transformed == null) {
217            transformed = matrix.copy();
218        }
219
220        for (int i = 0; i < transformed.getColumnDimension(); i++) {
221            transformed.setColumn(i, rankingAlgorithm.rank(transformed.getColumn(i)));
222        }
223
224        return transformed;
225    }
226
227    /**
228     * Returns a list containing the indices of NaN values in the input array.
229     *
230     * @param input the input array
231     * @return a list of NaN positions in the input array
232     */
233    private List<Integer> getNaNPositions(final double[] input) {
234        final List<Integer> positions = new ArrayList<Integer>();
235        for (int i = 0; i < input.length; i++) {
236            if (Double.isNaN(input[i])) {
237                positions.add(i);
238            }
239        }
240        return positions;
241    }
242
243    /**
244     * Removes all values from the input array at the specified indices.
245     *
246     * @param input the input array
247     * @param indices a set containing the indices to be removed
248     * @return the input array without the values at the specified indices
249     */
250    private double[] removeValues(final double[] input, final Set<Integer> indices) {
251        if (indices.isEmpty()) {
252            return input;
253        }
254        final double[] result = new double[input.length - indices.size()];
255        for (int i = 0, j = 0; i < input.length; i++) {
256            if (!indices.contains(i)) {
257                result[j++] = input[i];
258            }
259        }
260        return result;
261    }
262}