View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.distribution;
18  
19  import java.util.ArrayList;
20  import java.util.List;
21  
22  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
23  import org.apache.commons.math4.legacy.exception.NotPositiveException;
24  import org.apache.commons.math4.legacy.core.Pair;
25  
26  /**
27   * Multivariate normal mixture distribution.
28   * This class is mainly syntactic sugar.
29   *
30   * @see MixtureMultivariateRealDistribution
31   * @since 3.2
32   */
33  public class MixtureMultivariateNormalDistribution
34      extends MixtureMultivariateRealDistribution<MultivariateNormalDistribution> {
35      /**
36       * Creates a mixture model from a list of distributions and their
37       * associated weights.
38       *
39       * @param components Distributions from which to sample.
40       * @throws NotPositiveException if any of the weights is negative.
41       * @throws DimensionMismatchException if not all components have the same
42       * number of variables.
43       */
44      public MixtureMultivariateNormalDistribution(List<Pair<Double, MultivariateNormalDistribution>> components)
45          throws NotPositiveException,
46                 DimensionMismatchException {
47          super(components);
48      }
49  
50      /**
51       * Creates a multivariate normal mixture distribution.
52       *
53       * @param weights Weights of each component.
54       * @param means Mean vector for each component.
55       * @param covariances Covariance matrix for each component.
56       * @throws NotPositiveException if any of the weights is negative.
57       * @throws DimensionMismatchException if not all components have the same
58       * number of variables.
59       */
60      public MixtureMultivariateNormalDistribution(double[] weights,
61                                                   double[][] means,
62                                                   double[][][] covariances)
63          throws NotPositiveException,
64                 DimensionMismatchException {
65          this(createComponents(weights, means, covariances));
66      }
67  
68      /**
69       * Creates components of the mixture model.
70       *
71       * @param weights Weights of each component.
72       * @param means Mean vector for each component.
73       * @param covariances Covariance matrix for each component.
74       * @return the list of components.
75       */
76      private static List<Pair<Double, MultivariateNormalDistribution>> createComponents(double[] weights,
77                                                                                         double[][] means,
78                                                                                         double[][][] covariances) {
79          final List<Pair<Double, MultivariateNormalDistribution>> mvns
80              = new ArrayList<>(weights.length);
81  
82          for (int i = 0; i < weights.length; i++) {
83              final MultivariateNormalDistribution dist
84                  = new MultivariateNormalDistribution(means[i], covariances[i]);
85  
86              mvns.add(new Pair<>(weights[i], dist));
87          }
88  
89          return mvns;
90      }
91  }