View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.statistics.inference;
18  
19  import java.util.Arrays;
20  import java.util.Objects;
21  import java.util.function.Consumer;
22  import java.util.function.DoublePredicate;
23  import java.util.function.DoubleUnaryOperator;
24  import java.util.function.IntToDoubleFunction;
25  import org.apache.commons.numbers.combinatorics.LogBinomialCoefficient;
26  import org.apache.commons.statistics.inference.BrentOptimizer.PointValuePair;
27  
28  /**
29   * Implements an unconditioned exact test for a contingency table.
30   *
31   * <p>Performs an exact test for the statistical significance of the association (contingency)
32   * between two kinds of categorical classification. A 2x2 contingency table is:
33   *
34   * <p>\[ \left[ {\begin{array}{cc}
35   *         a &amp; b \\
36   *         c &amp; d \\
37   *       \end{array} } \right] \]
38   *
39   * <p>This test applies to the case of a 2x2 contingency table with one margin fixed. Note that
40   * if both margins are fixed (the row sums and column sums are not random)
41   * then Fisher's exact test can be applied.
42   *
43   * <p>This implementation fixes the column sums \( m = a + c \) and \( n = b + d \).
44   * All possible tables can be created using \( 0 \le a \le m \) and \( 0 \le b \le n \).
45   * The random values \( a \) and \( b \) follow a binomial distribution with probabilities
46   * \( p_0 \) and \( p_1 \) such that \( a \sim B(m, p_0) \) and \( b \sim B(n, p_1) \).
47   * The p-value of the 2x2 table is the product of two binomials:
48   *
49   * <p>\[ \begin{aligned}
50   *       p &amp;= Pr(a; m, p_0) \times Pr(b; n, p_1) \\
51   *         &amp;= \binom{m}{a} p_0^a (1-p_0)^{m-a} \times \binom{n}{b} p_1^b (1-p_1)^{n-b} \end{aligned} \]
52   *
53   * <p>For the binomial model, the null hypothesis is the two nuisance parameters are equal
54   * \( p_0 = p_1 = \pi\), with \( \pi \) the probability for equal proportions, and the probability
55   * of any single table is:
56   *
57   * <p>\[ p = \binom{m}{a} \binom{n}{b} \pi^{a+b} (1-\pi)^{m+n-a-b} \]
58   *
59   * <p>The p-value of the observed table is calculated by maximising the sum of the as or more
60   * extreme tables over the domain of the nuisance parameter \( 0 \lt \pi \lt 1 \):
61   *
62   * <p>\[ p(a, b) = \sum_{i,j} \binom{m}{i} \binom{n}{j} \pi^{i+j} (1-\pi)^{m+n-i-j} \]
63   *
64   * <p>where table \( (i,j) \) is as or more extreme than the observed table \( (a, b) \). The test
65   * can be configured to select more extreme tables using various {@linkplain Method methods}.
66   *
67   * <p>Note that the sum of the joint binomial distribution is a univariate function for
68   * the nuisance parameter \( \pi \). This function may have many local maxima and the
69   * search enumerates the range with a configured {@linkplain #withInitialPoints(int)
70   * number of points}. The best candidates are optionally used as the start point for an
71   * {@linkplain #withOptimize(boolean) optimized} search for a local maxima.
72   *
73   * <p>References:
74   * <ol>
75   * <li>
76   * Barnard, G.A. (1947).
77   * <a href="https://doi.org/10.1093/biomet/34.1-2.123">Significance tests for 2x2 tables.</a>
78   * Biometrika, 34, Issue 1-2, 123–138.
79   * <li>
80   * Boschloo, R.D. (1970).
81   * <a href="https://doi.org/10.1111/j.1467-9574.1970.tb00104.x">Raised conditional level of
82   * significance for the 2 × 2-table when testing the equality of two probabilities.</a>
83   * Statistica neerlandica, 24(1), 1–9.
84   * <li>
85   * Suisaa, A and Shuster, J.J. (1985).
86   * <a href="https://doi.org/10.2307/2981892">Exact Unconditional Sample Sizes
87   * for the 2 × 2 Binomial Trial.</a>
88   * Journal of the Royal Statistical Society. Series A (General), 148(4), 317-327.
89   * </ol>
90   *
91   * @see FisherExactTest
92   * @see <a href="https://en.wikipedia.org/wiki/Boschloo%27s_test">Boschloo&#39;s test (Wikipedia)</a>
93   * @see <a href="https://en.wikipedia.org/wiki/Barnard%27s_test">Barnard&#39;s test (Wikipedia)</a>
94   * @since 1.1
95   */
96  public final class UnconditionedExactTest {
97      /**
98       * Default instance.
99       *
100      * <p>SciPy's boschloo_exact and barnard_exact tests use 32 points in the interval [0,
101      * 1) The R Exact package uses 100 in the interval [1e-5, 1-1e-5]. Barnards 1947 paper
102      * describes the nuisance parameter in the open interval {@code 0 < pi < 1}. Here we
103      * respect the open-interval for the initial candidates and ignore 0 and 1. The
104      * initial bounds used are the same as the R Exact package. We closely match the inner
105      * 31 points from SciPy by using 33 points by default.
106      */
107     private static final UnconditionedExactTest DEFAULT = new UnconditionedExactTest(
108         AlternativeHypothesis.TWO_SIDED, Method.BOSCHLOO, 33, true);
109     /** Lower bound for the enumerated interval. The upper bound is {@code 1 - lower}. */
110     private static final double LOWER_BOUND = 1e-5;
111     /** Relative epsilon for the Brent solver. This is limited for a univariate function
112      * to approximately sqrt(eps) with eps = 2^-52. */
113     private static final double SOLVER_RELATIVE_EPS = 1.4901161193847656E-8;
114     /** Fraction of the increment (interval between enumerated points) to initialise the bracket
115      * for the minima. Note the minima should lie between x +/- increment. The bracket should
116      * search within this range. Set to 1/8 and so the initial point of the bracket is
117      * approximately 1.61 * 1/8 = 0.2 of the increment away from initial points a or b. */
118     private static final double INC_FRACTION = 0.125;
119     /** Maximum number of candidate to optimize. This is a safety limit to avoid excess
120      * optimization. Only candidates within a relative tolerance of the best candidate are
121      * stored. If the number of candidates exceeds this value then many candidates have a
122      * very similar p-value and the top candidates will be optimized. Using a value of 3
123      * allows at least one other candidate to be optimized when there is two-fold
124      * symmetry in the energy function. */
125     private static final int MAX_CANDIDATES = 3;
126     /** Relative distance of candidate minima from the lowest candidate. Used to exclude
127      * poor candidates from optimization. */
128     private static final double MINIMA_EPS = 0.02;
129     /** The maximum number of tables. This is limited by the maximum number of indices that
130      * can be maintained in memory. Potentially up to this number of tables must be tracked
131      * during computation of the p-value for as or more extreme tables. The limit is set
132      * using the same limit for maximum capacity as java.util.ArrayList. In practice any
133      * table anywhere near this limit can be computed using an alternative such as a chi-squared
134      * or g test. */
135     private static final int MAX_TABLES = Integer.MAX_VALUE - 8;
136     /** Error message text for zero column sums. */
137     private static final String COLUMN_SUM = "Column sum";
138 
139     /** Alternative hypothesis. */
140     private final AlternativeHypothesis alternative;
141     /** Method to identify more extreme tables. */
142     private final Method method;
143     /** Number of initial points. */
144     private final int points;
145     /** Option to optimize the best initial point(s). */
146     private final boolean optimize;
147 
148     /**
149      * Define the method to determine the more extreme tables.
150      *
151      * @since 1.1
152      */
153     public enum Method {
154         /**
155          * Uses the test statistic from a Z-test using a pooled variance.
156          *
157          * <p>\[ T(X) = \frac{\hat{p}_0 - \hat{p}_1}{\sqrt{\hat{p}(1 - \hat{p}) (\frac{1}{m} + \frac{1}{n})}} \]
158          *
159          * <p>where \( \hat{p}_0 = a / m \), \( \hat{p}_1 = b / n \), and
160          * \( \hat{p} = (a+b) / (m+n) \) are the estimators of \( p_0 \), \( p_1 \) and the
161          * pooled probability \( p \) assuming \( p_0 = p_1 \).
162          *
163          * <p>The more extreme tables are identified using the {@link AlternativeHypothesis}:
164          * <ul>
165          * <li>greater: \( T(X) \ge T(X_0) \)
166          * <li>less: \( T(X) \le T(X_0) \)
167          * <li>two-sided: \( | T(X) | \ge | T(X_0) | \)
168          * </ul>
169          *
170          * <p>The use of the Z statistic was suggested by Suissa and Shuster (1985).
171          * This method is uniformly more powerful than Fisher's test for balanced designs
172          * (\( m = n \)).
173          */
174         Z_POOLED,
175 
176         /**
177          * Uses the test statistic from a Z-test using an unpooled variance.
178          *
179          * <p>\[ T(X) = \frac{\hat{p}_0 - \hat{p}_1}
180          * {\sqrt{ \frac{\hat{p}_0(1 - \hat{p}_0)}{m} + \frac{\hat{p}_1(1 - \hat{p}_1)}{n}} } \]
181          *
182          * <p>where \( \hat{p}_0 = a / m \) and \( \hat{p}_1 = b / n \).
183          *
184          * <p>The more extreme tables are identified using the {@link AlternativeHypothesis} as
185          * per the {@link #Z_POOLED} method.
186          */
187         Z_UNPOOLED,
188 
189         /**
190          * Uses the p-value from Fisher's exact test. This is also known as Boschloo's test.
191          *
192          * <p>The p-value for Fisher's test is computed using using the
193          * {@link AlternativeHypothesis}. The more extreme tables are identified using
194          * \( p(X) \le p(X_0) \).
195          *
196          * <p>This method is always uniformly more powerful than Fisher's test.
197          *
198          * @see FisherExactTest
199          */
200         BOSCHLOO;
201     }
202 
203     /**
204      * Result for the unconditioned exact test.
205      *
206      * <p>This class is immutable.
207      *
208      * @since 1.1
209      */
210     public static final class Result extends BaseSignificanceResult {
211         /** Nuisance parameter. */
212         private final double pi;
213 
214         /**
215          * Create an instance where all tables are more extreme, i.e. the p-value
216          * is 1.0.
217          *
218          * @param statistic Test statistic.
219          */
220         Result(double statistic) {
221             super(statistic, 1);
222             this.pi = 0.5;
223         }
224 
225         /**
226          * @param statistic Test statistic.
227          * @param pi Nuisance parameter.
228          * @param p Result p-value.
229          */
230         Result(double statistic, double pi, double p) {
231             super(statistic, p);
232             this.pi = pi;
233         }
234 
235         /**
236          * {@inheritDoc}
237          *
238          * <p>The value of the statistic is dependent on the {@linkplain Method method}
239          * used to determine the more extreme tables.
240          */
241         @Override
242         public double getStatistic() {
243             // Note: This method is here for documentation
244             return super.getStatistic();
245         }
246 
247         /**
248          * Gets the nuisance parameter that maximised the probability sum of the as or more
249          * extreme tables.
250          *
251          * @return the nuisance parameter.
252          */
253         public double getNuisanceParameter() {
254             return pi;
255         }
256     }
257 
258     /**
259      * An expandable list of (x,y) values. This allows tracking 2D positions stored as
260      * a single index.
261      */
262     private static class XYList {
263         /** The maximum size of array to allocate. */
264         private final int max;
265         /** Width, or maximum x value (exclusive). */
266         private final int width;
267 
268         /** The size of the list. */
269         private int size;
270         /** The list data. */
271         private int[] data = new int[10];
272 
273         /**
274          * Create an instance. It is assumed that (maxx+1)*(maxy+1) does not exceed the
275          * capacity of an array.
276          *
277          * @param maxx Maximum x-value (inclusive).
278          * @param maxy Maximum y-value (inclusive).
279          */
280         XYList(int maxx, int maxy) {
281             this.width = maxx + 1;
282             this.max = width * (maxy + 1);
283         }
284 
285         /**
286          * Gets the width.
287          * (x, y) values are stored using y * width + x.
288          *
289          * @return the width
290          */
291         int getWidth() {
292             return width;
293         }
294 
295         /**
296          * Gets the maximum X value (inclusive).
297          *
298          * @return the max X
299          */
300         int getMaxX() {
301             return width - 1;
302         }
303 
304         /**
305          * Gets the maximum Y value (inclusive).
306          *
307          * @return the max Y
308          */
309         int getMaxY() {
310             return max / width - 1;
311         }
312 
313         /**
314          * Adds the value to the list.
315          *
316          * @param x X value.
317          * @param y Y value.
318          */
319         void add(int x, int y) {
320             if (size == data.length) {
321                 // Overflow safe doubling of the current size.
322                 data = Arrays.copyOf(data, (int) Math.min(max, size * 2L));
323             }
324             data[size++] = width * y + x;
325         }
326 
327         /**
328          * Gets the 2D index at the specified {@code index}.
329          * The index is y * width + x:
330          * <pre>
331          * x = index % width
332          * y = index / width
333          * </pre>
334          *
335          * @param index Element index.
336          * @return the 2D index
337          */
338         int get(int index) {
339             return data[index];
340         }
341 
342         /**
343          * Gets the number of elements in the list.
344          *
345          * @return the size
346          */
347         int size() {
348             return size;
349         }
350 
351         /**
352          * Checks if the list size is zero.
353          *
354          * @return true if empty
355          */
356         boolean isEmpty() {
357             return size == 0;
358         }
359 
360         /**
361          * Checks if the list is the maximum capacity.
362          *
363          * @return true if full
364          */
365         boolean isFull() {
366             return size == max;
367         }
368     }
369 
370     /**
371      * A container of (key,value) pairs to store candidate minima. Encapsulates the
372      * logic of storing multiple initial search points for optimization.
373      *
374      * <p>Stores all pairs within a relative tolerance of the lowest minima up to a set
375      * capacity. When at capacity the worst candidate is replaced by addition of a
376      * better candidate.
377      *
378      * <p>Special handling is provided to store only a single NaN value if no non-NaN
379      * values have been observed. This prevents storing a large number of NaN
380      * candidates.
381      */
382     static class Candidates {
383         /** The maximum size of array to allocate. */
384         private final int max;
385         /** Relative distance from lowest candidate. */
386         private final double eps;
387         /** Candidate (key,value) pairs. */
388         private double[][] data;
389         /** Current size of the list. */
390         private int size;
391         /** Current minimum. */
392         private double min = Double.POSITIVE_INFINITY;
393         /** Current threshold for inclusion. */
394         private double threshold = Double.POSITIVE_INFINITY;
395 
396         /**
397          * Create an instance.
398          *
399          * @param max Maximum number of allowed candidates (limited to at least 1).
400          * @param eps Relative distance of candidate minima from the lowest candidate
401          * (assumed to be positive and finite).
402          */
403         Candidates(int max, double eps) {
404             this.max = Math.max(1, max);
405             this.eps = eps;
406             // Create the initial storage
407             data = new double[Math.min(this.max, 4)][];
408         }
409 
410         /**
411          * Adds the (key, value) pair.
412          *
413          * @param k Key.
414          * @param v Value.
415          */
416         void add(double k, double v) {
417             // Store only a single NaN
418             if (Double.isNaN(v)) {
419                 if (size == 0) {
420                     // No requirement to check capacity
421                     data[size++] = new double[] {k, v};
422                 }
423                 return;
424             }
425             // Here values are non-NaN.
426             // If higher then do not store.
427             if (v > threshold) {
428                 return;
429             }
430             // Check if lower than the current minima.
431             if (v < min) {
432                 min = v;
433                 // Get new threshold
434                 threshold = v + Math.abs(v) * eps;
435                 // Remove existing entries above the threshold
436                 int s = 0;
437                 for (int i = 0; i < size; i++) {
438                     // This will filter NaN values
439                     if (data[i][1] <= threshold) {
440                         data[s++] = data[i];
441                     }
442                 }
443                 size = s;
444                 // Caution: This does not clear stale data
445                 // by setting all values in [newSize, oldSize) = null
446             }
447             addPair(k, v);
448         }
449 
450         /**
451          * Add the (key, value) pair to the data.
452          * It is assumed the data satisfy the conditions for addition.
453          *
454          * @param k Key.
455          * @param v Value.
456          */
457         private void addPair(double k, double v) {
458             if (size == data.length) {
459                 if (size == max) {
460                     // At capacity.
461                     replaceWorst(k, v);
462                     return;
463                 }
464                 // Expand
465                 data = Arrays.copyOfRange(data, 0, (int) Math.min(max, size * 2L));
466             }
467             data[size++] = new double[] {k, v};
468         }
469 
470         /**
471          * Replace the worst candidate.
472          *
473          * @param k Key.
474          * @param v Value.
475          */
476         private void replaceWorst(double k, double v) {
477             // Note: This only occurs when NaN values have been removed by addition
478             // of non-NaN values.
479             double[] worst = data[0];
480             for (int i = 1; i < size; i++) {
481                 if (worst[1] < data[i][1]) {
482                     worst = data[i];
483                 }
484             }
485             worst[0] = k;
486             worst[1] = v;
487         }
488 
489         /**
490          * Return the minimum (key,value) pair.
491          *
492          * @return the minimum (or null)
493          */
494         double[] getMinimum() {
495             // This will handle size=0 as data[0] will be null
496             double[] best = data[0];
497             for (int i = 1; i < size; i++) {
498                 if (best[1] > data[i][1]) {
499                     best = data[i];
500                 }
501             }
502             return best;
503         }
504 
505         /**
506          * Perform the given action for each (key, value) pair.
507          *
508          * @param action Action.
509          */
510         void forEach(Consumer<double[]> action) {
511             for (int i = 0; i < size; i++) {
512                 action.accept(data[i]);
513             }
514         }
515     }
516 
517     /**
518      * Compute the statistic for Boschloo's test.
519      */
520     @FunctionalInterface
521     private interface BoschlooStatistic {
522         /**
523          * Compute Fisher's p-value for the 2x2 contingency table with the observed
524          * value {@code x} in position [0][0]. Note that the table margins are fixed
525          * and are defined by the population size, number of successes and sample
526          * size of the specified hypergeometric distribution.
527          *
528          * @param dist Hypergeometric distribution.
529          * @param x Value.
530          * @return Fisher's p-value
531          */
532         double value(Hypergeom dist, int x);
533     }
534 
535     /**
536      * @param alternative Alternative hypothesis.
537      * @param method Method to identify more extreme tables.
538      * @param points Number of initial points.
539      * @param optimize Option to optimize the best initial point(s).
540      */
541     private UnconditionedExactTest(AlternativeHypothesis alternative,
542                                    Method method,
543                                    int points,
544                                    boolean optimize) {
545         this.alternative = alternative;
546         this.method = method;
547         this.points = points;
548         this.optimize = optimize;
549     }
550 
551     /**
552      * Return an instance using the default options.
553      *
554      * <ul>
555      * <li>{@link AlternativeHypothesis#TWO_SIDED}
556      * <li>{@link Method#BOSCHLOO}
557      * <li>{@linkplain #withInitialPoints(int) points = 33}
558      * <li>{@linkplain #withOptimize(boolean) optimize = true}
559      * </ul>
560      *
561      * @return default instance
562      */
563     public static UnconditionedExactTest withDefaults() {
564         return DEFAULT;
565     }
566 
567     /**
568      * Return an instance with the configured alternative hypothesis.
569      *
570      * @param v Value.
571      * @return an instance
572      */
573     public UnconditionedExactTest with(AlternativeHypothesis v) {
574         return new UnconditionedExactTest(Objects.requireNonNull(v), method, points, optimize);
575     }
576 
577     /**
578      * Return an instance with the configured method.
579      *
580      * @param v Value.
581      * @return an instance
582      */
583     public UnconditionedExactTest with(Method v) {
584         return new UnconditionedExactTest(alternative, Objects.requireNonNull(v), points, optimize);
585     }
586 
587     /**
588      * Return an instance with the configured number of initial points.
589      *
590      * <p>The search for the nuisance parameter will use \( v \) points in the open interval
591      * \( (0, 1) \). The interval is evaluated by including start and end points approximately
592      * equal to 0 and 1. Additional internal points are enumerated using increments of
593      * approximately \( \frac{1}{v-1} \). The minimum number of points is 2. Increasing the
594      * number of points increases the precision of the search at the cost of performance.
595      *
596      * <p>To approximately double the number of points so that all existing points are included
597      * and additional points half-way between them are sampled requires using {@code 2p - 1}
598      * where {@code p} is the existing number of points.
599      *
600      * @param v Value.
601      * @return an instance
602      * @throws IllegalArgumentException if the value is {@code < 2}.
603      */
604     public UnconditionedExactTest withInitialPoints(int v) {
605         if (v <= 1) {
606             throw new InferenceException(InferenceException.X_LT_Y, v, 2);
607         }
608         return new UnconditionedExactTest(alternative, method, v, optimize);
609     }
610 
611     /**
612      * Return an instance with the configured optimization of initial search points.
613      *
614      * <p>If enabled then the initial point(s) with the highest probability is/are used as the start
615      * for an optimization to find a local maxima.
616      *
617      * @param v Value.
618      * @return an instance
619      * @see #withInitialPoints(int)
620      */
621     public UnconditionedExactTest withOptimize(boolean v) {
622         return new UnconditionedExactTest(alternative, method, points, v);
623     }
624 
625     /**
626      * Compute the statistic for the unconditioned exact test. The statistic returned
627      * depends on the configured {@linkplain Method method}.
628      *
629      * @param table 2-by-2 contingency table.
630      * @return test statistic
631      * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any
632      * table entry is negative; any column sum is zero; the table sum is zero or not an
633      * integer; or the number of possible tables exceeds the maximum array capacity.
634      * @see #with(Method)
635      * @see #test(int[][])
636      */
637     public double statistic(int[][] table) {
638         checkTable(table);
639         final int a = table[0][0];
640         final int b = table[0][1];
641         final int c = table[1][0];
642         final int d = table[1][1];
643         final int m = a + c;
644         final int n = b + d;
645         // Exhaustive switch statement
646         switch (method) {
647         case Z_POOLED:
648             return statisticZ(a, b, m, n, true);
649         case Z_UNPOOLED:
650             return statisticZ(a, b, m, n, false);
651         case BOSCHLOO:
652             return statisticBoschloo(a, b, m, n);
653         }
654         throw new IllegalStateException(String.valueOf(method));
655     }
656 
657     /**
658      * Performs an unconditioned exact test on the 2-by-2 contingency table. The statistic and
659      * p-value returned depends on the configured {@linkplain Method method} and
660      * {@linkplain AlternativeHypothesis alternative hypothesis}.
661      *
662      * <p>The search for the nuisance parameter that maximises the p-value can be configured to:
663      * start with a number of {@linkplain #withInitialPoints(int) initial points}; and
664      * {@linkplain #withOptimize(boolean) optimize} the best points.
665      *
666      * @param table 2-by-2 contingency table.
667      * @return test result
668      * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any
669      * table entry is negative; any column sum is zero; the table sum is zero or not an
670      * integer; or the number of possible tables exceeds the maximum array capacity.
671      * @see #with(Method)
672      * @see #with(AlternativeHypothesis)
673      * @see #statistic(int[][])
674      */
675     public Result test(int[][] table) {
676         checkTable(table);
677         final int a = table[0][0];
678         final int b = table[0][1];
679         final int c = table[1][0];
680         final int d = table[1][1];
681         final int m = a + c;
682         final int n = b + d;
683 
684         // Used to track more extreme tables
685         final XYList tableList = new XYList(m, n);
686 
687         final double statistic = findExtremeTables(a, b, tableList);
688         if (tableList.isEmpty() || tableList.isFull()) {
689             // All possible tables are more extreme, e.g. a two-sided test where the
690             // z-statistic is zero.
691             return new Result(statistic);
692         }
693         final double[] opt = computePValue(tableList);
694 
695         return new Result(statistic, opt[0], opt[1]);
696     }
697 
698     /**
699      * Find all tables that are as or more extreme than the observed table.
700      *
701      * <p>If the list of tables is full then all tables are more extreme.
702      * Some configurations can detect this without performing a search
703      * and in this case the list of tables is returned as empty.
704      *
705      * @param a Observed value for a.
706      * @param b Observed value for b.
707      * @param tableList List to track more extreme tables.
708      * @return the test statistic
709      */
710     private double findExtremeTables(int a, int b, XYList tableList) {
711         final int m = tableList.getMaxX();
712         final int n = tableList.getMaxY();
713         // Exhaustive switch statement
714         switch (method) {
715         case Z_POOLED:
716             return findExtremeTablesZ(a, b, m, n, true, tableList);
717         case Z_UNPOOLED:
718             return findExtremeTablesZ(a, b, m, n, false, tableList);
719         case BOSCHLOO:
720             return findExtremeTablesBoschloo(a, b, m, n, tableList);
721         }
722         throw new IllegalStateException(String.valueOf(method));
723     }
724 
725     /**
726      * Compute the statistic from a Z-test.
727      *
728      * @param a Observed value for a.
729      * @param b Observed value for b.
730      * @param m Column sum m.
731      * @param n Column sum n.
732      * @param pooled true to use a pooled variance.
733      * @return z
734      */
735     private static double statisticZ(int a, int b, int m, int n, boolean pooled) {
736         final double p0 = (double) a / m;
737         final double p1 = (double) b / n;
738         // Avoid NaN generation 0 / 0 when the variance is 0
739         if (p0 != p1) {
740             final double variance;
741             if (pooled) {
742                 // Integer sums will not overflow
743                 final double p = (double) (a + b) / (m + n);
744                 variance = p * (1 - p) * (1.0 / m + 1.0 / n);
745             } else {
746                 variance = p0 * (1 - p0) / m + p1 * (1 - p1) / n;
747             }
748             return (p0 - p1) / Math.sqrt(variance);
749         }
750         return 0;
751     }
752 
753     /**
754      * Find all tables that are as or more extreme than the observed table using the Z statistic.
755      *
756      * @param a Observed value for a.
757      * @param b Observed value for b.
758      * @param m Column sum m.
759      * @param n Column sum n.
760      * @param pooled true to use a pooled variance.
761      * @param tableList List to track more extreme tables.
762      * @return observed z
763      */
764     private double findExtremeTablesZ(int a, int b, int m, int n, boolean pooled, XYList tableList) {
765         final double statistic = statisticZ(a, b, m, n, pooled);
766         // Identify more extreme tables using the alternate hypothesis
767         final DoublePredicate test;
768         if (alternative == AlternativeHypothesis.GREATER_THAN) {
769             test = z -> z >= statistic;
770         } else if (alternative == AlternativeHypothesis.LESS_THAN) {
771             test = z -> z <= statistic;
772         } else {
773             // two-sided
774             if (statistic == 0) {
775                 // Early exit: all tables are as extreme
776                 return 0;
777             }
778             final double za = Math.abs(statistic);
779             test = z -> Math.abs(z) >= za;
780         }
781         // Precompute factors
782         final double mn = (double) m + n;
783         final double norm = 1.0 / m + 1.0 / n;
784         double z;
785         // Process all possible tables
786         for (int i = 0; i <= m; i++) {
787             final double p0 = (double) i / m;
788             final double vp0 = p0 * (1 - p0) / m;
789             for (int j = 0; j <= n; j++) {
790                 final double p1 = (double) j / n;
791                 // Avoid NaN generation 0 / 0 when the variance is 0
792                 if (p0 == p1) {
793                     z = 0;
794                 } else {
795                     final double variance;
796                     if (pooled) {
797                         // Integer sums will not overflow
798                         final double p = (i + j) / mn;
799                         variance = p * (1 - p) * norm;
800                     } else {
801                         variance = vp0 + p1 * (1 - p1) / n;
802                     }
803                     z = (p0 - p1) / Math.sqrt(variance);
804                 }
805                 if (test.test(z)) {
806                     tableList.add(i, j);
807                 }
808             }
809         }
810         return statistic;
811     }
812 
813     /**
814      * Compute the statistic using Fisher's p-value (also known as Boschloo's test).
815      *
816      * @param a Observed value for a.
817      * @param b Observed value for b.
818      * @param m Column sum m.
819      * @param n Column sum n.
820      * @return p-value
821      */
822     private double statisticBoschloo(int a, int b, int m, int n) {
823         final int nn = m + n;
824         final int k = a + b;
825         // Re-use the cached Hypergeometric implementation to allow the value
826         // to be identical for the statistic and test methods.
827         final Hypergeom dist = new Hypergeom(nn, k, m);
828         if (alternative == AlternativeHypothesis.GREATER_THAN) {
829             return dist.sf(a - 1);
830         } else if (alternative == AlternativeHypothesis.LESS_THAN) {
831             return dist.cdf(a);
832         }
833         // two-sided: Find all i where Pr(X = i) <= Pr(X = a) and sum them.
834         return statisticBoschlooTwoSided(dist, a);
835     }
836 
837     /**
838      * Compute the two-sided statistic using Fisher's p-value (also known as Boschloo's test).
839      *
840      * @param distribution Hypergeometric distribution.
841      * @param k Observed value.
842      * @return p-value
843      */
844     private static double statisticBoschlooTwoSided(Hypergeom distribution, int k) {
845         // two-sided: Find all i where Pr(X = i) <= Pr(X = k) and sum them.
846         // Logic is the same as FisherExactTest but using the probability (PMF), which
847         // is cached, rather than the logProbability.
848         final double pk = distribution.pmf(k);
849 
850         final int m1 = distribution.getLowerMode();
851         final int m2 = distribution.getUpperMode();
852         if (k < m1) {
853             // Lower half = cdf(k)
854             // Find upper half. As k < lower mode i should never
855             // reach the lower mode based on the probability alone.
856             // Bracket with the upper mode.
857             final int i = Searches.searchDescending(m2, distribution.getSupportUpperBound(), pk,
858                 distribution::pmf);
859             return distribution.cdf(k) +
860                    distribution.sf(i - 1);
861         } else if (k > m2) {
862             // Upper half = sf(k - 1)
863             // Find lower half. As k > upper mode i should never
864             // reach the upper mode based on the probability alone.
865             // Bracket with the lower mode.
866             final int i = Searches.searchAscending(distribution.getSupportLowerBound(), m1, pk,
867                 distribution::pmf);
868             return distribution.cdf(i) +
869                    distribution.sf(k - 1);
870         }
871         // k == mode
872         // Edge case where the sum of probabilities will be either
873         // 1 or 1 - Pr(X = mode) where mode != k
874         final double pm = distribution.pmf(k == m1 ? m2 : m1);
875         return pm > pk ? 1 - pm : 1;
876     }
877 
878     /**
879      * Find all tables that are as or more extreme than the observed table using the
880      * Fisher's p-value as the statistic (also known as Boschloo's test).
881      *
882      * @param a Observed value for a.
883      * @param b Observed value for b.
884      * @param m Column sum m.
885      * @param n Column sum n.
886      * @param tableList List to track more extreme tables.
887      * @return observed p-value
888      */
889     private double findExtremeTablesBoschloo(int a, int b, int m, int n, XYList tableList) {
890         final double statistic = statisticBoschloo(a, b, m, n);
891 
892         // Function to compute the statistic
893         final BoschlooStatistic func;
894         if (alternative == AlternativeHypothesis.GREATER_THAN) {
895             func = (dist, x) -> dist.sf(x - 1);
896         } else if (alternative == AlternativeHypothesis.LESS_THAN) {
897             func = Hypergeom::cdf;
898         } else {
899             func = UnconditionedExactTest::statisticBoschlooTwoSided;
900         }
901 
902         // All tables are: 0 <= i <= m  by  0 <= j <= n
903         // Diagonal (upper-left to lower-right) strips of the possible
904         // tables use the same hypergeometric distribution
905         // (i.e. i+j == number of successes). To enumerate all requires
906         // using the full range of all distributions: 0 <= i+j <= m+n.
907         // Note the column sum m is fixed.
908         final int mn = m + n;
909         for (int k = 0; k <= mn; k++) {
910             final Hypergeom dist = new Hypergeom(mn, k, m);
911             final int lo = dist.getSupportLowerBound();
912             final int hi = dist.getSupportUpperBound();
913             for (int i = lo; i <= hi; i++) {
914                 if (func.value(dist, i) <= statistic) {
915                     // j = k - i
916                     tableList.add(i, k - i);
917                 }
918             }
919         }
920         return statistic;
921     }
922 
923     /**
924      * Compute the nuisance parameter and p-value for the binomial model given the list
925      * of possible tables.
926      *
927      * <p>The current method enumerates an initial set of points and stores local
928      * extrema as candidates. Any candidate within 2% of the best is optionally
929      * optimized; this is limited to the top 3 candidates. These settings
930      * could be exposed as configurable options. Currently only the choice to optimize
931      * or not is exposed.
932      *
933      * @param tableList List of tables.
934      * @return [nuisance parameter, p-value]
935      */
936     private double[] computePValue(XYList tableList) {
937         final DoubleUnaryOperator func = createBinomialModel(tableList);
938 
939         // Enumerate the range [LOWER, 1-LOWER] and save the best points for optimization
940         final Candidates minima = new Candidates(MAX_CANDIDATES, MINIMA_EPS);
941         final int n = points - 1;
942         final double inc = (1.0 - 2 * LOWER_BOUND) / n;
943         // Moving window of 3 values to identify minima.
944         // px holds the position of the previous evaluated point.
945         double v2 = 0;
946         double v3 = func.applyAsDouble(LOWER_BOUND);
947         double px = LOWER_BOUND;
948         for (int i = 1; i < n; i++) {
949             final double x = LOWER_BOUND + i * inc;
950             final double v1 = v2;
951             v2 = v3;
952             v3 = func.applyAsDouble(x);
953             addCandidate(minima, v1, v2, v3, px);
954             px = x;
955         }
956         // Add the upper bound
957         final double x = 1 - LOWER_BOUND;
958         final double vn = func.applyAsDouble(x);
959         addCandidate(minima, v2, v3, vn, px);
960         addCandidate(minima, v3, vn, 0, x);
961 
962         final double[] min = minima.getMinimum();
963 
964         // Optionally optimize the best point(s) (if not already optimal)
965         if (optimize && min[1] > -1) {
966             final BrentOptimizer opt = new BrentOptimizer(SOLVER_RELATIVE_EPS, Double.MIN_VALUE);
967             final BracketFinder bf = new BracketFinder();
968             minima.forEach(candidate -> {
969                 double a = candidate[0];
970                 final double fa;
971                 // Attempt to bracket the minima. Use an initial second point placed relative to
972                 // the size of the interval: [x - increment, x + increment].
973                 // if a < 0.5 then add a small delta ; otherwise subtract the delta.
974                 final double b = a - Math.copySign(inc * INC_FRACTION, a - 0.5);
975                 if (bf.search(func, a, b, 0, 1)) {
976                     // The bracket a < b < c must have f(b) < min(f(a), f(b))
977                     final PointValuePair p = opt.optimize(func, bf.getLo(), bf.getHi(), bf.getMid(), bf.getFMid());
978                     a = p.getPoint();
979                     fa = p.getValue();
980                 } else {
981                     // Mid-point is at one of the bounds (i.e. is 0 or 1)
982                     a = bf.getMid();
983                     fa = bf.getFMid();
984                 }
985                 if (fa < min[1]) {
986                     min[0] = a;
987                     min[1] = fa;
988                 }
989             });
990         }
991         // Reverse the sign of the p-value to create a maximum.
992         // Note that due to the summation the p-value can be above 1 so we clip the final result.
993         // Note: Apply max then reverse sign. This will pass through spurious NaN values if
994         // the p-value computation produced all NaNs.
995         min[1] = -Math.max(-1, min[1]);
996         return min;
997     }
998 
999     /**
1000      * Creates the binomial model p-value function for the nuisance parameter.
1001      * Note: This function computes the negative p-value so is suitable for
1002      * optimization by a search for a minimum.
1003      *
1004      * @param tableList List of tables.
1005      * @return the function
1006      */
1007     private static DoubleUnaryOperator createBinomialModel(XYList tableList) {
1008         final int m = tableList.getMaxX();
1009         final int n = tableList.getMaxY();
1010         final int mn = m + n;
1011         // Compute the probability using logs
1012         final double[] c = new double[tableList.size()];
1013         final int[] ij = new int[tableList.size()];
1014         final int width = tableList.getWidth();
1015 
1016         // Compute the log binomial dynamically for a small number of values
1017         final IntToDoubleFunction binomM;
1018         final IntToDoubleFunction binomN;
1019         if (tableList.size() < mn) {
1020             binomM = k -> LogBinomialCoefficient.value(m, k);
1021             binomN = k -> LogBinomialCoefficient.value(n, k);
1022         } else {
1023             // Pre-compute all values
1024             binomM = createLogBinomialCoefficients(m);
1025             binomN = m == n ? binomM : createLogBinomialCoefficients(n);
1026         }
1027 
1028         // Handle special cases i+j == 0 and i+j == m+n.
1029         // These will occur only once, if at all. Mark if they occur.
1030         int flag = 0;
1031         int j = 0;
1032         for (int i = 0; i < c.length; i++) {
1033             final int index = tableList.get(i);
1034             final int x = index % width;
1035             final int y = index / width;
1036             final int xy = x + y;
1037             if (xy == 0) {
1038                 flag |= 1;
1039             } else if (xy == mn) {
1040                 flag |= 2;
1041             } else {
1042                 ij[j] = xy;
1043                 c[j] = binomM.applyAsDouble(x) + binomN.applyAsDouble(y);
1044                 j++;
1045             }
1046         }
1047 
1048         final int size = j;
1049         final boolean ij0 = (flag & 1) != 0;
1050         final boolean ijmn = (flag & 2) != 0;
1051         return pi -> {
1052             final double logp = Math.log(pi);
1053             final double log1mp = Math.log1p(-pi);
1054             double sum = 0;
1055             for (int i = 0; i < size; i++) {
1056                 // binom(m, i) * binom(n, j) * pi^(i+j) * (1-pi)^(m+n-i-j)
1057                 sum += Math.exp(ij[i] * logp + (mn - ij[i]) * log1mp + c[i]);
1058             }
1059             // Add the simplified terms where the binomial is 1.0 and one power is x^0 == 1.0.
1060             // This avoids 0 * log(x) generating NaN when x is 0 in the case where pi was 0 or 1.
1061             // Reuse exp (not pow) to support pi approaching 0 or 1.
1062             if (ij0) {
1063                 // pow(1-pi, mn)
1064                 sum += Math.exp(mn * log1mp);
1065             }
1066             if (ijmn) {
1067                 // pow(pi, mn)
1068                 sum += Math.exp(mn * logp);
1069             }
1070             // The optimizer minimises the function so this returns -p.
1071             return -sum;
1072         };
1073     }
1074 
1075     /**
1076      * Create the natural logarithm of the binomial coefficient for all {@code k = [0, n]}.
1077      *
1078      * @param n Limit N.
1079      * @return ln binom(n, k)
1080      */
1081     private static IntToDoubleFunction createLogBinomialCoefficients(int n) {
1082         final double[] binom = new double[n + 1];
1083         // Exploit symmetry.
1084         // ignore: binom(n, 0) == binom(n, n) == 1
1085         int j = n - 1;
1086         for (int i = 1; i <= j; i++, j--) {
1087             binom[i] = binom[j] = LogBinomialCoefficient.value(n, i);
1088         }
1089         return k -> binom[k];
1090     }
1091 
1092     /**
1093      * Add point 2 to the list of minima if neither neighbour value is lower.
1094      * <pre>
1095      * !(v1 < v2 || v3 < v2)
1096      * </pre>
1097      *
1098      * @param minima Candidate minima.
1099      * @param v1 First point function value.
1100      * @param v2 Second point function value.
1101      * @param v3 Third point function value.
1102      * @param x2 Second point.
1103      */
1104     private void addCandidate(Candidates minima, double v1, double v2, double v3, double x2) {
1105         final double min = v1 < v3 ? v1 : v3;
1106         if (min < v2) {
1107             // Lower neighbour(s)
1108             return;
1109         }
1110         // Add the candidate. This could be NaN but the candidate list handles this by storing
1111         // NaN only when no non-NaN values have been observed.
1112         minima.add(x2, v2);
1113     }
1114 
1115     /**
1116      * Check the input is a 2-by-2 contingency table.
1117      *
1118      * @param table Contingency table.
1119      * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any
1120      * table entry is negative; any column sum is zero; the table sum is zero or not an
1121      * integer; or the number of possible tables exceeds the maximum array capacity.
1122      */
1123     private static void checkTable(int[][] table) {
1124         Arguments.checkTable(table);
1125         // Must all be positive
1126         final int a = table[0][0];
1127         final int c = table[1][0];
1128         // checkTable has validated the total sum is < 2^31
1129         final int m = a + c;
1130         if (m == 0) {
1131             throw new InferenceException(InferenceException.ZERO_AT, COLUMN_SUM, 0);
1132         }
1133         final int b = table[0][1];
1134         final int d = table[1][1];
1135         final int n = b + d;
1136         if (n == 0) {
1137             throw new InferenceException(InferenceException.ZERO_AT, COLUMN_SUM, 1);
1138         }
1139         // Total possible tables must be a size we can track in an array (to compute the p-value)
1140         final long size = (m + 1L) * (n + 1L);
1141         if (size > MAX_TABLES) {
1142             throw new InferenceException(InferenceException.X_GT_Y, size, MAX_TABLES);
1143         }
1144     }
1145 }