001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.distribution;
018
019import java.util.ArrayList;
020import java.util.HashMap;
021import java.util.List;
022import java.util.Map;
023import java.util.Map.Entry;
024import java.io.Serializable;
025
026import org.apache.commons.statistics.distribution.ContinuousDistribution;
027import org.apache.commons.math4.exception.DimensionMismatchException;
028import org.apache.commons.math4.exception.MathArithmeticException;
029import org.apache.commons.math4.exception.NotANumberException;
030import org.apache.commons.math4.exception.NotFiniteNumberException;
031import org.apache.commons.math4.exception.NotPositiveException;
032import org.apache.commons.math4.exception.OutOfRangeException;
033import org.apache.commons.rng.UniformRandomProvider;
034import org.apache.commons.math4.util.Pair;
035
036/**
037 * <p>Implementation of a real-valued {@link EnumeratedDistribution}.
038 *
039 * <p>Values with zero-probability are allowed but they do not extend the
040 * support.<br>
041 * Duplicate values are allowed. Probabilities of duplicate values are combined
042 * when computing cumulative probabilities and statistics.</p>
043 *
044 * @since 3.2
045 */
046public class EnumeratedRealDistribution
047    implements ContinuousDistribution,
048               Serializable {
049    /** Serializable UID. */
050    private static final long serialVersionUID = 20160311L;
051
052    /**
053     * {@link EnumeratedDistribution} (using the {@link Double} wrapper)
054     * used to generate the pmf.
055     */
056    protected final EnumeratedDistribution<Double> innerDistribution;
057
058    /**
059     * Create a discrete real-valued distribution using the given random number generator
060     * and probability mass function enumeration.
061     *
062     * @param singletons array of random variable values.
063     * @param probabilities array of probabilities.
064     * @throws DimensionMismatchException if
065     * {@code singletons.length != probabilities.length}
066     * @throws NotPositiveException if any of the probabilities are negative.
067     * @throws NotFiniteNumberException if any of the probabilities are infinite.
068     * @throws NotANumberException if any of the probabilities are NaN.
069     * @throws MathArithmeticException all of the probabilities are 0.
070     */
071    public EnumeratedRealDistribution(final double[] singletons,
072                                      final double[] probabilities)
073        throws DimensionMismatchException,
074               NotPositiveException,
075               MathArithmeticException,
076               NotFiniteNumberException,
077               NotANumberException {
078        innerDistribution = new EnumeratedDistribution<>(createDistribution(singletons, probabilities));
079    }
080
081    /**
082     * Creates a discrete real-valued distribution from the input data.
083     * Values are assigned mass based on their frequency.
084     *
085     * @param data input dataset
086     */
087    public EnumeratedRealDistribution(final double[] data) {
088        final Map<Double, Integer> dataMap = new HashMap<>();
089        for (double value : data) {
090            Integer count = dataMap.get(value);
091            if (count == null) {
092                count = 0;
093            }
094            dataMap.put(value, ++count);
095        }
096        final int massPoints = dataMap.size();
097        final double denom = data.length;
098        final double[] values = new double[massPoints];
099        final double[] probabilities = new double[massPoints];
100        int index = 0;
101        for (Entry<Double, Integer> entry : dataMap.entrySet()) {
102            values[index] = entry.getKey();
103            probabilities[index] = entry.getValue().intValue() / denom;
104            index++;
105        }
106        innerDistribution = new EnumeratedDistribution<>(createDistribution(values, probabilities));
107    }
108
109    /**
110     * Create the list of Pairs representing the distribution from singletons and probabilities.
111     *
112     * @param singletons values
113     * @param probabilities probabilities
114     * @return list of value/probability pairs
115     */
116    private static List<Pair<Double, Double>>  createDistribution(double[] singletons, double[] probabilities) {
117        if (singletons.length != probabilities.length) {
118            throw new DimensionMismatchException(probabilities.length, singletons.length);
119        }
120
121        final List<Pair<Double, Double>> samples = new ArrayList<>(singletons.length);
122
123        for (int i = 0; i < singletons.length; i++) {
124            samples.add(new Pair<>(singletons[i], probabilities[i]));
125        }
126        return samples;
127
128    }
129
130    /**
131     * {@inheritDoc}
132     */
133    @Override
134    public double probability(final double x) {
135        return innerDistribution.probability(x);
136    }
137
138    /**
139     * For a random variable {@code X} whose values are distributed according to
140     * this distribution, this method returns {@code P(X = x)}. In other words,
141     * this method represents the probability mass function (PMF) for the
142     * distribution.
143     *
144     * @param x the point at which the PMF is evaluated
145     * @return the value of the probability mass function at point {@code x}
146     */
147    @Override
148    public double density(final double x) {
149        return probability(x);
150    }
151
152    /**
153     * {@inheritDoc}
154     */
155    @Override
156    public double cumulativeProbability(final double x) {
157        double probability = 0;
158
159        for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
160            if (sample.getKey() <= x) {
161                probability += sample.getValue();
162            }
163        }
164
165        return probability;
166    }
167
168    /**
169     * {@inheritDoc}
170     */
171    @Override
172    public double inverseCumulativeProbability(final double p) throws OutOfRangeException {
173        if (p < 0.0 || p > 1.0) {
174            throw new OutOfRangeException(p, 0, 1);
175        }
176
177        double probability = 0;
178        double x = getSupportLowerBound();
179        for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
180            if (sample.getValue() == 0.0) {
181                continue;
182            }
183
184            probability += sample.getValue();
185            x = sample.getKey();
186
187            if (probability >= p) {
188                break;
189            }
190        }
191
192        return x;
193    }
194
195    /**
196     * {@inheritDoc}
197     *
198     * @return {@code sum(singletons[i] * probabilities[i])}
199     */
200    @Override
201    public double getMean() {
202        double mean = 0;
203
204        for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
205            mean += sample.getValue() * sample.getKey();
206        }
207
208        return mean;
209    }
210
211    /**
212     * {@inheritDoc}
213     *
214     * @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])}
215     */
216    @Override
217    public double getVariance() {
218        double mean = 0;
219        double meanOfSquares = 0;
220
221        for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
222            mean += sample.getValue() * sample.getKey();
223            meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey();
224        }
225
226        return meanOfSquares - mean * mean;
227    }
228
229    /**
230     * {@inheritDoc}
231     *
232     * Returns the lowest value with non-zero probability.
233     *
234     * @return the lowest value with non-zero probability.
235     */
236    @Override
237    public double getSupportLowerBound() {
238        double min = Double.POSITIVE_INFINITY;
239        for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
240            if (sample.getKey() < min && sample.getValue() > 0) {
241                min = sample.getKey();
242            }
243        }
244
245        return min;
246    }
247
248    /**
249     * {@inheritDoc}
250     *
251     * Returns the highest value with non-zero probability.
252     *
253     * @return the highest value with non-zero probability.
254     */
255    @Override
256    public double getSupportUpperBound() {
257        double max = Double.NEGATIVE_INFINITY;
258        for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
259            if (sample.getKey() > max && sample.getValue() > 0) {
260                max = sample.getKey();
261            }
262        }
263
264        return max;
265    }
266
267    /**
268     * {@inheritDoc}
269     *
270     * The support of this distribution is connected.
271     *
272     * @return {@code true}
273     */
274    @Override
275    public boolean isSupportConnected() {
276        return true;
277    }
278
279    /** {@inheritDoc} */
280    @Override
281    public ContinuousDistribution.Sampler createSampler(final UniformRandomProvider rng) {
282        return new ContinuousDistribution.Sampler() {
283            /** Delegate. */
284            private final EnumeratedDistribution<Double>.Sampler inner =
285                innerDistribution.createSampler(rng);
286
287            /** {@inheritDoc} */
288            @Override
289            public double sample() {
290                return inner.sample();
291            }
292        };
293    }
294}