FieldLUDecomposition.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.linear;

  18. import org.apache.commons.math4.legacy.core.Field;
  19. import org.apache.commons.math4.legacy.core.FieldElement;
  20. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  21. import org.apache.commons.math4.legacy.core.MathArrays;

  22. /**
  23.  * Calculates the LUP-decomposition of a square matrix.
  24.  * <p>The LUP-decomposition of a matrix A consists of three matrices
  25.  * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
  26.  * upper triangular and P is a permutation matrix. All matrices are
  27.  * m&times;m.</p>
  28.  * <p>Since {@link FieldElement field elements} do not provide an ordering
  29.  * operator, the permutation matrix is computed here only in order to avoid
  30.  * a zero pivot element, no attempt is done to get the largest pivot
  31.  * element.</p>
  32.  * <p>This class is based on the class with similar name from the
  33.  * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library.</p>
  34.  * <ul>
  35.  *   <li>a {@link #getP() getP} method has been added,</li>
  36.  *   <li>the {@code det} method has been renamed as {@link #getDeterminant()
  37.  *   getDeterminant},</li>
  38.  *   <li>the {@code getDoublePivot} method has been removed (but the int based
  39.  *   {@link #getPivot() getPivot} method has been kept),</li>
  40.  *   <li>the {@code solve} and {@code isNonSingular} methods have been replaced
  41.  *   by a {@link #getSolver() getSolver} method and the equivalent methods
  42.  *   provided by the returned {@link DecompositionSolver}.</li>
  43.  * </ul>
  44.  *
  45.  * @param <T> the type of the field elements
  46.  * @see <a href="http://mathworld.wolfram.com/LUDecomposition.html">MathWorld</a>
  47.  * @see <a href="http://en.wikipedia.org/wiki/LU_decomposition">Wikipedia</a>
  48.  * @since 2.0 (changed to concrete class in 3.0)
  49.  */
  50. public class FieldLUDecomposition<T extends FieldElement<T>> {

  51.     /** Field to which the elements belong. */
  52.     private final Field<T> field;

  53.     /** Entries of LU decomposition. */
  54.     private T[][] lu;

  55.     /** Pivot permutation associated with LU decomposition. */
  56.     private int[] pivot;

  57.     /** Parity of the permutation associated with the LU decomposition. */
  58.     private boolean even;

  59.     /** Singularity indicator. */
  60.     private boolean singular;

  61.     /** Cached value of L. */
  62.     private FieldMatrix<T> cachedL;

  63.     /** Cached value of U. */
  64.     private FieldMatrix<T> cachedU;

  65.     /** Cached value of P. */
  66.     private FieldMatrix<T> cachedP;

  67.     /**
  68.      * Calculates the LU-decomposition of the given matrix.
  69.      * @param matrix The matrix to decompose.
  70.      * @throws NonSquareMatrixException if matrix is not square
  71.      */
  72.     public FieldLUDecomposition(FieldMatrix<T> matrix) {
  73.         if (!matrix.isSquare()) {
  74.             throw new NonSquareMatrixException(matrix.getRowDimension(),
  75.                                                matrix.getColumnDimension());
  76.         }

  77.         final int m = matrix.getColumnDimension();
  78.         field = matrix.getField();
  79.         lu = matrix.getData();
  80.         pivot = new int[m];
  81.         cachedL = null;
  82.         cachedU = null;
  83.         cachedP = null;

  84.         // Initialize permutation array and parity
  85.         for (int row = 0; row < m; row++) {
  86.             pivot[row] = row;
  87.         }
  88.         even     = true;
  89.         singular = false;

  90.         // Loop over columns
  91.         for (int col = 0; col < m; col++) {

  92.             T sum = field.getZero();

  93.             // upper
  94.             for (int row = 0; row < col; row++) {
  95.                 final T[] luRow = lu[row];
  96.                 sum = luRow[col];
  97.                 for (int i = 0; i < row; i++) {
  98.                     sum = sum.subtract(luRow[i].multiply(lu[i][col]));
  99.                 }
  100.                 luRow[col] = sum;
  101.             }

  102.             // lower
  103.             int nonZero = col; // permutation row
  104.             for (int row = col; row < m; row++) {
  105.                 final T[] luRow = lu[row];
  106.                 sum = luRow[col];
  107.                 for (int i = 0; i < col; i++) {
  108.                     sum = sum.subtract(luRow[i].multiply(lu[i][col]));
  109.                 }
  110.                 luRow[col] = sum;

  111.                 if (lu[nonZero][col].equals(field.getZero())) {
  112.                     // try to select a better permutation choice
  113.                     ++nonZero;
  114.                 }
  115.             }

  116.             // Singularity check
  117.             if (nonZero >= m) {
  118.                 singular = true;
  119.                 return;
  120.             }

  121.             // Pivot if necessary
  122.             if (nonZero != col) {
  123.                 T tmp = field.getZero();
  124.                 for (int i = 0; i < m; i++) {
  125.                     tmp = lu[nonZero][i];
  126.                     lu[nonZero][i] = lu[col][i];
  127.                     lu[col][i] = tmp;
  128.                 }
  129.                 int temp = pivot[nonZero];
  130.                 pivot[nonZero] = pivot[col];
  131.                 pivot[col] = temp;
  132.                 even = !even;
  133.             }

  134.             // Divide the lower elements by the "winning" diagonal elt.
  135.             final T luDiag = lu[col][col];
  136.             for (int row = col + 1; row < m; row++) {
  137.                 final T[] luRow = lu[row];
  138.                 luRow[col] = luRow[col].divide(luDiag);
  139.             }
  140.         }
  141.     }

  142.     /**
  143.      * Returns the matrix L of the decomposition.
  144.      * <p>L is a lower-triangular matrix</p>
  145.      * @return the L matrix (or null if decomposed matrix is singular)
  146.      */
  147.     public FieldMatrix<T> getL() {
  148.         if (cachedL == null && !singular) {
  149.             final int m = pivot.length;
  150.             cachedL = new Array2DRowFieldMatrix<>(field, m, m);
  151.             for (int i = 0; i < m; ++i) {
  152.                 final T[] luI = lu[i];
  153.                 for (int j = 0; j < i; ++j) {
  154.                     cachedL.setEntry(i, j, luI[j]);
  155.                 }
  156.                 cachedL.setEntry(i, i, field.getOne());
  157.             }
  158.         }
  159.         return cachedL;
  160.     }

  161.     /**
  162.      * Returns the matrix U of the decomposition.
  163.      * <p>U is an upper-triangular matrix</p>
  164.      * @return the U matrix (or null if decomposed matrix is singular)
  165.      */
  166.     public FieldMatrix<T> getU() {
  167.         if (cachedU == null && !singular) {
  168.             final int m = pivot.length;
  169.             cachedU = new Array2DRowFieldMatrix<>(field, m, m);
  170.             for (int i = 0; i < m; ++i) {
  171.                 final T[] luI = lu[i];
  172.                 for (int j = i; j < m; ++j) {
  173.                     cachedU.setEntry(i, j, luI[j]);
  174.                 }
  175.             }
  176.         }
  177.         return cachedU;
  178.     }

  179.     /**
  180.      * Returns the P rows permutation matrix.
  181.      * <p>P is a sparse matrix with exactly one element set to 1.0 in
  182.      * each row and each column, all other elements being set to 0.0.</p>
  183.      * <p>The positions of the 1 elements are given by the {@link #getPivot()
  184.      * pivot permutation vector}.</p>
  185.      * @return the P rows permutation matrix (or null if decomposed matrix is singular)
  186.      * @see #getPivot()
  187.      */
  188.     public FieldMatrix<T> getP() {
  189.         if (cachedP == null && !singular) {
  190.             final int m = pivot.length;
  191.             cachedP = new Array2DRowFieldMatrix<>(field, m, m);
  192.             for (int i = 0; i < m; ++i) {
  193.                 cachedP.setEntry(i, pivot[i], field.getOne());
  194.             }
  195.         }
  196.         return cachedP;
  197.     }

  198.     /**
  199.      * Returns the pivot permutation vector.
  200.      * @return the pivot permutation vector
  201.      * @see #getP()
  202.      */
  203.     public int[] getPivot() {
  204.         return pivot.clone();
  205.     }

  206.     /**
  207.      * Return the determinant of the matrix.
  208.      * @return determinant of the matrix
  209.      */
  210.     public T getDeterminant() {
  211.         if (singular) {
  212.             return field.getZero();
  213.         } else {
  214.             final int m = pivot.length;
  215.             T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
  216.             for (int i = 0; i < m; i++) {
  217.                 determinant = determinant.multiply(lu[i][i]);
  218.             }
  219.             return determinant;
  220.         }
  221.     }

  222.     /**
  223.      * Get a solver for finding the A &times; X = B solution in exact linear sense.
  224.      * @return a solver
  225.      */
  226.     public FieldDecompositionSolver<T> getSolver() {
  227.         return new Solver<>(field, lu, pivot, singular);
  228.     }

  229.     /** Specialized solver.
  230.      * @param <T> the type of the field elements
  231.      */
  232.     private static final class Solver<T extends FieldElement<T>> implements FieldDecompositionSolver<T> {

  233.         /** Field to which the elements belong. */
  234.         private final Field<T> field;

  235.         /** Entries of LU decomposition. */
  236.         private final T[][] lu;

  237.         /** Pivot permutation associated with LU decomposition. */
  238.         private final int[] pivot;

  239.         /** Singularity indicator. */
  240.         private final boolean singular;

  241.         /**
  242.          * Build a solver from decomposed matrix.
  243.          * @param field field to which the matrix elements belong
  244.          * @param lu entries of LU decomposition
  245.          * @param pivot pivot permutation associated with LU decomposition
  246.          * @param singular singularity indicator
  247.          */
  248.         private Solver(final Field<T> field, final T[][] lu,
  249.                        final int[] pivot, final boolean singular) {
  250.             this.field    = field;
  251.             this.lu       = lu;
  252.             this.pivot    = pivot;
  253.             this.singular = singular;
  254.         }

  255.         /** {@inheritDoc} */
  256.         @Override
  257.         public boolean isNonSingular() {
  258.             return !singular;
  259.         }

  260.         /** {@inheritDoc} */
  261.         @Override
  262.         public FieldVector<T> solve(FieldVector<T> b) {
  263.             if (b instanceof ArrayFieldVector) {
  264.                 return solve((ArrayFieldVector<T>) b);
  265.             }

  266.             final int m = pivot.length;
  267.             if (b.getDimension() != m) {
  268.                 throw new DimensionMismatchException(b.getDimension(), m);
  269.             }
  270.             if (singular) {
  271.                 throw new SingularMatrixException();
  272.             }

  273.             // Apply permutations to b
  274.             final T[] bp = MathArrays.buildArray(field, m);
  275.             for (int row = 0; row < m; row++) {
  276.                 bp[row] = b.getEntry(pivot[row]);
  277.             }

  278.             // Solve LY = b
  279.             for (int col = 0; col < m; col++) {
  280.                 final T bpCol = bp[col];
  281.                 for (int i = col + 1; i < m; i++) {
  282.                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
  283.                 }
  284.             }

  285.             // Solve UX = Y
  286.             for (int col = m - 1; col >= 0; col--) {
  287.                 bp[col] = bp[col].divide(lu[col][col]);
  288.                 final T bpCol = bp[col];
  289.                 for (int i = 0; i < col; i++) {
  290.                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
  291.                 }
  292.             }

  293.             return new ArrayFieldVector<>(field, bp, false);
  294.         }

  295.         /** Solve the linear equation A &times; X = B.
  296.          * <p>The A matrix is implicit here. It is </p>
  297.          * @param b right-hand side of the equation A &times; X = B
  298.          * @return a vector X such that A &times; X = B
  299.          * @throws DimensionMismatchException if the matrices dimensions do not match.
  300.          * @throws SingularMatrixException if the decomposed matrix is singular.
  301.          */
  302.         public ArrayFieldVector<T> solve(ArrayFieldVector<T> b) {
  303.             final int m = pivot.length;
  304.             final int length = b.getDimension();
  305.             if (length != m) {
  306.                 throw new DimensionMismatchException(length, m);
  307.             }
  308.             if (singular) {
  309.                 throw new SingularMatrixException();
  310.             }

  311.             // Apply permutations to b
  312.             final T[] bp = MathArrays.buildArray(field, m);
  313.             for (int row = 0; row < m; row++) {
  314.                 bp[row] = b.getEntry(pivot[row]);
  315.             }

  316.             // Solve LY = b
  317.             for (int col = 0; col < m; col++) {
  318.                 final T bpCol = bp[col];
  319.                 for (int i = col + 1; i < m; i++) {
  320.                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
  321.                 }
  322.             }

  323.             // Solve UX = Y
  324.             for (int col = m - 1; col >= 0; col--) {
  325.                 bp[col] = bp[col].divide(lu[col][col]);
  326.                 final T bpCol = bp[col];
  327.                 for (int i = 0; i < col; i++) {
  328.                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
  329.                 }
  330.             }

  331.             return new ArrayFieldVector<>(bp, false);
  332.         }

  333.         /** {@inheritDoc} */
  334.         @Override
  335.         public FieldMatrix<T> solve(FieldMatrix<T> b) {
  336.             final int m = pivot.length;
  337.             if (b.getRowDimension() != m) {
  338.                 throw new DimensionMismatchException(b.getRowDimension(), m);
  339.             }
  340.             if (singular) {
  341.                 throw new SingularMatrixException();
  342.             }

  343.             final int nColB = b.getColumnDimension();

  344.             // Apply permutations to b
  345.             final T[][] bp = MathArrays.buildArray(field, m, nColB);
  346.             for (int row = 0; row < m; row++) {
  347.                 final T[] bpRow = bp[row];
  348.                 final int pRow = pivot[row];
  349.                 for (int col = 0; col < nColB; col++) {
  350.                     bpRow[col] = b.getEntry(pRow, col);
  351.                 }
  352.             }

  353.             // Solve LY = b
  354.             for (int col = 0; col < m; col++) {
  355.                 final T[] bpCol = bp[col];
  356.                 for (int i = col + 1; i < m; i++) {
  357.                     final T[] bpI = bp[i];
  358.                     final T luICol = lu[i][col];
  359.                     for (int j = 0; j < nColB; j++) {
  360.                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
  361.                     }
  362.                 }
  363.             }

  364.             // Solve UX = Y
  365.             for (int col = m - 1; col >= 0; col--) {
  366.                 final T[] bpCol = bp[col];
  367.                 final T luDiag = lu[col][col];
  368.                 for (int j = 0; j < nColB; j++) {
  369.                     bpCol[j] = bpCol[j].divide(luDiag);
  370.                 }
  371.                 for (int i = 0; i < col; i++) {
  372.                     final T[] bpI = bp[i];
  373.                     final T luICol = lu[i][col];
  374.                     for (int j = 0; j < nColB; j++) {
  375.                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
  376.                     }
  377.                 }
  378.             }

  379.             return new Array2DRowFieldMatrix<>(field, bp, false);
  380.         }

  381.         /** {@inheritDoc} */
  382.         @Override
  383.         public FieldMatrix<T> getInverse() {
  384.             final int m = pivot.length;
  385.             final T one = field.getOne();
  386.             FieldMatrix<T> identity = new Array2DRowFieldMatrix<>(field, m, m);
  387.             for (int i = 0; i < m; ++i) {
  388.                 identity.setEntry(i, i, one);
  389.             }
  390.             return solve(identity);
  391.         }
  392.     }
  393. }