1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.statistics.examples.jmh.distribution;
19
20 import java.util.SplittableRandom;
21 import java.util.concurrent.ThreadLocalRandom;
22 import java.util.concurrent.TimeUnit;
23 import java.util.function.DoubleUnaryOperator;
24 import org.apache.commons.numbers.rootfinder.BrentSolver;
25 import org.apache.commons.statistics.distribution.BetaDistribution;
26 import org.apache.commons.statistics.distribution.ChiSquaredDistribution;
27 import org.apache.commons.statistics.distribution.ContinuousDistribution;
28 import org.apache.commons.statistics.distribution.FDistribution;
29 import org.apache.commons.statistics.distribution.GammaDistribution;
30 import org.apache.commons.statistics.distribution.NakagamiDistribution;
31 import org.apache.commons.statistics.distribution.TDistribution;
32 import org.openjdk.jmh.annotations.Benchmark;
33 import org.openjdk.jmh.annotations.BenchmarkMode;
34 import org.openjdk.jmh.annotations.Fork;
35 import org.openjdk.jmh.annotations.Measurement;
36 import org.openjdk.jmh.annotations.Mode;
37 import org.openjdk.jmh.annotations.OutputTimeUnit;
38 import org.openjdk.jmh.annotations.Param;
39 import org.openjdk.jmh.annotations.Scope;
40 import org.openjdk.jmh.annotations.Setup;
41 import org.openjdk.jmh.annotations.State;
42 import org.openjdk.jmh.annotations.Warmup;
43
44 /**
45 * Executes a benchmark of inverse probability function operations
46 * (inverse cumulative distribution function (CDF) and inverse survival function (SF)).
47 */
48 @BenchmarkMode(Mode.AverageTime)
49 @OutputTimeUnit(TimeUnit.NANOSECONDS)
50 @Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
51 @Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
52 @State(Scope.Benchmark)
53 @Fork(value = 1, jvmArgs = {"-server", "-Xms512M", "-Xmx512M"})
54 public class InverseProbabilityPerformance {
55 /** No-operation for baseline. */
56 private static final String NOOP = "Noop";
57 /** Message prefix for an unknown function. */
58 private static final String UNKNOWN_FUNCTION = "unknown function: ";
59 /** Message prefix for an unknown distribution. */
60 private static final String UNKNOWN_DISTRIBUTION = "unknown distrbution: ";
61
62 /**
63 * The seed for random number generation. Ensures the same numbers are generated
64 * for each implementation of the function.
65 */
66 private static final long SEED = ThreadLocalRandom.current().nextLong();
67
68 /**
69 * Contains the inverse function to benchmark.
70 */
71 @State(Scope.Benchmark)
72 public static class InverseData {
73 /** The implementation of the function. */
74 @Param({NOOP,
75 // Worst accuracy cases from STATISTICS-36
76 "Beta:4:0.1",
77 "ChiSquared:0.1",
78 "F:5:6",
79 "Gamma:4:2",
80 "Nakagami:0.33333333333:1",
81 "T:5",
82 })
83 private String implementation;
84
85 /** The inversion relative accuracy. */
86 @Param({
87 // Default from o.a.c.math4.analysis.solvers.BaseAbstractUnivariateSolver
88 "1e-14",
89 // Lowest value so that 2 * eps * x is 1 ULP. Equal to 2^-53.
90 "1.1102230246251565E-16"})
91 private double relEps;
92
93 /** The inversion absolute accuracy. */
94 @Param({
95 // Default from o.a.c.math4.analysis.solvers.BaseAbstractUnivariateSolver
96 "1e-9",
97 // Lowest non-zero value. Equal to Double.MIN_VALUE.
98 "4.9e-324"})
99 private double absEps;
100
101 /** The function to invert. */
102 @Param({"cdf", "sf"})
103 private String invert;
104
105 /** Source of randomness for probabilities in the range [0, 1]. */
106 private SplittableRandom rng;
107
108 /** The inverse probability function. */
109 private DoubleUnaryOperator function;
110
111 /**
112 * Create the next inversion of a probability.
113 *
114 * @return the result
115 */
116 public double next() {
117 return function.applyAsDouble(rng.nextDouble());
118 }
119
120 /**
121 * Create the source of random probability values and the inverse probability function.
122 */
123 @Setup
124 public void setup() {
125 // Creation with a seed ensures the increment uses the golden ratio
126 // with its known robust statistical properties. Creating with no
127 // seed will use a random increment.
128 rng = new SplittableRandom(SEED);
129 function = createFunction(implementation, relEps, absEps, invert);
130 }
131
132 /**
133 * Creates the inverse probability function.
134 *
135 * @param implementation Function implementation
136 * @param relativeAccuracy Inversion relative accuracy
137 * @param absoluteAccuracy Inversion absolute accuracy
138 * @param invert Function to invert
139 * @return the function
140 */
141 private static DoubleUnaryOperator createFunction(String implementation,
142 double relativeAccuracy,
143 double absoluteAccuracy,
144 String invert) {
145 if (implementation.startsWith(NOOP)) {
146 return x -> x;
147 }
148
149 // Create the distribution
150 final ContinuousDistribution dist = createDistribution(implementation);
151
152 // Get the function inverter
153 final ContinuousDistributionInverter inverter =
154 new ContinuousDistributionInverter(dist, relativeAccuracy, absoluteAccuracy);
155 // Support CDF and SF
156 if ("cdf".equals(invert)) {
157 return inverter::inverseCumulativeProbability;
158 } else if ("sf".equals(invert)) {
159 return inverter::inverseSurvivalProbability;
160 }
161 throw new IllegalStateException(UNKNOWN_FUNCTION + invert);
162 }
163
164 /**
165 * Creates the distribution.
166 *
167 * @param implementation Function implementation
168 * @return the continuous distribution
169 */
170 private static ContinuousDistribution createDistribution(String implementation) {
171 // Implementation is:
172 // distribution:param1:param2:...
173 final String[] parts = implementation.split(":");
174 if ("Beta".equals(parts[0])) {
175 return BetaDistribution.of(Double.parseDouble(parts[1]), Double.parseDouble(parts[2]));
176 } else if ("ChiSquared".equals(parts[0])) {
177 return ChiSquaredDistribution.of(Double.parseDouble(parts[1]));
178 } else if ("F".equals(parts[0])) {
179 return FDistribution.of(Double.parseDouble(parts[1]), Double.parseDouble(parts[2]));
180 } else if ("Gamma".equals(parts[0])) {
181 return GammaDistribution.of(Double.parseDouble(parts[1]), Double.parseDouble(parts[2]));
182 } else if ("Nakagami".equals(parts[0])) {
183 return NakagamiDistribution.of(Double.parseDouble(parts[1]), Double.parseDouble(parts[2]));
184 } else if ("T".equals(parts[0])) {
185 return TDistribution.of(Double.parseDouble(parts[1]));
186 }
187 throw new IllegalStateException(UNKNOWN_DISTRIBUTION + implementation);
188 }
189
190 /**
191 * Class to invert the cumulative or survival probability.
192 * This is based on the implementation in the AbstractContinuousDistribution class
193 * from Commons Statistics version 1.0.
194 */
195 static class ContinuousDistributionInverter {
196 /** BrentSolver function value accuracy.
197 * Set to a very low value to search using Brent's method unless
198 * the starting point is correct. */
199 private static final double SOLVER_FUNCTION_VALUE_ACCURACY = Double.MIN_VALUE;
200
201 /** BrentSolver relative accuracy. This is used with {@code 2 * eps * abs(b)}
202 * so the minimum non-zero value with an effect is half of machine epsilon (2^-53). */
203 private final double relativeAccuracy;
204 /** BrentSolver absolute accuracy. */
205 private final double absoluteAccuracy;
206 /** The distribution. */
207 private final ContinuousDistribution dist;
208
209 /**
210 * @param dist The distribution to invert
211 * @param relativeAccuracy Solver relative accuracy
212 * @param absoluteAccuracy Solver absolute accuracy
213 */
214 ContinuousDistributionInverter(ContinuousDistribution dist,
215 double relativeAccuracy,
216 double absoluteAccuracy) {
217 this.dist = dist;
218 this.relativeAccuracy = relativeAccuracy;
219 this.absoluteAccuracy = absoluteAccuracy;
220 }
221
222 /**
223 * Checks if the value {@code x} is finite and strictly positive.
224 *
225 * @param x Value
226 * @return true if {@code x > 0} and is finite
227 */
228 private static boolean isFiniteStrictlyPositive(double x) {
229 return x > 0 && x < Double.POSITIVE_INFINITY;
230 }
231
232 /**
233 * Check the probability {@code p} is in the interval {@code [0, 1]}.
234 *
235 * @param p Probability
236 * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
237 */
238 private static void checkProbability(double p) {
239 if (p >= 0 && p <= 1) {
240 return;
241 }
242 // Out-of-range or NaN
243 throw new IllegalArgumentException("Invalid p: " + p);
244 }
245
246 /**
247 * Compute the inverse cumulative probability.
248 *
249 * @param p Probability
250 * @return the value
251 * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
252 */
253 public double inverseCumulativeProbability(double p) {
254 checkProbability(p);
255 return inverseProbability(p, 1 - p, false);
256 }
257
258 /**
259 * Compute the inverse survival probability.
260 *
261 * @param p Probability
262 * @return the value
263 * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
264 */
265 public double inverseSurvivalProbability(double p) {
266 checkProbability(p);
267 return inverseProbability(1 - p, p, true);
268 }
269
270 /**
271 * Implementation for the inverse cumulative or survival probability.
272 *
273 * @param p Cumulative probability.
274 * @param q Survival probability.
275 * @param complement Set to true to compute the inverse survival probability
276 * @return the value
277 */
278 private double inverseProbability(final double p, final double q, boolean complement) {
279 /* IMPLEMENTATION NOTES
280 * --------------------
281 * Where applicable, use is made of the one-sided Chebyshev inequality
282 * to bracket the root. This inequality states that
283 * P(X - mu >= k * sig) <= 1 / (1 + k^2),
284 * mu: mean, sig: standard deviation. Equivalently
285 * 1 - P(X < mu + k * sig) <= 1 / (1 + k^2),
286 * F(mu + k * sig) >= k^2 / (1 + k^2).
287 *
288 * For k = sqrt(p / (1 - p)), we find
289 * F(mu + k * sig) >= p,
290 * and (mu + k * sig) is an upper-bound for the root.
291 *
292 * Then, introducing Y = -X, mean(Y) = -mu, sd(Y) = sig, and
293 * P(Y >= -mu + k * sig) <= 1 / (1 + k^2),
294 * P(-X >= -mu + k * sig) <= 1 / (1 + k^2),
295 * P(X <= mu - k * sig) <= 1 / (1 + k^2),
296 * F(mu - k * sig) <= 1 / (1 + k^2).
297 *
298 * For k = sqrt((1 - p) / p), we find
299 * F(mu - k * sig) <= p,
300 * and (mu - k * sig) is a lower-bound for the root.
301 *
302 * In cases where the Chebyshev inequality does not apply, geometric
303 * progressions 1, 2, 4, ... and -1, -2, -4, ... are used to bracket
304 * the root.
305 *
306 * In the case of the survival probability the bracket can be set using the same
307 * bound given that the argument p = 1 - q, with q the survival probability.
308 */
309
310 double lowerBound = dist.getSupportLowerBound();
311 if (p == 0) {
312 return lowerBound;
313 }
314 double upperBound = dist.getSupportUpperBound();
315 if (q == 0) {
316 return upperBound;
317 }
318
319 final double mu = dist.getMean();
320 final double sig = Math.sqrt(dist.getVariance());
321 final boolean chebyshevApplies = Double.isFinite(mu) &&
322 isFiniteStrictlyPositive(sig);
323
324 if (lowerBound == Double.NEGATIVE_INFINITY) {
325 lowerBound = createFiniteLowerBound(p, q, complement, upperBound, mu, sig, chebyshevApplies);
326 }
327
328 if (upperBound == Double.POSITIVE_INFINITY) {
329 upperBound = createFiniteUpperBound(p, q, complement, lowerBound, mu, sig, chebyshevApplies);
330 }
331
332 // Here the bracket [lower, upper] uses finite values. If the support
333 // is infinite the bracket can truncate the distribution and the target
334 // probability can be outside the range of [lower, upper].
335 if (upperBound == Double.MAX_VALUE) {
336 if (complement) {
337 if (dist.survivalProbability(upperBound) > q) {
338 return dist.getSupportUpperBound();
339 }
340 } else if (dist.cumulativeProbability(upperBound) < p) {
341 return dist.getSupportUpperBound();
342 }
343 }
344 if (lowerBound == -Double.MAX_VALUE) {
345 if (complement) {
346 if (dist.survivalProbability(lowerBound) < q) {
347 return dist.getSupportLowerBound();
348 }
349 } else if (dist.cumulativeProbability(lowerBound) > p) {
350 return dist.getSupportLowerBound();
351 }
352 }
353
354 final DoubleUnaryOperator fun = complement ?
355 arg -> dist.survivalProbability(arg) - q :
356 arg -> dist.cumulativeProbability(arg) - p;
357 // Note the initial value is robust to overflow.
358 // Do not use 0.5 * (lowerBound + upperBound).
359 final double x = new BrentSolver(relativeAccuracy,
360 absoluteAccuracy,
361 SOLVER_FUNCTION_VALUE_ACCURACY)
362 .findRoot(fun,
363 lowerBound,
364 lowerBound + 0.5 * (upperBound - lowerBound),
365 upperBound);
366
367 return x;
368 }
369
370 /**
371 * Create a finite lower bound. Assumes the current lower bound is negative infinity.
372 *
373 * @param p Cumulative probability.
374 * @param q Survival probability.
375 * @param complement Set to true to compute the inverse survival probability
376 * @param upperBound Current upper bound
377 * @param mu Mean
378 * @param sig Standard deviation
379 * @param chebyshevApplies True if the Chebyshev inequality applies (mean is finite and {@code sig > 0}}
380 * @return the finite lower bound
381 */
382 private double createFiniteLowerBound(final double p, final double q, boolean complement,
383 double upperBound, final double mu, final double sig, final boolean chebyshevApplies) {
384 double lowerBound;
385 if (chebyshevApplies) {
386 lowerBound = mu - sig * Math.sqrt(q / p);
387 } else {
388 lowerBound = Double.NEGATIVE_INFINITY;
389 }
390 // Bound may have been set as infinite
391 if (lowerBound == Double.NEGATIVE_INFINITY) {
392 lowerBound = Math.min(-1, upperBound);
393 if (complement) {
394 while (dist.survivalProbability(lowerBound) < q) {
395 lowerBound *= 2;
396 }
397 } else {
398 while (dist.cumulativeProbability(lowerBound) >= p) {
399 lowerBound *= 2;
400 }
401 }
402 // Ensure finite
403 lowerBound = Math.max(lowerBound, -Double.MAX_VALUE);
404 }
405 return lowerBound;
406 }
407
408 /**
409 * Create a finite upper bound. Assumes the current upper bound is positive infinity.
410 *
411 * @param p Cumulative probability.
412 * @param q Survival probability.
413 * @param complement Set to true to compute the inverse survival probability
414 * @param lowerBound Current lower bound
415 * @param mu Mean
416 * @param sig Standard deviation
417 * @param chebyshevApplies True if the Chebyshev inequality applies (mean is finite and {@code sig > 0}}
418 * @return the finite lower bound
419 */
420 private double createFiniteUpperBound(final double p, final double q, boolean complement,
421 double lowerBound, final double mu, final double sig, final boolean chebyshevApplies) {
422 double upperBound;
423 if (chebyshevApplies) {
424 upperBound = mu + sig * Math.sqrt(p / q);
425 } else {
426 upperBound = Double.POSITIVE_INFINITY;
427 }
428 // Bound may have been set as infinite
429 if (upperBound == Double.POSITIVE_INFINITY) {
430 upperBound = Math.max(1, lowerBound);
431 if (complement) {
432 while (dist.survivalProbability(upperBound) >= q) {
433 upperBound *= 2;
434 }
435 } else {
436 while (dist.cumulativeProbability(upperBound) < p) {
437 upperBound *= 2;
438 }
439 }
440 // Ensure finite
441 upperBound = Math.min(upperBound, Double.MAX_VALUE);
442 }
443 return upperBound;
444 }
445 }
446 }
447
448 /**
449 * Benchmark the inverse function.
450 *
451 * @param data Test data.
452 * @return the inverse function value
453 */
454 @Benchmark
455 public double inverse(InverseData data) {
456 return data.next();
457 }
458 }