001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.legacy.distribution;
018
019import java.util.Arrays;
020import org.apache.commons.statistics.distribution.ContinuousDistribution;
021import org.apache.commons.statistics.distribution.NormalDistribution;
022import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
023import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
024import org.apache.commons.math4.legacy.linear.EigenDecomposition;
025import org.apache.commons.math4.legacy.linear.NonPositiveDefiniteMatrixException;
026import org.apache.commons.math4.legacy.linear.RealMatrix;
027import org.apache.commons.math4.legacy.linear.SingularMatrixException;
028import org.apache.commons.rng.UniformRandomProvider;
029import org.apache.commons.math4.core.jdkmath.JdkMath;
030
031/**
032 * Implementation of the multivariate normal (Gaussian) distribution.
033 *
034 * @see <a href="http://en.wikipedia.org/wiki/Multivariate_normal_distribution">
035 * Multivariate normal distribution (Wikipedia)</a>
036 * @see <a href="http://mathworld.wolfram.com/MultivariateNormalDistribution.html">
037 * Multivariate normal distribution (MathWorld)</a>
038 *
039 * @since 3.1
040 */
041public class MultivariateNormalDistribution
042    extends AbstractMultivariateRealDistribution {
043    /** Vector of means. */
044    private final double[] means;
045    /** Covariance matrix. */
046    private final RealMatrix covarianceMatrix;
047    /** The matrix inverse of the covariance matrix. */
048    private final RealMatrix covarianceMatrixInverse;
049    /** The determinant of the covariance matrix. */
050    private final double covarianceMatrixDeterminant;
051    /** Matrix used in computation of samples. */
052    private final RealMatrix samplingMatrix;
053
054    /**
055     * Creates a multivariate normal distribution with the given mean vector and
056     * covariance matrix.
057     * <p>
058     * The number of dimensions is equal to the length of the mean vector
059     * and to the number of rows and columns of the covariance matrix.
060     * It is frequently written as "p" in formulae.
061     * </p>
062     *
063     * @param means Vector of means.
064     * @param covariances Covariance matrix.
065     * @throws DimensionMismatchException if the arrays length are
066     * inconsistent.
067     * @throws SingularMatrixException if the eigenvalue decomposition cannot
068     * be performed on the provided covariance matrix.
069     * @throws NonPositiveDefiniteMatrixException if any of the eigenvalues is
070     * negative.
071     */
072    public MultivariateNormalDistribution(final double[] means,
073                                          final double[][] covariances)
074            throws SingularMatrixException,
075                   DimensionMismatchException,
076                   NonPositiveDefiniteMatrixException {
077        super(means.length);
078
079        final int dim = means.length;
080
081        if (covariances.length != dim) {
082            throw new DimensionMismatchException(covariances.length, dim);
083        }
084
085        for (int i = 0; i < dim; i++) {
086            if (dim != covariances[i].length) {
087                throw new DimensionMismatchException(covariances[i].length, dim);
088            }
089        }
090
091        this.means = Arrays.copyOf(means, means.length);
092
093        covarianceMatrix = new Array2DRowRealMatrix(covariances);
094
095        // Covariance matrix eigen decomposition.
096        final EigenDecomposition covMatDec = new EigenDecomposition(covarianceMatrix);
097
098        // Compute and store the inverse.
099        covarianceMatrixInverse = covMatDec.getSolver().getInverse();
100        // Compute and store the determinant.
101        covarianceMatrixDeterminant = covMatDec.getDeterminant();
102
103        // Eigenvalues of the covariance matrix.
104        final double[] covMatEigenvalues = covMatDec.getRealEigenvalues();
105
106        for (int i = 0; i < covMatEigenvalues.length; i++) {
107            if (covMatEigenvalues[i] < 0) {
108                throw new NonPositiveDefiniteMatrixException(covMatEigenvalues[i], i, 0);
109            }
110        }
111
112        // Matrix where each column is an eigenvector of the covariance matrix.
113        final Array2DRowRealMatrix covMatEigenvectors = new Array2DRowRealMatrix(dim, dim);
114        for (int v = 0; v < dim; v++) {
115            final double[] evec = covMatDec.getEigenvector(v).toArray();
116            covMatEigenvectors.setColumn(v, evec);
117        }
118
119        final RealMatrix tmpMatrix = covMatEigenvectors.transpose();
120
121        // Scale each eigenvector by the square root of its eigenvalue.
122        for (int row = 0; row < dim; row++) {
123            final double factor = JdkMath.sqrt(covMatEigenvalues[row]);
124            for (int col = 0; col < dim; col++) {
125                tmpMatrix.multiplyEntry(row, col, factor);
126            }
127        }
128
129        samplingMatrix = covMatEigenvectors.multiply(tmpMatrix);
130    }
131
132    /**
133     * Gets the mean vector.
134     *
135     * @return the mean vector.
136     */
137    public double[] getMeans() {
138        return Arrays.copyOf(means, means.length);
139    }
140
141    /**
142     * Gets the covariance matrix.
143     *
144     * @return the covariance matrix.
145     */
146    public RealMatrix getCovariances() {
147        return covarianceMatrix.copy();
148    }
149
150    /** {@inheritDoc} */
151    @Override
152    public double density(final double[] vals) throws DimensionMismatchException {
153        final int dim = getDimension();
154        if (vals.length != dim) {
155            throw new DimensionMismatchException(vals.length, dim);
156        }
157
158        return JdkMath.pow(2 * JdkMath.PI, -0.5 * dim) *
159            JdkMath.pow(covarianceMatrixDeterminant, -0.5) *
160            getExponentTerm(vals);
161    }
162
163    /**
164     * Gets the square root of each element on the diagonal of the covariance
165     * matrix.
166     *
167     * @return the standard deviations.
168     */
169    public double[] getStandardDeviations() {
170        final int dim = getDimension();
171        final double[] std = new double[dim];
172        final double[][] s = covarianceMatrix.getData();
173        for (int i = 0; i < dim; i++) {
174            std[i] = JdkMath.sqrt(s[i][i]);
175        }
176        return std;
177    }
178
179    /** {@inheritDoc} */
180    @Override
181    public MultivariateRealDistribution.Sampler createSampler(final UniformRandomProvider rng) {
182        return new MultivariateRealDistribution.Sampler() {
183            /** Normal distribution. */
184            private final ContinuousDistribution.Sampler gauss = NormalDistribution.of(0, 1).createSampler(rng);
185
186            /** {@inheritDoc} */
187            @Override
188            public double[] sample() {
189                final int dim = getDimension();
190                final double[] normalVals = new double[dim];
191
192                for (int i = 0; i < dim; i++) {
193                    normalVals[i] = gauss.sample();
194                }
195
196                final double[] vals = samplingMatrix.operate(normalVals);
197
198                for (int i = 0; i < dim; i++) {
199                    vals[i] += means[i];
200                }
201
202                return vals;
203            }
204        };
205    }
206
207    /**
208     * Computes the term used in the exponent (see definition of the distribution).
209     *
210     * @param values Values at which to compute density.
211     * @return the multiplication factor of density calculations.
212     */
213    private double getExponentTerm(final double[] values) {
214        final double[] centered = new double[values.length];
215        for (int i = 0; i < centered.length; i++) {
216            centered[i] = values[i] - means[i];
217        }
218        final double[] preMultiplied = covarianceMatrixInverse.preMultiply(centered);
219        double sum = 0;
220        for (int i = 0; i < preMultiplied.length; i++) {
221            sum += preMultiplied[i] * centered[i];
222        }
223        return JdkMath.exp(-0.5 * sum);
224    }
225}