ZigguratNormalizedGaussianSampler.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.rng.sampling.distribution;

  18. import org.apache.commons.rng.UniformRandomProvider;

  19. /**
  20.  * <a href="https://en.wikipedia.org/wiki/Ziggurat_algorithm">
  21.  * Marsaglia and Tsang "Ziggurat" method</a> for sampling from a Gaussian
  22.  * distribution with mean 0 and standard deviation 1.
  23.  *
  24.  * <p>The algorithm is explained in this
  25.  * <a href="http://www.jstatsoft.org/article/view/v005i08/ziggurat.pdf">paper</a>
  26.  * and this implementation has been adapted from the C code provided therein.</p>
  27.  *
  28.  * <p>Sampling uses:</p>
  29.  *
  30.  * <ul>
  31.  *   <li>{@link UniformRandomProvider#nextLong()}
  32.  *   <li>{@link UniformRandomProvider#nextDouble()}
  33.  * </ul>
  34.  *
  35.  * @since 1.1
  36.  */
  37. public class ZigguratNormalizedGaussianSampler
  38.     implements NormalizedGaussianSampler, SharedStateContinuousSampler {
  39.     /** Start of tail. */
  40.     private static final double R = 3.6541528853610088;
  41.     /** Inverse of R. */
  42.     private static final double ONE_OVER_R = 1 / R;
  43.     /** Index of last entry in the tables (which have a size that is a power of 2). */
  44.     private static final int LAST = 255;
  45.     /** Auxiliary table. */
  46.     private static final long[] K;
  47.     /** Auxiliary table. */
  48.     private static final double[] W;
  49.     /** Auxiliary table. */
  50.     private static final double[] F;

  51.     /** Underlying source of randomness. */
  52.     private final UniformRandomProvider rng;

  53.     static {
  54.         // Filling the tables.
  55.         // Rectangle area.
  56.         final double v = 0.00492867323399;
  57.         // Direction support uses the sign bit so the maximum magnitude from the long is 2^63
  58.         final double max = Math.pow(2, 63);
  59.         final double oneOverMax = 1d / max;

  60.         K = new long[LAST + 1];
  61.         W = new double[LAST + 1];
  62.         F = new double[LAST + 1];

  63.         double d = R;
  64.         double t = d;
  65.         double fd = pdf(d);
  66.         final double q = v / fd;

  67.         K[0] = (long) ((d / q) * max);
  68.         K[1] = 0;

  69.         W[0] = q * oneOverMax;
  70.         W[LAST] = d * oneOverMax;

  71.         F[0] = 1;
  72.         F[LAST] = fd;

  73.         for (int i = LAST - 1; i >= 1; i--) {
  74.             d = Math.sqrt(-2 * Math.log(v / d + fd));
  75.             fd = pdf(d);

  76.             K[i + 1] = (long) ((d / t) * max);
  77.             t = d;

  78.             F[i] = fd;

  79.             W[i] = d * oneOverMax;
  80.         }
  81.     }

  82.     /**
  83.      * Create an instance.
  84.      *
  85.      * @param rng Generator of uniformly distributed random numbers.
  86.      */
  87.     public ZigguratNormalizedGaussianSampler(UniformRandomProvider rng) {
  88.         this.rng = rng;
  89.     }

  90.     /** {@inheritDoc} */
  91.     @Override
  92.     public double sample() {
  93.         final long j = rng.nextLong();
  94.         final int i = ((int) j) & LAST;
  95.         if (Math.abs(j) < K[i]) {
  96.             // This branch is called about 0.985086 times per sample.
  97.             return j * W[i];
  98.         }
  99.         return fix(j, i);
  100.     }

  101.     /** {@inheritDoc} */
  102.     @Override
  103.     public String toString() {
  104.         return "Ziggurat normalized Gaussian deviate [" + rng.toString() + "]";
  105.     }

  106.     /**
  107.      * Gets the value from the tail of the distribution.
  108.      *
  109.      * @param hz Start random integer.
  110.      * @param iz Index of cell corresponding to {@code hz}.
  111.      * @return the requested random value.
  112.      */
  113.     private double fix(long hz,
  114.                        int iz) {
  115.         if (iz == 0) {
  116.             // Base strip.
  117.             // This branch is called about 2.55224E-4 times per sample.
  118.             double y;
  119.             double x;
  120.             do {
  121.                 // Avoid infinity by creating a non-zero double.
  122.                 // Note: The extreme value y from -Math.log(2^-53) is (to 4 sf):
  123.                 // y = 36.74
  124.                 // The largest value x where 2y < x^2 is false is sqrt(2*36.74):
  125.                 // x = 8.571
  126.                 // The extreme tail is:
  127.                 // out = +/- 12.01
  128.                 // To generate this requires longs of 0 and then (1377 << 11).
  129.                 y = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong()));
  130.                 x = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong())) * ONE_OVER_R;
  131.             } while (y + y < x * x);

  132.             final double out = R + x;
  133.             return hz > 0 ? out : -out;
  134.         }
  135.         // Wedge of other strips.
  136.         // This branch is called about 0.0146584 times per sample.
  137.         final double x = hz * W[iz];
  138.         if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < pdf(x)) {
  139.             // This branch is called about 0.00797887 times per sample.
  140.             return x;
  141.         }
  142.         // Try again.
  143.         // This branch is called about 0.00667957 times per sample.
  144.         return sample();
  145.     }

  146.     /**
  147.      * Compute the Gaussian probability density function {@code f(x) = e^-0.5x^2}.
  148.      *
  149.      * @param x Argument.
  150.      * @return \( e^{-\frac{x^2}{2}} \)
  151.      */
  152.     private static double pdf(double x) {
  153.         return Math.exp(-0.5 * x * x);
  154.     }

  155.     /**
  156.      * {@inheritDoc}
  157.      *
  158.      * @since 1.3
  159.      */
  160.     @Override
  161.     public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) {
  162.         return new ZigguratNormalizedGaussianSampler(rng);
  163.     }

  164.     /**
  165.      * Create a new normalised Gaussian sampler.
  166.      *
  167.      * @param <S> Sampler type.
  168.      * @param rng Generator of uniformly distributed random numbers.
  169.      * @return the sampler
  170.      * @since 1.3
  171.      */
  172.     @SuppressWarnings("unchecked")
  173.     public static <S extends NormalizedGaussianSampler & SharedStateContinuousSampler> S
  174.             of(UniformRandomProvider rng) {
  175.         return (S) new ZigguratNormalizedGaussianSampler(rng);
  176.     }
  177. }