View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.sampling.distribution;
19  
20  import org.apache.commons.rng.UniformRandomProvider;
21  import org.apache.commons.rng.sampling.SharedStateSampler;
22  
23  /**
24   * Functions used by some of the samplers.
25   * This class is not part of the public API, as it would be
26   * better to group these utilities in a dedicated component.
27   */
28  final class InternalUtils {
29      /** All long-representable factorials, precomputed as the natural
30       * logarithm using Matlab R2023a VPA: log(vpa(x)).
31       *
32       * <p>Note: This table could be any length. Previously this stored
33       * the long value of n!, not log(n!). Using the previous length
34       * maintains behaviour. */
35      private static final double[] LOG_FACTORIALS = {
36          0,
37          0,
38          0.69314718055994530941723212145818,
39          1.7917594692280550008124773583807,
40          3.1780538303479456196469416012971,
41          4.7874917427820459942477009345232,
42          6.5792512120101009950601782929039,
43          8.5251613610654143001655310363471,
44          10.604602902745250228417227400722,
45          12.801827480081469611207717874567,
46          15.104412573075515295225709329251,
47          17.502307845873885839287652907216,
48          19.987214495661886149517362387055,
49          22.55216385312342288557084982862,
50          25.191221182738681500093434693522,
51          27.89927138384089156608943926367,
52          30.671860106080672803758367749503,
53          33.505073450136888884007902367376,
54          36.39544520803305357621562496268,
55          39.339884187199494036224652394567,
56          42.33561646075348502965987597071
57      };
58  
59      /** The first array index with a non-zero log factorial. */
60      private static final int BEGIN_LOG_FACTORIALS = 2;
61  
62      /**
63       * The multiplier to convert the least significant 53-bits of a {@code long} to a {@code double}.
64       * Taken from org.apache.commons.rng.core.util.NumberFactory.
65       */
66      private static final double DOUBLE_MULTIPLIER = 0x1.0p-53d;
67  
68      /** Utility class. */
69      private InternalUtils() {}
70  
71      /**
72       * @param n Argument.
73       * @return {@code n!}
74       * @throws IndexOutOfBoundsException if the result is too large to be represented
75       * by a {@code long} (i.e. if {@code n > 20}), or {@code n} is negative.
76       */
77      static double logFactorial(int n)  {
78          return LOG_FACTORIALS[n];
79      }
80  
81      /**
82       * Validate the probabilities sum to a finite positive number.
83       *
84       * @param probabilities the probabilities
85       * @return the sum
86       * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
87       * probability is negative, infinite or {@code NaN}, or the sum of all
88       * probabilities is not strictly positive.
89       */
90      static double validateProbabilities(double[] probabilities) {
91          if (probabilities == null || probabilities.length == 0) {
92              throw new IllegalArgumentException("Probabilities must not be empty.");
93          }
94  
95          double sumProb = 0;
96          for (final double prob : probabilities) {
97              sumProb += requirePositiveFinite(prob, "probability");
98          }
99  
100         return requireStrictlyPositiveFinite(sumProb, "sum of probabilities");
101     }
102 
103     /**
104      * Checks the value {@code x} is finite.
105      *
106      * @param x Value.
107      * @param name Name of the value.
108      * @return x
109      * @throws IllegalArgumentException if {@code x} is non-finite
110      */
111     static double requireFinite(double x, String name) {
112         if (!Double.isFinite(x)) {
113             throw new IllegalArgumentException(name + " is not finite: " + x);
114         }
115         return x;
116     }
117 
118     /**
119      * Checks the value {@code x >= 0} and is finite.
120      * Note: This method allows {@code x == -0.0}.
121      *
122      * @param x Value.
123      * @param name Name of the value.
124      * @return x
125      * @throws IllegalArgumentException if {@code x < 0} or is non-finite
126      */
127     static double requirePositiveFinite(double x, String name) {
128         if (!(x >= 0 && x < Double.POSITIVE_INFINITY)) {
129             throw new IllegalArgumentException(
130                 name + " is not positive and finite: " + x);
131         }
132         return x;
133     }
134 
135     /**
136      * Checks the value {@code x > 0} and is finite.
137      *
138      * @param x Value.
139      * @param name Name of the value.
140      * @return x
141      * @throws IllegalArgumentException if {@code x <= 0} or is non-finite
142      */
143     static double requireStrictlyPositiveFinite(double x, String name) {
144         if (!(x > 0 && x < Double.POSITIVE_INFINITY)) {
145             throw new IllegalArgumentException(
146                 name + " is not strictly positive and finite: " + x);
147         }
148         return x;
149     }
150 
151     /**
152      * Checks the value {@code x >= 0}.
153      * Note: This method allows {@code x == -0.0}.
154      *
155      * @param x Value.
156      * @param name Name of the value.
157      * @return x
158      * @throws IllegalArgumentException if {@code x < 0}
159      */
160     static double requirePositive(double x, String name) {
161         // Logic inversion detects NaN
162         if (!(x >= 0)) {
163             throw new IllegalArgumentException(name + " is not positive: " + x);
164         }
165         return x;
166     }
167 
168     /**
169      * Checks the value {@code x > 0}.
170      *
171      * @param x Value.
172      * @param name Name of the value.
173      * @return x
174      * @throws IllegalArgumentException if {@code x <= 0}
175      */
176     static double requireStrictlyPositive(double x, String name) {
177         // Logic inversion detects NaN
178         if (!(x > 0)) {
179             throw new IllegalArgumentException(name + " is not strictly positive: " + x);
180         }
181         return x;
182     }
183 
184     /**
185      * Checks the value is within the range: {@code min <= x < max}.
186      *
187      * @param min Minimum (inclusive).
188      * @param max Maximum (exclusive).
189      * @param x Value.
190      * @param name Name of the value.
191      * @return x
192      * @throws IllegalArgumentException if {@code x < min || x >= max}.
193      */
194     static double requireRange(double min, double max, double x, String name) {
195         if (!(min <= x && x < max)) {
196             throw new IllegalArgumentException(
197                 String.format("%s not within range: %s <= %s < %s", name, min, x, max));
198         }
199         return x;
200     }
201 
202     /**
203      * Checks the value is within the closed range: {@code min <= x <= max}.
204      *
205      * @param min Minimum (inclusive).
206      * @param max Maximum (inclusive).
207      * @param x Value.
208      * @param name Name of the value.
209      * @return x
210      * @throws IllegalArgumentException if {@code x < min || x > max}.
211      */
212     static double requireRangeClosed(double min, double max, double x, String name) {
213         if (!(min <= x && x <= max)) {
214             throw new IllegalArgumentException(
215                 String.format("%s not within closed range: %s <= %s <= %s", name, min, x, max));
216         }
217         return x;
218     }
219 
220     /**
221      * Create a new instance of the given sampler using
222      * {@link SharedStateSampler#withUniformRandomProvider(UniformRandomProvider)}.
223      *
224      * @param sampler Source sampler.
225      * @param rng Generator of uniformly distributed random numbers.
226      * @return the new sampler
227      * @throws UnsupportedOperationException if the underlying sampler is not a
228      * {@link SharedStateSampler} or does not return a {@link NormalizedGaussianSampler} when
229      * sharing state.
230      */
231     static NormalizedGaussianSampler newNormalizedGaussianSampler(
232             NormalizedGaussianSampler sampler,
233             UniformRandomProvider rng) {
234         if (!(sampler instanceof SharedStateSampler<?>)) {
235             throw new UnsupportedOperationException("The underlying sampler cannot share state");
236         }
237         final Object newSampler = ((SharedStateSampler<?>) sampler).withUniformRandomProvider(rng);
238         if (!(newSampler instanceof NormalizedGaussianSampler)) {
239             throw new UnsupportedOperationException(
240                 "The underlying sampler did not create a normalized Gaussian sampler");
241         }
242         return (NormalizedGaussianSampler) newSampler;
243     }
244 
245     /**
246      * Creates a {@code double} in the interval {@code [0, 1)} from a {@code long} value.
247      *
248      * @param v Number.
249      * @return a {@code double} value in the interval {@code [0, 1)}.
250      */
251     static double makeDouble(long v) {
252         // This matches the method in o.a.c.rng.core.util.NumberFactory.makeDouble(long)
253         // without adding an explicit dependency on that module.
254         return (v >>> 11) * DOUBLE_MULTIPLIER;
255     }
256 
257     /**
258      * Creates a {@code double} in the interval {@code (0, 1]} from a {@code long} value.
259      *
260      * @param v Number.
261      * @return a {@code double} value in the interval {@code (0, 1]}.
262      */
263     static double makeNonZeroDouble(long v) {
264         // This matches the method in o.a.c.rng.core.util.NumberFactory.makeDouble(long)
265         // but shifts the range from [0, 1) to (0, 1].
266         return ((v >>> 11) + 1L) * DOUBLE_MULTIPLIER;
267     }
268 
269     /**
270      * Class for computing the natural logarithm of the factorial of {@code n}.
271      * It allows to allocate a cache of precomputed values.
272      * In case of cache miss, computation is performed by a call to
273      * {@link InternalGamma#logGamma(double)}.
274      */
275     public static final class FactorialLog {
276         /**
277          * Precomputed values of the function:
278          * {@code LOG_FACTORIALS[i] = log(i!)}.
279          */
280         private final double[] logFactorials;
281 
282         /**
283          * Creates an instance, reusing the already computed values if available.
284          *
285          * @param numValues Number of values of the function to compute.
286          * @param cache Existing cache.
287          * @throws NegativeArraySizeException if {@code numValues < 0}.
288          */
289         private FactorialLog(int numValues,
290                              double[] cache) {
291             logFactorials = new double[numValues];
292 
293             final int endCopy;
294             if (cache != null && cache.length > BEGIN_LOG_FACTORIALS) {
295                 // Copy available values.
296                 endCopy = Math.min(cache.length, numValues);
297                 System.arraycopy(cache, BEGIN_LOG_FACTORIALS, logFactorials, BEGIN_LOG_FACTORIALS,
298                     endCopy - BEGIN_LOG_FACTORIALS);
299             } else {
300                 // All values to be computed
301                 endCopy = BEGIN_LOG_FACTORIALS;
302             }
303 
304             // Compute remaining values.
305             for (int i = endCopy; i < numValues; i++) {
306                 if (i < LOG_FACTORIALS.length) {
307                     logFactorials[i] = LOG_FACTORIALS[i];
308                 } else {
309                     logFactorials[i] = logFactorials[i - 1] + Math.log(i);
310                 }
311             }
312         }
313 
314         /**
315          * Creates an instance with no precomputed values.
316          *
317          * @return an instance with no precomputed values.
318          */
319         public static FactorialLog create() {
320             return new FactorialLog(0, null);
321         }
322 
323         /**
324          * Creates an instance with the specified cache size.
325          *
326          * @param cacheSize Number of precomputed values of the function.
327          * @return a new instance where {@code cacheSize} values have been
328          * precomputed.
329          * @throws IllegalArgumentException if {@code n < 0}.
330          */
331         public FactorialLog withCache(final int cacheSize) {
332             return new FactorialLog(cacheSize, logFactorials);
333         }
334 
335         /**
336          * Computes {@code log(n!)}.
337          *
338          * @param n Argument.
339          * @return {@code log(n!)}.
340          * @throws IndexOutOfBoundsException if {@code numValues < 0}.
341          */
342         public double value(final int n) {
343             // Use cache of precomputed values.
344             if (n < logFactorials.length) {
345                 return logFactorials[n];
346             }
347 
348             // Use cache of precomputed log factorial values.
349             if (n < LOG_FACTORIALS.length) {
350                 return LOG_FACTORIALS[n];
351             }
352 
353             // Delegate.
354             return InternalGamma.logGamma(n + 1.0);
355         }
356     }
357 }