1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.math4.legacy.distribution;
18
19 import java.util.ArrayList;
20 import java.util.List;
21
22 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
23 import org.apache.commons.math4.legacy.exception.NotPositiveException;
24 import org.apache.commons.math4.legacy.core.Pair;
25
26 /**
27 * Multivariate normal mixture distribution.
28 * This class is mainly syntactic sugar.
29 *
30 * @see MixtureMultivariateRealDistribution
31 * @since 3.2
32 */
33 public class MixtureMultivariateNormalDistribution
34 extends MixtureMultivariateRealDistribution<MultivariateNormalDistribution> {
35 /**
36 * Creates a mixture model from a list of distributions and their
37 * associated weights.
38 *
39 * @param components Distributions from which to sample.
40 * @throws NotPositiveException if any of the weights is negative.
41 * @throws DimensionMismatchException if not all components have the same
42 * number of variables.
43 */
44 public MixtureMultivariateNormalDistribution(List<Pair<Double, MultivariateNormalDistribution>> components)
45 throws NotPositiveException,
46 DimensionMismatchException {
47 super(components);
48 }
49
50 /**
51 * Creates a multivariate normal mixture distribution.
52 *
53 * @param weights Weights of each component.
54 * @param means Mean vector for each component.
55 * @param covariances Covariance matrix for each component.
56 * @throws NotPositiveException if any of the weights is negative.
57 * @throws DimensionMismatchException if not all components have the same
58 * number of variables.
59 */
60 public MixtureMultivariateNormalDistribution(double[] weights,
61 double[][] means,
62 double[][][] covariances)
63 throws NotPositiveException,
64 DimensionMismatchException {
65 this(createComponents(weights, means, covariances));
66 }
67
68 /**
69 * Creates components of the mixture model.
70 *
71 * @param weights Weights of each component.
72 * @param means Mean vector for each component.
73 * @param covariances Covariance matrix for each component.
74 * @return the list of components.
75 */
76 private static List<Pair<Double, MultivariateNormalDistribution>> createComponents(double[] weights,
77 double[][] means,
78 double[][][] covariances) {
79 final List<Pair<Double, MultivariateNormalDistribution>> mvns
80 = new ArrayList<>(weights.length);
81
82 for (int i = 0; i < weights.length; i++) {
83 final MultivariateNormalDistribution dist
84 = new MultivariateNormalDistribution(means[i], covariances[i]);
85
86 mvns.add(new Pair<>(weights[i], dist));
87 }
88
89 return mvns;
90 }
91 }