001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.rng.sampling; 019 020import java.util.List; 021import java.util.Map; 022import java.util.ArrayList; 023 024import org.apache.commons.rng.UniformRandomProvider; 025import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler; 026import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler; 027 028/** 029 * Sampling from a collection of items with user-defined 030 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution"> 031 * probabilities</a>. 032 * Note that if all unique items are assigned the same probability, 033 * it is much more efficient to use {@link CollectionSampler}. 034 * 035 * <p>Sampling uses {@link UniformRandomProvider#nextDouble()}.</p> 036 * 037 * @param <T> Type of items in the collection. 038 * 039 * @since 1.1 040 */ 041public class DiscreteProbabilityCollectionSampler<T> implements SharedStateObjectSampler<T> { 042 /** The error message for an empty collection. */ 043 private static final String EMPTY_COLLECTION = "Empty collection"; 044 /** Collection to be sampled from. */ 045 private final List<T> items; 046 /** Sampler for the probabilities. */ 047 private final SharedStateDiscreteSampler sampler; 048 049 /** 050 * Creates a sampler. 051 * 052 * @param rng Generator of uniformly distributed random numbers. 053 * @param collection Collection to be sampled, with the probabilities 054 * associated to each of its items. 055 * A (shallow) copy of the items will be stored in the created instance. 056 * The probabilities must be non-negative, but zero values are allowed 057 * and their sum does not have to equal one (input will be normalized 058 * to make the probabilities sum to one). 059 * @throws IllegalArgumentException if {@code collection} is empty, a 060 * probability is negative, infinite or {@code NaN}, or the sum of all 061 * probabilities is not strictly positive. 062 */ 063 public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, 064 Map<T, Double> collection) { 065 if (collection.isEmpty()) { 066 throw new IllegalArgumentException(EMPTY_COLLECTION); 067 } 068 069 // Extract the items and probabilities 070 final int size = collection.size(); 071 items = new ArrayList<>(size); 072 final double[] probabilities = new double[size]; 073 074 int count = 0; 075 for (final Map.Entry<T, Double> e : collection.entrySet()) { 076 items.add(e.getKey()); 077 probabilities[count++] = e.getValue(); 078 } 079 080 // Delegate sampling 081 sampler = createSampler(rng, probabilities); 082 } 083 084 /** 085 * Creates a sampler. 086 * 087 * @param rng Generator of uniformly distributed random numbers. 088 * @param collection Collection to be sampled. 089 * A (shallow) copy of the items will be stored in the created instance. 090 * @param probabilities Probability associated to each item of the 091 * {@code collection}. 092 * The probabilities must be non-negative, but zero values are allowed 093 * and their sum does not have to equal one (input will be normalized 094 * to make the probabilities sum to one). 095 * @throws IllegalArgumentException if {@code collection} is empty or 096 * a probability is negative, infinite or {@code NaN}, or if the number 097 * of items in the {@code collection} is not equal to the number of 098 * provided {@code probabilities}. 099 */ 100 public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, 101 List<T> collection, 102 double[] probabilities) { 103 if (collection.isEmpty()) { 104 throw new IllegalArgumentException(EMPTY_COLLECTION); 105 } 106 final int len = probabilities.length; 107 if (len != collection.size()) { 108 throw new IllegalArgumentException("Size mismatch: " + 109 len + " != " + 110 collection.size()); 111 } 112 // Shallow copy the list 113 items = new ArrayList<>(collection); 114 // Delegate sampling 115 sampler = createSampler(rng, probabilities); 116 } 117 118 /** 119 * @param rng Generator of uniformly distributed random numbers. 120 * @param source Source to copy. 121 */ 122 private DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, 123 DiscreteProbabilityCollectionSampler<T> source) { 124 this.items = source.items; 125 this.sampler = source.sampler.withUniformRandomProvider(rng); 126 } 127 128 /** 129 * Picks one of the items from the collection passed to the constructor. 130 * 131 * @return a random sample. 132 */ 133 @Override 134 public T sample() { 135 return items.get(sampler.sample()); 136 } 137 138 /** 139 * {@inheritDoc} 140 * 141 * @since 1.3 142 */ 143 @Override 144 public DiscreteProbabilityCollectionSampler<T> withUniformRandomProvider(UniformRandomProvider rng) { 145 return new DiscreteProbabilityCollectionSampler<>(rng, this); 146 } 147 148 /** 149 * Creates the sampler of the enumerated probability distribution. 150 * 151 * @param rng Generator of uniformly distributed random numbers. 152 * @param probabilities Probability associated to each item. 153 * @return the sampler 154 */ 155 private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng, 156 double[] probabilities) { 157 return GuideTableDiscreteSampler.of(rng, probabilities); 158 } 159}