View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.linear;
19  
20  import org.apache.commons.math4.legacy.exception.MaxCountExceededException;
21  import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
22  import org.apache.commons.math4.core.jdkmath.JdkMath;
23  import org.apache.commons.numbers.core.Precision;
24  
25  /**
26   * Class transforming a general real matrix to Schur form.
27   * <p>A m &times; m matrix A can be written as the product of three matrices: A = P
28   * &times; T &times; P<sup>T</sup> with P an orthogonal matrix and T an quasi-triangular
29   * matrix. Both P and T are m &times; m matrices.</p>
30   * <p>Transformation to Schur form is often not a goal by itself, but it is an
31   * intermediate step in more general decomposition algorithms like
32   * {@link EigenDecomposition eigen decomposition}. This class is therefore
33   * intended for internal use by the library and is not public. As a consequence
34   * of this explicitly limited scope, many methods directly returns references to
35   * internal arrays, not copies.</p>
36   * <p>This class is based on the method hqr2 in class EigenvalueDecomposition
37   * from the <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library.</p>
38   *
39   * @see <a href="http://mathworld.wolfram.com/SchurDecomposition.html">Schur Decomposition - MathWorld</a>
40   * @see <a href="http://en.wikipedia.org/wiki/Schur_decomposition">Schur Decomposition - Wikipedia</a>
41   * @see <a href="http://en.wikipedia.org/wiki/Householder_transformation">Householder Transformations</a>
42   * @since 3.1
43   */
44  class SchurTransformer {
45      /** Maximum allowed iterations for convergence of the transformation. */
46      private static final int MAX_ITERATIONS = 100;
47  
48      /** P matrix. */
49      private final double[][] matrixP;
50      /** T matrix. */
51      private final double[][] matrixT;
52      /** Cached value of P. */
53      private RealMatrix cachedP;
54      /** Cached value of T. */
55      private RealMatrix cachedT;
56      /** Cached value of PT. */
57      private RealMatrix cachedPt;
58  
59      /** Epsilon criteria taken from JAMA code (originally was 2^-52). */
60      private final double epsilon = Precision.EPSILON;
61  
62      /**
63       * Build the transformation to Schur form of a general real matrix.
64       *
65       * @param matrix matrix to transform
66       * @throws NonSquareMatrixException if the matrix is not square
67       */
68      SchurTransformer(final RealMatrix matrix) {
69          if (!matrix.isSquare()) {
70              throw new NonSquareMatrixException(matrix.getRowDimension(),
71                                                 matrix.getColumnDimension());
72          }
73  
74          HessenbergTransformer transformer = new HessenbergTransformer(matrix);
75          matrixT = transformer.getH().getData();
76          matrixP = transformer.getP().getData();
77          cachedT = null;
78          cachedP = null;
79          cachedPt = null;
80  
81          // transform matrix
82          transform();
83      }
84  
85      /**
86       * Returns the matrix P of the transform.
87       * <p>P is an orthogonal matrix, i.e. its inverse is also its transpose.</p>
88       *
89       * @return the P matrix
90       */
91      public RealMatrix getP() {
92          if (cachedP == null) {
93              cachedP = MatrixUtils.createRealMatrix(matrixP);
94          }
95          return cachedP;
96      }
97  
98      /**
99       * Returns the transpose of the matrix P of the transform.
100      * <p>P is an orthogonal matrix, i.e. its inverse is also its transpose.</p>
101      *
102      * @return the transpose of the P matrix
103      */
104     public RealMatrix getPT() {
105         if (cachedPt == null) {
106             cachedPt = getP().transpose();
107         }
108 
109         // return the cached matrix
110         return cachedPt;
111     }
112 
113     /**
114      * Returns the quasi-triangular Schur matrix T of the transform.
115      *
116      * @return the T matrix
117      */
118     public RealMatrix getT() {
119         if (cachedT == null) {
120             cachedT = MatrixUtils.createRealMatrix(matrixT);
121         }
122 
123         // return the cached matrix
124         return cachedT;
125     }
126 
127     /**
128      * Transform original matrix to Schur form.
129      * @throws MaxCountExceededException if the transformation does not converge
130      */
131     private void transform() {
132         final int n = matrixT.length;
133 
134         // compute matrix norm
135         final double norm = getNorm();
136 
137         // shift information
138         final ShiftInfo shift = new ShiftInfo();
139 
140         // Outer loop over eigenvalue index
141         int iteration = 0;
142         int iu = n - 1;
143         while (iu >= 0) {
144 
145             // Look for single small sub-diagonal element
146             final int il = findSmallSubDiagonalElement(iu, norm);
147 
148             // Check for convergence
149             if (il == iu) {
150                 // One root found
151                 matrixT[iu][iu] += shift.exShift;
152                 iu--;
153                 iteration = 0;
154             } else if (il == iu - 1) {
155                 // Two roots found
156                 double p = (matrixT[iu - 1][iu - 1] - matrixT[iu][iu]) / 2.0;
157                 double q = p * p + matrixT[iu][iu - 1] * matrixT[iu - 1][iu];
158                 matrixT[iu][iu] += shift.exShift;
159                 matrixT[iu - 1][iu - 1] += shift.exShift;
160 
161                 if (q >= 0) {
162                     double z = JdkMath.sqrt(JdkMath.abs(q));
163                     if (p >= 0) {
164                         z = p + z;
165                     } else {
166                         z = p - z;
167                     }
168                     final double x = matrixT[iu][iu - 1];
169                     final double s = JdkMath.abs(x) + JdkMath.abs(z);
170                     p = x / s;
171                     q = z / s;
172                     final double r = JdkMath.sqrt(p * p + q * q);
173                     p /= r;
174                     q /= r;
175 
176                     // Row modification
177                     for (int j = iu - 1; j < n; j++) {
178                         z = matrixT[iu - 1][j];
179                         matrixT[iu - 1][j] = q * z + p * matrixT[iu][j];
180                         matrixT[iu][j] = q * matrixT[iu][j] - p * z;
181                     }
182 
183                     // Column modification
184                     for (int i = 0; i <= iu; i++) {
185                         z = matrixT[i][iu - 1];
186                         matrixT[i][iu - 1] = q * z + p * matrixT[i][iu];
187                         matrixT[i][iu] = q * matrixT[i][iu] - p * z;
188                     }
189 
190                     // Accumulate transformations
191                     for (int i = 0; i <= n - 1; i++) {
192                         z = matrixP[i][iu - 1];
193                         matrixP[i][iu - 1] = q * z + p * matrixP[i][iu];
194                         matrixP[i][iu] = q * matrixP[i][iu] - p * z;
195                     }
196                 }
197                 iu -= 2;
198                 iteration = 0;
199             } else {
200                 // No convergence yet
201                 computeShift(il, iu, iteration, shift);
202 
203                 // stop transformation after too many iterations
204                 if (++iteration > MAX_ITERATIONS) {
205                     throw new MaxCountExceededException(LocalizedFormats.CONVERGENCE_FAILED,
206                                                         MAX_ITERATIONS);
207                 }
208 
209                 // the initial houseHolder vector for the QR step
210                 final double[] hVec = new double[3];
211 
212                 final int im = initQRStep(il, iu, shift, hVec);
213                 performDoubleQRStep(il, im, iu, shift, hVec);
214             }
215         }
216     }
217 
218     /**
219      * Computes the L1 norm of the (quasi-)triangular matrix T.
220      *
221      * @return the L1 norm of matrix T
222      */
223     private double getNorm() {
224         double norm = 0.0;
225         for (int i = 0; i < matrixT.length; i++) {
226             // as matrix T is (quasi-)triangular, also take the sub-diagonal element into account
227             for (int j = JdkMath.max(i - 1, 0); j < matrixT.length; j++) {
228                 norm += JdkMath.abs(matrixT[i][j]);
229             }
230         }
231         return norm;
232     }
233 
234     /**
235      * Find the first small sub-diagonal element and returns its index.
236      *
237      * @param startIdx the starting index for the search
238      * @param norm the L1 norm of the matrix
239      * @return the index of the first small sub-diagonal element
240      */
241     private int findSmallSubDiagonalElement(final int startIdx, final double norm) {
242         int l = startIdx;
243         while (l > 0) {
244             double s = JdkMath.abs(matrixT[l - 1][l - 1]) + JdkMath.abs(matrixT[l][l]);
245             if (s == 0.0) {
246                 s = norm;
247             }
248             if (JdkMath.abs(matrixT[l][l - 1]) < epsilon * s) {
249                 break;
250             }
251             l--;
252         }
253         return l;
254     }
255 
256     /**
257      * Compute the shift for the current iteration.
258      *
259      * @param l the index of the small sub-diagonal element
260      * @param idx the current eigenvalue index
261      * @param iteration the current iteration
262      * @param shift holder for shift information
263      */
264     private void computeShift(final int l, final int idx, final int iteration, final ShiftInfo shift) {
265         // Form shift
266         shift.x = matrixT[idx][idx];
267         shift.y = shift.w = 0.0;
268         if (l < idx) {
269             shift.y = matrixT[idx - 1][idx - 1];
270             shift.w = matrixT[idx][idx - 1] * matrixT[idx - 1][idx];
271         }
272 
273         // Wilkinson's original ad hoc shift
274         if (iteration == 10) {
275             shift.exShift += shift.x;
276             for (int i = 0; i <= idx; i++) {
277                 matrixT[i][i] -= shift.x;
278             }
279             final double s = JdkMath.abs(matrixT[idx][idx - 1]) + JdkMath.abs(matrixT[idx - 1][idx - 2]);
280             shift.x = 0.75 * s;
281             shift.y = 0.75 * s;
282             shift.w = -0.4375 * s * s;
283         }
284 
285         // MATLAB's new ad hoc shift
286         if (iteration == 30) {
287             double s = (shift.y - shift.x) / 2.0;
288             s = s * s + shift.w;
289             if (s > 0.0) {
290                 s = JdkMath.sqrt(s);
291                 if (shift.y < shift.x) {
292                     s = -s;
293                 }
294                 s = shift.x - shift.w / ((shift.y - shift.x) / 2.0 + s);
295                 for (int i = 0; i <= idx; i++) {
296                     matrixT[i][i] -= s;
297                 }
298                 shift.exShift += s;
299                 shift.x = shift.y = shift.w = 0.964;
300             }
301         }
302     }
303 
304     /**
305      * Initialize the householder vectors for the QR step.
306      *
307      * @param il the index of the small sub-diagonal element
308      * @param iu the current eigenvalue index
309      * @param shift shift information holder
310      * @param hVec the initial houseHolder vector
311      * @return the start index for the QR step
312      */
313     private int initQRStep(int il, final int iu, final ShiftInfo shift, double[] hVec) {
314         // Look for two consecutive small sub-diagonal elements
315         int im = iu - 2;
316         while (im >= il) {
317             final double z = matrixT[im][im];
318             final double r = shift.x - z;
319             double s = shift.y - z;
320             hVec[0] = (r * s - shift.w) / matrixT[im + 1][im] + matrixT[im][im + 1];
321             hVec[1] = matrixT[im + 1][im + 1] - z - r - s;
322             hVec[2] = matrixT[im + 2][im + 1];
323 
324             if (im == il) {
325                 break;
326             }
327 
328             final double lhs = JdkMath.abs(matrixT[im][im - 1]) * (JdkMath.abs(hVec[1]) + JdkMath.abs(hVec[2]));
329             final double rhs = JdkMath.abs(hVec[0]) * (JdkMath.abs(matrixT[im - 1][im - 1]) +
330                                                         JdkMath.abs(z) +
331                                                         JdkMath.abs(matrixT[im + 1][im + 1]));
332 
333             if (lhs < epsilon * rhs) {
334                 break;
335             }
336             im--;
337         }
338 
339         return im;
340     }
341 
342     /**
343      * Perform a double QR step involving rows l:idx and columns m:n.
344      *
345      * @param il the index of the small sub-diagonal element
346      * @param im the start index for the QR step
347      * @param iu the current eigenvalue index
348      * @param shift shift information holder
349      * @param hVec the initial houseHolder vector
350      */
351     private void performDoubleQRStep(final int il, final int im, final int iu,
352                                      final ShiftInfo shift, final double[] hVec) {
353 
354         final int n = matrixT.length;
355         double p = hVec[0];
356         double q = hVec[1];
357         double r = hVec[2];
358 
359         for (int k = im; k <= iu - 1; k++) {
360             boolean notlast = k != (iu - 1);
361             if (k != im) {
362                 p = matrixT[k][k - 1];
363                 q = matrixT[k + 1][k - 1];
364                 r = notlast ? matrixT[k + 2][k - 1] : 0.0;
365                 shift.x = JdkMath.abs(p) + JdkMath.abs(q) + JdkMath.abs(r);
366                 if (Precision.equals(shift.x, 0.0, epsilon)) {
367                     continue;
368                 }
369                 p /= shift.x;
370                 q /= shift.x;
371                 r /= shift.x;
372             }
373             double s = JdkMath.sqrt(p * p + q * q + r * r);
374             if (p < 0.0) {
375                 s = -s;
376             }
377             if (s != 0.0) {
378                 if (k != im) {
379                     matrixT[k][k - 1] = -s * shift.x;
380                 } else if (il != im) {
381                     matrixT[k][k - 1] = -matrixT[k][k - 1];
382                 }
383                 p += s;
384                 shift.x = p / s;
385                 shift.y = q / s;
386                 double z = r / s;
387                 q /= p;
388                 r /= p;
389 
390                 // Row modification
391                 for (int j = k; j < n; j++) {
392                     p = matrixT[k][j] + q * matrixT[k + 1][j];
393                     if (notlast) {
394                         p += r * matrixT[k + 2][j];
395                         matrixT[k + 2][j] -= p * z;
396                     }
397                     matrixT[k][j] -= p * shift.x;
398                     matrixT[k + 1][j] -= p * shift.y;
399                 }
400 
401                 // Column modification
402                 for (int i = 0; i <= JdkMath.min(iu, k + 3); i++) {
403                     p = shift.x * matrixT[i][k] + shift.y * matrixT[i][k + 1];
404                     if (notlast) {
405                         p += z * matrixT[i][k + 2];
406                         matrixT[i][k + 2] -= p * r;
407                     }
408                     matrixT[i][k] -= p;
409                     matrixT[i][k + 1] -= p * q;
410                 }
411 
412                 // Accumulate transformations
413                 final int high = matrixT.length - 1;
414                 for (int i = 0; i <= high; i++) {
415                     p = shift.x * matrixP[i][k] + shift.y * matrixP[i][k + 1];
416                     if (notlast) {
417                         p += z * matrixP[i][k + 2];
418                         matrixP[i][k + 2] -= p * r;
419                     }
420                     matrixP[i][k] -= p;
421                     matrixP[i][k + 1] -= p * q;
422                 }
423             }  // (s != 0)
424         }  // k loop
425 
426         // clean up pollution due to round-off errors
427         for (int i = im + 2; i <= iu; i++) {
428             matrixT[i][i-2] = 0.0;
429             if (i > im + 2) {
430                 matrixT[i][i-3] = 0.0;
431             }
432         }
433     }
434 
435     /**
436      * Internal data structure holding the current shift information.
437      * Contains variable names as present in the original JAMA code.
438      */
439     private static final class ShiftInfo {
440         // CHECKSTYLE: stop all
441 
442         /** x shift info. */
443         double x;
444         /** y shift info. */
445         double y;
446         /** w shift info. */
447         double w;
448         /** Indicates an exceptional shift. */
449         double exShift;
450 
451         // CHECKSTYLE: resume all
452     }
453 }