SchurTransformer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.linear;

  18. import org.apache.commons.math4.legacy.exception.MaxCountExceededException;
  19. import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
  20. import org.apache.commons.math4.core.jdkmath.JdkMath;
  21. import org.apache.commons.numbers.core.Precision;

  22. /**
  23.  * Class transforming a general real matrix to Schur form.
  24.  * <p>A m &times; m matrix A can be written as the product of three matrices: A = P
  25.  * &times; T &times; P<sup>T</sup> with P an orthogonal matrix and T an quasi-triangular
  26.  * matrix. Both P and T are m &times; m matrices.</p>
  27.  * <p>Transformation to Schur form is often not a goal by itself, but it is an
  28.  * intermediate step in more general decomposition algorithms like
  29.  * {@link EigenDecomposition eigen decomposition}. This class is therefore
  30.  * intended for internal use by the library and is not public. As a consequence
  31.  * of this explicitly limited scope, many methods directly returns references to
  32.  * internal arrays, not copies.</p>
  33.  * <p>This class is based on the method hqr2 in class EigenvalueDecomposition
  34.  * from the <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library.</p>
  35.  *
  36.  * @see <a href="http://mathworld.wolfram.com/SchurDecomposition.html">Schur Decomposition - MathWorld</a>
  37.  * @see <a href="http://en.wikipedia.org/wiki/Schur_decomposition">Schur Decomposition - Wikipedia</a>
  38.  * @see <a href="http://en.wikipedia.org/wiki/Householder_transformation">Householder Transformations</a>
  39.  * @since 3.1
  40.  */
  41. class SchurTransformer {
  42.     /** Maximum allowed iterations for convergence of the transformation. */
  43.     private static final int MAX_ITERATIONS = 100;

  44.     /** P matrix. */
  45.     private final double[][] matrixP;
  46.     /** T matrix. */
  47.     private final double[][] matrixT;
  48.     /** Cached value of P. */
  49.     private RealMatrix cachedP;
  50.     /** Cached value of T. */
  51.     private RealMatrix cachedT;
  52.     /** Cached value of PT. */
  53.     private RealMatrix cachedPt;

  54.     /** Epsilon criteria taken from JAMA code (originally was 2^-52). */
  55.     private final double epsilon = Precision.EPSILON;

  56.     /**
  57.      * Build the transformation to Schur form of a general real matrix.
  58.      *
  59.      * @param matrix matrix to transform
  60.      * @throws NonSquareMatrixException if the matrix is not square
  61.      */
  62.     SchurTransformer(final RealMatrix matrix) {
  63.         if (!matrix.isSquare()) {
  64.             throw new NonSquareMatrixException(matrix.getRowDimension(),
  65.                                                matrix.getColumnDimension());
  66.         }

  67.         HessenbergTransformer transformer = new HessenbergTransformer(matrix);
  68.         matrixT = transformer.getH().getData();
  69.         matrixP = transformer.getP().getData();
  70.         cachedT = null;
  71.         cachedP = null;
  72.         cachedPt = null;

  73.         // transform matrix
  74.         transform();
  75.     }

  76.     /**
  77.      * Returns the matrix P of the transform.
  78.      * <p>P is an orthogonal matrix, i.e. its inverse is also its transpose.</p>
  79.      *
  80.      * @return the P matrix
  81.      */
  82.     public RealMatrix getP() {
  83.         if (cachedP == null) {
  84.             cachedP = MatrixUtils.createRealMatrix(matrixP);
  85.         }
  86.         return cachedP;
  87.     }

  88.     /**
  89.      * Returns the transpose of the matrix P of the transform.
  90.      * <p>P is an orthogonal matrix, i.e. its inverse is also its transpose.</p>
  91.      *
  92.      * @return the transpose of the P matrix
  93.      */
  94.     public RealMatrix getPT() {
  95.         if (cachedPt == null) {
  96.             cachedPt = getP().transpose();
  97.         }

  98.         // return the cached matrix
  99.         return cachedPt;
  100.     }

  101.     /**
  102.      * Returns the quasi-triangular Schur matrix T of the transform.
  103.      *
  104.      * @return the T matrix
  105.      */
  106.     public RealMatrix getT() {
  107.         if (cachedT == null) {
  108.             cachedT = MatrixUtils.createRealMatrix(matrixT);
  109.         }

  110.         // return the cached matrix
  111.         return cachedT;
  112.     }

  113.     /**
  114.      * Transform original matrix to Schur form.
  115.      * @throws MaxCountExceededException if the transformation does not converge
  116.      */
  117.     private void transform() {
  118.         final int n = matrixT.length;

  119.         // compute matrix norm
  120.         final double norm = getNorm();

  121.         // shift information
  122.         final ShiftInfo shift = new ShiftInfo();

  123.         // Outer loop over eigenvalue index
  124.         int iteration = 0;
  125.         int iu = n - 1;
  126.         while (iu >= 0) {

  127.             // Look for single small sub-diagonal element
  128.             final int il = findSmallSubDiagonalElement(iu, norm);

  129.             // Check for convergence
  130.             if (il == iu) {
  131.                 // One root found
  132.                 matrixT[iu][iu] += shift.exShift;
  133.                 iu--;
  134.                 iteration = 0;
  135.             } else if (il == iu - 1) {
  136.                 // Two roots found
  137.                 double p = (matrixT[iu - 1][iu - 1] - matrixT[iu][iu]) / 2.0;
  138.                 double q = p * p + matrixT[iu][iu - 1] * matrixT[iu - 1][iu];
  139.                 matrixT[iu][iu] += shift.exShift;
  140.                 matrixT[iu - 1][iu - 1] += shift.exShift;

  141.                 if (q >= 0) {
  142.                     double z = JdkMath.sqrt(JdkMath.abs(q));
  143.                     if (p >= 0) {
  144.                         z = p + z;
  145.                     } else {
  146.                         z = p - z;
  147.                     }
  148.                     final double x = matrixT[iu][iu - 1];
  149.                     final double s = JdkMath.abs(x) + JdkMath.abs(z);
  150.                     p = x / s;
  151.                     q = z / s;
  152.                     final double r = JdkMath.sqrt(p * p + q * q);
  153.                     p /= r;
  154.                     q /= r;

  155.                     // Row modification
  156.                     for (int j = iu - 1; j < n; j++) {
  157.                         z = matrixT[iu - 1][j];
  158.                         matrixT[iu - 1][j] = q * z + p * matrixT[iu][j];
  159.                         matrixT[iu][j] = q * matrixT[iu][j] - p * z;
  160.                     }

  161.                     // Column modification
  162.                     for (int i = 0; i <= iu; i++) {
  163.                         z = matrixT[i][iu - 1];
  164.                         matrixT[i][iu - 1] = q * z + p * matrixT[i][iu];
  165.                         matrixT[i][iu] = q * matrixT[i][iu] - p * z;
  166.                     }

  167.                     // Accumulate transformations
  168.                     for (int i = 0; i <= n - 1; i++) {
  169.                         z = matrixP[i][iu - 1];
  170.                         matrixP[i][iu - 1] = q * z + p * matrixP[i][iu];
  171.                         matrixP[i][iu] = q * matrixP[i][iu] - p * z;
  172.                     }
  173.                 }
  174.                 iu -= 2;
  175.                 iteration = 0;
  176.             } else {
  177.                 // No convergence yet
  178.                 computeShift(il, iu, iteration, shift);

  179.                 // stop transformation after too many iterations
  180.                 if (++iteration > MAX_ITERATIONS) {
  181.                     throw new MaxCountExceededException(LocalizedFormats.CONVERGENCE_FAILED,
  182.                                                         MAX_ITERATIONS);
  183.                 }

  184.                 // the initial houseHolder vector for the QR step
  185.                 final double[] hVec = new double[3];

  186.                 final int im = initQRStep(il, iu, shift, hVec);
  187.                 performDoubleQRStep(il, im, iu, shift, hVec);
  188.             }
  189.         }
  190.     }

  191.     /**
  192.      * Computes the L1 norm of the (quasi-)triangular matrix T.
  193.      *
  194.      * @return the L1 norm of matrix T
  195.      */
  196.     private double getNorm() {
  197.         double norm = 0.0;
  198.         for (int i = 0; i < matrixT.length; i++) {
  199.             // as matrix T is (quasi-)triangular, also take the sub-diagonal element into account
  200.             for (int j = JdkMath.max(i - 1, 0); j < matrixT.length; j++) {
  201.                 norm += JdkMath.abs(matrixT[i][j]);
  202.             }
  203.         }
  204.         return norm;
  205.     }

  206.     /**
  207.      * Find the first small sub-diagonal element and returns its index.
  208.      *
  209.      * @param startIdx the starting index for the search
  210.      * @param norm the L1 norm of the matrix
  211.      * @return the index of the first small sub-diagonal element
  212.      */
  213.     private int findSmallSubDiagonalElement(final int startIdx, final double norm) {
  214.         int l = startIdx;
  215.         while (l > 0) {
  216.             double s = JdkMath.abs(matrixT[l - 1][l - 1]) + JdkMath.abs(matrixT[l][l]);
  217.             if (s == 0.0) {
  218.                 s = norm;
  219.             }
  220.             if (JdkMath.abs(matrixT[l][l - 1]) < epsilon * s) {
  221.                 break;
  222.             }
  223.             l--;
  224.         }
  225.         return l;
  226.     }

  227.     /**
  228.      * Compute the shift for the current iteration.
  229.      *
  230.      * @param l the index of the small sub-diagonal element
  231.      * @param idx the current eigenvalue index
  232.      * @param iteration the current iteration
  233.      * @param shift holder for shift information
  234.      */
  235.     private void computeShift(final int l, final int idx, final int iteration, final ShiftInfo shift) {
  236.         // Form shift
  237.         shift.x = matrixT[idx][idx];
  238.         shift.y = shift.w = 0.0;
  239.         if (l < idx) {
  240.             shift.y = matrixT[idx - 1][idx - 1];
  241.             shift.w = matrixT[idx][idx - 1] * matrixT[idx - 1][idx];
  242.         }

  243.         // Wilkinson's original ad hoc shift
  244.         if (iteration == 10) {
  245.             shift.exShift += shift.x;
  246.             for (int i = 0; i <= idx; i++) {
  247.                 matrixT[i][i] -= shift.x;
  248.             }
  249.             final double s = JdkMath.abs(matrixT[idx][idx - 1]) + JdkMath.abs(matrixT[idx - 1][idx - 2]);
  250.             shift.x = 0.75 * s;
  251.             shift.y = 0.75 * s;
  252.             shift.w = -0.4375 * s * s;
  253.         }

  254.         // MATLAB's new ad hoc shift
  255.         if (iteration == 30) {
  256.             double s = (shift.y - shift.x) / 2.0;
  257.             s = s * s + shift.w;
  258.             if (s > 0.0) {
  259.                 s = JdkMath.sqrt(s);
  260.                 if (shift.y < shift.x) {
  261.                     s = -s;
  262.                 }
  263.                 s = shift.x - shift.w / ((shift.y - shift.x) / 2.0 + s);
  264.                 for (int i = 0; i <= idx; i++) {
  265.                     matrixT[i][i] -= s;
  266.                 }
  267.                 shift.exShift += s;
  268.                 shift.x = shift.y = shift.w = 0.964;
  269.             }
  270.         }
  271.     }

  272.     /**
  273.      * Initialize the householder vectors for the QR step.
  274.      *
  275.      * @param il the index of the small sub-diagonal element
  276.      * @param iu the current eigenvalue index
  277.      * @param shift shift information holder
  278.      * @param hVec the initial houseHolder vector
  279.      * @return the start index for the QR step
  280.      */
  281.     private int initQRStep(int il, final int iu, final ShiftInfo shift, double[] hVec) {
  282.         // Look for two consecutive small sub-diagonal elements
  283.         int im = iu - 2;
  284.         while (im >= il) {
  285.             final double z = matrixT[im][im];
  286.             final double r = shift.x - z;
  287.             double s = shift.y - z;
  288.             hVec[0] = (r * s - shift.w) / matrixT[im + 1][im] + matrixT[im][im + 1];
  289.             hVec[1] = matrixT[im + 1][im + 1] - z - r - s;
  290.             hVec[2] = matrixT[im + 2][im + 1];

  291.             if (im == il) {
  292.                 break;
  293.             }

  294.             final double lhs = JdkMath.abs(matrixT[im][im - 1]) * (JdkMath.abs(hVec[1]) + JdkMath.abs(hVec[2]));
  295.             final double rhs = JdkMath.abs(hVec[0]) * (JdkMath.abs(matrixT[im - 1][im - 1]) +
  296.                                                         JdkMath.abs(z) +
  297.                                                         JdkMath.abs(matrixT[im + 1][im + 1]));

  298.             if (lhs < epsilon * rhs) {
  299.                 break;
  300.             }
  301.             im--;
  302.         }

  303.         return im;
  304.     }

  305.     /**
  306.      * Perform a double QR step involving rows l:idx and columns m:n.
  307.      *
  308.      * @param il the index of the small sub-diagonal element
  309.      * @param im the start index for the QR step
  310.      * @param iu the current eigenvalue index
  311.      * @param shift shift information holder
  312.      * @param hVec the initial houseHolder vector
  313.      */
  314.     private void performDoubleQRStep(final int il, final int im, final int iu,
  315.                                      final ShiftInfo shift, final double[] hVec) {

  316.         final int n = matrixT.length;
  317.         double p = hVec[0];
  318.         double q = hVec[1];
  319.         double r = hVec[2];

  320.         for (int k = im; k <= iu - 1; k++) {
  321.             boolean notlast = k != (iu - 1);
  322.             if (k != im) {
  323.                 p = matrixT[k][k - 1];
  324.                 q = matrixT[k + 1][k - 1];
  325.                 r = notlast ? matrixT[k + 2][k - 1] : 0.0;
  326.                 shift.x = JdkMath.abs(p) + JdkMath.abs(q) + JdkMath.abs(r);
  327.                 if (Precision.equals(shift.x, 0.0, epsilon)) {
  328.                     continue;
  329.                 }
  330.                 p /= shift.x;
  331.                 q /= shift.x;
  332.                 r /= shift.x;
  333.             }
  334.             double s = JdkMath.sqrt(p * p + q * q + r * r);
  335.             if (p < 0.0) {
  336.                 s = -s;
  337.             }
  338.             if (s != 0.0) {
  339.                 if (k != im) {
  340.                     matrixT[k][k - 1] = -s * shift.x;
  341.                 } else if (il != im) {
  342.                     matrixT[k][k - 1] = -matrixT[k][k - 1];
  343.                 }
  344.                 p += s;
  345.                 shift.x = p / s;
  346.                 shift.y = q / s;
  347.                 double z = r / s;
  348.                 q /= p;
  349.                 r /= p;

  350.                 // Row modification
  351.                 for (int j = k; j < n; j++) {
  352.                     p = matrixT[k][j] + q * matrixT[k + 1][j];
  353.                     if (notlast) {
  354.                         p += r * matrixT[k + 2][j];
  355.                         matrixT[k + 2][j] -= p * z;
  356.                     }
  357.                     matrixT[k][j] -= p * shift.x;
  358.                     matrixT[k + 1][j] -= p * shift.y;
  359.                 }

  360.                 // Column modification
  361.                 for (int i = 0; i <= JdkMath.min(iu, k + 3); i++) {
  362.                     p = shift.x * matrixT[i][k] + shift.y * matrixT[i][k + 1];
  363.                     if (notlast) {
  364.                         p += z * matrixT[i][k + 2];
  365.                         matrixT[i][k + 2] -= p * r;
  366.                     }
  367.                     matrixT[i][k] -= p;
  368.                     matrixT[i][k + 1] -= p * q;
  369.                 }

  370.                 // Accumulate transformations
  371.                 final int high = matrixT.length - 1;
  372.                 for (int i = 0; i <= high; i++) {
  373.                     p = shift.x * matrixP[i][k] + shift.y * matrixP[i][k + 1];
  374.                     if (notlast) {
  375.                         p += z * matrixP[i][k + 2];
  376.                         matrixP[i][k + 2] -= p * r;
  377.                     }
  378.                     matrixP[i][k] -= p;
  379.                     matrixP[i][k + 1] -= p * q;
  380.                 }
  381.             }  // (s != 0)
  382.         }  // k loop

  383.         // clean up pollution due to round-off errors
  384.         for (int i = im + 2; i <= iu; i++) {
  385.             matrixT[i][i-2] = 0.0;
  386.             if (i > im + 2) {
  387.                 matrixT[i][i-3] = 0.0;
  388.             }
  389.         }
  390.     }

  391.     /**
  392.      * Internal data structure holding the current shift information.
  393.      * Contains variable names as present in the original JAMA code.
  394.      */
  395.     private static final class ShiftInfo {
  396.         // CHECKSTYLE: stop all

  397.         /** x shift info. */
  398.         double x;
  399.         /** y shift info. */
  400.         double y;
  401.         /** w shift info. */
  402.         double w;
  403.         /** Indicates an exceptional shift. */
  404.         double exShift;

  405.         // CHECKSTYLE: resume all
  406.     }
  407. }