View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.distribution;
18  
19  import java.util.ArrayList;
20  import java.util.List;
21  
22  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
23  import org.apache.commons.math4.legacy.exception.MathArithmeticException;
24  import org.apache.commons.math4.legacy.exception.NotPositiveException;
25  import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
26  import org.apache.commons.rng.UniformRandomProvider;
27  import org.apache.commons.math4.legacy.core.Pair;
28  
29  /**
30   * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model">
31   * mixture model</a> distributions.
32   *
33   * @param <T> Type of the mixture components.
34   *
35   * @since 3.1
36   */
37  public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
38      extends AbstractMultivariateRealDistribution {
39      /** Normalized weight of each mixture component. */
40      private final double[] weight;
41      /** Mixture components. */
42      private final List<T> distribution;
43  
44      /**
45       * Creates a mixture model from a list of distributions and their
46       * associated weights.
47       *
48       * @param components Distributions from which to sample.
49       * @throws NotPositiveException if any of the weights is negative.
50       * @throws DimensionMismatchException if not all components have the same
51       * number of variables.
52       */
53      public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
54          super(components.get(0).getSecond().getDimension());
55  
56          final int numComp = components.size();
57          final int dim = getDimension();
58          double weightSum = 0;
59          for (int i = 0; i < numComp; i++) {
60              final Pair<Double, T> comp = components.get(i);
61              if (comp.getSecond().getDimension() != dim) {
62                  throw new DimensionMismatchException(comp.getSecond().getDimension(), dim);
63              }
64              if (comp.getFirst() < 0) {
65                  throw new NotPositiveException(comp.getFirst());
66              }
67              weightSum += comp.getFirst();
68          }
69  
70          // Check for overflow.
71          if (Double.isInfinite(weightSum)) {
72              throw new MathArithmeticException(LocalizedFormats.OVERFLOW);
73          }
74  
75          // Store each distribution and its normalized weight.
76          distribution = new ArrayList<>();
77          weight = new double[numComp];
78          for (int i = 0; i < numComp; i++) {
79              final Pair<Double, T> comp = components.get(i);
80              weight[i] = comp.getFirst() / weightSum;
81              distribution.add(comp.getSecond());
82          }
83      }
84  
85      /** {@inheritDoc} */
86      @Override
87      public double density(final double[] values) {
88          double p = 0;
89          for (int i = 0; i < weight.length; i++) {
90              p += weight[i] * distribution.get(i).density(values);
91          }
92          return p;
93      }
94  
95      /**
96       * Gets the distributions that make up the mixture model.
97       *
98       * @return the component distributions and associated weights.
99       */
100     public List<Pair<Double, T>> getComponents() {
101         final List<Pair<Double, T>> list = new ArrayList<>(weight.length);
102 
103         for (int i = 0; i < weight.length; i++) {
104             list.add(new Pair<>(weight[i], distribution.get(i)));
105         }
106 
107         return list;
108     }
109 
110     /** {@inheritDoc} */
111     @Override
112     public MultivariateRealDistribution.Sampler createSampler(UniformRandomProvider rng) {
113         return new MixtureSampler(rng);
114     }
115 
116     /**
117      * Sampler.
118      */
119     private final class MixtureSampler implements MultivariateRealDistribution.Sampler {
120         /** RNG. */
121         private final UniformRandomProvider rng;
122         /** Sampler for each of the distribution in the mixture. */
123         private final MultivariateRealDistribution.Sampler[] samplers;
124 
125         /**
126          * @param generator RNG.
127          */
128         MixtureSampler(UniformRandomProvider generator) {
129             rng = generator;
130 
131             samplers = new MultivariateRealDistribution.Sampler[weight.length];
132             for (int i = 0; i < weight.length; i++) {
133                 samplers[i] = distribution.get(i).createSampler(rng);
134             }
135         }
136 
137         /** {@inheritDoc} */
138         @Override
139         public double[] sample() {
140             // Sampled values.
141             double[] vals = null;
142 
143             // Determine which component to sample from.
144             final double randomValue = rng.nextDouble();
145             double sum = 0;
146 
147             for (int i = 0; i < weight.length; i++) {
148                 sum += weight[i];
149                 if (randomValue <= sum) {
150                     // pick model i
151                     vals = samplers[i].sample();
152                     break;
153                 }
154             }
155 
156             if (vals == null) {
157                 // This should never happen, but it ensures we won't return a null in
158                 // case the loop above has some floating point inequality problem on
159                 // the final iteration.
160                 vals = samplers[weight.length - 1].sample();
161             }
162 
163             return vals;
164         }
165     }
166 }