1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.rng.sampling.shape;
19
20 import org.apache.commons.rng.UniformRandomProvider;
21 import org.apache.commons.rng.sampling.SharedStateObjectSampler;
22 import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
23 import org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler;
24 import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
25
26 /**
27 * Generate coordinates <a href="http://mathworld.wolfram.com/BallPointPicking.html">
28 * uniformly distributed within the unit n-ball</a>.
29 *
30 * <p>Sampling uses:</p>
31 *
32 * <ul>
33 * <li>{@link UniformRandomProvider#nextLong()}
34 * <li>{@link UniformRandomProvider#nextDouble()} (only for dimensions above 2)
35 * </ul>
36 *
37 * @since 1.4
38 */
39 public abstract class UnitBallSampler implements SharedStateObjectSampler<double[]> {
40 /** The dimension for 1D sampling. */
41 private static final int ONE_D = 1;
42 /** The dimension for 2D sampling. */
43 private static final int TWO_D = 2;
44 /** The dimension for 3D sampling. */
45 private static final int THREE_D = 3;
46 /**
47 * The multiplier to convert the least significant 53-bits of a {@code long} to a {@code double}.
48 * Taken from o.a.c.rng.core.utils.NumberFactory.
49 *
50 * <p>This is equivalent to 1.0 / (1L << 53).
51 */
52 private static final double DOUBLE_MULTIPLIER = 0x1.0p-53d;
53
54 /**
55 * Sample uniformly from a 1D unit line.
56 */
57 private static class UnitBallSampler1D extends UnitBallSampler {
58 /** The source of randomness. */
59 private final UniformRandomProvider rng;
60
61 /**
62 * @param rng Source of randomness.
63 */
64 UnitBallSampler1D(UniformRandomProvider rng) {
65 this.rng = rng;
66 }
67
68 @Override
69 public double[] sample() {
70 return new double[] {makeSignedDouble(rng.nextLong())};
71 }
72
73 @Override
74 public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
75 return new UnitBallSampler1D(rng);
76 }
77 }
78
79 /**
80 * Sample uniformly from a 2D unit disk.
81 */
82 private static class UnitBallSampler2D extends UnitBallSampler {
83 /** The source of randomness. */
84 private final UniformRandomProvider rng;
85
86 /**
87 * @param rng Source of randomness.
88 */
89 UnitBallSampler2D(UniformRandomProvider rng) {
90 this.rng = rng;
91 }
92
93 @Override
94 public double[] sample() {
95 // Generate via rejection method of a circle inside a square of edge length 2.
96 // This should compute approximately 2^2 / pi = 1.27 square positions per sample.
97 double x;
98 double y;
99 do {
100 x = makeSignedDouble(rng.nextLong());
101 y = makeSignedDouble(rng.nextLong());
102 } while (x * x + y * y > 1.0);
103 return new double[] {x, y};
104 }
105
106 @Override
107 public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
108 return new UnitBallSampler2D(rng);
109 }
110 }
111
112 /**
113 * Sample uniformly from a 3D unit ball. This is an non-array based specialisation of
114 * {@link UnitBallSamplerND} for performance.
115 */
116 private static class UnitBallSampler3D extends UnitBallSampler {
117 /** The standard normal distribution. */
118 private final NormalizedGaussianSampler normal;
119 /** The exponential distribution (mean=1). */
120 private final ContinuousSampler exp;
121
122 /**
123 * @param rng Source of randomness.
124 */
125 UnitBallSampler3D(UniformRandomProvider rng) {
126 normal = ZigguratSampler.NormalizedGaussian.of(rng);
127 // Require an Exponential(mean=2).
128 // Here we use mean = 1 and scale the output later.
129 exp = ZigguratSampler.Exponential.of(rng);
130 }
131
132 @Override
133 public double[] sample() {
134 final double x = normal.sample();
135 final double y = normal.sample();
136 final double z = normal.sample();
137 // Include the exponential sample. It has mean 1 so multiply by 2.
138 final double sum = exp.sample() * 2 + x * x + y * y + z * z;
139 // Note: Handle the possibility of a zero sum and invalid inverse
140 if (sum == 0) {
141 return sample();
142 }
143 final double f = 1.0 / Math.sqrt(sum);
144 return new double[] {x * f, y * f, z * f};
145 }
146
147 @Override
148 public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
149 return new UnitBallSampler3D(rng);
150 }
151 }
152
153 /**
154 * Sample using ball point picking.
155 * @see <a href="https://mathworld.wolfram.com/BallPointPicking.html">Ball point picking</a>
156 */
157 private static class UnitBallSamplerND extends UnitBallSampler {
158 /** The dimension. */
159 private final int dimension;
160 /** The standard normal distribution. */
161 private final NormalizedGaussianSampler normal;
162 /** The exponential distribution (mean=1). */
163 private final ContinuousSampler exp;
164
165 /**
166 * @param rng Source of randomness.
167 * @param dimension Space dimension.
168 */
169 UnitBallSamplerND(UniformRandomProvider rng, int dimension) {
170 this.dimension = dimension;
171 normal = ZigguratSampler.NormalizedGaussian.of(rng);
172 // Require an Exponential(mean=2).
173 // Here we use mean = 1 and scale the output later.
174 exp = ZigguratSampler.Exponential.of(rng);
175 }
176
177 @Override
178 public double[] sample() {
179 final double[] sample = new double[dimension];
180 // Include the exponential sample. It has mean 1 so multiply by 2.
181 double sum = exp.sample() * 2;
182 for (int i = 0; i < dimension; i++) {
183 final double x = normal.sample();
184 sum += x * x;
185 sample[i] = x;
186 }
187 // Note: Handle the possibility of a zero sum and invalid inverse
188 if (sum == 0) {
189 return sample();
190 }
191 final double f = 1.0 / Math.sqrt(sum);
192 for (int i = 0; i < dimension; i++) {
193 sample[i] *= f;
194 }
195 return sample;
196 }
197
198 @Override
199 public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
200 return new UnitBallSamplerND(rng, dimension);
201 }
202 }
203
204 /**
205 * @return a random Cartesian coordinate within the unit n-ball.
206 */
207 @Override
208 public abstract double[] sample();
209
210 /** {@inheritDoc} */
211 // Redeclare the signature to return a UnitBallSampler not a SharedStateObjectSampler<double[]>
212 @Override
213 public abstract UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng);
214
215 /**
216 * Create a unit n-ball sampler for the given dimension.
217 * Sampled points are uniformly distributed within the unit n-ball.
218 *
219 * <p>Sampling is supported in dimensions of 1 or above.
220 *
221 * @param rng Source of randomness.
222 * @param dimension Space dimension.
223 * @return the sampler
224 * @throws IllegalArgumentException If {@code dimension <= 0}
225 */
226 public static UnitBallSampler of(UniformRandomProvider rng,
227 int dimension) {
228 if (dimension <= 0) {
229 throw new IllegalArgumentException("Dimension must be strictly positive");
230 } else if (dimension == ONE_D) {
231 return new UnitBallSampler1D(rng);
232 } else if (dimension == TWO_D) {
233 return new UnitBallSampler2D(rng);
234 } else if (dimension == THREE_D) {
235 return new UnitBallSampler3D(rng);
236 }
237 return new UnitBallSamplerND(rng, dimension);
238 }
239
240 /**
241 * Creates a signed double in the range {@code [-1, 1)}. The magnitude is sampled evenly
242 * from the 2<sup>54</sup> dyadic rationals in the range.
243 *
244 * <p>Note: This method will not return samples for both -0.0 and 0.0.
245 *
246 * @param bits the bits
247 * @return the double
248 */
249 private static double makeSignedDouble(long bits) {
250 // As per o.a.c.rng.core.utils.NumberFactory.makeDouble(long) but using a signed
251 // shift of 10 in place of an unsigned shift of 11.
252 // Use the upper 54 bits on the assumption they are more random.
253 // The sign bit is maintained by the signed shift.
254 // The next 53 bits generates a magnitude in the range [0, 2^53) or [-2^53, 0).
255 return (bits >> 10) * DOUBLE_MULTIPLIER;
256 }
257 }