001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math3.distribution; 018 019import java.util.ArrayList; 020import java.util.HashMap; 021import java.util.List; 022import java.util.Map; 023import java.util.Map.Entry; 024 025import org.apache.commons.math3.exception.DimensionMismatchException; 026import org.apache.commons.math3.exception.MathArithmeticException; 027import org.apache.commons.math3.exception.NotANumberException; 028import org.apache.commons.math3.exception.NotFiniteNumberException; 029import org.apache.commons.math3.exception.NotPositiveException; 030import org.apache.commons.math3.exception.OutOfRangeException; 031import org.apache.commons.math3.random.RandomGenerator; 032import org.apache.commons.math3.random.Well19937c; 033import org.apache.commons.math3.util.Pair; 034 035/** 036 * <p>Implementation of a real-valued {@link EnumeratedDistribution}. 037 * 038 * <p>Values with zero-probability are allowed but they do not extend the 039 * support.<br/> 040 * Duplicate values are allowed. Probabilities of duplicate values are combined 041 * when computing cumulative probabilities and statistics.</p> 042 * 043 * @since 3.2 044 */ 045public class EnumeratedRealDistribution extends AbstractRealDistribution { 046 047 /** Serializable UID. */ 048 private static final long serialVersionUID = 20130308L; 049 050 /** 051 * {@link EnumeratedDistribution} (using the {@link Double} wrapper) 052 * used to generate the pmf. 053 */ 054 protected final EnumeratedDistribution<Double> innerDistribution; 055 056 /** 057 * Create a discrete real-valued distribution using the given probability mass function 058 * enumeration. 059 * <p> 060 * <b>Note:</b> this constructor will implicitly create an instance of 061 * {@link Well19937c} as random generator to be used for sampling only (see 062 * {@link #sample()} and {@link #sample(int)}). In case no sampling is 063 * needed for the created distribution, it is advised to pass {@code null} 064 * as random generator via the appropriate constructors to avoid the 065 * additional initialisation overhead. 066 * 067 * @param singletons array of random variable values. 068 * @param probabilities array of probabilities. 069 * @throws DimensionMismatchException if 070 * {@code singletons.length != probabilities.length} 071 * @throws NotPositiveException if any of the probabilities are negative. 072 * @throws NotFiniteNumberException if any of the probabilities are infinite. 073 * @throws NotANumberException if any of the probabilities are NaN. 074 * @throws MathArithmeticException all of the probabilities are 0. 075 */ 076 public EnumeratedRealDistribution(final double[] singletons, final double[] probabilities) 077 throws DimensionMismatchException, NotPositiveException, MathArithmeticException, 078 NotFiniteNumberException, NotANumberException { 079 this(new Well19937c(), singletons, probabilities); 080 } 081 082 /** 083 * Create a discrete real-valued distribution using the given random number generator 084 * and probability mass function enumeration. 085 * 086 * @param rng random number generator. 087 * @param singletons array of random variable values. 088 * @param probabilities array of probabilities. 089 * @throws DimensionMismatchException if 090 * {@code singletons.length != probabilities.length} 091 * @throws NotPositiveException if any of the probabilities are negative. 092 * @throws NotFiniteNumberException if any of the probabilities are infinite. 093 * @throws NotANumberException if any of the probabilities are NaN. 094 * @throws MathArithmeticException all of the probabilities are 0. 095 */ 096 public EnumeratedRealDistribution(final RandomGenerator rng, 097 final double[] singletons, final double[] probabilities) 098 throws DimensionMismatchException, NotPositiveException, MathArithmeticException, 099 NotFiniteNumberException, NotANumberException { 100 super(rng); 101 102 innerDistribution = new EnumeratedDistribution<Double>( 103 rng, createDistribution(singletons, probabilities)); 104 } 105 106 /** 107 * Create a discrete real-valued distribution from the input data. Values are assigned 108 * mass based on their frequency. 109 * 110 * @param rng random number generator used for sampling 111 * @param data input dataset 112 * @since 3.6 113 */ 114 public EnumeratedRealDistribution(final RandomGenerator rng, final double[] data) { 115 super(rng); 116 final Map<Double, Integer> dataMap = new HashMap<Double, Integer>(); 117 for (double value : data) { 118 Integer count = dataMap.get(value); 119 if (count == null) { 120 count = 0; 121 } 122 dataMap.put(value, ++count); 123 } 124 final int massPoints = dataMap.size(); 125 final double denom = data.length; 126 final double[] values = new double[massPoints]; 127 final double[] probabilities = new double[massPoints]; 128 int index = 0; 129 for (Entry<Double, Integer> entry : dataMap.entrySet()) { 130 values[index] = entry.getKey(); 131 probabilities[index] = entry.getValue().intValue() / denom; 132 index++; 133 } 134 innerDistribution = new EnumeratedDistribution<Double>(rng, createDistribution(values, probabilities)); 135 } 136 137 /** 138 * Create a discrete real-valued distribution from the input data. Values are assigned 139 * mass based on their frequency. For example, [0,1,1,2] as input creates a distribution 140 * with values 0, 1 and 2 having probability masses 0.25, 0.5 and 0.25 respectively, 141 * 142 * @param data input dataset 143 * @since 3.6 144 */ 145 public EnumeratedRealDistribution(final double[] data) { 146 this(new Well19937c(), data); 147 } 148 /** 149 * Create the list of Pairs representing the distribution from singletons and probabilities. 150 * 151 * @param singletons values 152 * @param probabilities probabilities 153 * @return list of value/probability pairs 154 */ 155 private static List<Pair<Double, Double>> createDistribution(double[] singletons, double[] probabilities) { 156 if (singletons.length != probabilities.length) { 157 throw new DimensionMismatchException(probabilities.length, singletons.length); 158 } 159 160 final List<Pair<Double, Double>> samples = new ArrayList<Pair<Double, Double>>(singletons.length); 161 162 for (int i = 0; i < singletons.length; i++) { 163 samples.add(new Pair<Double, Double>(singletons[i], probabilities[i])); 164 } 165 return samples; 166 167 } 168 169 /** 170 * {@inheritDoc} 171 */ 172 @Override 173 public double probability(final double x) { 174 return innerDistribution.probability(x); 175 } 176 177 /** 178 * For a random variable {@code X} whose values are distributed according to 179 * this distribution, this method returns {@code P(X = x)}. In other words, 180 * this method represents the probability mass function (PMF) for the 181 * distribution. 182 * 183 * @param x the point at which the PMF is evaluated 184 * @return the value of the probability mass function at point {@code x} 185 */ 186 public double density(final double x) { 187 return probability(x); 188 } 189 190 /** 191 * {@inheritDoc} 192 */ 193 public double cumulativeProbability(final double x) { 194 double probability = 0; 195 196 for (final Pair<Double, Double> sample : innerDistribution.getPmf()) { 197 if (sample.getKey() <= x) { 198 probability += sample.getValue(); 199 } 200 } 201 202 return probability; 203 } 204 205 /** 206 * {@inheritDoc} 207 */ 208 @Override 209 public double inverseCumulativeProbability(final double p) throws OutOfRangeException { 210 if (p < 0.0 || p > 1.0) { 211 throw new OutOfRangeException(p, 0, 1); 212 } 213 214 double probability = 0; 215 double x = getSupportLowerBound(); 216 for (final Pair<Double, Double> sample : innerDistribution.getPmf()) { 217 if (sample.getValue() == 0.0) { 218 continue; 219 } 220 221 probability += sample.getValue(); 222 x = sample.getKey(); 223 224 if (probability >= p) { 225 break; 226 } 227 } 228 229 return x; 230 } 231 232 /** 233 * {@inheritDoc} 234 * 235 * @return {@code sum(singletons[i] * probabilities[i])} 236 */ 237 public double getNumericalMean() { 238 double mean = 0; 239 240 for (final Pair<Double, Double> sample : innerDistribution.getPmf()) { 241 mean += sample.getValue() * sample.getKey(); 242 } 243 244 return mean; 245 } 246 247 /** 248 * {@inheritDoc} 249 * 250 * @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])} 251 */ 252 public double getNumericalVariance() { 253 double mean = 0; 254 double meanOfSquares = 0; 255 256 for (final Pair<Double, Double> sample : innerDistribution.getPmf()) { 257 mean += sample.getValue() * sample.getKey(); 258 meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey(); 259 } 260 261 return meanOfSquares - mean * mean; 262 } 263 264 /** 265 * {@inheritDoc} 266 * 267 * Returns the lowest value with non-zero probability. 268 * 269 * @return the lowest value with non-zero probability. 270 */ 271 public double getSupportLowerBound() { 272 double min = Double.POSITIVE_INFINITY; 273 for (final Pair<Double, Double> sample : innerDistribution.getPmf()) { 274 if (sample.getKey() < min && sample.getValue() > 0) { 275 min = sample.getKey(); 276 } 277 } 278 279 return min; 280 } 281 282 /** 283 * {@inheritDoc} 284 * 285 * Returns the highest value with non-zero probability. 286 * 287 * @return the highest value with non-zero probability. 288 */ 289 public double getSupportUpperBound() { 290 double max = Double.NEGATIVE_INFINITY; 291 for (final Pair<Double, Double> sample : innerDistribution.getPmf()) { 292 if (sample.getKey() > max && sample.getValue() > 0) { 293 max = sample.getKey(); 294 } 295 } 296 297 return max; 298 } 299 300 /** 301 * {@inheritDoc} 302 * 303 * The support of this distribution includes the lower bound. 304 * 305 * @return {@code true} 306 */ 307 public boolean isSupportLowerBoundInclusive() { 308 return true; 309 } 310 311 /** 312 * {@inheritDoc} 313 * 314 * The support of this distribution includes the upper bound. 315 * 316 * @return {@code true} 317 */ 318 public boolean isSupportUpperBoundInclusive() { 319 return true; 320 } 321 322 /** 323 * {@inheritDoc} 324 * 325 * The support of this distribution is connected. 326 * 327 * @return {@code true} 328 */ 329 public boolean isSupportConnected() { 330 return true; 331 } 332 333 /** 334 * {@inheritDoc} 335 */ 336 @Override 337 public double sample() { 338 return innerDistribution.sample(); 339 } 340}