1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.rng.sampling.distribution;
19
20 import org.apache.commons.rng.UniformRandomProvider;
21 import org.apache.commons.rng.sampling.SharedStateSampler;
22
23 /**
24 * Functions used by some of the samplers.
25 * This class is not part of the public API, as it would be
26 * better to group these utilities in a dedicated component.
27 */
28 final class InternalUtils {
29 /** All long-representable factorials, precomputed as the natural
30 * logarithm using Matlab R2023a VPA: log(vpa(x)). */
31 private static final double[] LOG_FACTORIALS = {
32 0,
33 0,
34 0.69314718055994530941723212145818,
35 1.7917594692280550008124773583807,
36 3.1780538303479456196469416012971,
37 4.7874917427820459942477009345232,
38 6.5792512120101009950601782929039,
39 8.5251613610654143001655310363471,
40 10.604602902745250228417227400722,
41 12.801827480081469611207717874567,
42 15.104412573075515295225709329251,
43 17.502307845873885839287652907216,
44 19.987214495661886149517362387055,
45 22.55216385312342288557084982862,
46 25.191221182738681500093434693522,
47 27.89927138384089156608943926367,
48 30.671860106080672803758367749503,
49 33.505073450136888884007902367376,
50 36.39544520803305357621562496268,
51 39.339884187199494036224652394567,
52 42.33561646075348502965987597071
53 };
54
55 /** The first array index with a non-zero log factorial. */
56 private static final int BEGIN_LOG_FACTORIALS = 2;
57
58 /**
59 * The multiplier to convert the least significant 53-bits of a {@code long} to a {@code double}.
60 * Taken from org.apache.commons.rng.core.util.NumberFactory.
61 */
62 private static final double DOUBLE_MULTIPLIER = 0x1.0p-53d;
63
64 /** Utility class. */
65 private InternalUtils() {}
66
67 /**
68 * Validate the probabilities sum to a finite positive number.
69 *
70 * @param probabilities the probabilities
71 * @return the sum
72 * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
73 * probability is negative, infinite or {@code NaN}, or the sum of all
74 * probabilities is not strictly positive.
75 */
76 static double validateProbabilities(double[] probabilities) {
77 if (probabilities == null || probabilities.length == 0) {
78 throw new IllegalArgumentException("Probabilities must not be empty.");
79 }
80
81 double sumProb = 0;
82 for (final double prob : probabilities) {
83 sumProb += requirePositiveFinite(prob, "probability");
84 }
85
86 return requireStrictlyPositiveFinite(sumProb, "sum of probabilities");
87 }
88
89 /**
90 * Checks the value {@code x} is finite.
91 *
92 * @param x Value.
93 * @param name Name of the value.
94 * @return x
95 * @throws IllegalArgumentException if {@code x} is non-finite
96 */
97 static double requireFinite(double x, String name) {
98 if (!Double.isFinite(x)) {
99 throw new IllegalArgumentException(name + " is not finite: " + x);
100 }
101 return x;
102 }
103
104 /**
105 * Checks the value {@code x >= 0} and is finite.
106 * Note: This method allows {@code x == -0.0}.
107 *
108 * @param x Value.
109 * @param name Name of the value.
110 * @return x
111 * @throws IllegalArgumentException if {@code x < 0} or is non-finite
112 */
113 static double requirePositiveFinite(double x, String name) {
114 if (!(x >= 0 && x < Double.POSITIVE_INFINITY)) {
115 throw new IllegalArgumentException(
116 name + " is not positive and finite: " + x);
117 }
118 return x;
119 }
120
121 /**
122 * Checks the value {@code x > 0} and is finite.
123 *
124 * @param x Value.
125 * @param name Name of the value.
126 * @return x
127 * @throws IllegalArgumentException if {@code x <= 0} or is non-finite
128 */
129 static double requireStrictlyPositiveFinite(double x, String name) {
130 if (!(x > 0 && x < Double.POSITIVE_INFINITY)) {
131 throw new IllegalArgumentException(
132 name + " is not strictly positive and finite: " + x);
133 }
134 return x;
135 }
136
137 /**
138 * Checks the value {@code x >= 0}.
139 * Note: This method allows {@code x == -0.0}.
140 *
141 * @param x Value.
142 * @param name Name of the value.
143 * @return x
144 * @throws IllegalArgumentException if {@code x < 0}
145 */
146 static double requirePositive(double x, String name) {
147 // Logic inversion detects NaN
148 if (!(x >= 0)) {
149 throw new IllegalArgumentException(name + " is not positive: " + x);
150 }
151 return x;
152 }
153
154 /**
155 * Checks the value {@code x > 0}.
156 *
157 * @param x Value.
158 * @param name Name of the value.
159 * @return x
160 * @throws IllegalArgumentException if {@code x <= 0}
161 */
162 static double requireStrictlyPositive(double x, String name) {
163 // Logic inversion detects NaN
164 if (!(x > 0)) {
165 throw new IllegalArgumentException(name + " is not strictly positive: " + x);
166 }
167 return x;
168 }
169
170 /**
171 * Checks the value is within the range: {@code min <= x < max}.
172 *
173 * @param min Minimum (inclusive).
174 * @param max Maximum (exclusive).
175 * @param x Value.
176 * @param name Name of the value.
177 * @return x
178 * @throws IllegalArgumentException if {@code x < min || x >= max}.
179 */
180 static double requireRange(double min, double max, double x, String name) {
181 if (!(min <= x && x < max)) {
182 throw new IllegalArgumentException(
183 String.format("%s not within range: %s <= %s < %s", name, min, x, max));
184 }
185 return x;
186 }
187
188 /**
189 * Checks the value is within the closed range: {@code min <= x <= max}.
190 *
191 * @param min Minimum (inclusive).
192 * @param max Maximum (inclusive).
193 * @param x Value.
194 * @param name Name of the value.
195 * @return x
196 * @throws IllegalArgumentException if {@code x < min || x > max}.
197 */
198 static double requireRangeClosed(double min, double max, double x, String name) {
199 if (!(min <= x && x <= max)) {
200 throw new IllegalArgumentException(
201 String.format("%s not within closed range: %s <= %s <= %s", name, min, x, max));
202 }
203 return x;
204 }
205
206 /**
207 * Create a new instance of the given sampler using
208 * {@link SharedStateSampler#withUniformRandomProvider(UniformRandomProvider)}.
209 *
210 * @param sampler Source sampler.
211 * @param rng Generator of uniformly distributed random numbers.
212 * @return the new sampler
213 * @throws UnsupportedOperationException if the underlying sampler is not a
214 * {@link SharedStateSampler} or does not return a {@link NormalizedGaussianSampler} when
215 * sharing state.
216 */
217 static NormalizedGaussianSampler newNormalizedGaussianSampler(
218 NormalizedGaussianSampler sampler,
219 UniformRandomProvider rng) {
220 if (!(sampler instanceof SharedStateSampler<?>)) {
221 throw new UnsupportedOperationException("The underlying sampler cannot share state");
222 }
223 final Object newSampler = ((SharedStateSampler<?>) sampler).withUniformRandomProvider(rng);
224 if (!(newSampler instanceof NormalizedGaussianSampler)) {
225 throw new UnsupportedOperationException(
226 "The underlying sampler did not create a normalized Gaussian sampler");
227 }
228 return (NormalizedGaussianSampler) newSampler;
229 }
230
231 /**
232 * Creates a {@code double} in the interval {@code [0, 1)} from a {@code long} value.
233 *
234 * @param v Number.
235 * @return a {@code double} value in the interval {@code [0, 1)}.
236 */
237 static double makeDouble(long v) {
238 // This matches the method in o.a.c.rng.core.util.NumberFactory.makeDouble(long)
239 // without adding an explicit dependency on that module.
240 return (v >>> 11) * DOUBLE_MULTIPLIER;
241 }
242
243 /**
244 * Creates a {@code double} in the interval {@code (0, 1]} from a {@code long} value.
245 *
246 * @param v Number.
247 * @return a {@code double} value in the interval {@code (0, 1]}.
248 */
249 static double makeNonZeroDouble(long v) {
250 // This matches the method in o.a.c.rng.core.util.NumberFactory.makeDouble(long)
251 // but shifts the range from [0, 1) to (0, 1].
252 return ((v >>> 11) + 1L) * DOUBLE_MULTIPLIER;
253 }
254
255 /**
256 * Class for computing the natural logarithm of the factorial of {@code n}.
257 * It allows to allocate a cache of precomputed values.
258 * In case of cache miss, computation is performed by a call to
259 * {@link InternalGamma#logGamma(double)}.
260 *
261 * <p>TODO: Remove public modifiers in next major release. This class is
262 * effectively package-private but changing triggers a false-positive
263 * binary compatibility failure using japicmp.
264 */
265 public static final class FactorialLog {
266 /**
267 * Precomputed values of the function:
268 * {@code LOG_FACTORIALS[i] = log(i!)}.
269 */
270 private final double[] logFactorials;
271
272 /**
273 * Creates an instance, reusing the already computed values if available.
274 *
275 * @param numValues Number of values of the function to compute.
276 * @param cache Existing cache.
277 * @throws NegativeArraySizeException if {@code numValues < 0}.
278 */
279 private FactorialLog(int numValues,
280 double[] cache) {
281 logFactorials = new double[numValues];
282
283 final int endCopy;
284 if (cache != null && cache.length > BEGIN_LOG_FACTORIALS) {
285 // Copy available values.
286 endCopy = Math.min(cache.length, numValues);
287 System.arraycopy(cache, BEGIN_LOG_FACTORIALS, logFactorials, BEGIN_LOG_FACTORIALS,
288 endCopy - BEGIN_LOG_FACTORIALS);
289 } else {
290 // All values to be computed
291 endCopy = BEGIN_LOG_FACTORIALS;
292 }
293
294 // Compute remaining values.
295 for (int i = endCopy; i < numValues; i++) {
296 if (i < LOG_FACTORIALS.length) {
297 logFactorials[i] = LOG_FACTORIALS[i];
298 } else {
299 logFactorials[i] = logFactorials[i - 1] + Math.log(i);
300 }
301 }
302 }
303
304 /**
305 * Creates an instance with no precomputed values.
306 *
307 * @return an instance with no precomputed values.
308 */
309 public static FactorialLog create() {
310 return new FactorialLog(0, null);
311 }
312
313 /**
314 * Creates an instance with the specified cache size.
315 *
316 * @param cacheSize Number of precomputed values of the function.
317 * @return a new instance where {@code cacheSize} values have been
318 * precomputed.
319 * @throws IllegalArgumentException if {@code n < 0}.
320 */
321 public FactorialLog withCache(final int cacheSize) {
322 return new FactorialLog(cacheSize, logFactorials);
323 }
324
325 /**
326 * Computes {@code log(n!)}.
327 *
328 * @param n Argument.
329 * @return {@code log(n!)}.
330 * @throws IndexOutOfBoundsException if {@code numValues < 0}.
331 */
332 public double value(final int n) {
333 // Use cache of precomputed values.
334 if (n < logFactorials.length) {
335 return logFactorials[n];
336 }
337
338 // Use cache of precomputed log factorial values.
339 if (n < LOG_FACTORIALS.length) {
340 return LOG_FACTORIALS[n];
341 }
342
343 // Delegate.
344 return InternalGamma.logGamma(n + 1.0);
345 }
346 }
347 }