001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.statistics.inference;
018
019import java.util.Arrays;
020import java.util.Objects;
021import java.util.function.Consumer;
022import java.util.function.DoublePredicate;
023import java.util.function.DoubleUnaryOperator;
024import java.util.function.IntToDoubleFunction;
025import org.apache.commons.numbers.combinatorics.LogBinomialCoefficient;
026import org.apache.commons.statistics.inference.BrentOptimizer.PointValuePair;
027
028/**
029 * Implements an unconditioned exact test for a contingency table.
030 *
031 * <p>Performs an exact test for the statistical significance of the association (contingency)
032 * between two kinds of categorical classification. A 2x2 contingency table is:
033 *
034 * <p>\[ \left[ {\begin{array}{cc}
035 *         a &amp; b \\
036 *         c &amp; d \\
037 *       \end{array} } \right] \]
038 *
039 * <p>This test applies to the case of a 2x2 contingency table with one margin fixed. Note that
040 * if both margins are fixed (the row sums and column sums are not random)
041 * then Fisher's exact test can be applied.
042 *
043 * <p>This implementation fixes the column sums \( m = a + c \) and \( n = b + d \).
044 * All possible tables can be created using \( 0 \le a \le m \) and \( 0 \le b \le n \).
045 * The random values \( a \) and \( b \) follow a binomial distribution with probabilities
046 * \( p_0 \) and \( p_1 \) such that \( a \sim B(m, p_0) \) and \( b \sim B(n, p_1) \).
047 * The p-value of the 2x2 table is the product of two binomials:
048 *
049 * <p>\[ \begin{aligned}
050 *       p &amp;= Pr(a; m, p_0) \times Pr(b; n, p_1) \\
051 *         &amp;= \binom{m}{a} p_0^a (1-p_0)^{m-a} \times \binom{n}{b} p_1^b (1-p_1)^{n-b} \end{aligned} \]
052 *
053 * <p>For the binomial model, the null hypothesis is the two nuisance parameters are equal
054 * \( p_0 = p_1 = \pi\), with \( \pi \) the probability for equal proportions, and the probability
055 * of any single table is:
056 *
057 * <p>\[ p = \binom{m}{a} \binom{n}{b} \pi^{a+b} (1-\pi)^{m+n-a-b} \]
058 *
059 * <p>The p-value of the observed table is calculated by maximising the sum of the as or more
060 * extreme tables over the domain of the nuisance parameter \( 0 \lt \pi \lt 1 \):
061 *
062 * <p>\[ p(a, b) = \sum_{i,j} \binom{m}{i} \binom{n}{j} \pi^{i+j} (1-\pi)^{m+n-i-j} \]
063 *
064 * <p>where table \( (i,j) \) is as or more extreme than the observed table \( (a, b) \). The test
065 * can be configured to select more extreme tables using various {@linkplain Method methods}.
066 *
067 * <p>Note that the sum of the joint binomial distribution is a univariate function for
068 * the nuisance parameter \( \pi \). This function may have many local maxima and the
069 * search enumerates the range with a configured {@linkplain #withInitialPoints(int)
070 * number of points}. The best candidates are optionally used as the start point for an
071 * {@linkplain #withOptimize(boolean) optimized} search for a local maxima.
072 *
073 * <p>References:
074 * <ol>
075 * <li>
076 * Barnard, G.A. (1947).
077 * <a href="https://doi.org/10.1093/biomet/34.1-2.123">Significance tests for 2x2 tables.</a>
078 * Biometrika, 34, Issue 1-2, 123–138.
079 * <li>
080 * Boschloo, R.D. (1970).
081 * <a href="https://doi.org/10.1111/j.1467-9574.1970.tb00104.x">Raised conditional level of
082 * significance for the 2 × 2-table when testing the equality of two probabilities.</a>
083 * Statistica neerlandica, 24(1), 1–9.
084 * <li>
085 * Suisaa, A and Shuster, J.J. (1985).
086 * <a href="https://doi.org/10.2307/2981892">Exact Unconditional Sample Sizes
087 * for the 2 × 2 Binomial Trial.</a>
088 * Journal of the Royal Statistical Society. Series A (General), 148(4), 317-327.
089 * </ol>
090 *
091 * @see FisherExactTest
092 * @see <a href="https://en.wikipedia.org/wiki/Boschloo%27s_test">Boschloo&#39;s test (Wikipedia)</a>
093 * @see <a href="https://en.wikipedia.org/wiki/Barnard%27s_test">Barnard&#39;s test (Wikipedia)</a>
094 * @since 1.1
095 */
096public final class UnconditionedExactTest {
097    /**
098     * Default instance.
099     *
100     * <p>SciPy's boschloo_exact and barnard_exact tests use 32 points in the interval [0,
101     * 1) The R Exact package uses 100 in the interval [1e-5, 1-1e-5]. Barnards 1947 paper
102     * describes the nuisance parameter in the open interval {@code 0 < pi < 1}. Here we
103     * respect the open-interval for the initial candidates and ignore 0 and 1. The
104     * initial bounds used are the same as the R Exact package. We closely match the inner
105     * 31 points from SciPy by using 33 points by default.
106     */
107    private static final UnconditionedExactTest DEFAULT = new UnconditionedExactTest(
108        AlternativeHypothesis.TWO_SIDED, Method.BOSCHLOO, 33, true);
109    /** Lower bound for the enumerated interval. The upper bound is {@code 1 - lower}. */
110    private static final double LOWER_BOUND = 1e-5;
111    /** Relative epsilon for the Brent solver. This is limited for a univariate function
112     * to approximately sqrt(eps) with eps = 2^-52. */
113    private static final double SOLVER_RELATIVE_EPS = 1.4901161193847656E-8;
114    /** Fraction of the increment (interval between enumerated points) to initialise the bracket
115     * for the minima. Note the minima should lie between x +/- increment. The bracket should
116     * search within this range. Set to 1/8 and so the initial point of the bracket is
117     * approximately 1.61 * 1/8 = 0.2 of the increment away from initial points a or b. */
118    private static final double INC_FRACTION = 0.125;
119    /** Maximum number of candidate to optimize. This is a safety limit to avoid excess
120     * optimization. Only candidates within a relative tolerance of the best candidate are
121     * stored. If the number of candidates exceeds this value then many candidates have a
122     * very similar p-value and the top candidates will be optimized. Using a value of 3
123     * allows at least one other candidate to be optimized when there is two-fold
124     * symmetry in the energy function. */
125    private static final int MAX_CANDIDATES = 3;
126    /** Relative distance of candidate minima from the lowest candidate. Used to exclude
127     * poor candidates from optimization. */
128    private static final double MINIMA_EPS = 0.02;
129    /** The maximum number of tables. This is limited by the maximum number of indices that
130     * can be maintained in memory. Potentially up to this number of tables must be tracked
131     * during computation of the p-value for as or more extreme tables. The limit is set
132     * using the same limit for maximum capacity as java.util.ArrayList. In practice any
133     * table anywhere near this limit can be computed using an alternative such as a chi-squared
134     * or g test. */
135    private static final int MAX_TABLES = Integer.MAX_VALUE - 8;
136    /** Error message text for zero column sums. */
137    private static final String COLUMN_SUM = "Column sum";
138
139    /** Alternative hypothesis. */
140    private final AlternativeHypothesis alternative;
141    /** Method to identify more extreme tables. */
142    private final Method method;
143    /** Number of initial points. */
144    private final int points;
145    /** Option to optimize the best initial point(s). */
146    private final boolean optimize;
147
148    /**
149     * Define the method to determine the more extreme tables.
150     *
151     * @since 1.1
152     */
153    public enum Method {
154        /**
155         * Uses the test statistic from a Z-test using a pooled variance.
156         *
157         * <p>\[ T(X) = \frac{\hat{p}_0 - \hat{p}_1}{\sqrt{\hat{p}(1 - \hat{p}) (\frac{1}{m} + \frac{1}{n})}} \]
158         *
159         * <p>where \( \hat{p}_0 = a / m \), \( \hat{p}_1 = b / n \), and
160         * \( \hat{p} = (a+b) / (m+n) \) are the estimators of \( p_0 \), \( p_1 \) and the
161         * pooled probability \( p \) assuming \( p_0 = p_1 \).
162         *
163         * <p>The more extreme tables are identified using the {@link AlternativeHypothesis}:
164         * <ul>
165         * <li>greater: \( T(X) \ge T(X_0) \)
166         * <li>less: \( T(X) \le T(X_0) \)
167         * <li>two-sided: \( | T(X) | \ge | T(X_0) | \)
168         * </ul>
169         *
170         * <p>The use of the Z statistic was suggested by Suissa and Shuster (1985).
171         * This method is uniformly more powerful than Fisher's test for balanced designs
172         * (\( m = n \)).
173         */
174        Z_POOLED,
175
176        /**
177         * Uses the test statistic from a Z-test using an unpooled variance.
178         *
179         * <p>\[ T(X) = \frac{\hat{p}_0 - \hat{p}_1}
180         * {\sqrt{ \frac{\hat{p}_0(1 - \hat{p}_0)}{m} + \frac{\hat{p}_1(1 - \hat{p}_1)}{n}} } \]
181         *
182         * <p>where \( \hat{p}_0 = a / m \) and \( \hat{p}_1 = b / n \).
183         *
184         * <p>The more extreme tables are identified using the {@link AlternativeHypothesis} as
185         * per the {@link #Z_POOLED} method.
186         */
187        Z_UNPOOLED,
188
189        /**
190         * Uses the p-value from Fisher's exact test. This is also known as Boschloo's test.
191         *
192         * <p>The p-value for Fisher's test is computed using using the
193         * {@link AlternativeHypothesis}. The more extreme tables are identified using
194         * \( p(X) \le p(X_0) \).
195         *
196         * <p>This method is always uniformly more powerful than Fisher's test.
197         *
198         * @see FisherExactTest
199         */
200        BOSCHLOO;
201    }
202
203    /**
204     * Result for the unconditioned exact test.
205     *
206     * <p>This class is immutable.
207     *
208     * @since 1.1
209     */
210    public static final class Result extends BaseSignificanceResult {
211        /** Nuisance parameter. */
212        private final double pi;
213
214        /**
215         * Create an instance where all tables are more extreme, i.e. the p-value
216         * is 1.0.
217         *
218         * @param statistic Test statistic.
219         */
220        Result(double statistic) {
221            super(statistic, 1);
222            this.pi = 0.5;
223        }
224
225        /**
226         * @param statistic Test statistic.
227         * @param pi Nuisance parameter.
228         * @param p Result p-value.
229         */
230        Result(double statistic, double pi, double p) {
231            super(statistic, p);
232            this.pi = pi;
233        }
234
235        /**
236         * {@inheritDoc}
237         *
238         * <p>The value of the statistic is dependent on the {@linkplain Method method}
239         * used to determine the more extreme tables.
240         */
241        @Override
242        public double getStatistic() {
243            // Note: This method is here for documentation
244            return super.getStatistic();
245        }
246
247        /**
248         * Gets the nuisance parameter that maximised the probability sum of the as or more
249         * extreme tables.
250         *
251         * @return the nuisance parameter.
252         */
253        public double getNuisanceParameter() {
254            return pi;
255        }
256    }
257
258    /**
259     * An expandable list of (x,y) values. This allows tracking 2D positions stored as
260     * a single index.
261     */
262    private static class XYList {
263        /** The maximum size of array to allocate. */
264        private final int max;
265        /** Width, or maximum x value (exclusive). */
266        private final int width;
267
268        /** The size of the list. */
269        private int size;
270        /** The list data. */
271        private int[] data = new int[10];
272
273        /**
274         * Create an instance. It is assumed that (maxx+1)*(maxy+1) does not exceed the
275         * capacity of an array.
276         *
277         * @param maxx Maximum x-value (inclusive).
278         * @param maxy Maximum y-value (inclusive).
279         */
280        XYList(int maxx, int maxy) {
281            this.width = maxx + 1;
282            this.max = width * (maxy + 1);
283        }
284
285        /**
286         * Gets the width.
287         * (x, y) values are stored using y * width + x.
288         *
289         * @return the width
290         */
291        int getWidth() {
292            return width;
293        }
294
295        /**
296         * Gets the maximum X value (inclusive).
297         *
298         * @return the max X
299         */
300        int getMaxX() {
301            return width - 1;
302        }
303
304        /**
305         * Gets the maximum Y value (inclusive).
306         *
307         * @return the max Y
308         */
309        int getMaxY() {
310            return max / width - 1;
311        }
312
313        /**
314         * Adds the value to the list.
315         *
316         * @param x X value.
317         * @param y Y value.
318         */
319        void add(int x, int y) {
320            if (size == data.length) {
321                // Overflow safe doubling of the current size.
322                data = Arrays.copyOf(data, (int) Math.min(max, size * 2L));
323            }
324            data[size++] = width * y + x;
325        }
326
327        /**
328         * Gets the 2D index at the specified {@code index}.
329         * The index is y * width + x:
330         * <pre>
331         * x = index % width
332         * y = index / width
333         * </pre>
334         *
335         * @param index Element index.
336         * @return the 2D index
337         */
338        int get(int index) {
339            return data[index];
340        }
341
342        /**
343         * Gets the number of elements in the list.
344         *
345         * @return the size
346         */
347        int size() {
348            return size;
349        }
350
351        /**
352         * Checks if the list size is zero.
353         *
354         * @return true if empty
355         */
356        boolean isEmpty() {
357            return size == 0;
358        }
359
360        /**
361         * Checks if the list is the maximum capacity.
362         *
363         * @return true if full
364         */
365        boolean isFull() {
366            return size == max;
367        }
368    }
369
370    /**
371     * A container of (key,value) pairs to store candidate minima. Encapsulates the
372     * logic of storing multiple initial search points for optimization.
373     *
374     * <p>Stores all pairs within a relative tolerance of the lowest minima up to a set
375     * capacity. When at capacity the worst candidate is replaced by addition of a
376     * better candidate.
377     *
378     * <p>Special handling is provided to store only a single NaN value if no non-NaN
379     * values have been observed. This prevents storing a large number of NaN
380     * candidates.
381     */
382    static class Candidates {
383        /** The maximum size of array to allocate. */
384        private final int max;
385        /** Relative distance from lowest candidate. */
386        private final double eps;
387        /** Candidate (key,value) pairs. */
388        private double[][] data;
389        /** Current size of the list. */
390        private int size;
391        /** Current minimum. */
392        private double min = Double.POSITIVE_INFINITY;
393        /** Current threshold for inclusion. */
394        private double threshold = Double.POSITIVE_INFINITY;
395
396        /**
397         * Create an instance.
398         *
399         * @param max Maximum number of allowed candidates (limited to at least 1).
400         * @param eps Relative distance of candidate minima from the lowest candidate
401         * (assumed to be positive and finite).
402         */
403        Candidates(int max, double eps) {
404            this.max = Math.max(1, max);
405            this.eps = eps;
406            // Create the initial storage
407            data = new double[Math.min(this.max, 4)][];
408        }
409
410        /**
411         * Adds the (key, value) pair.
412         *
413         * @param k Key.
414         * @param v Value.
415         */
416        void add(double k, double v) {
417            // Store only a single NaN
418            if (Double.isNaN(v)) {
419                if (size == 0) {
420                    // No requirement to check capacity
421                    data[size++] = new double[] {k, v};
422                }
423                return;
424            }
425            // Here values are non-NaN.
426            // If higher then do not store.
427            if (v > threshold) {
428                return;
429            }
430            // Check if lower than the current minima.
431            if (v < min) {
432                min = v;
433                // Get new threshold
434                threshold = v + Math.abs(v) * eps;
435                // Remove existing entries above the threshold
436                int s = 0;
437                for (int i = 0; i < size; i++) {
438                    // This will filter NaN values
439                    if (data[i][1] <= threshold) {
440                        data[s++] = data[i];
441                    }
442                }
443                size = s;
444                // Caution: This does not clear stale data
445                // by setting all values in [newSize, oldSize) = null
446            }
447            addPair(k, v);
448        }
449
450        /**
451         * Add the (key, value) pair to the data.
452         * It is assumed the data satisfy the conditions for addition.
453         *
454         * @param k Key.
455         * @param v Value.
456         */
457        private void addPair(double k, double v) {
458            if (size == data.length) {
459                if (size == max) {
460                    // At capacity.
461                    replaceWorst(k, v);
462                    return;
463                }
464                // Expand
465                data = Arrays.copyOfRange(data, 0, (int) Math.min(max, size * 2L));
466            }
467            data[size++] = new double[] {k, v};
468        }
469
470        /**
471         * Replace the worst candidate.
472         *
473         * @param k Key.
474         * @param v Value.
475         */
476        private void replaceWorst(double k, double v) {
477            // Note: This only occurs when NaN values have been removed by addition
478            // of non-NaN values.
479            double[] worst = data[0];
480            for (int i = 1; i < size; i++) {
481                if (worst[1] < data[i][1]) {
482                    worst = data[i];
483                }
484            }
485            worst[0] = k;
486            worst[1] = v;
487        }
488
489        /**
490         * Return the minimum (key,value) pair.
491         *
492         * @return the minimum (or null)
493         */
494        double[] getMinimum() {
495            // This will handle size=0 as data[0] will be null
496            double[] best = data[0];
497            for (int i = 1; i < size; i++) {
498                if (best[1] > data[i][1]) {
499                    best = data[i];
500                }
501            }
502            return best;
503        }
504
505        /**
506         * Perform the given action for each (key, value) pair.
507         *
508         * @param action Action.
509         */
510        void forEach(Consumer<double[]> action) {
511            for (int i = 0; i < size; i++) {
512                action.accept(data[i]);
513            }
514        }
515    }
516
517    /**
518     * Compute the statistic for Boschloo's test.
519     */
520    @FunctionalInterface
521    private interface BoschlooStatistic {
522        /**
523         * Compute Fisher's p-value for the 2x2 contingency table with the observed
524         * value {@code x} in position [0][0]. Note that the table margins are fixed
525         * and are defined by the population size, number of successes and sample
526         * size of the specified hypergeometric distribution.
527         *
528         * @param dist Hypergeometric distribution.
529         * @param x Value.
530         * @return Fisher's p-value
531         */
532        double value(Hypergeom dist, int x);
533    }
534
535    /**
536     * @param alternative Alternative hypothesis.
537     * @param method Method to identify more extreme tables.
538     * @param points Number of initial points.
539     * @param optimize Option to optimize the best initial point(s).
540     */
541    private UnconditionedExactTest(AlternativeHypothesis alternative,
542                                   Method method,
543                                   int points,
544                                   boolean optimize) {
545        this.alternative = alternative;
546        this.method = method;
547        this.points = points;
548        this.optimize = optimize;
549    }
550
551    /**
552     * Return an instance using the default options.
553     *
554     * <ul>
555     * <li>{@link AlternativeHypothesis#TWO_SIDED}
556     * <li>{@link Method#BOSCHLOO}
557     * <li>{@linkplain #withInitialPoints(int) points = 33}
558     * <li>{@linkplain #withOptimize(boolean) optimize = true}
559     * </ul>
560     *
561     * @return default instance
562     */
563    public static UnconditionedExactTest withDefaults() {
564        return DEFAULT;
565    }
566
567    /**
568     * Return an instance with the configured alternative hypothesis.
569     *
570     * @param v Value.
571     * @return an instance
572     */
573    public UnconditionedExactTest with(AlternativeHypothesis v) {
574        return new UnconditionedExactTest(Objects.requireNonNull(v), method, points, optimize);
575    }
576
577    /**
578     * Return an instance with the configured method.
579     *
580     * @param v Value.
581     * @return an instance
582     */
583    public UnconditionedExactTest with(Method v) {
584        return new UnconditionedExactTest(alternative, Objects.requireNonNull(v), points, optimize);
585    }
586
587    /**
588     * Return an instance with the configured number of initial points.
589     *
590     * <p>The search for the nuisance parameter will use \( v \) points in the open interval
591     * \( (0, 1) \). The interval is evaluated by including start and end points approximately
592     * equal to 0 and 1. Additional internal points are enumerated using increments of
593     * approximately \( \frac{1}{v-1} \). The minimum number of points is 2. Increasing the
594     * number of points increases the precision of the search at the cost of performance.
595     *
596     * <p>To approximately double the number of points so that all existing points are included
597     * and additional points half-way between them are sampled requires using {@code 2p - 1}
598     * where {@code p} is the existing number of points.
599     *
600     * @param v Value.
601     * @return an instance
602     * @throws IllegalArgumentException if the value is {@code < 2}.
603     */
604    public UnconditionedExactTest withInitialPoints(int v) {
605        if (v <= 1) {
606            throw new InferenceException(InferenceException.X_LT_Y, v, 2);
607        }
608        return new UnconditionedExactTest(alternative, method, v, optimize);
609    }
610
611    /**
612     * Return an instance with the configured optimization of initial search points.
613     *
614     * <p>If enabled then the initial point(s) with the highest probability is/are used as the start
615     * for an optimization to find a local maxima.
616     *
617     * @param v Value.
618     * @return an instance
619     * @see #withInitialPoints(int)
620     */
621    public UnconditionedExactTest withOptimize(boolean v) {
622        return new UnconditionedExactTest(alternative, method, points, v);
623    }
624
625    /**
626     * Compute the statistic for the unconditioned exact test. The statistic returned
627     * depends on the configured {@linkplain Method method}.
628     *
629     * @param table 2-by-2 contingency table.
630     * @return test statistic
631     * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any
632     * table entry is negative; any column sum is zero; the table sum is zero or not an
633     * integer; or the number of possible tables exceeds the maximum array capacity.
634     * @see #with(Method)
635     * @see #test(int[][])
636     */
637    public double statistic(int[][] table) {
638        checkTable(table);
639        final int a = table[0][0];
640        final int b = table[0][1];
641        final int c = table[1][0];
642        final int d = table[1][1];
643        final int m = a + c;
644        final int n = b + d;
645        // Exhaustive switch statement
646        switch (method) {
647        case Z_POOLED:
648            return statisticZ(a, b, m, n, true);
649        case Z_UNPOOLED:
650            return statisticZ(a, b, m, n, false);
651        case BOSCHLOO:
652            return statisticBoschloo(a, b, m, n);
653        }
654        throw new IllegalStateException(String.valueOf(method));
655    }
656
657    /**
658     * Performs an unconditioned exact test on the 2-by-2 contingency table. The statistic and
659     * p-value returned depends on the configured {@linkplain Method method} and
660     * {@linkplain AlternativeHypothesis alternative hypothesis}.
661     *
662     * <p>The search for the nuisance parameter that maximises the p-value can be configured to:
663     * start with a number of {@linkplain #withInitialPoints(int) initial points}; and
664     * {@linkplain #withOptimize(boolean) optimize} the best points.
665     *
666     * @param table 2-by-2 contingency table.
667     * @return test result
668     * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any
669     * table entry is negative; any column sum is zero; the table sum is zero or not an
670     * integer; or the number of possible tables exceeds the maximum array capacity.
671     * @see #with(Method)
672     * @see #with(AlternativeHypothesis)
673     * @see #statistic(int[][])
674     */
675    public Result test(int[][] table) {
676        checkTable(table);
677        final int a = table[0][0];
678        final int b = table[0][1];
679        final int c = table[1][0];
680        final int d = table[1][1];
681        final int m = a + c;
682        final int n = b + d;
683
684        // Used to track more extreme tables
685        final XYList tableList = new XYList(m, n);
686
687        final double statistic = findExtremeTables(a, b, tableList);
688        if (tableList.isEmpty() || tableList.isFull()) {
689            // All possible tables are more extreme, e.g. a two-sided test where the
690            // z-statistic is zero.
691            return new Result(statistic);
692        }
693        final double[] opt = computePValue(tableList);
694
695        return new Result(statistic, opt[0], opt[1]);
696    }
697
698    /**
699     * Find all tables that are as or more extreme than the observed table.
700     *
701     * <p>If the list of tables is full then all tables are more extreme.
702     * Some configurations can detect this without performing a search
703     * and in this case the list of tables is returned as empty.
704     *
705     * @param a Observed value for a.
706     * @param b Observed value for b.
707     * @param tableList List to track more extreme tables.
708     * @return the test statistic
709     */
710    private double findExtremeTables(int a, int b, XYList tableList) {
711        final int m = tableList.getMaxX();
712        final int n = tableList.getMaxY();
713        // Exhaustive switch statement
714        switch (method) {
715        case Z_POOLED:
716            return findExtremeTablesZ(a, b, m, n, true, tableList);
717        case Z_UNPOOLED:
718            return findExtremeTablesZ(a, b, m, n, false, tableList);
719        case BOSCHLOO:
720            return findExtremeTablesBoschloo(a, b, m, n, tableList);
721        }
722        throw new IllegalStateException(String.valueOf(method));
723    }
724
725    /**
726     * Compute the statistic from a Z-test.
727     *
728     * @param a Observed value for a.
729     * @param b Observed value for b.
730     * @param m Column sum m.
731     * @param n Column sum n.
732     * @param pooled true to use a pooled variance.
733     * @return z
734     */
735    private static double statisticZ(int a, int b, int m, int n, boolean pooled) {
736        final double p0 = (double) a / m;
737        final double p1 = (double) b / n;
738        // Avoid NaN generation 0 / 0 when the variance is 0
739        if (p0 != p1) {
740            final double variance;
741            if (pooled) {
742                // Integer sums will not overflow
743                final double p = (double) (a + b) / (m + n);
744                variance = p * (1 - p) * (1.0 / m + 1.0 / n);
745            } else {
746                variance = p0 * (1 - p0) / m + p1 * (1 - p1) / n;
747            }
748            return (p0 - p1) / Math.sqrt(variance);
749        }
750        return 0;
751    }
752
753    /**
754     * Find all tables that are as or more extreme than the observed table using the Z statistic.
755     *
756     * @param a Observed value for a.
757     * @param b Observed value for b.
758     * @param m Column sum m.
759     * @param n Column sum n.
760     * @param pooled true to use a pooled variance.
761     * @param tableList List to track more extreme tables.
762     * @return observed z
763     */
764    private double findExtremeTablesZ(int a, int b, int m, int n, boolean pooled, XYList tableList) {
765        final double statistic = statisticZ(a, b, m, n, pooled);
766        // Identify more extreme tables using the alternate hypothesis
767        final DoublePredicate test;
768        if (alternative == AlternativeHypothesis.GREATER_THAN) {
769            test = z -> z >= statistic;
770        } else if (alternative == AlternativeHypothesis.LESS_THAN) {
771            test = z -> z <= statistic;
772        } else {
773            // two-sided
774            if (statistic == 0) {
775                // Early exit: all tables are as extreme
776                return 0;
777            }
778            final double za = Math.abs(statistic);
779            test = z -> Math.abs(z) >= za;
780        }
781        // Precompute factors
782        final double mn = (double) m + n;
783        final double norm = 1.0 / m + 1.0 / n;
784        double z;
785        // Process all possible tables
786        for (int i = 0; i <= m; i++) {
787            final double p0 = (double) i / m;
788            final double vp0 = p0 * (1 - p0) / m;
789            for (int j = 0; j <= n; j++) {
790                final double p1 = (double) j / n;
791                // Avoid NaN generation 0 / 0 when the variance is 0
792                if (p0 == p1) {
793                    z = 0;
794                } else {
795                    final double variance;
796                    if (pooled) {
797                        // Integer sums will not overflow
798                        final double p = (i + j) / mn;
799                        variance = p * (1 - p) * norm;
800                    } else {
801                        variance = vp0 + p1 * (1 - p1) / n;
802                    }
803                    z = (p0 - p1) / Math.sqrt(variance);
804                }
805                if (test.test(z)) {
806                    tableList.add(i, j);
807                }
808            }
809        }
810        return statistic;
811    }
812
813    /**
814     * Compute the statistic using Fisher's p-value (also known as Boschloo's test).
815     *
816     * @param a Observed value for a.
817     * @param b Observed value for b.
818     * @param m Column sum m.
819     * @param n Column sum n.
820     * @return p-value
821     */
822    private double statisticBoschloo(int a, int b, int m, int n) {
823        final int nn = m + n;
824        final int k = a + b;
825        // Re-use the cached Hypergeometric implementation to allow the value
826        // to be identical for the statistic and test methods.
827        final Hypergeom dist = new Hypergeom(nn, k, m);
828        if (alternative == AlternativeHypothesis.GREATER_THAN) {
829            return dist.sf(a - 1);
830        } else if (alternative == AlternativeHypothesis.LESS_THAN) {
831            return dist.cdf(a);
832        }
833        // two-sided: Find all i where Pr(X = i) <= Pr(X = a) and sum them.
834        return statisticBoschlooTwoSided(dist, a);
835    }
836
837    /**
838     * Compute the two-sided statistic using Fisher's p-value (also known as Boschloo's test).
839     *
840     * @param distribution Hypergeometric distribution.
841     * @param k Observed value.
842     * @return p-value
843     */
844    private static double statisticBoschlooTwoSided(Hypergeom distribution, int k) {
845        // two-sided: Find all i where Pr(X = i) <= Pr(X = k) and sum them.
846        // Logic is the same as FisherExactTest but using the probability (PMF), which
847        // is cached, rather than the logProbability.
848        final double pk = distribution.pmf(k);
849
850        final int m1 = distribution.getLowerMode();
851        final int m2 = distribution.getUpperMode();
852        if (k < m1) {
853            // Lower half = cdf(k)
854            // Find upper half. As k < lower mode i should never
855            // reach the lower mode based on the probability alone.
856            // Bracket with the upper mode.
857            final int i = Searches.searchDescending(m2, distribution.getSupportUpperBound(), pk,
858                distribution::pmf);
859            return distribution.cdf(k) +
860                   distribution.sf(i - 1);
861        } else if (k > m2) {
862            // Upper half = sf(k - 1)
863            // Find lower half. As k > upper mode i should never
864            // reach the upper mode based on the probability alone.
865            // Bracket with the lower mode.
866            final int i = Searches.searchAscending(distribution.getSupportLowerBound(), m1, pk,
867                distribution::pmf);
868            return distribution.cdf(i) +
869                   distribution.sf(k - 1);
870        }
871        // k == mode
872        // Edge case where the sum of probabilities will be either
873        // 1 or 1 - Pr(X = mode) where mode != k
874        final double pm = distribution.pmf(k == m1 ? m2 : m1);
875        return pm > pk ? 1 - pm : 1;
876    }
877
878    /**
879     * Find all tables that are as or more extreme than the observed table using the
880     * Fisher's p-value as the statistic (also known as Boschloo's test).
881     *
882     * @param a Observed value for a.
883     * @param b Observed value for b.
884     * @param m Column sum m.
885     * @param n Column sum n.
886     * @param tableList List to track more extreme tables.
887     * @return observed p-value
888     */
889    private double findExtremeTablesBoschloo(int a, int b, int m, int n, XYList tableList) {
890        final double statistic = statisticBoschloo(a, b, m, n);
891
892        // Function to compute the statistic
893        final BoschlooStatistic func;
894        if (alternative == AlternativeHypothesis.GREATER_THAN) {
895            func = (dist, x) -> dist.sf(x - 1);
896        } else if (alternative == AlternativeHypothesis.LESS_THAN) {
897            func = Hypergeom::cdf;
898        } else {
899            func = UnconditionedExactTest::statisticBoschlooTwoSided;
900        }
901
902        // All tables are: 0 <= i <= m  by  0 <= j <= n
903        // Diagonal (upper-left to lower-right) strips of the possible
904        // tables use the same hypergeometric distribution
905        // (i.e. i+j == number of successes). To enumerate all requires
906        // using the full range of all distributions: 0 <= i+j <= m+n.
907        // Note the column sum m is fixed.
908        final int mn = m + n;
909        for (int k = 0; k <= mn; k++) {
910            final Hypergeom dist = new Hypergeom(mn, k, m);
911            final int lo = dist.getSupportLowerBound();
912            final int hi = dist.getSupportUpperBound();
913            for (int i = lo; i <= hi; i++) {
914                if (func.value(dist, i) <= statistic) {
915                    // j = k - i
916                    tableList.add(i, k - i);
917                }
918            }
919        }
920        return statistic;
921    }
922
923    /**
924     * Compute the nuisance parameter and p-value for the binomial model given the list
925     * of possible tables.
926     *
927     * <p>The current method enumerates an initial set of points and stores local
928     * extrema as candidates. Any candidate within 2% of the best is optionally
929     * optimized; this is limited to the top 3 candidates. These settings
930     * could be exposed as configurable options. Currently only the choice to optimize
931     * or not is exposed.
932     *
933     * @param tableList List of tables.
934     * @return [nuisance parameter, p-value]
935     */
936    private double[] computePValue(XYList tableList) {
937        final DoubleUnaryOperator func = createBinomialModel(tableList);
938
939        // Enumerate the range [LOWER, 1-LOWER] and save the best points for optimization
940        final Candidates minima = new Candidates(MAX_CANDIDATES, MINIMA_EPS);
941        final int n = points - 1;
942        final double inc = (1.0 - 2 * LOWER_BOUND) / n;
943        // Moving window of 3 values to identify minima.
944        // px holds the position of the previous evaluated point.
945        double v2 = 0;
946        double v3 = func.applyAsDouble(LOWER_BOUND);
947        double px = LOWER_BOUND;
948        for (int i = 1; i < n; i++) {
949            final double x = LOWER_BOUND + i * inc;
950            final double v1 = v2;
951            v2 = v3;
952            v3 = func.applyAsDouble(x);
953            addCandidate(minima, v1, v2, v3, px);
954            px = x;
955        }
956        // Add the upper bound
957        final double x = 1 - LOWER_BOUND;
958        final double vn = func.applyAsDouble(x);
959        addCandidate(minima, v2, v3, vn, px);
960        addCandidate(minima, v3, vn, 0, x);
961
962        final double[] min = minima.getMinimum();
963
964        // Optionally optimize the best point(s) (if not already optimal)
965        if (optimize && min[1] > -1) {
966            final BrentOptimizer opt = new BrentOptimizer(SOLVER_RELATIVE_EPS, Double.MIN_VALUE);
967            final BracketFinder bf = new BracketFinder();
968            minima.forEach(candidate -> {
969                double a = candidate[0];
970                final double fa;
971                // Attempt to bracket the minima. Use an initial second point placed relative to
972                // the size of the interval: [x - increment, x + increment].
973                // if a < 0.5 then add a small delta ; otherwise subtract the delta.
974                final double b = a - Math.copySign(inc * INC_FRACTION, a - 0.5);
975                if (bf.search(func, a, b, 0, 1)) {
976                    // The bracket a < b < c must have f(b) < min(f(a), f(b))
977                    final PointValuePair p = opt.optimize(func, bf.getLo(), bf.getHi(), bf.getMid(), bf.getFMid());
978                    a = p.getPoint();
979                    fa = p.getValue();
980                } else {
981                    // Mid-point is at one of the bounds (i.e. is 0 or 1)
982                    a = bf.getMid();
983                    fa = bf.getFMid();
984                }
985                if (fa < min[1]) {
986                    min[0] = a;
987                    min[1] = fa;
988                }
989            });
990        }
991        // Reverse the sign of the p-value to create a maximum.
992        // Note that due to the summation the p-value can be above 1 so we clip the final result.
993        // Note: Apply max then reverse sign. This will pass through spurious NaN values if
994        // the p-value computation produced all NaNs.
995        min[1] = -Math.max(-1, min[1]);
996        return min;
997    }
998
999    /**
1000     * Creates the binomial model p-value function for the nuisance parameter.
1001     * Note: This function computes the negative p-value so is suitable for
1002     * optimization by a search for a minimum.
1003     *
1004     * @param tableList List of tables.
1005     * @return the function
1006     */
1007    private static DoubleUnaryOperator createBinomialModel(XYList tableList) {
1008        final int m = tableList.getMaxX();
1009        final int n = tableList.getMaxY();
1010        final int mn = m + n;
1011        // Compute the probability using logs
1012        final double[] c = new double[tableList.size()];
1013        final int[] ij = new int[tableList.size()];
1014        final int width = tableList.getWidth();
1015
1016        // Compute the log binomial dynamically for a small number of values
1017        final IntToDoubleFunction binomM;
1018        final IntToDoubleFunction binomN;
1019        if (tableList.size() < mn) {
1020            binomM = k -> LogBinomialCoefficient.value(m, k);
1021            binomN = k -> LogBinomialCoefficient.value(n, k);
1022        } else {
1023            // Pre-compute all values
1024            binomM = createLogBinomialCoefficients(m);
1025            binomN = m == n ? binomM : createLogBinomialCoefficients(n);
1026        }
1027
1028        // Handle special cases i+j == 0 and i+j == m+n.
1029        // These will occur only once, if at all. Mark if they occur.
1030        int flag = 0;
1031        int j = 0;
1032        for (int i = 0; i < c.length; i++) {
1033            final int index = tableList.get(i);
1034            final int x = index % width;
1035            final int y = index / width;
1036            final int xy = x + y;
1037            if (xy == 0) {
1038                flag |= 1;
1039            } else if (xy == mn) {
1040                flag |= 2;
1041            } else {
1042                ij[j] = xy;
1043                c[j] = binomM.applyAsDouble(x) + binomN.applyAsDouble(y);
1044                j++;
1045            }
1046        }
1047
1048        final int size = j;
1049        final boolean ij0 = (flag & 1) != 0;
1050        final boolean ijmn = (flag & 2) != 0;
1051        return pi -> {
1052            final double logp = Math.log(pi);
1053            final double log1mp = Math.log1p(-pi);
1054            double sum = 0;
1055            for (int i = 0; i < size; i++) {
1056                // binom(m, i) * binom(n, j) * pi^(i+j) * (1-pi)^(m+n-i-j)
1057                sum += Math.exp(ij[i] * logp + (mn - ij[i]) * log1mp + c[i]);
1058            }
1059            // Add the simplified terms where the binomial is 1.0 and one power is x^0 == 1.0.
1060            // This avoids 0 * log(x) generating NaN when x is 0 in the case where pi was 0 or 1.
1061            // Reuse exp (not pow) to support pi approaching 0 or 1.
1062            if (ij0) {
1063                // pow(1-pi, mn)
1064                sum += Math.exp(mn * log1mp);
1065            }
1066            if (ijmn) {
1067                // pow(pi, mn)
1068                sum += Math.exp(mn * logp);
1069            }
1070            // The optimizer minimises the function so this returns -p.
1071            return -sum;
1072        };
1073    }
1074
1075    /**
1076     * Create the natural logarithm of the binomial coefficient for all {@code k = [0, n]}.
1077     *
1078     * @param n Limit N.
1079     * @return ln binom(n, k)
1080     */
1081    private static IntToDoubleFunction createLogBinomialCoefficients(int n) {
1082        final double[] binom = new double[n + 1];
1083        // Exploit symmetry.
1084        // ignore: binom(n, 0) == binom(n, n) == 1
1085        int j = n - 1;
1086        for (int i = 1; i <= j; i++, j--) {
1087            binom[i] = binom[j] = LogBinomialCoefficient.value(n, i);
1088        }
1089        return k -> binom[k];
1090    }
1091
1092    /**
1093     * Add point 2 to the list of minima if neither neighbour value is lower.
1094     * <pre>
1095     * !(v1 < v2 || v3 < v2)
1096     * </pre>
1097     *
1098     * @param minima Candidate minima.
1099     * @param v1 First point function value.
1100     * @param v2 Second point function value.
1101     * @param v3 Third point function value.
1102     * @param x2 Second point.
1103     */
1104    private void addCandidate(Candidates minima, double v1, double v2, double v3, double x2) {
1105        final double min = v1 < v3 ? v1 : v3;
1106        if (min < v2) {
1107            // Lower neighbour(s)
1108            return;
1109        }
1110        // Add the candidate. This could be NaN but the candidate list handles this by storing
1111        // NaN only when no non-NaN values have been observed.
1112        minima.add(x2, v2);
1113    }
1114
1115    /**
1116     * Check the input is a 2-by-2 contingency table.
1117     *
1118     * @param table Contingency table.
1119     * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any
1120     * table entry is negative; any column sum is zero; the table sum is zero or not an
1121     * integer; or the number of possible tables exceeds the maximum array capacity.
1122     */
1123    private static void checkTable(int[][] table) {
1124        Arguments.checkTable(table);
1125        // Must all be positive
1126        final int a = table[0][0];
1127        final int c = table[1][0];
1128        // checkTable has validated the total sum is < 2^31
1129        final int m = a + c;
1130        if (m == 0) {
1131            throw new InferenceException(InferenceException.ZERO_AT, COLUMN_SUM, 0);
1132        }
1133        final int b = table[0][1];
1134        final int d = table[1][1];
1135        final int n = b + d;
1136        if (n == 0) {
1137            throw new InferenceException(InferenceException.ZERO_AT, COLUMN_SUM, 1);
1138        }
1139        // Total possible tables must be a size we can track in an array (to compute the p-value)
1140        final long size = (m + 1L) * (n + 1L);
1141        if (size > MAX_TABLES) {
1142            throw new InferenceException(InferenceException.X_GT_Y, size, MAX_TABLES);
1143        }
1144    }
1145}