001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math3.distribution; 018 019import java.util.ArrayList; 020import java.util.HashMap; 021import java.util.List; 022import java.util.Map; 023import java.util.Map.Entry; 024 025import org.apache.commons.math3.exception.DimensionMismatchException; 026import org.apache.commons.math3.exception.MathArithmeticException; 027import org.apache.commons.math3.exception.NotANumberException; 028import org.apache.commons.math3.exception.NotFiniteNumberException; 029import org.apache.commons.math3.exception.NotPositiveException; 030import org.apache.commons.math3.random.RandomGenerator; 031import org.apache.commons.math3.random.Well19937c; 032import org.apache.commons.math3.util.Pair; 033 034/** 035 * <p>Implementation of an integer-valued {@link EnumeratedDistribution}.</p> 036 * 037 * <p>Values with zero-probability are allowed but they do not extend the 038 * support.<br/> 039 * Duplicate values are allowed. Probabilities of duplicate values are combined 040 * when computing cumulative probabilities and statistics.</p> 041 * 042 * @since 3.2 043 */ 044public class EnumeratedIntegerDistribution extends AbstractIntegerDistribution { 045 046 /** Serializable UID. */ 047 private static final long serialVersionUID = 20130308L; 048 049 /** 050 * {@link EnumeratedDistribution} instance (using the {@link Integer} wrapper) 051 * used to generate the pmf. 052 */ 053 protected final EnumeratedDistribution<Integer> innerDistribution; 054 055 /** 056 * Create a discrete distribution using the given probability mass function 057 * definition. 058 * <p> 059 * <b>Note:</b> this constructor will implicitly create an instance of 060 * {@link Well19937c} as random generator to be used for sampling only (see 061 * {@link #sample()} and {@link #sample(int)}). In case no sampling is 062 * needed for the created distribution, it is advised to pass {@code null} 063 * as random generator via the appropriate constructors to avoid the 064 * additional initialisation overhead. 065 * 066 * @param singletons array of random variable values. 067 * @param probabilities array of probabilities. 068 * @throws DimensionMismatchException if 069 * {@code singletons.length != probabilities.length} 070 * @throws NotPositiveException if any of the probabilities are negative. 071 * @throws NotFiniteNumberException if any of the probabilities are infinite. 072 * @throws NotANumberException if any of the probabilities are NaN. 073 * @throws MathArithmeticException all of the probabilities are 0. 074 */ 075 public EnumeratedIntegerDistribution(final int[] singletons, final double[] probabilities) 076 throws DimensionMismatchException, NotPositiveException, MathArithmeticException, 077 NotFiniteNumberException, NotANumberException{ 078 this(new Well19937c(), singletons, probabilities); 079 } 080 081 /** 082 * Create a discrete distribution using the given random number generator 083 * and probability mass function definition. 084 * 085 * @param rng random number generator. 086 * @param singletons array of random variable values. 087 * @param probabilities array of probabilities. 088 * @throws DimensionMismatchException if 089 * {@code singletons.length != probabilities.length} 090 * @throws NotPositiveException if any of the probabilities are negative. 091 * @throws NotFiniteNumberException if any of the probabilities are infinite. 092 * @throws NotANumberException if any of the probabilities are NaN. 093 * @throws MathArithmeticException all of the probabilities are 0. 094 */ 095 public EnumeratedIntegerDistribution(final RandomGenerator rng, 096 final int[] singletons, final double[] probabilities) 097 throws DimensionMismatchException, NotPositiveException, MathArithmeticException, 098 NotFiniteNumberException, NotANumberException { 099 super(rng); 100 innerDistribution = new EnumeratedDistribution<Integer>( 101 rng, createDistribution(singletons, probabilities)); 102 } 103 104 /** 105 * Create a discrete integer-valued distribution from the input data. Values are assigned 106 * mass based on their frequency. 107 * 108 * @param rng random number generator used for sampling 109 * @param data input dataset 110 * @since 3.6 111 */ 112 public EnumeratedIntegerDistribution(final RandomGenerator rng, final int[] data) { 113 super(rng); 114 final Map<Integer, Integer> dataMap = new HashMap<Integer, Integer>(); 115 for (int value : data) { 116 Integer count = dataMap.get(value); 117 if (count == null) { 118 count = 0; 119 } 120 dataMap.put(value, ++count); 121 } 122 final int massPoints = dataMap.size(); 123 final double denom = data.length; 124 final int[] values = new int[massPoints]; 125 final double[] probabilities = new double[massPoints]; 126 int index = 0; 127 for (Entry<Integer, Integer> entry : dataMap.entrySet()) { 128 values[index] = entry.getKey(); 129 probabilities[index] = entry.getValue().intValue() / denom; 130 index++; 131 } 132 innerDistribution = new EnumeratedDistribution<Integer>(rng, createDistribution(values, probabilities)); 133 } 134 135 /** 136 * Create a discrete integer-valued distribution from the input data. Values are assigned 137 * mass based on their frequency. For example, [0,1,1,2] as input creates a distribution 138 * with values 0, 1 and 2 having probability masses 0.25, 0.5 and 0.25 respectively, 139 * 140 * @param data input dataset 141 * @since 3.6 142 */ 143 public EnumeratedIntegerDistribution(final int[] data) { 144 this(new Well19937c(), data); 145 } 146 147 /** 148 * Create the list of Pairs representing the distribution from singletons and probabilities. 149 * 150 * @param singletons values 151 * @param probabilities probabilities 152 * @return list of value/probability pairs 153 */ 154 private static List<Pair<Integer, Double>> createDistribution(int[] singletons, double[] probabilities) { 155 if (singletons.length != probabilities.length) { 156 throw new DimensionMismatchException(probabilities.length, singletons.length); 157 } 158 159 final List<Pair<Integer, Double>> samples = new ArrayList<Pair<Integer, Double>>(singletons.length); 160 161 for (int i = 0; i < singletons.length; i++) { 162 samples.add(new Pair<Integer, Double>(singletons[i], probabilities[i])); 163 } 164 return samples; 165 166 } 167 168 /** 169 * {@inheritDoc} 170 */ 171 public double probability(final int x) { 172 return innerDistribution.probability(x); 173 } 174 175 /** 176 * {@inheritDoc} 177 */ 178 public double cumulativeProbability(final int x) { 179 double probability = 0; 180 181 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) { 182 if (sample.getKey() <= x) { 183 probability += sample.getValue(); 184 } 185 } 186 187 return probability; 188 } 189 190 /** 191 * {@inheritDoc} 192 * 193 * @return {@code sum(singletons[i] * probabilities[i])} 194 */ 195 public double getNumericalMean() { 196 double mean = 0; 197 198 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) { 199 mean += sample.getValue() * sample.getKey(); 200 } 201 202 return mean; 203 } 204 205 /** 206 * {@inheritDoc} 207 * 208 * @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])} 209 */ 210 public double getNumericalVariance() { 211 double mean = 0; 212 double meanOfSquares = 0; 213 214 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) { 215 mean += sample.getValue() * sample.getKey(); 216 meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey(); 217 } 218 219 return meanOfSquares - mean * mean; 220 } 221 222 /** 223 * {@inheritDoc} 224 * 225 * Returns the lowest value with non-zero probability. 226 * 227 * @return the lowest value with non-zero probability. 228 */ 229 public int getSupportLowerBound() { 230 int min = Integer.MAX_VALUE; 231 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) { 232 if (sample.getKey() < min && sample.getValue() > 0) { 233 min = sample.getKey(); 234 } 235 } 236 237 return min; 238 } 239 240 /** 241 * {@inheritDoc} 242 * 243 * Returns the highest value with non-zero probability. 244 * 245 * @return the highest value with non-zero probability. 246 */ 247 public int getSupportUpperBound() { 248 int max = Integer.MIN_VALUE; 249 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) { 250 if (sample.getKey() > max && sample.getValue() > 0) { 251 max = sample.getKey(); 252 } 253 } 254 255 return max; 256 } 257 258 /** 259 * {@inheritDoc} 260 * 261 * The support of this distribution is connected. 262 * 263 * @return {@code true} 264 */ 265 public boolean isSupportConnected() { 266 return true; 267 } 268 269 /** 270 * {@inheritDoc} 271 */ 272 @Override 273 public int sample() { 274 return innerDistribution.sample(); 275 } 276}