001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.rng.sampling.distribution;
019
020import org.apache.commons.rng.UniformRandomProvider;
021
022/**
023 * <a href="https://en.wikipedia.org/wiki/Ziggurat_algorithm">
024 * Marsaglia and Tsang "Ziggurat" method</a> for sampling from a Gaussian
025 * distribution with mean 0 and standard deviation 1.
026 *
027 * The algorithm is explained in this
028 * <a href="http://www.jstatsoft.org/article/view/v005i08/ziggurat.pdf">paper</a>
029 * and this implementation has been adapted from the C code provided therein.
030 *
031 * @since 1.1
032 */
033public class ZigguratNormalizedGaussianSampler
034    implements NormalizedGaussianSampler {
035    /** Start of tail. */
036    private static final double R = 3.442619855899;
037    /** Inverse of R. */
038    private static final double ONE_OVER_R = 1 / R;
039    /** Rectangle area. */
040    private static final double V = 9.91256303526217e-3;
041    /** 2^63 */
042    private static final double MAX = Math.pow(2, 63);
043    /** 2^-63 */
044    private static final double ONE_OVER_MAX = 1d / MAX;
045    /** Number of entries. */
046    private static final int LEN = 128;
047    /** Index of last entry. */
048    private static final int LAST = LEN - 1;
049    /** Auxiliary table. */
050    private static final long[] K = new long[LEN];
051    /** Auxiliary table. */
052    private static final double[] W = new double[LEN];
053    /** Auxiliary table. */
054    private static final double[] F = new double[LEN];
055    /** Underlying source of randomness. */
056    private final UniformRandomProvider rng;
057
058    static {
059        // Filling the tables.
060
061        double d = R;
062        double t = d;
063        double fd = gauss(d);
064        final double q = V / fd;
065
066        K[0] = (long) ((d / q) * MAX);
067        K[1] = 0;
068
069        W[0] = q * ONE_OVER_MAX;
070        W[LAST] = d * ONE_OVER_MAX;
071
072        F[0] = 1;
073        F[LAST] = fd;
074
075        for (int i = LAST - 1; i >= 1; i--) {
076            d = Math.sqrt(-2 * Math.log(V / d + fd));
077            fd = gauss(d);
078
079            K[i + 1] = (long) ((d / t) * MAX);
080            t = d;
081
082            F[i] = fd;
083
084            W[i] = d * ONE_OVER_MAX;
085        }
086    }
087
088    /**
089     * @param rng Generator of uniformly distributed random numbers.
090     */
091    public ZigguratNormalizedGaussianSampler(UniformRandomProvider rng) {
092        this.rng = rng;
093    }
094
095    /** {@inheritDoc} */
096    @Override
097    public double sample() {
098        final long j = rng.nextLong();
099        final int i = (int) (j & LAST);
100        if (Math.abs(j) < K[i]) {
101            return j * W[i];
102        } else {
103            return fix(j, i);
104        }
105    }
106
107    /** {@inheritDoc} */
108    @Override
109    public String toString() {
110        return "Ziggurat normalized Gaussian deviate [" + rng.toString() + "]";
111    }
112
113    /**
114     * Gets the value from the tail of the distribution.
115     *
116     * @param hz Start random integer.
117     * @param iz Index of cell corresponding to {@code hz}.
118     * @return the requested random value.
119     */
120    private double fix(long hz,
121                       int iz) {
122        double x;
123        double y;
124
125        x = hz * W[iz];
126        if (iz == 0) {
127            // Base strip.
128            // This branch is called about 5.7624515E-4 times per sample.
129            do {
130                y = -Math.log(rng.nextDouble());
131                x = -Math.log(rng.nextDouble()) * ONE_OVER_R;
132            } while (y + y < x * x);
133
134            final double out = R + x;
135            return hz > 0 ? out : -out;
136        } else {
137            // Wedge of other strips.
138            // This branch is called about 0.027323 times per sample.
139            if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < gauss(x)) {
140                return x;
141            } else {
142                // Try again.
143                // This branch is called about 0.012362 times per sample.
144                return sample();
145            }
146        }
147    }
148
149    /**
150     * @param x Argument.
151     * @return \( e^{-\frac{x^2}{2}} \)
152     */
153    private static double gauss(double x) {
154        return Math.exp(-0.5 * x * x);
155    }
156}