1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.rng.sampling.distribution;
19
20 import org.apache.commons.rng.UniformRandomProvider;
21
22 /**
23 * <a href="https://en.wikipedia.org/wiki/Ziggurat_algorithm">
24 * Marsaglia and Tsang "Ziggurat" method</a> for sampling from a Gaussian
25 * distribution with mean 0 and standard deviation 1.
26 *
27 * <p>The algorithm is explained in this
28 * <a href="http://www.jstatsoft.org/article/view/v005i08/ziggurat.pdf">paper</a>
29 * and this implementation has been adapted from the C code provided therein.</p>
30 *
31 * <p>Sampling uses:</p>
32 *
33 * <ul>
34 * <li>{@link UniformRandomProvider#nextLong()}
35 * <li>{@link UniformRandomProvider#nextDouble()}
36 * </ul>
37 *
38 * @since 1.1
39 */
40 public class ZigguratNormalizedGaussianSampler
41 implements NormalizedGaussianSampler, SharedStateContinuousSampler {
42 /** Start of tail. */
43 private static final double R = 3.6541528853610088;
44 /** Inverse of R. */
45 private static final double ONE_OVER_R = 1 / R;
46 /** Index of last entry in the tables (which have a size that is a power of 2). */
47 private static final int LAST = 255;
48 /** Auxiliary table. */
49 private static final long[] K;
50 /** Auxiliary table. */
51 private static final double[] W;
52 /** Auxiliary table. */
53 private static final double[] F;
54
55 /** Underlying source of randomness. */
56 private final UniformRandomProvider rng;
57
58 static {
59 // Filling the tables.
60 // Rectangle area.
61 final double v = 0.00492867323399;
62 // Direction support uses the sign bit so the maximum magnitude from the long is 2^63
63 final double max = Math.pow(2, 63);
64 final double oneOverMax = 1d / max;
65
66 K = new long[LAST + 1];
67 W = new double[LAST + 1];
68 F = new double[LAST + 1];
69
70 double d = R;
71 double t = d;
72 double fd = pdf(d);
73 final double q = v / fd;
74
75 K[0] = (long) ((d / q) * max);
76 K[1] = 0;
77
78 W[0] = q * oneOverMax;
79 W[LAST] = d * oneOverMax;
80
81 F[0] = 1;
82 F[LAST] = fd;
83
84 for (int i = LAST - 1; i >= 1; i--) {
85 d = Math.sqrt(-2 * Math.log(v / d + fd));
86 fd = pdf(d);
87
88 K[i + 1] = (long) ((d / t) * max);
89 t = d;
90
91 F[i] = fd;
92
93 W[i] = d * oneOverMax;
94 }
95 }
96
97 /**
98 * Create an instance.
99 *
100 * @param rng Generator of uniformly distributed random numbers.
101 */
102 public ZigguratNormalizedGaussianSampler(UniformRandomProvider rng) {
103 this.rng = rng;
104 }
105
106 /** {@inheritDoc} */
107 @Override
108 public double sample() {
109 final long j = rng.nextLong();
110 final int i = ((int) j) & LAST;
111 if (Math.abs(j) < K[i]) {
112 // This branch is called about 0.985086 times per sample.
113 return j * W[i];
114 }
115 return fix(j, i);
116 }
117
118 /** {@inheritDoc} */
119 @Override
120 public String toString() {
121 return "Ziggurat normalized Gaussian deviate [" + rng.toString() + "]";
122 }
123
124 /**
125 * Gets the value from the tail of the distribution.
126 *
127 * @param hz Start random integer.
128 * @param iz Index of cell corresponding to {@code hz}.
129 * @return the requested random value.
130 */
131 private double fix(long hz,
132 int iz) {
133 if (iz == 0) {
134 // Base strip.
135 // This branch is called about 2.55224E-4 times per sample.
136 double y;
137 double x;
138 do {
139 // Avoid infinity by creating a non-zero double.
140 // Note: The extreme value y from -Math.log(2^-53) is (to 4 sf):
141 // y = 36.74
142 // The largest value x where 2y < x^2 is false is sqrt(2*36.74):
143 // x = 8.571
144 // The extreme tail is:
145 // out = +/- 12.01
146 // To generate this requires longs of 0 and then (1377 << 11).
147 y = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong()));
148 x = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong())) * ONE_OVER_R;
149 } while (y + y < x * x);
150
151 final double out = R + x;
152 return hz > 0 ? out : -out;
153 }
154 // Wedge of other strips.
155 // This branch is called about 0.0146584 times per sample.
156 final double x = hz * W[iz];
157 if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < pdf(x)) {
158 // This branch is called about 0.00797887 times per sample.
159 return x;
160 }
161 // Try again.
162 // This branch is called about 0.00667957 times per sample.
163 return sample();
164 }
165
166 /**
167 * Compute the Gaussian probability density function {@code f(x) = e^-0.5x^2}.
168 *
169 * @param x Argument.
170 * @return \( e^{-\frac{x^2}{2}} \)
171 */
172 private static double pdf(double x) {
173 return Math.exp(-0.5 * x * x);
174 }
175
176 /**
177 * {@inheritDoc}
178 *
179 * @since 1.3
180 */
181 @Override
182 public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) {
183 return new ZigguratNormalizedGaussianSampler(rng);
184 }
185
186 /**
187 * Create a new normalised Gaussian sampler.
188 *
189 * @param <S> Sampler type.
190 * @param rng Generator of uniformly distributed random numbers.
191 * @return the sampler
192 * @since 1.3
193 */
194 @SuppressWarnings("unchecked")
195 public static <S extends NormalizedGaussianSampler & SharedStateContinuousSampler> S
196 of(UniformRandomProvider rng) {
197 return (S) new ZigguratNormalizedGaussianSampler(rng);
198 }
199 }