001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.legacy.distribution;
018
019import java.util.ArrayList;
020import java.util.List;
021
022import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
023import org.apache.commons.math4.legacy.exception.MathArithmeticException;
024import org.apache.commons.math4.legacy.exception.NotPositiveException;
025import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
026import org.apache.commons.rng.UniformRandomProvider;
027import org.apache.commons.math4.legacy.core.Pair;
028
029/**
030 * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model">
031 * mixture model</a> distributions.
032 *
033 * @param <T> Type of the mixture components.
034 *
035 * @since 3.1
036 */
037public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
038    extends AbstractMultivariateRealDistribution {
039    /** Normalized weight of each mixture component. */
040    private final double[] weight;
041    /** Mixture components. */
042    private final List<T> distribution;
043
044    /**
045     * Creates a mixture model from a list of distributions and their
046     * associated weights.
047     *
048     * @param components Distributions from which to sample.
049     * @throws NotPositiveException if any of the weights is negative.
050     * @throws DimensionMismatchException if not all components have the same
051     * number of variables.
052     */
053    public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
054        super(components.get(0).getSecond().getDimension());
055
056        final int numComp = components.size();
057        final int dim = getDimension();
058        double weightSum = 0;
059        for (int i = 0; i < numComp; i++) {
060            final Pair<Double, T> comp = components.get(i);
061            if (comp.getSecond().getDimension() != dim) {
062                throw new DimensionMismatchException(comp.getSecond().getDimension(), dim);
063            }
064            if (comp.getFirst() < 0) {
065                throw new NotPositiveException(comp.getFirst());
066            }
067            weightSum += comp.getFirst();
068        }
069
070        // Check for overflow.
071        if (Double.isInfinite(weightSum)) {
072            throw new MathArithmeticException(LocalizedFormats.OVERFLOW);
073        }
074
075        // Store each distribution and its normalized weight.
076        distribution = new ArrayList<>();
077        weight = new double[numComp];
078        for (int i = 0; i < numComp; i++) {
079            final Pair<Double, T> comp = components.get(i);
080            weight[i] = comp.getFirst() / weightSum;
081            distribution.add(comp.getSecond());
082        }
083    }
084
085    /** {@inheritDoc} */
086    @Override
087    public double density(final double[] values) {
088        double p = 0;
089        for (int i = 0; i < weight.length; i++) {
090            p += weight[i] * distribution.get(i).density(values);
091        }
092        return p;
093    }
094
095    /**
096     * Gets the distributions that make up the mixture model.
097     *
098     * @return the component distributions and associated weights.
099     */
100    public List<Pair<Double, T>> getComponents() {
101        final List<Pair<Double, T>> list = new ArrayList<>(weight.length);
102
103        for (int i = 0; i < weight.length; i++) {
104            list.add(new Pair<>(weight[i], distribution.get(i)));
105        }
106
107        return list;
108    }
109
110    /** {@inheritDoc} */
111    @Override
112    public MultivariateRealDistribution.Sampler createSampler(UniformRandomProvider rng) {
113        return new MixtureSampler(rng);
114    }
115
116    /**
117     * Sampler.
118     */
119    private class MixtureSampler implements MultivariateRealDistribution.Sampler {
120        /** RNG. */
121        private final UniformRandomProvider rng;
122        /** Sampler for each of the distribution in the mixture. */
123        private final MultivariateRealDistribution.Sampler[] samplers;
124
125        /**
126         * @param generator RNG.
127         */
128        MixtureSampler(UniformRandomProvider generator) {
129            rng = generator;
130
131            samplers = new MultivariateRealDistribution.Sampler[weight.length];
132            for (int i = 0; i < weight.length; i++) {
133                samplers[i] = distribution.get(i).createSampler(rng);
134            }
135        }
136
137        /** {@inheritDoc} */
138        @Override
139        public double[] sample() {
140            // Sampled values.
141            double[] vals = null;
142
143            // Determine which component to sample from.
144            final double randomValue = rng.nextDouble();
145            double sum = 0;
146
147            for (int i = 0; i < weight.length; i++) {
148                sum += weight[i];
149                if (randomValue <= sum) {
150                    // pick model i
151                    vals = samplers[i].sample();
152                    break;
153                }
154            }
155
156            if (vals == null) {
157                // This should never happen, but it ensures we won't return a null in
158                // case the loop above has some floating point inequality problem on
159                // the final iteration.
160                vals = samplers[weight.length - 1].sample();
161            }
162
163            return vals;
164        }
165    }
166}