001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.distribution;
018
019import org.apache.commons.math3.exception.DimensionMismatchException;
020import org.apache.commons.math3.linear.Array2DRowRealMatrix;
021import org.apache.commons.math3.linear.EigenDecomposition;
022import org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException;
023import org.apache.commons.math3.linear.RealMatrix;
024import org.apache.commons.math3.linear.SingularMatrixException;
025import org.apache.commons.math3.random.RandomGenerator;
026import org.apache.commons.math3.random.Well19937c;
027import org.apache.commons.math3.util.FastMath;
028import org.apache.commons.math3.util.MathArrays;
029
030/**
031 * Implementation of the multivariate normal (Gaussian) distribution.
032 *
033 * @see <a href="http://en.wikipedia.org/wiki/Multivariate_normal_distribution">
034 * Multivariate normal distribution (Wikipedia)</a>
035 * @see <a href="http://mathworld.wolfram.com/MultivariateNormalDistribution.html">
036 * Multivariate normal distribution (MathWorld)</a>
037 *
038 * @version $Id: MultivariateNormalDistribution.java 1503290 2013-07-15 15:16:29Z sebb $
039 * @since 3.1
040 */
041public class MultivariateNormalDistribution
042    extends AbstractMultivariateRealDistribution {
043    /** Vector of means. */
044    private final double[] means;
045    /** Covariance matrix. */
046    private final RealMatrix covarianceMatrix;
047    /** The matrix inverse of the covariance matrix. */
048    private final RealMatrix covarianceMatrixInverse;
049    /** The determinant of the covariance matrix. */
050    private final double covarianceMatrixDeterminant;
051    /** Matrix used in computation of samples. */
052    private final RealMatrix samplingMatrix;
053
054    /**
055     * Creates a multivariate normal distribution with the given mean vector and
056     * covariance matrix.
057     * <br/>
058     * The number of dimensions is equal to the length of the mean vector
059     * and to the number of rows and columns of the covariance matrix.
060     * It is frequently written as "p" in formulae.
061     *
062     * @param means Vector of means.
063     * @param covariances Covariance matrix.
064     * @throws DimensionMismatchException if the arrays length are
065     * inconsistent.
066     * @throws SingularMatrixException if the eigenvalue decomposition cannot
067     * be performed on the provided covariance matrix.
068     * @throws NonPositiveDefiniteMatrixException if any of the eigenvalues is
069     * negative.
070     */
071    public MultivariateNormalDistribution(final double[] means,
072                                          final double[][] covariances)
073        throws SingularMatrixException,
074               DimensionMismatchException,
075               NonPositiveDefiniteMatrixException {
076        this(new Well19937c(), means, covariances);
077    }
078
079    /**
080     * Creates a multivariate normal distribution with the given mean vector and
081     * covariance matrix.
082     * <br/>
083     * The number of dimensions is equal to the length of the mean vector
084     * and to the number of rows and columns of the covariance matrix.
085     * It is frequently written as "p" in formulae.
086     *
087     * @param rng Random Number Generator.
088     * @param means Vector of means.
089     * @param covariances Covariance matrix.
090     * @throws DimensionMismatchException if the arrays length are
091     * inconsistent.
092     * @throws SingularMatrixException if the eigenvalue decomposition cannot
093     * be performed on the provided covariance matrix.
094     * @throws NonPositiveDefiniteMatrixException if any of the eigenvalues is
095     * negative.
096     */
097    public MultivariateNormalDistribution(RandomGenerator rng,
098                                          final double[] means,
099                                          final double[][] covariances)
100            throws SingularMatrixException,
101                   DimensionMismatchException,
102                   NonPositiveDefiniteMatrixException {
103        super(rng, means.length);
104
105        final int dim = means.length;
106
107        if (covariances.length != dim) {
108            throw new DimensionMismatchException(covariances.length, dim);
109        }
110
111        for (int i = 0; i < dim; i++) {
112            if (dim != covariances[i].length) {
113                throw new DimensionMismatchException(covariances[i].length, dim);
114            }
115        }
116
117        this.means = MathArrays.copyOf(means);
118
119        covarianceMatrix = new Array2DRowRealMatrix(covariances);
120
121        // Covariance matrix eigen decomposition.
122        final EigenDecomposition covMatDec = new EigenDecomposition(covarianceMatrix);
123
124        // Compute and store the inverse.
125        covarianceMatrixInverse = covMatDec.getSolver().getInverse();
126        // Compute and store the determinant.
127        covarianceMatrixDeterminant = covMatDec.getDeterminant();
128
129        // Eigenvalues of the covariance matrix.
130        final double[] covMatEigenvalues = covMatDec.getRealEigenvalues();
131
132        for (int i = 0; i < covMatEigenvalues.length; i++) {
133            if (covMatEigenvalues[i] < 0) {
134                throw new NonPositiveDefiniteMatrixException(covMatEigenvalues[i], i, 0);
135            }
136        }
137
138        // Matrix where each column is an eigenvector of the covariance matrix.
139        final Array2DRowRealMatrix covMatEigenvectors = new Array2DRowRealMatrix(dim, dim);
140        for (int v = 0; v < dim; v++) {
141            final double[] evec = covMatDec.getEigenvector(v).toArray();
142            covMatEigenvectors.setColumn(v, evec);
143        }
144
145        final RealMatrix tmpMatrix = covMatEigenvectors.transpose();
146
147        // Scale each eigenvector by the square root of its eigenvalue.
148        for (int row = 0; row < dim; row++) {
149            final double factor = FastMath.sqrt(covMatEigenvalues[row]);
150            for (int col = 0; col < dim; col++) {
151                tmpMatrix.multiplyEntry(row, col, factor);
152            }
153        }
154
155        samplingMatrix = covMatEigenvectors.multiply(tmpMatrix);
156    }
157
158    /**
159     * Gets the mean vector.
160     *
161     * @return the mean vector.
162     */
163    public double[] getMeans() {
164        return MathArrays.copyOf(means);
165    }
166
167    /**
168     * Gets the covariance matrix.
169     *
170     * @return the covariance matrix.
171     */
172    public RealMatrix getCovariances() {
173        return covarianceMatrix.copy();
174    }
175
176    /** {@inheritDoc} */
177    public double density(final double[] vals) throws DimensionMismatchException {
178        final int dim = getDimension();
179        if (vals.length != dim) {
180            throw new DimensionMismatchException(vals.length, dim);
181        }
182
183        return FastMath.pow(2 * FastMath.PI, -0.5 * dim) *
184            FastMath.pow(covarianceMatrixDeterminant, -0.5) *
185            getExponentTerm(vals);
186    }
187
188    /**
189     * Gets the square root of each element on the diagonal of the covariance
190     * matrix.
191     *
192     * @return the standard deviations.
193     */
194    public double[] getStandardDeviations() {
195        final int dim = getDimension();
196        final double[] std = new double[dim];
197        final double[][] s = covarianceMatrix.getData();
198        for (int i = 0; i < dim; i++) {
199            std[i] = FastMath.sqrt(s[i][i]);
200        }
201        return std;
202    }
203
204    /** {@inheritDoc} */
205    @Override
206    public double[] sample() {
207        final int dim = getDimension();
208        final double[] normalVals = new double[dim];
209
210        for (int i = 0; i < dim; i++) {
211            normalVals[i] = random.nextGaussian();
212        }
213
214        final double[] vals = samplingMatrix.operate(normalVals);
215
216        for (int i = 0; i < dim; i++) {
217            vals[i] += means[i];
218        }
219
220        return vals;
221    }
222
223    /**
224     * Computes the term used in the exponent (see definition of the distribution).
225     *
226     * @param values Values at which to compute density.
227     * @return the multiplication factor of density calculations.
228     */
229    private double getExponentTerm(final double[] values) {
230        final double[] centered = new double[values.length];
231        for (int i = 0; i < centered.length; i++) {
232            centered[i] = values[i] - getMeans()[i];
233        }
234        final double[] preMultiplied = covarianceMatrixInverse.preMultiply(centered);
235        double sum = 0;
236        for (int i = 0; i < preMultiplied.length; i++) {
237            sum += preMultiplied[i] * centered[i];
238        }
239        return FastMath.exp(-0.5 * sum);
240    }
241}