001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math4.legacy.distribution; 018 019import java.util.ArrayList; 020import java.util.List; 021 022import org.apache.commons.math4.legacy.exception.DimensionMismatchException; 023import org.apache.commons.math4.legacy.exception.NotPositiveException; 024import org.apache.commons.math4.legacy.core.Pair; 025 026/** 027 * Multivariate normal mixture distribution. 028 * This class is mainly syntactic sugar. 029 * 030 * @see MixtureMultivariateRealDistribution 031 * @since 3.2 032 */ 033public class MixtureMultivariateNormalDistribution 034 extends MixtureMultivariateRealDistribution<MultivariateNormalDistribution> { 035 /** 036 * Creates a mixture model from a list of distributions and their 037 * associated weights. 038 * 039 * @param components Distributions from which to sample. 040 * @throws NotPositiveException if any of the weights is negative. 041 * @throws DimensionMismatchException if not all components have the same 042 * number of variables. 043 */ 044 public MixtureMultivariateNormalDistribution(List<Pair<Double, MultivariateNormalDistribution>> components) 045 throws NotPositiveException, 046 DimensionMismatchException { 047 super(components); 048 } 049 050 /** 051 * Creates a multivariate normal mixture distribution. 052 * 053 * @param weights Weights of each component. 054 * @param means Mean vector for each component. 055 * @param covariances Covariance matrix for each component. 056 * @throws NotPositiveException if any of the weights is negative. 057 * @throws DimensionMismatchException if not all components have the same 058 * number of variables. 059 */ 060 public MixtureMultivariateNormalDistribution(double[] weights, 061 double[][] means, 062 double[][][] covariances) 063 throws NotPositiveException, 064 DimensionMismatchException { 065 this(createComponents(weights, means, covariances)); 066 } 067 068 /** 069 * Creates components of the mixture model. 070 * 071 * @param weights Weights of each component. 072 * @param means Mean vector for each component. 073 * @param covariances Covariance matrix for each component. 074 * @return the list of components. 075 */ 076 private static List<Pair<Double, MultivariateNormalDistribution>> createComponents(double[] weights, 077 double[][] means, 078 double[][][] covariances) { 079 final List<Pair<Double, MultivariateNormalDistribution>> mvns 080 = new ArrayList<>(weights.length); 081 082 for (int i = 0; i < weights.length; i++) { 083 final MultivariateNormalDistribution dist 084 = new MultivariateNormalDistribution(means[i], covariances[i]); 085 086 mvns.add(new Pair<>(weights[i], dist)); 087 } 088 089 return mvns; 090 } 091}