View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.distribution;
18  
19  import java.util.Arrays;
20  import org.apache.commons.statistics.distribution.ContinuousDistribution;
21  import org.apache.commons.statistics.distribution.NormalDistribution;
22  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
23  import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
24  import org.apache.commons.math4.legacy.linear.EigenDecomposition;
25  import org.apache.commons.math4.legacy.linear.NonPositiveDefiniteMatrixException;
26  import org.apache.commons.math4.legacy.linear.RealMatrix;
27  import org.apache.commons.math4.legacy.linear.SingularMatrixException;
28  import org.apache.commons.rng.UniformRandomProvider;
29  import org.apache.commons.math4.core.jdkmath.JdkMath;
30  
31  /**
32   * Implementation of the multivariate normal (Gaussian) distribution.
33   *
34   * @see <a href="http://en.wikipedia.org/wiki/Multivariate_normal_distribution">
35   * Multivariate normal distribution (Wikipedia)</a>
36   * @see <a href="http://mathworld.wolfram.com/MultivariateNormalDistribution.html">
37   * Multivariate normal distribution (MathWorld)</a>
38   *
39   * @since 3.1
40   */
41  public class MultivariateNormalDistribution
42      extends AbstractMultivariateRealDistribution {
43      /** Vector of means. */
44      private final double[] means;
45      /** Covariance matrix. */
46      private final RealMatrix covarianceMatrix;
47      /** The matrix inverse of the covariance matrix. */
48      private final RealMatrix covarianceMatrixInverse;
49      /** The determinant of the covariance matrix. */
50      private final double covarianceMatrixDeterminant;
51      /** Matrix used in computation of samples. */
52      private final RealMatrix samplingMatrix;
53  
54      /**
55       * Creates a multivariate normal distribution with the given mean vector and
56       * covariance matrix.
57       * <p>
58       * The number of dimensions is equal to the length of the mean vector
59       * and to the number of rows and columns of the covariance matrix.
60       * It is frequently written as "p" in formulae.
61       * </p>
62       *
63       * @param means Vector of means.
64       * @param covariances Covariance matrix.
65       * @throws DimensionMismatchException if the arrays length are
66       * inconsistent.
67       * @throws SingularMatrixException if the eigenvalue decomposition cannot
68       * be performed on the provided covariance matrix.
69       * @throws NonPositiveDefiniteMatrixException if any of the eigenvalues is
70       * negative.
71       */
72      public MultivariateNormalDistribution(final double[] means,
73                                            final double[][] covariances)
74              throws SingularMatrixException,
75                     DimensionMismatchException,
76                     NonPositiveDefiniteMatrixException {
77          super(means.length);
78  
79          final int dim = means.length;
80  
81          if (covariances.length != dim) {
82              throw new DimensionMismatchException(covariances.length, dim);
83          }
84  
85          for (int i = 0; i < dim; i++) {
86              if (dim != covariances[i].length) {
87                  throw new DimensionMismatchException(covariances[i].length, dim);
88              }
89          }
90  
91          this.means = Arrays.copyOf(means, means.length);
92  
93          covarianceMatrix = new Array2DRowRealMatrix(covariances);
94  
95          // Covariance matrix eigen decomposition.
96          final EigenDecomposition covMatDec = new EigenDecomposition(covarianceMatrix);
97  
98          // Compute and store the inverse.
99          covarianceMatrixInverse = covMatDec.getSolver().getInverse();
100         // Compute and store the determinant.
101         covarianceMatrixDeterminant = covMatDec.getDeterminant();
102 
103         // Eigenvalues of the covariance matrix.
104         final double[] covMatEigenvalues = covMatDec.getRealEigenvalues();
105 
106         for (int i = 0; i < covMatEigenvalues.length; i++) {
107             if (covMatEigenvalues[i] < 0) {
108                 throw new NonPositiveDefiniteMatrixException(covMatEigenvalues[i], i, 0);
109             }
110         }
111 
112         // Matrix where each column is an eigenvector of the covariance matrix.
113         final Array2DRowRealMatrix covMatEigenvectors = new Array2DRowRealMatrix(dim, dim);
114         for (int v = 0; v < dim; v++) {
115             final double[] evec = covMatDec.getEigenvector(v).toArray();
116             covMatEigenvectors.setColumn(v, evec);
117         }
118 
119         final RealMatrix tmpMatrix = covMatEigenvectors.transpose();
120 
121         // Scale each eigenvector by the square root of its eigenvalue.
122         for (int row = 0; row < dim; row++) {
123             final double factor = JdkMath.sqrt(covMatEigenvalues[row]);
124             for (int col = 0; col < dim; col++) {
125                 tmpMatrix.multiplyEntry(row, col, factor);
126             }
127         }
128 
129         samplingMatrix = covMatEigenvectors.multiply(tmpMatrix);
130     }
131 
132     /**
133      * Gets the mean vector.
134      *
135      * @return the mean vector.
136      */
137     public double[] getMeans() {
138         return Arrays.copyOf(means, means.length);
139     }
140 
141     /**
142      * Gets the covariance matrix.
143      *
144      * @return the covariance matrix.
145      */
146     public RealMatrix getCovariances() {
147         return covarianceMatrix.copy();
148     }
149 
150     /** {@inheritDoc} */
151     @Override
152     public double density(final double[] vals) throws DimensionMismatchException {
153         final int dim = getDimension();
154         if (vals.length != dim) {
155             throw new DimensionMismatchException(vals.length, dim);
156         }
157 
158         return JdkMath.pow(2 * JdkMath.PI, -0.5 * dim) *
159             JdkMath.pow(covarianceMatrixDeterminant, -0.5) *
160             getExponentTerm(vals);
161     }
162 
163     /**
164      * Gets the square root of each element on the diagonal of the covariance
165      * matrix.
166      *
167      * @return the standard deviations.
168      */
169     public double[] getStandardDeviations() {
170         final int dim = getDimension();
171         final double[] std = new double[dim];
172         final double[][] s = covarianceMatrix.getData();
173         for (int i = 0; i < dim; i++) {
174             std[i] = JdkMath.sqrt(s[i][i]);
175         }
176         return std;
177     }
178 
179     /** {@inheritDoc} */
180     @Override
181     public MultivariateRealDistribution.Sampler createSampler(final UniformRandomProvider rng) {
182         return new MultivariateRealDistribution.Sampler() {
183             /** Normal distribution. */
184             private final ContinuousDistribution.Sampler gauss = NormalDistribution.of(0, 1).createSampler(rng);
185 
186             /** {@inheritDoc} */
187             @Override
188             public double[] sample() {
189                 final int dim = getDimension();
190                 final double[] normalVals = new double[dim];
191 
192                 for (int i = 0; i < dim; i++) {
193                     normalVals[i] = gauss.sample();
194                 }
195 
196                 final double[] vals = samplingMatrix.operate(normalVals);
197 
198                 for (int i = 0; i < dim; i++) {
199                     vals[i] += means[i];
200                 }
201 
202                 return vals;
203             }
204         };
205     }
206 
207     /**
208      * Computes the term used in the exponent (see definition of the distribution).
209      *
210      * @param values Values at which to compute density.
211      * @return the multiplication factor of density calculations.
212      */
213     private double getExponentTerm(final double[] values) {
214         final double[] centered = new double[values.length];
215         for (int i = 0; i < centered.length; i++) {
216             centered[i] = values[i] - means[i];
217         }
218         final double[] preMultiplied = covarianceMatrixInverse.preMultiply(centered);
219         double sum = 0;
220         for (int i = 0; i < preMultiplied.length; i++) {
221             sum += preMultiplied[i] * centered[i];
222         }
223         return JdkMath.exp(-0.5 * sum);
224     }
225 }