001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math4.legacy.distribution; 018 019import java.util.ArrayList; 020import java.util.LinkedHashMap; 021import java.util.List; 022import java.util.Map; 023import java.util.Map.Entry; 024 025import org.apache.commons.statistics.distribution.DiscreteDistribution; 026import org.apache.commons.math4.legacy.exception.DimensionMismatchException; 027import org.apache.commons.math4.legacy.exception.MathArithmeticException; 028import org.apache.commons.math4.legacy.exception.NotANumberException; 029import org.apache.commons.math4.legacy.exception.NotFiniteNumberException; 030import org.apache.commons.math4.legacy.exception.NotPositiveException; 031import org.apache.commons.rng.UniformRandomProvider; 032import org.apache.commons.math4.legacy.core.Pair; 033 034/** 035 * <p>Implementation of an integer-valued {@link EnumeratedDistribution}.</p> 036 * 037 * <p>Values with zero-probability are allowed but they do not extend the 038 * support.<br> 039 * Duplicate values are allowed. Probabilities of duplicate values are combined 040 * when computing cumulative probabilities and statistics.</p> 041 * 042 * @since 3.2 043 */ 044public class EnumeratedIntegerDistribution extends AbstractIntegerDistribution { 045 /** 046 * {@link EnumeratedDistribution} instance (using the {@link Integer} wrapper) 047 * used to generate the pmf. 048 */ 049 protected final EnumeratedDistribution<Integer> innerDistribution; 050 051 /** 052 * Create a discrete distribution. 053 * 054 * @param singletons array of random variable values. 055 * @param probabilities array of probabilities. 056 * @throws DimensionMismatchException if 057 * {@code singletons.length != probabilities.length} 058 * @throws NotPositiveException if any of the probabilities are negative. 059 * @throws NotFiniteNumberException if any of the probabilities are infinite. 060 * @throws NotANumberException if any of the probabilities are NaN. 061 * @throws MathArithmeticException all of the probabilities are 0. 062 */ 063 public EnumeratedIntegerDistribution(final int[] singletons, 064 final double[] probabilities) 065 throws DimensionMismatchException, 066 NotPositiveException, 067 MathArithmeticException, 068 NotFiniteNumberException, 069 NotANumberException { 070 innerDistribution = new EnumeratedDistribution<>(createDistribution(singletons, 071 probabilities)); 072 } 073 074 /** 075 * Create a discrete integer-valued distribution from the input data. 076 * Values are assigned mass based on their frequency. 077 * 078 * @param data input dataset 079 */ 080 public EnumeratedIntegerDistribution(final int[] data) { 081 final Map<Integer, Integer> dataMap = new LinkedHashMap<>(); 082 for (int value : data) { 083 dataMap.merge(value, 1, Integer::sum); 084 } 085 final int massPoints = dataMap.size(); 086 final double denom = data.length; 087 final int[] values = new int[massPoints]; 088 final double[] probabilities = new double[massPoints]; 089 int index = 0; 090 for (Entry<Integer, Integer> entry : dataMap.entrySet()) { 091 values[index] = entry.getKey(); 092 probabilities[index] = entry.getValue().intValue() / denom; 093 index++; 094 } 095 innerDistribution = new EnumeratedDistribution<>(createDistribution(values, probabilities)); 096 } 097 098 /** 099 * Create the list of Pairs representing the distribution from singletons and probabilities. 100 * 101 * @param singletons values 102 * @param probabilities probabilities 103 * @return list of value/probability pairs 104 */ 105 private static List<Pair<Integer, Double>> createDistribution(int[] singletons, double[] probabilities) { 106 if (singletons.length != probabilities.length) { 107 throw new DimensionMismatchException(probabilities.length, singletons.length); 108 } 109 110 final List<Pair<Integer, Double>> samples = new ArrayList<>(singletons.length); 111 112 for (int i = 0; i < singletons.length; i++) { 113 samples.add(new Pair<>(singletons[i], probabilities[i])); 114 } 115 return samples; 116 } 117 118 /** 119 * {@inheritDoc} 120 */ 121 @Override 122 public double probability(final int x) { 123 return innerDistribution.probability(x); 124 } 125 126 /** 127 * {@inheritDoc} 128 */ 129 @Override 130 public double cumulativeProbability(final int x) { 131 double probability = 0; 132 133 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) { 134 if (sample.getKey() <= x) { 135 probability += sample.getValue(); 136 } 137 } 138 139 return probability; 140 } 141 142 /** 143 * {@inheritDoc} 144 * 145 * @return {@code sum(singletons[i] * probabilities[i])} 146 */ 147 @Override 148 public double getMean() { 149 double mean = 0; 150 151 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) { 152 mean += sample.getValue() * sample.getKey(); 153 } 154 155 return mean; 156 } 157 158 /** 159 * {@inheritDoc} 160 * 161 * @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])} 162 */ 163 @Override 164 public double getVariance() { 165 double mean = 0; 166 double meanOfSquares = 0; 167 168 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) { 169 mean += sample.getValue() * sample.getKey(); 170 meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey(); 171 } 172 173 return meanOfSquares - mean * mean; 174 } 175 176 /** 177 * {@inheritDoc} 178 * 179 * Returns the lowest value with non-zero probability. 180 * 181 * @return the lowest value with non-zero probability. 182 */ 183 @Override 184 public int getSupportLowerBound() { 185 int min = Integer.MAX_VALUE; 186 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) { 187 if (sample.getKey() < min && sample.getValue() > 0) { 188 min = sample.getKey(); 189 } 190 } 191 192 return min; 193 } 194 195 /** 196 * {@inheritDoc} 197 * 198 * Returns the highest value with non-zero probability. 199 * 200 * @return the highest value with non-zero probability. 201 */ 202 @Override 203 public int getSupportUpperBound() { 204 int max = Integer.MIN_VALUE; 205 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) { 206 if (sample.getKey() > max && sample.getValue() > 0) { 207 max = sample.getKey(); 208 } 209 } 210 211 return max; 212 } 213 214 /** 215 * {@inheritDoc} 216 * 217 * Refer to {@link EnumeratedDistribution.Sampler} for implementation details. 218 */ 219 @Override 220 public DiscreteDistribution.Sampler createSampler(final UniformRandomProvider rng) { 221 return innerDistribution.createSampler(rng)::sample; 222 } 223}