View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.distribution;
18  
19  import java.util.ArrayList;
20  import java.util.LinkedHashMap;
21  import java.util.List;
22  import java.util.Map;
23  import java.util.Map.Entry;
24  
25  import org.apache.commons.statistics.distribution.ContinuousDistribution;
26  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
27  import org.apache.commons.math4.legacy.exception.MathArithmeticException;
28  import org.apache.commons.math4.legacy.exception.NotANumberException;
29  import org.apache.commons.math4.legacy.exception.NotFiniteNumberException;
30  import org.apache.commons.math4.legacy.exception.NotPositiveException;
31  import org.apache.commons.math4.legacy.exception.OutOfRangeException;
32  import org.apache.commons.rng.UniformRandomProvider;
33  import org.apache.commons.math4.legacy.core.Pair;
34  
35  /**
36   * <p>Implementation of a real-valued {@link EnumeratedDistribution}.
37   *
38   * <p>Values with zero-probability are allowed but they do not extend the
39   * support.<br>
40   * Duplicate values are allowed. Probabilities of duplicate values are combined
41   * when computing cumulative probabilities and statistics.</p>
42   *
43   * @since 3.2
44   */
45  public class EnumeratedRealDistribution
46      implements ContinuousDistribution {
47      /**
48       * {@link EnumeratedDistribution} (using the {@link Double} wrapper)
49       * used to generate the pmf.
50       */
51      protected final EnumeratedDistribution<Double> innerDistribution;
52  
53      /**
54       * Create a discrete real-valued distribution using the given random number generator
55       * and probability mass function enumeration.
56       *
57       * @param singletons array of random variable values.
58       * @param probabilities array of probabilities.
59       * @throws DimensionMismatchException if
60       * {@code singletons.length != probabilities.length}
61       * @throws NotPositiveException if any of the probabilities are negative.
62       * @throws NotFiniteNumberException if any of the probabilities are infinite.
63       * @throws NotANumberException if any of the probabilities are NaN.
64       * @throws MathArithmeticException all of the probabilities are 0.
65       */
66      public EnumeratedRealDistribution(final double[] singletons,
67                                        final double[] probabilities)
68          throws DimensionMismatchException,
69                 NotPositiveException,
70                 MathArithmeticException,
71                 NotFiniteNumberException,
72                 NotANumberException {
73          innerDistribution = new EnumeratedDistribution<>(createDistribution(singletons, probabilities));
74      }
75  
76      /**
77       * Creates a discrete real-valued distribution from the input data.
78       * Values are assigned mass based on their frequency.
79       *
80       * @param data input dataset
81       */
82      public EnumeratedRealDistribution(final double[] data) {
83          final Map<Double, Integer> dataMap = new LinkedHashMap<>();
84          for (double value : data) {
85              dataMap.merge(value, 1, Integer::sum);
86          }
87          final int massPoints = dataMap.size();
88          final double denom = data.length;
89          final double[] values = new double[massPoints];
90          final double[] probabilities = new double[massPoints];
91          int index = 0;
92          for (Entry<Double, Integer> entry : dataMap.entrySet()) {
93              values[index] = entry.getKey();
94              probabilities[index] = entry.getValue().intValue() / denom;
95              index++;
96          }
97          innerDistribution = new EnumeratedDistribution<>(createDistribution(values, probabilities));
98      }
99  
100     /**
101      * Create the list of Pairs representing the distribution from singletons and probabilities.
102      *
103      * @param singletons values
104      * @param probabilities probabilities
105      * @return list of value/probability pairs
106      */
107     private static List<Pair<Double, Double>>  createDistribution(double[] singletons, double[] probabilities) {
108         if (singletons.length != probabilities.length) {
109             throw new DimensionMismatchException(probabilities.length, singletons.length);
110         }
111 
112         final List<Pair<Double, Double>> samples = new ArrayList<>(singletons.length);
113 
114         for (int i = 0; i < singletons.length; i++) {
115             samples.add(new Pair<>(singletons[i], probabilities[i]));
116         }
117         return samples;
118     }
119 
120     /**
121      * For a random variable {@code X} whose values are distributed according to
122      * this distribution, this method returns {@code P(X = x)}. In other words,
123      * this method represents the probability mass function (PMF) for the
124      * distribution.
125      *
126      * @param x the point at which the PMF is evaluated
127      * @return the value of the probability mass function at point {@code x}
128      */
129     @Override
130     public double density(final double x) {
131         return innerDistribution.probability(x);
132     }
133 
134     /**
135      * {@inheritDoc}
136      */
137     @Override
138     public double cumulativeProbability(final double x) {
139         double probability = 0;
140 
141         for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
142             if (sample.getKey() <= x) {
143                 probability += sample.getValue();
144             }
145         }
146 
147         return probability;
148     }
149 
150     /**
151      * {@inheritDoc}
152      */
153     @Override
154     public double inverseCumulativeProbability(final double p) throws OutOfRangeException {
155         if (p < 0.0 || p > 1.0) {
156             throw new OutOfRangeException(p, 0, 1);
157         }
158 
159         double probability = 0;
160         double x = getSupportLowerBound();
161         for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
162             if (sample.getValue() == 0.0) {
163                 continue;
164             }
165 
166             probability += sample.getValue();
167             x = sample.getKey();
168 
169             if (probability >= p) {
170                 break;
171             }
172         }
173 
174         return x;
175     }
176 
177     /**
178      * {@inheritDoc}
179      *
180      * @return {@code sum(singletons[i] * probabilities[i])}
181      */
182     @Override
183     public double getMean() {
184         double mean = 0;
185 
186         for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
187             mean += sample.getValue() * sample.getKey();
188         }
189 
190         return mean;
191     }
192 
193     /**
194      * {@inheritDoc}
195      *
196      * @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])}
197      */
198     @Override
199     public double getVariance() {
200         double mean = 0;
201         double meanOfSquares = 0;
202 
203         for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
204             mean += sample.getValue() * sample.getKey();
205             meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey();
206         }
207 
208         return meanOfSquares - mean * mean;
209     }
210 
211     /**
212      * {@inheritDoc}
213      *
214      * Returns the lowest value with non-zero probability.
215      *
216      * @return the lowest value with non-zero probability.
217      */
218     @Override
219     public double getSupportLowerBound() {
220         double min = Double.POSITIVE_INFINITY;
221         for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
222             if (sample.getKey() < min && sample.getValue() > 0) {
223                 min = sample.getKey();
224             }
225         }
226 
227         return min;
228     }
229 
230     /**
231      * {@inheritDoc}
232      *
233      * Returns the highest value with non-zero probability.
234      *
235      * @return the highest value with non-zero probability.
236      */
237     @Override
238     public double getSupportUpperBound() {
239         double max = Double.NEGATIVE_INFINITY;
240         for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
241             if (sample.getKey() > max && sample.getValue() > 0) {
242                 max = sample.getKey();
243             }
244         }
245 
246         return max;
247     }
248 
249     /** {@inheritDoc} */
250     @Override
251     public ContinuousDistribution.Sampler createSampler(final UniformRandomProvider rng) {
252         return innerDistribution.createSampler(rng)::sample;
253     }
254 }