ZigguratNormalizedGaussianSampler.java
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- package org.apache.commons.rng.sampling.distribution;
- import org.apache.commons.rng.UniformRandomProvider;
- /**
- * <a href="https://en.wikipedia.org/wiki/Ziggurat_algorithm">
- * Marsaglia and Tsang "Ziggurat" method</a> for sampling from a Gaussian
- * distribution with mean 0 and standard deviation 1.
- *
- * <p>The algorithm is explained in this
- * <a href="http://www.jstatsoft.org/article/view/v005i08/ziggurat.pdf">paper</a>
- * and this implementation has been adapted from the C code provided therein.</p>
- *
- * <p>Sampling uses:</p>
- *
- * <ul>
- * <li>{@link UniformRandomProvider#nextLong()}
- * <li>{@link UniformRandomProvider#nextDouble()}
- * </ul>
- *
- * @since 1.1
- */
- public class ZigguratNormalizedGaussianSampler
- implements NormalizedGaussianSampler, SharedStateContinuousSampler {
- /** Start of tail. */
- private static final double R = 3.6541528853610088;
- /** Inverse of R. */
- private static final double ONE_OVER_R = 1 / R;
- /** Index of last entry in the tables (which have a size that is a power of 2). */
- private static final int LAST = 255;
- /** Auxiliary table. */
- private static final long[] K;
- /** Auxiliary table. */
- private static final double[] W;
- /** Auxiliary table. */
- private static final double[] F;
- /** Underlying source of randomness. */
- private final UniformRandomProvider rng;
- static {
- // Filling the tables.
- // Rectangle area.
- final double v = 0.00492867323399;
- // Direction support uses the sign bit so the maximum magnitude from the long is 2^63
- final double max = Math.pow(2, 63);
- final double oneOverMax = 1d / max;
- K = new long[LAST + 1];
- W = new double[LAST + 1];
- F = new double[LAST + 1];
- double d = R;
- double t = d;
- double fd = pdf(d);
- final double q = v / fd;
- K[0] = (long) ((d / q) * max);
- K[1] = 0;
- W[0] = q * oneOverMax;
- W[LAST] = d * oneOverMax;
- F[0] = 1;
- F[LAST] = fd;
- for (int i = LAST - 1; i >= 1; i--) {
- d = Math.sqrt(-2 * Math.log(v / d + fd));
- fd = pdf(d);
- K[i + 1] = (long) ((d / t) * max);
- t = d;
- F[i] = fd;
- W[i] = d * oneOverMax;
- }
- }
- /**
- * Create an instance.
- *
- * @param rng Generator of uniformly distributed random numbers.
- */
- public ZigguratNormalizedGaussianSampler(UniformRandomProvider rng) {
- this.rng = rng;
- }
- /** {@inheritDoc} */
- @Override
- public double sample() {
- final long j = rng.nextLong();
- final int i = ((int) j) & LAST;
- if (Math.abs(j) < K[i]) {
- // This branch is called about 0.985086 times per sample.
- return j * W[i];
- }
- return fix(j, i);
- }
- /** {@inheritDoc} */
- @Override
- public String toString() {
- return "Ziggurat normalized Gaussian deviate [" + rng.toString() + "]";
- }
- /**
- * Gets the value from the tail of the distribution.
- *
- * @param hz Start random integer.
- * @param iz Index of cell corresponding to {@code hz}.
- * @return the requested random value.
- */
- private double fix(long hz,
- int iz) {
- if (iz == 0) {
- // Base strip.
- // This branch is called about 2.55224E-4 times per sample.
- double y;
- double x;
- do {
- // Avoid infinity by creating a non-zero double.
- // Note: The extreme value y from -Math.log(2^-53) is (to 4 sf):
- // y = 36.74
- // The largest value x where 2y < x^2 is false is sqrt(2*36.74):
- // x = 8.571
- // The extreme tail is:
- // out = +/- 12.01
- // To generate this requires longs of 0 and then (1377 << 11).
- y = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong()));
- x = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong())) * ONE_OVER_R;
- } while (y + y < x * x);
- final double out = R + x;
- return hz > 0 ? out : -out;
- }
- // Wedge of other strips.
- // This branch is called about 0.0146584 times per sample.
- final double x = hz * W[iz];
- if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < pdf(x)) {
- // This branch is called about 0.00797887 times per sample.
- return x;
- }
- // Try again.
- // This branch is called about 0.00667957 times per sample.
- return sample();
- }
- /**
- * Compute the Gaussian probability density function {@code f(x) = e^-0.5x^2}.
- *
- * @param x Argument.
- * @return \( e^{-\frac{x^2}{2}} \)
- */
- private static double pdf(double x) {
- return Math.exp(-0.5 * x * x);
- }
- /**
- * {@inheritDoc}
- *
- * @since 1.3
- */
- @Override
- public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) {
- return new ZigguratNormalizedGaussianSampler(rng);
- }
- /**
- * Create a new normalised Gaussian sampler.
- *
- * @param <S> Sampler type.
- * @param rng Generator of uniformly distributed random numbers.
- * @return the sampler
- * @since 1.3
- */
- @SuppressWarnings("unchecked")
- public static <S extends NormalizedGaussianSampler & SharedStateContinuousSampler> S
- of(UniformRandomProvider rng) {
- return (S) new ZigguratNormalizedGaussianSampler(rng);
- }
- }