1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.commons.rng.sampling.distribution; 19 20 import org.apache.commons.rng.UniformRandomProvider; 21 22 /** 23 * <a href="https://en.wikipedia.org/wiki/Ziggurat_algorithm"> 24 * Marsaglia and Tsang "Ziggurat" method</a> for sampling from a Gaussian 25 * distribution with mean 0 and standard deviation 1. 26 * 27 * <p>The algorithm is explained in this 28 * <a href="http://www.jstatsoft.org/article/view/v005i08/ziggurat.pdf">paper</a> 29 * and this implementation has been adapted from the C code provided therein.</p> 30 * 31 * <p>Sampling uses:</p> 32 * 33 * <ul> 34 * <li>{@link UniformRandomProvider#nextLong()} 35 * <li>{@link UniformRandomProvider#nextDouble()} 36 * </ul> 37 * 38 * @since 1.1 39 */ 40 public class ZigguratNormalizedGaussianSampler 41 implements NormalizedGaussianSampler, SharedStateContinuousSampler { 42 /** Start of tail. */ 43 private static final double R = 3.6541528853610088; 44 /** Inverse of R. */ 45 private static final double ONE_OVER_R = 1 / R; 46 /** Index of last entry in the tables (which have a size that is a power of 2). */ 47 private static final int LAST = 255; 48 /** Auxiliary table. */ 49 private static final long[] K; 50 /** Auxiliary table. */ 51 private static final double[] W; 52 /** Auxiliary table. */ 53 private static final double[] F; 54 55 /** Underlying source of randomness. */ 56 private final UniformRandomProvider rng; 57 58 static { 59 // Filling the tables. 60 // Rectangle area. 61 final double v = 0.00492867323399; 62 // Direction support uses the sign bit so the maximum magnitude from the long is 2^63 63 final double max = Math.pow(2, 63); 64 final double oneOverMax = 1d / max; 65 66 K = new long[LAST + 1]; 67 W = new double[LAST + 1]; 68 F = new double[LAST + 1]; 69 70 double d = R; 71 double t = d; 72 double fd = pdf(d); 73 final double q = v / fd; 74 75 K[0] = (long) ((d / q) * max); 76 K[1] = 0; 77 78 W[0] = q * oneOverMax; 79 W[LAST] = d * oneOverMax; 80 81 F[0] = 1; 82 F[LAST] = fd; 83 84 for (int i = LAST - 1; i >= 1; i--) { 85 d = Math.sqrt(-2 * Math.log(v / d + fd)); 86 fd = pdf(d); 87 88 K[i + 1] = (long) ((d / t) * max); 89 t = d; 90 91 F[i] = fd; 92 93 W[i] = d * oneOverMax; 94 } 95 } 96 97 /** 98 * @param rng Generator of uniformly distributed random numbers. 99 */ 100 public ZigguratNormalizedGaussianSampler(UniformRandomProvider rng) { 101 this.rng = rng; 102 } 103 104 /** {@inheritDoc} */ 105 @Override 106 public double sample() { 107 final long j = rng.nextLong(); 108 final int i = ((int) j) & LAST; 109 if (Math.abs(j) < K[i]) { 110 // This branch is called about 0.985086 times per sample. 111 return j * W[i]; 112 } 113 return fix(j, i); 114 } 115 116 /** {@inheritDoc} */ 117 @Override 118 public String toString() { 119 return "Ziggurat normalized Gaussian deviate [" + rng.toString() + "]"; 120 } 121 122 /** 123 * Gets the value from the tail of the distribution. 124 * 125 * @param hz Start random integer. 126 * @param iz Index of cell corresponding to {@code hz}. 127 * @return the requested random value. 128 */ 129 private double fix(long hz, 130 int iz) { 131 if (iz == 0) { 132 // Base strip. 133 // This branch is called about 2.55224E-4 times per sample. 134 double y; 135 double x; 136 do { 137 // Avoid infinity by creating a non-zero double. 138 // Note: The extreme value y from -Math.log(2^-53) is (to 4 sf): 139 // y = 36.74 140 // The largest value x where 2y < x^2 is false is sqrt(2*36.74): 141 // x = 8.571 142 // The extreme tail is: 143 // out = +/- 12.01 144 // To generate this requires longs of 0 and then (1377 << 11). 145 y = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong())); 146 x = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong())) * ONE_OVER_R; 147 } while (y + y < x * x); 148 149 final double out = R + x; 150 return hz > 0 ? out : -out; 151 } 152 // Wedge of other strips. 153 // This branch is called about 0.0146584 times per sample. 154 final double x = hz * W[iz]; 155 if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < pdf(x)) { 156 // This branch is called about 0.00797887 times per sample. 157 return x; 158 } 159 // Try again. 160 // This branch is called about 0.00667957 times per sample. 161 return sample(); 162 } 163 164 /** 165 * Compute the Gaussian probability density function {@code f(x) = e^-0.5x^2}. 166 * 167 * @param x Argument. 168 * @return \( e^{-\frac{x^2}{2}} \) 169 */ 170 private static double pdf(double x) { 171 return Math.exp(-0.5 * x * x); 172 } 173 174 /** 175 * {@inheritDoc} 176 * 177 * @since 1.3 178 */ 179 @Override 180 public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) { 181 return new ZigguratNormalizedGaussianSampler(rng); 182 } 183 184 /** 185 * Create a new normalised Gaussian sampler. 186 * 187 * @param <S> Sampler type. 188 * @param rng Generator of uniformly distributed random numbers. 189 * @return the sampler 190 * @since 1.3 191 */ 192 @SuppressWarnings("unchecked") 193 public static <S extends NormalizedGaussianSampler & SharedStateContinuousSampler> S 194 of(UniformRandomProvider rng) { 195 return (S) new ZigguratNormalizedGaussianSampler(rng); 196 } 197 }