View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import org.apache.commons.rng.UniformRandomProvider;
20  import org.apache.commons.rng.sampling.SharedStateObjectSampler;
21  
22  /**
23   * Sampling from a <a href="https://en.wikipedia.org/wiki/Dirichlet_distribution">Dirichlet
24   * distribution</a>.
25   *
26   * <p>Sampling uses:</p>
27   *
28   * <ul>
29   *   <li>{@link UniformRandomProvider#nextLong()}
30   *   <li>{@link UniformRandomProvider#nextDouble()}
31   * </ul>
32   *
33   * @since 1.4
34   */
35  public abstract class DirichletSampler implements SharedStateObjectSampler<double[]> {
36      /** The minimum number of categories. */
37      private static final int MIN_CATGEORIES = 2;
38  
39      /** RNG (used for the toString() method). */
40      private final UniformRandomProvider rng;
41  
42      /**
43       * Sample from a Dirichlet distribution with different concentration parameters
44       * for each category.
45       */
46      private static final class GeneralDirichletSampler extends DirichletSampler {
47          /** Samplers for each category. */
48          private final SharedStateContinuousSampler[] samplers;
49  
50          /**
51           * @param rng Generator of uniformly distributed random numbers.
52           * @param samplers Samplers for each category.
53           */
54          GeneralDirichletSampler(UniformRandomProvider rng,
55                                  SharedStateContinuousSampler[] samplers) {
56              super(rng);
57              // Array is stored directly as it is generated within the DirichletSampler class
58              this.samplers = samplers;
59          }
60  
61          @Override
62          protected int getK() {
63              return samplers.length;
64          }
65  
66          @Override
67          protected double nextGamma(int i) {
68              return samplers[i].sample();
69          }
70  
71          @Override
72          public GeneralDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
73              final SharedStateContinuousSampler[] newSamplers = new SharedStateContinuousSampler[samplers.length];
74              for (int i = 0; i < newSamplers.length; i++) {
75                  newSamplers[i] = samplers[i].withUniformRandomProvider(rng);
76              }
77              return new GeneralDirichletSampler(rng, newSamplers);
78          }
79      }
80  
81      /**
82       * Sample from a symmetric Dirichlet distribution with the same concentration parameter
83       * for each category.
84       */
85      private static final class SymmetricDirichletSampler extends DirichletSampler {
86          /** Number of categories. */
87          private final int k;
88          /** Sampler for the categories. */
89          private final SharedStateContinuousSampler sampler;
90  
91          /**
92           * @param rng Generator of uniformly distributed random numbers.
93           * @param k Number of categories.
94           * @param sampler Sampler for the categories.
95           */
96          SymmetricDirichletSampler(UniformRandomProvider rng,
97                                    int k,
98                                    SharedStateContinuousSampler sampler) {
99              super(rng);
100             this.k = k;
101             this.sampler = sampler;
102         }
103 
104         @Override
105         protected int getK() {
106             return k;
107         }
108 
109         @Override
110         protected double nextGamma(int i) {
111             return sampler.sample();
112         }
113 
114         @Override
115         public SymmetricDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
116             return new SymmetricDirichletSampler(rng, k, sampler.withUniformRandomProvider(rng));
117         }
118     }
119 
120     /**
121      * @param rng Generator of uniformly distributed random numbers.
122      */
123     private DirichletSampler(UniformRandomProvider rng) {
124         this.rng = rng;
125     }
126 
127     /** {@inheritDoc} */
128     @Override
129     public String toString() {
130         return "Dirichlet deviate [" + rng.toString() + "]";
131     }
132 
133     @Override
134     public double[] sample() {
135         // Create Gamma(alpha_i, 1) deviates for all alpha
136         final double[] y = new double[getK()];
137         double norm = 0;
138         for (int i = 0; i < y.length; i++) {
139             final double yi = nextGamma(i);
140             norm += yi;
141             y[i] = yi;
142         }
143         // Normalize by dividing by the sum of the samples
144         norm = 1.0 / norm;
145         // Detect an invalid normalization, e.g. cases of all zero samples
146         if (!isNonZeroPositiveFinite(norm)) {
147             // Sample again using recursion.
148             // A stack overflow due to a broken RNG will eventually occur
149             // rather than the alternative which is an infinite loop.
150             return sample();
151         }
152         // Normalise
153         for (int i = 0; i < y.length; i++) {
154             y[i] *= norm;
155         }
156         return y;
157     }
158 
159     /**
160      * Gets the number of categories.
161      *
162      * @return k
163      */
164     protected abstract int getK();
165 
166     /**
167      * Create a gamma sample for the given category.
168      *
169      * @param category Category.
170      * @return the sample
171      */
172     protected abstract double nextGamma(int category);
173 
174     /** {@inheritDoc} */
175     // Redeclare the signature to return a DirichletSampler not a SharedStateObjectSampler<double[]>
176     @Override
177     public abstract DirichletSampler withUniformRandomProvider(UniformRandomProvider rng);
178 
179     /**
180      * Creates a new Dirichlet distribution sampler.
181      *
182      * @param rng Generator of uniformly distributed random numbers.
183      * @param alpha Concentration parameters.
184      * @return the sampler
185      * @throws IllegalArgumentException if the number of concentration parameters
186      * is less than 2; or if any concentration parameter is not strictly positive.
187      */
188     public static DirichletSampler of(UniformRandomProvider rng,
189                                       double... alpha) {
190         validateNumberOfCategories(alpha.length);
191         final SharedStateContinuousSampler[] samplers = new SharedStateContinuousSampler[alpha.length];
192         for (int i = 0; i < samplers.length; i++) {
193             samplers[i] = createSampler(rng, alpha[i]);
194         }
195         return new GeneralDirichletSampler(rng, samplers);
196     }
197 
198     /**
199      * Creates a new symmetric Dirichlet distribution sampler using the same concentration
200      * parameter for each category.
201      *
202      * @param rng Generator of uniformly distributed random numbers.
203      * @param k Number of categories.
204      * @param alpha Concentration parameter.
205      * @return the sampler
206      * @throws IllegalArgumentException if the number of categories is
207      * less than 2; or if the concentration parameter is not strictly positive.
208      */
209     public static DirichletSampler symmetric(UniformRandomProvider rng,
210                                              int k,
211                                              double alpha) {
212         validateNumberOfCategories(k);
213         final SharedStateContinuousSampler sampler = createSampler(rng, alpha);
214         return new SymmetricDirichletSampler(rng, k, sampler);
215     }
216 
217     /**
218      * Validate the number of categories.
219      *
220      * @param k Number of categories.
221      * @throws IllegalArgumentException if the number of categories is
222      * less than 2.
223      */
224     private static void validateNumberOfCategories(int k) {
225         if (k < MIN_CATGEORIES) {
226             throw new IllegalArgumentException("Invalid number of categories: " + k);
227         }
228     }
229 
230     /**
231      * Creates a gamma sampler for a category with the given concentration parameter.
232      *
233      * @param rng Generator of uniformly distributed random numbers.
234      * @param alpha Concentration parameter.
235      * @return the sampler
236      * @throws IllegalArgumentException if the concentration parameter is not strictly positive.
237      */
238     private static SharedStateContinuousSampler createSampler(UniformRandomProvider rng,
239                                                               double alpha) {
240         // Negation of logic will detect NaN
241         if (!isNonZeroPositiveFinite(alpha)) {
242             throw new IllegalArgumentException("Invalid concentration: " + alpha);
243         }
244         // Create a Gamma(shape=alpha, scale=1) sampler.
245         if (alpha == 1) {
246             // Special case
247             // Gamma(shape=1, scale=1) == Exponential(mean=1)
248             return ZigguratSampler.Exponential.of(rng);
249         }
250         return AhrensDieterMarsagliaTsangGammaSampler.of(rng, alpha, 1);
251     }
252 
253     /**
254      * Return true if the value is non-zero, positive and finite.
255      *
256      * @param x Value.
257      * @return true if non-zero positive finite
258      */
259     private static boolean isNonZeroPositiveFinite(double x) {
260         return x > 0 && x < Double.POSITIVE_INFINITY;
261     }
262 }