001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math3.distribution; 018 019import java.util.ArrayList; 020import java.util.List; 021 022import org.apache.commons.math3.exception.DimensionMismatchException; 023import org.apache.commons.math3.exception.MathArithmeticException; 024import org.apache.commons.math3.exception.NotPositiveException; 025import org.apache.commons.math3.exception.util.LocalizedFormats; 026import org.apache.commons.math3.random.RandomGenerator; 027import org.apache.commons.math3.random.Well19937c; 028import org.apache.commons.math3.util.Pair; 029 030/** 031 * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model"> 032 * mixture model</a> distributions. 033 * 034 * @param <T> Type of the mixture components. 035 * 036 * @since 3.1 037 */ 038public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution> 039 extends AbstractMultivariateRealDistribution { 040 /** Normalized weight of each mixture component. */ 041 private final double[] weight; 042 /** Mixture components. */ 043 private final List<T> distribution; 044 045 /** 046 * Creates a mixture model from a list of distributions and their 047 * associated weights. 048 * <p> 049 * <b>Note:</b> this constructor will implicitly create an instance of 050 * {@link Well19937c} as random generator to be used for sampling only (see 051 * {@link #sample()} and {@link #sample(int)}). In case no sampling is 052 * needed for the created distribution, it is advised to pass {@code null} 053 * as random generator via the appropriate constructors to avoid the 054 * additional initialisation overhead. 055 * 056 * @param components List of (weight, distribution) pairs from which to sample. 057 */ 058 public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) { 059 this(new Well19937c(), components); 060 } 061 062 /** 063 * Creates a mixture model from a list of distributions and their 064 * associated weights. 065 * 066 * @param rng Random number generator. 067 * @param components Distributions from which to sample. 068 * @throws NotPositiveException if any of the weights is negative. 069 * @throws DimensionMismatchException if not all components have the same 070 * number of variables. 071 */ 072 public MixtureMultivariateRealDistribution(RandomGenerator rng, 073 List<Pair<Double, T>> components) { 074 super(rng, components.get(0).getSecond().getDimension()); 075 076 final int numComp = components.size(); 077 final int dim = getDimension(); 078 double weightSum = 0; 079 for (int i = 0; i < numComp; i++) { 080 final Pair<Double, T> comp = components.get(i); 081 if (comp.getSecond().getDimension() != dim) { 082 throw new DimensionMismatchException(comp.getSecond().getDimension(), dim); 083 } 084 if (comp.getFirst() < 0) { 085 throw new NotPositiveException(comp.getFirst()); 086 } 087 weightSum += comp.getFirst(); 088 } 089 090 // Check for overflow. 091 if (Double.isInfinite(weightSum)) { 092 throw new MathArithmeticException(LocalizedFormats.OVERFLOW); 093 } 094 095 // Store each distribution and its normalized weight. 096 distribution = new ArrayList<T>(); 097 weight = new double[numComp]; 098 for (int i = 0; i < numComp; i++) { 099 final Pair<Double, T> comp = components.get(i); 100 weight[i] = comp.getFirst() / weightSum; 101 distribution.add(comp.getSecond()); 102 } 103 } 104 105 /** {@inheritDoc} */ 106 public double density(final double[] values) { 107 double p = 0; 108 for (int i = 0; i < weight.length; i++) { 109 p += weight[i] * distribution.get(i).density(values); 110 } 111 return p; 112 } 113 114 /** {@inheritDoc} */ 115 @Override 116 public double[] sample() { 117 // Sampled values. 118 double[] vals = null; 119 120 // Determine which component to sample from. 121 final double randomValue = random.nextDouble(); 122 double sum = 0; 123 124 for (int i = 0; i < weight.length; i++) { 125 sum += weight[i]; 126 if (randomValue <= sum) { 127 // pick model i 128 vals = distribution.get(i).sample(); 129 break; 130 } 131 } 132 133 if (vals == null) { 134 // This should never happen, but it ensures we won't return a null in 135 // case the loop above has some floating point inequality problem on 136 // the final iteration. 137 vals = distribution.get(weight.length - 1).sample(); 138 } 139 140 return vals; 141 } 142 143 /** {@inheritDoc} */ 144 @Override 145 public void reseedRandomGenerator(long seed) { 146 // Seed needs to be propagated to underlying components 147 // in order to maintain consistency between runs. 148 super.reseedRandomGenerator(seed); 149 150 for (int i = 0; i < distribution.size(); i++) { 151 // Make each component's seed different in order to avoid 152 // using the same sequence of random numbers. 153 distribution.get(i).reseedRandomGenerator(i + 1 + seed); 154 } 155 } 156 157 /** 158 * Gets the distributions that make up the mixture model. 159 * 160 * @return the component distributions and associated weights. 161 */ 162 public List<Pair<Double, T>> getComponents() { 163 final List<Pair<Double, T>> list = new ArrayList<Pair<Double, T>>(weight.length); 164 165 for (int i = 0; i < weight.length; i++) { 166 list.add(new Pair<Double, T>(weight[i], distribution.get(i))); 167 } 168 169 return list; 170 } 171}