InternalUtils.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.rng.sampling.distribution;

  18. import org.apache.commons.rng.UniformRandomProvider;
  19. import org.apache.commons.rng.sampling.SharedStateSampler;

  20. /**
  21.  * Functions used by some of the samplers.
  22.  * This class is not part of the public API, as it would be
  23.  * better to group these utilities in a dedicated component.
  24.  */
  25. final class InternalUtils {
  26.     /** All long-representable factorials, precomputed as the natural
  27.      * logarithm using Matlab R2023a VPA: log(vpa(x)).
  28.      *
  29.      * <p>Note: This table could be any length. Previously this stored
  30.      * the long value of n!, not log(n!). Using the previous length
  31.      * maintains behaviour. */
  32.     private static final double[] LOG_FACTORIALS = {
  33.         0,
  34.         0,
  35.         0.69314718055994530941723212145818,
  36.         1.7917594692280550008124773583807,
  37.         3.1780538303479456196469416012971,
  38.         4.7874917427820459942477009345232,
  39.         6.5792512120101009950601782929039,
  40.         8.5251613610654143001655310363471,
  41.         10.604602902745250228417227400722,
  42.         12.801827480081469611207717874567,
  43.         15.104412573075515295225709329251,
  44.         17.502307845873885839287652907216,
  45.         19.987214495661886149517362387055,
  46.         22.55216385312342288557084982862,
  47.         25.191221182738681500093434693522,
  48.         27.89927138384089156608943926367,
  49.         30.671860106080672803758367749503,
  50.         33.505073450136888884007902367376,
  51.         36.39544520803305357621562496268,
  52.         39.339884187199494036224652394567,
  53.         42.33561646075348502965987597071
  54.     };

  55.     /** The first array index with a non-zero log factorial. */
  56.     private static final int BEGIN_LOG_FACTORIALS = 2;

  57.     /**
  58.      * The multiplier to convert the least significant 53-bits of a {@code long} to a {@code double}.
  59.      * Taken from org.apache.commons.rng.core.util.NumberFactory.
  60.      */
  61.     private static final double DOUBLE_MULTIPLIER = 0x1.0p-53d;

  62.     /** Utility class. */
  63.     private InternalUtils() {}

  64.     /**
  65.      * @param n Argument.
  66.      * @return {@code n!}
  67.      * @throws IndexOutOfBoundsException if the result is too large to be represented
  68.      * by a {@code long} (i.e. if {@code n > 20}), or {@code n} is negative.
  69.      */
  70.     static double logFactorial(int n)  {
  71.         return LOG_FACTORIALS[n];
  72.     }

  73.     /**
  74.      * Validate the probabilities sum to a finite positive number.
  75.      *
  76.      * @param probabilities the probabilities
  77.      * @return the sum
  78.      * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
  79.      * probability is negative, infinite or {@code NaN}, or the sum of all
  80.      * probabilities is not strictly positive.
  81.      */
  82.     static double validateProbabilities(double[] probabilities) {
  83.         if (probabilities == null || probabilities.length == 0) {
  84.             throw new IllegalArgumentException("Probabilities must not be empty.");
  85.         }

  86.         double sumProb = 0;
  87.         for (final double prob : probabilities) {
  88.             sumProb += requirePositiveFinite(prob, "probability");
  89.         }

  90.         return requireStrictlyPositiveFinite(sumProb, "sum of probabilities");
  91.     }

  92.     /**
  93.      * Checks the value {@code x} is finite.
  94.      *
  95.      * @param x Value.
  96.      * @param name Name of the value.
  97.      * @return x
  98.      * @throws IllegalArgumentException if {@code x} is non-finite
  99.      */
  100.     static double requireFinite(double x, String name) {
  101.         if (!Double.isFinite(x)) {
  102.             throw new IllegalArgumentException(name + " is not finite: " + x);
  103.         }
  104.         return x;
  105.     }

  106.     /**
  107.      * Checks the value {@code x >= 0} and is finite.
  108.      * Note: This method allows {@code x == -0.0}.
  109.      *
  110.      * @param x Value.
  111.      * @param name Name of the value.
  112.      * @return x
  113.      * @throws IllegalArgumentException if {@code x < 0} or is non-finite
  114.      */
  115.     static double requirePositiveFinite(double x, String name) {
  116.         if (!(x >= 0 && x < Double.POSITIVE_INFINITY)) {
  117.             throw new IllegalArgumentException(
  118.                 name + " is not positive and finite: " + x);
  119.         }
  120.         return x;
  121.     }

  122.     /**
  123.      * Checks the value {@code x > 0} and is finite.
  124.      *
  125.      * @param x Value.
  126.      * @param name Name of the value.
  127.      * @return x
  128.      * @throws IllegalArgumentException if {@code x <= 0} or is non-finite
  129.      */
  130.     static double requireStrictlyPositiveFinite(double x, String name) {
  131.         if (!(x > 0 && x < Double.POSITIVE_INFINITY)) {
  132.             throw new IllegalArgumentException(
  133.                 name + " is not strictly positive and finite: " + x);
  134.         }
  135.         return x;
  136.     }

  137.     /**
  138.      * Checks the value {@code x >= 0}.
  139.      * Note: This method allows {@code x == -0.0}.
  140.      *
  141.      * @param x Value.
  142.      * @param name Name of the value.
  143.      * @return x
  144.      * @throws IllegalArgumentException if {@code x < 0}
  145.      */
  146.     static double requirePositive(double x, String name) {
  147.         // Logic inversion detects NaN
  148.         if (!(x >= 0)) {
  149.             throw new IllegalArgumentException(name + " is not positive: " + x);
  150.         }
  151.         return x;
  152.     }

  153.     /**
  154.      * Checks the value {@code x > 0}.
  155.      *
  156.      * @param x Value.
  157.      * @param name Name of the value.
  158.      * @return x
  159.      * @throws IllegalArgumentException if {@code x <= 0}
  160.      */
  161.     static double requireStrictlyPositive(double x, String name) {
  162.         // Logic inversion detects NaN
  163.         if (!(x > 0)) {
  164.             throw new IllegalArgumentException(name + " is not strictly positive: " + x);
  165.         }
  166.         return x;
  167.     }

  168.     /**
  169.      * Checks the value is within the range: {@code min <= x < max}.
  170.      *
  171.      * @param min Minimum (inclusive).
  172.      * @param max Maximum (exclusive).
  173.      * @param x Value.
  174.      * @param name Name of the value.
  175.      * @return x
  176.      * @throws IllegalArgumentException if {@code x < min || x >= max}.
  177.      */
  178.     static double requireRange(double min, double max, double x, String name) {
  179.         if (!(min <= x && x < max)) {
  180.             throw new IllegalArgumentException(
  181.                 String.format("%s not within range: %s <= %s < %s", name, min, x, max));
  182.         }
  183.         return x;
  184.     }

  185.     /**
  186.      * Checks the value is within the closed range: {@code min <= x <= max}.
  187.      *
  188.      * @param min Minimum (inclusive).
  189.      * @param max Maximum (inclusive).
  190.      * @param x Value.
  191.      * @param name Name of the value.
  192.      * @return x
  193.      * @throws IllegalArgumentException if {@code x < min || x > max}.
  194.      */
  195.     static double requireRangeClosed(double min, double max, double x, String name) {
  196.         if (!(min <= x && x <= max)) {
  197.             throw new IllegalArgumentException(
  198.                 String.format("%s not within closed range: %s <= %s <= %s", name, min, x, max));
  199.         }
  200.         return x;
  201.     }

  202.     /**
  203.      * Create a new instance of the given sampler using
  204.      * {@link SharedStateSampler#withUniformRandomProvider(UniformRandomProvider)}.
  205.      *
  206.      * @param sampler Source sampler.
  207.      * @param rng Generator of uniformly distributed random numbers.
  208.      * @return the new sampler
  209.      * @throws UnsupportedOperationException if the underlying sampler is not a
  210.      * {@link SharedStateSampler} or does not return a {@link NormalizedGaussianSampler} when
  211.      * sharing state.
  212.      */
  213.     static NormalizedGaussianSampler newNormalizedGaussianSampler(
  214.             NormalizedGaussianSampler sampler,
  215.             UniformRandomProvider rng) {
  216.         if (!(sampler instanceof SharedStateSampler<?>)) {
  217.             throw new UnsupportedOperationException("The underlying sampler cannot share state");
  218.         }
  219.         final Object newSampler = ((SharedStateSampler<?>) sampler).withUniformRandomProvider(rng);
  220.         if (!(newSampler instanceof NormalizedGaussianSampler)) {
  221.             throw new UnsupportedOperationException(
  222.                 "The underlying sampler did not create a normalized Gaussian sampler");
  223.         }
  224.         return (NormalizedGaussianSampler) newSampler;
  225.     }

  226.     /**
  227.      * Creates a {@code double} in the interval {@code [0, 1)} from a {@code long} value.
  228.      *
  229.      * @param v Number.
  230.      * @return a {@code double} value in the interval {@code [0, 1)}.
  231.      */
  232.     static double makeDouble(long v) {
  233.         // This matches the method in o.a.c.rng.core.util.NumberFactory.makeDouble(long)
  234.         // without adding an explicit dependency on that module.
  235.         return (v >>> 11) * DOUBLE_MULTIPLIER;
  236.     }

  237.     /**
  238.      * Creates a {@code double} in the interval {@code (0, 1]} from a {@code long} value.
  239.      *
  240.      * @param v Number.
  241.      * @return a {@code double} value in the interval {@code (0, 1]}.
  242.      */
  243.     static double makeNonZeroDouble(long v) {
  244.         // This matches the method in o.a.c.rng.core.util.NumberFactory.makeDouble(long)
  245.         // but shifts the range from [0, 1) to (0, 1].
  246.         return ((v >>> 11) + 1L) * DOUBLE_MULTIPLIER;
  247.     }

  248.     /**
  249.      * Class for computing the natural logarithm of the factorial of {@code n}.
  250.      * It allows to allocate a cache of precomputed values.
  251.      * In case of cache miss, computation is performed by a call to
  252.      * {@link InternalGamma#logGamma(double)}.
  253.      */
  254.     public static final class FactorialLog {
  255.         /**
  256.          * Precomputed values of the function:
  257.          * {@code LOG_FACTORIALS[i] = log(i!)}.
  258.          */
  259.         private final double[] logFactorials;

  260.         /**
  261.          * Creates an instance, reusing the already computed values if available.
  262.          *
  263.          * @param numValues Number of values of the function to compute.
  264.          * @param cache Existing cache.
  265.          * @throws NegativeArraySizeException if {@code numValues < 0}.
  266.          */
  267.         private FactorialLog(int numValues,
  268.                              double[] cache) {
  269.             logFactorials = new double[numValues];

  270.             final int endCopy;
  271.             if (cache != null && cache.length > BEGIN_LOG_FACTORIALS) {
  272.                 // Copy available values.
  273.                 endCopy = Math.min(cache.length, numValues);
  274.                 System.arraycopy(cache, BEGIN_LOG_FACTORIALS, logFactorials, BEGIN_LOG_FACTORIALS,
  275.                     endCopy - BEGIN_LOG_FACTORIALS);
  276.             } else {
  277.                 // All values to be computed
  278.                 endCopy = BEGIN_LOG_FACTORIALS;
  279.             }

  280.             // Compute remaining values.
  281.             for (int i = endCopy; i < numValues; i++) {
  282.                 if (i < LOG_FACTORIALS.length) {
  283.                     logFactorials[i] = LOG_FACTORIALS[i];
  284.                 } else {
  285.                     logFactorials[i] = logFactorials[i - 1] + Math.log(i);
  286.                 }
  287.             }
  288.         }

  289.         /**
  290.          * Creates an instance with no precomputed values.
  291.          *
  292.          * @return an instance with no precomputed values.
  293.          */
  294.         public static FactorialLog create() {
  295.             return new FactorialLog(0, null);
  296.         }

  297.         /**
  298.          * Creates an instance with the specified cache size.
  299.          *
  300.          * @param cacheSize Number of precomputed values of the function.
  301.          * @return a new instance where {@code cacheSize} values have been
  302.          * precomputed.
  303.          * @throws IllegalArgumentException if {@code n < 0}.
  304.          */
  305.         public FactorialLog withCache(final int cacheSize) {
  306.             return new FactorialLog(cacheSize, logFactorials);
  307.         }

  308.         /**
  309.          * Computes {@code log(n!)}.
  310.          *
  311.          * @param n Argument.
  312.          * @return {@code log(n!)}.
  313.          * @throws IndexOutOfBoundsException if {@code numValues < 0}.
  314.          */
  315.         public double value(final int n) {
  316.             // Use cache of precomputed values.
  317.             if (n < logFactorials.length) {
  318.                 return logFactorials[n];
  319.             }

  320.             // Use cache of precomputed log factorial values.
  321.             if (n < LOG_FACTORIALS.length) {
  322.                 return LOG_FACTORIALS[n];
  323.             }

  324.             // Delegate.
  325.             return InternalGamma.logGamma(n + 1.0);
  326.         }
  327.     }
  328. }