UnconditionedExactTest.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.statistics.inference;

  18. import java.util.Arrays;
  19. import java.util.Objects;
  20. import java.util.function.Consumer;
  21. import java.util.function.DoublePredicate;
  22. import java.util.function.DoubleUnaryOperator;
  23. import java.util.function.IntToDoubleFunction;
  24. import org.apache.commons.numbers.combinatorics.LogBinomialCoefficient;
  25. import org.apache.commons.statistics.inference.BrentOptimizer.PointValuePair;

  26. /**
  27.  * Implements an unconditioned exact test for a contingency table.
  28.  *
  29.  * <p>Performs an exact test for the statistical significance of the association (contingency)
  30.  * between two kinds of categorical classification. A 2x2 contingency table is:
  31.  *
  32.  * <p>\[ \left[ {\begin{array}{cc}
  33.  *         a &amp; b \\
  34.  *         c &amp; d \\
  35.  *       \end{array} } \right] \]
  36.  *
  37.  * <p>This test applies to the case of a 2x2 contingency table with one margin fixed. Note that
  38.  * if both margins are fixed (the row sums and column sums are not random)
  39.  * then Fisher's exact test can be applied.
  40.  *
  41.  * <p>This implementation fixes the column sums \( m = a + c \) and \( n = b + d \).
  42.  * All possible tables can be created using \( 0 \le a \le m \) and \( 0 \le b \le n \).
  43.  * The random values \( a \) and \( b \) follow a binomial distribution with probabilities
  44.  * \( p_0 \) and \( p_1 \) such that \( a \sim B(m, p_0) \) and \( b \sim B(n, p_1) \).
  45.  * The p-value of the 2x2 table is the product of two binomials:
  46.  *
  47.  * <p>\[ \begin{aligned}
  48.  *       p &amp;= Pr(a; m, p_0) \times Pr(b; n, p_1) \\
  49.  *         &amp;= \binom{m}{a} p_0^a (1-p_0)^{m-a} \times \binom{n}{b} p_1^b (1-p_1)^{n-b} \end{aligned} \]
  50.  *
  51.  * <p>For the binomial model, the null hypothesis is the two nuisance parameters are equal
  52.  * \( p_0 = p_1 = \pi\), with \( \pi \) the probability for equal proportions, and the probability
  53.  * of any single table is:
  54.  *
  55.  * <p>\[ p = \binom{m}{a} \binom{n}{b} \pi^{a+b} (1-\pi)^{m+n-a-b} \]
  56.  *
  57.  * <p>The p-value of the observed table is calculated by maximising the sum of the as or more
  58.  * extreme tables over the domain of the nuisance parameter \( 0 \lt \pi \lt 1 \):
  59.  *
  60.  * <p>\[ p(a, b) = \sum_{i,j} \binom{m}{i} \binom{n}{j} \pi^{i+j} (1-\pi)^{m+n-i-j} \]
  61.  *
  62.  * <p>where table \( (i,j) \) is as or more extreme than the observed table \( (a, b) \). The test
  63.  * can be configured to select more extreme tables using various {@linkplain Method methods}.
  64.  *
  65.  * <p>Note that the sum of the joint binomial distribution is a univariate function for
  66.  * the nuisance parameter \( \pi \). This function may have many local maxima and the
  67.  * search enumerates the range with a configured {@linkplain #withInitialPoints(int)
  68.  * number of points}. The best candidates are optionally used as the start point for an
  69.  * {@linkplain #withOptimize(boolean) optimized} search for a local maxima.
  70.  *
  71.  * <p>References:
  72.  * <ol>
  73.  * <li>
  74.  * Barnard, G.A. (1947).
  75.  * <a href="https://doi.org/10.1093/biomet/34.1-2.123">Significance tests for 2x2 tables.</a>
  76.  * Biometrika, 34, Issue 1-2, 123–138.
  77.  * <li>
  78.  * Boschloo, R.D. (1970).
  79.  * <a href="https://doi.org/10.1111/j.1467-9574.1970.tb00104.x">Raised conditional level of
  80.  * significance for the 2 × 2-table when testing the equality of two probabilities.</a>
  81.  * Statistica neerlandica, 24(1), 1–9.
  82.  * <li>
  83.  * Suisaa, A and Shuster, J.J. (1985).
  84.  * <a href="https://doi.org/10.2307/2981892">Exact Unconditional Sample Sizes
  85.  * for the 2 × 2 Binomial Trial.</a>
  86.  * Journal of the Royal Statistical Society. Series A (General), 148(4), 317-327.
  87.  * </ol>
  88.  *
  89.  * @see FisherExactTest
  90.  * @see <a href="https://en.wikipedia.org/wiki/Boschloo%27s_test">Boschloo&#39;s test (Wikipedia)</a>
  91.  * @see <a href="https://en.wikipedia.org/wiki/Barnard%27s_test">Barnard&#39;s test (Wikipedia)</a>
  92.  * @since 1.1
  93.  */
  94. public final class UnconditionedExactTest {
  95.     /**
  96.      * Default instance.
  97.      *
  98.      * <p>SciPy's boschloo_exact and barnard_exact tests use 32 points in the interval [0,
  99.      * 1) The R Exact package uses 100 in the interval [1e-5, 1-1e-5]. Barnards 1947 paper
  100.      * describes the nuisance parameter in the open interval {@code 0 < pi < 1}. Here we
  101.      * respect the open-interval for the initial candidates and ignore 0 and 1. The
  102.      * initial bounds used are the same as the R Exact package. We closely match the inner
  103.      * 31 points from SciPy by using 33 points by default.
  104.      */
  105.     private static final UnconditionedExactTest DEFAULT = new UnconditionedExactTest(
  106.         AlternativeHypothesis.TWO_SIDED, Method.BOSCHLOO, 33, true);
  107.     /** Lower bound for the enumerated interval. The upper bound is {@code 1 - lower}. */
  108.     private static final double LOWER_BOUND = 1e-5;
  109.     /** Relative epsilon for the Brent solver. This is limited for a univariate function
  110.      * to approximately sqrt(eps) with eps = 2^-52. */
  111.     private static final double SOLVER_RELATIVE_EPS = 1.4901161193847656E-8;
  112.     /** Fraction of the increment (interval between enumerated points) to initialise the bracket
  113.      * for the minima. Note the minima should lie between x +/- increment. The bracket should
  114.      * search within this range. Set to 1/8 and so the initial point of the bracket is
  115.      * approximately 1.61 * 1/8 = 0.2 of the increment away from initial points a or b. */
  116.     private static final double INC_FRACTION = 0.125;
  117.     /** Maximum number of candidate to optimize. This is a safety limit to avoid excess
  118.      * optimization. Only candidates within a relative tolerance of the best candidate are
  119.      * stored. If the number of candidates exceeds this value then many candidates have a
  120.      * very similar p-value and the top candidates will be optimized. Using a value of 3
  121.      * allows at least one other candidate to be optimized when there is two-fold
  122.      * symmetry in the energy function. */
  123.     private static final int MAX_CANDIDATES = 3;
  124.     /** Relative distance of candidate minima from the lowest candidate. Used to exclude
  125.      * poor candidates from optimization. */
  126.     private static final double MINIMA_EPS = 0.02;
  127.     /** The maximum number of tables. This is limited by the maximum number of indices that
  128.      * can be maintained in memory. Potentially up to this number of tables must be tracked
  129.      * during computation of the p-value for as or more extreme tables. The limit is set
  130.      * using the same limit for maximum capacity as java.util.ArrayList. In practice any
  131.      * table anywhere near this limit can be computed using an alternative such as a chi-squared
  132.      * or g test. */
  133.     private static final int MAX_TABLES = Integer.MAX_VALUE - 8;
  134.     /** Error message text for zero column sums. */
  135.     private static final String COLUMN_SUM = "Column sum";

  136.     /** Alternative hypothesis. */
  137.     private final AlternativeHypothesis alternative;
  138.     /** Method to identify more extreme tables. */
  139.     private final Method method;
  140.     /** Number of initial points. */
  141.     private final int points;
  142.     /** Option to optimize the best initial point(s). */
  143.     private final boolean optimize;

  144.     /**
  145.      * Define the method to determine the more extreme tables.
  146.      *
  147.      * @since 1.1
  148.      */
  149.     public enum Method {
  150.         /**
  151.          * Uses the test statistic from a Z-test using a pooled variance.
  152.          *
  153.          * <p>\[ T(X) = \frac{\hat{p}_0 - \hat{p}_1}{\sqrt{\hat{p}(1 - \hat{p}) (\frac{1}{m} + \frac{1}{n})}} \]
  154.          *
  155.          * <p>where \( \hat{p}_0 = a / m \), \( \hat{p}_1 = b / n \), and
  156.          * \( \hat{p} = (a+b) / (m+n) \) are the estimators of \( p_0 \), \( p_1 \) and the
  157.          * pooled probability \( p \) assuming \( p_0 = p_1 \).
  158.          *
  159.          * <p>The more extreme tables are identified using the {@link AlternativeHypothesis}:
  160.          * <ul>
  161.          * <li>greater: \( T(X) \ge T(X_0) \)
  162.          * <li>less: \( T(X) \le T(X_0) \)
  163.          * <li>two-sided: \( | T(X) | \ge | T(X_0) | \)
  164.          * </ul>
  165.          *
  166.          * <p>The use of the Z statistic was suggested by Suissa and Shuster (1985).
  167.          * This method is uniformly more powerful than Fisher's test for balanced designs
  168.          * (\( m = n \)).
  169.          */
  170.         Z_POOLED,

  171.         /**
  172.          * Uses the test statistic from a Z-test using an unpooled variance.
  173.          *
  174.          * <p>\[ T(X) = \frac{\hat{p}_0 - \hat{p}_1}
  175.          * {\sqrt{ \frac{\hat{p}_0(1 - \hat{p}_0)}{m} + \frac{\hat{p}_1(1 - \hat{p}_1)}{n}} } \]
  176.          *
  177.          * <p>where \( \hat{p}_0 = a / m \) and \( \hat{p}_1 = b / n \).
  178.          *
  179.          * <p>The more extreme tables are identified using the {@link AlternativeHypothesis} as
  180.          * per the {@link #Z_POOLED} method.
  181.          */
  182.         Z_UNPOOLED,

  183.         /**
  184.          * Uses the p-value from Fisher's exact test. This is also known as Boschloo's test.
  185.          *
  186.          * <p>The p-value for Fisher's test is computed using using the
  187.          * {@link AlternativeHypothesis}. The more extreme tables are identified using
  188.          * \( p(X) \le p(X_0) \).
  189.          *
  190.          * <p>This method is always uniformly more powerful than Fisher's test.
  191.          *
  192.          * @see FisherExactTest
  193.          */
  194.         BOSCHLOO;
  195.     }

  196.     /**
  197.      * Result for the unconditioned exact test.
  198.      *
  199.      * <p>This class is immutable.
  200.      *
  201.      * @since 1.1
  202.      */
  203.     public static final class Result extends BaseSignificanceResult {
  204.         /** Nuisance parameter. */
  205.         private final double pi;

  206.         /**
  207.          * Create an instance where all tables are more extreme, i.e. the p-value
  208.          * is 1.0.
  209.          *
  210.          * @param statistic Test statistic.
  211.          */
  212.         Result(double statistic) {
  213.             super(statistic, 1);
  214.             this.pi = 0.5;
  215.         }

  216.         /**
  217.          * @param statistic Test statistic.
  218.          * @param pi Nuisance parameter.
  219.          * @param p Result p-value.
  220.          */
  221.         Result(double statistic, double pi, double p) {
  222.             super(statistic, p);
  223.             this.pi = pi;
  224.         }

  225.         /**
  226.          * {@inheritDoc}
  227.          *
  228.          * <p>The value of the statistic is dependent on the {@linkplain Method method}
  229.          * used to determine the more extreme tables.
  230.          */
  231.         @Override
  232.         public double getStatistic() {
  233.             // Note: This method is here for documentation
  234.             return super.getStatistic();
  235.         }

  236.         /**
  237.          * Gets the nuisance parameter that maximised the probability sum of the as or more
  238.          * extreme tables.
  239.          *
  240.          * @return the nuisance parameter.
  241.          */
  242.         public double getNuisanceParameter() {
  243.             return pi;
  244.         }
  245.     }

  246.     /**
  247.      * An expandable list of (x,y) values. This allows tracking 2D positions stored as
  248.      * a single index.
  249.      */
  250.     private static class XYList {
  251.         /** The maximum size of array to allocate. */
  252.         private final int max;
  253.         /** Width, or maximum x value (exclusive). */
  254.         private final int width;

  255.         /** The size of the list. */
  256.         private int size;
  257.         /** The list data. */
  258.         private int[] data = new int[10];

  259.         /**
  260.          * Create an instance. It is assumed that (maxx+1)*(maxy+1) does not exceed the
  261.          * capacity of an array.
  262.          *
  263.          * @param maxx Maximum x-value (inclusive).
  264.          * @param maxy Maximum y-value (inclusive).
  265.          */
  266.         XYList(int maxx, int maxy) {
  267.             this.width = maxx + 1;
  268.             this.max = width * (maxy + 1);
  269.         }

  270.         /**
  271.          * Gets the width.
  272.          * (x, y) values are stored using y * width + x.
  273.          *
  274.          * @return the width
  275.          */
  276.         int getWidth() {
  277.             return width;
  278.         }

  279.         /**
  280.          * Gets the maximum X value (inclusive).
  281.          *
  282.          * @return the max X
  283.          */
  284.         int getMaxX() {
  285.             return width - 1;
  286.         }

  287.         /**
  288.          * Gets the maximum Y value (inclusive).
  289.          *
  290.          * @return the max Y
  291.          */
  292.         int getMaxY() {
  293.             return max / width - 1;
  294.         }

  295.         /**
  296.          * Adds the value to the list.
  297.          *
  298.          * @param x X value.
  299.          * @param y Y value.
  300.          */
  301.         void add(int x, int y) {
  302.             if (size == data.length) {
  303.                 // Overflow safe doubling of the current size.
  304.                 data = Arrays.copyOf(data, (int) Math.min(max, size * 2L));
  305.             }
  306.             data[size++] = width * y + x;
  307.         }

  308.         /**
  309.          * Gets the 2D index at the specified {@code index}.
  310.          * The index is y * width + x:
  311.          * <pre>
  312.          * x = index % width
  313.          * y = index / width
  314.          * </pre>
  315.          *
  316.          * @param index Element index.
  317.          * @return the 2D index
  318.          */
  319.         int get(int index) {
  320.             return data[index];
  321.         }

  322.         /**
  323.          * Gets the number of elements in the list.
  324.          *
  325.          * @return the size
  326.          */
  327.         int size() {
  328.             return size;
  329.         }

  330.         /**
  331.          * Checks if the list size is zero.
  332.          *
  333.          * @return true if empty
  334.          */
  335.         boolean isEmpty() {
  336.             return size == 0;
  337.         }

  338.         /**
  339.          * Checks if the list is the maximum capacity.
  340.          *
  341.          * @return true if full
  342.          */
  343.         boolean isFull() {
  344.             return size == max;
  345.         }
  346.     }

  347.     /**
  348.      * A container of (key,value) pairs to store candidate minima. Encapsulates the
  349.      * logic of storing multiple initial search points for optimization.
  350.      *
  351.      * <p>Stores all pairs within a relative tolerance of the lowest minima up to a set
  352.      * capacity. When at capacity the worst candidate is replaced by addition of a
  353.      * better candidate.
  354.      *
  355.      * <p>Special handling is provided to store only a single NaN value if no non-NaN
  356.      * values have been observed. This prevents storing a large number of NaN
  357.      * candidates.
  358.      */
  359.     static class Candidates {
  360.         /** The maximum size of array to allocate. */
  361.         private final int max;
  362.         /** Relative distance from lowest candidate. */
  363.         private final double eps;
  364.         /** Candidate (key,value) pairs. */
  365.         private double[][] data;
  366.         /** Current size of the list. */
  367.         private int size;
  368.         /** Current minimum. */
  369.         private double min = Double.POSITIVE_INFINITY;
  370.         /** Current threshold for inclusion. */
  371.         private double threshold = Double.POSITIVE_INFINITY;

  372.         /**
  373.          * Create an instance.
  374.          *
  375.          * @param max Maximum number of allowed candidates (limited to at least 1).
  376.          * @param eps Relative distance of candidate minima from the lowest candidate
  377.          * (assumed to be positive and finite).
  378.          */
  379.         Candidates(int max, double eps) {
  380.             this.max = Math.max(1, max);
  381.             this.eps = eps;
  382.             // Create the initial storage
  383.             data = new double[Math.min(this.max, 4)][];
  384.         }

  385.         /**
  386.          * Adds the (key, value) pair.
  387.          *
  388.          * @param k Key.
  389.          * @param v Value.
  390.          */
  391.         void add(double k, double v) {
  392.             // Store only a single NaN
  393.             if (Double.isNaN(v)) {
  394.                 if (size == 0) {
  395.                     // No requirement to check capacity
  396.                     data[size++] = new double[] {k, v};
  397.                 }
  398.                 return;
  399.             }
  400.             // Here values are non-NaN.
  401.             // If higher then do not store.
  402.             if (v > threshold) {
  403.                 return;
  404.             }
  405.             // Check if lower than the current minima.
  406.             if (v < min) {
  407.                 min = v;
  408.                 // Get new threshold
  409.                 threshold = v + Math.abs(v) * eps;
  410.                 // Remove existing entries above the threshold
  411.                 int s = 0;
  412.                 for (int i = 0; i < size; i++) {
  413.                     // This will filter NaN values
  414.                     if (data[i][1] <= threshold) {
  415.                         data[s++] = data[i];
  416.                     }
  417.                 }
  418.                 size = s;
  419.                 // Caution: This does not clear stale data
  420.                 // by setting all values in [newSize, oldSize) = null
  421.             }
  422.             addPair(k, v);
  423.         }

  424.         /**
  425.          * Add the (key, value) pair to the data.
  426.          * It is assumed the data satisfy the conditions for addition.
  427.          *
  428.          * @param k Key.
  429.          * @param v Value.
  430.          */
  431.         private void addPair(double k, double v) {
  432.             if (size == data.length) {
  433.                 if (size == max) {
  434.                     // At capacity.
  435.                     replaceWorst(k, v);
  436.                     return;
  437.                 }
  438.                 // Expand
  439.                 data = Arrays.copyOfRange(data, 0, (int) Math.min(max, size * 2L));
  440.             }
  441.             data[size++] = new double[] {k, v};
  442.         }

  443.         /**
  444.          * Replace the worst candidate.
  445.          *
  446.          * @param k Key.
  447.          * @param v Value.
  448.          */
  449.         private void replaceWorst(double k, double v) {
  450.             // Note: This only occurs when NaN values have been removed by addition
  451.             // of non-NaN values.
  452.             double[] worst = data[0];
  453.             for (int i = 1; i < size; i++) {
  454.                 if (worst[1] < data[i][1]) {
  455.                     worst = data[i];
  456.                 }
  457.             }
  458.             worst[0] = k;
  459.             worst[1] = v;
  460.         }

  461.         /**
  462.          * Return the minimum (key,value) pair.
  463.          *
  464.          * @return the minimum (or null)
  465.          */
  466.         double[] getMinimum() {
  467.             // This will handle size=0 as data[0] will be null
  468.             double[] best = data[0];
  469.             for (int i = 1; i < size; i++) {
  470.                 if (best[1] > data[i][1]) {
  471.                     best = data[i];
  472.                 }
  473.             }
  474.             return best;
  475.         }

  476.         /**
  477.          * Perform the given action for each (key, value) pair.
  478.          *
  479.          * @param action Action.
  480.          */
  481.         void forEach(Consumer<double[]> action) {
  482.             for (int i = 0; i < size; i++) {
  483.                 action.accept(data[i]);
  484.             }
  485.         }
  486.     }

  487.     /**
  488.      * Compute the statistic for Boschloo's test.
  489.      */
  490.     private interface BoschlooStatistic {
  491.         /**
  492.          * Compute Fisher's p-value for the 2x2 contingency table with the observed
  493.          * value {@code x} in position [0][0]. Note that the table margins are fixed
  494.          * and are defined by the population size, number of successes and sample
  495.          * size of the specified hypergeometric distribution.
  496.          *
  497.          * @param dist Hypergeometric distribution.
  498.          * @param x Value.
  499.          * @return Fisher's p-value
  500.          */
  501.         double value(Hypergeom dist, int x);
  502.     }

  503.     /**
  504.      * @param alternative Alternative hypothesis.
  505.      * @param method Method to identify more extreme tables.
  506.      * @param points Number of initial points.
  507.      * @param optimize Option to optimize the best initial point(s).
  508.      */
  509.     private UnconditionedExactTest(AlternativeHypothesis alternative,
  510.                                    Method method,
  511.                                    int points,
  512.                                    boolean optimize) {
  513.         this.alternative = alternative;
  514.         this.method = method;
  515.         this.points = points;
  516.         this.optimize = optimize;
  517.     }

  518.     /**
  519.      * Return an instance using the default options.
  520.      *
  521.      * <ul>
  522.      * <li>{@link AlternativeHypothesis#TWO_SIDED}
  523.      * <li>{@link Method#BOSCHLOO}
  524.      * <li>{@linkplain #withInitialPoints(int) points = 33}
  525.      * <li>{@linkplain #withOptimize(boolean) optimize = true}
  526.      * </ul>
  527.      *
  528.      * @return default instance
  529.      */
  530.     public static UnconditionedExactTest withDefaults() {
  531.         return DEFAULT;
  532.     }

  533.     /**
  534.      * Return an instance with the configured alternative hypothesis.
  535.      *
  536.      * @param v Value.
  537.      * @return an instance
  538.      */
  539.     public UnconditionedExactTest with(AlternativeHypothesis v) {
  540.         return new UnconditionedExactTest(Objects.requireNonNull(v), method, points, optimize);
  541.     }

  542.     /**
  543.      * Return an instance with the configured method.
  544.      *
  545.      * @param v Value.
  546.      * @return an instance
  547.      */
  548.     public UnconditionedExactTest with(Method v) {
  549.         return new UnconditionedExactTest(alternative, Objects.requireNonNull(v), points, optimize);
  550.     }

  551.     /**
  552.      * Return an instance with the configured number of initial points.
  553.      *
  554.      * <p>The search for the nuisance parameter will use \( v \) points in the open interval
  555.      * \( (0, 1) \). The interval is evaluated by including start and end points approximately
  556.      * equal to 0 and 1. Additional internal points are enumerated using increments of
  557.      * approximately \( \frac{1}{v-1} \). The minimum number of points is 2. Increasing the
  558.      * number of points increases the precision of the search at the cost of performance.
  559.      *
  560.      * <p>To approximately double the number of points so that all existing points are included
  561.      * and additional points half-way between them are sampled requires using {@code 2p - 1}
  562.      * where {@code p} is the existing number of points.
  563.      *
  564.      * @param v Value.
  565.      * @return an instance
  566.      * @throws IllegalArgumentException if the value is {@code < 2}.
  567.      */
  568.     public UnconditionedExactTest withInitialPoints(int v) {
  569.         if (v <= 1) {
  570.             throw new InferenceException(InferenceException.X_LT_Y, v, 2);
  571.         }
  572.         return new UnconditionedExactTest(alternative, method, v, optimize);
  573.     }

  574.     /**
  575.      * Return an instance with the configured optimization of initial search points.
  576.      *
  577.      * <p>If enabled then the initial point(s) with the highest probability is/are used as the start
  578.      * for an optimization to find a local maxima.
  579.      *
  580.      * @param v Value.
  581.      * @return an instance
  582.      * @see #withInitialPoints(int)
  583.      */
  584.     public UnconditionedExactTest withOptimize(boolean v) {
  585.         return new UnconditionedExactTest(alternative, method, points, v);
  586.     }

  587.     /**
  588.      * Compute the statistic for the unconditioned exact test. The statistic returned
  589.      * depends on the configured {@linkplain Method method}.
  590.      *
  591.      * @param table 2-by-2 contingency table.
  592.      * @return test statistic
  593.      * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any
  594.      * table entry is negative; any column sum is zero; the table sum is zero or not an
  595.      * integer; or the number of possible tables exceeds the maximum array capacity.
  596.      * @see #with(Method)
  597.      * @see #test(int[][])
  598.      */
  599.     public double statistic(int[][] table) {
  600.         checkTable(table);
  601.         final int a = table[0][0];
  602.         final int b = table[0][1];
  603.         final int c = table[1][0];
  604.         final int d = table[1][1];
  605.         final int m = a + c;
  606.         final int n = b + d;
  607.         switch (method) {
  608.         case Z_POOLED:
  609.             return statisticZ(a, b, m, n, true);
  610.         case Z_UNPOOLED:
  611.             return statisticZ(a, b, m, n, false);
  612.         case BOSCHLOO:
  613.             return statisticBoschloo(a, b, m, n);
  614.         default:
  615.             throw new IllegalStateException(String.valueOf(method));
  616.         }
  617.     }

  618.     /**
  619.      * Performs an unconditioned exact test on the 2-by-2 contingency table. The statistic and
  620.      * p-value returned depends on the configured {@linkplain Method method} and
  621.      * {@linkplain AlternativeHypothesis alternative hypothesis}.
  622.      *
  623.      * <p>The search for the nuisance parameter that maximises the p-value can be configured to:
  624.      * start with a number of {@linkplain #withInitialPoints(int) initial points}; and
  625.      * {@linkplain #withOptimize(boolean) optimize} the best points.
  626.      *
  627.      * @param table 2-by-2 contingency table.
  628.      * @return test result
  629.      * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any
  630.      * table entry is negative; any column sum is zero; the table sum is zero or not an
  631.      * integer; or the number of possible tables exceeds the maximum array capacity.
  632.      * @see #with(Method)
  633.      * @see #with(AlternativeHypothesis)
  634.      * @see #statistic(int[][])
  635.      */
  636.     public Result test(int[][] table) {
  637.         checkTable(table);
  638.         final int a = table[0][0];
  639.         final int b = table[0][1];
  640.         final int c = table[1][0];
  641.         final int d = table[1][1];
  642.         final int m = a + c;
  643.         final int n = b + d;

  644.         // Used to track more extreme tables
  645.         final XYList tableList = new XYList(m, n);

  646.         final double statistic = findExtremeTables(a, b, tableList);
  647.         if (tableList.isEmpty() || tableList.isFull()) {
  648.             // All possible tables are more extreme, e.g. a two-sided test where the
  649.             // z-statistic is zero.
  650.             return new Result(statistic);
  651.         }
  652.         final double[] opt = computePValue(tableList);

  653.         return new Result(statistic, opt[0], opt[1]);
  654.     }

  655.     /**
  656.      * Find all tables that are as or more extreme than the observed table.
  657.      *
  658.      * <p>If the list of tables is full then all tables are more extreme.
  659.      * Some configurations can detect this without performing a search
  660.      * and in this case the list of tables is returned as empty.
  661.      *
  662.      * @param a Observed value for a.
  663.      * @param b Observed value for b.
  664.      * @param tableList List to track more extreme tables.
  665.      * @return the test statistic
  666.      */
  667.     private double findExtremeTables(int a, int b, XYList tableList) {
  668.         final int m = tableList.getMaxX();
  669.         final int n = tableList.getMaxY();
  670.         switch (method) {
  671.         case Z_POOLED:
  672.             return findExtremeTablesZ(a, b, m, n, true, tableList);
  673.         case Z_UNPOOLED:
  674.             return findExtremeTablesZ(a, b, m, n, false, tableList);
  675.         case BOSCHLOO:
  676.             return findExtremeTablesBoschloo(a, b, m, n, tableList);
  677.         default:
  678.             throw new IllegalStateException(String.valueOf(method));
  679.         }
  680.     }

  681.     /**
  682.      * Compute the statistic from a Z-test.
  683.      *
  684.      * @param a Observed value for a.
  685.      * @param b Observed value for b.
  686.      * @param m Column sum m.
  687.      * @param n Column sum n.
  688.      * @param pooled true to use a pooled variance.
  689.      * @return z
  690.      */
  691.     private static double statisticZ(int a, int b, int m, int n, boolean pooled) {
  692.         final double p0 = (double) a / m;
  693.         final double p1 = (double) b / n;
  694.         // Avoid NaN generation 0 / 0 when the variance is 0
  695.         if (p0 != p1) {
  696.             final double variance;
  697.             if (pooled) {
  698.                 // Integer sums will not overflow
  699.                 final double p = (double) (a + b) / (m + n);
  700.                 variance = p * (1 - p) * (1.0 / m + 1.0 / n);
  701.             } else {
  702.                 variance = p0 * (1 - p0) / m + p1 * (1 - p1) / n;
  703.             }
  704.             return (p0 - p1) / Math.sqrt(variance);
  705.         }
  706.         return 0;
  707.     }

  708.     /**
  709.      * Find all tables that are as or more extreme than the observed table using the Z statistic.
  710.      *
  711.      * @param a Observed value for a.
  712.      * @param b Observed value for b.
  713.      * @param m Column sum m.
  714.      * @param n Column sum n.
  715.      * @param pooled true to use a pooled variance.
  716.      * @param tableList List to track more extreme tables.
  717.      * @return observed z
  718.      */
  719.     private double findExtremeTablesZ(int a, int b, int m, int n, boolean pooled, XYList tableList) {
  720.         final double statistic = statisticZ(a, b, m, n, pooled);
  721.         // Identify more extreme tables using the alternate hypothesis
  722.         final DoublePredicate test;
  723.         if (alternative == AlternativeHypothesis.GREATER_THAN) {
  724.             test = z -> z >= statistic;
  725.         } else if (alternative == AlternativeHypothesis.LESS_THAN) {
  726.             test = z -> z <= statistic;
  727.         } else {
  728.             // two-sided
  729.             if (statistic == 0) {
  730.                 // Early exit: all tables are as extreme
  731.                 return 0;
  732.             }
  733.             final double za = Math.abs(statistic);
  734.             test = z -> Math.abs(z) >= za;
  735.         }
  736.         // Precompute factors
  737.         final double mn = (double) m + n;
  738.         final double norm = 1.0 / m + 1.0 / n;
  739.         double z;
  740.         // Process all possible tables
  741.         for (int i = 0; i <= m; i++) {
  742.             final double p0 = (double) i / m;
  743.             final double vp0 = p0 * (1 - p0) / m;
  744.             for (int j = 0; j <= n; j++) {
  745.                 final double p1 = (double) j / n;
  746.                 // Avoid NaN generation 0 / 0 when the variance is 0
  747.                 if (p0 == p1) {
  748.                     z = 0;
  749.                 } else {
  750.                     final double variance;
  751.                     if (pooled) {
  752.                         // Integer sums will not overflow
  753.                         final double p = (i + j) / mn;
  754.                         variance = p * (1 - p) * norm;
  755.                     } else {
  756.                         variance = vp0 + p1 * (1 - p1) / n;
  757.                     }
  758.                     z = (p0 - p1) / Math.sqrt(variance);
  759.                 }
  760.                 if (test.test(z)) {
  761.                     tableList.add(i, j);
  762.                 }
  763.             }
  764.         }
  765.         return statistic;
  766.     }

  767.     /**
  768.      * Compute the statistic using Fisher's p-value (also known as Boschloo's test).
  769.      *
  770.      * @param a Observed value for a.
  771.      * @param b Observed value for b.
  772.      * @param m Column sum m.
  773.      * @param n Column sum n.
  774.      * @return p-value
  775.      */
  776.     private double statisticBoschloo(int a, int b, int m, int n) {
  777.         final int nn = m + n;
  778.         final int k = a + b;
  779.         // Re-use the cached Hypergeometric implementation to allow the value
  780.         // to be identical for the statistic and test methods.
  781.         final Hypergeom dist = new Hypergeom(nn, k, m);
  782.         if (alternative == AlternativeHypothesis.GREATER_THAN) {
  783.             return dist.sf(a - 1);
  784.         } else if (alternative == AlternativeHypothesis.LESS_THAN) {
  785.             return dist.cdf(a);
  786.         }
  787.         // two-sided: Find all i where Pr(X = i) <= Pr(X = a) and sum them.
  788.         return statisticBoschlooTwoSided(dist, a);
  789.     }

  790.     /**
  791.      * Compute the two-sided statistic using Fisher's p-value (also known as Boschloo's test).
  792.      *
  793.      * @param distribution Hypergeometric distribution.
  794.      * @param k Observed value.
  795.      * @return p-value
  796.      */
  797.     private static double statisticBoschlooTwoSided(Hypergeom distribution, int k) {
  798.         // two-sided: Find all i where Pr(X = i) <= Pr(X = k) and sum them.
  799.         // Logic is the same as FisherExactTest but using the probability (PMF), which
  800.         // is cached, rather than the logProbability.
  801.         final double pk = distribution.pmf(k);

  802.         final int m1 = distribution.getLowerMode();
  803.         final int m2 = distribution.getUpperMode();
  804.         if (k < m1) {
  805.             // Lower half = cdf(k)
  806.             // Find upper half. As k < lower mode i should never
  807.             // reach the lower mode based on the probability alone.
  808.             // Bracket with the upper mode.
  809.             final int i = Searches.searchDescending(m2, distribution.getSupportUpperBound(), pk,
  810.                 distribution::pmf);
  811.             return distribution.cdf(k) +
  812.                    distribution.sf(i - 1);
  813.         } else if (k > m2) {
  814.             // Upper half = sf(k - 1)
  815.             // Find lower half. As k > upper mode i should never
  816.             // reach the upper mode based on the probability alone.
  817.             // Bracket with the lower mode.
  818.             final int i = Searches.searchAscending(distribution.getSupportLowerBound(), m1, pk,
  819.                 distribution::pmf);
  820.             return distribution.cdf(i) +
  821.                    distribution.sf(k - 1);
  822.         }
  823.         // k == mode
  824.         // Edge case where the sum of probabilities will be either
  825.         // 1 or 1 - Pr(X = mode) where mode != k
  826.         final double pm = distribution.pmf(k == m1 ? m2 : m1);
  827.         return pm > pk ? 1 - pm : 1;
  828.     }

  829.     /**
  830.      * Find all tables that are as or more extreme than the observed table using the
  831.      * Fisher's p-value as the statistic (also known as Boschloo's test).
  832.      *
  833.      * @param a Observed value for a.
  834.      * @param b Observed value for b.
  835.      * @param m Column sum m.
  836.      * @param n Column sum n.
  837.      * @param tableList List to track more extreme tables.
  838.      * @return observed p-value
  839.      */
  840.     private double findExtremeTablesBoschloo(int a, int b, int m, int n, XYList tableList) {
  841.         final double statistic = statisticBoschloo(a, b, m, n);

  842.         // Function to compute the statistic
  843.         final BoschlooStatistic func;
  844.         if (alternative == AlternativeHypothesis.GREATER_THAN) {
  845.             func = (dist, x) -> dist.sf(x - 1);
  846.         } else if (alternative == AlternativeHypothesis.LESS_THAN) {
  847.             func = Hypergeom::cdf;
  848.         } else {
  849.             func = UnconditionedExactTest::statisticBoschlooTwoSided;
  850.         }

  851.         // All tables are: 0 <= i <= m  by  0 <= j <= n
  852.         // Diagonal (upper-left to lower-right) strips of the possible
  853.         // tables use the same hypergeometric distribution
  854.         // (i.e. i+j == number of successes). To enumerate all requires
  855.         // using the full range of all distributions: 0 <= i+j <= m+n.
  856.         // Note the column sum m is fixed.
  857.         final int mn = m + n;
  858.         for (int k = 0; k <= mn; k++) {
  859.             final Hypergeom dist = new Hypergeom(mn, k, m);
  860.             final int lo = dist.getSupportLowerBound();
  861.             final int hi = dist.getSupportUpperBound();
  862.             for (int i = lo; i <= hi; i++) {
  863.                 if (func.value(dist, i) <= statistic) {
  864.                     // j = k - i
  865.                     tableList.add(i, k - i);
  866.                 }
  867.             }
  868.         }
  869.         return statistic;
  870.     }

  871.     /**
  872.      * Compute the nuisance parameter and p-value for the binomial model given the list
  873.      * of possible tables.
  874.      *
  875.      * <p>The current method enumerates an initial set of points and stores local
  876.      * extrema as candidates. Any candidate within 2% of the best is optionally
  877.      * optimized; this is limited to the top 3 candidates. These settings
  878.      * could be exposed as configurable options. Currently only the choice to optimize
  879.      * or not is exposed.
  880.      *
  881.      * @param tableList List of tables.
  882.      * @return [nuisance parameter, p-value]
  883.      */
  884.     private double[] computePValue(XYList tableList) {
  885.         final DoubleUnaryOperator func = createBinomialModel(tableList);

  886.         // Enumerate the range [LOWER, 1-LOWER] and save the best points for optimization
  887.         final Candidates minima = new Candidates(MAX_CANDIDATES, MINIMA_EPS);
  888.         final int n = points - 1;
  889.         final double inc = (1.0 - 2 * LOWER_BOUND) / n;
  890.         // Moving window of 3 values to identify minima.
  891.         // px holds the position of the previous evaluated point.
  892.         double v2 = 0;
  893.         double v3 = func.applyAsDouble(LOWER_BOUND);
  894.         double px = LOWER_BOUND;
  895.         for (int i = 1; i < n; i++) {
  896.             final double x = LOWER_BOUND + i * inc;
  897.             final double v1 = v2;
  898.             v2 = v3;
  899.             v3 = func.applyAsDouble(x);
  900.             addCandidate(minima, v1, v2, v3, px);
  901.             px = x;
  902.         }
  903.         // Add the upper bound
  904.         final double x = 1 - LOWER_BOUND;
  905.         final double vn = func.applyAsDouble(x);
  906.         addCandidate(minima, v2, v3, vn, px);
  907.         addCandidate(minima, v3, vn, 0, x);

  908.         final double[] min = minima.getMinimum();

  909.         // Optionally optimize the best point(s) (if not already optimal)
  910.         if (optimize && min[1] > -1) {
  911.             final BrentOptimizer opt = new BrentOptimizer(SOLVER_RELATIVE_EPS, Double.MIN_VALUE);
  912.             final BracketFinder bf = new BracketFinder();
  913.             minima.forEach(candidate -> {
  914.                 double a = candidate[0];
  915.                 final double fa;
  916.                 // Attempt to bracket the minima. Use an initial second point placed relative to
  917.                 // the size of the interval: [x - increment, x + increment].
  918.                 // if a < 0.5 then add a small delta ; otherwise subtract the delta.
  919.                 final double b = a - Math.copySign(inc * INC_FRACTION, a - 0.5);
  920.                 if (bf.search(func, a, b, 0, 1)) {
  921.                     // The bracket a < b < c must have f(b) < min(f(a), f(b))
  922.                     final PointValuePair p = opt.optimize(func, bf.getLo(), bf.getHi(), bf.getMid(), bf.getFMid());
  923.                     a = p.getPoint();
  924.                     fa = p.getValue();
  925.                 } else {
  926.                     // Mid-point is at one of the bounds (i.e. is 0 or 1)
  927.                     a = bf.getMid();
  928.                     fa = bf.getFMid();
  929.                 }
  930.                 if (fa < min[1]) {
  931.                     min[0] = a;
  932.                     min[1] = fa;
  933.                 }
  934.             });
  935.         }
  936.         // Reverse the sign of the p-value to create a maximum.
  937.         // Note that due to the summation the p-value can be above 1 so we clip the final result.
  938.         // Note: Apply max then reverse sign. This will pass through spurious NaN values if
  939.         // the p-value computation produced all NaNs.
  940.         min[1] = -Math.max(-1, min[1]);
  941.         return min;
  942.     }

  943.     /**
  944.      * Creates the binomial model p-value function for the nuisance parameter.
  945.      * Note: This function computes the negative p-value so is suitable for
  946.      * optimization by a search for a minimum.
  947.      *
  948.      * @param tableList List of tables.
  949.      * @return the function
  950.      */
  951.     private static DoubleUnaryOperator createBinomialModel(XYList tableList) {
  952.         final int m = tableList.getMaxX();
  953.         final int n = tableList.getMaxY();
  954.         final int mn = m + n;
  955.         // Compute the probability using logs
  956.         final double[] c = new double[tableList.size()];
  957.         final int[] ij = new int[tableList.size()];
  958.         final int width = tableList.getWidth();

  959.         // Compute the log binomial dynamically for a small number of values
  960.         final IntToDoubleFunction binomM;
  961.         final IntToDoubleFunction binomN;
  962.         if (tableList.size() < mn) {
  963.             binomM = k -> LogBinomialCoefficient.value(m, k);
  964.             binomN = k -> LogBinomialCoefficient.value(n, k);
  965.         } else {
  966.             // Pre-compute all values
  967.             binomM = createLogBinomialCoefficients(m);
  968.             binomN = m == n ? binomM : createLogBinomialCoefficients(n);
  969.         }

  970.         // Handle special cases i+j == 0 and i+j == m+n.
  971.         // These will occur only once, if at all. Mark if they occur.
  972.         int flag = 0;
  973.         int j = 0;
  974.         for (int i = 0; i < c.length; i++) {
  975.             final int index = tableList.get(i);
  976.             final int x = index % width;
  977.             final int y = index / width;
  978.             final int xy = x + y;
  979.             if (xy == 0) {
  980.                 flag |= 1;
  981.             } else if (xy == mn) {
  982.                 flag |= 2;
  983.             } else {
  984.                 ij[j] = xy;
  985.                 c[j] = binomM.applyAsDouble(x) + binomN.applyAsDouble(y);
  986.                 j++;
  987.             }
  988.         }

  989.         final int size = j;
  990.         final boolean ij0 = (flag & 1) != 0;
  991.         final boolean ijmn = (flag & 2) != 0;
  992.         return pi -> {
  993.             final double logp = Math.log(pi);
  994.             final double log1mp = Math.log1p(-pi);
  995.             double sum = 0;
  996.             for (int i = 0; i < size; i++) {
  997.                 // binom(m, i) * binom(n, j) * pi^(i+j) * (1-pi)^(m+n-i-j)
  998.                 sum += Math.exp(ij[i] * logp + (mn - ij[i]) * log1mp + c[i]);
  999.             }
  1000.             // Add the simplified terms where the binomial is 1.0 and one power is x^0 == 1.0.
  1001.             // This avoids 0 * log(x) generating NaN when x is 0 in the case where pi was 0 or 1.
  1002.             // Reuse exp (not pow) to support pi approaching 0 or 1.
  1003.             if (ij0) {
  1004.                 // pow(1-pi, mn)
  1005.                 sum += Math.exp(mn * log1mp);
  1006.             }
  1007.             if (ijmn) {
  1008.                 // pow(pi, mn)
  1009.                 sum += Math.exp(mn * logp);
  1010.             }
  1011.             // The optimizer minimises the function so this returns -p.
  1012.             return -sum;
  1013.         };
  1014.     }

  1015.     /**
  1016.      * Create the natural logarithm of the binomial coefficient for all {@code k = [0, n]}.
  1017.      *
  1018.      * @param n Limit N.
  1019.      * @return ln binom(n, k)
  1020.      */
  1021.     private static IntToDoubleFunction createLogBinomialCoefficients(int n) {
  1022.         final double[] binom = new double[n + 1];
  1023.         // Exploit symmetry.
  1024.         // ignore: binom(n, 0) == binom(n, n) == 1
  1025.         int j = n - 1;
  1026.         for (int i = 1; i <= j; i++, j--) {
  1027.             binom[i] = binom[j] = LogBinomialCoefficient.value(n, i);
  1028.         }
  1029.         return k -> binom[k];
  1030.     }

  1031.     /**
  1032.      * Add point 2 to the list of minima if neither neighbour value is lower.
  1033.      * <pre>
  1034.      * !(v1 < v2 || v3 < v2)
  1035.      * </pre>
  1036.      *
  1037.      * @param minima Candidate minima.
  1038.      * @param v1 First point function value.
  1039.      * @param v2 Second point function value.
  1040.      * @param v3 Third point function value.
  1041.      * @param x2 Second point.
  1042.      */
  1043.     private void addCandidate(Candidates minima, double v1, double v2, double v3, double x2) {
  1044.         final double min = v1 < v3 ? v1 : v3;
  1045.         if (min < v2) {
  1046.             // Lower neighbour(s)
  1047.             return;
  1048.         }
  1049.         // Add the candidate. This could be NaN but the candidate list handles this by storing
  1050.         // NaN only when no non-NaN values have been observed.
  1051.         minima.add(x2, v2);
  1052.     }

  1053.     /**
  1054.      * Check the input is a 2-by-2 contingency table.
  1055.      *
  1056.      * @param table Contingency table.
  1057.      * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any
  1058.      * table entry is negative; any column sum is zero; the table sum is zero or not an
  1059.      * integer; or the number of possible tables exceeds the maximum array capacity.
  1060.      */
  1061.     private static void checkTable(int[][] table) {
  1062.         Arguments.checkTable(table);
  1063.         // Must all be positive
  1064.         final int a = table[0][0];
  1065.         final int c = table[1][0];
  1066.         // checkTable has validated the total sum is < 2^31
  1067.         final int m = a + c;
  1068.         if (m == 0) {
  1069.             throw new InferenceException(InferenceException.ZERO_AT, COLUMN_SUM, 0);
  1070.         }
  1071.         final int b = table[0][1];
  1072.         final int d = table[1][1];
  1073.         final int n = b + d;
  1074.         if (n == 0) {
  1075.             throw new InferenceException(InferenceException.ZERO_AT, COLUMN_SUM, 1);
  1076.         }
  1077.         // Total possible tables must be a size we can track in an array (to compute the p-value)
  1078.         final long size = (m + 1L) * (n + 1L);
  1079.         if (size > MAX_TABLES) {
  1080.             throw new InferenceException(InferenceException.X_GT_Y, size, MAX_TABLES);
  1081.         }
  1082.     }
  1083. }