View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.sampling.distribution;
19  
20  import org.apache.commons.rng.UniformRandomProvider;
21  
22  /**
23   * <a href="https://en.wikipedia.org/wiki/Ziggurat_algorithm">
24   * Marsaglia and Tsang "Ziggurat" method</a> for sampling from a Gaussian
25   * distribution with mean 0 and standard deviation 1.
26   *
27   * <p>The algorithm is explained in this
28   * <a href="http://www.jstatsoft.org/article/view/v005i08/ziggurat.pdf">paper</a>
29   * and this implementation has been adapted from the C code provided therein.</p>
30   *
31   * <p>Sampling uses:</p>
32   *
33   * <ul>
34   *   <li>{@link UniformRandomProvider#nextLong()}
35   *   <li>{@link UniformRandomProvider#nextDouble()}
36   * </ul>
37   *
38   * @since 1.1
39   */
40  public class ZigguratNormalizedGaussianSampler
41      implements NormalizedGaussianSampler, SharedStateContinuousSampler {
42      /** Start of tail. */
43      private static final double R = 3.6541528853610088;
44      /** Inverse of R. */
45      private static final double ONE_OVER_R = 1 / R;
46      /** Index of last entry in the tables (which have a size that is a power of 2). */
47      private static final int LAST = 255;
48      /** Auxiliary table. */
49      private static final long[] K;
50      /** Auxiliary table. */
51      private static final double[] W;
52      /** Auxiliary table. */
53      private static final double[] F;
54  
55      /** Underlying source of randomness. */
56      private final UniformRandomProvider rng;
57  
58      static {
59          // Filling the tables.
60          // Rectangle area.
61          final double v = 0.00492867323399;
62          // Direction support uses the sign bit so the maximum magnitude from the long is 2^63
63          final double max = Math.pow(2, 63);
64          final double oneOverMax = 1d / max;
65  
66          K = new long[LAST + 1];
67          W = new double[LAST + 1];
68          F = new double[LAST + 1];
69  
70          double d = R;
71          double t = d;
72          double fd = pdf(d);
73          final double q = v / fd;
74  
75          K[0] = (long) ((d / q) * max);
76          K[1] = 0;
77  
78          W[0] = q * oneOverMax;
79          W[LAST] = d * oneOverMax;
80  
81          F[0] = 1;
82          F[LAST] = fd;
83  
84          for (int i = LAST - 1; i >= 1; i--) {
85              d = Math.sqrt(-2 * Math.log(v / d + fd));
86              fd = pdf(d);
87  
88              K[i + 1] = (long) ((d / t) * max);
89              t = d;
90  
91              F[i] = fd;
92  
93              W[i] = d * oneOverMax;
94          }
95      }
96  
97      /**
98       * Create an instance.
99       *
100      * @param rng Generator of uniformly distributed random numbers.
101      */
102     public ZigguratNormalizedGaussianSampler(UniformRandomProvider rng) {
103         this.rng = rng;
104     }
105 
106     /** {@inheritDoc} */
107     @Override
108     public double sample() {
109         final long j = rng.nextLong();
110         final int i = ((int) j) & LAST;
111         if (Math.abs(j) < K[i]) {
112             // This branch is called about 0.985086 times per sample.
113             return j * W[i];
114         }
115         return fix(j, i);
116     }
117 
118     /** {@inheritDoc} */
119     @Override
120     public String toString() {
121         return "Ziggurat normalized Gaussian deviate [" + rng.toString() + "]";
122     }
123 
124     /**
125      * Gets the value from the tail of the distribution.
126      *
127      * @param hz Start random integer.
128      * @param iz Index of cell corresponding to {@code hz}.
129      * @return the requested random value.
130      */
131     private double fix(long hz,
132                        int iz) {
133         if (iz == 0) {
134             // Base strip.
135             // This branch is called about 2.55224E-4 times per sample.
136             double y;
137             double x;
138             do {
139                 // Avoid infinity by creating a non-zero double.
140                 // Note: The extreme value y from -Math.log(2^-53) is (to 4 sf):
141                 // y = 36.74
142                 // The largest value x where 2y < x^2 is false is sqrt(2*36.74):
143                 // x = 8.571
144                 // The extreme tail is:
145                 // out = +/- 12.01
146                 // To generate this requires longs of 0 and then (1377 << 11).
147                 y = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong()));
148                 x = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong())) * ONE_OVER_R;
149             } while (y + y < x * x);
150 
151             final double out = R + x;
152             return hz > 0 ? out : -out;
153         }
154         // Wedge of other strips.
155         // This branch is called about 0.0146584 times per sample.
156         final double x = hz * W[iz];
157         if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < pdf(x)) {
158             // This branch is called about 0.00797887 times per sample.
159             return x;
160         }
161         // Try again.
162         // This branch is called about 0.00667957 times per sample.
163         return sample();
164     }
165 
166     /**
167      * Compute the Gaussian probability density function {@code f(x) = e^-0.5x^2}.
168      *
169      * @param x Argument.
170      * @return \( e^{-\frac{x^2}{2}} \)
171      */
172     private static double pdf(double x) {
173         return Math.exp(-0.5 * x * x);
174     }
175 
176     /**
177      * {@inheritDoc}
178      *
179      * @since 1.3
180      */
181     @Override
182     public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) {
183         return new ZigguratNormalizedGaussianSampler(rng);
184     }
185 
186     /**
187      * Create a new normalised Gaussian sampler.
188      *
189      * @param <S> Sampler type.
190      * @param rng Generator of uniformly distributed random numbers.
191      * @return the sampler
192      * @since 1.3
193      */
194     @SuppressWarnings("unchecked")
195     public static <S extends NormalizedGaussianSampler & SharedStateContinuousSampler> S
196             of(UniformRandomProvider rng) {
197         return (S) new ZigguratNormalizedGaussianSampler(rng);
198     }
199 }