View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.sampling.distribution;
19  
20  import org.apache.commons.rng.UniformRandomProvider;
21  
22  /**
23   * <a href="https://en.wikipedia.org/wiki/Ziggurat_algorithm">
24   * Marsaglia and Tsang "Ziggurat" method</a> for sampling from a Gaussian
25   * distribution with mean 0 and standard deviation 1.
26   *
27   * <p>The algorithm is explained in this
28   * <a href="http://www.jstatsoft.org/article/view/v005i08/ziggurat.pdf">paper</a>
29   * and this implementation has been adapted from the C code provided therein.</p>
30   *
31   * <p>Sampling uses:</p>
32   *
33   * <ul>
34   *   <li>{@link UniformRandomProvider#nextLong()}
35   *   <li>{@link UniformRandomProvider#nextDouble()}
36   * </ul>
37   *
38   * @since 1.1
39   */
40  public class ZigguratNormalizedGaussianSampler
41      implements NormalizedGaussianSampler, SharedStateContinuousSampler {
42      /** Start of tail. */
43      private static final double R = 3.6541528853610088;
44      /** Inverse of R. */
45      private static final double ONE_OVER_R = 1 / R;
46      /** Index of last entry in the tables (which have a size that is a power of 2). */
47      private static final int LAST = 255;
48      /** Auxiliary table. */
49      private static final long[] K;
50      /** Auxiliary table. */
51      private static final double[] W;
52      /** Auxiliary table. */
53      private static final double[] F;
54  
55      /** Underlying source of randomness. */
56      private final UniformRandomProvider rng;
57  
58      static {
59          // Filling the tables.
60          // Rectangle area.
61          final double v = 0.00492867323399;
62          // Direction support uses the sign bit so the maximum magnitude from the long is 2^63
63          final double max = Math.pow(2, 63);
64          final double oneOverMax = 1d / max;
65  
66          K = new long[LAST + 1];
67          W = new double[LAST + 1];
68          F = new double[LAST + 1];
69  
70          double d = R;
71          double t = d;
72          double fd = pdf(d);
73          final double q = v / fd;
74  
75          K[0] = (long) ((d / q) * max);
76          K[1] = 0;
77  
78          W[0] = q * oneOverMax;
79          W[LAST] = d * oneOverMax;
80  
81          F[0] = 1;
82          F[LAST] = fd;
83  
84          for (int i = LAST - 1; i >= 1; i--) {
85              d = Math.sqrt(-2 * Math.log(v / d + fd));
86              fd = pdf(d);
87  
88              K[i + 1] = (long) ((d / t) * max);
89              t = d;
90  
91              F[i] = fd;
92  
93              W[i] = d * oneOverMax;
94          }
95      }
96  
97      /**
98       * @param rng Generator of uniformly distributed random numbers.
99       */
100     public ZigguratNormalizedGaussianSampler(UniformRandomProvider rng) {
101         this.rng = rng;
102     }
103 
104     /** {@inheritDoc} */
105     @Override
106     public double sample() {
107         final long j = rng.nextLong();
108         final int i = ((int) j) & LAST;
109         if (Math.abs(j) < K[i]) {
110             // This branch is called about 0.985086 times per sample.
111             return j * W[i];
112         }
113         return fix(j, i);
114     }
115 
116     /** {@inheritDoc} */
117     @Override
118     public String toString() {
119         return "Ziggurat normalized Gaussian deviate [" + rng.toString() + "]";
120     }
121 
122     /**
123      * Gets the value from the tail of the distribution.
124      *
125      * @param hz Start random integer.
126      * @param iz Index of cell corresponding to {@code hz}.
127      * @return the requested random value.
128      */
129     private double fix(long hz,
130                        int iz) {
131         if (iz == 0) {
132             // Base strip.
133             // This branch is called about 2.55224E-4 times per sample.
134             double y;
135             double x;
136             do {
137                 // Avoid infinity by creating a non-zero double.
138                 // Note: The extreme value y from -Math.log(2^-53) is (to 4 sf):
139                 // y = 36.74
140                 // The largest value x where 2y < x^2 is false is sqrt(2*36.74):
141                 // x = 8.571
142                 // The extreme tail is:
143                 // out = +/- 12.01
144                 // To generate this requires longs of 0 and then (1377 << 11).
145                 y = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong()));
146                 x = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong())) * ONE_OVER_R;
147             } while (y + y < x * x);
148 
149             final double out = R + x;
150             return hz > 0 ? out : -out;
151         }
152         // Wedge of other strips.
153         // This branch is called about 0.0146584 times per sample.
154         final double x = hz * W[iz];
155         if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < pdf(x)) {
156             // This branch is called about 0.00797887 times per sample.
157             return x;
158         }
159         // Try again.
160         // This branch is called about 0.00667957 times per sample.
161         return sample();
162     }
163 
164     /**
165      * Compute the Gaussian probability density function {@code f(x) = e^-0.5x^2}.
166      *
167      * @param x Argument.
168      * @return \( e^{-\frac{x^2}{2}} \)
169      */
170     private static double pdf(double x) {
171         return Math.exp(-0.5 * x * x);
172     }
173 
174     /**
175      * {@inheritDoc}
176      *
177      * @since 1.3
178      */
179     @Override
180     public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) {
181         return new ZigguratNormalizedGaussianSampler(rng);
182     }
183 
184     /**
185      * Create a new normalised Gaussian sampler.
186      *
187      * @param <S> Sampler type.
188      * @param rng Generator of uniformly distributed random numbers.
189      * @return the sampler
190      * @since 1.3
191      */
192     @SuppressWarnings("unchecked")
193     public static <S extends NormalizedGaussianSampler & SharedStateContinuousSampler> S
194             of(UniformRandomProvider rng) {
195         return (S) new ZigguratNormalizedGaussianSampler(rng);
196     }
197 }