001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math4.legacy.distribution; 018 019import java.util.Arrays; 020import org.apache.commons.statistics.distribution.ContinuousDistribution; 021import org.apache.commons.statistics.distribution.NormalDistribution; 022import org.apache.commons.math4.legacy.exception.DimensionMismatchException; 023import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix; 024import org.apache.commons.math4.legacy.linear.EigenDecomposition; 025import org.apache.commons.math4.legacy.linear.NonPositiveDefiniteMatrixException; 026import org.apache.commons.math4.legacy.linear.RealMatrix; 027import org.apache.commons.math4.legacy.linear.SingularMatrixException; 028import org.apache.commons.rng.UniformRandomProvider; 029import org.apache.commons.math4.core.jdkmath.JdkMath; 030 031/** 032 * Implementation of the multivariate normal (Gaussian) distribution. 033 * 034 * @see <a href="http://en.wikipedia.org/wiki/Multivariate_normal_distribution"> 035 * Multivariate normal distribution (Wikipedia)</a> 036 * @see <a href="http://mathworld.wolfram.com/MultivariateNormalDistribution.html"> 037 * Multivariate normal distribution (MathWorld)</a> 038 * 039 * @since 3.1 040 */ 041public class MultivariateNormalDistribution 042 extends AbstractMultivariateRealDistribution { 043 /** Vector of means. */ 044 private final double[] means; 045 /** Covariance matrix. */ 046 private final RealMatrix covarianceMatrix; 047 /** The matrix inverse of the covariance matrix. */ 048 private final RealMatrix covarianceMatrixInverse; 049 /** The determinant of the covariance matrix. */ 050 private final double covarianceMatrixDeterminant; 051 /** Matrix used in computation of samples. */ 052 private final RealMatrix samplingMatrix; 053 054 /** 055 * Creates a multivariate normal distribution with the given mean vector and 056 * covariance matrix. 057 * <p> 058 * The number of dimensions is equal to the length of the mean vector 059 * and to the number of rows and columns of the covariance matrix. 060 * It is frequently written as "p" in formulae. 061 * </p> 062 * 063 * @param means Vector of means. 064 * @param covariances Covariance matrix. 065 * @throws DimensionMismatchException if the arrays length are 066 * inconsistent. 067 * @throws SingularMatrixException if the eigenvalue decomposition cannot 068 * be performed on the provided covariance matrix. 069 * @throws NonPositiveDefiniteMatrixException if any of the eigenvalues is 070 * negative. 071 */ 072 public MultivariateNormalDistribution(final double[] means, 073 final double[][] covariances) 074 throws SingularMatrixException, 075 DimensionMismatchException, 076 NonPositiveDefiniteMatrixException { 077 super(means.length); 078 079 final int dim = means.length; 080 081 if (covariances.length != dim) { 082 throw new DimensionMismatchException(covariances.length, dim); 083 } 084 085 for (int i = 0; i < dim; i++) { 086 if (dim != covariances[i].length) { 087 throw new DimensionMismatchException(covariances[i].length, dim); 088 } 089 } 090 091 this.means = Arrays.copyOf(means, means.length); 092 093 covarianceMatrix = new Array2DRowRealMatrix(covariances); 094 095 // Covariance matrix eigen decomposition. 096 final EigenDecomposition covMatDec = new EigenDecomposition(covarianceMatrix); 097 098 // Compute and store the inverse. 099 covarianceMatrixInverse = covMatDec.getSolver().getInverse(); 100 // Compute and store the determinant. 101 covarianceMatrixDeterminant = covMatDec.getDeterminant(); 102 103 // Eigenvalues of the covariance matrix. 104 final double[] covMatEigenvalues = covMatDec.getRealEigenvalues(); 105 106 for (int i = 0; i < covMatEigenvalues.length; i++) { 107 if (covMatEigenvalues[i] < 0) { 108 throw new NonPositiveDefiniteMatrixException(covMatEigenvalues[i], i, 0); 109 } 110 } 111 112 // Matrix where each column is an eigenvector of the covariance matrix. 113 final Array2DRowRealMatrix covMatEigenvectors = new Array2DRowRealMatrix(dim, dim); 114 for (int v = 0; v < dim; v++) { 115 final double[] evec = covMatDec.getEigenvector(v).toArray(); 116 covMatEigenvectors.setColumn(v, evec); 117 } 118 119 final RealMatrix tmpMatrix = covMatEigenvectors.transpose(); 120 121 // Scale each eigenvector by the square root of its eigenvalue. 122 for (int row = 0; row < dim; row++) { 123 final double factor = JdkMath.sqrt(covMatEigenvalues[row]); 124 for (int col = 0; col < dim; col++) { 125 tmpMatrix.multiplyEntry(row, col, factor); 126 } 127 } 128 129 samplingMatrix = covMatEigenvectors.multiply(tmpMatrix); 130 } 131 132 /** 133 * Gets the mean vector. 134 * 135 * @return the mean vector. 136 */ 137 public double[] getMeans() { 138 return Arrays.copyOf(means, means.length); 139 } 140 141 /** 142 * Gets the covariance matrix. 143 * 144 * @return the covariance matrix. 145 */ 146 public RealMatrix getCovariances() { 147 return covarianceMatrix.copy(); 148 } 149 150 /** {@inheritDoc} */ 151 @Override 152 public double density(final double[] vals) throws DimensionMismatchException { 153 final int dim = getDimension(); 154 if (vals.length != dim) { 155 throw new DimensionMismatchException(vals.length, dim); 156 } 157 158 return JdkMath.pow(2 * JdkMath.PI, -0.5 * dim) * 159 JdkMath.pow(covarianceMatrixDeterminant, -0.5) * 160 getExponentTerm(vals); 161 } 162 163 /** 164 * Gets the square root of each element on the diagonal of the covariance 165 * matrix. 166 * 167 * @return the standard deviations. 168 */ 169 public double[] getStandardDeviations() { 170 final int dim = getDimension(); 171 final double[] std = new double[dim]; 172 final double[][] s = covarianceMatrix.getData(); 173 for (int i = 0; i < dim; i++) { 174 std[i] = JdkMath.sqrt(s[i][i]); 175 } 176 return std; 177 } 178 179 /** {@inheritDoc} */ 180 @Override 181 public MultivariateRealDistribution.Sampler createSampler(final UniformRandomProvider rng) { 182 return new MultivariateRealDistribution.Sampler() { 183 /** Normal distribution. */ 184 private final ContinuousDistribution.Sampler gauss = NormalDistribution.of(0, 1).createSampler(rng); 185 186 /** {@inheritDoc} */ 187 @Override 188 public double[] sample() { 189 final int dim = getDimension(); 190 final double[] normalVals = new double[dim]; 191 192 for (int i = 0; i < dim; i++) { 193 normalVals[i] = gauss.sample(); 194 } 195 196 final double[] vals = samplingMatrix.operate(normalVals); 197 198 for (int i = 0; i < dim; i++) { 199 vals[i] += means[i]; 200 } 201 202 return vals; 203 } 204 }; 205 } 206 207 /** 208 * Computes the term used in the exponent (see definition of the distribution). 209 * 210 * @param values Values at which to compute density. 211 * @return the multiplication factor of density calculations. 212 */ 213 private double getExponentTerm(final double[] values) { 214 final double[] centered = new double[values.length]; 215 for (int i = 0; i < centered.length; i++) { 216 centered[i] = values[i] - means[i]; 217 } 218 final double[] preMultiplied = covarianceMatrixInverse.preMultiply(centered); 219 double sum = 0; 220 for (int i = 0; i < preMultiplied.length; i++) { 221 sum += preMultiplied[i] * centered[i]; 222 } 223 return JdkMath.exp(-0.5 * sum); 224 } 225}