View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import org.apache.commons.rng.UniformRandomProvider;
20  import org.apache.commons.rng.sampling.SharedStateObjectSampler;
21  
22  /**
23   * Sampling from a <a href="https://en.wikipedia.org/wiki/Dirichlet_distribution">Dirichlet
24   * distribution</a>.
25   *
26   * <p>Sampling uses:</p>
27   *
28   * <ul>
29   *   <li>{@link UniformRandomProvider#nextLong()}
30   *   <li>{@link UniformRandomProvider#nextDouble()}
31   * </ul>
32   *
33   * @since 1.4
34   */
35  public abstract class DirichletSampler implements SharedStateObjectSampler<double[]> {
36      /** The minimum number of categories. */
37      private static final int MIN_CATGEORIES = 2;
38  
39      /** RNG (used for the toString() method). */
40      private final UniformRandomProvider rng;
41  
42      /**
43       * Sample from a Dirichlet distribution with different concentration parameters
44       * for each category.
45       */
46      private static final class GeneralDirichletSampler extends DirichletSampler {
47          /** Samplers for each category. */
48          private final SharedStateContinuousSampler[] samplers;
49  
50          /**
51           * @param rng Generator of uniformly distributed random numbers.
52           * @param samplers Samplers for each category.
53           */
54          GeneralDirichletSampler(UniformRandomProvider rng,
55                                  SharedStateContinuousSampler[] samplers) {
56              super(rng);
57              // Array is stored directly as it is generated within the DirichletSampler class
58              this.samplers = samplers;
59          }
60  
61          @Override
62          protected int getK() {
63              return samplers.length;
64          }
65  
66          @Override
67          protected double nextGamma(int i) {
68              return samplers[i].sample();
69          }
70  
71          @Override
72          public GeneralDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
73              final SharedStateContinuousSampler[] newSamplers = new SharedStateContinuousSampler[samplers.length];
74              for (int i = 0; i < newSamplers.length; i++) {
75                  newSamplers[i] = samplers[i].withUniformRandomProvider(rng);
76              }
77              return new GeneralDirichletSampler(rng, newSamplers);
78          }
79      }
80  
81      /**
82       * Sample from a symmetric Dirichlet distribution with the same concentration parameter
83       * for each category.
84       */
85      private static final class SymmetricDirichletSampler extends DirichletSampler {
86          /** Number of categories. */
87          private final int k;
88          /** Sampler for the categories. */
89          private final SharedStateContinuousSampler sampler;
90  
91          /**
92           * @param rng Generator of uniformly distributed random numbers.
93           * @param k Number of categories.
94           * @param sampler Sampler for the categories.
95           */
96          SymmetricDirichletSampler(UniformRandomProvider rng,
97                                    int k,
98                                    SharedStateContinuousSampler sampler) {
99              super(rng);
100             this.k = k;
101             this.sampler = sampler;
102         }
103 
104         @Override
105         protected int getK() {
106             return k;
107         }
108 
109         @Override
110         protected double nextGamma(int i) {
111             return sampler.sample();
112         }
113 
114         @Override
115         public SymmetricDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
116             return new SymmetricDirichletSampler(rng, k, sampler.withUniformRandomProvider(rng));
117         }
118     }
119 
120     /**
121      * @param rng Generator of uniformly distributed random numbers.
122      */
123     DirichletSampler(UniformRandomProvider rng) {
124         this.rng = rng;
125     }
126 
127     /** {@inheritDoc} */
128     @Override
129     public String toString() {
130         return "Dirichlet deviate [" + rng.toString() + "]";
131     }
132 
133     /** {@inheritDoc} */
134     @Override
135     public double[] sample() {
136         // Create Gamma(alpha_i, 1) deviates for all alpha
137         final double[] y = new double[getK()];
138         double norm = 0;
139         for (int i = 0; i < y.length; i++) {
140             final double yi = nextGamma(i);
141             norm += yi;
142             y[i] = yi;
143         }
144         // Normalize by dividing by the sum of the samples
145         norm = 1.0 / norm;
146         // Detect an invalid normalization, e.g. cases of all zero samples
147         if (!isNonZeroPositiveFinite(norm)) {
148             // Sample again using recursion.
149             // A stack overflow due to a broken RNG will eventually occur
150             // rather than the alternative which is an infinite loop.
151             return sample();
152         }
153         // Normalise
154         for (int i = 0; i < y.length; i++) {
155             y[i] *= norm;
156         }
157         return y;
158     }
159 
160     /**
161      * Gets the number of categories.
162      *
163      * @return k
164      */
165     protected abstract int getK();
166 
167     /**
168      * Create a gamma sample for the given category.
169      *
170      * @param category Category.
171      * @return the sample
172      */
173     protected abstract double nextGamma(int category);
174 
175     /** {@inheritDoc} */
176     // Redeclare the signature to return a DirichletSampler not a SharedStateObjectSampler<double[]>
177     @Override
178     public abstract DirichletSampler withUniformRandomProvider(UniformRandomProvider rng);
179 
180     /**
181      * Creates a new Dirichlet distribution sampler.
182      *
183      * @param rng Generator of uniformly distributed random numbers.
184      * @param alpha Concentration parameters.
185      * @return the sampler
186      * @throws IllegalArgumentException if the number of concentration parameters
187      * is less than 2; or if any concentration parameter is not strictly positive.
188      */
189     public static DirichletSampler of(UniformRandomProvider rng,
190                                       double... alpha) {
191         validateNumberOfCategories(alpha.length);
192         final SharedStateContinuousSampler[] samplers = new SharedStateContinuousSampler[alpha.length];
193         for (int i = 0; i < samplers.length; i++) {
194             samplers[i] = createSampler(rng, alpha[i]);
195         }
196         return new GeneralDirichletSampler(rng, samplers);
197     }
198 
199     /**
200      * Creates a new symmetric Dirichlet distribution sampler using the same concentration
201      * parameter for each category.
202      *
203      * @param rng Generator of uniformly distributed random numbers.
204      * @param k Number of categories.
205      * @param alpha Concentration parameter.
206      * @return the sampler
207      * @throws IllegalArgumentException if the number of categories is
208      * less than 2; or if the concentration parameter is not strictly positive.
209      */
210     public static DirichletSampler symmetric(UniformRandomProvider rng,
211                                              int k,
212                                              double alpha) {
213         validateNumberOfCategories(k);
214         final SharedStateContinuousSampler sampler = createSampler(rng, alpha);
215         return new SymmetricDirichletSampler(rng, k, sampler);
216     }
217 
218     /**
219      * Validate the number of categories.
220      *
221      * @param k Number of categories.
222      * @throws IllegalArgumentException if the number of categories is
223      * less than 2.
224      */
225     private static void validateNumberOfCategories(int k) {
226         if (k < MIN_CATGEORIES) {
227             throw new IllegalArgumentException("Invalid number of categories: " + k);
228         }
229     }
230 
231     /**
232      * Creates a gamma sampler for a category with the given concentration parameter.
233      *
234      * @param rng Generator of uniformly distributed random numbers.
235      * @param alpha Concentration parameter.
236      * @return the sampler
237      * @throws IllegalArgumentException if the concentration parameter is not strictly positive.
238      */
239     private static SharedStateContinuousSampler createSampler(UniformRandomProvider rng,
240                                                               double alpha) {
241         InternalUtils.requireStrictlyPositiveFinite(alpha, "alpha concentration");
242         // Create a Gamma(shape=alpha, scale=1) sampler.
243         if (alpha == 1) {
244             // Special case
245             // Gamma(shape=1, scale=1) == Exponential(mean=1)
246             return ZigguratSampler.Exponential.of(rng);
247         }
248         return AhrensDieterMarsagliaTsangGammaSampler.of(rng, alpha, 1);
249     }
250 
251     /**
252      * Return true if the value is non-zero, positive and finite.
253      *
254      * @param x Value.
255      * @return true if non-zero positive finite
256      */
257     private static boolean isNonZeroPositiveFinite(double x) {
258         return x > 0 && x < Double.POSITIVE_INFINITY;
259     }
260 }