CholeskyDecomposition.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.linear;

  18. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  19. import org.apache.commons.math4.core.jdkmath.JdkMath;


  20. /**
  21.  * Calculates the Cholesky decomposition of a matrix.
  22.  * <p>The Cholesky decomposition of a real symmetric positive-definite
  23.  * matrix A consists of a lower triangular matrix L with same size such
  24.  * that: A = LL<sup>T</sup>. In a sense, this is the square root of A.</p>
  25.  * <p>This class is based on the class with similar name from the
  26.  * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library, with the
  27.  * following changes:</p>
  28.  * <ul>
  29.  *   <li>a {@link #getLT() getLT} method has been added,</li>
  30.  *   <li>the {@code isspd} method has been removed, since the constructor of
  31.  *   this class throws a {@link NonPositiveDefiniteMatrixException} when a
  32.  *   matrix cannot be decomposed,</li>
  33.  *   <li>a {@link #getDeterminant() getDeterminant} method has been added,</li>
  34.  *   <li>the {@code solve} method has been replaced by a {@link #getSolver()
  35.  *   getSolver} method and the equivalent method provided by the returned
  36.  *   {@link DecompositionSolver}.</li>
  37.  * </ul>
  38.  *
  39.  * @see <a href="http://mathworld.wolfram.com/CholeskyDecomposition.html">MathWorld</a>
  40.  * @see <a href="http://en.wikipedia.org/wiki/Cholesky_decomposition">Wikipedia</a>
  41.  * @since 2.0 (changed to concrete class in 3.0)
  42.  */
  43. public class CholeskyDecomposition {
  44.     /**
  45.      * Default threshold above which off-diagonal elements are considered too different
  46.      * and matrix not symmetric.
  47.      */
  48.     public static final double DEFAULT_RELATIVE_SYMMETRY_THRESHOLD = 1.0e-15;
  49.     /**
  50.      * Default threshold below which diagonal elements are considered null
  51.      * and matrix not positive definite.
  52.      */
  53.     public static final double DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD = 1.0e-10;
  54.     /** Row-oriented storage for L<sup>T</sup> matrix data. */
  55.     private final double[][] lTData;
  56.     /** Cached value of L. */
  57.     private RealMatrix cachedL;
  58.     /** Cached value of LT. */
  59.     private RealMatrix cachedLT;

  60.     /**
  61.      * Calculates the Cholesky decomposition of the given matrix.
  62.      * <p>
  63.      * Calling this constructor is equivalent to call {@link
  64.      * #CholeskyDecomposition(RealMatrix, double, double)} with the
  65.      * thresholds set to the default values {@link
  66.      * #DEFAULT_RELATIVE_SYMMETRY_THRESHOLD} and {@link
  67.      * #DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD}
  68.      * </p>
  69.      * @param matrix the matrix to decompose
  70.      * @throws NonSquareMatrixException if the matrix is not square.
  71.      * @throws NonSymmetricMatrixException if the matrix is not symmetric.
  72.      * @throws NonPositiveDefiniteMatrixException if the matrix is not
  73.      * strictly positive definite.
  74.      * @see #CholeskyDecomposition(RealMatrix, double, double)
  75.      * @see #DEFAULT_RELATIVE_SYMMETRY_THRESHOLD
  76.      * @see #DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD
  77.      */
  78.     public CholeskyDecomposition(final RealMatrix matrix) {
  79.         this(matrix, DEFAULT_RELATIVE_SYMMETRY_THRESHOLD,
  80.              DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD);
  81.     }

  82.     /**
  83.      * Calculates the Cholesky decomposition of the given matrix.
  84.      * @param matrix the matrix to decompose
  85.      * @param relativeSymmetryThreshold threshold above which off-diagonal
  86.      * elements are considered too different and matrix not symmetric
  87.      * @param absolutePositivityThreshold threshold below which diagonal
  88.      * elements are considered null and matrix not positive definite
  89.      * @throws NonSquareMatrixException if the matrix is not square.
  90.      * @throws NonSymmetricMatrixException if the matrix is not symmetric.
  91.      * @throws NonPositiveDefiniteMatrixException if the matrix is not
  92.      * strictly positive definite.
  93.      * @see #CholeskyDecomposition(RealMatrix)
  94.      * @see #DEFAULT_RELATIVE_SYMMETRY_THRESHOLD
  95.      * @see #DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD
  96.      */
  97.     public CholeskyDecomposition(final RealMatrix matrix,
  98.                                      final double relativeSymmetryThreshold,
  99.                                      final double absolutePositivityThreshold) {
  100.         if (!matrix.isSquare()) {
  101.             throw new NonSquareMatrixException(matrix.getRowDimension(),
  102.                                                matrix.getColumnDimension());
  103.         }

  104.         final int order = matrix.getRowDimension();
  105.         lTData   = matrix.getData();
  106.         cachedL  = null;
  107.         cachedLT = null;

  108.         // check the matrix before transformation
  109.         for (int i = 0; i < order; ++i) {
  110.             final double[] lI = lTData[i];

  111.             // check off-diagonal elements (and reset them to 0)
  112.             for (int j = i + 1; j < order; ++j) {
  113.                 final double[] lJ = lTData[j];
  114.                 final double lIJ = lI[j];
  115.                 final double lJI = lJ[i];
  116.                 final double maxDelta =
  117.                     relativeSymmetryThreshold * JdkMath.max(JdkMath.abs(lIJ), JdkMath.abs(lJI));
  118.                 if (JdkMath.abs(lIJ - lJI) > maxDelta) {
  119.                     throw new NonSymmetricMatrixException(i, j, relativeSymmetryThreshold);
  120.                 }
  121.                 lJ[i] = 0;
  122.            }
  123.         }

  124.         // transform the matrix
  125.         for (int i = 0; i < order; ++i) {

  126.             final double[] ltI = lTData[i];

  127.             // check diagonal element
  128.             if (ltI[i] <= absolutePositivityThreshold) {
  129.                 throw new NonPositiveDefiniteMatrixException(ltI[i], i, absolutePositivityThreshold);
  130.             }

  131.             ltI[i] = JdkMath.sqrt(ltI[i]);
  132.             final double inverse = 1.0 / ltI[i];

  133.             for (int q = order - 1; q > i; --q) {
  134.                 ltI[q] *= inverse;
  135.                 final double[] ltQ = lTData[q];
  136.                 for (int p = q; p < order; ++p) {
  137.                     ltQ[p] -= ltI[q] * ltI[p];
  138.                 }
  139.             }
  140.         }
  141.     }

  142.     /**
  143.      * Returns the matrix L of the decomposition.
  144.      * <p>L is an lower-triangular matrix</p>
  145.      * @return the L matrix
  146.      */
  147.     public RealMatrix getL() {
  148.         if (cachedL == null) {
  149.             cachedL = getLT().transpose();
  150.         }
  151.         return cachedL;
  152.     }

  153.     /**
  154.      * Returns the transpose of the matrix L of the decomposition.
  155.      * <p>L<sup>T</sup> is an upper-triangular matrix</p>
  156.      * @return the transpose of the matrix L of the decomposition
  157.      */
  158.     public RealMatrix getLT() {

  159.         if (cachedLT == null) {
  160.             cachedLT = MatrixUtils.createRealMatrix(lTData);
  161.         }

  162.         // return the cached matrix
  163.         return cachedLT;
  164.     }

  165.     /**
  166.      * Return the determinant of the matrix.
  167.      * @return determinant of the matrix
  168.      */
  169.     public double getDeterminant() {
  170.         double determinant = 1.0;
  171.         for (int i = 0; i < lTData.length; ++i) {
  172.             double lTii = lTData[i][i];
  173.             determinant *= lTii * lTii;
  174.         }
  175.         return determinant;
  176.     }

  177.     /**
  178.      * Get a solver for finding the A &times; X = B solution in least square sense.
  179.      * @return a solver
  180.      */
  181.     public DecompositionSolver getSolver() {
  182.         return new Solver(lTData);
  183.     }

  184.     /** Specialized solver. */
  185.     private static final class Solver implements DecompositionSolver {
  186.         /** Row-oriented storage for L<sup>T</sup> matrix data. */
  187.         private final double[][] lTData;

  188.         /**
  189.          * Build a solver from decomposed matrix.
  190.          * @param lTData row-oriented storage for L<sup>T</sup> matrix data
  191.          */
  192.         private Solver(final double[][] lTData) {
  193.             this.lTData = lTData;
  194.         }

  195.         /** {@inheritDoc} */
  196.         @Override
  197.         public boolean isNonSingular() {
  198.             // if we get this far, the matrix was positive definite, hence non-singular
  199.             return true;
  200.         }

  201.         /** {@inheritDoc} */
  202.         @Override
  203.         public RealVector solve(final RealVector b) {
  204.             final int m = lTData.length;
  205.             if (b.getDimension() != m) {
  206.                 throw new DimensionMismatchException(b.getDimension(), m);
  207.             }

  208.             final double[] x = b.toArray();

  209.             // Solve LY = b
  210.             for (int j = 0; j < m; j++) {
  211.                 final double[] lJ = lTData[j];
  212.                 x[j] /= lJ[j];
  213.                 final double xJ = x[j];
  214.                 for (int i = j + 1; i < m; i++) {
  215.                     x[i] -= xJ * lJ[i];
  216.                 }
  217.             }

  218.             // Solve LTX = Y
  219.             for (int j = m - 1; j >= 0; j--) {
  220.                 x[j] /= lTData[j][j];
  221.                 final double xJ = x[j];
  222.                 for (int i = 0; i < j; i++) {
  223.                     x[i] -= xJ * lTData[i][j];
  224.                 }
  225.             }

  226.             return new ArrayRealVector(x, false);
  227.         }

  228.         /** {@inheritDoc} */
  229.         @Override
  230.         public RealMatrix solve(RealMatrix b) {
  231.             final int m = lTData.length;
  232.             if (b.getRowDimension() != m) {
  233.                 throw new DimensionMismatchException(b.getRowDimension(), m);
  234.             }

  235.             final int nColB = b.getColumnDimension();
  236.             final double[][] x = b.getData();

  237.             // Solve LY = b
  238.             for (int j = 0; j < m; j++) {
  239.                 final double[] lJ = lTData[j];
  240.                 final double lJJ = lJ[j];
  241.                 final double[] xJ = x[j];
  242.                 for (int k = 0; k < nColB; ++k) {
  243.                     xJ[k] /= lJJ;
  244.                 }
  245.                 for (int i = j + 1; i < m; i++) {
  246.                     final double[] xI = x[i];
  247.                     final double lJI = lJ[i];
  248.                     for (int k = 0; k < nColB; ++k) {
  249.                         xI[k] -= xJ[k] * lJI;
  250.                     }
  251.                 }
  252.             }

  253.             // Solve LTX = Y
  254.             for (int j = m - 1; j >= 0; j--) {
  255.                 final double lJJ = lTData[j][j];
  256.                 final double[] xJ = x[j];
  257.                 for (int k = 0; k < nColB; ++k) {
  258.                     xJ[k] /= lJJ;
  259.                 }
  260.                 for (int i = 0; i < j; i++) {
  261.                     final double[] xI = x[i];
  262.                     final double lIJ = lTData[i][j];
  263.                     for (int k = 0; k < nColB; ++k) {
  264.                         xI[k] -= xJ[k] * lIJ;
  265.                     }
  266.                 }
  267.             }

  268.             return new Array2DRowRealMatrix(x);
  269.         }

  270.         /**
  271.          * Get the inverse of the decomposed matrix.
  272.          *
  273.          * @return the inverse matrix.
  274.          */
  275.         @Override
  276.         public RealMatrix getInverse() {
  277.             return solve(MatrixUtils.createRealIdentityMatrix(lTData.length));
  278.         }
  279.     }
  280. }