DirichletSampler.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.rng.sampling.distribution;

  18. import org.apache.commons.rng.UniformRandomProvider;
  19. import org.apache.commons.rng.sampling.SharedStateObjectSampler;

  20. /**
  21.  * Sampling from a <a href="https://en.wikipedia.org/wiki/Dirichlet_distribution">Dirichlet
  22.  * distribution</a>.
  23.  *
  24.  * <p>Sampling uses:</p>
  25.  *
  26.  * <ul>
  27.  *   <li>{@link UniformRandomProvider#nextLong()}
  28.  *   <li>{@link UniformRandomProvider#nextDouble()}
  29.  * </ul>
  30.  *
  31.  * @since 1.4
  32.  */
  33. public abstract class DirichletSampler implements SharedStateObjectSampler<double[]> {
  34.     /** The minimum number of categories. */
  35.     private static final int MIN_CATGEORIES = 2;

  36.     /** RNG (used for the toString() method). */
  37.     private final UniformRandomProvider rng;

  38.     /**
  39.      * Sample from a Dirichlet distribution with different concentration parameters
  40.      * for each category.
  41.      */
  42.     private static final class GeneralDirichletSampler extends DirichletSampler {
  43.         /** Samplers for each category. */
  44.         private final SharedStateContinuousSampler[] samplers;

  45.         /**
  46.          * @param rng Generator of uniformly distributed random numbers.
  47.          * @param samplers Samplers for each category.
  48.          */
  49.         GeneralDirichletSampler(UniformRandomProvider rng,
  50.                                 SharedStateContinuousSampler[] samplers) {
  51.             super(rng);
  52.             // Array is stored directly as it is generated within the DirichletSampler class
  53.             this.samplers = samplers;
  54.         }

  55.         @Override
  56.         protected int getK() {
  57.             return samplers.length;
  58.         }

  59.         @Override
  60.         protected double nextGamma(int i) {
  61.             return samplers[i].sample();
  62.         }

  63.         @Override
  64.         public GeneralDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
  65.             final SharedStateContinuousSampler[] newSamplers = new SharedStateContinuousSampler[samplers.length];
  66.             for (int i = 0; i < newSamplers.length; i++) {
  67.                 newSamplers[i] = samplers[i].withUniformRandomProvider(rng);
  68.             }
  69.             return new GeneralDirichletSampler(rng, newSamplers);
  70.         }
  71.     }

  72.     /**
  73.      * Sample from a symmetric Dirichlet distribution with the same concentration parameter
  74.      * for each category.
  75.      */
  76.     private static final class SymmetricDirichletSampler extends DirichletSampler {
  77.         /** Number of categories. */
  78.         private final int k;
  79.         /** Sampler for the categories. */
  80.         private final SharedStateContinuousSampler sampler;

  81.         /**
  82.          * @param rng Generator of uniformly distributed random numbers.
  83.          * @param k Number of categories.
  84.          * @param sampler Sampler for the categories.
  85.          */
  86.         SymmetricDirichletSampler(UniformRandomProvider rng,
  87.                                   int k,
  88.                                   SharedStateContinuousSampler sampler) {
  89.             super(rng);
  90.             this.k = k;
  91.             this.sampler = sampler;
  92.         }

  93.         @Override
  94.         protected int getK() {
  95.             return k;
  96.         }

  97.         @Override
  98.         protected double nextGamma(int i) {
  99.             return sampler.sample();
  100.         }

  101.         @Override
  102.         public SymmetricDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
  103.             return new SymmetricDirichletSampler(rng, k, sampler.withUniformRandomProvider(rng));
  104.         }
  105.     }

  106.     /**
  107.      * @param rng Generator of uniformly distributed random numbers.
  108.      */
  109.     DirichletSampler(UniformRandomProvider rng) {
  110.         this.rng = rng;
  111.     }

  112.     /** {@inheritDoc} */
  113.     @Override
  114.     public String toString() {
  115.         return "Dirichlet deviate [" + rng.toString() + "]";
  116.     }

  117.     /** {@inheritDoc} */
  118.     @Override
  119.     public double[] sample() {
  120.         // Create Gamma(alpha_i, 1) deviates for all alpha
  121.         final double[] y = new double[getK()];
  122.         double norm = 0;
  123.         for (int i = 0; i < y.length; i++) {
  124.             final double yi = nextGamma(i);
  125.             norm += yi;
  126.             y[i] = yi;
  127.         }
  128.         // Normalize by dividing by the sum of the samples
  129.         norm = 1.0 / norm;
  130.         // Detect an invalid normalization, e.g. cases of all zero samples
  131.         if (!isNonZeroPositiveFinite(norm)) {
  132.             // Sample again using recursion.
  133.             // A stack overflow due to a broken RNG will eventually occur
  134.             // rather than the alternative which is an infinite loop.
  135.             return sample();
  136.         }
  137.         // Normalise
  138.         for (int i = 0; i < y.length; i++) {
  139.             y[i] *= norm;
  140.         }
  141.         return y;
  142.     }

  143.     /**
  144.      * Gets the number of categories.
  145.      *
  146.      * @return k
  147.      */
  148.     protected abstract int getK();

  149.     /**
  150.      * Create a gamma sample for the given category.
  151.      *
  152.      * @param category Category.
  153.      * @return the sample
  154.      */
  155.     protected abstract double nextGamma(int category);

  156.     /** {@inheritDoc} */
  157.     // Redeclare the signature to return a DirichletSampler not a SharedStateObjectSampler<double[]>
  158.     @Override
  159.     public abstract DirichletSampler withUniformRandomProvider(UniformRandomProvider rng);

  160.     /**
  161.      * Creates a new Dirichlet distribution sampler.
  162.      *
  163.      * @param rng Generator of uniformly distributed random numbers.
  164.      * @param alpha Concentration parameters.
  165.      * @return the sampler
  166.      * @throws IllegalArgumentException if the number of concentration parameters
  167.      * is less than 2; or if any concentration parameter is not strictly positive.
  168.      */
  169.     public static DirichletSampler of(UniformRandomProvider rng,
  170.                                       double... alpha) {
  171.         validateNumberOfCategories(alpha.length);
  172.         final SharedStateContinuousSampler[] samplers = new SharedStateContinuousSampler[alpha.length];
  173.         for (int i = 0; i < samplers.length; i++) {
  174.             samplers[i] = createSampler(rng, alpha[i]);
  175.         }
  176.         return new GeneralDirichletSampler(rng, samplers);
  177.     }

  178.     /**
  179.      * Creates a new symmetric Dirichlet distribution sampler using the same concentration
  180.      * parameter for each category.
  181.      *
  182.      * @param rng Generator of uniformly distributed random numbers.
  183.      * @param k Number of categories.
  184.      * @param alpha Concentration parameter.
  185.      * @return the sampler
  186.      * @throws IllegalArgumentException if the number of categories is
  187.      * less than 2; or if the concentration parameter is not strictly positive.
  188.      */
  189.     public static DirichletSampler symmetric(UniformRandomProvider rng,
  190.                                              int k,
  191.                                              double alpha) {
  192.         validateNumberOfCategories(k);
  193.         final SharedStateContinuousSampler sampler = createSampler(rng, alpha);
  194.         return new SymmetricDirichletSampler(rng, k, sampler);
  195.     }

  196.     /**
  197.      * Validate the number of categories.
  198.      *
  199.      * @param k Number of categories.
  200.      * @throws IllegalArgumentException if the number of categories is
  201.      * less than 2.
  202.      */
  203.     private static void validateNumberOfCategories(int k) {
  204.         if (k < MIN_CATGEORIES) {
  205.             throw new IllegalArgumentException("Invalid number of categories: " + k);
  206.         }
  207.     }

  208.     /**
  209.      * Creates a gamma sampler for a category with the given concentration parameter.
  210.      *
  211.      * @param rng Generator of uniformly distributed random numbers.
  212.      * @param alpha Concentration parameter.
  213.      * @return the sampler
  214.      * @throws IllegalArgumentException if the concentration parameter is not strictly positive.
  215.      */
  216.     private static SharedStateContinuousSampler createSampler(UniformRandomProvider rng,
  217.                                                               double alpha) {
  218.         InternalUtils.requireStrictlyPositiveFinite(alpha, "alpha concentration");
  219.         // Create a Gamma(shape=alpha, scale=1) sampler.
  220.         if (alpha == 1) {
  221.             // Special case
  222.             // Gamma(shape=1, scale=1) == Exponential(mean=1)
  223.             return ZigguratSampler.Exponential.of(rng);
  224.         }
  225.         return AhrensDieterMarsagliaTsangGammaSampler.of(rng, alpha, 1);
  226.     }

  227.     /**
  228.      * Return true if the value is non-zero, positive and finite.
  229.      *
  230.      * @param x Value.
  231.      * @return true if non-zero positive finite
  232.      */
  233.     private static boolean isNonZeroPositiveFinite(double x) {
  234.         return x > 0 && x < Double.POSITIVE_INFINITY;
  235.     }
  236. }