org.apache.commons.math3.distribution.fitting
Class MultivariateNormalMixtureExpectationMaximization

java.lang.Object
  extended by org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization

public class MultivariateNormalMixtureExpectationMaximization
extends Object

Expectation-Maximization algorithm for fitting the parameters of multivariate normal mixture model distributions. This implementation is pure original code based on EM Demystified: An Expectation-Maximization Tutorial by Yihua Chen and Maya R. Gupta, Department of Electrical Engineering, University of Washington, Seattle, WA 98195. It was verified using external tools like CRAN Mixtools (see the JUnit test cases) but it is not based on Mixtools code at all. The discussion of the origin of this class can be seen in the comments of the MATH-817 JIRA issue.

Since:
3.2
Version:
$Id: MultivariateNormalMixtureExpectationMaximization.html 857555 2013-04-06 23:30:25Z luc $

Constructor Summary
MultivariateNormalMixtureExpectationMaximization(double[][] data)
          Creates an object to fit a multivariate normal mixture model to data.
 
Method Summary
static MixtureMultivariateNormalDistribution estimate(double[][] data, int numComponents)
          Helper method to create a multivariate normal mixture model which can be used to initialize fit(MixtureMultivariateNormalDistribution).
 void fit(MixtureMultivariateNormalDistribution initialMixture)
          Fit a mixture model to the data supplied to the constructor.
 void fit(MixtureMultivariateNormalDistribution initialMixture, int maxIterations, double threshold)
          Fit a mixture model to the data supplied to the constructor.
 MixtureMultivariateNormalDistribution getFittedModel()
          Gets the fitted model.
 double getLogLikelihood()
          Gets the log likelihood of the data under the fitted model.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Constructor Detail

MultivariateNormalMixtureExpectationMaximization

public MultivariateNormalMixtureExpectationMaximization(double[][] data)
                                                 throws NotStrictlyPositiveException,
                                                        DimensionMismatchException,
                                                        NumberIsTooSmallException
Creates an object to fit a multivariate normal mixture model to data.

Parameters:
data - Data to use in fitting procedure
Throws:
NotStrictlyPositiveException - if data has no rows
DimensionMismatchException - if rows of data have different numbers of columns
NumberIsTooSmallException - if the number of columns in the data is less than 2
Method Detail

fit

public void fit(MixtureMultivariateNormalDistribution initialMixture,
                int maxIterations,
                double threshold)
         throws SingularMatrixException,
                NotStrictlyPositiveException,
                DimensionMismatchException
Fit a mixture model to the data supplied to the constructor. The quality of the fit depends on the concavity of the data provided to the constructor and the initial mixture provided to this function. If the data has many local optima, multiple runs of the fitting function with different initial mixtures may be required to find the optimal solution. If a SingularMatrixException is encountered, it is possible that another initialization would work.

Parameters:
initialMixture - Model containing initial values of weights and multivariate normals
maxIterations - Maximum iterations allowed for fit
threshold - Convergence threshold computed as difference in logLikelihoods between successive iterations
Throws:
SingularMatrixException - if any component's covariance matrix is singular during fitting
NotStrictlyPositiveException - if numComponents is less than one or threshold is less than Double.MIN_VALUE
DimensionMismatchException - if initialMixture mean vector and data number of columns are not equal

fit

public void fit(MixtureMultivariateNormalDistribution initialMixture)
         throws SingularMatrixException,
                NotStrictlyPositiveException
Fit a mixture model to the data supplied to the constructor. The quality of the fit depends on the concavity of the data provided to the constructor and the initial mixture provided to this function. If the data has many local optima, multiple runs of the fitting function with different initial mixtures may be required to find the optimal solution. If a SingularMatrixException is encountered, it is possible that another initialization would work.

Parameters:
initialMixture - Model containing initial values of weights and multivariate normals
Throws:
SingularMatrixException - if any component's covariance matrix is singular during fitting
NotStrictlyPositiveException - if numComponents is less than one or threshold is less than Double.MIN_VALUE

estimate

public static MixtureMultivariateNormalDistribution estimate(double[][] data,
                                                             int numComponents)
                                                      throws NotStrictlyPositiveException,
                                                             DimensionMismatchException
Helper method to create a multivariate normal mixture model which can be used to initialize fit(MixtureMultivariateNormalDistribution). This method uses the data supplied to the constructor to try to determine a good mixture model at which to start the fit, but it is not guaranteed to supply a model which will find the optimal solution or even converge.

Parameters:
data - Data to estimate distribution
numComponents - Number of components for estimated mixture
Returns:
Multivariate normal mixture model estimated from the data
Throws:
NumberIsTooLargeException - if numComponents is greater than the number of data rows.
NumberIsTooSmallException - if numComponents < 2.
NotStrictlyPositiveException - if data has less than 2 rows
DimensionMismatchException - if rows of data have different numbers of columns

getLogLikelihood

public double getLogLikelihood()
Gets the log likelihood of the data under the fitted model.

Returns:
Log likelihood of data or zero of no data has been fit

getFittedModel

public MixtureMultivariateNormalDistribution getFittedModel()
Gets the fitted model.

Returns:
fitted model or null if no fit has been performed yet.


Copyright © 2003-2013 The Apache Software Foundation. All Rights Reserved.