001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.distribution;
018
019import org.apache.commons.math3.exception.DimensionMismatchException;
020import org.apache.commons.math3.linear.Array2DRowRealMatrix;
021import org.apache.commons.math3.linear.EigenDecomposition;
022import org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException;
023import org.apache.commons.math3.linear.RealMatrix;
024import org.apache.commons.math3.linear.SingularMatrixException;
025import org.apache.commons.math3.random.RandomGenerator;
026import org.apache.commons.math3.random.Well19937c;
027import org.apache.commons.math3.util.FastMath;
028import org.apache.commons.math3.util.MathArrays;
029
030/**
031 * Implementation of the multivariate normal (Gaussian) distribution.
032 *
033 * @see <a href="http://en.wikipedia.org/wiki/Multivariate_normal_distribution">
034 * Multivariate normal distribution (Wikipedia)</a>
035 * @see <a href="http://mathworld.wolfram.com/MultivariateNormalDistribution.html">
036 * Multivariate normal distribution (MathWorld)</a>
037 *
038 * @since 3.1
039 */
040public class MultivariateNormalDistribution
041    extends AbstractMultivariateRealDistribution {
042    /** Vector of means. */
043    private final double[] means;
044    /** Covariance matrix. */
045    private final RealMatrix covarianceMatrix;
046    /** The matrix inverse of the covariance matrix. */
047    private final RealMatrix covarianceMatrixInverse;
048    /** The determinant of the covariance matrix. */
049    private final double covarianceMatrixDeterminant;
050    /** Matrix used in computation of samples. */
051    private final RealMatrix samplingMatrix;
052
053    /**
054     * Creates a multivariate normal distribution with the given mean vector and
055     * covariance matrix.
056     * <br/>
057     * The number of dimensions is equal to the length of the mean vector
058     * and to the number of rows and columns of the covariance matrix.
059     * It is frequently written as "p" in formulae.
060     * <p>
061     * <b>Note:</b> this constructor will implicitly create an instance of
062     * {@link Well19937c} as random generator to be used for sampling only (see
063     * {@link #sample()} and {@link #sample(int)}). In case no sampling is
064     * needed for the created distribution, it is advised to pass {@code null}
065     * as random generator via the appropriate constructors to avoid the
066     * additional initialisation overhead.
067     *
068     * @param means Vector of means.
069     * @param covariances Covariance matrix.
070     * @throws DimensionMismatchException if the arrays length are
071     * inconsistent.
072     * @throws SingularMatrixException if the eigenvalue decomposition cannot
073     * be performed on the provided covariance matrix.
074     * @throws NonPositiveDefiniteMatrixException if any of the eigenvalues is
075     * negative.
076     */
077    public MultivariateNormalDistribution(final double[] means,
078                                          final double[][] covariances)
079        throws SingularMatrixException,
080               DimensionMismatchException,
081               NonPositiveDefiniteMatrixException {
082        this(new Well19937c(), means, covariances);
083    }
084
085    /**
086     * Creates a multivariate normal distribution with the given mean vector and
087     * covariance matrix.
088     * <br/>
089     * The number of dimensions is equal to the length of the mean vector
090     * and to the number of rows and columns of the covariance matrix.
091     * It is frequently written as "p" in formulae.
092     *
093     * @param rng Random Number Generator.
094     * @param means Vector of means.
095     * @param covariances Covariance matrix.
096     * @throws DimensionMismatchException if the arrays length are
097     * inconsistent.
098     * @throws SingularMatrixException if the eigenvalue decomposition cannot
099     * be performed on the provided covariance matrix.
100     * @throws NonPositiveDefiniteMatrixException if any of the eigenvalues is
101     * negative.
102     */
103    public MultivariateNormalDistribution(RandomGenerator rng,
104                                          final double[] means,
105                                          final double[][] covariances)
106            throws SingularMatrixException,
107                   DimensionMismatchException,
108                   NonPositiveDefiniteMatrixException {
109        super(rng, means.length);
110
111        final int dim = means.length;
112
113        if (covariances.length != dim) {
114            throw new DimensionMismatchException(covariances.length, dim);
115        }
116
117        for (int i = 0; i < dim; i++) {
118            if (dim != covariances[i].length) {
119                throw new DimensionMismatchException(covariances[i].length, dim);
120            }
121        }
122
123        this.means = MathArrays.copyOf(means);
124
125        covarianceMatrix = new Array2DRowRealMatrix(covariances);
126
127        // Covariance matrix eigen decomposition.
128        final EigenDecomposition covMatDec = new EigenDecomposition(covarianceMatrix);
129
130        // Compute and store the inverse.
131        covarianceMatrixInverse = covMatDec.getSolver().getInverse();
132        // Compute and store the determinant.
133        covarianceMatrixDeterminant = covMatDec.getDeterminant();
134
135        // Eigenvalues of the covariance matrix.
136        final double[] covMatEigenvalues = covMatDec.getRealEigenvalues();
137
138        for (int i = 0; i < covMatEigenvalues.length; i++) {
139            if (covMatEigenvalues[i] < 0) {
140                throw new NonPositiveDefiniteMatrixException(covMatEigenvalues[i], i, 0);
141            }
142        }
143
144        // Matrix where each column is an eigenvector of the covariance matrix.
145        final Array2DRowRealMatrix covMatEigenvectors = new Array2DRowRealMatrix(dim, dim);
146        for (int v = 0; v < dim; v++) {
147            final double[] evec = covMatDec.getEigenvector(v).toArray();
148            covMatEigenvectors.setColumn(v, evec);
149        }
150
151        final RealMatrix tmpMatrix = covMatEigenvectors.transpose();
152
153        // Scale each eigenvector by the square root of its eigenvalue.
154        for (int row = 0; row < dim; row++) {
155            final double factor = FastMath.sqrt(covMatEigenvalues[row]);
156            for (int col = 0; col < dim; col++) {
157                tmpMatrix.multiplyEntry(row, col, factor);
158            }
159        }
160
161        samplingMatrix = covMatEigenvectors.multiply(tmpMatrix);
162    }
163
164    /**
165     * Gets the mean vector.
166     *
167     * @return the mean vector.
168     */
169    public double[] getMeans() {
170        return MathArrays.copyOf(means);
171    }
172
173    /**
174     * Gets the covariance matrix.
175     *
176     * @return the covariance matrix.
177     */
178    public RealMatrix getCovariances() {
179        return covarianceMatrix.copy();
180    }
181
182    /** {@inheritDoc} */
183    public double density(final double[] vals) throws DimensionMismatchException {
184        final int dim = getDimension();
185        if (vals.length != dim) {
186            throw new DimensionMismatchException(vals.length, dim);
187        }
188
189        return FastMath.pow(2 * FastMath.PI, -0.5 * dim) *
190            FastMath.pow(covarianceMatrixDeterminant, -0.5) *
191            getExponentTerm(vals);
192    }
193
194    /**
195     * Gets the square root of each element on the diagonal of the covariance
196     * matrix.
197     *
198     * @return the standard deviations.
199     */
200    public double[] getStandardDeviations() {
201        final int dim = getDimension();
202        final double[] std = new double[dim];
203        final double[][] s = covarianceMatrix.getData();
204        for (int i = 0; i < dim; i++) {
205            std[i] = FastMath.sqrt(s[i][i]);
206        }
207        return std;
208    }
209
210    /** {@inheritDoc} */
211    @Override
212    public double[] sample() {
213        final int dim = getDimension();
214        final double[] normalVals = new double[dim];
215
216        for (int i = 0; i < dim; i++) {
217            normalVals[i] = random.nextGaussian();
218        }
219
220        final double[] vals = samplingMatrix.operate(normalVals);
221
222        for (int i = 0; i < dim; i++) {
223            vals[i] += means[i];
224        }
225
226        return vals;
227    }
228
229    /**
230     * Computes the term used in the exponent (see definition of the distribution).
231     *
232     * @param values Values at which to compute density.
233     * @return the multiplication factor of density calculations.
234     */
235    private double getExponentTerm(final double[] values) {
236        final double[] centered = new double[values.length];
237        for (int i = 0; i < centered.length; i++) {
238            centered[i] = values[i] - getMeans()[i];
239        }
240        final double[] preMultiplied = covarianceMatrixInverse.preMultiply(centered);
241        double sum = 0;
242        for (int i = 0; i < preMultiplied.length; i++) {
243            sum += preMultiplied[i] * centered[i];
244        }
245        return FastMath.exp(-0.5 * sum);
246    }
247}