001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.distribution;
018
019import java.util.ArrayList;
020import java.util.HashMap;
021import java.util.List;
022import java.util.Map;
023import java.util.Map.Entry;
024
025import org.apache.commons.statistics.distribution.DiscreteDistribution;
026import org.apache.commons.math4.exception.DimensionMismatchException;
027import org.apache.commons.math4.exception.MathArithmeticException;
028import org.apache.commons.math4.exception.NotANumberException;
029import org.apache.commons.math4.exception.NotFiniteNumberException;
030import org.apache.commons.math4.exception.NotPositiveException;
031import org.apache.commons.rng.UniformRandomProvider;
032import org.apache.commons.math4.util.Pair;
033
034/**
035 * <p>Implementation of an integer-valued {@link EnumeratedDistribution}.</p>
036 *
037 * <p>Values with zero-probability are allowed but they do not extend the
038 * support.<br>
039 * Duplicate values are allowed. Probabilities of duplicate values are combined
040 * when computing cumulative probabilities and statistics.</p>
041 *
042 * @since 3.2
043 */
044public class EnumeratedIntegerDistribution extends AbstractIntegerDistribution {
045    /** Serializable UID. */
046    private static final long serialVersionUID = 20130308L;
047    /**
048     * {@link EnumeratedDistribution} instance (using the {@link Integer} wrapper)
049     * used to generate the pmf.
050     */
051    protected final EnumeratedDistribution<Integer> innerDistribution;
052
053    /**
054     * Create a discrete distribution.
055     *
056     * @param singletons array of random variable values.
057     * @param probabilities array of probabilities.
058     * @throws DimensionMismatchException if
059     * {@code singletons.length != probabilities.length}
060     * @throws NotPositiveException if any of the probabilities are negative.
061     * @throws NotFiniteNumberException if any of the probabilities are infinite.
062     * @throws NotANumberException if any of the probabilities are NaN.
063     * @throws MathArithmeticException all of the probabilities are 0.
064     */
065    public EnumeratedIntegerDistribution(final int[] singletons,
066                                         final double[] probabilities)
067        throws DimensionMismatchException,
068               NotPositiveException,
069               MathArithmeticException,
070               NotFiniteNumberException,
071               NotANumberException {
072        innerDistribution = new EnumeratedDistribution<>(createDistribution(singletons,
073                                                                                   probabilities));
074    }
075
076    /**
077     * Create a discrete integer-valued distribution from the input data.
078     * Values are assigned mass based on their frequency.
079     *
080     * @param data input dataset
081     */
082    public EnumeratedIntegerDistribution(final int[] data) {
083        final Map<Integer, Integer> dataMap = new HashMap<>();
084        for (int value : data) {
085            Integer count = dataMap.get(value);
086            if (count == null) {
087                count = 0;
088            }
089            dataMap.put(value, ++count);
090        }
091        final int massPoints = dataMap.size();
092        final double denom = data.length;
093        final int[] values = new int[massPoints];
094        final double[] probabilities = new double[massPoints];
095        int index = 0;
096        for (Entry<Integer, Integer> entry : dataMap.entrySet()) {
097            values[index] = entry.getKey();
098            probabilities[index] = entry.getValue().intValue() / denom;
099            index++;
100        }
101        innerDistribution = new EnumeratedDistribution<>(createDistribution(values, probabilities));
102    }
103
104    /**
105     * Create the list of Pairs representing the distribution from singletons and probabilities.
106     *
107     * @param singletons values
108     * @param probabilities probabilities
109     * @return list of value/probability pairs
110     */
111    private static List<Pair<Integer, Double>>  createDistribution(int[] singletons, double[] probabilities) {
112        if (singletons.length != probabilities.length) {
113            throw new DimensionMismatchException(probabilities.length, singletons.length);
114        }
115
116        final List<Pair<Integer, Double>> samples = new ArrayList<>(singletons.length);
117
118        for (int i = 0; i < singletons.length; i++) {
119            samples.add(new Pair<>(singletons[i], probabilities[i]));
120        }
121        return samples;
122    }
123
124    /**
125     * {@inheritDoc}
126     */
127    @Override
128    public double probability(final int x) {
129        return innerDistribution.probability(x);
130    }
131
132    /**
133     * {@inheritDoc}
134     */
135    @Override
136    public double cumulativeProbability(final int x) {
137        double probability = 0;
138
139        for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
140            if (sample.getKey() <= x) {
141                probability += sample.getValue();
142            }
143        }
144
145        return probability;
146    }
147
148    /**
149     * {@inheritDoc}
150     *
151     * @return {@code sum(singletons[i] * probabilities[i])}
152     */
153    @Override
154    public double getMean() {
155        double mean = 0;
156
157        for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
158            mean += sample.getValue() * sample.getKey();
159        }
160
161        return mean;
162    }
163
164    /**
165     * {@inheritDoc}
166     *
167     * @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])}
168     */
169    @Override
170    public double getVariance() {
171        double mean = 0;
172        double meanOfSquares = 0;
173
174        for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
175            mean += sample.getValue() * sample.getKey();
176            meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey();
177        }
178
179        return meanOfSquares - mean * mean;
180    }
181
182    /**
183     * {@inheritDoc}
184     *
185     * Returns the lowest value with non-zero probability.
186     *
187     * @return the lowest value with non-zero probability.
188     */
189    @Override
190    public int getSupportLowerBound() {
191        int min = Integer.MAX_VALUE;
192        for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
193            if (sample.getKey() < min && sample.getValue() > 0) {
194                min = sample.getKey();
195            }
196        }
197
198        return min;
199    }
200
201    /**
202     * {@inheritDoc}
203     *
204     * Returns the highest value with non-zero probability.
205     *
206     * @return the highest value with non-zero probability.
207     */
208    @Override
209    public int getSupportUpperBound() {
210        int max = Integer.MIN_VALUE;
211        for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
212            if (sample.getKey() > max && sample.getValue() > 0) {
213                max = sample.getKey();
214            }
215        }
216
217        return max;
218    }
219
220    /**
221     * {@inheritDoc}
222     *
223     * The support of this distribution is connected.
224     *
225     * @return {@code true}
226     */
227    @Override
228    public boolean isSupportConnected() {
229        return true;
230    }
231
232    /** {@inheritDoc} */
233    @Override
234    public DiscreteDistribution.Sampler createSampler(final UniformRandomProvider rng) {
235        return new DiscreteDistribution.Sampler() {
236            /** Delegate. */
237            private final EnumeratedDistribution<Integer>.Sampler inner =
238                innerDistribution.createSampler(rng);
239
240            /** {@inheritDoc} */
241            @Override
242            public int sample() {
243                return inner.sample();
244            }
245        };
246    }
247}