View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.examples.sampling;
19  
20  import java.io.PrintWriter;
21  import java.util.EnumSet;
22  import java.util.concurrent.Callable;
23  import java.io.IOException;
24  import org.apache.commons.rng.UniformRandomProvider;
25  import org.apache.commons.rng.simple.RandomSource;
26  
27  import picocli.CommandLine.Command;
28  import picocli.CommandLine.Mixin;
29  import picocli.CommandLine.Option;
30  
31  import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
32  import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
33  import org.apache.commons.rng.sampling.distribution.MarsagliaNormalizedGaussianSampler;
34  import org.apache.commons.rng.sampling.distribution.StableSampler;
35  import org.apache.commons.rng.sampling.distribution.TSampler;
36  import org.apache.commons.rng.sampling.distribution.BoxMullerNormalizedGaussianSampler;
37  import org.apache.commons.rng.sampling.distribution.ChengBetaSampler;
38  import org.apache.commons.rng.sampling.distribution.AhrensDieterExponentialSampler;
39  import org.apache.commons.rng.sampling.distribution.AhrensDieterMarsagliaTsangGammaSampler;
40  import org.apache.commons.rng.sampling.distribution.InverseTransformParetoSampler;
41  import org.apache.commons.rng.sampling.distribution.LevySampler;
42  import org.apache.commons.rng.sampling.distribution.LogNormalSampler;
43  import org.apache.commons.rng.sampling.distribution.ContinuousUniformSampler;
44  import org.apache.commons.rng.sampling.distribution.GaussianSampler;
45  import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
46  
47  /**
48   * Approximation of the probability density by the histogram of the sampler output.
49   */
50  @Command(name = "density",
51           description = {"Approximate the probability density of samplers."})
52  class ProbabilityDensityApproximationCommand implements Callable<Void> {
53      /** The standard options. */
54      @Mixin
55      private StandardOptions reusableOptions;
56  
57      /** Number of (equal-width) bins in the histogram. */
58      @Option(names = {"-b", "--bins"},
59              description = "The number of bins in the histogram (default: ${DEFAULT-VALUE}).")
60      private int numBins = 25_000;
61  
62      /** Number of samples to be generated. */
63      @Option(names = {"-n", "--samples"},
64              description = "The number of samples in the histogram (default: ${DEFAULT-VALUE}).")
65      private long numSamples = 1_000_000_000;
66  
67      /** The samplers. */
68      @Option(names = {"-s", "--samplers"},
69              split = ",",
70              description = {"The samplers (comma-delimited for multiple options).",
71                             "Valid values: ${COMPLETION-CANDIDATES}."})
72      private EnumSet<Sampler> samplers = EnumSet.noneOf(Sampler.class);
73  
74      /** The samplers. */
75      @Option(names = {"-r", "--rng"},
76              description = {"The source of randomness (default: ${DEFAULT-VALUE})."})
77      private RandomSource randomSource = RandomSource.XOR_SHIFT_1024_S_PHI;
78  
79      /** Flag to output all samplers. */
80      @Option(names = {"-a", "--all"},
81              description = "Output all samplers")
82      private boolean allSamplers;
83  
84      /**
85       * The sampler. This enum uses lower case for clarity when matching the distribution name.
86       */
87      enum Sampler {
88          /** The ziggurat gaussian sampler. */
89          ZigguratGaussianSampler,
90          /** The Marsaglia gaussian sampler. */
91          MarsagliaGaussianSampler,
92          /** The Box Muller gaussian sampler. */
93          BoxMullerGaussianSampler,
94          /** The modified ziggurat gaussian sampler. */
95          ModifiedZigguratGaussianSampler,
96          /** The Cheng beta sampler case 1. */
97          ChengBetaSamplerCase1,
98          /** The Cheng beta sampler case 2. */
99          ChengBetaSamplerCase2,
100         /** The Ahrens Dieter exponential sampler. */
101         AhrensDieterExponentialSampler,
102         /** The modified ziggurat exponential sampler. */
103         ModifiedZigguratExponentialSampler,
104         /** The Ahrens Dieter Marsaglia Tsang gamma sampler small gamma. */
105         AhrensDieterMarsagliaTsangGammaSamplerCase1,
106         /** The Ahrens Dieter Marsaglia Tsang gamma sampler large gamma. */
107         AhrensDieterMarsagliaTsangGammaSamplerCase2,
108         /** The inverse transform pareto sampler. */
109         InverseTransformParetoSampler,
110         /** The continuous uniform sampler. */
111         ContinuousUniformSampler,
112         /** The log normal ziggurat gaussian sampler. */
113         LogNormalZigguratGaussianSampler,
114         /** The log normal Marsaglia gaussian sampler. */
115         LogNormalMarsagliaGaussianSampler,
116         /** The log normal Box Muller gaussian sampler. */
117         LogNormalBoxMullerGaussianSampler,
118         /** The log normal modified ziggurat gaussian sampler. */
119         LogNormalModifiedZigguratGaussianSampler,
120         /** The Levy sampler. */
121         LevySampler,
122         /** The stable sampler. */
123         StableSampler,
124         /** The t sampler. */
125         TSampler,
126     }
127 
128     /**
129      * @param sampler Sampler.
130      * @param min Right abscissa of the first bin: every sample smaller
131      * than that value will increment an additional bin (of infinite width)
132      * placed before the first "equal-width" bin.
133      * @param max abscissa of the last bin: every sample larger than or
134      * equal to that value will increment an additional bin (of infinite
135      * width) placed after the last "equal-width" bin.
136      * @param outputFile Filename (final name is "pdf.[filename].txt").
137      * @throws IOException Signals that an I/O exception has occurred.
138      */
139     private void createDensity(ContinuousSampler sampler,
140                                double min,
141                                double max,
142                                String outputFile)
143         throws IOException {
144         final double binSize = (max - min) / numBins;
145         final long[] histogram = new long[numBins];
146 
147         long belowMin = 0;
148         long aboveMax = 0;
149         for (long n = 0; n < numSamples; n++) {
150             final double r = sampler.sample();
151 
152             if (r < min) {
153                 ++belowMin;
154                 continue;
155             }
156 
157             if (r >= max) {
158                 ++aboveMax;
159                 continue;
160             }
161 
162             final int binIndex = (int) ((r - min) / binSize);
163             ++histogram[binIndex];
164         }
165 
166         final double binHalfSize = 0.5 * binSize;
167         final double norm = 1 / (binSize * numSamples);
168 
169         try (PrintWriter out = new PrintWriter("pdf." + outputFile + ".txt", "UTF-8")) {
170             // CHECKSTYLE: stop MultipleStringLiteralsCheck
171             out.println("# Sampler: " + sampler);
172             out.println("# Number of bins: " + numBins);
173             out.println("# Min: " + min + " (fraction of samples below: " + (belowMin / (double) numSamples) + ")");
174             out.println("# Max: " + max + " (fraction of samples above: " + (aboveMax / (double) numSamples) + ")");
175             out.println("# Bin width: " + binSize);
176             out.println("# Histogram normalization factor: " + norm);
177             out.println("#");
178             out.println("# " + (min - binHalfSize) + " " + (belowMin * norm));
179             for (int i = 0; i < numBins; i++) {
180                 out.println((min + (i + 1) * binSize - binHalfSize) + " " + (histogram[i] * norm));
181             }
182             out.println("# " + (max + binHalfSize) + " " + (aboveMax * norm));
183             // CHECKSTYLE: resume MultipleStringLiteralsCheck
184         }
185     }
186 
187     /**
188      * Program entry point.
189      *
190      * @throws IOException if failure occurred while writing to files.
191      */
192     @Override
193     public Void call() throws IOException {
194         if (allSamplers) {
195             samplers = EnumSet.allOf(Sampler.class);
196         } else if (samplers.isEmpty()) {
197             // CHECKSTYLE: stop regexp
198             System.err.println("ERROR: No samplers specified");
199             // CHECKSTYLE: resume regexp
200             System.exit(1);
201         }
202 
203         final UniformRandomProvider rng = randomSource.create();
204 
205         final double gaussMean = 1;
206         final double gaussSigma = 2;
207         final double gaussMin = -9;
208         final double gaussMax = 11;
209         if (samplers.contains(Sampler.ZigguratGaussianSampler)) {
210             createDensity(GaussianSampler.of(ZigguratNormalizedGaussianSampler.of(rng),
211                                              gaussMean, gaussSigma),
212                           gaussMin, gaussMax, "gauss.ziggurat");
213         }
214         if (samplers.contains(Sampler.MarsagliaGaussianSampler)) {
215             createDensity(GaussianSampler.of(MarsagliaNormalizedGaussianSampler.of(rng),
216                                              gaussMean, gaussSigma),
217                           gaussMin, gaussMax, "gauss.marsaglia");
218         }
219         if (samplers.contains(Sampler.BoxMullerGaussianSampler)) {
220             createDensity(GaussianSampler.of(BoxMullerNormalizedGaussianSampler.of(rng),
221                                              gaussMean, gaussSigma),
222                           gaussMin, gaussMax, "gauss.boxmuller");
223         }
224         if (samplers.contains(Sampler.ModifiedZigguratGaussianSampler)) {
225             createDensity(GaussianSampler.of(ZigguratSampler.NormalizedGaussian.of(rng),
226                                              gaussMean, gaussSigma),
227                           gaussMin, gaussMax, "gauss.modified.ziggurat");
228         }
229 
230         final double betaMin = 0;
231         final double betaMax = 1;
232         if (samplers.contains(Sampler.ChengBetaSamplerCase1)) {
233             final double alphaBeta = 4.3;
234             final double betaBeta = 2.1;
235             createDensity(ChengBetaSampler.of(rng, alphaBeta, betaBeta),
236                           betaMin, betaMax, "beta.case1");
237         }
238         if (samplers.contains(Sampler.ChengBetaSamplerCase2)) {
239             final double alphaBetaAlt = 0.5678;
240             final double betaBetaAlt = 0.1234;
241             createDensity(ChengBetaSampler.of(rng, alphaBetaAlt, betaBetaAlt),
242                           betaMin, betaMax, "beta.case2");
243         }
244 
245         final double meanExp = 3.45;
246         final double expMin = 0;
247         final double expMax = 60;
248         if (samplers.contains(Sampler.AhrensDieterExponentialSampler)) {
249             createDensity(AhrensDieterExponentialSampler.of(rng, meanExp),
250                           expMin, expMax, "exp");
251         }
252         if (samplers.contains(Sampler.ModifiedZigguratExponentialSampler)) {
253             createDensity(ZigguratSampler.Exponential.of(rng, meanExp),
254                           expMin, expMax, "exp.modified.ziggurat");
255         }
256 
257         final double gammaMin = 0;
258         final double gammaMax1 = 40;
259         final double thetaGamma = 3.456;
260         if (samplers.contains(Sampler.AhrensDieterMarsagliaTsangGammaSamplerCase1)) {
261             final double alphaGammaSmallerThanOne = 0.1234;
262             createDensity(AhrensDieterMarsagliaTsangGammaSampler.of(rng, alphaGammaSmallerThanOne, thetaGamma),
263                           gammaMin, gammaMax1, "gamma.case1");
264         }
265         if (samplers.contains(Sampler.AhrensDieterMarsagliaTsangGammaSamplerCase2)) {
266             final double alphaGammaLargerThanOne = 2.345;
267             final double gammaMax2 = 70;
268             createDensity(AhrensDieterMarsagliaTsangGammaSampler.of(rng, alphaGammaLargerThanOne, thetaGamma),
269                           gammaMin, gammaMax2, "gamma.case2");
270         }
271 
272         final double scalePareto = 23.45;
273         final double shapePareto = 0.789;
274         final double paretoMin = 23;
275         final double paretoMax = 400;
276         if (samplers.contains(Sampler.InverseTransformParetoSampler)) {
277             createDensity(InverseTransformParetoSampler.of(rng, scalePareto, shapePareto),
278                           paretoMin, paretoMax, "pareto");
279         }
280 
281         final double loUniform = -9.876;
282         final double hiUniform = 5.432;
283         if (samplers.contains(Sampler.ContinuousUniformSampler)) {
284             createDensity(ContinuousUniformSampler.of(rng, loUniform, hiUniform),
285                           loUniform, hiUniform, "uniform");
286         }
287 
288         final double scaleLogNormal = 2.345;
289         final double shapeLogNormal = 0.1234;
290         final double logNormalMin = 5;
291         final double logNormalMax = 25;
292         if (samplers.contains(Sampler.LogNormalZigguratGaussianSampler)) {
293             createDensity(LogNormalSampler.of(ZigguratNormalizedGaussianSampler.of(rng),
294                                               scaleLogNormal, shapeLogNormal),
295                           logNormalMin, logNormalMax, "lognormal.ziggurat");
296         }
297         if (samplers.contains(Sampler.LogNormalMarsagliaGaussianSampler)) {
298             createDensity(LogNormalSampler.of(MarsagliaNormalizedGaussianSampler.of(rng),
299                                               scaleLogNormal, shapeLogNormal),
300                           logNormalMin, logNormalMax, "lognormal.marsaglia");
301         }
302         if (samplers.contains(Sampler.LogNormalBoxMullerGaussianSampler)) {
303             createDensity(LogNormalSampler.of(BoxMullerNormalizedGaussianSampler.of(rng),
304                                               scaleLogNormal, shapeLogNormal),
305                           logNormalMin, logNormalMax, "lognormal.boxmuller");
306         }
307         if (samplers.contains(Sampler.LogNormalModifiedZigguratGaussianSampler)) {
308             createDensity(LogNormalSampler.of(ZigguratSampler.NormalizedGaussian.of(rng),
309                                               scaleLogNormal, shapeLogNormal),
310                           logNormalMin, logNormalMax, "lognormal.modified.ziggurat");
311         }
312 
313         if (samplers.contains(Sampler.LevySampler)) {
314             final double levyLocation = 1.23;
315             final double levyscale = 0.75;
316             final double levyMin = levyLocation;
317             // Quantile 0 to 0.7 (avoid long tail to infinity)
318             final double levyMax = 6.2815;
319             createDensity(LevySampler.of(rng, levyLocation, levyscale),
320                           levyMin, levyMax, "levy");
321         }
322 
323         if (samplers.contains(Sampler.StableSampler)) {
324             final double stableAlpha = 1.23;
325             final double stableBeta = 0.75;
326             // Quantiles 0.05 to 0.9 (avoid long tail to infinity)
327             final double stableMin = -1.7862;
328             final double stableMax = 4.0364;
329             createDensity(StableSampler.of(rng, stableAlpha, stableBeta),
330                           stableMin, stableMax, "stable");
331         }
332 
333         if (samplers.contains(Sampler.TSampler)) {
334             final double tDegreesOfFreedom = 1.23;
335             // Quantiles 0.02 to 0.98 (avoid long tail to infinity)
336             final double tMin = -9.9264;
337             final double tMax = 9.9264;
338             createDensity(TSampler.of(rng, tDegreesOfFreedom),
339                           tMin, tMax, "t");
340         }
341 
342         return null;
343     }
344 }