DirichletSampler.java
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- package org.apache.commons.rng.sampling.distribution;
- import org.apache.commons.rng.UniformRandomProvider;
- import org.apache.commons.rng.sampling.SharedStateObjectSampler;
- /**
- * Sampling from a <a href="https://en.wikipedia.org/wiki/Dirichlet_distribution">Dirichlet
- * distribution</a>.
- *
- * <p>Sampling uses:</p>
- *
- * <ul>
- * <li>{@link UniformRandomProvider#nextLong()}
- * <li>{@link UniformRandomProvider#nextDouble()}
- * </ul>
- *
- * @since 1.4
- */
- public abstract class DirichletSampler implements SharedStateObjectSampler<double[]> {
- /** The minimum number of categories. */
- private static final int MIN_CATGEORIES = 2;
- /** RNG (used for the toString() method). */
- private final UniformRandomProvider rng;
- /**
- * Sample from a Dirichlet distribution with different concentration parameters
- * for each category.
- */
- private static final class GeneralDirichletSampler extends DirichletSampler {
- /** Samplers for each category. */
- private final SharedStateContinuousSampler[] samplers;
- /**
- * @param rng Generator of uniformly distributed random numbers.
- * @param samplers Samplers for each category.
- */
- GeneralDirichletSampler(UniformRandomProvider rng,
- SharedStateContinuousSampler[] samplers) {
- super(rng);
- // Array is stored directly as it is generated within the DirichletSampler class
- this.samplers = samplers;
- }
- @Override
- protected int getK() {
- return samplers.length;
- }
- @Override
- protected double nextGamma(int i) {
- return samplers[i].sample();
- }
- @Override
- public GeneralDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
- final SharedStateContinuousSampler[] newSamplers = new SharedStateContinuousSampler[samplers.length];
- for (int i = 0; i < newSamplers.length; i++) {
- newSamplers[i] = samplers[i].withUniformRandomProvider(rng);
- }
- return new GeneralDirichletSampler(rng, newSamplers);
- }
- }
- /**
- * Sample from a symmetric Dirichlet distribution with the same concentration parameter
- * for each category.
- */
- private static final class SymmetricDirichletSampler extends DirichletSampler {
- /** Number of categories. */
- private final int k;
- /** Sampler for the categories. */
- private final SharedStateContinuousSampler sampler;
- /**
- * @param rng Generator of uniformly distributed random numbers.
- * @param k Number of categories.
- * @param sampler Sampler for the categories.
- */
- SymmetricDirichletSampler(UniformRandomProvider rng,
- int k,
- SharedStateContinuousSampler sampler) {
- super(rng);
- this.k = k;
- this.sampler = sampler;
- }
- @Override
- protected int getK() {
- return k;
- }
- @Override
- protected double nextGamma(int i) {
- return sampler.sample();
- }
- @Override
- public SymmetricDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
- return new SymmetricDirichletSampler(rng, k, sampler.withUniformRandomProvider(rng));
- }
- }
- /**
- * @param rng Generator of uniformly distributed random numbers.
- */
- DirichletSampler(UniformRandomProvider rng) {
- this.rng = rng;
- }
- /** {@inheritDoc} */
- @Override
- public String toString() {
- return "Dirichlet deviate [" + rng.toString() + "]";
- }
- /** {@inheritDoc} */
- @Override
- public double[] sample() {
- // Create Gamma(alpha_i, 1) deviates for all alpha
- final double[] y = new double[getK()];
- double norm = 0;
- for (int i = 0; i < y.length; i++) {
- final double yi = nextGamma(i);
- norm += yi;
- y[i] = yi;
- }
- // Normalize by dividing by the sum of the samples
- norm = 1.0 / norm;
- // Detect an invalid normalization, e.g. cases of all zero samples
- if (!isNonZeroPositiveFinite(norm)) {
- // Sample again using recursion.
- // A stack overflow due to a broken RNG will eventually occur
- // rather than the alternative which is an infinite loop.
- return sample();
- }
- // Normalise
- for (int i = 0; i < y.length; i++) {
- y[i] *= norm;
- }
- return y;
- }
- /**
- * Gets the number of categories.
- *
- * @return k
- */
- protected abstract int getK();
- /**
- * Create a gamma sample for the given category.
- *
- * @param category Category.
- * @return the sample
- */
- protected abstract double nextGamma(int category);
- /** {@inheritDoc} */
- // Redeclare the signature to return a DirichletSampler not a SharedStateObjectSampler<double[]>
- @Override
- public abstract DirichletSampler withUniformRandomProvider(UniformRandomProvider rng);
- /**
- * Creates a new Dirichlet distribution sampler.
- *
- * @param rng Generator of uniformly distributed random numbers.
- * @param alpha Concentration parameters.
- * @return the sampler
- * @throws IllegalArgumentException if the number of concentration parameters
- * is less than 2; or if any concentration parameter is not strictly positive.
- */
- public static DirichletSampler of(UniformRandomProvider rng,
- double... alpha) {
- validateNumberOfCategories(alpha.length);
- final SharedStateContinuousSampler[] samplers = new SharedStateContinuousSampler[alpha.length];
- for (int i = 0; i < samplers.length; i++) {
- samplers[i] = createSampler(rng, alpha[i]);
- }
- return new GeneralDirichletSampler(rng, samplers);
- }
- /**
- * Creates a new symmetric Dirichlet distribution sampler using the same concentration
- * parameter for each category.
- *
- * @param rng Generator of uniformly distributed random numbers.
- * @param k Number of categories.
- * @param alpha Concentration parameter.
- * @return the sampler
- * @throws IllegalArgumentException if the number of categories is
- * less than 2; or if the concentration parameter is not strictly positive.
- */
- public static DirichletSampler symmetric(UniformRandomProvider rng,
- int k,
- double alpha) {
- validateNumberOfCategories(k);
- final SharedStateContinuousSampler sampler = createSampler(rng, alpha);
- return new SymmetricDirichletSampler(rng, k, sampler);
- }
- /**
- * Validate the number of categories.
- *
- * @param k Number of categories.
- * @throws IllegalArgumentException if the number of categories is
- * less than 2.
- */
- private static void validateNumberOfCategories(int k) {
- if (k < MIN_CATGEORIES) {
- throw new IllegalArgumentException("Invalid number of categories: " + k);
- }
- }
- /**
- * Creates a gamma sampler for a category with the given concentration parameter.
- *
- * @param rng Generator of uniformly distributed random numbers.
- * @param alpha Concentration parameter.
- * @return the sampler
- * @throws IllegalArgumentException if the concentration parameter is not strictly positive.
- */
- private static SharedStateContinuousSampler createSampler(UniformRandomProvider rng,
- double alpha) {
- InternalUtils.requireStrictlyPositiveFinite(alpha, "alpha concentration");
- // Create a Gamma(shape=alpha, scale=1) sampler.
- if (alpha == 1) {
- // Special case
- // Gamma(shape=1, scale=1) == Exponential(mean=1)
- return ZigguratSampler.Exponential.of(rng);
- }
- return AhrensDieterMarsagliaTsangGammaSampler.of(rng, alpha, 1);
- }
- /**
- * Return true if the value is non-zero, positive and finite.
- *
- * @param x Value.
- * @return true if non-zero positive finite
- */
- private static boolean isNonZeroPositiveFinite(double x) {
- return x > 0 && x < Double.POSITIVE_INFINITY;
- }
- }