View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.statistics.examples.jmh.distribution;
19  
20  import java.util.SplittableRandom;
21  import java.util.concurrent.ThreadLocalRandom;
22  import java.util.concurrent.TimeUnit;
23  import java.util.function.DoubleUnaryOperator;
24  import org.apache.commons.numbers.rootfinder.BrentSolver;
25  import org.apache.commons.statistics.distribution.BetaDistribution;
26  import org.apache.commons.statistics.distribution.ChiSquaredDistribution;
27  import org.apache.commons.statistics.distribution.ContinuousDistribution;
28  import org.apache.commons.statistics.distribution.FDistribution;
29  import org.apache.commons.statistics.distribution.GammaDistribution;
30  import org.apache.commons.statistics.distribution.NakagamiDistribution;
31  import org.apache.commons.statistics.distribution.TDistribution;
32  import org.openjdk.jmh.annotations.Benchmark;
33  import org.openjdk.jmh.annotations.BenchmarkMode;
34  import org.openjdk.jmh.annotations.Fork;
35  import org.openjdk.jmh.annotations.Measurement;
36  import org.openjdk.jmh.annotations.Mode;
37  import org.openjdk.jmh.annotations.OutputTimeUnit;
38  import org.openjdk.jmh.annotations.Param;
39  import org.openjdk.jmh.annotations.Scope;
40  import org.openjdk.jmh.annotations.Setup;
41  import org.openjdk.jmh.annotations.State;
42  import org.openjdk.jmh.annotations.Warmup;
43  
44  /**
45   * Executes a benchmark of inverse probability function operations
46   * (inverse cumulative distribution function (CDF) and inverse survival function (SF)).
47   */
48  @BenchmarkMode(Mode.AverageTime)
49  @OutputTimeUnit(TimeUnit.NANOSECONDS)
50  @Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
51  @Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
52  @State(Scope.Benchmark)
53  @Fork(value = 1, jvmArgs = {"-server", "-Xms512M", "-Xmx512M"})
54  public class InverseProbabilityPerformance {
55      /** No-operation for baseline. */
56      private static final String NOOP = "Noop";
57      /** Message prefix for an unknown function. */
58      private static final String UNKNOWN_FUNCTION = "unknown function: ";
59      /** Message prefix for an unknown distribution. */
60      private static final String UNKNOWN_DISTRIBUTION = "unknown distrbution: ";
61  
62      /**
63       * The seed for random number generation. Ensures the same numbers are generated
64       * for each implementation of the function.
65       */
66      private static final long SEED = ThreadLocalRandom.current().nextLong();
67  
68      /**
69       * Contains the inverse function to benchmark.
70       */
71      @State(Scope.Benchmark)
72      public static class InverseData {
73          /** The implementation of the function. */
74          @Param({NOOP,
75              // Worst accuracy cases from STATISTICS-36
76              "Beta:4:0.1",
77              "ChiSquared:0.1",
78              "F:5:6",
79              "Gamma:4:2",
80              "Nakagami:0.33333333333:1",
81              "T:5",
82          })
83          private String implementation;
84  
85          /** The inversion relative accuracy. */
86          @Param({
87              // Default from o.a.c.math4.analysis.solvers.BaseAbstractUnivariateSolver
88              "1e-14",
89              // Lowest value so that 2 * eps * x is 1 ULP. Equal to 2^-53.
90              "1.1102230246251565E-16"})
91          private double relEps;
92  
93          /** The inversion absolute accuracy. */
94          @Param({
95              // Default from o.a.c.math4.analysis.solvers.BaseAbstractUnivariateSolver
96              "1e-9",
97              // Lowest non-zero value. Equal to Double.MIN_VALUE.
98              "4.9e-324"})
99          private double absEps;
100 
101         /** The function to invert. */
102         @Param({"cdf", "sf"})
103         private String invert;
104 
105         /** Source of randomness for probabilities in the range [0, 1]. */
106         private SplittableRandom rng;
107 
108         /** The inverse probability function. */
109         private DoubleUnaryOperator function;
110 
111         /**
112          * Create the next inversion of a probability.
113          *
114          * @return the result
115          */
116         public double next() {
117             return function.applyAsDouble(rng.nextDouble());
118         }
119 
120         /**
121          * Create the source of random probability values and the inverse probability function.
122          */
123         @Setup
124         public void setup() {
125             // Creation with a seed ensures the increment uses the golden ratio
126             // with its known robust statistical properties. Creating with no
127             // seed will use a random increment.
128             rng = new SplittableRandom(SEED);
129             function = createFunction(implementation, relEps, absEps, invert);
130         }
131 
132         /**
133          * Creates the inverse probability function.
134          *
135          * @param implementation Function implementation
136          * @param relativeAccuracy Inversion relative accuracy
137          * @param absoluteAccuracy Inversion absolute accuracy
138          * @param invert Function to invert
139          * @return the function
140          */
141         private static DoubleUnaryOperator createFunction(String implementation,
142                                                           double relativeAccuracy,
143                                                           double absoluteAccuracy,
144                                                           String invert) {
145             if (implementation.startsWith(NOOP)) {
146                 return x -> x;
147             }
148 
149             // Create the distribution
150             final ContinuousDistribution dist = createDistribution(implementation);
151 
152             // Get the function inverter
153             final ContinuousDistributionInverter inverter =
154                 new ContinuousDistributionInverter(dist, relativeAccuracy, absoluteAccuracy);
155             // Support CDF and SF
156             if ("cdf".equals(invert)) {
157                 return inverter::inverseCumulativeProbability;
158             } else if ("sf".equals(invert)) {
159                 return inverter::inverseSurvivalProbability;
160             }
161             throw new IllegalStateException(UNKNOWN_FUNCTION + invert);
162         }
163 
164         /**
165          * Creates the distribution.
166          *
167          * @param implementation Function implementation
168          * @return the continuous distribution
169          */
170         private static ContinuousDistribution createDistribution(String implementation) {
171             // Implementation is:
172             // distribution:param1:param2:...
173             final String[] parts = implementation.split(":");
174             if ("Beta".equals(parts[0])) {
175                 return BetaDistribution.of(Double.parseDouble(parts[1]), Double.parseDouble(parts[2]));
176             } else if ("ChiSquared".equals(parts[0])) {
177                 return ChiSquaredDistribution.of(Double.parseDouble(parts[1]));
178             } else if ("F".equals(parts[0])) {
179                 return FDistribution.of(Double.parseDouble(parts[1]), Double.parseDouble(parts[2]));
180             } else if ("Gamma".equals(parts[0])) {
181                 return GammaDistribution.of(Double.parseDouble(parts[1]), Double.parseDouble(parts[2]));
182             } else if ("Nakagami".equals(parts[0])) {
183                 return NakagamiDistribution.of(Double.parseDouble(parts[1]), Double.parseDouble(parts[2]));
184             } else if ("T".equals(parts[0])) {
185                 return TDistribution.of(Double.parseDouble(parts[1]));
186             }
187             throw new IllegalStateException(UNKNOWN_DISTRIBUTION + implementation);
188         }
189 
190         /**
191          * Class to invert the cumulative or survival probability.
192          * This is based on the implementation in the AbstractContinuousDistribution class
193          * from Commons Statistics version 1.0.
194          */
195         static class ContinuousDistributionInverter {
196             /** BrentSolver function value accuracy.
197              * Set to a very low value to search using Brent's method unless
198              * the starting point is correct. */
199             private static final double SOLVER_FUNCTION_VALUE_ACCURACY = Double.MIN_VALUE;
200 
201             /** BrentSolver relative accuracy. This is used with {@code 2 * eps * abs(b)}
202              * so the minimum non-zero value with an effect is half of machine epsilon (2^-53). */
203             private final double relativeAccuracy;
204             /** BrentSolver absolute accuracy. */
205             private final double absoluteAccuracy;
206             /** The distribution. */
207             private final ContinuousDistribution dist;
208 
209             /**
210              * @param dist The distribution to invert
211              * @param relativeAccuracy Solver relative accuracy
212              * @param absoluteAccuracy Solver absolute accuracy
213              */
214             ContinuousDistributionInverter(ContinuousDistribution dist,
215                                            double relativeAccuracy,
216                                            double absoluteAccuracy) {
217                 this.dist = dist;
218                 this.relativeAccuracy = relativeAccuracy;
219                 this.absoluteAccuracy = absoluteAccuracy;
220             }
221 
222             /**
223              * Checks if the value {@code x} is finite and strictly positive.
224              *
225              * @param x Value
226              * @return true if {@code x > 0} and is finite
227              */
228             private static boolean isFiniteStrictlyPositive(double x) {
229                 return x > 0 && x < Double.POSITIVE_INFINITY;
230             }
231 
232             /**
233              * Check the probability {@code p} is in the interval {@code [0, 1]}.
234              *
235              * @param p Probability
236              * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
237              */
238             private static void checkProbability(double p) {
239                 if (p >= 0 && p <= 1) {
240                     return;
241                 }
242                 // Out-of-range or NaN
243                 throw new IllegalArgumentException("Invalid p: " + p);
244             }
245 
246             /**
247              * Compute the inverse cumulative probability.
248              *
249              * @param p Probability
250              * @return the value
251              * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
252              */
253             public double inverseCumulativeProbability(double p) {
254                 checkProbability(p);
255                 return inverseProbability(p, 1 - p, false);
256             }
257 
258             /**
259              * Compute the inverse survival probability.
260              *
261              * @param p Probability
262              * @return the value
263              * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
264              */
265             public double inverseSurvivalProbability(double p) {
266                 checkProbability(p);
267                 return inverseProbability(1 - p, p, true);
268             }
269 
270             /**
271              * Implementation for the inverse cumulative or survival probability.
272              *
273              * @param p Cumulative probability.
274              * @param q Survival probability.
275              * @param complement Set to true to compute the inverse survival probability
276              * @return the value
277              */
278             private double inverseProbability(final double p, final double q, boolean complement) {
279                 /* IMPLEMENTATION NOTES
280                  * --------------------
281                  * Where applicable, use is made of the one-sided Chebyshev inequality
282                  * to bracket the root. This inequality states that
283                  * P(X - mu >= k * sig) <= 1 / (1 + k^2),
284                  * mu: mean, sig: standard deviation. Equivalently
285                  * 1 - P(X < mu + k * sig) <= 1 / (1 + k^2),
286                  * F(mu + k * sig) >= k^2 / (1 + k^2).
287                  *
288                  * For k = sqrt(p / (1 - p)), we find
289                  * F(mu + k * sig) >= p,
290                  * and (mu + k * sig) is an upper-bound for the root.
291                  *
292                  * Then, introducing Y = -X, mean(Y) = -mu, sd(Y) = sig, and
293                  * P(Y >= -mu + k * sig) <= 1 / (1 + k^2),
294                  * P(-X >= -mu + k * sig) <= 1 / (1 + k^2),
295                  * P(X <= mu - k * sig) <= 1 / (1 + k^2),
296                  * F(mu - k * sig) <= 1 / (1 + k^2).
297                  *
298                  * For k = sqrt((1 - p) / p), we find
299                  * F(mu - k * sig) <= p,
300                  * and (mu - k * sig) is a lower-bound for the root.
301                  *
302                  * In cases where the Chebyshev inequality does not apply, geometric
303                  * progressions 1, 2, 4, ... and -1, -2, -4, ... are used to bracket
304                  * the root.
305                  *
306                  * In the case of the survival probability the bracket can be set using the same
307                  * bound given that the argument p = 1 - q, with q the survival probability.
308                  */
309 
310                 double lowerBound = dist.getSupportLowerBound();
311                 if (p == 0) {
312                     return lowerBound;
313                 }
314                 double upperBound = dist.getSupportUpperBound();
315                 if (q == 0) {
316                     return upperBound;
317                 }
318 
319                 final double mu = dist.getMean();
320                 final double sig = Math.sqrt(dist.getVariance());
321                 final boolean chebyshevApplies = Double.isFinite(mu) &&
322                                                  isFiniteStrictlyPositive(sig);
323 
324                 if (lowerBound == Double.NEGATIVE_INFINITY) {
325                     lowerBound = createFiniteLowerBound(p, q, complement, upperBound, mu, sig, chebyshevApplies);
326                 }
327 
328                 if (upperBound == Double.POSITIVE_INFINITY) {
329                     upperBound = createFiniteUpperBound(p, q, complement, lowerBound, mu, sig, chebyshevApplies);
330                 }
331 
332                 // Here the bracket [lower, upper] uses finite values. If the support
333                 // is infinite the bracket can truncate the distribution and the target
334                 // probability can be outside the range of [lower, upper].
335                 if (upperBound == Double.MAX_VALUE) {
336                     if (complement) {
337                         if (dist.survivalProbability(upperBound) > q) {
338                             return dist.getSupportUpperBound();
339                         }
340                     } else if (dist.cumulativeProbability(upperBound) < p) {
341                         return dist.getSupportUpperBound();
342                     }
343                 }
344                 if (lowerBound == -Double.MAX_VALUE) {
345                     if (complement) {
346                         if (dist.survivalProbability(lowerBound) < q) {
347                             return dist.getSupportLowerBound();
348                         }
349                     } else if (dist.cumulativeProbability(lowerBound) > p) {
350                         return dist.getSupportLowerBound();
351                     }
352                 }
353 
354                 final DoubleUnaryOperator fun = complement ?
355                     arg -> dist.survivalProbability(arg) - q :
356                     arg -> dist.cumulativeProbability(arg) - p;
357                 // Note the initial value is robust to overflow.
358                 // Do not use 0.5 * (lowerBound + upperBound).
359                 final double x = new BrentSolver(relativeAccuracy,
360                                                  absoluteAccuracy,
361                                                  SOLVER_FUNCTION_VALUE_ACCURACY)
362                     .findRoot(fun,
363                               lowerBound,
364                               lowerBound + 0.5 * (upperBound - lowerBound),
365                               upperBound);
366 
367                 return x;
368             }
369 
370             /**
371              * Create a finite lower bound. Assumes the current lower bound is negative infinity.
372              *
373              * @param p Cumulative probability.
374              * @param q Survival probability.
375              * @param complement Set to true to compute the inverse survival probability
376              * @param upperBound Current upper bound
377              * @param mu Mean
378              * @param sig Standard deviation
379              * @param chebyshevApplies True if the Chebyshev inequality applies (mean is finite and {@code sig > 0}}
380              * @return the finite lower bound
381              */
382             private double createFiniteLowerBound(final double p, final double q, boolean complement,
383                 double upperBound, final double mu, final double sig, final boolean chebyshevApplies) {
384                 double lowerBound;
385                 if (chebyshevApplies) {
386                     lowerBound = mu - sig * Math.sqrt(q / p);
387                 } else {
388                     lowerBound = Double.NEGATIVE_INFINITY;
389                 }
390                 // Bound may have been set as infinite
391                 if (lowerBound == Double.NEGATIVE_INFINITY) {
392                     lowerBound = Math.min(-1, upperBound);
393                     if (complement) {
394                         while (dist.survivalProbability(lowerBound) < q) {
395                             lowerBound *= 2;
396                         }
397                     } else {
398                         while (dist.cumulativeProbability(lowerBound) >= p) {
399                             lowerBound *= 2;
400                         }
401                     }
402                     // Ensure finite
403                     lowerBound = Math.max(lowerBound, -Double.MAX_VALUE);
404                 }
405                 return lowerBound;
406             }
407 
408             /**
409              * Create a finite upper bound. Assumes the current upper bound is positive infinity.
410              *
411              * @param p Cumulative probability.
412              * @param q Survival probability.
413              * @param complement Set to true to compute the inverse survival probability
414              * @param lowerBound Current lower bound
415              * @param mu Mean
416              * @param sig Standard deviation
417              * @param chebyshevApplies True if the Chebyshev inequality applies (mean is finite and {@code sig > 0}}
418              * @return the finite lower bound
419              */
420             private double createFiniteUpperBound(final double p, final double q, boolean complement,
421                 double lowerBound, final double mu, final double sig, final boolean chebyshevApplies) {
422                 double upperBound;
423                 if (chebyshevApplies) {
424                     upperBound = mu + sig * Math.sqrt(p / q);
425                 } else {
426                     upperBound = Double.POSITIVE_INFINITY;
427                 }
428                 // Bound may have been set as infinite
429                 if (upperBound == Double.POSITIVE_INFINITY) {
430                     upperBound = Math.max(1, lowerBound);
431                     if (complement) {
432                         while (dist.survivalProbability(upperBound) >= q) {
433                             upperBound *= 2;
434                         }
435                     } else {
436                         while (dist.cumulativeProbability(upperBound) < p) {
437                             upperBound *= 2;
438                         }
439                     }
440                     // Ensure finite
441                     upperBound = Math.min(upperBound, Double.MAX_VALUE);
442                 }
443                 return upperBound;
444             }
445         }
446     }
447 
448     /**
449      * Benchmark the inverse function.
450      *
451      * @param data Test data.
452      * @return the inverse function value
453      */
454     @Benchmark
455     public double inverse(InverseData data) {
456         return data.next();
457     }
458 }