View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.distribution;
18  
19  import java.util.ArrayList;
20  import java.util.LinkedHashMap;
21  import java.util.List;
22  import java.util.Map;
23  import java.util.Map.Entry;
24  
25  import org.apache.commons.statistics.distribution.DiscreteDistribution;
26  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
27  import org.apache.commons.math4.legacy.exception.MathArithmeticException;
28  import org.apache.commons.math4.legacy.exception.NotANumberException;
29  import org.apache.commons.math4.legacy.exception.NotFiniteNumberException;
30  import org.apache.commons.math4.legacy.exception.NotPositiveException;
31  import org.apache.commons.rng.UniformRandomProvider;
32  import org.apache.commons.math4.legacy.core.Pair;
33  
34  /**
35   * <p>Implementation of an integer-valued {@link EnumeratedDistribution}.</p>
36   *
37   * <p>Values with zero-probability are allowed but they do not extend the
38   * support.<br>
39   * Duplicate values are allowed. Probabilities of duplicate values are combined
40   * when computing cumulative probabilities and statistics.</p>
41   *
42   * @since 3.2
43   */
44  public class EnumeratedIntegerDistribution extends AbstractIntegerDistribution {
45      /**
46       * {@link EnumeratedDistribution} instance (using the {@link Integer} wrapper)
47       * used to generate the pmf.
48       */
49      protected final EnumeratedDistribution<Integer> innerDistribution;
50  
51      /**
52       * Create a discrete distribution.
53       *
54       * @param singletons array of random variable values.
55       * @param probabilities array of probabilities.
56       * @throws DimensionMismatchException if
57       * {@code singletons.length != probabilities.length}
58       * @throws NotPositiveException if any of the probabilities are negative.
59       * @throws NotFiniteNumberException if any of the probabilities are infinite.
60       * @throws NotANumberException if any of the probabilities are NaN.
61       * @throws MathArithmeticException all of the probabilities are 0.
62       */
63      public EnumeratedIntegerDistribution(final int[] singletons,
64                                           final double[] probabilities)
65          throws DimensionMismatchException,
66                 NotPositiveException,
67                 MathArithmeticException,
68                 NotFiniteNumberException,
69                 NotANumberException {
70          innerDistribution = new EnumeratedDistribution<>(createDistribution(singletons,
71                                                                              probabilities));
72      }
73  
74      /**
75       * Create a discrete integer-valued distribution from the input data.
76       * Values are assigned mass based on their frequency.
77       *
78       * @param data input dataset
79       */
80      public EnumeratedIntegerDistribution(final int[] data) {
81          final Map<Integer, Integer> dataMap = new LinkedHashMap<>();
82          for (int value : data) {
83              dataMap.merge(value, 1, Integer::sum);
84          }
85          final int massPoints = dataMap.size();
86          final double denom = data.length;
87          final int[] values = new int[massPoints];
88          final double[] probabilities = new double[massPoints];
89          int index = 0;
90          for (Entry<Integer, Integer> entry : dataMap.entrySet()) {
91              values[index] = entry.getKey();
92              probabilities[index] = entry.getValue().intValue() / denom;
93              index++;
94          }
95          innerDistribution = new EnumeratedDistribution<>(createDistribution(values, probabilities));
96      }
97  
98      /**
99       * Create the list of Pairs representing the distribution from singletons and probabilities.
100      *
101      * @param singletons values
102      * @param probabilities probabilities
103      * @return list of value/probability pairs
104      */
105     private static List<Pair<Integer, Double>>  createDistribution(int[] singletons, double[] probabilities) {
106         if (singletons.length != probabilities.length) {
107             throw new DimensionMismatchException(probabilities.length, singletons.length);
108         }
109 
110         final List<Pair<Integer, Double>> samples = new ArrayList<>(singletons.length);
111 
112         for (int i = 0; i < singletons.length; i++) {
113             samples.add(new Pair<>(singletons[i], probabilities[i]));
114         }
115         return samples;
116     }
117 
118     /**
119      * {@inheritDoc}
120      */
121     @Override
122     public double probability(final int x) {
123         return innerDistribution.probability(x);
124     }
125 
126     /**
127      * {@inheritDoc}
128      */
129     @Override
130     public double cumulativeProbability(final int x) {
131         double probability = 0;
132 
133         for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
134             if (sample.getKey() <= x) {
135                 probability += sample.getValue();
136             }
137         }
138 
139         return probability;
140     }
141 
142     /**
143      * {@inheritDoc}
144      *
145      * @return {@code sum(singletons[i] * probabilities[i])}
146      */
147     @Override
148     public double getMean() {
149         double mean = 0;
150 
151         for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
152             mean += sample.getValue() * sample.getKey();
153         }
154 
155         return mean;
156     }
157 
158     /**
159      * {@inheritDoc}
160      *
161      * @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])}
162      */
163     @Override
164     public double getVariance() {
165         double mean = 0;
166         double meanOfSquares = 0;
167 
168         for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
169             mean += sample.getValue() * sample.getKey();
170             meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey();
171         }
172 
173         return meanOfSquares - mean * mean;
174     }
175 
176     /**
177      * {@inheritDoc}
178      *
179      * Returns the lowest value with non-zero probability.
180      *
181      * @return the lowest value with non-zero probability.
182      */
183     @Override
184     public int getSupportLowerBound() {
185         int min = Integer.MAX_VALUE;
186         for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
187             if (sample.getKey() < min && sample.getValue() > 0) {
188                 min = sample.getKey();
189             }
190         }
191 
192         return min;
193     }
194 
195     /**
196      * {@inheritDoc}
197      *
198      * Returns the highest value with non-zero probability.
199      *
200      * @return the highest value with non-zero probability.
201      */
202     @Override
203     public int getSupportUpperBound() {
204         int max = Integer.MIN_VALUE;
205         for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
206             if (sample.getKey() > max && sample.getValue() > 0) {
207                 max = sample.getKey();
208             }
209         }
210 
211         return max;
212     }
213 
214     /**
215      * {@inheritDoc}
216      *
217      * Refer to {@link EnumeratedDistribution.Sampler} for implementation details.
218      */
219     @Override
220     public DiscreteDistribution.Sampler createSampler(final UniformRandomProvider rng) {
221         return innerDistribution.createSampler(rng)::sample;
222     }
223 }