InternalUtils.java
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- package org.apache.commons.rng.sampling.distribution;
- import org.apache.commons.rng.UniformRandomProvider;
- import org.apache.commons.rng.sampling.SharedStateSampler;
- /**
- * Functions used by some of the samplers.
- * This class is not part of the public API, as it would be
- * better to group these utilities in a dedicated component.
- */
- final class InternalUtils {
- /** All long-representable factorials, precomputed as the natural
- * logarithm using Matlab R2023a VPA: log(vpa(x)).
- *
- * <p>Note: This table could be any length. Previously this stored
- * the long value of n!, not log(n!). Using the previous length
- * maintains behaviour. */
- private static final double[] LOG_FACTORIALS = {
- 0,
- 0,
- 0.69314718055994530941723212145818,
- 1.7917594692280550008124773583807,
- 3.1780538303479456196469416012971,
- 4.7874917427820459942477009345232,
- 6.5792512120101009950601782929039,
- 8.5251613610654143001655310363471,
- 10.604602902745250228417227400722,
- 12.801827480081469611207717874567,
- 15.104412573075515295225709329251,
- 17.502307845873885839287652907216,
- 19.987214495661886149517362387055,
- 22.55216385312342288557084982862,
- 25.191221182738681500093434693522,
- 27.89927138384089156608943926367,
- 30.671860106080672803758367749503,
- 33.505073450136888884007902367376,
- 36.39544520803305357621562496268,
- 39.339884187199494036224652394567,
- 42.33561646075348502965987597071
- };
- /** The first array index with a non-zero log factorial. */
- private static final int BEGIN_LOG_FACTORIALS = 2;
- /**
- * The multiplier to convert the least significant 53-bits of a {@code long} to a {@code double}.
- * Taken from org.apache.commons.rng.core.util.NumberFactory.
- */
- private static final double DOUBLE_MULTIPLIER = 0x1.0p-53d;
- /** Utility class. */
- private InternalUtils() {}
- /**
- * @param n Argument.
- * @return {@code n!}
- * @throws IndexOutOfBoundsException if the result is too large to be represented
- * by a {@code long} (i.e. if {@code n > 20}), or {@code n} is negative.
- */
- static double logFactorial(int n) {
- return LOG_FACTORIALS[n];
- }
- /**
- * Validate the probabilities sum to a finite positive number.
- *
- * @param probabilities the probabilities
- * @return the sum
- * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
- * probability is negative, infinite or {@code NaN}, or the sum of all
- * probabilities is not strictly positive.
- */
- static double validateProbabilities(double[] probabilities) {
- if (probabilities == null || probabilities.length == 0) {
- throw new IllegalArgumentException("Probabilities must not be empty.");
- }
- double sumProb = 0;
- for (final double prob : probabilities) {
- sumProb += requirePositiveFinite(prob, "probability");
- }
- return requireStrictlyPositiveFinite(sumProb, "sum of probabilities");
- }
- /**
- * Checks the value {@code x} is finite.
- *
- * @param x Value.
- * @param name Name of the value.
- * @return x
- * @throws IllegalArgumentException if {@code x} is non-finite
- */
- static double requireFinite(double x, String name) {
- if (!Double.isFinite(x)) {
- throw new IllegalArgumentException(name + " is not finite: " + x);
- }
- return x;
- }
- /**
- * Checks the value {@code x >= 0} and is finite.
- * Note: This method allows {@code x == -0.0}.
- *
- * @param x Value.
- * @param name Name of the value.
- * @return x
- * @throws IllegalArgumentException if {@code x < 0} or is non-finite
- */
- static double requirePositiveFinite(double x, String name) {
- if (!(x >= 0 && x < Double.POSITIVE_INFINITY)) {
- throw new IllegalArgumentException(
- name + " is not positive and finite: " + x);
- }
- return x;
- }
- /**
- * Checks the value {@code x > 0} and is finite.
- *
- * @param x Value.
- * @param name Name of the value.
- * @return x
- * @throws IllegalArgumentException if {@code x <= 0} or is non-finite
- */
- static double requireStrictlyPositiveFinite(double x, String name) {
- if (!(x > 0 && x < Double.POSITIVE_INFINITY)) {
- throw new IllegalArgumentException(
- name + " is not strictly positive and finite: " + x);
- }
- return x;
- }
- /**
- * Checks the value {@code x >= 0}.
- * Note: This method allows {@code x == -0.0}.
- *
- * @param x Value.
- * @param name Name of the value.
- * @return x
- * @throws IllegalArgumentException if {@code x < 0}
- */
- static double requirePositive(double x, String name) {
- // Logic inversion detects NaN
- if (!(x >= 0)) {
- throw new IllegalArgumentException(name + " is not positive: " + x);
- }
- return x;
- }
- /**
- * Checks the value {@code x > 0}.
- *
- * @param x Value.
- * @param name Name of the value.
- * @return x
- * @throws IllegalArgumentException if {@code x <= 0}
- */
- static double requireStrictlyPositive(double x, String name) {
- // Logic inversion detects NaN
- if (!(x > 0)) {
- throw new IllegalArgumentException(name + " is not strictly positive: " + x);
- }
- return x;
- }
- /**
- * Checks the value is within the range: {@code min <= x < max}.
- *
- * @param min Minimum (inclusive).
- * @param max Maximum (exclusive).
- * @param x Value.
- * @param name Name of the value.
- * @return x
- * @throws IllegalArgumentException if {@code x < min || x >= max}.
- */
- static double requireRange(double min, double max, double x, String name) {
- if (!(min <= x && x < max)) {
- throw new IllegalArgumentException(
- String.format("%s not within range: %s <= %s < %s", name, min, x, max));
- }
- return x;
- }
- /**
- * Checks the value is within the closed range: {@code min <= x <= max}.
- *
- * @param min Minimum (inclusive).
- * @param max Maximum (inclusive).
- * @param x Value.
- * @param name Name of the value.
- * @return x
- * @throws IllegalArgumentException if {@code x < min || x > max}.
- */
- static double requireRangeClosed(double min, double max, double x, String name) {
- if (!(min <= x && x <= max)) {
- throw new IllegalArgumentException(
- String.format("%s not within closed range: %s <= %s <= %s", name, min, x, max));
- }
- return x;
- }
- /**
- * Create a new instance of the given sampler using
- * {@link SharedStateSampler#withUniformRandomProvider(UniformRandomProvider)}.
- *
- * @param sampler Source sampler.
- * @param rng Generator of uniformly distributed random numbers.
- * @return the new sampler
- * @throws UnsupportedOperationException if the underlying sampler is not a
- * {@link SharedStateSampler} or does not return a {@link NormalizedGaussianSampler} when
- * sharing state.
- */
- static NormalizedGaussianSampler newNormalizedGaussianSampler(
- NormalizedGaussianSampler sampler,
- UniformRandomProvider rng) {
- if (!(sampler instanceof SharedStateSampler<?>)) {
- throw new UnsupportedOperationException("The underlying sampler cannot share state");
- }
- final Object newSampler = ((SharedStateSampler<?>) sampler).withUniformRandomProvider(rng);
- if (!(newSampler instanceof NormalizedGaussianSampler)) {
- throw new UnsupportedOperationException(
- "The underlying sampler did not create a normalized Gaussian sampler");
- }
- return (NormalizedGaussianSampler) newSampler;
- }
- /**
- * Creates a {@code double} in the interval {@code [0, 1)} from a {@code long} value.
- *
- * @param v Number.
- * @return a {@code double} value in the interval {@code [0, 1)}.
- */
- static double makeDouble(long v) {
- // This matches the method in o.a.c.rng.core.util.NumberFactory.makeDouble(long)
- // without adding an explicit dependency on that module.
- return (v >>> 11) * DOUBLE_MULTIPLIER;
- }
- /**
- * Creates a {@code double} in the interval {@code (0, 1]} from a {@code long} value.
- *
- * @param v Number.
- * @return a {@code double} value in the interval {@code (0, 1]}.
- */
- static double makeNonZeroDouble(long v) {
- // This matches the method in o.a.c.rng.core.util.NumberFactory.makeDouble(long)
- // but shifts the range from [0, 1) to (0, 1].
- return ((v >>> 11) + 1L) * DOUBLE_MULTIPLIER;
- }
- /**
- * Class for computing the natural logarithm of the factorial of {@code n}.
- * It allows to allocate a cache of precomputed values.
- * In case of cache miss, computation is performed by a call to
- * {@link InternalGamma#logGamma(double)}.
- */
- public static final class FactorialLog {
- /**
- * Precomputed values of the function:
- * {@code LOG_FACTORIALS[i] = log(i!)}.
- */
- private final double[] logFactorials;
- /**
- * Creates an instance, reusing the already computed values if available.
- *
- * @param numValues Number of values of the function to compute.
- * @param cache Existing cache.
- * @throws NegativeArraySizeException if {@code numValues < 0}.
- */
- private FactorialLog(int numValues,
- double[] cache) {
- logFactorials = new double[numValues];
- final int endCopy;
- if (cache != null && cache.length > BEGIN_LOG_FACTORIALS) {
- // Copy available values.
- endCopy = Math.min(cache.length, numValues);
- System.arraycopy(cache, BEGIN_LOG_FACTORIALS, logFactorials, BEGIN_LOG_FACTORIALS,
- endCopy - BEGIN_LOG_FACTORIALS);
- } else {
- // All values to be computed
- endCopy = BEGIN_LOG_FACTORIALS;
- }
- // Compute remaining values.
- for (int i = endCopy; i < numValues; i++) {
- if (i < LOG_FACTORIALS.length) {
- logFactorials[i] = LOG_FACTORIALS[i];
- } else {
- logFactorials[i] = logFactorials[i - 1] + Math.log(i);
- }
- }
- }
- /**
- * Creates an instance with no precomputed values.
- *
- * @return an instance with no precomputed values.
- */
- public static FactorialLog create() {
- return new FactorialLog(0, null);
- }
- /**
- * Creates an instance with the specified cache size.
- *
- * @param cacheSize Number of precomputed values of the function.
- * @return a new instance where {@code cacheSize} values have been
- * precomputed.
- * @throws IllegalArgumentException if {@code n < 0}.
- */
- public FactorialLog withCache(final int cacheSize) {
- return new FactorialLog(cacheSize, logFactorials);
- }
- /**
- * Computes {@code log(n!)}.
- *
- * @param n Argument.
- * @return {@code log(n!)}.
- * @throws IndexOutOfBoundsException if {@code numValues < 0}.
- */
- public double value(final int n) {
- // Use cache of precomputed values.
- if (n < logFactorials.length) {
- return logFactorials[n];
- }
- // Use cache of precomputed log factorial values.
- if (n < LOG_FACTORIALS.length) {
- return LOG_FACTORIALS[n];
- }
- // Delegate.
- return InternalGamma.logGamma(n + 1.0);
- }
- }
- }