MultivariateNormalMixtureExpectationMaximization.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.distribution.fitting;

  18. import java.util.ArrayList;
  19. import java.util.Arrays;
  20. import java.util.List;

  21. import org.apache.commons.math4.legacy.distribution.MixtureMultivariateNormalDistribution;
  22. import org.apache.commons.math4.legacy.distribution.MultivariateNormalDistribution;
  23. import org.apache.commons.math4.legacy.exception.ConvergenceException;
  24. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  25. import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
  26. import org.apache.commons.math4.legacy.exception.NumberIsTooLargeException;
  27. import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
  28. import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
  29. import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
  30. import org.apache.commons.math4.legacy.linear.RealMatrix;
  31. import org.apache.commons.math4.legacy.linear.SingularMatrixException;
  32. import org.apache.commons.math4.legacy.stat.correlation.Covariance;
  33. import org.apache.commons.math4.core.jdkmath.JdkMath;
  34. import org.apache.commons.math4.legacy.core.MathArrays;
  35. import org.apache.commons.math4.legacy.core.Pair;

  36. /**
  37.  * Expectation-Maximization algorithm for fitting the parameters of
  38.  * multivariate normal mixture model distributions.
  39.  *
  40.  * This implementation is pure original code based on <a
  41.  * href="https://www.ee.washington.edu/techsite/papers/documents/UWEETR-2010-0002.pdf">
  42.  * EM Demystified: An Expectation-Maximization Tutorial</a> by Yihua Chen and Maya R. Gupta,
  43.  * Department of Electrical Engineering, University of Washington, Seattle, WA 98195.
  44.  * It was verified using external tools like <a
  45.  * href="http://cran.r-project.org/web/packages/mixtools/index.html">CRAN Mixtools</a>
  46.  * (see the JUnit test cases) but it is <strong>not</strong> based on Mixtools code at all.
  47.  * The discussion of the origin of this class can be seen in the comments of the <a
  48.  * href="https://issues.apache.org/jira/browse/MATH-817">MATH-817</a> JIRA issue.
  49.  * @since 3.2
  50.  */
  51. public class MultivariateNormalMixtureExpectationMaximization {
  52.     /**
  53.      * Default maximum number of iterations allowed per fitting process.
  54.      */
  55.     private static final int DEFAULT_MAX_ITERATIONS = 1000;
  56.     /**
  57.      * Default convergence threshold for fitting.
  58.      */
  59.     private static final double DEFAULT_THRESHOLD = 1E-5;
  60.     /**
  61.      * The data to fit.
  62.      */
  63.     private final double[][] data;
  64.     /**
  65.      * The model fit against the data.
  66.      */
  67.     private MixtureMultivariateNormalDistribution fittedModel;
  68.     /**
  69.      * The log likelihood of the data given the fitted model.
  70.      */
  71.     private double logLikelihood;

  72.     /**
  73.      * Creates an object to fit a multivariate normal mixture model to data.
  74.      *
  75.      * @param data Data to use in fitting procedure
  76.      * @throws NotStrictlyPositiveException if data has no rows
  77.      * @throws DimensionMismatchException if rows of data have different numbers
  78.      *             of columns
  79.      * @throws NumberIsTooSmallException if the number of columns in the data is
  80.      *             less than 1
  81.      */
  82.     public MultivariateNormalMixtureExpectationMaximization(double[][] data)
  83.         throws NotStrictlyPositiveException,
  84.                DimensionMismatchException,
  85.                NumberIsTooSmallException {
  86.         if (data.length < 1) {
  87.             throw new NotStrictlyPositiveException(data.length);
  88.         }

  89.         this.data = new double[data.length][data[0].length];

  90.         for (int i = 0; i < data.length; i++) {
  91.             if (data[i].length != data[0].length) {
  92.                 // Jagged arrays not allowed
  93.                 throw new DimensionMismatchException(data[i].length,
  94.                                                      data[0].length);
  95.             }
  96.             if (data[i].length < 1) {
  97.                 throw new NumberIsTooSmallException(LocalizedFormats.NUMBER_TOO_SMALL,
  98.                                                     data[i].length, 1, true);
  99.             }
  100.             this.data[i] = Arrays.copyOf(data[i], data[i].length);
  101.         }
  102.     }

  103.     /**
  104.      * Fit a mixture model to the data supplied to the constructor.
  105.      *
  106.      * The quality of the fit depends on the concavity of the data provided to
  107.      * the constructor and the initial mixture provided to this function. If the
  108.      * data has many local optima, multiple runs of the fitting function with
  109.      * different initial mixtures may be required to find the optimal solution.
  110.      * If a SingularMatrixException is encountered, it is possible that another
  111.      * initialization would work.
  112.      *
  113.      * @param initialMixture Model containing initial values of weights and
  114.      *            multivariate normals
  115.      * @param maxIterations Maximum iterations allowed for fit
  116.      * @param threshold Convergence threshold computed as difference in
  117.      *             logLikelihoods between successive iterations
  118.      * @throws SingularMatrixException if any component's covariance matrix is
  119.      *             singular during fitting
  120.      * @throws NotStrictlyPositiveException if numComponents is less than one
  121.      *             or threshold is less than Double.MIN_VALUE
  122.      * @throws DimensionMismatchException if initialMixture mean vector and data
  123.      *             number of columns are not equal
  124.      */
  125.     public void fit(final MixtureMultivariateNormalDistribution initialMixture,
  126.                     final int maxIterations,
  127.                     final double threshold)
  128.             throws SingularMatrixException,
  129.                    NotStrictlyPositiveException,
  130.                    DimensionMismatchException {
  131.         if (maxIterations < 1) {
  132.             throw new NotStrictlyPositiveException(maxIterations);
  133.         }

  134.         if (threshold < Double.MIN_VALUE) {
  135.             throw new NotStrictlyPositiveException(threshold);
  136.         }

  137.         final int n = data.length;

  138.         // Number of data columns. Jagged data already rejected in constructor,
  139.         // so we can assume the lengths of each row are equal.
  140.         final int numCols = data[0].length;
  141.         final int k = initialMixture.getComponents().size();

  142.         final int numMeanColumns
  143.             = initialMixture.getComponents().get(0).getSecond().getMeans().length;

  144.         if (numMeanColumns != numCols) {
  145.             throw new DimensionMismatchException(numMeanColumns, numCols);
  146.         }

  147.         int numIterations = 0;
  148.         double previousLogLikelihood = 0d;

  149.         logLikelihood = Double.NEGATIVE_INFINITY;

  150.         // Initialize model to fit to initial mixture.
  151.         fittedModel = new MixtureMultivariateNormalDistribution(initialMixture.getComponents());

  152.         while (numIterations++ <= maxIterations &&
  153.                JdkMath.abs(previousLogLikelihood - logLikelihood) > threshold) {
  154.             previousLogLikelihood = logLikelihood;
  155.             double sumLogLikelihood = 0d;

  156.             // Mixture components
  157.             final List<Pair<Double, MultivariateNormalDistribution>> components
  158.                 = fittedModel.getComponents();

  159.             // Weight and distribution of each component
  160.             final double[] weights = new double[k];

  161.             final MultivariateNormalDistribution[] mvns = new MultivariateNormalDistribution[k];

  162.             for (int j = 0; j < k; j++) {
  163.                 weights[j] = components.get(j).getFirst();
  164.                 mvns[j] = components.get(j).getSecond();
  165.             }

  166.             // E-step: compute the data dependent parameters of the expectation
  167.             // function.
  168.             // The percentage of row's total density between a row and a
  169.             // component
  170.             final double[][] gamma = new double[n][k];

  171.             // Sum of gamma for each component
  172.             final double[] gammaSums = new double[k];

  173.             // Sum of gamma times its row for each each component
  174.             final double[][] gammaDataProdSums = new double[k][numCols];

  175.             for (int i = 0; i < n; i++) {
  176.                 final double rowDensity = fittedModel.density(data[i]);
  177.                 sumLogLikelihood += JdkMath.log(rowDensity);

  178.                 for (int j = 0; j < k; j++) {
  179.                     gamma[i][j] = weights[j] * mvns[j].density(data[i]) / rowDensity;
  180.                     gammaSums[j] += gamma[i][j];

  181.                     for (int col = 0; col < numCols; col++) {
  182.                         gammaDataProdSums[j][col] += gamma[i][j] * data[i][col];
  183.                     }
  184.                 }
  185.             }

  186.             logLikelihood = sumLogLikelihood / n;

  187.             // M-step: compute the new parameters based on the expectation
  188.             // function.
  189.             final double[] newWeights = new double[k];
  190.             final double[][] newMeans = new double[k][numCols];

  191.             for (int j = 0; j < k; j++) {
  192.                 newWeights[j] = gammaSums[j] / n;
  193.                 for (int col = 0; col < numCols; col++) {
  194.                     newMeans[j][col] = gammaDataProdSums[j][col] / gammaSums[j];
  195.                 }
  196.             }

  197.             // Compute new covariance matrices
  198.             final RealMatrix[] newCovMats = new RealMatrix[k];
  199.             for (int j = 0; j < k; j++) {
  200.                 newCovMats[j] = new Array2DRowRealMatrix(numCols, numCols);
  201.             }
  202.             for (int i = 0; i < n; i++) {
  203.                 for (int j = 0; j < k; j++) {
  204.                     final RealMatrix vec
  205.                         = new Array2DRowRealMatrix(MathArrays.ebeSubtract(data[i], newMeans[j]));
  206.                     final RealMatrix dataCov
  207.                         = vec.multiply(vec.transpose()).scalarMultiply(gamma[i][j]);
  208.                     newCovMats[j] = newCovMats[j].add(dataCov);
  209.                 }
  210.             }

  211.             // Converting to arrays for use by fitted model
  212.             final double[][][] newCovMatArrays = new double[k][numCols][numCols];
  213.             for (int j = 0; j < k; j++) {
  214.                 newCovMats[j] = newCovMats[j].scalarMultiply(1d / gammaSums[j]);
  215.                 newCovMatArrays[j] = newCovMats[j].getData();
  216.             }

  217.             // Update current model
  218.             fittedModel = new MixtureMultivariateNormalDistribution(newWeights,
  219.                                                                     newMeans,
  220.                                                                     newCovMatArrays);
  221.         }

  222.         if (JdkMath.abs(previousLogLikelihood - logLikelihood) > threshold) {
  223.             // Did not converge before the maximum number of iterations
  224.             throw new ConvergenceException();
  225.         }
  226.     }

  227.     /**
  228.      * Fit a mixture model to the data supplied to the constructor.
  229.      *
  230.      * The quality of the fit depends on the concavity of the data provided to
  231.      * the constructor and the initial mixture provided to this function. If the
  232.      * data has many local optima, multiple runs of the fitting function with
  233.      * different initial mixtures may be required to find the optimal solution.
  234.      * If a SingularMatrixException is encountered, it is possible that another
  235.      * initialization would work.
  236.      *
  237.      * @param initialMixture Model containing initial values of weights and
  238.      *            multivariate normals
  239.      * @throws SingularMatrixException if any component's covariance matrix is
  240.      *             singular during fitting
  241.      * @throws NotStrictlyPositiveException if numComponents is less than one or
  242.      *             threshold is less than Double.MIN_VALUE
  243.      */
  244.     public void fit(MixtureMultivariateNormalDistribution initialMixture)
  245.         throws SingularMatrixException,
  246.                NotStrictlyPositiveException {
  247.         fit(initialMixture, DEFAULT_MAX_ITERATIONS, DEFAULT_THRESHOLD);
  248.     }

  249.     /**
  250.      * Helper method to create a multivariate normal mixture model which can be
  251.      * used to initialize {@link #fit(MixtureMultivariateNormalDistribution)}.
  252.      *
  253.      * This method uses the data supplied to the constructor to try to determine
  254.      * a good mixture model at which to start the fit, but it is not guaranteed
  255.      * to supply a model which will find the optimal solution or even converge.
  256.      *
  257.      * @param data Data to estimate distribution
  258.      * @param numComponents Number of components for estimated mixture
  259.      * @return Multivariate normal mixture model estimated from the data
  260.      * @throws NumberIsTooLargeException if {@code numComponents} is greater
  261.      * than the number of data rows.
  262.      * @throws NumberIsTooSmallException if {@code numComponents < 1}.
  263.      * @throws NotStrictlyPositiveException if data has less than 2 rows
  264.      * @throws DimensionMismatchException if rows of data have different numbers
  265.      *             of columns
  266.      */
  267.     public static MixtureMultivariateNormalDistribution estimate(final double[][] data,
  268.                                                                  final int numComponents)
  269.         throws NotStrictlyPositiveException,
  270.                DimensionMismatchException {
  271.         if (data.length < 2) {
  272.             throw new NotStrictlyPositiveException(data.length);
  273.         }
  274.         if (numComponents < 1) {
  275.             throw new NumberIsTooSmallException(numComponents, 1, true);
  276.         }
  277.         if (numComponents > data.length) {
  278.             throw new NumberIsTooLargeException(numComponents, data.length, true);
  279.         }

  280.         final int numRows = data.length;
  281.         final int numCols = data[0].length;

  282.         // sort the data
  283.         final DataRow[] sortedData = new DataRow[numRows];
  284.         for (int i = 0; i < numRows; i++) {
  285.             sortedData[i] = new DataRow(data[i]);
  286.         }
  287.         Arrays.sort(sortedData);

  288.         // uniform weight for each bin
  289.         final double weight = 1d / numComponents;

  290.         // components of mixture model to be created
  291.         final List<Pair<Double, MultivariateNormalDistribution>> components =
  292.                 new ArrayList<>(numComponents);

  293.         // create a component based on data in each bin
  294.         for (int binIndex = 0; binIndex < numComponents; binIndex++) {
  295.             // minimum index (inclusive) from sorted data for this bin
  296.             final int minIndex = (binIndex * numRows) / numComponents;

  297.             // maximum index (exclusive) from sorted data for this bin
  298.             final int maxIndex = ((binIndex + 1) * numRows) / numComponents;

  299.             // number of data records that will be in this bin
  300.             final int numBinRows = maxIndex - minIndex;

  301.             // data for this bin
  302.             final double[][] binData = new double[numBinRows][numCols];

  303.             // mean of each column for the data in the this bin
  304.             final double[] columnMeans = new double[numCols];

  305.             // populate bin and create component
  306.             for (int i = minIndex, iBin = 0; i < maxIndex; i++, iBin++) {
  307.                 for (int j = 0; j < numCols; j++) {
  308.                     final double val = sortedData[i].getRow()[j];
  309.                     columnMeans[j] += val;
  310.                     binData[iBin][j] = val;
  311.                 }
  312.             }

  313.             MathArrays.scaleInPlace(1d / numBinRows, columnMeans);

  314.             // covariance matrix for this bin
  315.             final double[][] covMat
  316.                 = new Covariance(binData).getCovarianceMatrix().getData();
  317.             final MultivariateNormalDistribution mvn
  318.                 = new MultivariateNormalDistribution(columnMeans, covMat);

  319.             components.add(new Pair<>(weight, mvn));
  320.         }

  321.         return new MixtureMultivariateNormalDistribution(components);
  322.     }

  323.     /**
  324.      * Gets the log likelihood of the data under the fitted model.
  325.      *
  326.      * @return Log likelihood of data or zero of no data has been fit
  327.      */
  328.     public double getLogLikelihood() {
  329.         return logLikelihood;
  330.     }

  331.     /**
  332.      * Gets the fitted model.
  333.      *
  334.      * @return fitted model or {@code null} if no fit has been performed yet.
  335.      */
  336.     public MixtureMultivariateNormalDistribution getFittedModel() {
  337.         return new MixtureMultivariateNormalDistribution(fittedModel.getComponents());
  338.     }

  339.     /**
  340.      * Class used for sorting user-supplied data.
  341.      */
  342.     private static final class DataRow implements Comparable<DataRow> {
  343.         /** One data row. */
  344.         private final double[] row;
  345.         /** Mean of the data row. */
  346.         private Double mean;

  347.         /**
  348.          * Create a data row.
  349.          * @param data Data to use for the row
  350.          */
  351.         DataRow(final double[] data) {
  352.             // Store reference.
  353.             row = data;
  354.             // Compute mean.
  355.             mean = 0d;
  356.             for (int i = 0; i < data.length; i++) {
  357.                 mean += data[i];
  358.             }
  359.             mean /= data.length;
  360.         }

  361.         /**
  362.          * Compare two data rows.
  363.          * @param other The other row
  364.          * @return int for sorting
  365.          */
  366.         @Override
  367.         public int compareTo(final DataRow other) {
  368.             return mean.compareTo(other.mean);
  369.         }

  370.         /** {@inheritDoc} */
  371.         @Override
  372.         public boolean equals(Object other) {

  373.             if (this == other) {
  374.                 return true;
  375.             }

  376.             if (other instanceof DataRow) {
  377.                 return MathArrays.equals(row, ((DataRow) other).row);
  378.             }

  379.             return false;
  380.         }

  381.         /** {@inheritDoc} */
  382.         @Override
  383.         public int hashCode() {
  384.             return Arrays.hashCode(row);
  385.         }
  386.         /**
  387.          * Get a data row.
  388.          * @return data row array
  389.          */
  390.         public double[] getRow() {
  391.             return row;
  392.         }
  393.     }
  394. }