View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.sampling.distribution;
19  
20  import org.apache.commons.rng.UniformRandomProvider;
21  import org.apache.commons.rng.sampling.SharedStateSampler;
22  
23  /**
24   * Functions used by some of the samplers.
25   * This class is not part of the public API, as it would be
26   * better to group these utilities in a dedicated component.
27   */
28  final class InternalUtils {
29      /** All long-representable factorials, precomputed as the natural
30       * logarithm using Matlab R2023a VPA: log(vpa(x)). */
31      private static final double[] LOG_FACTORIALS = {
32          0,
33          0,
34          0.69314718055994530941723212145818,
35          1.7917594692280550008124773583807,
36          3.1780538303479456196469416012971,
37          4.7874917427820459942477009345232,
38          6.5792512120101009950601782929039,
39          8.5251613610654143001655310363471,
40          10.604602902745250228417227400722,
41          12.801827480081469611207717874567,
42          15.104412573075515295225709329251,
43          17.502307845873885839287652907216,
44          19.987214495661886149517362387055,
45          22.55216385312342288557084982862,
46          25.191221182738681500093434693522,
47          27.89927138384089156608943926367,
48          30.671860106080672803758367749503,
49          33.505073450136888884007902367376,
50          36.39544520803305357621562496268,
51          39.339884187199494036224652394567,
52          42.33561646075348502965987597071
53      };
54  
55      /** The first array index with a non-zero log factorial. */
56      private static final int BEGIN_LOG_FACTORIALS = 2;
57  
58      /**
59       * The multiplier to convert the least significant 53-bits of a {@code long} to a {@code double}.
60       * Taken from org.apache.commons.rng.core.util.NumberFactory.
61       */
62      private static final double DOUBLE_MULTIPLIER = 0x1.0p-53d;
63  
64      /** Utility class. */
65      private InternalUtils() {}
66  
67      /**
68       * Validate the probabilities sum to a finite positive number.
69       *
70       * @param probabilities the probabilities
71       * @return the sum
72       * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
73       * probability is negative, infinite or {@code NaN}, or the sum of all
74       * probabilities is not strictly positive.
75       */
76      static double validateProbabilities(double[] probabilities) {
77          if (probabilities == null || probabilities.length == 0) {
78              throw new IllegalArgumentException("Probabilities must not be empty.");
79          }
80  
81          double sumProb = 0;
82          for (final double prob : probabilities) {
83              sumProb += requirePositiveFinite(prob, "probability");
84          }
85  
86          return requireStrictlyPositiveFinite(sumProb, "sum of probabilities");
87      }
88  
89      /**
90       * Checks the value {@code x} is finite.
91       *
92       * @param x Value.
93       * @param name Name of the value.
94       * @return x
95       * @throws IllegalArgumentException if {@code x} is non-finite
96       */
97      static double requireFinite(double x, String name) {
98          if (!Double.isFinite(x)) {
99              throw new IllegalArgumentException(name + " is not finite: " + x);
100         }
101         return x;
102     }
103 
104     /**
105      * Checks the value {@code x >= 0} and is finite.
106      * Note: This method allows {@code x == -0.0}.
107      *
108      * @param x Value.
109      * @param name Name of the value.
110      * @return x
111      * @throws IllegalArgumentException if {@code x < 0} or is non-finite
112      */
113     static double requirePositiveFinite(double x, String name) {
114         if (!(x >= 0 && x < Double.POSITIVE_INFINITY)) {
115             throw new IllegalArgumentException(
116                 name + " is not positive and finite: " + x);
117         }
118         return x;
119     }
120 
121     /**
122      * Checks the value {@code x > 0} and is finite.
123      *
124      * @param x Value.
125      * @param name Name of the value.
126      * @return x
127      * @throws IllegalArgumentException if {@code x <= 0} or is non-finite
128      */
129     static double requireStrictlyPositiveFinite(double x, String name) {
130         if (!(x > 0 && x < Double.POSITIVE_INFINITY)) {
131             throw new IllegalArgumentException(
132                 name + " is not strictly positive and finite: " + x);
133         }
134         return x;
135     }
136 
137     /**
138      * Checks the value {@code x >= 0}.
139      * Note: This method allows {@code x == -0.0}.
140      *
141      * @param x Value.
142      * @param name Name of the value.
143      * @return x
144      * @throws IllegalArgumentException if {@code x < 0}
145      */
146     static double requirePositive(double x, String name) {
147         // Logic inversion detects NaN
148         if (!(x >= 0)) {
149             throw new IllegalArgumentException(name + " is not positive: " + x);
150         }
151         return x;
152     }
153 
154     /**
155      * Checks the value {@code x > 0}.
156      *
157      * @param x Value.
158      * @param name Name of the value.
159      * @return x
160      * @throws IllegalArgumentException if {@code x <= 0}
161      */
162     static double requireStrictlyPositive(double x, String name) {
163         // Logic inversion detects NaN
164         if (!(x > 0)) {
165             throw new IllegalArgumentException(name + " is not strictly positive: " + x);
166         }
167         return x;
168     }
169 
170     /**
171      * Checks the value is within the range: {@code min <= x < max}.
172      *
173      * @param min Minimum (inclusive).
174      * @param max Maximum (exclusive).
175      * @param x Value.
176      * @param name Name of the value.
177      * @return x
178      * @throws IllegalArgumentException if {@code x < min || x >= max}.
179      */
180     static double requireRange(double min, double max, double x, String name) {
181         if (!(min <= x && x < max)) {
182             throw new IllegalArgumentException(
183                 String.format("%s not within range: %s <= %s < %s", name, min, x, max));
184         }
185         return x;
186     }
187 
188     /**
189      * Checks the value is within the closed range: {@code min <= x <= max}.
190      *
191      * @param min Minimum (inclusive).
192      * @param max Maximum (inclusive).
193      * @param x Value.
194      * @param name Name of the value.
195      * @return x
196      * @throws IllegalArgumentException if {@code x < min || x > max}.
197      */
198     static double requireRangeClosed(double min, double max, double x, String name) {
199         if (!(min <= x && x <= max)) {
200             throw new IllegalArgumentException(
201                 String.format("%s not within closed range: %s <= %s <= %s", name, min, x, max));
202         }
203         return x;
204     }
205 
206     /**
207      * Create a new instance of the given sampler using
208      * {@link SharedStateSampler#withUniformRandomProvider(UniformRandomProvider)}.
209      *
210      * @param sampler Source sampler.
211      * @param rng Generator of uniformly distributed random numbers.
212      * @return the new sampler
213      * @throws UnsupportedOperationException if the underlying sampler is not a
214      * {@link SharedStateSampler} or does not return a {@link NormalizedGaussianSampler} when
215      * sharing state.
216      */
217     static NormalizedGaussianSampler newNormalizedGaussianSampler(
218             NormalizedGaussianSampler sampler,
219             UniformRandomProvider rng) {
220         if (!(sampler instanceof SharedStateSampler<?>)) {
221             throw new UnsupportedOperationException("The underlying sampler cannot share state");
222         }
223         final Object newSampler = ((SharedStateSampler<?>) sampler).withUniformRandomProvider(rng);
224         if (!(newSampler instanceof NormalizedGaussianSampler)) {
225             throw new UnsupportedOperationException(
226                 "The underlying sampler did not create a normalized Gaussian sampler");
227         }
228         return (NormalizedGaussianSampler) newSampler;
229     }
230 
231     /**
232      * Creates a {@code double} in the interval {@code [0, 1)} from a {@code long} value.
233      *
234      * @param v Number.
235      * @return a {@code double} value in the interval {@code [0, 1)}.
236      */
237     static double makeDouble(long v) {
238         // This matches the method in o.a.c.rng.core.util.NumberFactory.makeDouble(long)
239         // without adding an explicit dependency on that module.
240         return (v >>> 11) * DOUBLE_MULTIPLIER;
241     }
242 
243     /**
244      * Creates a {@code double} in the interval {@code (0, 1]} from a {@code long} value.
245      *
246      * @param v Number.
247      * @return a {@code double} value in the interval {@code (0, 1]}.
248      */
249     static double makeNonZeroDouble(long v) {
250         // This matches the method in o.a.c.rng.core.util.NumberFactory.makeDouble(long)
251         // but shifts the range from [0, 1) to (0, 1].
252         return ((v >>> 11) + 1L) * DOUBLE_MULTIPLIER;
253     }
254 
255     /**
256      * Class for computing the natural logarithm of the factorial of {@code n}.
257      * It allows to allocate a cache of precomputed values.
258      * In case of cache miss, computation is performed by a call to
259      * {@link InternalGamma#logGamma(double)}.
260      *
261      * <p>TODO: Remove public modifiers in next major release. This class is
262      * effectively package-private but changing triggers a false-positive
263      * binary compatibility failure using japicmp.
264      */
265     public static final class FactorialLog {
266         /**
267          * Precomputed values of the function:
268          * {@code LOG_FACTORIALS[i] = log(i!)}.
269          */
270         private final double[] logFactorials;
271 
272         /**
273          * Creates an instance, reusing the already computed values if available.
274          *
275          * @param numValues Number of values of the function to compute.
276          * @param cache Existing cache.
277          * @throws NegativeArraySizeException if {@code numValues < 0}.
278          */
279         private FactorialLog(int numValues,
280                              double[] cache) {
281             logFactorials = new double[numValues];
282 
283             final int endCopy;
284             if (cache != null && cache.length > BEGIN_LOG_FACTORIALS) {
285                 // Copy available values.
286                 endCopy = Math.min(cache.length, numValues);
287                 System.arraycopy(cache, BEGIN_LOG_FACTORIALS, logFactorials, BEGIN_LOG_FACTORIALS,
288                     endCopy - BEGIN_LOG_FACTORIALS);
289             } else {
290                 // All values to be computed
291                 endCopy = BEGIN_LOG_FACTORIALS;
292             }
293 
294             // Compute remaining values.
295             for (int i = endCopy; i < numValues; i++) {
296                 if (i < LOG_FACTORIALS.length) {
297                     logFactorials[i] = LOG_FACTORIALS[i];
298                 } else {
299                     logFactorials[i] = logFactorials[i - 1] + Math.log(i);
300                 }
301             }
302         }
303 
304         /**
305          * Creates an instance with no precomputed values.
306          *
307          * @return an instance with no precomputed values.
308          */
309         public static FactorialLog create() {
310             return new FactorialLog(0, null);
311         }
312 
313         /**
314          * Creates an instance with the specified cache size.
315          *
316          * @param cacheSize Number of precomputed values of the function.
317          * @return a new instance where {@code cacheSize} values have been
318          * precomputed.
319          * @throws IllegalArgumentException if {@code n < 0}.
320          */
321         public FactorialLog withCache(final int cacheSize) {
322             return new FactorialLog(cacheSize, logFactorials);
323         }
324 
325         /**
326          * Computes {@code log(n!)}.
327          *
328          * @param n Argument.
329          * @return {@code log(n!)}.
330          * @throws IndexOutOfBoundsException if {@code numValues < 0}.
331          */
332         public double value(final int n) {
333             // Use cache of precomputed values.
334             if (n < logFactorials.length) {
335                 return logFactorials[n];
336             }
337 
338             // Use cache of precomputed log factorial values.
339             if (n < LOG_FACTORIALS.length) {
340                 return LOG_FACTORIALS[n];
341             }
342 
343             // Delegate.
344             return InternalGamma.logGamma(n + 1.0);
345         }
346     }
347 }