1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.rng.sampling.distribution;
19
20 import org.apache.commons.rng.UniformRandomProvider;
21 import org.apache.commons.rng.sampling.SharedStateSampler;
22
23 /**
24 * Functions used by some of the samplers.
25 * This class is not part of the public API, as it would be
26 * better to group these utilities in a dedicated component.
27 */
28 final class InternalUtils {
29 /** All long-representable factorials, precomputed as the natural
30 * logarithm using Matlab R2023a VPA: log(vpa(x)).
31 *
32 * <p>Note: This table could be any length. Previously this stored
33 * the long value of n!, not log(n!). Using the previous length
34 * maintains behaviour. */
35 private static final double[] LOG_FACTORIALS = {
36 0,
37 0,
38 0.69314718055994530941723212145818,
39 1.7917594692280550008124773583807,
40 3.1780538303479456196469416012971,
41 4.7874917427820459942477009345232,
42 6.5792512120101009950601782929039,
43 8.5251613610654143001655310363471,
44 10.604602902745250228417227400722,
45 12.801827480081469611207717874567,
46 15.104412573075515295225709329251,
47 17.502307845873885839287652907216,
48 19.987214495661886149517362387055,
49 22.55216385312342288557084982862,
50 25.191221182738681500093434693522,
51 27.89927138384089156608943926367,
52 30.671860106080672803758367749503,
53 33.505073450136888884007902367376,
54 36.39544520803305357621562496268,
55 39.339884187199494036224652394567,
56 42.33561646075348502965987597071
57 };
58
59 /** The first array index with a non-zero log factorial. */
60 private static final int BEGIN_LOG_FACTORIALS = 2;
61
62 /**
63 * The multiplier to convert the least significant 53-bits of a {@code long} to a {@code double}.
64 * Taken from org.apache.commons.rng.core.util.NumberFactory.
65 */
66 private static final double DOUBLE_MULTIPLIER = 0x1.0p-53d;
67
68 /** Utility class. */
69 private InternalUtils() {}
70
71 /**
72 * @param n Argument.
73 * @return {@code n!}
74 * @throws IndexOutOfBoundsException if the result is too large to be represented
75 * by a {@code long} (i.e. if {@code n > 20}), or {@code n} is negative.
76 */
77 static double logFactorial(int n) {
78 return LOG_FACTORIALS[n];
79 }
80
81 /**
82 * Validate the probabilities sum to a finite positive number.
83 *
84 * @param probabilities the probabilities
85 * @return the sum
86 * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
87 * probability is negative, infinite or {@code NaN}, or the sum of all
88 * probabilities is not strictly positive.
89 */
90 static double validateProbabilities(double[] probabilities) {
91 if (probabilities == null || probabilities.length == 0) {
92 throw new IllegalArgumentException("Probabilities must not be empty.");
93 }
94
95 double sumProb = 0;
96 for (final double prob : probabilities) {
97 sumProb += requirePositiveFinite(prob, "probability");
98 }
99
100 return requireStrictlyPositiveFinite(sumProb, "sum of probabilities");
101 }
102
103 /**
104 * Checks the value {@code x} is finite.
105 *
106 * @param x Value.
107 * @param name Name of the value.
108 * @return x
109 * @throws IllegalArgumentException if {@code x} is non-finite
110 */
111 static double requireFinite(double x, String name) {
112 if (!Double.isFinite(x)) {
113 throw new IllegalArgumentException(name + " is not finite: " + x);
114 }
115 return x;
116 }
117
118 /**
119 * Checks the value {@code x >= 0} and is finite.
120 * Note: This method allows {@code x == -0.0}.
121 *
122 * @param x Value.
123 * @param name Name of the value.
124 * @return x
125 * @throws IllegalArgumentException if {@code x < 0} or is non-finite
126 */
127 static double requirePositiveFinite(double x, String name) {
128 if (!(x >= 0 && x < Double.POSITIVE_INFINITY)) {
129 throw new IllegalArgumentException(
130 name + " is not positive and finite: " + x);
131 }
132 return x;
133 }
134
135 /**
136 * Checks the value {@code x > 0} and is finite.
137 *
138 * @param x Value.
139 * @param name Name of the value.
140 * @return x
141 * @throws IllegalArgumentException if {@code x <= 0} or is non-finite
142 */
143 static double requireStrictlyPositiveFinite(double x, String name) {
144 if (!(x > 0 && x < Double.POSITIVE_INFINITY)) {
145 throw new IllegalArgumentException(
146 name + " is not strictly positive and finite: " + x);
147 }
148 return x;
149 }
150
151 /**
152 * Checks the value {@code x >= 0}.
153 * Note: This method allows {@code x == -0.0}.
154 *
155 * @param x Value.
156 * @param name Name of the value.
157 * @return x
158 * @throws IllegalArgumentException if {@code x < 0}
159 */
160 static double requirePositive(double x, String name) {
161 // Logic inversion detects NaN
162 if (!(x >= 0)) {
163 throw new IllegalArgumentException(name + " is not positive: " + x);
164 }
165 return x;
166 }
167
168 /**
169 * Checks the value {@code x > 0}.
170 *
171 * @param x Value.
172 * @param name Name of the value.
173 * @return x
174 * @throws IllegalArgumentException if {@code x <= 0}
175 */
176 static double requireStrictlyPositive(double x, String name) {
177 // Logic inversion detects NaN
178 if (!(x > 0)) {
179 throw new IllegalArgumentException(name + " is not strictly positive: " + x);
180 }
181 return x;
182 }
183
184 /**
185 * Checks the value is within the range: {@code min <= x < max}.
186 *
187 * @param min Minimum (inclusive).
188 * @param max Maximum (exclusive).
189 * @param x Value.
190 * @param name Name of the value.
191 * @return x
192 * @throws IllegalArgumentException if {@code x < min || x >= max}.
193 */
194 static double requireRange(double min, double max, double x, String name) {
195 if (!(min <= x && x < max)) {
196 throw new IllegalArgumentException(
197 String.format("%s not within range: %s <= %s < %s", name, min, x, max));
198 }
199 return x;
200 }
201
202 /**
203 * Checks the value is within the closed range: {@code min <= x <= max}.
204 *
205 * @param min Minimum (inclusive).
206 * @param max Maximum (inclusive).
207 * @param x Value.
208 * @param name Name of the value.
209 * @return x
210 * @throws IllegalArgumentException if {@code x < min || x > max}.
211 */
212 static double requireRangeClosed(double min, double max, double x, String name) {
213 if (!(min <= x && x <= max)) {
214 throw new IllegalArgumentException(
215 String.format("%s not within closed range: %s <= %s <= %s", name, min, x, max));
216 }
217 return x;
218 }
219
220 /**
221 * Create a new instance of the given sampler using
222 * {@link SharedStateSampler#withUniformRandomProvider(UniformRandomProvider)}.
223 *
224 * @param sampler Source sampler.
225 * @param rng Generator of uniformly distributed random numbers.
226 * @return the new sampler
227 * @throws UnsupportedOperationException if the underlying sampler is not a
228 * {@link SharedStateSampler} or does not return a {@link NormalizedGaussianSampler} when
229 * sharing state.
230 */
231 static NormalizedGaussianSampler newNormalizedGaussianSampler(
232 NormalizedGaussianSampler sampler,
233 UniformRandomProvider rng) {
234 if (!(sampler instanceof SharedStateSampler<?>)) {
235 throw new UnsupportedOperationException("The underlying sampler cannot share state");
236 }
237 final Object newSampler = ((SharedStateSampler<?>) sampler).withUniformRandomProvider(rng);
238 if (!(newSampler instanceof NormalizedGaussianSampler)) {
239 throw new UnsupportedOperationException(
240 "The underlying sampler did not create a normalized Gaussian sampler");
241 }
242 return (NormalizedGaussianSampler) newSampler;
243 }
244
245 /**
246 * Creates a {@code double} in the interval {@code [0, 1)} from a {@code long} value.
247 *
248 * @param v Number.
249 * @return a {@code double} value in the interval {@code [0, 1)}.
250 */
251 static double makeDouble(long v) {
252 // This matches the method in o.a.c.rng.core.util.NumberFactory.makeDouble(long)
253 // without adding an explicit dependency on that module.
254 return (v >>> 11) * DOUBLE_MULTIPLIER;
255 }
256
257 /**
258 * Creates a {@code double} in the interval {@code (0, 1]} from a {@code long} value.
259 *
260 * @param v Number.
261 * @return a {@code double} value in the interval {@code (0, 1]}.
262 */
263 static double makeNonZeroDouble(long v) {
264 // This matches the method in o.a.c.rng.core.util.NumberFactory.makeDouble(long)
265 // but shifts the range from [0, 1) to (0, 1].
266 return ((v >>> 11) + 1L) * DOUBLE_MULTIPLIER;
267 }
268
269 /**
270 * Class for computing the natural logarithm of the factorial of {@code n}.
271 * It allows to allocate a cache of precomputed values.
272 * In case of cache miss, computation is performed by a call to
273 * {@link InternalGamma#logGamma(double)}.
274 */
275 public static final class FactorialLog {
276 /**
277 * Precomputed values of the function:
278 * {@code LOG_FACTORIALS[i] = log(i!)}.
279 */
280 private final double[] logFactorials;
281
282 /**
283 * Creates an instance, reusing the already computed values if available.
284 *
285 * @param numValues Number of values of the function to compute.
286 * @param cache Existing cache.
287 * @throws NegativeArraySizeException if {@code numValues < 0}.
288 */
289 private FactorialLog(int numValues,
290 double[] cache) {
291 logFactorials = new double[numValues];
292
293 final int endCopy;
294 if (cache != null && cache.length > BEGIN_LOG_FACTORIALS) {
295 // Copy available values.
296 endCopy = Math.min(cache.length, numValues);
297 System.arraycopy(cache, BEGIN_LOG_FACTORIALS, logFactorials, BEGIN_LOG_FACTORIALS,
298 endCopy - BEGIN_LOG_FACTORIALS);
299 } else {
300 // All values to be computed
301 endCopy = BEGIN_LOG_FACTORIALS;
302 }
303
304 // Compute remaining values.
305 for (int i = endCopy; i < numValues; i++) {
306 if (i < LOG_FACTORIALS.length) {
307 logFactorials[i] = LOG_FACTORIALS[i];
308 } else {
309 logFactorials[i] = logFactorials[i - 1] + Math.log(i);
310 }
311 }
312 }
313
314 /**
315 * Creates an instance with no precomputed values.
316 *
317 * @return an instance with no precomputed values.
318 */
319 public static FactorialLog create() {
320 return new FactorialLog(0, null);
321 }
322
323 /**
324 * Creates an instance with the specified cache size.
325 *
326 * @param cacheSize Number of precomputed values of the function.
327 * @return a new instance where {@code cacheSize} values have been
328 * precomputed.
329 * @throws IllegalArgumentException if {@code n < 0}.
330 */
331 public FactorialLog withCache(final int cacheSize) {
332 return new FactorialLog(cacheSize, logFactorials);
333 }
334
335 /**
336 * Computes {@code log(n!)}.
337 *
338 * @param n Argument.
339 * @return {@code log(n!)}.
340 * @throws IndexOutOfBoundsException if {@code numValues < 0}.
341 */
342 public double value(final int n) {
343 // Use cache of precomputed values.
344 if (n < logFactorials.length) {
345 return logFactorials[n];
346 }
347
348 // Use cache of precomputed log factorial values.
349 if (n < LOG_FACTORIALS.length) {
350 return LOG_FACTORIALS[n];
351 }
352
353 // Delegate.
354 return InternalGamma.logGamma(n + 1.0);
355 }
356 }
357 }