001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.legacy.distribution;
018
019import java.util.ArrayList;
020import java.util.LinkedHashMap;
021import java.util.List;
022import java.util.Map;
023import java.util.Map.Entry;
024
025import org.apache.commons.statistics.distribution.ContinuousDistribution;
026import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
027import org.apache.commons.math4.legacy.exception.MathArithmeticException;
028import org.apache.commons.math4.legacy.exception.NotANumberException;
029import org.apache.commons.math4.legacy.exception.NotFiniteNumberException;
030import org.apache.commons.math4.legacy.exception.NotPositiveException;
031import org.apache.commons.math4.legacy.exception.OutOfRangeException;
032import org.apache.commons.rng.UniformRandomProvider;
033import org.apache.commons.math4.legacy.core.Pair;
034
035/**
036 * <p>Implementation of a real-valued {@link EnumeratedDistribution}.
037 *
038 * <p>Values with zero-probability are allowed but they do not extend the
039 * support.<br>
040 * Duplicate values are allowed. Probabilities of duplicate values are combined
041 * when computing cumulative probabilities and statistics.</p>
042 *
043 * @since 3.2
044 */
045public class EnumeratedRealDistribution
046    implements ContinuousDistribution {
047    /**
048     * {@link EnumeratedDistribution} (using the {@link Double} wrapper)
049     * used to generate the pmf.
050     */
051    protected final EnumeratedDistribution<Double> innerDistribution;
052
053    /**
054     * Create a discrete real-valued distribution using the given random number generator
055     * and probability mass function enumeration.
056     *
057     * @param singletons array of random variable values.
058     * @param probabilities array of probabilities.
059     * @throws DimensionMismatchException if
060     * {@code singletons.length != probabilities.length}
061     * @throws NotPositiveException if any of the probabilities are negative.
062     * @throws NotFiniteNumberException if any of the probabilities are infinite.
063     * @throws NotANumberException if any of the probabilities are NaN.
064     * @throws MathArithmeticException all of the probabilities are 0.
065     */
066    public EnumeratedRealDistribution(final double[] singletons,
067                                      final double[] probabilities)
068        throws DimensionMismatchException,
069               NotPositiveException,
070               MathArithmeticException,
071               NotFiniteNumberException,
072               NotANumberException {
073        innerDistribution = new EnumeratedDistribution<>(createDistribution(singletons, probabilities));
074    }
075
076    /**
077     * Creates a discrete real-valued distribution from the input data.
078     * Values are assigned mass based on their frequency.
079     *
080     * @param data input dataset
081     */
082    public EnumeratedRealDistribution(final double[] data) {
083        final Map<Double, Integer> dataMap = new LinkedHashMap<>();
084        for (double value : data) {
085            dataMap.merge(value, 1, Integer::sum);
086        }
087        final int massPoints = dataMap.size();
088        final double denom = data.length;
089        final double[] values = new double[massPoints];
090        final double[] probabilities = new double[massPoints];
091        int index = 0;
092        for (Entry<Double, Integer> entry : dataMap.entrySet()) {
093            values[index] = entry.getKey();
094            probabilities[index] = entry.getValue().intValue() / denom;
095            index++;
096        }
097        innerDistribution = new EnumeratedDistribution<>(createDistribution(values, probabilities));
098    }
099
100    /**
101     * Create the list of Pairs representing the distribution from singletons and probabilities.
102     *
103     * @param singletons values
104     * @param probabilities probabilities
105     * @return list of value/probability pairs
106     */
107    private static List<Pair<Double, Double>>  createDistribution(double[] singletons, double[] probabilities) {
108        if (singletons.length != probabilities.length) {
109            throw new DimensionMismatchException(probabilities.length, singletons.length);
110        }
111
112        final List<Pair<Double, Double>> samples = new ArrayList<>(singletons.length);
113
114        for (int i = 0; i < singletons.length; i++) {
115            samples.add(new Pair<>(singletons[i], probabilities[i]));
116        }
117        return samples;
118    }
119
120    /**
121     * For a random variable {@code X} whose values are distributed according to
122     * this distribution, this method returns {@code P(X = x)}. In other words,
123     * this method represents the probability mass function (PMF) for the
124     * distribution.
125     *
126     * @param x the point at which the PMF is evaluated
127     * @return the value of the probability mass function at point {@code x}
128     */
129    @Override
130    public double density(final double x) {
131        return innerDistribution.probability(x);
132    }
133
134    /**
135     * {@inheritDoc}
136     */
137    @Override
138    public double cumulativeProbability(final double x) {
139        double probability = 0;
140
141        for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
142            if (sample.getKey() <= x) {
143                probability += sample.getValue();
144            }
145        }
146
147        return probability;
148    }
149
150    /**
151     * {@inheritDoc}
152     */
153    @Override
154    public double inverseCumulativeProbability(final double p) throws OutOfRangeException {
155        if (p < 0.0 || p > 1.0) {
156            throw new OutOfRangeException(p, 0, 1);
157        }
158
159        double probability = 0;
160        double x = getSupportLowerBound();
161        for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
162            if (sample.getValue() == 0.0) {
163                continue;
164            }
165
166            probability += sample.getValue();
167            x = sample.getKey();
168
169            if (probability >= p) {
170                break;
171            }
172        }
173
174        return x;
175    }
176
177    /**
178     * {@inheritDoc}
179     *
180     * @return {@code sum(singletons[i] * probabilities[i])}
181     */
182    @Override
183    public double getMean() {
184        double mean = 0;
185
186        for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
187            mean += sample.getValue() * sample.getKey();
188        }
189
190        return mean;
191    }
192
193    /**
194     * {@inheritDoc}
195     *
196     * @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])}
197     */
198    @Override
199    public double getVariance() {
200        double mean = 0;
201        double meanOfSquares = 0;
202
203        for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
204            mean += sample.getValue() * sample.getKey();
205            meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey();
206        }
207
208        return meanOfSquares - mean * mean;
209    }
210
211    /**
212     * {@inheritDoc}
213     *
214     * Returns the lowest value with non-zero probability.
215     *
216     * @return the lowest value with non-zero probability.
217     */
218    @Override
219    public double getSupportLowerBound() {
220        double min = Double.POSITIVE_INFINITY;
221        for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
222            if (sample.getKey() < min && sample.getValue() > 0) {
223                min = sample.getKey();
224            }
225        }
226
227        return min;
228    }
229
230    /**
231     * {@inheritDoc}
232     *
233     * Returns the highest value with non-zero probability.
234     *
235     * @return the highest value with non-zero probability.
236     */
237    @Override
238    public double getSupportUpperBound() {
239        double max = Double.NEGATIVE_INFINITY;
240        for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
241            if (sample.getKey() > max && sample.getValue() > 0) {
242                max = sample.getKey();
243            }
244        }
245
246        return max;
247    }
248
249    /** {@inheritDoc} */
250    @Override
251    public ContinuousDistribution.Sampler createSampler(final UniformRandomProvider rng) {
252        return innerDistribution.createSampler(rng)::sample;
253    }
254}