MixtureMultivariateRealDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.distribution;

  18. import java.util.ArrayList;
  19. import java.util.List;

  20. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  21. import org.apache.commons.math4.legacy.exception.MathArithmeticException;
  22. import org.apache.commons.math4.legacy.exception.NotPositiveException;
  23. import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
  24. import org.apache.commons.rng.UniformRandomProvider;
  25. import org.apache.commons.math4.legacy.core.Pair;

  26. /**
  27.  * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model">
  28.  * mixture model</a> distributions.
  29.  *
  30.  * @param <T> Type of the mixture components.
  31.  *
  32.  * @since 3.1
  33.  */
  34. public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
  35.     extends AbstractMultivariateRealDistribution {
  36.     /** Normalized weight of each mixture component. */
  37.     private final double[] weight;
  38.     /** Mixture components. */
  39.     private final List<T> distribution;

  40.     /**
  41.      * Creates a mixture model from a list of distributions and their
  42.      * associated weights.
  43.      *
  44.      * @param components Distributions from which to sample.
  45.      * @throws NotPositiveException if any of the weights is negative.
  46.      * @throws DimensionMismatchException if not all components have the same
  47.      * number of variables.
  48.      */
  49.     public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
  50.         super(components.get(0).getSecond().getDimension());

  51.         final int numComp = components.size();
  52.         final int dim = getDimension();
  53.         double weightSum = 0;
  54.         for (int i = 0; i < numComp; i++) {
  55.             final Pair<Double, T> comp = components.get(i);
  56.             if (comp.getSecond().getDimension() != dim) {
  57.                 throw new DimensionMismatchException(comp.getSecond().getDimension(), dim);
  58.             }
  59.             if (comp.getFirst() < 0) {
  60.                 throw new NotPositiveException(comp.getFirst());
  61.             }
  62.             weightSum += comp.getFirst();
  63.         }

  64.         // Check for overflow.
  65.         if (Double.isInfinite(weightSum)) {
  66.             throw new MathArithmeticException(LocalizedFormats.OVERFLOW);
  67.         }

  68.         // Store each distribution and its normalized weight.
  69.         distribution = new ArrayList<>();
  70.         weight = new double[numComp];
  71.         for (int i = 0; i < numComp; i++) {
  72.             final Pair<Double, T> comp = components.get(i);
  73.             weight[i] = comp.getFirst() / weightSum;
  74.             distribution.add(comp.getSecond());
  75.         }
  76.     }

  77.     /** {@inheritDoc} */
  78.     @Override
  79.     public double density(final double[] values) {
  80.         double p = 0;
  81.         for (int i = 0; i < weight.length; i++) {
  82.             p += weight[i] * distribution.get(i).density(values);
  83.         }
  84.         return p;
  85.     }

  86.     /**
  87.      * Gets the distributions that make up the mixture model.
  88.      *
  89.      * @return the component distributions and associated weights.
  90.      */
  91.     public List<Pair<Double, T>> getComponents() {
  92.         final List<Pair<Double, T>> list = new ArrayList<>(weight.length);

  93.         for (int i = 0; i < weight.length; i++) {
  94.             list.add(new Pair<>(weight[i], distribution.get(i)));
  95.         }

  96.         return list;
  97.     }

  98.     /** {@inheritDoc} */
  99.     @Override
  100.     public MultivariateRealDistribution.Sampler createSampler(UniformRandomProvider rng) {
  101.         return new MixtureSampler(rng);
  102.     }

  103.     /**
  104.      * Sampler.
  105.      */
  106.     private final class MixtureSampler implements MultivariateRealDistribution.Sampler {
  107.         /** RNG. */
  108.         private final UniformRandomProvider rng;
  109.         /** Sampler for each of the distribution in the mixture. */
  110.         private final MultivariateRealDistribution.Sampler[] samplers;

  111.         /**
  112.          * @param generator RNG.
  113.          */
  114.         MixtureSampler(UniformRandomProvider generator) {
  115.             rng = generator;

  116.             samplers = new MultivariateRealDistribution.Sampler[weight.length];
  117.             for (int i = 0; i < weight.length; i++) {
  118.                 samplers[i] = distribution.get(i).createSampler(rng);
  119.             }
  120.         }

  121.         /** {@inheritDoc} */
  122.         @Override
  123.         public double[] sample() {
  124.             // Sampled values.
  125.             double[] vals = null;

  126.             // Determine which component to sample from.
  127.             final double randomValue = rng.nextDouble();
  128.             double sum = 0;

  129.             for (int i = 0; i < weight.length; i++) {
  130.                 sum += weight[i];
  131.                 if (randomValue <= sum) {
  132.                     // pick model i
  133.                     vals = samplers[i].sample();
  134.                     break;
  135.                 }
  136.             }

  137.             if (vals == null) {
  138.                 // This should never happen, but it ensures we won't return a null in
  139.                 // case the loop above has some floating point inequality problem on
  140.                 // the final iteration.
  141.                 vals = samplers[weight.length - 1].sample();
  142.             }

  143.             return vals;
  144.         }
  145.     }
  146. }