View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.linear;
19  
20  import java.util.Arrays;
21  
22  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
23  import org.apache.commons.math4.core.jdkmath.JdkMath;
24  import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
25  
26  
27  /**
28   * Calculates the QR-decomposition of a matrix.
29   * <p>The QR-decomposition of a matrix A consists of two matrices Q and R
30   * that satisfy: A = QR, Q is orthogonal (Q<sup>T</sup>Q = I), and R is
31   * upper triangular. If A is m&times;n, Q is m&times;m and R m&times;n.</p>
32   * <p>This class compute the decomposition using Householder reflectors.</p>
33   * <p>For efficiency purposes, the decomposition in packed form is transposed.
34   * This allows inner loop to iterate inside rows, which is much more cache-efficient
35   * in Java.</p>
36   * <p>This class is based on the class with similar name from the
37   * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library, with the
38   * following changes:</p>
39   * <ul>
40   *   <li>a {@link #getQT() getQT} method has been added,</li>
41   *   <li>the {@code solve} and {@code isFullRank} methods have been replaced
42   *   by a {@link #getSolver() getSolver} method and the equivalent methods
43   *   provided by the returned {@link DecompositionSolver}.</li>
44   * </ul>
45   *
46   * @see <a href="http://mathworld.wolfram.com/QRDecomposition.html">MathWorld</a>
47   * @see <a href="http://en.wikipedia.org/wiki/QR_decomposition">Wikipedia</a>
48   *
49   * @since 1.2 (changed to concrete class in 3.0)
50   */
51  public class QRDecomposition {
52      /**
53       * A packed TRANSPOSED representation of the QR decomposition.
54       * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
55       * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
56       * from which an explicit form of Q can be recomputed if desired.</p>
57       */
58      private double[][] qrt;
59      /** The diagonal elements of R. */
60      private double[] rDiag;
61      /** Cached value of Q. */
62      private RealMatrix cachedQ;
63      /** Cached value of QT. */
64      private RealMatrix cachedQT;
65      /** Cached value of R. */
66      private RealMatrix cachedR;
67      /** Cached value of H. */
68      private RealMatrix cachedH;
69      /** Singularity threshold. */
70      private final double threshold;
71  
72      /**
73       * Calculates the QR-decomposition of the given matrix.
74       * The singularity threshold defaults to zero.
75       *
76       * @param matrix The matrix to decompose.
77       *
78       * @see #QRDecomposition(RealMatrix,double)
79       */
80      public QRDecomposition(RealMatrix matrix) {
81          this(matrix, 0d);
82      }
83  
84      /**
85       * Calculates the QR-decomposition of the given matrix.
86       *
87       * @param matrix The matrix to decompose.
88       * @param threshold Singularity threshold.
89       * The matrix will be considered singular if the absolute value of
90       * any of the diagonal elements of the "R" matrix is smaller than
91       * the threshold.
92       */
93      public QRDecomposition(RealMatrix matrix,
94                             double threshold) {
95          this.threshold = threshold;
96  
97          final int m = matrix.getRowDimension();
98          final int n = matrix.getColumnDimension();
99          qrt = matrix.transpose().getData();
100         rDiag = new double[JdkMath.min(m, n)];
101         cachedQ  = null;
102         cachedQT = null;
103         cachedR  = null;
104         cachedH  = null;
105 
106         decompose(qrt);
107     }
108 
109     /** Decompose matrix.
110      * @param matrix transposed matrix
111      * @since 3.2
112      */
113     protected void decompose(double[][] matrix) {
114         for (int minor = 0; minor < JdkMath.min(matrix.length, matrix[0].length); minor++) {
115             performHouseholderReflection(minor, matrix);
116         }
117     }
118 
119     /** Perform Householder reflection for a minor A(minor, minor) of A.
120      * @param minor minor index
121      * @param matrix transposed matrix
122      * @since 3.2
123      */
124     protected void performHouseholderReflection(int minor, double[][] matrix) {
125 
126         final double[] qrtMinor = matrix[minor];
127 
128         /*
129          * Let x be the first column of the minor, and a^2 = |x|^2.
130          * x will be in the positions qr[minor][minor] through qr[m][minor].
131          * The first column of the transformed minor will be (a,0,0,..)'
132          * The sign of a is chosen to be opposite to the sign of the first
133          * component of x. Let's find a:
134          */
135         double xNormSqr = 0;
136         for (int row = minor; row < qrtMinor.length; row++) {
137             final double c = qrtMinor[row];
138             xNormSqr += c * c;
139         }
140         final double a = (qrtMinor[minor] > 0) ? -JdkMath.sqrt(xNormSqr) : JdkMath.sqrt(xNormSqr);
141         rDiag[minor] = a;
142 
143         if (a != 0.0) {
144 
145             /*
146              * Calculate the normalized reflection vector v and transform
147              * the first column. We know the norm of v beforehand: v = x-ae
148              * so |v|^2 = <x-ae,x-ae> = <x,x>-2a<x,e>+a^2<e,e> =
149              * a^2+a^2-2a<x,e> = 2a*(a - <x,e>).
150              * Here <x, e> is now qr[minor][minor].
151              * v = x-ae is stored in the column at qr:
152              */
153             qrtMinor[minor] -= a; // now |v|^2 = -2a*(qr[minor][minor])
154 
155             /*
156              * Transform the rest of the columns of the minor:
157              * They will be transformed by the matrix H = I-2vv'/|v|^2.
158              * If x is a column vector of the minor, then
159              * Hx = (I-2vv'/|v|^2)x = x-2vv'x/|v|^2 = x - 2<x,v>/|v|^2 v.
160              * Therefore the transformation is easily calculated by
161              * subtracting the column vector (2<x,v>/|v|^2)v from x.
162              *
163              * Let 2<x,v>/|v|^2 = alpha. From above we have
164              * |v|^2 = -2a*(qr[minor][minor]), so
165              * alpha = -<x,v>/(a*qr[minor][minor])
166              */
167             for (int col = minor+1; col < matrix.length; col++) {
168                 final double[] qrtCol = matrix[col];
169                 double alpha = 0;
170                 for (int row = minor; row < qrtCol.length; row++) {
171                     alpha -= qrtCol[row] * qrtMinor[row];
172                 }
173                 alpha /= a * qrtMinor[minor];
174 
175                 // Subtract the column vector alpha*v from x.
176                 for (int row = minor; row < qrtCol.length; row++) {
177                     qrtCol[row] -= alpha * qrtMinor[row];
178                 }
179             }
180         }
181     }
182 
183 
184     /**
185      * Returns the matrix R of the decomposition.
186      * <p>R is an upper-triangular matrix</p>
187      * @return the R matrix
188      */
189     public RealMatrix getR() {
190 
191         if (cachedR == null) {
192 
193             // R is supposed to be m x n
194             final int n = qrt.length;
195             final int m = qrt[0].length;
196             double[][] ra = new double[m][n];
197             // copy the diagonal from rDiag and the upper triangle of qr
198             for (int row = JdkMath.min(m, n) - 1; row >= 0; row--) {
199                 ra[row][row] = rDiag[row];
200                 for (int col = row + 1; col < n; col++) {
201                     ra[row][col] = qrt[col][row];
202                 }
203             }
204             cachedR = MatrixUtils.createRealMatrix(ra);
205         }
206 
207         // return the cached matrix
208         return cachedR;
209     }
210 
211     /**
212      * Returns the matrix Q of the decomposition.
213      * <p>Q is an orthogonal matrix</p>
214      * @return the Q matrix
215      */
216     public RealMatrix getQ() {
217         if (cachedQ == null) {
218             cachedQ = getQT().transpose();
219         }
220         return cachedQ;
221     }
222 
223     /**
224      * Returns the transpose of the matrix Q of the decomposition.
225      * <p>Q is an orthogonal matrix</p>
226      * @return the transpose of the Q matrix, Q<sup>T</sup>
227      */
228     public RealMatrix getQT() {
229         if (cachedQT == null) {
230 
231             // QT is supposed to be m x m
232             final int n = qrt.length;
233             final int m = qrt[0].length;
234             double[][] qta = new double[m][m];
235 
236             /*
237              * Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
238              * applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
239              * succession to the result
240              */
241             for (int minor = m - 1; minor >= JdkMath.min(m, n); minor--) {
242                 qta[minor][minor] = 1.0d;
243             }
244 
245             for (int minor = JdkMath.min(m, n)-1; minor >= 0; minor--){
246                 final double[] qrtMinor = qrt[minor];
247                 qta[minor][minor] = 1.0d;
248                 if (qrtMinor[minor] != 0.0) {
249                     for (int col = minor; col < m; col++) {
250                         double alpha = 0;
251                         for (int row = minor; row < m; row++) {
252                             alpha -= qta[col][row] * qrtMinor[row];
253                         }
254                         alpha /= rDiag[minor] * qrtMinor[minor];
255 
256                         for (int row = minor; row < m; row++) {
257                             qta[col][row] += -alpha * qrtMinor[row];
258                         }
259                     }
260                 }
261             }
262             cachedQT = MatrixUtils.createRealMatrix(qta);
263         }
264 
265         // return the cached matrix
266         return cachedQT;
267     }
268 
269     /**
270      * Returns the Householder reflector vectors.
271      * <p>H is a lower trapezoidal matrix whose columns represent
272      * each successive Householder reflector vector. This matrix is used
273      * to compute Q.</p>
274      * @return a matrix containing the Householder reflector vectors
275      */
276     public RealMatrix getH() {
277         if (cachedH == null) {
278 
279             final int n = qrt.length;
280             final int m = qrt[0].length;
281             double[][] ha = new double[m][n];
282             for (int i = 0; i < m; ++i) {
283                 for (int j = 0; j < JdkMath.min(i + 1, n); ++j) {
284                     ha[i][j] = qrt[j][i] / -rDiag[j];
285                 }
286             }
287             cachedH = MatrixUtils.createRealMatrix(ha);
288         }
289 
290         // return the cached matrix
291         return cachedH;
292     }
293 
294     /**
295      * Get a solver for finding the A &times; X = B solution in least square sense.
296      * <p>
297      * Least Square sense means a solver can be computed for an overdetermined system,
298      * (i.e. a system with more equations than unknowns, which corresponds to a tall A
299      * matrix with more rows than columns). In any case, if the matrix is singular
300      * within the tolerance set at {@link QRDecomposition#QRDecomposition(RealMatrix,
301      * double) construction}, an error will be triggered when
302      * the {@link DecompositionSolver#solve(RealVector) solve} method will be called.
303      * </p>
304      * @return a solver
305      */
306     public DecompositionSolver getSolver() {
307         return new Solver(qrt, rDiag, threshold);
308     }
309 
310     /** Specialized solver. */
311     private static final class Solver implements DecompositionSolver {
312         /**
313          * A packed TRANSPOSED representation of the QR decomposition.
314          * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
315          * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
316          * from which an explicit form of Q can be recomputed if desired.</p>
317          */
318         private final double[][] qrt;
319         /** The diagonal elements of R. */
320         private final double[] rDiag;
321         /** Singularity threshold. */
322         private final double threshold;
323 
324         /**
325          * Build a solver from decomposed matrix.
326          *
327          * @param qrt Packed TRANSPOSED representation of the QR decomposition.
328          * @param rDiag Diagonal elements of R.
329          * @param threshold Singularity threshold.
330          */
331         private Solver(final double[][] qrt,
332                        final double[] rDiag,
333                        final double threshold) {
334             this.qrt   = qrt;
335             this.rDiag = rDiag;
336             this.threshold = threshold;
337         }
338 
339         /** {@inheritDoc} */
340         @Override
341         public boolean isNonSingular() {
342             return !checkSingular(rDiag, threshold, false);
343         }
344 
345         /** {@inheritDoc} */
346         @Override
347         public RealVector solve(RealVector b) {
348             final int n = qrt.length;
349             final int m = qrt[0].length;
350             if (b.getDimension() != m) {
351                 throw new DimensionMismatchException(b.getDimension(), m);
352             }
353             checkSingular(rDiag, threshold, true);
354 
355             final double[] x = new double[n];
356             final double[] y = b.toArray();
357 
358             // apply Householder transforms to solve Q.y = b
359             for (int minor = 0; minor < JdkMath.min(m, n); minor++) {
360 
361                 final double[] qrtMinor = qrt[minor];
362                 double dotProduct = 0;
363                 for (int row = minor; row < m; row++) {
364                     dotProduct += y[row] * qrtMinor[row];
365                 }
366                 dotProduct /= rDiag[minor] * qrtMinor[minor];
367 
368                 for (int row = minor; row < m; row++) {
369                     y[row] += dotProduct * qrtMinor[row];
370                 }
371             }
372 
373             // solve triangular system R.x = y
374             for (int row = rDiag.length - 1; row >= 0; --row) {
375                 y[row] /= rDiag[row];
376                 final double yRow = y[row];
377                 final double[] qrtRow = qrt[row];
378                 x[row] = yRow;
379                 for (int i = 0; i < row; i++) {
380                     y[i] -= yRow * qrtRow[i];
381                 }
382             }
383 
384             return new ArrayRealVector(x, false);
385         }
386 
387         /** {@inheritDoc} */
388         @Override
389         public RealMatrix solve(RealMatrix b) {
390             final int n = qrt.length;
391             final int m = qrt[0].length;
392             if (b.getRowDimension() != m) {
393                 throw new DimensionMismatchException(b.getRowDimension(), m);
394             }
395             checkSingular(rDiag, threshold, true);
396 
397             final int columns        = b.getColumnDimension();
398             final int blockSize      = BlockRealMatrix.BLOCK_SIZE;
399             final int cBlocks        = (columns + blockSize - 1) / blockSize;
400             final double[][] xBlocks = BlockRealMatrix.createBlocksLayout(n, columns);
401             final double[][] y       = new double[b.getRowDimension()][blockSize];
402             final double[]   alpha   = new double[blockSize];
403 
404             for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
405                 final int kStart = kBlock * blockSize;
406                 final int kEnd   = JdkMath.min(kStart + blockSize, columns);
407                 final int kWidth = kEnd - kStart;
408 
409                 // get the right hand side vector
410                 b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
411 
412                 // apply Householder transforms to solve Q.y = b
413                 for (int minor = 0; minor < JdkMath.min(m, n); minor++) {
414                     final double[] qrtMinor = qrt[minor];
415                     final double factor     = 1.0 / (rDiag[minor] * qrtMinor[minor]);
416 
417                     Arrays.fill(alpha, 0, kWidth, 0.0);
418                     for (int row = minor; row < m; ++row) {
419                         final double   d    = qrtMinor[row];
420                         final double[] yRow = y[row];
421                         for (int k = 0; k < kWidth; ++k) {
422                             alpha[k] += d * yRow[k];
423                         }
424                     }
425                     for (int k = 0; k < kWidth; ++k) {
426                         alpha[k] *= factor;
427                     }
428 
429                     for (int row = minor; row < m; ++row) {
430                         final double   d    = qrtMinor[row];
431                         final double[] yRow = y[row];
432                         for (int k = 0; k < kWidth; ++k) {
433                             yRow[k] += alpha[k] * d;
434                         }
435                     }
436                 }
437 
438                 // solve triangular system R.x = y
439                 for (int j = rDiag.length - 1; j >= 0; --j) {
440                     final int      jBlock = j / blockSize;
441                     final int      jStart = jBlock * blockSize;
442                     final double   factor = 1.0 / rDiag[j];
443                     final double[] yJ     = y[j];
444                     final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
445                     int index = (j - jStart) * kWidth;
446                     for (int k = 0; k < kWidth; ++k) {
447                         yJ[k]          *= factor;
448                         xBlock[index++] = yJ[k];
449                     }
450 
451                     final double[] qrtJ = qrt[j];
452                     for (int i = 0; i < j; ++i) {
453                         final double rIJ  = qrtJ[i];
454                         final double[] yI = y[i];
455                         for (int k = 0; k < kWidth; ++k) {
456                             yI[k] -= yJ[k] * rIJ;
457                         }
458                     }
459                 }
460             }
461 
462             return new BlockRealMatrix(n, columns, xBlocks, false);
463         }
464 
465         /**
466          * {@inheritDoc}
467          * @throws SingularMatrixException if the decomposed matrix is singular.
468          */
469         @Override
470         public RealMatrix getInverse() {
471             return solve(MatrixUtils.createRealIdentityMatrix(qrt[0].length));
472         }
473 
474         /**
475          * Check singularity.
476          *
477          * @param diag Diagonal elements of the R matrix.
478          * @param min Singularity threshold.
479          * @param raise Whether to raise a {@link SingularMatrixException}
480          * if any element of the diagonal fails the check.
481          * @return {@code true} if any element of the diagonal is smaller
482          * or equal to {@code min}.
483          * @throws SingularMatrixException if the matrix is singular and
484          * {@code raise} is {@code true}.
485          */
486         private static boolean checkSingular(double[] diag,
487                                              double min,
488                                              boolean raise) {
489             final int len = diag.length;
490             for (int i = 0; i < len; i++) {
491                 final double d = diag[i];
492                 if (JdkMath.abs(d) <= min) {
493                     if (raise) {
494                         final SingularMatrixException e = new SingularMatrixException();
495                         e.getContext().addMessage(LocalizedFormats.NUMBER_TOO_SMALL, d, min);
496                         e.getContext().addMessage(LocalizedFormats.INDEX, i);
497                         throw e;
498                     } else {
499                         return true;
500                     }
501                 }
502             }
503             return false;
504         }
505     }
506 }