001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math4.legacy.distribution; 018 019import java.util.ArrayList; 020import java.util.List; 021 022import org.apache.commons.math4.legacy.exception.DimensionMismatchException; 023import org.apache.commons.math4.legacy.exception.MathArithmeticException; 024import org.apache.commons.math4.legacy.exception.NotPositiveException; 025import org.apache.commons.math4.legacy.exception.util.LocalizedFormats; 026import org.apache.commons.rng.UniformRandomProvider; 027import org.apache.commons.math4.legacy.core.Pair; 028 029/** 030 * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model"> 031 * mixture model</a> distributions. 032 * 033 * @param <T> Type of the mixture components. 034 * 035 * @since 3.1 036 */ 037public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution> 038 extends AbstractMultivariateRealDistribution { 039 /** Normalized weight of each mixture component. */ 040 private final double[] weight; 041 /** Mixture components. */ 042 private final List<T> distribution; 043 044 /** 045 * Creates a mixture model from a list of distributions and their 046 * associated weights. 047 * 048 * @param components Distributions from which to sample. 049 * @throws NotPositiveException if any of the weights is negative. 050 * @throws DimensionMismatchException if not all components have the same 051 * number of variables. 052 */ 053 public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) { 054 super(components.get(0).getSecond().getDimension()); 055 056 final int numComp = components.size(); 057 final int dim = getDimension(); 058 double weightSum = 0; 059 for (int i = 0; i < numComp; i++) { 060 final Pair<Double, T> comp = components.get(i); 061 if (comp.getSecond().getDimension() != dim) { 062 throw new DimensionMismatchException(comp.getSecond().getDimension(), dim); 063 } 064 if (comp.getFirst() < 0) { 065 throw new NotPositiveException(comp.getFirst()); 066 } 067 weightSum += comp.getFirst(); 068 } 069 070 // Check for overflow. 071 if (Double.isInfinite(weightSum)) { 072 throw new MathArithmeticException(LocalizedFormats.OVERFLOW); 073 } 074 075 // Store each distribution and its normalized weight. 076 distribution = new ArrayList<>(); 077 weight = new double[numComp]; 078 for (int i = 0; i < numComp; i++) { 079 final Pair<Double, T> comp = components.get(i); 080 weight[i] = comp.getFirst() / weightSum; 081 distribution.add(comp.getSecond()); 082 } 083 } 084 085 /** {@inheritDoc} */ 086 @Override 087 public double density(final double[] values) { 088 double p = 0; 089 for (int i = 0; i < weight.length; i++) { 090 p += weight[i] * distribution.get(i).density(values); 091 } 092 return p; 093 } 094 095 /** 096 * Gets the distributions that make up the mixture model. 097 * 098 * @return the component distributions and associated weights. 099 */ 100 public List<Pair<Double, T>> getComponents() { 101 final List<Pair<Double, T>> list = new ArrayList<>(weight.length); 102 103 for (int i = 0; i < weight.length; i++) { 104 list.add(new Pair<>(weight[i], distribution.get(i))); 105 } 106 107 return list; 108 } 109 110 /** {@inheritDoc} */ 111 @Override 112 public MultivariateRealDistribution.Sampler createSampler(UniformRandomProvider rng) { 113 return new MixtureSampler(rng); 114 } 115 116 /** 117 * Sampler. 118 */ 119 private final class MixtureSampler implements MultivariateRealDistribution.Sampler { 120 /** RNG. */ 121 private final UniformRandomProvider rng; 122 /** Sampler for each of the distribution in the mixture. */ 123 private final MultivariateRealDistribution.Sampler[] samplers; 124 125 /** 126 * @param generator RNG. 127 */ 128 MixtureSampler(UniformRandomProvider generator) { 129 rng = generator; 130 131 samplers = new MultivariateRealDistribution.Sampler[weight.length]; 132 for (int i = 0; i < weight.length; i++) { 133 samplers[i] = distribution.get(i).createSampler(rng); 134 } 135 } 136 137 /** {@inheritDoc} */ 138 @Override 139 public double[] sample() { 140 // Sampled values. 141 double[] vals = null; 142 143 // Determine which component to sample from. 144 final double randomValue = rng.nextDouble(); 145 double sum = 0; 146 147 for (int i = 0; i < weight.length; i++) { 148 sum += weight[i]; 149 if (randomValue <= sum) { 150 // pick model i 151 vals = samplers[i].sample(); 152 break; 153 } 154 } 155 156 if (vals == null) { 157 // This should never happen, but it ensures we won't return a null in 158 // case the loop above has some floating point inequality problem on 159 // the final iteration. 160 vals = samplers[weight.length - 1].sample(); 161 } 162 163 return vals; 164 } 165 } 166}