001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.legacy.distribution;
018
019import java.util.ArrayList;
020import java.util.LinkedHashMap;
021import java.util.List;
022import java.util.Map;
023import java.util.Map.Entry;
024
025import org.apache.commons.statistics.distribution.DiscreteDistribution;
026import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
027import org.apache.commons.math4.legacy.exception.MathArithmeticException;
028import org.apache.commons.math4.legacy.exception.NotANumberException;
029import org.apache.commons.math4.legacy.exception.NotFiniteNumberException;
030import org.apache.commons.math4.legacy.exception.NotPositiveException;
031import org.apache.commons.rng.UniformRandomProvider;
032import org.apache.commons.math4.legacy.core.Pair;
033
034/**
035 * <p>Implementation of an integer-valued {@link EnumeratedDistribution}.</p>
036 *
037 * <p>Values with zero-probability are allowed but they do not extend the
038 * support.<br>
039 * Duplicate values are allowed. Probabilities of duplicate values are combined
040 * when computing cumulative probabilities and statistics.</p>
041 *
042 * @since 3.2
043 */
044public class EnumeratedIntegerDistribution extends AbstractIntegerDistribution {
045    /**
046     * {@link EnumeratedDistribution} instance (using the {@link Integer} wrapper)
047     * used to generate the pmf.
048     */
049    protected final EnumeratedDistribution<Integer> innerDistribution;
050
051    /**
052     * Create a discrete distribution.
053     *
054     * @param singletons array of random variable values.
055     * @param probabilities array of probabilities.
056     * @throws DimensionMismatchException if
057     * {@code singletons.length != probabilities.length}
058     * @throws NotPositiveException if any of the probabilities are negative.
059     * @throws NotFiniteNumberException if any of the probabilities are infinite.
060     * @throws NotANumberException if any of the probabilities are NaN.
061     * @throws MathArithmeticException all of the probabilities are 0.
062     */
063    public EnumeratedIntegerDistribution(final int[] singletons,
064                                         final double[] probabilities)
065        throws DimensionMismatchException,
066               NotPositiveException,
067               MathArithmeticException,
068               NotFiniteNumberException,
069               NotANumberException {
070        innerDistribution = new EnumeratedDistribution<>(createDistribution(singletons,
071                                                                            probabilities));
072    }
073
074    /**
075     * Create a discrete integer-valued distribution from the input data.
076     * Values are assigned mass based on their frequency.
077     *
078     * @param data input dataset
079     */
080    public EnumeratedIntegerDistribution(final int[] data) {
081        final Map<Integer, Integer> dataMap = new LinkedHashMap<>();
082        for (int value : data) {
083            dataMap.merge(value, 1, Integer::sum);
084        }
085        final int massPoints = dataMap.size();
086        final double denom = data.length;
087        final int[] values = new int[massPoints];
088        final double[] probabilities = new double[massPoints];
089        int index = 0;
090        for (Entry<Integer, Integer> entry : dataMap.entrySet()) {
091            values[index] = entry.getKey();
092            probabilities[index] = entry.getValue().intValue() / denom;
093            index++;
094        }
095        innerDistribution = new EnumeratedDistribution<>(createDistribution(values, probabilities));
096    }
097
098    /**
099     * Create the list of Pairs representing the distribution from singletons and probabilities.
100     *
101     * @param singletons values
102     * @param probabilities probabilities
103     * @return list of value/probability pairs
104     */
105    private static List<Pair<Integer, Double>>  createDistribution(int[] singletons, double[] probabilities) {
106        if (singletons.length != probabilities.length) {
107            throw new DimensionMismatchException(probabilities.length, singletons.length);
108        }
109
110        final List<Pair<Integer, Double>> samples = new ArrayList<>(singletons.length);
111
112        for (int i = 0; i < singletons.length; i++) {
113            samples.add(new Pair<>(singletons[i], probabilities[i]));
114        }
115        return samples;
116    }
117
118    /**
119     * {@inheritDoc}
120     */
121    @Override
122    public double probability(final int x) {
123        return innerDistribution.probability(x);
124    }
125
126    /**
127     * {@inheritDoc}
128     */
129    @Override
130    public double cumulativeProbability(final int x) {
131        double probability = 0;
132
133        for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
134            if (sample.getKey() <= x) {
135                probability += sample.getValue();
136            }
137        }
138
139        return probability;
140    }
141
142    /**
143     * {@inheritDoc}
144     *
145     * @return {@code sum(singletons[i] * probabilities[i])}
146     */
147    @Override
148    public double getMean() {
149        double mean = 0;
150
151        for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
152            mean += sample.getValue() * sample.getKey();
153        }
154
155        return mean;
156    }
157
158    /**
159     * {@inheritDoc}
160     *
161     * @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])}
162     */
163    @Override
164    public double getVariance() {
165        double mean = 0;
166        double meanOfSquares = 0;
167
168        for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
169            mean += sample.getValue() * sample.getKey();
170            meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey();
171        }
172
173        return meanOfSquares - mean * mean;
174    }
175
176    /**
177     * {@inheritDoc}
178     *
179     * Returns the lowest value with non-zero probability.
180     *
181     * @return the lowest value with non-zero probability.
182     */
183    @Override
184    public int getSupportLowerBound() {
185        int min = Integer.MAX_VALUE;
186        for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
187            if (sample.getKey() < min && sample.getValue() > 0) {
188                min = sample.getKey();
189            }
190        }
191
192        return min;
193    }
194
195    /**
196     * {@inheritDoc}
197     *
198     * Returns the highest value with non-zero probability.
199     *
200     * @return the highest value with non-zero probability.
201     */
202    @Override
203    public int getSupportUpperBound() {
204        int max = Integer.MIN_VALUE;
205        for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
206            if (sample.getKey() > max && sample.getValue() > 0) {
207                max = sample.getKey();
208            }
209        }
210
211        return max;
212    }
213
214    /**
215     * {@inheritDoc}
216     *
217     * Refer to {@link EnumeratedDistribution.Sampler} for implementation details.
218     */
219    @Override
220    public DiscreteDistribution.Sampler createSampler(final UniformRandomProvider rng) {
221        return innerDistribution.createSampler(rng)::sample;
222    }
223}