DirichletSampler.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.rng.sampling.distribution;

import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.SharedStateObjectSampler;

/**
 * Sampling from a <a href="https://en.wikipedia.org/wiki/Dirichlet_distribution">Dirichlet
 * distribution</a>.
 *
 * <p>Sampling uses:</p>
 *
 * <ul>
 *   <li>{@link UniformRandomProvider#nextLong()}
 *   <li>{@link UniformRandomProvider#nextDouble()}
 * </ul>
 *
 * @since 1.4
 */
public abstract class DirichletSampler implements SharedStateObjectSampler<double[]> {
    /** The minimum number of categories. */
    private static final int MIN_CATGEORIES = 2;

    /** RNG (used for the toString() method). */
    private final UniformRandomProvider rng;

    /**
     * Sample from a Dirichlet distribution with different concentration parameters
     * for each category.
     */
    private static final class GeneralDirichletSampler extends DirichletSampler {
        /** Samplers for each category. */
        private final SharedStateContinuousSampler[] samplers;

        /**
         * @param rng Generator of uniformly distributed random numbers.
         * @param samplers Samplers for each category.
         */
        GeneralDirichletSampler(UniformRandomProvider rng,
                                SharedStateContinuousSampler[] samplers) {
            super(rng);
            // Array is stored directly as it is generated within the DirichletSampler class
            this.samplers = samplers;
        }

        @Override
        protected int getK() {
            return samplers.length;
        }

        @Override
        protected double nextGamma(int i) {
            return samplers[i].sample();
        }

        @Override
        public GeneralDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
            final SharedStateContinuousSampler[] newSamplers = new SharedStateContinuousSampler[samplers.length];
            for (int i = 0; i < newSamplers.length; i++) {
                newSamplers[i] = samplers[i].withUniformRandomProvider(rng);
            }
            return new GeneralDirichletSampler(rng, newSamplers);
        }
    }

    /**
     * Sample from a symmetric Dirichlet distribution with the same concentration parameter
     * for each category.
     */
    private static final class SymmetricDirichletSampler extends DirichletSampler {
        /** Number of categories. */
        private final int k;
        /** Sampler for the categories. */
        private final SharedStateContinuousSampler sampler;

        /**
         * @param rng Generator of uniformly distributed random numbers.
         * @param k Number of categories.
         * @param sampler Sampler for the categories.
         */
        SymmetricDirichletSampler(UniformRandomProvider rng,
                                  int k,
                                  SharedStateContinuousSampler sampler) {
            super(rng);
            this.k = k;
            this.sampler = sampler;
        }

        @Override
        protected int getK() {
            return k;
        }

        @Override
        protected double nextGamma(int i) {
            return sampler.sample();
        }

        @Override
        public SymmetricDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
            return new SymmetricDirichletSampler(rng, k, sampler.withUniformRandomProvider(rng));
        }
    }

    /**
     * @param rng Generator of uniformly distributed random numbers.
     */
    private DirichletSampler(UniformRandomProvider rng) {
        this.rng = rng;
    }

    /** {@inheritDoc} */
    @Override
    public String toString() {
        return "Dirichlet deviate [" + rng.toString() + "]";
    }

    @Override
    public double[] sample() {
        // Create Gamma(alpha_i, 1) deviates for all alpha
        final double[] y = new double[getK()];
        double norm = 0;
        for (int i = 0; i < y.length; i++) {
            final double yi = nextGamma(i);
            norm += yi;
            y[i] = yi;
        }
        // Normalize by dividing by the sum of the samples
        norm = 1.0 / norm;
        // Detect an invalid normalization, e.g. cases of all zero samples
        if (!isNonZeroPositiveFinite(norm)) {
            // Sample again using recursion.
            // A stack overflow due to a broken RNG will eventually occur
            // rather than the alternative which is an infinite loop.
            return sample();
        }
        // Normalise
        for (int i = 0; i < y.length; i++) {
            y[i] *= norm;
        }
        return y;
    }

    /**
     * Gets the number of categories.
     *
     * @return k
     */
    protected abstract int getK();

    /**
     * Create a gamma sample for the given category.
     *
     * @param category Category.
     * @return the sample
     */
    protected abstract double nextGamma(int category);

    /** {@inheritDoc} */
    // Redeclare the signature to return a DirichletSampler not a SharedStateObjectSampler<double[]>
    @Override
    public abstract DirichletSampler withUniformRandomProvider(UniformRandomProvider rng);

    /**
     * Creates a new Dirichlet distribution sampler.
     *
     * @param rng Generator of uniformly distributed random numbers.
     * @param alpha Concentration parameters.
     * @return the sampler
     * @throws IllegalArgumentException if the number of concentration parameters
     * is less than 2; or if any concentration parameter is not strictly positive.
     */
    public static DirichletSampler of(UniformRandomProvider rng,
                                      double... alpha) {
        validateNumberOfCategories(alpha.length);
        final SharedStateContinuousSampler[] samplers = new SharedStateContinuousSampler[alpha.length];
        for (int i = 0; i < samplers.length; i++) {
            samplers[i] = createSampler(rng, alpha[i]);
        }
        return new GeneralDirichletSampler(rng, samplers);
    }

    /**
     * Creates a new symmetric Dirichlet distribution sampler using the same concentration
     * parameter for each category.
     *
     * @param rng Generator of uniformly distributed random numbers.
     * @param k Number of categories.
     * @param alpha Concentration parameter.
     * @return the sampler
     * @throws IllegalArgumentException if the number of categories is
     * less than 2; or if the concentration parameter is not strictly positive.
     */
    public static DirichletSampler symmetric(UniformRandomProvider rng,
                                             int k,
                                             double alpha) {
        validateNumberOfCategories(k);
        final SharedStateContinuousSampler sampler = createSampler(rng, alpha);
        return new SymmetricDirichletSampler(rng, k, sampler);
    }

    /**
     * Validate the number of categories.
     *
     * @param k Number of categories.
     * @throws IllegalArgumentException if the number of categories is
     * less than 2.
     */
    private static void validateNumberOfCategories(int k) {
        if (k < MIN_CATGEORIES) {
            throw new IllegalArgumentException("Invalid number of categories: " + k);
        }
    }

    /**
     * Creates a gamma sampler for a category with the given concentration parameter.
     *
     * @param rng Generator of uniformly distributed random numbers.
     * @param alpha Concentration parameter.
     * @return the sampler
     * @throws IllegalArgumentException if the concentration parameter is not strictly positive.
     */
    private static SharedStateContinuousSampler createSampler(UniformRandomProvider rng,
                                                              double alpha) {
        // Negation of logic will detect NaN
        if (!isNonZeroPositiveFinite(alpha)) {
            throw new IllegalArgumentException("Invalid concentration: " + alpha);
        }
        // Create a Gamma(shape=alpha, scale=1) sampler.
        if (alpha == 1) {
            // Special case
            // Gamma(shape=1, scale=1) == Exponential(mean=1)
            return ZigguratSampler.Exponential.of(rng);
        }
        return AhrensDieterMarsagliaTsangGammaSampler.of(rng, alpha, 1);
    }

    /**
     * Return true if the value is non-zero, positive and finite.
     *
     * @param x Value.
     * @return true if non-zero positive finite
     */
    private static boolean isNonZeroPositiveFinite(double x) {
        return x > 0 && x < Double.POSITIVE_INFINITY;
    }
}