001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.distribution;
018
019import java.util.ArrayList;
020import java.util.List;
021
022import org.apache.commons.math3.exception.DimensionMismatchException;
023import org.apache.commons.math3.exception.MathArithmeticException;
024import org.apache.commons.math3.exception.NotPositiveException;
025import org.apache.commons.math3.exception.util.LocalizedFormats;
026import org.apache.commons.math3.random.RandomGenerator;
027import org.apache.commons.math3.random.Well19937c;
028import org.apache.commons.math3.util.Pair;
029
030/**
031 * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model">
032 * mixture model</a> distributions.
033 *
034 * @param <T> Type of the mixture components.
035 *
036 * @since 3.1
037 */
038public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
039    extends AbstractMultivariateRealDistribution {
040    /** Normalized weight of each mixture component. */
041    private final double[] weight;
042    /** Mixture components. */
043    private final List<T> distribution;
044
045    /**
046     * Creates a mixture model from a list of distributions and their
047     * associated weights.
048     * <p>
049     * <b>Note:</b> this constructor will implicitly create an instance of
050     * {@link Well19937c} as random generator to be used for sampling only (see
051     * {@link #sample()} and {@link #sample(int)}). In case no sampling is
052     * needed for the created distribution, it is advised to pass {@code null}
053     * as random generator via the appropriate constructors to avoid the
054     * additional initialisation overhead.
055     *
056     * @param components List of (weight, distribution) pairs from which to sample.
057     */
058    public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
059        this(new Well19937c(), components);
060    }
061
062    /**
063     * Creates a mixture model from a list of distributions and their
064     * associated weights.
065     *
066     * @param rng Random number generator.
067     * @param components Distributions from which to sample.
068     * @throws NotPositiveException if any of the weights is negative.
069     * @throws DimensionMismatchException if not all components have the same
070     * number of variables.
071     */
072    public MixtureMultivariateRealDistribution(RandomGenerator rng,
073                                               List<Pair<Double, T>> components) {
074        super(rng, components.get(0).getSecond().getDimension());
075
076        final int numComp = components.size();
077        final int dim = getDimension();
078        double weightSum = 0;
079        for (int i = 0; i < numComp; i++) {
080            final Pair<Double, T> comp = components.get(i);
081            if (comp.getSecond().getDimension() != dim) {
082                throw new DimensionMismatchException(comp.getSecond().getDimension(), dim);
083            }
084            if (comp.getFirst() < 0) {
085                throw new NotPositiveException(comp.getFirst());
086            }
087            weightSum += comp.getFirst();
088        }
089
090        // Check for overflow.
091        if (Double.isInfinite(weightSum)) {
092            throw new MathArithmeticException(LocalizedFormats.OVERFLOW);
093        }
094
095        // Store each distribution and its normalized weight.
096        distribution = new ArrayList<T>();
097        weight = new double[numComp];
098        for (int i = 0; i < numComp; i++) {
099            final Pair<Double, T> comp = components.get(i);
100            weight[i] = comp.getFirst() / weightSum;
101            distribution.add(comp.getSecond());
102        }
103    }
104
105    /** {@inheritDoc} */
106    public double density(final double[] values) {
107        double p = 0;
108        for (int i = 0; i < weight.length; i++) {
109            p += weight[i] * distribution.get(i).density(values);
110        }
111        return p;
112    }
113
114    /** {@inheritDoc} */
115    @Override
116    public double[] sample() {
117        // Sampled values.
118        double[] vals = null;
119
120        // Determine which component to sample from.
121        final double randomValue = random.nextDouble();
122        double sum = 0;
123
124        for (int i = 0; i < weight.length; i++) {
125            sum += weight[i];
126            if (randomValue <= sum) {
127                // pick model i
128                vals = distribution.get(i).sample();
129                break;
130            }
131        }
132
133        if (vals == null) {
134            // This should never happen, but it ensures we won't return a null in
135            // case the loop above has some floating point inequality problem on
136            // the final iteration.
137            vals = distribution.get(weight.length - 1).sample();
138        }
139
140        return vals;
141    }
142
143    /** {@inheritDoc} */
144    @Override
145    public void reseedRandomGenerator(long seed) {
146        // Seed needs to be propagated to underlying components
147        // in order to maintain consistency between runs.
148        super.reseedRandomGenerator(seed);
149
150        for (int i = 0; i < distribution.size(); i++) {
151            // Make each component's seed different in order to avoid
152            // using the same sequence of random numbers.
153            distribution.get(i).reseedRandomGenerator(i + 1 + seed);
154        }
155    }
156
157    /**
158     * Gets the distributions that make up the mixture model.
159     *
160     * @return the component distributions and associated weights.
161     */
162    public List<Pair<Double, T>> getComponents() {
163        final List<Pair<Double, T>> list = new ArrayList<Pair<Double, T>>(weight.length);
164
165        for (int i = 0; i < weight.length; i++) {
166            list.add(new Pair<Double, T>(weight[i], distribution.get(i)));
167        }
168
169        return list;
170    }
171}