LUDecomposition.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.linear;

  18. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  19. import org.apache.commons.math4.core.jdkmath.JdkMath;

  20. /**
  21.  * Calculates the LUP-decomposition of a square matrix.
  22.  * <p>The LUP-decomposition of a matrix A consists of three matrices L, U and
  23.  * P that satisfy: P&times;A = L&times;U. L is lower triangular (with unit
  24.  * diagonal terms), U is upper triangular and P is a permutation matrix. All
  25.  * matrices are m&times;m.</p>
  26.  * <p>As shown by the presence of the P matrix, this decomposition is
  27.  * implemented using partial pivoting.</p>
  28.  * <p>This class is based on the class with similar name from the
  29.  * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library.</p>
  30.  * <ul>
  31.  *   <li>a {@link #getP() getP} method has been added,</li>
  32.  *   <li>the {@code det} method has been renamed as {@link #getDeterminant()
  33.  *   getDeterminant},</li>
  34.  *   <li>the {@code getDoublePivot} method has been removed (but the int based
  35.  *   {@link #getPivot() getPivot} method has been kept),</li>
  36.  *   <li>the {@code solve} and {@code isNonSingular} methods have been replaced
  37.  *   by a {@link #getSolver() getSolver} method and the equivalent methods
  38.  *   provided by the returned {@link DecompositionSolver}.</li>
  39.  * </ul>
  40.  *
  41.  * @see <a href="http://mathworld.wolfram.com/LUDecomposition.html">MathWorld</a>
  42.  * @see <a href="http://en.wikipedia.org/wiki/LU_decomposition">Wikipedia</a>
  43.  * @since 2.0 (changed to concrete class in 3.0)
  44.  */
  45. public class LUDecomposition {
  46.     /** Default bound to determine effective singularity in LU decomposition. */
  47.     private static final double DEFAULT_TOO_SMALL = 1e-11;
  48.     /** Entries of LU decomposition. */
  49.     private final double[][] lu;
  50.     /** Pivot permutation associated with LU decomposition. */
  51.     private final int[] pivot;
  52.     /** Parity of the permutation associated with the LU decomposition. */
  53.     private boolean even;
  54.     /** Singularity indicator. */
  55.     private boolean singular;
  56.     /** Cached value of L. */
  57.     private RealMatrix cachedL;
  58.     /** Cached value of U. */
  59.     private RealMatrix cachedU;
  60.     /** Cached value of P. */
  61.     private RealMatrix cachedP;

  62.     /**
  63.      * Calculates the LU-decomposition of the given matrix.
  64.      * This constructor uses 1e-11 as default value for the singularity
  65.      * threshold.
  66.      *
  67.      * @param matrix Matrix to decompose.
  68.      * @throws NonSquareMatrixException if matrix is not square.
  69.      */
  70.     public LUDecomposition(RealMatrix matrix) {
  71.         this(matrix, DEFAULT_TOO_SMALL);
  72.     }

  73.     /**
  74.      * Calculates the LU-decomposition of the given matrix.
  75.      * @param matrix The matrix to decompose.
  76.      * @param singularityThreshold threshold (based on partial row norm)
  77.      * under which a matrix is considered singular
  78.      * @throws NonSquareMatrixException if matrix is not square
  79.      */
  80.     public LUDecomposition(RealMatrix matrix, double singularityThreshold) {
  81.         if (!matrix.isSquare()) {
  82.             throw new NonSquareMatrixException(matrix.getRowDimension(),
  83.                                                matrix.getColumnDimension());
  84.         }

  85.         final int m = matrix.getColumnDimension();
  86.         lu = matrix.getData();
  87.         pivot = new int[m];
  88.         cachedL = null;
  89.         cachedU = null;
  90.         cachedP = null;

  91.         // Initialize permutation array and parity
  92.         for (int row = 0; row < m; row++) {
  93.             pivot[row] = row;
  94.         }
  95.         even     = true;
  96.         singular = false;

  97.         // Loop over columns
  98.         for (int col = 0; col < m; col++) {

  99.             // upper
  100.             for (int row = 0; row < col; row++) {
  101.                 final double[] luRow = lu[row];
  102.                 double sum = luRow[col];
  103.                 for (int i = 0; i < row; i++) {
  104.                     sum -= luRow[i] * lu[i][col];
  105.                 }
  106.                 luRow[col] = sum;
  107.             }

  108.             // lower
  109.             int max = col; // permutation row
  110.             double largest = Double.NEGATIVE_INFINITY;
  111.             for (int row = col; row < m; row++) {
  112.                 final double[] luRow = lu[row];
  113.                 double sum = luRow[col];
  114.                 for (int i = 0; i < col; i++) {
  115.                     sum -= luRow[i] * lu[i][col];
  116.                 }
  117.                 luRow[col] = sum;

  118.                 // maintain best permutation choice
  119.                 if (JdkMath.abs(sum) > largest) {
  120.                     largest = JdkMath.abs(sum);
  121.                     max = row;
  122.                 }
  123.             }

  124.             // Singularity check
  125.             if (JdkMath.abs(lu[max][col]) < singularityThreshold) {
  126.                 singular = true;
  127.                 return;
  128.             }

  129.             // Pivot if necessary
  130.             if (max != col) {
  131.                 double tmp = 0;
  132.                 final double[] luMax = lu[max];
  133.                 final double[] luCol = lu[col];
  134.                 for (int i = 0; i < m; i++) {
  135.                     tmp = luMax[i];
  136.                     luMax[i] = luCol[i];
  137.                     luCol[i] = tmp;
  138.                 }
  139.                 int temp = pivot[max];
  140.                 pivot[max] = pivot[col];
  141.                 pivot[col] = temp;
  142.                 even = !even;
  143.             }

  144.             // Divide the lower elements by the "winning" diagonal elt.
  145.             final double luDiag = lu[col][col];
  146.             for (int row = col + 1; row < m; row++) {
  147.                 lu[row][col] /= luDiag;
  148.             }
  149.         }
  150.     }

  151.     /**
  152.      * Returns the matrix L of the decomposition.
  153.      * <p>L is a lower-triangular matrix</p>
  154.      * @return the L matrix (or null if decomposed matrix is singular)
  155.      */
  156.     public RealMatrix getL() {
  157.         if (cachedL == null && !singular) {
  158.             final int m = pivot.length;
  159.             cachedL = MatrixUtils.createRealMatrix(m, m);
  160.             for (int i = 0; i < m; ++i) {
  161.                 final double[] luI = lu[i];
  162.                 for (int j = 0; j < i; ++j) {
  163.                     cachedL.setEntry(i, j, luI[j]);
  164.                 }
  165.                 cachedL.setEntry(i, i, 1.0);
  166.             }
  167.         }
  168.         return cachedL;
  169.     }

  170.     /**
  171.      * Returns the matrix U of the decomposition.
  172.      * <p>U is an upper-triangular matrix</p>
  173.      * @return the U matrix (or null if decomposed matrix is singular)
  174.      */
  175.     public RealMatrix getU() {
  176.         if (cachedU == null && !singular) {
  177.             final int m = pivot.length;
  178.             cachedU = MatrixUtils.createRealMatrix(m, m);
  179.             for (int i = 0; i < m; ++i) {
  180.                 final double[] luI = lu[i];
  181.                 for (int j = i; j < m; ++j) {
  182.                     cachedU.setEntry(i, j, luI[j]);
  183.                 }
  184.             }
  185.         }
  186.         return cachedU;
  187.     }

  188.     /**
  189.      * Returns the P rows permutation matrix.
  190.      * <p>P is a sparse matrix with exactly one element set to 1.0 in
  191.      * each row and each column, all other elements being set to 0.0.</p>
  192.      * <p>The positions of the 1 elements are given by the {@link #getPivot()
  193.      * pivot permutation vector}.</p>
  194.      * @return the P rows permutation matrix (or null if decomposed matrix is singular)
  195.      * @see #getPivot()
  196.      */
  197.     public RealMatrix getP() {
  198.         if (cachedP == null && !singular) {
  199.             final int m = pivot.length;
  200.             cachedP = MatrixUtils.createRealMatrix(m, m);
  201.             for (int i = 0; i < m; ++i) {
  202.                 cachedP.setEntry(i, pivot[i], 1.0);
  203.             }
  204.         }
  205.         return cachedP;
  206.     }

  207.     /**
  208.      * Returns the pivot permutation vector.
  209.      * @return the pivot permutation vector
  210.      * @see #getP()
  211.      */
  212.     public int[] getPivot() {
  213.         return pivot.clone();
  214.     }

  215.     /**
  216.      * Return the determinant of the matrix.
  217.      * @return determinant of the matrix
  218.      */
  219.     public double getDeterminant() {
  220.         if (singular) {
  221.             return 0;
  222.         } else {
  223.             final int m = pivot.length;
  224.             double determinant = even ? 1 : -1;
  225.             for (int i = 0; i < m; i++) {
  226.                 determinant *= lu[i][i];
  227.             }
  228.             return determinant;
  229.         }
  230.     }

  231.     /**
  232.      * Get a solver for finding the A &times; X = B solution in exact linear
  233.      * sense.
  234.      * @return a solver
  235.      */
  236.     public DecompositionSolver getSolver() {
  237.         return new Solver(lu, pivot, singular);
  238.     }

  239.     /** Specialized solver. */
  240.     private static final class Solver implements DecompositionSolver {

  241.         /** Entries of LU decomposition. */
  242.         private final double[][] lu;

  243.         /** Pivot permutation associated with LU decomposition. */
  244.         private final int[] pivot;

  245.         /** Singularity indicator. */
  246.         private final boolean singular;

  247.         /**
  248.          * Build a solver from decomposed matrix.
  249.          * @param lu entries of LU decomposition
  250.          * @param pivot pivot permutation associated with LU decomposition
  251.          * @param singular singularity indicator
  252.          */
  253.         private Solver(final double[][] lu, final int[] pivot, final boolean singular) {
  254.             this.lu       = lu;
  255.             this.pivot    = pivot;
  256.             this.singular = singular;
  257.         }

  258.         /** {@inheritDoc} */
  259.         @Override
  260.         public boolean isNonSingular() {
  261.             return !singular;
  262.         }

  263.         /** {@inheritDoc} */
  264.         @Override
  265.         public RealVector solve(RealVector b) {
  266.             final int m = pivot.length;
  267.             if (b.getDimension() != m) {
  268.                 throw new DimensionMismatchException(b.getDimension(), m);
  269.             }
  270.             if (singular) {
  271.                 throw new SingularMatrixException();
  272.             }

  273.             final double[] bp = new double[m];

  274.             // Apply permutations to b
  275.             for (int row = 0; row < m; row++) {
  276.                 bp[row] = b.getEntry(pivot[row]);
  277.             }

  278.             // Solve LY = b
  279.             for (int col = 0; col < m; col++) {
  280.                 final double bpCol = bp[col];
  281.                 for (int i = col + 1; i < m; i++) {
  282.                     bp[i] -= bpCol * lu[i][col];
  283.                 }
  284.             }

  285.             // Solve UX = Y
  286.             for (int col = m - 1; col >= 0; col--) {
  287.                 bp[col] /= lu[col][col];
  288.                 final double bpCol = bp[col];
  289.                 for (int i = 0; i < col; i++) {
  290.                     bp[i] -= bpCol * lu[i][col];
  291.                 }
  292.             }

  293.             return new ArrayRealVector(bp, false);
  294.         }

  295.         /** {@inheritDoc} */
  296.         @Override
  297.         public RealMatrix solve(RealMatrix b) {

  298.             final int m = pivot.length;
  299.             if (b.getRowDimension() != m) {
  300.                 throw new DimensionMismatchException(b.getRowDimension(), m);
  301.             }
  302.             if (singular) {
  303.                 throw new SingularMatrixException();
  304.             }

  305.             final int nColB = b.getColumnDimension();

  306.             // Apply permutations to b
  307.             final double[][] bp = new double[m][nColB];
  308.             for (int row = 0; row < m; row++) {
  309.                 final double[] bpRow = bp[row];
  310.                 final int pRow = pivot[row];
  311.                 for (int col = 0; col < nColB; col++) {
  312.                     bpRow[col] = b.getEntry(pRow, col);
  313.                 }
  314.             }

  315.             // Solve LY = b
  316.             for (int col = 0; col < m; col++) {
  317.                 final double[] bpCol = bp[col];
  318.                 for (int i = col + 1; i < m; i++) {
  319.                     final double[] bpI = bp[i];
  320.                     final double luICol = lu[i][col];
  321.                     for (int j = 0; j < nColB; j++) {
  322.                         bpI[j] -= bpCol[j] * luICol;
  323.                     }
  324.                 }
  325.             }

  326.             // Solve UX = Y
  327.             for (int col = m - 1; col >= 0; col--) {
  328.                 final double[] bpCol = bp[col];
  329.                 final double luDiag = lu[col][col];
  330.                 for (int j = 0; j < nColB; j++) {
  331.                     bpCol[j] /= luDiag;
  332.                 }
  333.                 for (int i = 0; i < col; i++) {
  334.                     final double[] bpI = bp[i];
  335.                     final double luICol = lu[i][col];
  336.                     for (int j = 0; j < nColB; j++) {
  337.                         bpI[j] -= bpCol[j] * luICol;
  338.                     }
  339.                 }
  340.             }

  341.             return new Array2DRowRealMatrix(bp, false);
  342.         }

  343.         /**
  344.          * Get the inverse of the decomposed matrix.
  345.          *
  346.          * @return the inverse matrix.
  347.          * @throws SingularMatrixException if the decomposed matrix is singular.
  348.          */
  349.         @Override
  350.         public RealMatrix getInverse() {
  351.             return solve(MatrixUtils.createRealIdentityMatrix(pivot.length));
  352.         }
  353.     }
  354. }