001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.rng.sampling.distribution; 019 020import org.apache.commons.rng.UniformRandomProvider; 021 022/** 023 * <a href="https://en.wikipedia.org/wiki/Ziggurat_algorithm"> 024 * Marsaglia and Tsang "Ziggurat" method</a> for sampling from a Gaussian 025 * distribution with mean 0 and standard deviation 1. 026 * 027 * <p>The algorithm is explained in this 028 * <a href="http://www.jstatsoft.org/article/view/v005i08/ziggurat.pdf">paper</a> 029 * and this implementation has been adapted from the C code provided therein.</p> 030 * 031 * <p>Sampling uses:</p> 032 * 033 * <ul> 034 * <li>{@link UniformRandomProvider#nextLong()} 035 * <li>{@link UniformRandomProvider#nextDouble()} 036 * </ul> 037 * 038 * @since 1.1 039 */ 040public class ZigguratNormalizedGaussianSampler 041 implements NormalizedGaussianSampler, SharedStateContinuousSampler { 042 /** Start of tail. */ 043 private static final double R = 3.6541528853610088; 044 /** Inverse of R. */ 045 private static final double ONE_OVER_R = 1 / R; 046 /** Index of last entry in the tables (which have a size that is a power of 2). */ 047 private static final int LAST = 255; 048 /** Auxiliary table. */ 049 private static final long[] K; 050 /** Auxiliary table. */ 051 private static final double[] W; 052 /** Auxiliary table. */ 053 private static final double[] F; 054 055 /** Underlying source of randomness. */ 056 private final UniformRandomProvider rng; 057 058 static { 059 // Filling the tables. 060 // Rectangle area. 061 final double v = 0.00492867323399; 062 // Direction support uses the sign bit so the maximum magnitude from the long is 2^63 063 final double max = Math.pow(2, 63); 064 final double oneOverMax = 1d / max; 065 066 K = new long[LAST + 1]; 067 W = new double[LAST + 1]; 068 F = new double[LAST + 1]; 069 070 double d = R; 071 double t = d; 072 double fd = pdf(d); 073 final double q = v / fd; 074 075 K[0] = (long) ((d / q) * max); 076 K[1] = 0; 077 078 W[0] = q * oneOverMax; 079 W[LAST] = d * oneOverMax; 080 081 F[0] = 1; 082 F[LAST] = fd; 083 084 for (int i = LAST - 1; i >= 1; i--) { 085 d = Math.sqrt(-2 * Math.log(v / d + fd)); 086 fd = pdf(d); 087 088 K[i + 1] = (long) ((d / t) * max); 089 t = d; 090 091 F[i] = fd; 092 093 W[i] = d * oneOverMax; 094 } 095 } 096 097 /** 098 * @param rng Generator of uniformly distributed random numbers. 099 */ 100 public ZigguratNormalizedGaussianSampler(UniformRandomProvider rng) { 101 this.rng = rng; 102 } 103 104 /** {@inheritDoc} */ 105 @Override 106 public double sample() { 107 final long j = rng.nextLong(); 108 final int i = ((int) j) & LAST; 109 if (Math.abs(j) < K[i]) { 110 // This branch is called about 0.985086 times per sample. 111 return j * W[i]; 112 } 113 return fix(j, i); 114 } 115 116 /** {@inheritDoc} */ 117 @Override 118 public String toString() { 119 return "Ziggurat normalized Gaussian deviate [" + rng.toString() + "]"; 120 } 121 122 /** 123 * Gets the value from the tail of the distribution. 124 * 125 * @param hz Start random integer. 126 * @param iz Index of cell corresponding to {@code hz}. 127 * @return the requested random value. 128 */ 129 private double fix(long hz, 130 int iz) { 131 if (iz == 0) { 132 // Base strip. 133 // This branch is called about 2.55224E-4 times per sample. 134 double y; 135 double x; 136 do { 137 // Avoid infinity by creating a non-zero double. 138 // Note: The extreme value y from -Math.log(2^-53) is (to 4 sf): 139 // y = 36.74 140 // The largest value x where 2y < x^2 is false is sqrt(2*36.74): 141 // x = 8.571 142 // The extreme tail is: 143 // out = +/- 12.01 144 // To generate this requires longs of 0 and then (1377 << 11). 145 y = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong())); 146 x = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong())) * ONE_OVER_R; 147 } while (y + y < x * x); 148 149 final double out = R + x; 150 return hz > 0 ? out : -out; 151 } 152 // Wedge of other strips. 153 // This branch is called about 0.0146584 times per sample. 154 final double x = hz * W[iz]; 155 if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < pdf(x)) { 156 // This branch is called about 0.00797887 times per sample. 157 return x; 158 } 159 // Try again. 160 // This branch is called about 0.00667957 times per sample. 161 return sample(); 162 } 163 164 /** 165 * Compute the Gaussian probability density function {@code f(x) = e^-0.5x^2}. 166 * 167 * @param x Argument. 168 * @return \( e^{-\frac{x^2}{2}} \) 169 */ 170 private static double pdf(double x) { 171 return Math.exp(-0.5 * x * x); 172 } 173 174 /** 175 * {@inheritDoc} 176 * 177 * @since 1.3 178 */ 179 @Override 180 public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) { 181 return new ZigguratNormalizedGaussianSampler(rng); 182 } 183 184 /** 185 * Create a new normalised Gaussian sampler. 186 * 187 * @param <S> Sampler type. 188 * @param rng Generator of uniformly distributed random numbers. 189 * @return the sampler 190 * @since 1.3 191 */ 192 @SuppressWarnings("unchecked") 193 public static <S extends NormalizedGaussianSampler & SharedStateContinuousSampler> S 194 of(UniformRandomProvider rng) { 195 return (S) new ZigguratNormalizedGaussianSampler(rng); 196 } 197}