001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.statistics.inference;
018
019import java.lang.ref.SoftReference;
020import java.util.Arrays;
021import java.util.EnumSet;
022import java.util.Objects;
023import java.util.concurrent.locks.ReentrantLock;
024import java.util.stream.IntStream;
025import org.apache.commons.numbers.combinatorics.BinomialCoefficientDouble;
026import org.apache.commons.statistics.distribution.NormalDistribution;
027import org.apache.commons.statistics.ranking.NaNStrategy;
028import org.apache.commons.statistics.ranking.NaturalRanking;
029import org.apache.commons.statistics.ranking.RankingAlgorithm;
030import org.apache.commons.statistics.ranking.TiesStrategy;
031
032/**
033 * Implements the Mann-Whitney U test (also called Wilcoxon rank-sum test).
034 *
035 * @see <a href="https://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U_test">
036 * Mann-Whitney U test (Wikipedia)</a>
037 * @since 1.1
038 */
039public final class MannWhitneyUTest {
040    /** Limit on sample size for the exact p-value computation for the auto mode. */
041    private static final int AUTO_LIMIT = 50;
042    /** Ranking instance. */
043    private static final RankingAlgorithm RANKING = new NaturalRanking(NaNStrategy.FAILED, TiesStrategy.AVERAGE);
044    /** Value for an unset f computation. */
045    private static final double UNSET = -1;
046    /** An object to use for synchonization when accessing the cache of F. */
047    private static final ReentrantLock LOCK = new ReentrantLock();
048    /** A reference to a previously computed storage for f.
049     * Use of a SoftReference ensures this is garbage collected before an OutOfMemoryError.
050     * The value should only be accessed, checked for size and optionally
051     * modified when holding the lock. When the storage is determined to be the correct
052     * size it can be returned for read/write to the array when not holding the lock. */
053    private static SoftReference<double[][][]> cacheF = new SoftReference<>(null);
054    /** Default instance. */
055    private static final MannWhitneyUTest DEFAULT = new MannWhitneyUTest(
056        AlternativeHypothesis.TWO_SIDED, PValueMethod.AUTO, true, 0);
057
058    /** Alternative hypothesis. */
059    private final AlternativeHypothesis alternative;
060    /** Method to compute the p-value. */
061    private final PValueMethod pValueMethod;
062    /** Perform continuity correction. */
063    private final boolean continuityCorrection;
064    /** Expected location shift. */
065    private final double mu;
066
067    /**
068     * Result for the Mann-Whitney U test.
069     *
070     * <p>This class is immutable.
071     *
072     * @since 1.1
073     */
074    public static final class Result extends BaseSignificanceResult {
075        /** Flag indicating the data has tied values. */
076        private final boolean tiedValues;
077
078        /**
079         * Create an instance.
080         *
081         * @param statistic Test statistic.
082         * @param tiedValues Flag indicating the data has tied values.
083         * @param p Result p-value.
084         */
085        Result(double statistic, boolean tiedValues, double p) {
086            super(statistic, p);
087            this.tiedValues = tiedValues;
088        }
089
090        /**
091         * {@inheritDoc}
092         *
093         * <p>This is the U<sub>1</sub> statistic. Compute the U<sub>2</sub> statistic using
094         * the original sample lengths {@code n} and {@code m} using:
095         * <pre>
096         * u2 = (long) n * m - u1;
097         * </pre>
098         */
099        @Override
100        public double getStatistic() {
101            // Note: This method is here for documentation
102            return super.getStatistic();
103        }
104
105        /**
106         * Return {@code true} if the data had tied values.
107         *
108         * <p>Note: The exact computation cannot be used when there are tied values.
109         *
110         * @return {@code true} if there were tied values
111         */
112        public boolean hasTiedValues() {
113            return tiedValues;
114        }
115    }
116
117    /**
118     * @param alternative Alternative hypothesis.
119     * @param method P-value method.
120     * @param continuityCorrection true to perform continuity correction.
121     * @param mu Expected location shift.
122     */
123    private MannWhitneyUTest(AlternativeHypothesis alternative, PValueMethod method,
124        boolean continuityCorrection, double mu) {
125        this.alternative = alternative;
126        this.pValueMethod = method;
127        this.continuityCorrection = continuityCorrection;
128        this.mu = mu;
129    }
130
131    /**
132     * Return an instance using the default options.
133     *
134     * <ul>
135     * <li>{@link AlternativeHypothesis#TWO_SIDED}
136     * <li>{@link PValueMethod#AUTO}
137     * <li>{@link ContinuityCorrection#ENABLED}
138     * <li>{@linkplain #withMu(double) mu = 0}
139     * </ul>
140     *
141     * @return default instance
142     */
143    public static MannWhitneyUTest withDefaults() {
144        return DEFAULT;
145    }
146
147    /**
148     * Return an instance with the configured alternative hypothesis.
149     *
150     * @param v Value.
151     * @return an instance
152     */
153    public MannWhitneyUTest with(AlternativeHypothesis v) {
154        return new MannWhitneyUTest(Objects.requireNonNull(v), pValueMethod, continuityCorrection, mu);
155    }
156
157    /**
158     * Return an instance with the configured p-value method.
159     *
160     * @param v Value.
161     * @return an instance
162     * @throws IllegalArgumentException if the value is not in the allowed options or is null
163     */
164    public MannWhitneyUTest with(PValueMethod v) {
165        return new MannWhitneyUTest(alternative,
166            Arguments.checkOption(v, EnumSet.of(PValueMethod.AUTO, PValueMethod.EXACT, PValueMethod.ASYMPTOTIC)),
167            continuityCorrection, mu);
168    }
169
170    /**
171     * Return an instance with the configured continuity correction.
172     *
173     * <p>If {@link ContinuityCorrection#ENABLED ENABLED}, adjust the U rank statistic by
174     * 0.5 towards the mean value when computing the z-statistic if a normal approximation is used
175     * to compute the p-value.
176     *
177     * @param v Value.
178     * @return an instance
179     */
180    public MannWhitneyUTest with(ContinuityCorrection v) {
181        return new MannWhitneyUTest(alternative, pValueMethod,
182            Objects.requireNonNull(v) == ContinuityCorrection.ENABLED, mu);
183    }
184
185    /**
186     * Return an instance with the configured location shift {@code mu}.
187     *
188     * @param v Value.
189     * @return an instance
190     * @throws IllegalArgumentException if the value is not finite
191     */
192    public MannWhitneyUTest withMu(double v) {
193        return new MannWhitneyUTest(alternative, pValueMethod, continuityCorrection, Arguments.checkFinite(v));
194    }
195
196    /**
197     * Computes the Mann-Whitney U statistic comparing two independent
198     * samples possibly of different length.
199     *
200     * <p>This statistic can be used to perform a Mann-Whitney U test evaluating the
201     * null hypothesis that the two independent samples differ by a location shift of {@code mu}.
202     *
203     * <p>This returns the U<sub>1</sub> statistic. Compute the U<sub>2</sub> statistic using:
204     * <pre>
205     * u2 = (long) x.length * y.length - u1;
206     * </pre>
207     *
208     * @param x First sample values.
209     * @param y Second sample values.
210     * @return Mann-Whitney U<sub>1</sub> statistic
211     * @throws IllegalArgumentException if {@code x} or {@code y} are zero-length; or contain
212     * NaN values.
213     * @see #withMu(double)
214     */
215    public double statistic(double[] x, double[] y) {
216        checkSamples(x, y);
217
218        final double[] z = concatenateSamples(mu, x, y);
219        final double[] ranks = RANKING.apply(z);
220
221        // The ranks for x is in the first x.length entries in ranks because x
222        // is in the first x.length entries in z
223        final double sumRankX = Arrays.stream(ranks).limit(x.length).sum();
224
225        // U1 = R1 - (n1 * (n1 + 1)) / 2 where R1 is sum of ranks for sample 1,
226        // e.g. x, n1 is the number of observations in sample 1.
227        return sumRankX - ((long) x.length * (x.length + 1)) * 0.5;
228    }
229
230    /**
231     * Performs a Mann-Whitney U test comparing the location for two independent
232     * samples. The location is specified using {@link #withMu(double) mu}.
233     *
234     * <p>The test is defined by the {@link AlternativeHypothesis}.
235     * <ul>
236     * <li>'two-sided': the distribution underlying {@code (x - mu)} is not equal to the
237     * distribution underlying {@code y}.
238     * <li>'greater': the distribution underlying {@code (x - mu)} is stochastically greater than
239     * the distribution underlying {@code y}.
240     * <li>'less': the distribution underlying {@code (x - mu)} is stochastically less than
241     * the distribution underlying {@code y}.
242     * </ul>
243     *
244     * <p>If the p-value method is {@linkplain PValueMethod#AUTO auto} an exact p-value is
245     * computed if the samples contain less than 50 values; otherwise a normal
246     * approximation is used.
247     *
248     * <p>Computation of the exact p-value is only valid if there are no tied
249     * ranks in the data; otherwise the p-value resorts to the asymptotic
250     * approximation using a tie correction and an optional continuity correction.
251     *
252     * <p><strong>Note: </strong>
253     * Exact computation requires tabulation of values not exceeding size
254     * {@code (n+1)*(m+1)*(u+1)} where {@code u} is the minimum of the U<sub>1</sub> and
255     * U<sub>2</sub> statistics and {@code n} and {@code m} are the sample sizes.
256     * This may use a very large amount of memory and result in an {@link OutOfMemoryError}.
257     * Exact computation requires a finite binomial coefficient {@code binom(n+m, m)}
258     * which is limited to {@code n+m <= 1029} for any {@code n} and {@code m},
259     * or {@code min(n, m) <= 37} for any {@code max(n, m)}.
260     * An {@link OutOfMemoryError} is not expected using the
261     * limits configured for the {@linkplain PValueMethod#AUTO auto} p-value computation
262     * as the maximum required memory is approximately 23 MiB.
263     *
264     * @param x First sample values.
265     * @param y Second sample values.
266     * @return test result
267     * @throws IllegalArgumentException if {@code x} or {@code y} are zero-length; or contain
268     * NaN values.
269     * @throws OutOfMemoryError if the exact computation is <em>user-requested</em> for
270     * large samples and there is not enough memory.
271     * @see #statistic(double[], double[])
272     * @see #withMu(double)
273     * @see #with(AlternativeHypothesis)
274     * @see #with(ContinuityCorrection)
275     */
276    public Result test(double[] x, double[] y) {
277        // Computation as above. The ranks are required for tie correction.
278        checkSamples(x, y);
279        final double[] z = concatenateSamples(mu, x, y);
280        final double[] ranks = RANKING.apply(z);
281        final double sumRankX = Arrays.stream(ranks).limit(x.length).sum();
282        final double u1 = sumRankX - ((long) x.length * (x.length + 1)) * 0.5;
283
284        final double c = WilcoxonSignedRankTest.calculateTieCorrection(ranks);
285        final boolean tiedValues = c != 0;
286
287        PValueMethod method = pValueMethod;
288        final int n = x.length;
289        final int m = y.length;
290        if (method == PValueMethod.AUTO && Math.max(n, m) < AUTO_LIMIT) {
291            method = PValueMethod.EXACT;
292        }
293        // Exact p requires no ties.
294        // The method will fail-fast if the computation is not possible due
295        // to the size of the data.
296        double p = method == PValueMethod.EXACT && !tiedValues ?
297            calculateExactPValue(u1, n, m, alternative) : -1;
298        if (p < 0) {
299            p = calculateAsymptoticPValue(u1, n, m, c);
300        }
301        return new Result(u1, tiedValues, p);
302    }
303
304    /**
305     * Ensures that the provided arrays fulfil the assumptions.
306     *
307     * @param x First sample values.
308     * @param y Second sample values.
309     * @throws IllegalArgumentException if {@code x} or {@code y} are zero-length.
310     */
311    private static void checkSamples(double[] x, double[] y) {
312        Arguments.checkValuesRequiredSize(x.length, 1);
313        Arguments.checkValuesRequiredSize(y.length, 1);
314    }
315
316    /**
317     * Concatenate the samples into one array. Subtract {@code mu} from the first sample.
318     *
319     * @param mu Expected difference between means.
320     * @param x First sample values.
321     * @param y Second sample values.
322     * @return concatenated array
323     */
324    private static double[] concatenateSamples(double mu, double[] x, double[] y) {
325        final double[] z = new double[x.length + y.length];
326        System.arraycopy(x, 0, z, 0, x.length);
327        System.arraycopy(y, 0, z, x.length, y.length);
328        if (mu != 0) {
329            for (int i = 0; i < x.length; i++) {
330                z[i] -= mu;
331            }
332        }
333        return z;
334    }
335
336    /**
337     * Calculate the asymptotic p-value using a Normal approximation.
338     *
339     * @param u Mann-Whitney U value.
340     * @param n1 Number of subjects in first sample.
341     * @param n2 Number of subjects in second sample.
342     * @param c Tie-correction
343     * @return two-sided asymptotic p-value
344     */
345    private double calculateAsymptoticPValue(double u, int n1, int n2, double c) {
346        // Use long to avoid overflow
347        final long n1n2 = (long) n1 * n2;
348        final long n = (long) n1 + n2;
349
350        // https://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U_test#Normal_approximation_and_tie_correction
351        final double e = n1n2 * 0.5;
352        final double variance = (n1n2 / 12.0) * ((n + 1.0) - c / n / (n - 1));
353
354        double z = u - e;
355        if (continuityCorrection) {
356            // +/- 0.5 is a continuity correction towards the expected.
357            if (alternative == AlternativeHypothesis.GREATER_THAN) {
358                z -= 0.5;
359            } else if (alternative == AlternativeHypothesis.LESS_THAN) {
360                z += 0.5;
361            } else {
362                // two-sided. Shift towards the expected of zero.
363                // Use of signum ignores x==0 (i.e. not copySign(0.5, z))
364                z -= Math.signum(z) * 0.5;
365            }
366        }
367        z /= Math.sqrt(variance);
368
369        final NormalDistribution standardNormal = NormalDistribution.of(0, 1);
370        if (alternative == AlternativeHypothesis.GREATER_THAN) {
371            return standardNormal.survivalProbability(z);
372        }
373        if (alternative == AlternativeHypothesis.LESS_THAN) {
374            return standardNormal.cumulativeProbability(z);
375        }
376        // two-sided
377        return 2 * standardNormal.survivalProbability(Math.abs(z));
378    }
379
380    /**
381     * Calculate the exact p-value. If the value cannot be computed this returns -1.
382     *
383     * <p>Note: Computation may run out of memory during array allocation, or method
384     * recursion.
385     *
386     * @param u Mann-Whitney U value.
387     * @param m Number of subjects in first sample.
388     * @param n Number of subjects in second sample.
389     * @param alternative Alternative hypothesis.
390     * @return exact p-value (or -1) (two-sided, greater, or less using the options)
391     */
392    // package-private for testing
393    static double calculateExactPValue(double u, int m, int n, AlternativeHypothesis alternative) {
394        // Check the computation can be attempted.
395        // u must be an integer
396        if ((int) u != u) {
397            return -1;
398        }
399        // Note: n+m will not overflow as we concatenated the samples to a single array.
400        final double binom = BinomialCoefficientDouble.value(n + m, m);
401        if (binom == Double.POSITIVE_INFINITY) {
402            return -1;
403        }
404
405        // Use u_min for the CDF.
406        final int u1 = (int) u;
407        final int u2 = (int) ((long) m * n - u1);
408        // Use m < n to support symmetry.
409        final int n1 = Math.min(m, n);
410        final int n2 = Math.max(m, n);
411
412        // Return the correct side:
413        if (alternative == AlternativeHypothesis.GREATER_THAN) {
414            // sf(u1 - 1)
415            return sf(u1 - 1, u2 + 1, n1, n2, binom);
416        }
417        if (alternative == AlternativeHypothesis.LESS_THAN) {
418            // cdf(u1)
419            return cdf(u1, u2, n1, n2, binom);
420        }
421        // two-sided: 2 * sf(max(u1, u2) - 1) or 2 * cdf(min(u1, u2))
422        final double p = 2 * computeCdf(Math.min(u1, u2), n1, n2, binom);
423        // Clip to range: [0, 1]
424        return Math.min(1, p);
425    }
426
427    /**
428     * Compute the cumulative density function of the Mann-Whitney U1 statistic.
429     * The U2 statistic is passed for convenience to exploit symmetry in the distribution.
430     *
431     * @param u1 Mann-Whitney U1 statistic
432     * @param u2 Mann-Whitney U2 statistic
433     * @param m First sample size.
434     * @param n Second sample size.
435     * @param binom binom(n+m, m) (must be finite)
436     * @return {@code Pr(X <= k)}
437     */
438    private static double cdf(int u1, int u2, int m, int n, double binom) {
439        // Exploit symmetry. Note the distribution is discrete thus requiring (u2 - 1).
440        return u2 > u1 ?
441            computeCdf(u1, m, n, binom) :
442            1 - computeCdf(u2 - 1, m, n, binom);
443    }
444
445    /**
446     * Compute the survival function of the Mann-Whitney U1 statistic.
447     * The U2 statistic is passed for convenience to exploit symmetry in the distribution.
448     *
449     * @param u1 Mann-Whitney U1 statistic
450     * @param u2 Mann-Whitney U2 statistic
451     * @param m First sample size.
452     * @param n Second sample size.
453     * @param binom binom(n+m, m) (must be finite)
454     * @return {@code Pr(X > k)}
455     */
456    private static double sf(int u1, int u2, int m, int n, double binom) {
457        // Opposite of the CDF
458        return u2 > u1 ?
459            1 - computeCdf(u1, m, n, binom) :
460            computeCdf(u2 - 1, m, n, binom);
461    }
462
463    /**
464     * Compute the cumulative density function of the Mann-Whitney U statistic.
465     *
466     * <p>This should be called with the lower of U1 or U2 for computational efficiency.
467     *
468     * <p>Uses the recursive formula provided in Bucchianico, A.D, (1999)
469     * Combinatorics, computer algebra and the Wilcoxon-Mann-Whitney test, Journal
470     * of Statistical Planning and Inference, Volume 79, Issue 2, 349-364.
471     *
472     * @param k Mann-Whitney U statistic
473     * @param m First sample size.
474     * @param n Second sample size.
475     * @param binom binom(n+m, m) (must be finite)
476     * @return {@code Pr(X <= k)}
477     */
478    private static double computeCdf(int k, int m, int n, double binom) {
479        // Theorem 2.5:
480        // f(m, n, k) = 0 if k < 0, m < 0, n < 0, k > nm
481        if (k < 0) {
482            return 0;
483        }
484        // Recursively compute f(m, n, k)
485        final double[][][] f = getF(m, n, k);
486
487        // P(X=k) = f(m, n, k) / binom(m+n, m)
488        // P(X<=k) = sum_0^k (P(X=i))
489
490        // Called with k = min(u1, u2) : max(p) ~ 0.5 so no need to clip to [0, 1]
491        return IntStream.rangeClosed(0, k).mapToDouble(i -> fmnk(f, m, n, i)).sum() / binom;
492    }
493
494    /**
495     * Gets the storage for f(m, n, k).
496     *
497     * <p>This may be cached for performance.
498     *
499     * @param m M.
500     * @param n N.
501     * @param k K.
502     * @return the storage for f
503     */
504    private static double[][][] getF(int m, int n, int k) {
505        // Obtain any previous computation of f and expand it if required.
506        // F is only modified within this synchronized block.
507        // Any concurrent threads using a reference returned by this method
508        // will not receive an index out-of-bounds as f is only ever expanded.
509        try {
510            LOCK.lock();
511            // Note: f(x<m, y<n, z<k) is always the same.
512            // Cache the array and re-use any previous computation.
513            double[][][] f = cacheF.get();
514
515            // Require:
516            // f = new double[m + 1][n + 1][k + 1]
517            // f(m, n, 0) == 1; otherwise -1 if not computed
518            // m+n <= 1029 for any m,n; k < mn/2 (due to symmetry using min(u1, u2))
519            // Size m=n=515: approximately 516^2 * 515^2/2 = 398868 doubles ~ 3.04 GiB
520            if (f == null) {
521                f = new double[m + 1][n + 1][k + 1];
522                for (final double[][] a : f) {
523                    for (final double[] b : a) {
524                        initialize(b);
525                    }
526                }
527                // Cache for reuse.
528                cacheF = new SoftReference<>(f);
529                return f;
530            }
531
532            // Grow if required: m1 < m+1 => m1-(m+1) < 0 => m1 - m < 1
533            final int m1 = f.length;
534            final int n1 = f[0].length;
535            final int k1 = f[0][0].length;
536            final boolean growM = m1 - m < 1;
537            final boolean growN = n1 - n < 1;
538            final boolean growK = k1 - k < 1;
539            if (growM | growN | growK) {
540                // Some part of the previous f is too small.
541                // Atomically grow without destroying the previous computation.
542                // Any other thread using the previous f will not go out of bounds
543                // by keeping the new f dimensions at least as large.
544                // Note: Doing this in-place allows the memory to be gradually
545                // increased rather than allocating a new [m + 1][n + 1][k + 1]
546                // and copying all old values.
547                final int sn = Math.max(n1, n + 1);
548                final int sk = Math.max(k1, k + 1);
549                if (growM) {
550                    // Entirely new region
551                    f = Arrays.copyOf(f, m + 1);
552                    for (int x = m1; x <= m; x++) {
553                        f[x] = new double[sn][sk];
554                        for (final double[] b : f[x]) {
555                            initialize(b);
556                        }
557                    }
558                }
559                // Expand previous in place if required
560                if (growN) {
561                    for (int x = 0; x < m1; x++) {
562                        f[x] = Arrays.copyOf(f[x], sn);
563                        for (int y = n1; y < sn; y++) {
564                            final double[] b = f[x][y] = new double[sk];
565                            initialize(b);
566                        }
567                    }
568                }
569                if (growK) {
570                    for (int x = 0; x < m1; x++) {
571                        for (int y = 0; y < n1; y++) {
572                            final double[] b = f[x][y] = Arrays.copyOf(f[x][y], sk);
573                            for (int z = k1; z < sk; z++) {
574                                b[z] = UNSET;
575                            }
576                        }
577                    }
578                }
579                // Avoided an OutOfMemoryError. Cache for reuse.
580                cacheF = new SoftReference<>(f);
581            }
582            return f;
583        } finally {
584            LOCK.unlock();
585        }
586    }
587
588    /**
589     * Initialize the array for f(m, n, x).
590     * Set value to 1 for x=0; otherwise {@link #UNSET}.
591     *
592     * @param fmn Array.
593     */
594    private static void initialize(double[] fmn) {
595        Arrays.fill(fmn, UNSET);
596        // f(m, n, 0) == 1 if m >= 0, n >= 0
597        fmn[0] = 1;
598    }
599
600    /**
601     * Compute f(m; n; k), the number of subsets of {0; 1; ...; n} with m elements such
602     * that the elements of this subset add up to k.
603     *
604     * <p>The function is computed recursively.
605     *
606     * @param f Tabulated values of f[m][n][k].
607     * @param m M
608     * @param n N
609     * @param k K
610     * @return f(m; n; k)
611     */
612    private static double fmnk(double[][][] f, int m, int n, int k) {
613        // Theorem 2.5:
614        // Omit conditions that will not be met: k > mn
615        // f(m, n, k) = 0 if k < 0, m < 0, n < 0
616        if ((k | m | n) < 0) {
617            return 0;
618        }
619        // Compute on demand
620        double fmnk = f[m][n][k];
621        if (fmnk < 0) {
622            // f(m, n, 0) == 1 if m >= 0, n >= 0
623            // This is already computed.
624
625            // Recursion from formula (3):
626            // f(m, n, k) = f(m-1, n, k-n) + f(m, n-1, k)
627            f[m][n][k] = fmnk = fmnk(f, m - 1, n, k - n) + fmnk(f, m, n - 1, k);
628        }
629        return fmnk;
630    }
631}