QRDecomposition.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.linear;

  18. import java.util.Arrays;

  19. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  20. import org.apache.commons.math4.core.jdkmath.JdkMath;
  21. import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;


  22. /**
  23.  * Calculates the QR-decomposition of a matrix.
  24.  * <p>The QR-decomposition of a matrix A consists of two matrices Q and R
  25.  * that satisfy: A = QR, Q is orthogonal (Q<sup>T</sup>Q = I), and R is
  26.  * upper triangular. If A is m&times;n, Q is m&times;m and R m&times;n.</p>
  27.  * <p>This class compute the decomposition using Householder reflectors.</p>
  28.  * <p>For efficiency purposes, the decomposition in packed form is transposed.
  29.  * This allows inner loop to iterate inside rows, which is much more cache-efficient
  30.  * in Java.</p>
  31.  * <p>This class is based on the class with similar name from the
  32.  * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library, with the
  33.  * following changes:</p>
  34.  * <ul>
  35.  *   <li>a {@link #getQT() getQT} method has been added,</li>
  36.  *   <li>the {@code solve} and {@code isFullRank} methods have been replaced
  37.  *   by a {@link #getSolver() getSolver} method and the equivalent methods
  38.  *   provided by the returned {@link DecompositionSolver}.</li>
  39.  * </ul>
  40.  *
  41.  * @see <a href="http://mathworld.wolfram.com/QRDecomposition.html">MathWorld</a>
  42.  * @see <a href="http://en.wikipedia.org/wiki/QR_decomposition">Wikipedia</a>
  43.  *
  44.  * @since 1.2 (changed to concrete class in 3.0)
  45.  */
  46. public class QRDecomposition {
  47.     /**
  48.      * A packed TRANSPOSED representation of the QR decomposition.
  49.      * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
  50.      * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
  51.      * from which an explicit form of Q can be recomputed if desired.</p>
  52.      */
  53.     private double[][] qrt;
  54.     /** The diagonal elements of R. */
  55.     private double[] rDiag;
  56.     /** Cached value of Q. */
  57.     private RealMatrix cachedQ;
  58.     /** Cached value of QT. */
  59.     private RealMatrix cachedQT;
  60.     /** Cached value of R. */
  61.     private RealMatrix cachedR;
  62.     /** Cached value of H. */
  63.     private RealMatrix cachedH;
  64.     /** Singularity threshold. */
  65.     private final double threshold;

  66.     /**
  67.      * Calculates the QR-decomposition of the given matrix.
  68.      * The singularity threshold defaults to zero.
  69.      *
  70.      * @param matrix The matrix to decompose.
  71.      *
  72.      * @see #QRDecomposition(RealMatrix,double)
  73.      */
  74.     public QRDecomposition(RealMatrix matrix) {
  75.         this(matrix, 0d);
  76.     }

  77.     /**
  78.      * Calculates the QR-decomposition of the given matrix.
  79.      *
  80.      * @param matrix The matrix to decompose.
  81.      * @param threshold Singularity threshold.
  82.      * The matrix will be considered singular if the absolute value of
  83.      * any of the diagonal elements of the "R" matrix is smaller than
  84.      * the threshold.
  85.      */
  86.     public QRDecomposition(RealMatrix matrix,
  87.                            double threshold) {
  88.         this.threshold = threshold;

  89.         final int m = matrix.getRowDimension();
  90.         final int n = matrix.getColumnDimension();
  91.         qrt = matrix.transpose().getData();
  92.         rDiag = new double[JdkMath.min(m, n)];
  93.         cachedQ  = null;
  94.         cachedQT = null;
  95.         cachedR  = null;
  96.         cachedH  = null;

  97.         decompose(qrt);
  98.     }

  99.     /** Decompose matrix.
  100.      * @param matrix transposed matrix
  101.      * @since 3.2
  102.      */
  103.     protected void decompose(double[][] matrix) {
  104.         for (int minor = 0; minor < JdkMath.min(matrix.length, matrix[0].length); minor++) {
  105.             performHouseholderReflection(minor, matrix);
  106.         }
  107.     }

  108.     /** Perform Householder reflection for a minor A(minor, minor) of A.
  109.      * @param minor minor index
  110.      * @param matrix transposed matrix
  111.      * @since 3.2
  112.      */
  113.     protected void performHouseholderReflection(int minor, double[][] matrix) {

  114.         final double[] qrtMinor = matrix[minor];

  115.         /*
  116.          * Let x be the first column of the minor, and a^2 = |x|^2.
  117.          * x will be in the positions qr[minor][minor] through qr[m][minor].
  118.          * The first column of the transformed minor will be (a,0,0,..)'
  119.          * The sign of a is chosen to be opposite to the sign of the first
  120.          * component of x. Let's find a:
  121.          */
  122.         double xNormSqr = 0;
  123.         for (int row = minor; row < qrtMinor.length; row++) {
  124.             final double c = qrtMinor[row];
  125.             xNormSqr += c * c;
  126.         }
  127.         final double a = (qrtMinor[minor] > 0) ? -JdkMath.sqrt(xNormSqr) : JdkMath.sqrt(xNormSqr);
  128.         rDiag[minor] = a;

  129.         if (a != 0.0) {

  130.             /*
  131.              * Calculate the normalized reflection vector v and transform
  132.              * the first column. We know the norm of v beforehand: v = x-ae
  133.              * so |v|^2 = <x-ae,x-ae> = <x,x>-2a<x,e>+a^2<e,e> =
  134.              * a^2+a^2-2a<x,e> = 2a*(a - <x,e>).
  135.              * Here <x, e> is now qr[minor][minor].
  136.              * v = x-ae is stored in the column at qr:
  137.              */
  138.             qrtMinor[minor] -= a; // now |v|^2 = -2a*(qr[minor][minor])

  139.             /*
  140.              * Transform the rest of the columns of the minor:
  141.              * They will be transformed by the matrix H = I-2vv'/|v|^2.
  142.              * If x is a column vector of the minor, then
  143.              * Hx = (I-2vv'/|v|^2)x = x-2vv'x/|v|^2 = x - 2<x,v>/|v|^2 v.
  144.              * Therefore the transformation is easily calculated by
  145.              * subtracting the column vector (2<x,v>/|v|^2)v from x.
  146.              *
  147.              * Let 2<x,v>/|v|^2 = alpha. From above we have
  148.              * |v|^2 = -2a*(qr[minor][minor]), so
  149.              * alpha = -<x,v>/(a*qr[minor][minor])
  150.              */
  151.             for (int col = minor+1; col < matrix.length; col++) {
  152.                 final double[] qrtCol = matrix[col];
  153.                 double alpha = 0;
  154.                 for (int row = minor; row < qrtCol.length; row++) {
  155.                     alpha -= qrtCol[row] * qrtMinor[row];
  156.                 }
  157.                 alpha /= a * qrtMinor[minor];

  158.                 // Subtract the column vector alpha*v from x.
  159.                 for (int row = minor; row < qrtCol.length; row++) {
  160.                     qrtCol[row] -= alpha * qrtMinor[row];
  161.                 }
  162.             }
  163.         }
  164.     }


  165.     /**
  166.      * Returns the matrix R of the decomposition.
  167.      * <p>R is an upper-triangular matrix</p>
  168.      * @return the R matrix
  169.      */
  170.     public RealMatrix getR() {

  171.         if (cachedR == null) {

  172.             // R is supposed to be m x n
  173.             final int n = qrt.length;
  174.             final int m = qrt[0].length;
  175.             double[][] ra = new double[m][n];
  176.             // copy the diagonal from rDiag and the upper triangle of qr
  177.             for (int row = JdkMath.min(m, n) - 1; row >= 0; row--) {
  178.                 ra[row][row] = rDiag[row];
  179.                 for (int col = row + 1; col < n; col++) {
  180.                     ra[row][col] = qrt[col][row];
  181.                 }
  182.             }
  183.             cachedR = MatrixUtils.createRealMatrix(ra);
  184.         }

  185.         // return the cached matrix
  186.         return cachedR;
  187.     }

  188.     /**
  189.      * Returns the matrix Q of the decomposition.
  190.      * <p>Q is an orthogonal matrix</p>
  191.      * @return the Q matrix
  192.      */
  193.     public RealMatrix getQ() {
  194.         if (cachedQ == null) {
  195.             cachedQ = getQT().transpose();
  196.         }
  197.         return cachedQ;
  198.     }

  199.     /**
  200.      * Returns the transpose of the matrix Q of the decomposition.
  201.      * <p>Q is an orthogonal matrix</p>
  202.      * @return the transpose of the Q matrix, Q<sup>T</sup>
  203.      */
  204.     public RealMatrix getQT() {
  205.         if (cachedQT == null) {

  206.             // QT is supposed to be m x m
  207.             final int n = qrt.length;
  208.             final int m = qrt[0].length;
  209.             double[][] qta = new double[m][m];

  210.             /*
  211.              * Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
  212.              * applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
  213.              * succession to the result
  214.              */
  215.             for (int minor = m - 1; minor >= JdkMath.min(m, n); minor--) {
  216.                 qta[minor][minor] = 1.0d;
  217.             }

  218.             for (int minor = JdkMath.min(m, n)-1; minor >= 0; minor--){
  219.                 final double[] qrtMinor = qrt[minor];
  220.                 qta[minor][minor] = 1.0d;
  221.                 if (qrtMinor[minor] != 0.0) {
  222.                     for (int col = minor; col < m; col++) {
  223.                         double alpha = 0;
  224.                         for (int row = minor; row < m; row++) {
  225.                             alpha -= qta[col][row] * qrtMinor[row];
  226.                         }
  227.                         alpha /= rDiag[minor] * qrtMinor[minor];

  228.                         for (int row = minor; row < m; row++) {
  229.                             qta[col][row] += -alpha * qrtMinor[row];
  230.                         }
  231.                     }
  232.                 }
  233.             }
  234.             cachedQT = MatrixUtils.createRealMatrix(qta);
  235.         }

  236.         // return the cached matrix
  237.         return cachedQT;
  238.     }

  239.     /**
  240.      * Returns the Householder reflector vectors.
  241.      * <p>H is a lower trapezoidal matrix whose columns represent
  242.      * each successive Householder reflector vector. This matrix is used
  243.      * to compute Q.</p>
  244.      * @return a matrix containing the Householder reflector vectors
  245.      */
  246.     public RealMatrix getH() {
  247.         if (cachedH == null) {

  248.             final int n = qrt.length;
  249.             final int m = qrt[0].length;
  250.             double[][] ha = new double[m][n];
  251.             for (int i = 0; i < m; ++i) {
  252.                 for (int j = 0; j < JdkMath.min(i + 1, n); ++j) {
  253.                     ha[i][j] = qrt[j][i] / -rDiag[j];
  254.                 }
  255.             }
  256.             cachedH = MatrixUtils.createRealMatrix(ha);
  257.         }

  258.         // return the cached matrix
  259.         return cachedH;
  260.     }

  261.     /**
  262.      * Get a solver for finding the A &times; X = B solution in least square sense.
  263.      * <p>
  264.      * Least Square sense means a solver can be computed for an overdetermined system,
  265.      * (i.e. a system with more equations than unknowns, which corresponds to a tall A
  266.      * matrix with more rows than columns). In any case, if the matrix is singular
  267.      * within the tolerance set at {@link QRDecomposition#QRDecomposition(RealMatrix,
  268.      * double) construction}, an error will be triggered when
  269.      * the {@link DecompositionSolver#solve(RealVector) solve} method will be called.
  270.      * </p>
  271.      * @return a solver
  272.      */
  273.     public DecompositionSolver getSolver() {
  274.         return new Solver(qrt, rDiag, threshold);
  275.     }

  276.     /** Specialized solver. */
  277.     private static final class Solver implements DecompositionSolver {
  278.         /**
  279.          * A packed TRANSPOSED representation of the QR decomposition.
  280.          * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
  281.          * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
  282.          * from which an explicit form of Q can be recomputed if desired.</p>
  283.          */
  284.         private final double[][] qrt;
  285.         /** The diagonal elements of R. */
  286.         private final double[] rDiag;
  287.         /** Singularity threshold. */
  288.         private final double threshold;

  289.         /**
  290.          * Build a solver from decomposed matrix.
  291.          *
  292.          * @param qrt Packed TRANSPOSED representation of the QR decomposition.
  293.          * @param rDiag Diagonal elements of R.
  294.          * @param threshold Singularity threshold.
  295.          */
  296.         private Solver(final double[][] qrt,
  297.                        final double[] rDiag,
  298.                        final double threshold) {
  299.             this.qrt   = qrt;
  300.             this.rDiag = rDiag;
  301.             this.threshold = threshold;
  302.         }

  303.         /** {@inheritDoc} */
  304.         @Override
  305.         public boolean isNonSingular() {
  306.             return !checkSingular(rDiag, threshold, false);
  307.         }

  308.         /** {@inheritDoc} */
  309.         @Override
  310.         public RealVector solve(RealVector b) {
  311.             final int n = qrt.length;
  312.             final int m = qrt[0].length;
  313.             if (b.getDimension() != m) {
  314.                 throw new DimensionMismatchException(b.getDimension(), m);
  315.             }
  316.             checkSingular(rDiag, threshold, true);

  317.             final double[] x = new double[n];
  318.             final double[] y = b.toArray();

  319.             // apply Householder transforms to solve Q.y = b
  320.             for (int minor = 0; minor < JdkMath.min(m, n); minor++) {

  321.                 final double[] qrtMinor = qrt[minor];
  322.                 double dotProduct = 0;
  323.                 for (int row = minor; row < m; row++) {
  324.                     dotProduct += y[row] * qrtMinor[row];
  325.                 }
  326.                 dotProduct /= rDiag[minor] * qrtMinor[minor];

  327.                 for (int row = minor; row < m; row++) {
  328.                     y[row] += dotProduct * qrtMinor[row];
  329.                 }
  330.             }

  331.             // solve triangular system R.x = y
  332.             for (int row = rDiag.length - 1; row >= 0; --row) {
  333.                 y[row] /= rDiag[row];
  334.                 final double yRow = y[row];
  335.                 final double[] qrtRow = qrt[row];
  336.                 x[row] = yRow;
  337.                 for (int i = 0; i < row; i++) {
  338.                     y[i] -= yRow * qrtRow[i];
  339.                 }
  340.             }

  341.             return new ArrayRealVector(x, false);
  342.         }

  343.         /** {@inheritDoc} */
  344.         @Override
  345.         public RealMatrix solve(RealMatrix b) {
  346.             final int n = qrt.length;
  347.             final int m = qrt[0].length;
  348.             if (b.getRowDimension() != m) {
  349.                 throw new DimensionMismatchException(b.getRowDimension(), m);
  350.             }
  351.             checkSingular(rDiag, threshold, true);

  352.             final int columns        = b.getColumnDimension();
  353.             final int blockSize      = BlockRealMatrix.BLOCK_SIZE;
  354.             final int cBlocks        = (columns + blockSize - 1) / blockSize;
  355.             final double[][] xBlocks = BlockRealMatrix.createBlocksLayout(n, columns);
  356.             final double[][] y       = new double[b.getRowDimension()][blockSize];
  357.             final double[]   alpha   = new double[blockSize];

  358.             for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
  359.                 final int kStart = kBlock * blockSize;
  360.                 final int kEnd   = JdkMath.min(kStart + blockSize, columns);
  361.                 final int kWidth = kEnd - kStart;

  362.                 // get the right hand side vector
  363.                 b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);

  364.                 // apply Householder transforms to solve Q.y = b
  365.                 for (int minor = 0; minor < JdkMath.min(m, n); minor++) {
  366.                     final double[] qrtMinor = qrt[minor];
  367.                     final double factor     = 1.0 / (rDiag[minor] * qrtMinor[minor]);

  368.                     Arrays.fill(alpha, 0, kWidth, 0.0);
  369.                     for (int row = minor; row < m; ++row) {
  370.                         final double   d    = qrtMinor[row];
  371.                         final double[] yRow = y[row];
  372.                         for (int k = 0; k < kWidth; ++k) {
  373.                             alpha[k] += d * yRow[k];
  374.                         }
  375.                     }
  376.                     for (int k = 0; k < kWidth; ++k) {
  377.                         alpha[k] *= factor;
  378.                     }

  379.                     for (int row = minor; row < m; ++row) {
  380.                         final double   d    = qrtMinor[row];
  381.                         final double[] yRow = y[row];
  382.                         for (int k = 0; k < kWidth; ++k) {
  383.                             yRow[k] += alpha[k] * d;
  384.                         }
  385.                     }
  386.                 }

  387.                 // solve triangular system R.x = y
  388.                 for (int j = rDiag.length - 1; j >= 0; --j) {
  389.                     final int      jBlock = j / blockSize;
  390.                     final int      jStart = jBlock * blockSize;
  391.                     final double   factor = 1.0 / rDiag[j];
  392.                     final double[] yJ     = y[j];
  393.                     final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
  394.                     int index = (j - jStart) * kWidth;
  395.                     for (int k = 0; k < kWidth; ++k) {
  396.                         yJ[k]          *= factor;
  397.                         xBlock[index++] = yJ[k];
  398.                     }

  399.                     final double[] qrtJ = qrt[j];
  400.                     for (int i = 0; i < j; ++i) {
  401.                         final double rIJ  = qrtJ[i];
  402.                         final double[] yI = y[i];
  403.                         for (int k = 0; k < kWidth; ++k) {
  404.                             yI[k] -= yJ[k] * rIJ;
  405.                         }
  406.                     }
  407.                 }
  408.             }

  409.             return new BlockRealMatrix(n, columns, xBlocks, false);
  410.         }

  411.         /**
  412.          * {@inheritDoc}
  413.          * @throws SingularMatrixException if the decomposed matrix is singular.
  414.          */
  415.         @Override
  416.         public RealMatrix getInverse() {
  417.             return solve(MatrixUtils.createRealIdentityMatrix(qrt[0].length));
  418.         }

  419.         /**
  420.          * Check singularity.
  421.          *
  422.          * @param diag Diagonal elements of the R matrix.
  423.          * @param min Singularity threshold.
  424.          * @param raise Whether to raise a {@link SingularMatrixException}
  425.          * if any element of the diagonal fails the check.
  426.          * @return {@code true} if any element of the diagonal is smaller
  427.          * or equal to {@code min}.
  428.          * @throws SingularMatrixException if the matrix is singular and
  429.          * {@code raise} is {@code true}.
  430.          */
  431.         private static boolean checkSingular(double[] diag,
  432.                                              double min,
  433.                                              boolean raise) {
  434.             final int len = diag.length;
  435.             for (int i = 0; i < len; i++) {
  436.                 final double d = diag[i];
  437.                 if (JdkMath.abs(d) <= min) {
  438.                     if (raise) {
  439.                         final SingularMatrixException e = new SingularMatrixException();
  440.                         e.getContext().addMessage(LocalizedFormats.NUMBER_TOO_SMALL, d, min);
  441.                         e.getContext().addMessage(LocalizedFormats.INDEX, i);
  442.                         throw e;
  443.                     } else {
  444.                         return true;
  445.                     }
  446.                 }
  447.             }
  448.             return false;
  449.         }
  450.     }
  451. }