View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math3.linear;
19  
20  import java.util.Arrays;
21  
22  import org.apache.commons.math3.exception.DimensionMismatchException;
23  import org.apache.commons.math3.util.FastMath;
24  
25  
26  /**
27   * Calculates the QR-decomposition of a matrix.
28   * <p>The QR-decomposition of a matrix A consists of two matrices Q and R
29   * that satisfy: A = QR, Q is orthogonal (Q<sup>T</sup>Q = I), and R is
30   * upper triangular. If A is m&times;n, Q is m&times;m and R m&times;n.</p>
31   * <p>This class compute the decomposition using Householder reflectors.</p>
32   * <p>For efficiency purposes, the decomposition in packed form is transposed.
33   * This allows inner loop to iterate inside rows, which is much more cache-efficient
34   * in Java.</p>
35   * <p>This class is based on the class with similar name from the
36   * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library, with the
37   * following changes:</p>
38   * <ul>
39   *   <li>a {@link #getQT() getQT} method has been added,</li>
40   *   <li>the {@code solve} and {@code isFullRank} methods have been replaced
41   *   by a {@link #getSolver() getSolver} method and the equivalent methods
42   *   provided by the returned {@link DecompositionSolver}.</li>
43   * </ul>
44   *
45   * @see <a href="http://mathworld.wolfram.com/QRDecomposition.html">MathWorld</a>
46   * @see <a href="http://en.wikipedia.org/wiki/QR_decomposition">Wikipedia</a>
47   *
48   * @since 1.2 (changed to concrete class in 3.0)
49   */
50  public class QRDecomposition {
51      /**
52       * A packed TRANSPOSED representation of the QR decomposition.
53       * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
54       * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
55       * from which an explicit form of Q can be recomputed if desired.</p>
56       */
57      private double[][] qrt;
58      /** The diagonal elements of R. */
59      private double[] rDiag;
60      /** Cached value of Q. */
61      private RealMatrix cachedQ;
62      /** Cached value of QT. */
63      private RealMatrix cachedQT;
64      /** Cached value of R. */
65      private RealMatrix cachedR;
66      /** Cached value of H. */
67      private RealMatrix cachedH;
68      /** Singularity threshold. */
69      private final double threshold;
70  
71      /**
72       * Calculates the QR-decomposition of the given matrix.
73       * The singularity threshold defaults to zero.
74       *
75       * @param matrix The matrix to decompose.
76       *
77       * @see #QRDecomposition(RealMatrix,double)
78       */
79      public QRDecomposition(RealMatrix matrix) {
80          this(matrix, 0d);
81      }
82  
83      /**
84       * Calculates the QR-decomposition of the given matrix.
85       *
86       * @param matrix The matrix to decompose.
87       * @param threshold Singularity threshold.
88       */
89      public QRDecomposition(RealMatrix matrix,
90                             double threshold) {
91          this.threshold = threshold;
92  
93          final int m = matrix.getRowDimension();
94          final int n = matrix.getColumnDimension();
95          qrt = matrix.transpose().getData();
96          rDiag = new double[FastMath.min(m, n)];
97          cachedQ  = null;
98          cachedQT = null;
99          cachedR  = null;
100         cachedH  = null;
101 
102         decompose(qrt);
103 
104     }
105 
106     /** Decompose matrix.
107      * @param matrix transposed matrix
108      * @since 3.2
109      */
110     protected void decompose(double[][] matrix) {
111         for (int minor = 0; minor < FastMath.min(qrt.length, qrt[0].length); minor++) {
112             performHouseholderReflection(minor, qrt);
113         }
114     }
115 
116     /** Perform Householder reflection for a minor A(minor, minor) of A.
117      * @param minor minor index
118      * @param matrix transposed matrix
119      * @since 3.2
120      */
121     protected void performHouseholderReflection(int minor, double[][] matrix) {
122 
123         final double[] qrtMinor = qrt[minor];
124 
125         /*
126          * Let x be the first column of the minor, and a^2 = |x|^2.
127          * x will be in the positions qr[minor][minor] through qr[m][minor].
128          * The first column of the transformed minor will be (a,0,0,..)'
129          * The sign of a is chosen to be opposite to the sign of the first
130          * component of x. Let's find a:
131          */
132         double xNormSqr = 0;
133         for (int row = minor; row < qrtMinor.length; row++) {
134             final double c = qrtMinor[row];
135             xNormSqr += c * c;
136         }
137         final double a = (qrtMinor[minor] > 0) ? -FastMath.sqrt(xNormSqr) : FastMath.sqrt(xNormSqr);
138         rDiag[minor] = a;
139 
140         if (a != 0.0) {
141 
142             /*
143              * Calculate the normalized reflection vector v and transform
144              * the first column. We know the norm of v beforehand: v = x-ae
145              * so |v|^2 = <x-ae,x-ae> = <x,x>-2a<x,e>+a^2<e,e> =
146              * a^2+a^2-2a<x,e> = 2a*(a - <x,e>).
147              * Here <x, e> is now qr[minor][minor].
148              * v = x-ae is stored in the column at qr:
149              */
150             qrtMinor[minor] -= a; // now |v|^2 = -2a*(qr[minor][minor])
151 
152             /*
153              * Transform the rest of the columns of the minor:
154              * They will be transformed by the matrix H = I-2vv'/|v|^2.
155              * If x is a column vector of the minor, then
156              * Hx = (I-2vv'/|v|^2)x = x-2vv'x/|v|^2 = x - 2<x,v>/|v|^2 v.
157              * Therefore the transformation is easily calculated by
158              * subtracting the column vector (2<x,v>/|v|^2)v from x.
159              *
160              * Let 2<x,v>/|v|^2 = alpha. From above we have
161              * |v|^2 = -2a*(qr[minor][minor]), so
162              * alpha = -<x,v>/(a*qr[minor][minor])
163              */
164             for (int col = minor+1; col < qrt.length; col++) {
165                 final double[] qrtCol = qrt[col];
166                 double alpha = 0;
167                 for (int row = minor; row < qrtCol.length; row++) {
168                     alpha -= qrtCol[row] * qrtMinor[row];
169                 }
170                 alpha /= a * qrtMinor[minor];
171 
172                 // Subtract the column vector alpha*v from x.
173                 for (int row = minor; row < qrtCol.length; row++) {
174                     qrtCol[row] -= alpha * qrtMinor[row];
175                 }
176             }
177         }
178     }
179 
180 
181     /**
182      * Returns the matrix R of the decomposition.
183      * <p>R is an upper-triangular matrix</p>
184      * @return the R matrix
185      */
186     public RealMatrix getR() {
187 
188         if (cachedR == null) {
189 
190             // R is supposed to be m x n
191             final int n = qrt.length;
192             final int m = qrt[0].length;
193             double[][] ra = new double[m][n];
194             // copy the diagonal from rDiag and the upper triangle of qr
195             for (int row = FastMath.min(m, n) - 1; row >= 0; row--) {
196                 ra[row][row] = rDiag[row];
197                 for (int col = row + 1; col < n; col++) {
198                     ra[row][col] = qrt[col][row];
199                 }
200             }
201             cachedR = MatrixUtils.createRealMatrix(ra);
202         }
203 
204         // return the cached matrix
205         return cachedR;
206     }
207 
208     /**
209      * Returns the matrix Q of the decomposition.
210      * <p>Q is an orthogonal matrix</p>
211      * @return the Q matrix
212      */
213     public RealMatrix getQ() {
214         if (cachedQ == null) {
215             cachedQ = getQT().transpose();
216         }
217         return cachedQ;
218     }
219 
220     /**
221      * Returns the transpose of the matrix Q of the decomposition.
222      * <p>Q is an orthogonal matrix</p>
223      * @return the transpose of the Q matrix, Q<sup>T</sup>
224      */
225     public RealMatrix getQT() {
226         if (cachedQT == null) {
227 
228             // QT is supposed to be m x m
229             final int n = qrt.length;
230             final int m = qrt[0].length;
231             double[][] qta = new double[m][m];
232 
233             /*
234              * Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
235              * applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
236              * succession to the result
237              */
238             for (int minor = m - 1; minor >= FastMath.min(m, n); minor--) {
239                 qta[minor][minor] = 1.0d;
240             }
241 
242             for (int minor = FastMath.min(m, n)-1; minor >= 0; minor--){
243                 final double[] qrtMinor = qrt[minor];
244                 qta[minor][minor] = 1.0d;
245                 if (qrtMinor[minor] != 0.0) {
246                     for (int col = minor; col < m; col++) {
247                         double alpha = 0;
248                         for (int row = minor; row < m; row++) {
249                             alpha -= qta[col][row] * qrtMinor[row];
250                         }
251                         alpha /= rDiag[minor] * qrtMinor[minor];
252 
253                         for (int row = minor; row < m; row++) {
254                             qta[col][row] += -alpha * qrtMinor[row];
255                         }
256                     }
257                 }
258             }
259             cachedQT = MatrixUtils.createRealMatrix(qta);
260         }
261 
262         // return the cached matrix
263         return cachedQT;
264     }
265 
266     /**
267      * Returns the Householder reflector vectors.
268      * <p>H is a lower trapezoidal matrix whose columns represent
269      * each successive Householder reflector vector. This matrix is used
270      * to compute Q.</p>
271      * @return a matrix containing the Householder reflector vectors
272      */
273     public RealMatrix getH() {
274         if (cachedH == null) {
275 
276             final int n = qrt.length;
277             final int m = qrt[0].length;
278             double[][] ha = new double[m][n];
279             for (int i = 0; i < m; ++i) {
280                 for (int j = 0; j < FastMath.min(i + 1, n); ++j) {
281                     ha[i][j] = qrt[j][i] / -rDiag[j];
282                 }
283             }
284             cachedH = MatrixUtils.createRealMatrix(ha);
285         }
286 
287         // return the cached matrix
288         return cachedH;
289     }
290 
291     /**
292      * Get a solver for finding the A &times; X = B solution in least square sense.
293      * <p>
294      * Least Square sense means a solver can be computed for an overdetermined system,
295      * (i.e. a system with more equations than unknowns, which corresponds to a tall A
296      * matrix with more rows than columns). In any case, if the matrix is singular
297      * within the tolerance set at {@link QRDecomposition#QRDecomposition(RealMatrix,
298      * double) construction}, an error will be triggered when
299      * the {@link DecompositionSolver#solve(RealVector) solve} method will be called.
300      * </p>
301      * @return a solver
302      */
303     public DecompositionSolver getSolver() {
304         return new Solver(qrt, rDiag, threshold);
305     }
306 
307     /** Specialized solver. */
308     private static class Solver implements DecompositionSolver {
309         /**
310          * A packed TRANSPOSED representation of the QR decomposition.
311          * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
312          * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
313          * from which an explicit form of Q can be recomputed if desired.</p>
314          */
315         private final double[][] qrt;
316         /** The diagonal elements of R. */
317         private final double[] rDiag;
318         /** Singularity threshold. */
319         private final double threshold;
320 
321         /**
322          * Build a solver from decomposed matrix.
323          *
324          * @param qrt Packed TRANSPOSED representation of the QR decomposition.
325          * @param rDiag Diagonal elements of R.
326          * @param threshold Singularity threshold.
327          */
328         private Solver(final double[][] qrt,
329                        final double[] rDiag,
330                        final double threshold) {
331             this.qrt   = qrt;
332             this.rDiag = rDiag;
333             this.threshold = threshold;
334         }
335 
336         /** {@inheritDoc} */
337         public boolean isNonSingular() {
338             for (double diag : rDiag) {
339                 if (FastMath.abs(diag) <= threshold) {
340                     return false;
341                 }
342             }
343             return true;
344         }
345 
346         /** {@inheritDoc} */
347         public RealVector solve(RealVector b) {
348             final int n = qrt.length;
349             final int m = qrt[0].length;
350             if (b.getDimension() != m) {
351                 throw new DimensionMismatchException(b.getDimension(), m);
352             }
353             if (!isNonSingular()) {
354                 throw new SingularMatrixException();
355             }
356 
357             final double[] x = new double[n];
358             final double[] y = b.toArray();
359 
360             // apply Householder transforms to solve Q.y = b
361             for (int minor = 0; minor < FastMath.min(m, n); minor++) {
362 
363                 final double[] qrtMinor = qrt[minor];
364                 double dotProduct = 0;
365                 for (int row = minor; row < m; row++) {
366                     dotProduct += y[row] * qrtMinor[row];
367                 }
368                 dotProduct /= rDiag[minor] * qrtMinor[minor];
369 
370                 for (int row = minor; row < m; row++) {
371                     y[row] += dotProduct * qrtMinor[row];
372                 }
373             }
374 
375             // solve triangular system R.x = y
376             for (int row = rDiag.length - 1; row >= 0; --row) {
377                 y[row] /= rDiag[row];
378                 final double yRow = y[row];
379                 final double[] qrtRow = qrt[row];
380                 x[row] = yRow;
381                 for (int i = 0; i < row; i++) {
382                     y[i] -= yRow * qrtRow[i];
383                 }
384             }
385 
386             return new ArrayRealVector(x, false);
387         }
388 
389         /** {@inheritDoc} */
390         public RealMatrix solve(RealMatrix b) {
391             final int n = qrt.length;
392             final int m = qrt[0].length;
393             if (b.getRowDimension() != m) {
394                 throw new DimensionMismatchException(b.getRowDimension(), m);
395             }
396             if (!isNonSingular()) {
397                 throw new SingularMatrixException();
398             }
399 
400             final int columns        = b.getColumnDimension();
401             final int blockSize      = BlockRealMatrix.BLOCK_SIZE;
402             final int cBlocks        = (columns + blockSize - 1) / blockSize;
403             final double[][] xBlocks = BlockRealMatrix.createBlocksLayout(n, columns);
404             final double[][] y       = new double[b.getRowDimension()][blockSize];
405             final double[]   alpha   = new double[blockSize];
406 
407             for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
408                 final int kStart = kBlock * blockSize;
409                 final int kEnd   = FastMath.min(kStart + blockSize, columns);
410                 final int kWidth = kEnd - kStart;
411 
412                 // get the right hand side vector
413                 b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
414 
415                 // apply Householder transforms to solve Q.y = b
416                 for (int minor = 0; minor < FastMath.min(m, n); minor++) {
417                     final double[] qrtMinor = qrt[minor];
418                     final double factor     = 1.0 / (rDiag[minor] * qrtMinor[minor]);
419 
420                     Arrays.fill(alpha, 0, kWidth, 0.0);
421                     for (int row = minor; row < m; ++row) {
422                         final double   d    = qrtMinor[row];
423                         final double[] yRow = y[row];
424                         for (int k = 0; k < kWidth; ++k) {
425                             alpha[k] += d * yRow[k];
426                         }
427                     }
428                     for (int k = 0; k < kWidth; ++k) {
429                         alpha[k] *= factor;
430                     }
431 
432                     for (int row = minor; row < m; ++row) {
433                         final double   d    = qrtMinor[row];
434                         final double[] yRow = y[row];
435                         for (int k = 0; k < kWidth; ++k) {
436                             yRow[k] += alpha[k] * d;
437                         }
438                     }
439                 }
440 
441                 // solve triangular system R.x = y
442                 for (int j = rDiag.length - 1; j >= 0; --j) {
443                     final int      jBlock = j / blockSize;
444                     final int      jStart = jBlock * blockSize;
445                     final double   factor = 1.0 / rDiag[j];
446                     final double[] yJ     = y[j];
447                     final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
448                     int index = (j - jStart) * kWidth;
449                     for (int k = 0; k < kWidth; ++k) {
450                         yJ[k]          *= factor;
451                         xBlock[index++] = yJ[k];
452                     }
453 
454                     final double[] qrtJ = qrt[j];
455                     for (int i = 0; i < j; ++i) {
456                         final double rIJ  = qrtJ[i];
457                         final double[] yI = y[i];
458                         for (int k = 0; k < kWidth; ++k) {
459                             yI[k] -= yJ[k] * rIJ;
460                         }
461                     }
462                 }
463             }
464 
465             return new BlockRealMatrix(n, columns, xBlocks, false);
466         }
467 
468         /**
469          * {@inheritDoc}
470          * @throws SingularMatrixException if the decomposed matrix is singular.
471          */
472         public RealMatrix getInverse() {
473             return solve(MatrixUtils.createRealIdentityMatrix(qrt[0].length));
474         }
475     }
476 }