001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math3.distribution; 018 019import java.util.ArrayList; 020import java.util.List; 021 022import org.apache.commons.math3.exception.DimensionMismatchException; 023import org.apache.commons.math3.exception.NotPositiveException; 024import org.apache.commons.math3.random.RandomGenerator; 025import org.apache.commons.math3.util.Pair; 026 027/** 028 * Multivariate normal mixture distribution. 029 * This class is mainly syntactic sugar. 030 * 031 * @see MixtureMultivariateRealDistribution 032 * @since 3.2 033 */ 034public class MixtureMultivariateNormalDistribution 035 extends MixtureMultivariateRealDistribution<MultivariateNormalDistribution> { 036 037 /** 038 * Creates a multivariate normal mixture distribution. 039 * <p> 040 * <b>Note:</b> this constructor will implicitly create an instance of 041 * {@link org.apache.commons.math3.random.Well19937c Well19937c} as random 042 * generator to be used for sampling only (see {@link #sample()} and 043 * {@link #sample(int)}). In case no sampling is needed for the created 044 * distribution, it is advised to pass {@code null} as random generator via 045 * the appropriate constructors to avoid the additional initialisation 046 * overhead. 047 * 048 * @param weights Weights of each component. 049 * @param means Mean vector for each component. 050 * @param covariances Covariance matrix for each component. 051 */ 052 public MixtureMultivariateNormalDistribution(double[] weights, 053 double[][] means, 054 double[][][] covariances) { 055 super(createComponents(weights, means, covariances)); 056 } 057 058 /** 059 * Creates a mixture model from a list of distributions and their 060 * associated weights. 061 * <p> 062 * <b>Note:</b> this constructor will implicitly create an instance of 063 * {@link org.apache.commons.math3.random.Well19937c Well19937c} as random 064 * generator to be used for sampling only (see {@link #sample()} and 065 * {@link #sample(int)}). In case no sampling is needed for the created 066 * distribution, it is advised to pass {@code null} as random generator via 067 * the appropriate constructors to avoid the additional initialisation 068 * overhead. 069 * 070 * @param components List of (weight, distribution) pairs from which to sample. 071 */ 072 public MixtureMultivariateNormalDistribution(List<Pair<Double, MultivariateNormalDistribution>> components) { 073 super(components); 074 } 075 076 /** 077 * Creates a mixture model from a list of distributions and their 078 * associated weights. 079 * 080 * @param rng Random number generator. 081 * @param components Distributions from which to sample. 082 * @throws NotPositiveException if any of the weights is negative. 083 * @throws DimensionMismatchException if not all components have the same 084 * number of variables. 085 */ 086 public MixtureMultivariateNormalDistribution(RandomGenerator rng, 087 List<Pair<Double, MultivariateNormalDistribution>> components) 088 throws NotPositiveException, DimensionMismatchException { 089 super(rng, components); 090 } 091 092 /** 093 * @param weights Weights of each component. 094 * @param means Mean vector for each component. 095 * @param covariances Covariance matrix for each component. 096 * @return the list of components. 097 */ 098 private static List<Pair<Double, MultivariateNormalDistribution>> createComponents(double[] weights, 099 double[][] means, 100 double[][][] covariances) { 101 final List<Pair<Double, MultivariateNormalDistribution>> mvns 102 = new ArrayList<Pair<Double, MultivariateNormalDistribution>>(weights.length); 103 104 for (int i = 0; i < weights.length; i++) { 105 final MultivariateNormalDistribution dist 106 = new MultivariateNormalDistribution(means[i], covariances[i]); 107 108 mvns.add(new Pair<Double, MultivariateNormalDistribution>(weights[i], dist)); 109 } 110 111 return mvns; 112 } 113}