001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.rng.sampling.distribution;
018
019import org.apache.commons.rng.UniformRandomProvider;
020import org.apache.commons.rng.sampling.SharedStateObjectSampler;
021
022/**
023 * Sampling from a <a href="https://en.wikipedia.org/wiki/Dirichlet_distribution">Dirichlet
024 * distribution</a>.
025 *
026 * <p>Sampling uses:</p>
027 *
028 * <ul>
029 *   <li>{@link UniformRandomProvider#nextLong()}
030 *   <li>{@link UniformRandomProvider#nextDouble()}
031 * </ul>
032 *
033 * @since 1.4
034 */
035public abstract class DirichletSampler implements SharedStateObjectSampler<double[]> {
036    /** The minimum number of categories. */
037    private static final int MIN_CATGEORIES = 2;
038
039    /** RNG (used for the toString() method). */
040    private final UniformRandomProvider rng;
041
042    /**
043     * Sample from a Dirichlet distribution with different concentration parameters
044     * for each category.
045     */
046    private static final class GeneralDirichletSampler extends DirichletSampler {
047        /** Samplers for each category. */
048        private final SharedStateContinuousSampler[] samplers;
049
050        /**
051         * @param rng Generator of uniformly distributed random numbers.
052         * @param samplers Samplers for each category.
053         */
054        GeneralDirichletSampler(UniformRandomProvider rng,
055                                SharedStateContinuousSampler[] samplers) {
056            super(rng);
057            // Array is stored directly as it is generated within the DirichletSampler class
058            this.samplers = samplers;
059        }
060
061        @Override
062        protected int getK() {
063            return samplers.length;
064        }
065
066        @Override
067        protected double nextGamma(int i) {
068            return samplers[i].sample();
069        }
070
071        @Override
072        public GeneralDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
073            final SharedStateContinuousSampler[] newSamplers = new SharedStateContinuousSampler[samplers.length];
074            for (int i = 0; i < newSamplers.length; i++) {
075                newSamplers[i] = samplers[i].withUniformRandomProvider(rng);
076            }
077            return new GeneralDirichletSampler(rng, newSamplers);
078        }
079    }
080
081    /**
082     * Sample from a symmetric Dirichlet distribution with the same concentration parameter
083     * for each category.
084     */
085    private static final class SymmetricDirichletSampler extends DirichletSampler {
086        /** Number of categories. */
087        private final int k;
088        /** Sampler for the categories. */
089        private final SharedStateContinuousSampler sampler;
090
091        /**
092         * @param rng Generator of uniformly distributed random numbers.
093         * @param k Number of categories.
094         * @param sampler Sampler for the categories.
095         */
096        SymmetricDirichletSampler(UniformRandomProvider rng,
097                                  int k,
098                                  SharedStateContinuousSampler sampler) {
099            super(rng);
100            this.k = k;
101            this.sampler = sampler;
102        }
103
104        @Override
105        protected int getK() {
106            return k;
107        }
108
109        @Override
110        protected double nextGamma(int i) {
111            return sampler.sample();
112        }
113
114        @Override
115        public SymmetricDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
116            return new SymmetricDirichletSampler(rng, k, sampler.withUniformRandomProvider(rng));
117        }
118    }
119
120    /**
121     * @param rng Generator of uniformly distributed random numbers.
122     */
123    private DirichletSampler(UniformRandomProvider rng) {
124        this.rng = rng;
125    }
126
127    /** {@inheritDoc} */
128    @Override
129    public String toString() {
130        return "Dirichlet deviate [" + rng.toString() + "]";
131    }
132
133    @Override
134    public double[] sample() {
135        // Create Gamma(alpha_i, 1) deviates for all alpha
136        final double[] y = new double[getK()];
137        double norm = 0;
138        for (int i = 0; i < y.length; i++) {
139            final double yi = nextGamma(i);
140            norm += yi;
141            y[i] = yi;
142        }
143        // Normalize by dividing by the sum of the samples
144        norm = 1.0 / norm;
145        // Detect an invalid normalization, e.g. cases of all zero samples
146        if (!isNonZeroPositiveFinite(norm)) {
147            // Sample again using recursion.
148            // A stack overflow due to a broken RNG will eventually occur
149            // rather than the alternative which is an infinite loop.
150            return sample();
151        }
152        // Normalise
153        for (int i = 0; i < y.length; i++) {
154            y[i] *= norm;
155        }
156        return y;
157    }
158
159    /**
160     * Gets the number of categories.
161     *
162     * @return k
163     */
164    protected abstract int getK();
165
166    /**
167     * Create a gamma sample for the given category.
168     *
169     * @param category Category.
170     * @return the sample
171     */
172    protected abstract double nextGamma(int category);
173
174    /** {@inheritDoc} */
175    // Redeclare the signature to return a DirichletSampler not a SharedStateObjectSampler<double[]>
176    @Override
177    public abstract DirichletSampler withUniformRandomProvider(UniformRandomProvider rng);
178
179    /**
180     * Creates a new Dirichlet distribution sampler.
181     *
182     * @param rng Generator of uniformly distributed random numbers.
183     * @param alpha Concentration parameters.
184     * @return the sampler
185     * @throws IllegalArgumentException if the number of concentration parameters
186     * is less than 2; or if any concentration parameter is not strictly positive.
187     */
188    public static DirichletSampler of(UniformRandomProvider rng,
189                                      double... alpha) {
190        validateNumberOfCategories(alpha.length);
191        final SharedStateContinuousSampler[] samplers = new SharedStateContinuousSampler[alpha.length];
192        for (int i = 0; i < samplers.length; i++) {
193            samplers[i] = createSampler(rng, alpha[i]);
194        }
195        return new GeneralDirichletSampler(rng, samplers);
196    }
197
198    /**
199     * Creates a new symmetric Dirichlet distribution sampler using the same concentration
200     * parameter for each category.
201     *
202     * @param rng Generator of uniformly distributed random numbers.
203     * @param k Number of categories.
204     * @param alpha Concentration parameter.
205     * @return the sampler
206     * @throws IllegalArgumentException if the number of categories is
207     * less than 2; or if the concentration parameter is not strictly positive.
208     */
209    public static DirichletSampler symmetric(UniformRandomProvider rng,
210                                             int k,
211                                             double alpha) {
212        validateNumberOfCategories(k);
213        final SharedStateContinuousSampler sampler = createSampler(rng, alpha);
214        return new SymmetricDirichletSampler(rng, k, sampler);
215    }
216
217    /**
218     * Validate the number of categories.
219     *
220     * @param k Number of categories.
221     * @throws IllegalArgumentException if the number of categories is
222     * less than 2.
223     */
224    private static void validateNumberOfCategories(int k) {
225        if (k < MIN_CATGEORIES) {
226            throw new IllegalArgumentException("Invalid number of categories: " + k);
227        }
228    }
229
230    /**
231     * Creates a gamma sampler for a category with the given concentration parameter.
232     *
233     * @param rng Generator of uniformly distributed random numbers.
234     * @param alpha Concentration parameter.
235     * @return the sampler
236     * @throws IllegalArgumentException if the concentration parameter is not strictly positive.
237     */
238    private static SharedStateContinuousSampler createSampler(UniformRandomProvider rng,
239                                                              double alpha) {
240        // Negation of logic will detect NaN
241        if (!isNonZeroPositiveFinite(alpha)) {
242            throw new IllegalArgumentException("Invalid concentration: " + alpha);
243        }
244        // Create a Gamma(shape=alpha, scale=1) sampler.
245        if (alpha == 1) {
246            // Special case
247            // Gamma(shape=1, scale=1) == Exponential(mean=1)
248            return ZigguratSampler.Exponential.of(rng);
249        }
250        return AhrensDieterMarsagliaTsangGammaSampler.of(rng, alpha, 1);
251    }
252
253    /**
254     * Return true if the value is non-zero, positive and finite.
255     *
256     * @param x Value.
257     * @return true if non-zero positive finite
258     */
259    private static boolean isNonZeroPositiveFinite(double x) {
260        return x > 0 && x < Double.POSITIVE_INFINITY;
261    }
262}