1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.statistics.inference;
18
19 import java.util.Arrays;
20 import java.util.Objects;
21 import java.util.function.Consumer;
22 import java.util.function.DoublePredicate;
23 import java.util.function.DoubleUnaryOperator;
24 import java.util.function.IntToDoubleFunction;
25 import org.apache.commons.numbers.combinatorics.LogBinomialCoefficient;
26 import org.apache.commons.statistics.inference.BrentOptimizer.PointValuePair;
27
28 /**
29 * Implements an unconditioned exact test for a contingency table.
30 *
31 * <p>Performs an exact test for the statistical significance of the association (contingency)
32 * between two kinds of categorical classification. A 2x2 contingency table is:
33 *
34 * <p>\[ \left[ {\begin{array}{cc}
35 * a & b \\
36 * c & d \\
37 * \end{array} } \right] \]
38 *
39 * <p>This test applies to the case of a 2x2 contingency table with one margin fixed. Note that
40 * if both margins are fixed (the row sums and column sums are not random)
41 * then Fisher's exact test can be applied.
42 *
43 * <p>This implementation fixes the column sums \( m = a + c \) and \( n = b + d \).
44 * All possible tables can be created using \( 0 \le a \le m \) and \( 0 \le b \le n \).
45 * The random values \( a \) and \( b \) follow a binomial distribution with probabilities
46 * \( p_0 \) and \( p_1 \) such that \( a \sim B(m, p_0) \) and \( b \sim B(n, p_1) \).
47 * The p-value of the 2x2 table is the product of two binomials:
48 *
49 * <p>\[ \begin{aligned}
50 * p &= Pr(a; m, p_0) \times Pr(b; n, p_1) \\
51 * &= \binom{m}{a} p_0^a (1-p_0)^{m-a} \times \binom{n}{b} p_1^b (1-p_1)^{n-b} \end{aligned} \]
52 *
53 * <p>For the binomial model, the null hypothesis is the two nuisance parameters are equal
54 * \( p_0 = p_1 = \pi\), with \( \pi \) the probability for equal proportions, and the probability
55 * of any single table is:
56 *
57 * <p>\[ p = \binom{m}{a} \binom{n}{b} \pi^{a+b} (1-\pi)^{m+n-a-b} \]
58 *
59 * <p>The p-value of the observed table is calculated by maximising the sum of the as or more
60 * extreme tables over the domain of the nuisance parameter \( 0 \lt \pi \lt 1 \):
61 *
62 * <p>\[ p(a, b) = \sum_{i,j} \binom{m}{i} \binom{n}{j} \pi^{i+j} (1-\pi)^{m+n-i-j} \]
63 *
64 * <p>where table \( (i,j) \) is as or more extreme than the observed table \( (a, b) \). The test
65 * can be configured to select more extreme tables using various {@linkplain Method methods}.
66 *
67 * <p>Note that the sum of the joint binomial distribution is a univariate function for
68 * the nuisance parameter \( \pi \). This function may have many local maxima and the
69 * search enumerates the range with a configured {@linkplain #withInitialPoints(int)
70 * number of points}. The best candidates are optionally used as the start point for an
71 * {@linkplain #withOptimize(boolean) optimized} search for a local maxima.
72 *
73 * <p>References:
74 * <ol>
75 * <li>
76 * Barnard, G.A. (1947).
77 * <a href="https://doi.org/10.1093/biomet/34.1-2.123">Significance tests for 2x2 tables.</a>
78 * Biometrika, 34, Issue 1-2, 123–138.
79 * <li>
80 * Boschloo, R.D. (1970).
81 * <a href="https://doi.org/10.1111/j.1467-9574.1970.tb00104.x">Raised conditional level of
82 * significance for the 2 × 2-table when testing the equality of two probabilities.</a>
83 * Statistica neerlandica, 24(1), 1–9.
84 * <li>
85 * Suisaa, A and Shuster, J.J. (1985).
86 * <a href="https://doi.org/10.2307/2981892">Exact Unconditional Sample Sizes
87 * for the 2 × 2 Binomial Trial.</a>
88 * Journal of the Royal Statistical Society. Series A (General), 148(4), 317-327.
89 * </ol>
90 *
91 * @see FisherExactTest
92 * @see <a href="https://en.wikipedia.org/wiki/Boschloo%27s_test">Boschloo's test (Wikipedia)</a>
93 * @see <a href="https://en.wikipedia.org/wiki/Barnard%27s_test">Barnard's test (Wikipedia)</a>
94 * @since 1.1
95 */
96 public final class UnconditionedExactTest {
97 /**
98 * Default instance.
99 *
100 * <p>SciPy's boschloo_exact and barnard_exact tests use 32 points in the interval [0,
101 * 1) The R Exact package uses 100 in the interval [1e-5, 1-1e-5]. Barnards 1947 paper
102 * describes the nuisance parameter in the open interval {@code 0 < pi < 1}. Here we
103 * respect the open-interval for the initial candidates and ignore 0 and 1. The
104 * initial bounds used are the same as the R Exact package. We closely match the inner
105 * 31 points from SciPy by using 33 points by default.
106 */
107 private static final UnconditionedExactTest DEFAULT = new UnconditionedExactTest(
108 AlternativeHypothesis.TWO_SIDED, Method.BOSCHLOO, 33, true);
109 /** Lower bound for the enumerated interval. The upper bound is {@code 1 - lower}. */
110 private static final double LOWER_BOUND = 1e-5;
111 /** Relative epsilon for the Brent solver. This is limited for a univariate function
112 * to approximately sqrt(eps) with eps = 2^-52. */
113 private static final double SOLVER_RELATIVE_EPS = 1.4901161193847656E-8;
114 /** Fraction of the increment (interval between enumerated points) to initialise the bracket
115 * for the minima. Note the minima should lie between x +/- increment. The bracket should
116 * search within this range. Set to 1/8 and so the initial point of the bracket is
117 * approximately 1.61 * 1/8 = 0.2 of the increment away from initial points a or b. */
118 private static final double INC_FRACTION = 0.125;
119 /** Maximum number of candidate to optimize. This is a safety limit to avoid excess
120 * optimization. Only candidates within a relative tolerance of the best candidate are
121 * stored. If the number of candidates exceeds this value then many candidates have a
122 * very similar p-value and the top candidates will be optimized. Using a value of 3
123 * allows at least one other candidate to be optimized when there is two-fold
124 * symmetry in the energy function. */
125 private static final int MAX_CANDIDATES = 3;
126 /** Relative distance of candidate minima from the lowest candidate. Used to exclude
127 * poor candidates from optimization. */
128 private static final double MINIMA_EPS = 0.02;
129 /** The maximum number of tables. This is limited by the maximum number of indices that
130 * can be maintained in memory. Potentially up to this number of tables must be tracked
131 * during computation of the p-value for as or more extreme tables. The limit is set
132 * using the same limit for maximum capacity as java.util.ArrayList. In practice any
133 * table anywhere near this limit can be computed using an alternative such as a chi-squared
134 * or g test. */
135 private static final int MAX_TABLES = Integer.MAX_VALUE - 8;
136 /** Error message text for zero column sums. */
137 private static final String COLUMN_SUM = "Column sum";
138
139 /** Alternative hypothesis. */
140 private final AlternativeHypothesis alternative;
141 /** Method to identify more extreme tables. */
142 private final Method method;
143 /** Number of initial points. */
144 private final int points;
145 /** Option to optimize the best initial point(s). */
146 private final boolean optimize;
147
148 /**
149 * Define the method to determine the more extreme tables.
150 *
151 * @since 1.1
152 */
153 public enum Method {
154 /**
155 * Uses the test statistic from a Z-test using a pooled variance.
156 *
157 * <p>\[ T(X) = \frac{\hat{p}_0 - \hat{p}_1}{\sqrt{\hat{p}(1 - \hat{p}) (\frac{1}{m} + \frac{1}{n})}} \]
158 *
159 * <p>where \( \hat{p}_0 = a / m \), \( \hat{p}_1 = b / n \), and
160 * \( \hat{p} = (a+b) / (m+n) \) are the estimators of \( p_0 \), \( p_1 \) and the
161 * pooled probability \( p \) assuming \( p_0 = p_1 \).
162 *
163 * <p>The more extreme tables are identified using the {@link AlternativeHypothesis}:
164 * <ul>
165 * <li>greater: \( T(X) \ge T(X_0) \)
166 * <li>less: \( T(X) \le T(X_0) \)
167 * <li>two-sided: \( | T(X) | \ge | T(X_0) | \)
168 * </ul>
169 *
170 * <p>The use of the Z statistic was suggested by Suissa and Shuster (1985).
171 * This method is uniformly more powerful than Fisher's test for balanced designs
172 * (\( m = n \)).
173 */
174 Z_POOLED,
175
176 /**
177 * Uses the test statistic from a Z-test using an unpooled variance.
178 *
179 * <p>\[ T(X) = \frac{\hat{p}_0 - \hat{p}_1}
180 * {\sqrt{ \frac{\hat{p}_0(1 - \hat{p}_0)}{m} + \frac{\hat{p}_1(1 - \hat{p}_1)}{n}} } \]
181 *
182 * <p>where \( \hat{p}_0 = a / m \) and \( \hat{p}_1 = b / n \).
183 *
184 * <p>The more extreme tables are identified using the {@link AlternativeHypothesis} as
185 * per the {@link #Z_POOLED} method.
186 */
187 Z_UNPOOLED,
188
189 /**
190 * Uses the p-value from Fisher's exact test. This is also known as Boschloo's test.
191 *
192 * <p>The p-value for Fisher's test is computed using using the
193 * {@link AlternativeHypothesis}. The more extreme tables are identified using
194 * \( p(X) \le p(X_0) \).
195 *
196 * <p>This method is always uniformly more powerful than Fisher's test.
197 *
198 * @see FisherExactTest
199 */
200 BOSCHLOO;
201 }
202
203 /**
204 * Result for the unconditioned exact test.
205 *
206 * <p>This class is immutable.
207 *
208 * @since 1.1
209 */
210 public static final class Result extends BaseSignificanceResult {
211 /** Nuisance parameter. */
212 private final double pi;
213
214 /**
215 * Create an instance where all tables are more extreme, i.e. the p-value
216 * is 1.0.
217 *
218 * @param statistic Test statistic.
219 */
220 Result(double statistic) {
221 super(statistic, 1);
222 this.pi = 0.5;
223 }
224
225 /**
226 * @param statistic Test statistic.
227 * @param pi Nuisance parameter.
228 * @param p Result p-value.
229 */
230 Result(double statistic, double pi, double p) {
231 super(statistic, p);
232 this.pi = pi;
233 }
234
235 /**
236 * {@inheritDoc}
237 *
238 * <p>The value of the statistic is dependent on the {@linkplain Method method}
239 * used to determine the more extreme tables.
240 */
241 @Override
242 public double getStatistic() {
243 // Note: This method is here for documentation
244 return super.getStatistic();
245 }
246
247 /**
248 * Gets the nuisance parameter that maximised the probability sum of the as or more
249 * extreme tables.
250 *
251 * @return the nuisance parameter.
252 */
253 public double getNuisanceParameter() {
254 return pi;
255 }
256 }
257
258 /**
259 * An expandable list of (x,y) values. This allows tracking 2D positions stored as
260 * a single index.
261 */
262 private static class XYList {
263 /** The maximum size of array to allocate. */
264 private final int max;
265 /** Width, or maximum x value (exclusive). */
266 private final int width;
267
268 /** The size of the list. */
269 private int size;
270 /** The list data. */
271 private int[] data = new int[10];
272
273 /**
274 * Create an instance. It is assumed that (maxx+1)*(maxy+1) does not exceed the
275 * capacity of an array.
276 *
277 * @param maxx Maximum x-value (inclusive).
278 * @param maxy Maximum y-value (inclusive).
279 */
280 XYList(int maxx, int maxy) {
281 this.width = maxx + 1;
282 this.max = width * (maxy + 1);
283 }
284
285 /**
286 * Gets the width.
287 * (x, y) values are stored using y * width + x.
288 *
289 * @return the width
290 */
291 int getWidth() {
292 return width;
293 }
294
295 /**
296 * Gets the maximum X value (inclusive).
297 *
298 * @return the max X
299 */
300 int getMaxX() {
301 return width - 1;
302 }
303
304 /**
305 * Gets the maximum Y value (inclusive).
306 *
307 * @return the max Y
308 */
309 int getMaxY() {
310 return max / width - 1;
311 }
312
313 /**
314 * Adds the value to the list.
315 *
316 * @param x X value.
317 * @param y Y value.
318 */
319 void add(int x, int y) {
320 if (size == data.length) {
321 // Overflow safe doubling of the current size.
322 data = Arrays.copyOf(data, (int) Math.min(max, size * 2L));
323 }
324 data[size++] = width * y + x;
325 }
326
327 /**
328 * Gets the 2D index at the specified {@code index}.
329 * The index is y * width + x:
330 * <pre>
331 * x = index % width
332 * y = index / width
333 * </pre>
334 *
335 * @param index Element index.
336 * @return the 2D index
337 */
338 int get(int index) {
339 return data[index];
340 }
341
342 /**
343 * Gets the number of elements in the list.
344 *
345 * @return the size
346 */
347 int size() {
348 return size;
349 }
350
351 /**
352 * Checks if the list size is zero.
353 *
354 * @return true if empty
355 */
356 boolean isEmpty() {
357 return size == 0;
358 }
359
360 /**
361 * Checks if the list is the maximum capacity.
362 *
363 * @return true if full
364 */
365 boolean isFull() {
366 return size == max;
367 }
368 }
369
370 /**
371 * A container of (key,value) pairs to store candidate minima. Encapsulates the
372 * logic of storing multiple initial search points for optimization.
373 *
374 * <p>Stores all pairs within a relative tolerance of the lowest minima up to a set
375 * capacity. When at capacity the worst candidate is replaced by addition of a
376 * better candidate.
377 *
378 * <p>Special handling is provided to store only a single NaN value if no non-NaN
379 * values have been observed. This prevents storing a large number of NaN
380 * candidates.
381 */
382 static class Candidates {
383 /** The maximum size of array to allocate. */
384 private final int max;
385 /** Relative distance from lowest candidate. */
386 private final double eps;
387 /** Candidate (key,value) pairs. */
388 private double[][] data;
389 /** Current size of the list. */
390 private int size;
391 /** Current minimum. */
392 private double min = Double.POSITIVE_INFINITY;
393 /** Current threshold for inclusion. */
394 private double threshold = Double.POSITIVE_INFINITY;
395
396 /**
397 * Create an instance.
398 *
399 * @param max Maximum number of allowed candidates (limited to at least 1).
400 * @param eps Relative distance of candidate minima from the lowest candidate
401 * (assumed to be positive and finite).
402 */
403 Candidates(int max, double eps) {
404 this.max = Math.max(1, max);
405 this.eps = eps;
406 // Create the initial storage
407 data = new double[Math.min(this.max, 4)][];
408 }
409
410 /**
411 * Adds the (key, value) pair.
412 *
413 * @param k Key.
414 * @param v Value.
415 */
416 void add(double k, double v) {
417 // Store only a single NaN
418 if (Double.isNaN(v)) {
419 if (size == 0) {
420 // No requirement to check capacity
421 data[size++] = new double[] {k, v};
422 }
423 return;
424 }
425 // Here values are non-NaN.
426 // If higher then do not store.
427 if (v > threshold) {
428 return;
429 }
430 // Check if lower than the current minima.
431 if (v < min) {
432 min = v;
433 // Get new threshold
434 threshold = v + Math.abs(v) * eps;
435 // Remove existing entries above the threshold
436 int s = 0;
437 for (int i = 0; i < size; i++) {
438 // This will filter NaN values
439 if (data[i][1] <= threshold) {
440 data[s++] = data[i];
441 }
442 }
443 size = s;
444 // Caution: This does not clear stale data
445 // by setting all values in [newSize, oldSize) = null
446 }
447 addPair(k, v);
448 }
449
450 /**
451 * Add the (key, value) pair to the data.
452 * It is assumed the data satisfy the conditions for addition.
453 *
454 * @param k Key.
455 * @param v Value.
456 */
457 private void addPair(double k, double v) {
458 if (size == data.length) {
459 if (size == max) {
460 // At capacity.
461 replaceWorst(k, v);
462 return;
463 }
464 // Expand
465 data = Arrays.copyOfRange(data, 0, (int) Math.min(max, size * 2L));
466 }
467 data[size++] = new double[] {k, v};
468 }
469
470 /**
471 * Replace the worst candidate.
472 *
473 * @param k Key.
474 * @param v Value.
475 */
476 private void replaceWorst(double k, double v) {
477 // Note: This only occurs when NaN values have been removed by addition
478 // of non-NaN values.
479 double[] worst = data[0];
480 for (int i = 1; i < size; i++) {
481 if (worst[1] < data[i][1]) {
482 worst = data[i];
483 }
484 }
485 worst[0] = k;
486 worst[1] = v;
487 }
488
489 /**
490 * Return the minimum (key,value) pair.
491 *
492 * @return the minimum (or null)
493 */
494 double[] getMinimum() {
495 // This will handle size=0 as data[0] will be null
496 double[] best = data[0];
497 for (int i = 1; i < size; i++) {
498 if (best[1] > data[i][1]) {
499 best = data[i];
500 }
501 }
502 return best;
503 }
504
505 /**
506 * Perform the given action for each (key, value) pair.
507 *
508 * @param action Action.
509 */
510 void forEach(Consumer<double[]> action) {
511 for (int i = 0; i < size; i++) {
512 action.accept(data[i]);
513 }
514 }
515 }
516
517 /**
518 * Compute the statistic for Boschloo's test.
519 */
520 @FunctionalInterface
521 private interface BoschlooStatistic {
522 /**
523 * Compute Fisher's p-value for the 2x2 contingency table with the observed
524 * value {@code x} in position [0][0]. Note that the table margins are fixed
525 * and are defined by the population size, number of successes and sample
526 * size of the specified hypergeometric distribution.
527 *
528 * @param dist Hypergeometric distribution.
529 * @param x Value.
530 * @return Fisher's p-value
531 */
532 double value(Hypergeom dist, int x);
533 }
534
535 /**
536 * @param alternative Alternative hypothesis.
537 * @param method Method to identify more extreme tables.
538 * @param points Number of initial points.
539 * @param optimize Option to optimize the best initial point(s).
540 */
541 private UnconditionedExactTest(AlternativeHypothesis alternative,
542 Method method,
543 int points,
544 boolean optimize) {
545 this.alternative = alternative;
546 this.method = method;
547 this.points = points;
548 this.optimize = optimize;
549 }
550
551 /**
552 * Return an instance using the default options.
553 *
554 * <ul>
555 * <li>{@link AlternativeHypothesis#TWO_SIDED}
556 * <li>{@link Method#BOSCHLOO}
557 * <li>{@linkplain #withInitialPoints(int) points = 33}
558 * <li>{@linkplain #withOptimize(boolean) optimize = true}
559 * </ul>
560 *
561 * @return default instance
562 */
563 public static UnconditionedExactTest withDefaults() {
564 return DEFAULT;
565 }
566
567 /**
568 * Return an instance with the configured alternative hypothesis.
569 *
570 * @param v Value.
571 * @return an instance
572 */
573 public UnconditionedExactTest with(AlternativeHypothesis v) {
574 return new UnconditionedExactTest(Objects.requireNonNull(v), method, points, optimize);
575 }
576
577 /**
578 * Return an instance with the configured method.
579 *
580 * @param v Value.
581 * @return an instance
582 */
583 public UnconditionedExactTest with(Method v) {
584 return new UnconditionedExactTest(alternative, Objects.requireNonNull(v), points, optimize);
585 }
586
587 /**
588 * Return an instance with the configured number of initial points.
589 *
590 * <p>The search for the nuisance parameter will use \( v \) points in the open interval
591 * \( (0, 1) \). The interval is evaluated by including start and end points approximately
592 * equal to 0 and 1. Additional internal points are enumerated using increments of
593 * approximately \( \frac{1}{v-1} \). The minimum number of points is 2. Increasing the
594 * number of points increases the precision of the search at the cost of performance.
595 *
596 * <p>To approximately double the number of points so that all existing points are included
597 * and additional points half-way between them are sampled requires using {@code 2p - 1}
598 * where {@code p} is the existing number of points.
599 *
600 * @param v Value.
601 * @return an instance
602 * @throws IllegalArgumentException if the value is {@code < 2}.
603 */
604 public UnconditionedExactTest withInitialPoints(int v) {
605 if (v <= 1) {
606 throw new InferenceException(InferenceException.X_LT_Y, v, 2);
607 }
608 return new UnconditionedExactTest(alternative, method, v, optimize);
609 }
610
611 /**
612 * Return an instance with the configured optimization of initial search points.
613 *
614 * <p>If enabled then the initial point(s) with the highest probability is/are used as the start
615 * for an optimization to find a local maxima.
616 *
617 * @param v Value.
618 * @return an instance
619 * @see #withInitialPoints(int)
620 */
621 public UnconditionedExactTest withOptimize(boolean v) {
622 return new UnconditionedExactTest(alternative, method, points, v);
623 }
624
625 /**
626 * Compute the statistic for the unconditioned exact test. The statistic returned
627 * depends on the configured {@linkplain Method method}.
628 *
629 * @param table 2-by-2 contingency table.
630 * @return test statistic
631 * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any
632 * table entry is negative; any column sum is zero; the table sum is zero or not an
633 * integer; or the number of possible tables exceeds the maximum array capacity.
634 * @see #with(Method)
635 * @see #test(int[][])
636 */
637 public double statistic(int[][] table) {
638 checkTable(table);
639 final int a = table[0][0];
640 final int b = table[0][1];
641 final int c = table[1][0];
642 final int d = table[1][1];
643 final int m = a + c;
644 final int n = b + d;
645 // Exhaustive switch statement
646 switch (method) {
647 case Z_POOLED:
648 return statisticZ(a, b, m, n, true);
649 case Z_UNPOOLED:
650 return statisticZ(a, b, m, n, false);
651 case BOSCHLOO:
652 return statisticBoschloo(a, b, m, n);
653 }
654 throw new IllegalStateException(String.valueOf(method));
655 }
656
657 /**
658 * Performs an unconditioned exact test on the 2-by-2 contingency table. The statistic and
659 * p-value returned depends on the configured {@linkplain Method method} and
660 * {@linkplain AlternativeHypothesis alternative hypothesis}.
661 *
662 * <p>The search for the nuisance parameter that maximises the p-value can be configured to:
663 * start with a number of {@linkplain #withInitialPoints(int) initial points}; and
664 * {@linkplain #withOptimize(boolean) optimize} the best points.
665 *
666 * @param table 2-by-2 contingency table.
667 * @return test result
668 * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any
669 * table entry is negative; any column sum is zero; the table sum is zero or not an
670 * integer; or the number of possible tables exceeds the maximum array capacity.
671 * @see #with(Method)
672 * @see #with(AlternativeHypothesis)
673 * @see #statistic(int[][])
674 */
675 public Result test(int[][] table) {
676 checkTable(table);
677 final int a = table[0][0];
678 final int b = table[0][1];
679 final int c = table[1][0];
680 final int d = table[1][1];
681 final int m = a + c;
682 final int n = b + d;
683
684 // Used to track more extreme tables
685 final XYList tableList = new XYList(m, n);
686
687 final double statistic = findExtremeTables(a, b, tableList);
688 if (tableList.isEmpty() || tableList.isFull()) {
689 // All possible tables are more extreme, e.g. a two-sided test where the
690 // z-statistic is zero.
691 return new Result(statistic);
692 }
693 final double[] opt = computePValue(tableList);
694
695 return new Result(statistic, opt[0], opt[1]);
696 }
697
698 /**
699 * Find all tables that are as or more extreme than the observed table.
700 *
701 * <p>If the list of tables is full then all tables are more extreme.
702 * Some configurations can detect this without performing a search
703 * and in this case the list of tables is returned as empty.
704 *
705 * @param a Observed value for a.
706 * @param b Observed value for b.
707 * @param tableList List to track more extreme tables.
708 * @return the test statistic
709 */
710 private double findExtremeTables(int a, int b, XYList tableList) {
711 final int m = tableList.getMaxX();
712 final int n = tableList.getMaxY();
713 // Exhaustive switch statement
714 switch (method) {
715 case Z_POOLED:
716 return findExtremeTablesZ(a, b, m, n, true, tableList);
717 case Z_UNPOOLED:
718 return findExtremeTablesZ(a, b, m, n, false, tableList);
719 case BOSCHLOO:
720 return findExtremeTablesBoschloo(a, b, m, n, tableList);
721 }
722 throw new IllegalStateException(String.valueOf(method));
723 }
724
725 /**
726 * Compute the statistic from a Z-test.
727 *
728 * @param a Observed value for a.
729 * @param b Observed value for b.
730 * @param m Column sum m.
731 * @param n Column sum n.
732 * @param pooled true to use a pooled variance.
733 * @return z
734 */
735 private static double statisticZ(int a, int b, int m, int n, boolean pooled) {
736 final double p0 = (double) a / m;
737 final double p1 = (double) b / n;
738 // Avoid NaN generation 0 / 0 when the variance is 0
739 if (p0 != p1) {
740 final double variance;
741 if (pooled) {
742 // Integer sums will not overflow
743 final double p = (double) (a + b) / (m + n);
744 variance = p * (1 - p) * (1.0 / m + 1.0 / n);
745 } else {
746 variance = p0 * (1 - p0) / m + p1 * (1 - p1) / n;
747 }
748 return (p0 - p1) / Math.sqrt(variance);
749 }
750 return 0;
751 }
752
753 /**
754 * Find all tables that are as or more extreme than the observed table using the Z statistic.
755 *
756 * @param a Observed value for a.
757 * @param b Observed value for b.
758 * @param m Column sum m.
759 * @param n Column sum n.
760 * @param pooled true to use a pooled variance.
761 * @param tableList List to track more extreme tables.
762 * @return observed z
763 */
764 private double findExtremeTablesZ(int a, int b, int m, int n, boolean pooled, XYList tableList) {
765 final double statistic = statisticZ(a, b, m, n, pooled);
766 // Identify more extreme tables using the alternate hypothesis
767 final DoublePredicate test;
768 if (alternative == AlternativeHypothesis.GREATER_THAN) {
769 test = z -> z >= statistic;
770 } else if (alternative == AlternativeHypothesis.LESS_THAN) {
771 test = z -> z <= statistic;
772 } else {
773 // two-sided
774 if (statistic == 0) {
775 // Early exit: all tables are as extreme
776 return 0;
777 }
778 final double za = Math.abs(statistic);
779 test = z -> Math.abs(z) >= za;
780 }
781 // Precompute factors
782 final double mn = (double) m + n;
783 final double norm = 1.0 / m + 1.0 / n;
784 double z;
785 // Process all possible tables
786 for (int i = 0; i <= m; i++) {
787 final double p0 = (double) i / m;
788 final double vp0 = p0 * (1 - p0) / m;
789 for (int j = 0; j <= n; j++) {
790 final double p1 = (double) j / n;
791 // Avoid NaN generation 0 / 0 when the variance is 0
792 if (p0 == p1) {
793 z = 0;
794 } else {
795 final double variance;
796 if (pooled) {
797 // Integer sums will not overflow
798 final double p = (i + j) / mn;
799 variance = p * (1 - p) * norm;
800 } else {
801 variance = vp0 + p1 * (1 - p1) / n;
802 }
803 z = (p0 - p1) / Math.sqrt(variance);
804 }
805 if (test.test(z)) {
806 tableList.add(i, j);
807 }
808 }
809 }
810 return statistic;
811 }
812
813 /**
814 * Compute the statistic using Fisher's p-value (also known as Boschloo's test).
815 *
816 * @param a Observed value for a.
817 * @param b Observed value for b.
818 * @param m Column sum m.
819 * @param n Column sum n.
820 * @return p-value
821 */
822 private double statisticBoschloo(int a, int b, int m, int n) {
823 final int nn = m + n;
824 final int k = a + b;
825 // Re-use the cached Hypergeometric implementation to allow the value
826 // to be identical for the statistic and test methods.
827 final Hypergeom dist = new Hypergeom(nn, k, m);
828 if (alternative == AlternativeHypothesis.GREATER_THAN) {
829 return dist.sf(a - 1);
830 } else if (alternative == AlternativeHypothesis.LESS_THAN) {
831 return dist.cdf(a);
832 }
833 // two-sided: Find all i where Pr(X = i) <= Pr(X = a) and sum them.
834 return statisticBoschlooTwoSided(dist, a);
835 }
836
837 /**
838 * Compute the two-sided statistic using Fisher's p-value (also known as Boschloo's test).
839 *
840 * @param distribution Hypergeometric distribution.
841 * @param k Observed value.
842 * @return p-value
843 */
844 private static double statisticBoschlooTwoSided(Hypergeom distribution, int k) {
845 // two-sided: Find all i where Pr(X = i) <= Pr(X = k) and sum them.
846 // Logic is the same as FisherExactTest but using the probability (PMF), which
847 // is cached, rather than the logProbability.
848 final double pk = distribution.pmf(k);
849
850 final int m1 = distribution.getLowerMode();
851 final int m2 = distribution.getUpperMode();
852 if (k < m1) {
853 // Lower half = cdf(k)
854 // Find upper half. As k < lower mode i should never
855 // reach the lower mode based on the probability alone.
856 // Bracket with the upper mode.
857 final int i = Searches.searchDescending(m2, distribution.getSupportUpperBound(), pk,
858 distribution::pmf);
859 return distribution.cdf(k) +
860 distribution.sf(i - 1);
861 } else if (k > m2) {
862 // Upper half = sf(k - 1)
863 // Find lower half. As k > upper mode i should never
864 // reach the upper mode based on the probability alone.
865 // Bracket with the lower mode.
866 final int i = Searches.searchAscending(distribution.getSupportLowerBound(), m1, pk,
867 distribution::pmf);
868 return distribution.cdf(i) +
869 distribution.sf(k - 1);
870 }
871 // k == mode
872 // Edge case where the sum of probabilities will be either
873 // 1 or 1 - Pr(X = mode) where mode != k
874 final double pm = distribution.pmf(k == m1 ? m2 : m1);
875 return pm > pk ? 1 - pm : 1;
876 }
877
878 /**
879 * Find all tables that are as or more extreme than the observed table using the
880 * Fisher's p-value as the statistic (also known as Boschloo's test).
881 *
882 * @param a Observed value for a.
883 * @param b Observed value for b.
884 * @param m Column sum m.
885 * @param n Column sum n.
886 * @param tableList List to track more extreme tables.
887 * @return observed p-value
888 */
889 private double findExtremeTablesBoschloo(int a, int b, int m, int n, XYList tableList) {
890 final double statistic = statisticBoschloo(a, b, m, n);
891
892 // Function to compute the statistic
893 final BoschlooStatistic func;
894 if (alternative == AlternativeHypothesis.GREATER_THAN) {
895 func = (dist, x) -> dist.sf(x - 1);
896 } else if (alternative == AlternativeHypothesis.LESS_THAN) {
897 func = Hypergeom::cdf;
898 } else {
899 func = UnconditionedExactTest::statisticBoschlooTwoSided;
900 }
901
902 // All tables are: 0 <= i <= m by 0 <= j <= n
903 // Diagonal (upper-left to lower-right) strips of the possible
904 // tables use the same hypergeometric distribution
905 // (i.e. i+j == number of successes). To enumerate all requires
906 // using the full range of all distributions: 0 <= i+j <= m+n.
907 // Note the column sum m is fixed.
908 final int mn = m + n;
909 for (int k = 0; k <= mn; k++) {
910 final Hypergeom dist = new Hypergeom(mn, k, m);
911 final int lo = dist.getSupportLowerBound();
912 final int hi = dist.getSupportUpperBound();
913 for (int i = lo; i <= hi; i++) {
914 if (func.value(dist, i) <= statistic) {
915 // j = k - i
916 tableList.add(i, k - i);
917 }
918 }
919 }
920 return statistic;
921 }
922
923 /**
924 * Compute the nuisance parameter and p-value for the binomial model given the list
925 * of possible tables.
926 *
927 * <p>The current method enumerates an initial set of points and stores local
928 * extrema as candidates. Any candidate within 2% of the best is optionally
929 * optimized; this is limited to the top 3 candidates. These settings
930 * could be exposed as configurable options. Currently only the choice to optimize
931 * or not is exposed.
932 *
933 * @param tableList List of tables.
934 * @return [nuisance parameter, p-value]
935 */
936 private double[] computePValue(XYList tableList) {
937 final DoubleUnaryOperator func = createBinomialModel(tableList);
938
939 // Enumerate the range [LOWER, 1-LOWER] and save the best points for optimization
940 final Candidates minima = new Candidates(MAX_CANDIDATES, MINIMA_EPS);
941 final int n = points - 1;
942 final double inc = (1.0 - 2 * LOWER_BOUND) / n;
943 // Moving window of 3 values to identify minima.
944 // px holds the position of the previous evaluated point.
945 double v2 = 0;
946 double v3 = func.applyAsDouble(LOWER_BOUND);
947 double px = LOWER_BOUND;
948 for (int i = 1; i < n; i++) {
949 final double x = LOWER_BOUND + i * inc;
950 final double v1 = v2;
951 v2 = v3;
952 v3 = func.applyAsDouble(x);
953 addCandidate(minima, v1, v2, v3, px);
954 px = x;
955 }
956 // Add the upper bound
957 final double x = 1 - LOWER_BOUND;
958 final double vn = func.applyAsDouble(x);
959 addCandidate(minima, v2, v3, vn, px);
960 addCandidate(minima, v3, vn, 0, x);
961
962 final double[] min = minima.getMinimum();
963
964 // Optionally optimize the best point(s) (if not already optimal)
965 if (optimize && min[1] > -1) {
966 final BrentOptimizer opt = new BrentOptimizer(SOLVER_RELATIVE_EPS, Double.MIN_VALUE);
967 final BracketFinder bf = new BracketFinder();
968 minima.forEach(candidate -> {
969 double a = candidate[0];
970 final double fa;
971 // Attempt to bracket the minima. Use an initial second point placed relative to
972 // the size of the interval: [x - increment, x + increment].
973 // if a < 0.5 then add a small delta ; otherwise subtract the delta.
974 final double b = a - Math.copySign(inc * INC_FRACTION, a - 0.5);
975 if (bf.search(func, a, b, 0, 1)) {
976 // The bracket a < b < c must have f(b) < min(f(a), f(b))
977 final PointValuePair p = opt.optimize(func, bf.getLo(), bf.getHi(), bf.getMid(), bf.getFMid());
978 a = p.getPoint();
979 fa = p.getValue();
980 } else {
981 // Mid-point is at one of the bounds (i.e. is 0 or 1)
982 a = bf.getMid();
983 fa = bf.getFMid();
984 }
985 if (fa < min[1]) {
986 min[0] = a;
987 min[1] = fa;
988 }
989 });
990 }
991 // Reverse the sign of the p-value to create a maximum.
992 // Note that due to the summation the p-value can be above 1 so we clip the final result.
993 // Note: Apply max then reverse sign. This will pass through spurious NaN values if
994 // the p-value computation produced all NaNs.
995 min[1] = -Math.max(-1, min[1]);
996 return min;
997 }
998
999 /**
1000 * Creates the binomial model p-value function for the nuisance parameter.
1001 * Note: This function computes the negative p-value so is suitable for
1002 * optimization by a search for a minimum.
1003 *
1004 * @param tableList List of tables.
1005 * @return the function
1006 */
1007 private static DoubleUnaryOperator createBinomialModel(XYList tableList) {
1008 final int m = tableList.getMaxX();
1009 final int n = tableList.getMaxY();
1010 final int mn = m + n;
1011 // Compute the probability using logs
1012 final double[] c = new double[tableList.size()];
1013 final int[] ij = new int[tableList.size()];
1014 final int width = tableList.getWidth();
1015
1016 // Compute the log binomial dynamically for a small number of values
1017 final IntToDoubleFunction binomM;
1018 final IntToDoubleFunction binomN;
1019 if (tableList.size() < mn) {
1020 binomM = k -> LogBinomialCoefficient.value(m, k);
1021 binomN = k -> LogBinomialCoefficient.value(n, k);
1022 } else {
1023 // Pre-compute all values
1024 binomM = createLogBinomialCoefficients(m);
1025 binomN = m == n ? binomM : createLogBinomialCoefficients(n);
1026 }
1027
1028 // Handle special cases i+j == 0 and i+j == m+n.
1029 // These will occur only once, if at all. Mark if they occur.
1030 int flag = 0;
1031 int j = 0;
1032 for (int i = 0; i < c.length; i++) {
1033 final int index = tableList.get(i);
1034 final int x = index % width;
1035 final int y = index / width;
1036 final int xy = x + y;
1037 if (xy == 0) {
1038 flag |= 1;
1039 } else if (xy == mn) {
1040 flag |= 2;
1041 } else {
1042 ij[j] = xy;
1043 c[j] = binomM.applyAsDouble(x) + binomN.applyAsDouble(y);
1044 j++;
1045 }
1046 }
1047
1048 final int size = j;
1049 final boolean ij0 = (flag & 1) != 0;
1050 final boolean ijmn = (flag & 2) != 0;
1051 return pi -> {
1052 final double logp = Math.log(pi);
1053 final double log1mp = Math.log1p(-pi);
1054 double sum = 0;
1055 for (int i = 0; i < size; i++) {
1056 // binom(m, i) * binom(n, j) * pi^(i+j) * (1-pi)^(m+n-i-j)
1057 sum += Math.exp(ij[i] * logp + (mn - ij[i]) * log1mp + c[i]);
1058 }
1059 // Add the simplified terms where the binomial is 1.0 and one power is x^0 == 1.0.
1060 // This avoids 0 * log(x) generating NaN when x is 0 in the case where pi was 0 or 1.
1061 // Reuse exp (not pow) to support pi approaching 0 or 1.
1062 if (ij0) {
1063 // pow(1-pi, mn)
1064 sum += Math.exp(mn * log1mp);
1065 }
1066 if (ijmn) {
1067 // pow(pi, mn)
1068 sum += Math.exp(mn * logp);
1069 }
1070 // The optimizer minimises the function so this returns -p.
1071 return -sum;
1072 };
1073 }
1074
1075 /**
1076 * Create the natural logarithm of the binomial coefficient for all {@code k = [0, n]}.
1077 *
1078 * @param n Limit N.
1079 * @return ln binom(n, k)
1080 */
1081 private static IntToDoubleFunction createLogBinomialCoefficients(int n) {
1082 final double[] binom = new double[n + 1];
1083 // Exploit symmetry.
1084 // ignore: binom(n, 0) == binom(n, n) == 1
1085 int j = n - 1;
1086 for (int i = 1; i <= j; i++, j--) {
1087 binom[i] = binom[j] = LogBinomialCoefficient.value(n, i);
1088 }
1089 return k -> binom[k];
1090 }
1091
1092 /**
1093 * Add point 2 to the list of minima if neither neighbour value is lower.
1094 * <pre>
1095 * !(v1 < v2 || v3 < v2)
1096 * </pre>
1097 *
1098 * @param minima Candidate minima.
1099 * @param v1 First point function value.
1100 * @param v2 Second point function value.
1101 * @param v3 Third point function value.
1102 * @param x2 Second point.
1103 */
1104 private void addCandidate(Candidates minima, double v1, double v2, double v3, double x2) {
1105 final double min = v1 < v3 ? v1 : v3;
1106 if (min < v2) {
1107 // Lower neighbour(s)
1108 return;
1109 }
1110 // Add the candidate. This could be NaN but the candidate list handles this by storing
1111 // NaN only when no non-NaN values have been observed.
1112 minima.add(x2, v2);
1113 }
1114
1115 /**
1116 * Check the input is a 2-by-2 contingency table.
1117 *
1118 * @param table Contingency table.
1119 * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any
1120 * table entry is negative; any column sum is zero; the table sum is zero or not an
1121 * integer; or the number of possible tables exceeds the maximum array capacity.
1122 */
1123 private static void checkTable(int[][] table) {
1124 Arguments.checkTable(table);
1125 // Must all be positive
1126 final int a = table[0][0];
1127 final int c = table[1][0];
1128 // checkTable has validated the total sum is < 2^31
1129 final int m = a + c;
1130 if (m == 0) {
1131 throw new InferenceException(InferenceException.ZERO_AT, COLUMN_SUM, 0);
1132 }
1133 final int b = table[0][1];
1134 final int d = table[1][1];
1135 final int n = b + d;
1136 if (n == 0) {
1137 throw new InferenceException(InferenceException.ZERO_AT, COLUMN_SUM, 1);
1138 }
1139 // Total possible tables must be a size we can track in an array (to compute the p-value)
1140 final long size = (m + 1L) * (n + 1L);
1141 if (size > MAX_TABLES) {
1142 throw new InferenceException(InferenceException.X_GT_Y, size, MAX_TABLES);
1143 }
1144 }
1145 }