001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math3.linear;
019    
020    import java.util.Arrays;
021    
022    import org.apache.commons.math3.exception.DimensionMismatchException;
023    import org.apache.commons.math3.util.FastMath;
024    
025    
026    /**
027     * Calculates the QR-decomposition of a matrix.
028     * <p>The QR-decomposition of a matrix A consists of two matrices Q and R
029     * that satisfy: A = QR, Q is orthogonal (Q<sup>T</sup>Q = I), and R is
030     * upper triangular. If A is m&times;n, Q is m&times;m and R m&times;n.</p>
031     * <p>This class compute the decomposition using Householder reflectors.</p>
032     * <p>For efficiency purposes, the decomposition in packed form is transposed.
033     * This allows inner loop to iterate inside rows, which is much more cache-efficient
034     * in Java.</p>
035     * <p>This class is based on the class with similar name from the
036     * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library, with the
037     * following changes:</p>
038     * <ul>
039     *   <li>a {@link #getQT() getQT} method has been added,</li>
040     *   <li>the {@code solve} and {@code isFullRank} methods have been replaced
041     *   by a {@link #getSolver() getSolver} method and the equivalent methods
042     *   provided by the returned {@link DecompositionSolver}.</li>
043     * </ul>
044     *
045     * @see <a href="http://mathworld.wolfram.com/QRDecomposition.html">MathWorld</a>
046     * @see <a href="http://en.wikipedia.org/wiki/QR_decomposition">Wikipedia</a>
047     *
048     * @version $Id: QRDecomposition.java 1462423 2013-03-29 07:25:18Z luc $
049     * @since 1.2 (changed to concrete class in 3.0)
050     */
051    public class QRDecomposition {
052        /**
053         * A packed TRANSPOSED representation of the QR decomposition.
054         * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
055         * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
056         * from which an explicit form of Q can be recomputed if desired.</p>
057         */
058        private double[][] qrt;
059        /** The diagonal elements of R. */
060        private double[] rDiag;
061        /** Cached value of Q. */
062        private RealMatrix cachedQ;
063        /** Cached value of QT. */
064        private RealMatrix cachedQT;
065        /** Cached value of R. */
066        private RealMatrix cachedR;
067        /** Cached value of H. */
068        private RealMatrix cachedH;
069        /** Singularity threshold. */
070        private final double threshold;
071    
072        /**
073         * Calculates the QR-decomposition of the given matrix.
074         * The singularity threshold defaults to zero.
075         *
076         * @param matrix The matrix to decompose.
077         *
078         * @see #QRDecomposition(RealMatrix,double)
079         */
080        public QRDecomposition(RealMatrix matrix) {
081            this(matrix, 0d);
082        }
083    
084        /**
085         * Calculates the QR-decomposition of the given matrix.
086         *
087         * @param matrix The matrix to decompose.
088         * @param threshold Singularity threshold.
089         */
090        public QRDecomposition(RealMatrix matrix,
091                               double threshold) {
092            this.threshold = threshold;
093    
094            final int m = matrix.getRowDimension();
095            final int n = matrix.getColumnDimension();
096            qrt = matrix.transpose().getData();
097            rDiag = new double[FastMath.min(m, n)];
098            cachedQ  = null;
099            cachedQT = null;
100            cachedR  = null;
101            cachedH  = null;
102    
103            decompose(qrt);
104    
105        }
106    
107        /** Decompose matrix.
108         * @param matrix transposed matrix
109         * @since 3.2
110         */
111        protected void decompose(double[][] matrix) {
112            for (int minor = 0; minor < FastMath.min(qrt.length, qrt[0].length); minor++) {
113                performHouseholderReflection(minor, qrt);
114            }
115        }
116    
117        /** Perform Householder reflection for a minor A(minor, minor) of A.
118         * @param minor minor index
119         * @param matrix transposed matrix
120         * @since 3.2
121         */
122        protected void performHouseholderReflection(int minor, double[][] matrix) {
123    
124            final double[] qrtMinor = qrt[minor];
125    
126            /*
127             * Let x be the first column of the minor, and a^2 = |x|^2.
128             * x will be in the positions qr[minor][minor] through qr[m][minor].
129             * The first column of the transformed minor will be (a,0,0,..)'
130             * The sign of a is chosen to be opposite to the sign of the first
131             * component of x. Let's find a:
132             */
133            double xNormSqr = 0;
134            for (int row = minor; row < qrtMinor.length; row++) {
135                final double c = qrtMinor[row];
136                xNormSqr += c * c;
137            }
138            final double a = (qrtMinor[minor] > 0) ? -FastMath.sqrt(xNormSqr) : FastMath.sqrt(xNormSqr);
139            rDiag[minor] = a;
140    
141            if (a != 0.0) {
142    
143                /*
144                 * Calculate the normalized reflection vector v and transform
145                 * the first column. We know the norm of v beforehand: v = x-ae
146                 * so |v|^2 = <x-ae,x-ae> = <x,x>-2a<x,e>+a^2<e,e> =
147                 * a^2+a^2-2a<x,e> = 2a*(a - <x,e>).
148                 * Here <x, e> is now qr[minor][minor].
149                 * v = x-ae is stored in the column at qr:
150                 */
151                qrtMinor[minor] -= a; // now |v|^2 = -2a*(qr[minor][minor])
152    
153                /*
154                 * Transform the rest of the columns of the minor:
155                 * They will be transformed by the matrix H = I-2vv'/|v|^2.
156                 * If x is a column vector of the minor, then
157                 * Hx = (I-2vv'/|v|^2)x = x-2vv'x/|v|^2 = x - 2<x,v>/|v|^2 v.
158                 * Therefore the transformation is easily calculated by
159                 * subtracting the column vector (2<x,v>/|v|^2)v from x.
160                 *
161                 * Let 2<x,v>/|v|^2 = alpha. From above we have
162                 * |v|^2 = -2a*(qr[minor][minor]), so
163                 * alpha = -<x,v>/(a*qr[minor][minor])
164                 */
165                for (int col = minor+1; col < qrt.length; col++) {
166                    final double[] qrtCol = qrt[col];
167                    double alpha = 0;
168                    for (int row = minor; row < qrtCol.length; row++) {
169                        alpha -= qrtCol[row] * qrtMinor[row];
170                    }
171                    alpha /= a * qrtMinor[minor];
172    
173                    // Subtract the column vector alpha*v from x.
174                    for (int row = minor; row < qrtCol.length; row++) {
175                        qrtCol[row] -= alpha * qrtMinor[row];
176                    }
177                }
178            }
179        }
180    
181    
182        /**
183         * Returns the matrix R of the decomposition.
184         * <p>R is an upper-triangular matrix</p>
185         * @return the R matrix
186         */
187        public RealMatrix getR() {
188    
189            if (cachedR == null) {
190    
191                // R is supposed to be m x n
192                final int n = qrt.length;
193                final int m = qrt[0].length;
194                double[][] ra = new double[m][n];
195                // copy the diagonal from rDiag and the upper triangle of qr
196                for (int row = FastMath.min(m, n) - 1; row >= 0; row--) {
197                    ra[row][row] = rDiag[row];
198                    for (int col = row + 1; col < n; col++) {
199                        ra[row][col] = qrt[col][row];
200                    }
201                }
202                cachedR = MatrixUtils.createRealMatrix(ra);
203            }
204    
205            // return the cached matrix
206            return cachedR;
207        }
208    
209        /**
210         * Returns the matrix Q of the decomposition.
211         * <p>Q is an orthogonal matrix</p>
212         * @return the Q matrix
213         */
214        public RealMatrix getQ() {
215            if (cachedQ == null) {
216                cachedQ = getQT().transpose();
217            }
218            return cachedQ;
219        }
220    
221        /**
222         * Returns the transpose of the matrix Q of the decomposition.
223         * <p>Q is an orthogonal matrix</p>
224         * @return the transpose of the Q matrix, Q<sup>T</sup>
225         */
226        public RealMatrix getQT() {
227            if (cachedQT == null) {
228    
229                // QT is supposed to be m x m
230                final int n = qrt.length;
231                final int m = qrt[0].length;
232                double[][] qta = new double[m][m];
233    
234                /*
235                 * Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
236                 * applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
237                 * succession to the result
238                 */
239                for (int minor = m - 1; minor >= FastMath.min(m, n); minor--) {
240                    qta[minor][minor] = 1.0d;
241                }
242    
243                for (int minor = FastMath.min(m, n)-1; minor >= 0; minor--){
244                    final double[] qrtMinor = qrt[minor];
245                    qta[minor][minor] = 1.0d;
246                    if (qrtMinor[minor] != 0.0) {
247                        for (int col = minor; col < m; col++) {
248                            double alpha = 0;
249                            for (int row = minor; row < m; row++) {
250                                alpha -= qta[col][row] * qrtMinor[row];
251                            }
252                            alpha /= rDiag[minor] * qrtMinor[minor];
253    
254                            for (int row = minor; row < m; row++) {
255                                qta[col][row] += -alpha * qrtMinor[row];
256                            }
257                        }
258                    }
259                }
260                cachedQT = MatrixUtils.createRealMatrix(qta);
261            }
262    
263            // return the cached matrix
264            return cachedQT;
265        }
266    
267        /**
268         * Returns the Householder reflector vectors.
269         * <p>H is a lower trapezoidal matrix whose columns represent
270         * each successive Householder reflector vector. This matrix is used
271         * to compute Q.</p>
272         * @return a matrix containing the Householder reflector vectors
273         */
274        public RealMatrix getH() {
275            if (cachedH == null) {
276    
277                final int n = qrt.length;
278                final int m = qrt[0].length;
279                double[][] ha = new double[m][n];
280                for (int i = 0; i < m; ++i) {
281                    for (int j = 0; j < FastMath.min(i + 1, n); ++j) {
282                        ha[i][j] = qrt[j][i] / -rDiag[j];
283                    }
284                }
285                cachedH = MatrixUtils.createRealMatrix(ha);
286            }
287    
288            // return the cached matrix
289            return cachedH;
290        }
291    
292        /**
293         * Get a solver for finding the A &times; X = B solution in least square sense.
294         * @return a solver
295         */
296        public DecompositionSolver getSolver() {
297            return new Solver(qrt, rDiag, threshold);
298        }
299    
300        /** Specialized solver. */
301        private static class Solver implements DecompositionSolver {
302            /**
303             * A packed TRANSPOSED representation of the QR decomposition.
304             * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
305             * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
306             * from which an explicit form of Q can be recomputed if desired.</p>
307             */
308            private final double[][] qrt;
309            /** The diagonal elements of R. */
310            private final double[] rDiag;
311            /** Singularity threshold. */
312            private final double threshold;
313    
314            /**
315             * Build a solver from decomposed matrix.
316             *
317             * @param qrt Packed TRANSPOSED representation of the QR decomposition.
318             * @param rDiag Diagonal elements of R.
319             * @param threshold Singularity threshold.
320             */
321            private Solver(final double[][] qrt,
322                           final double[] rDiag,
323                           final double threshold) {
324                this.qrt   = qrt;
325                this.rDiag = rDiag;
326                this.threshold = threshold;
327            }
328    
329            /** {@inheritDoc} */
330            public boolean isNonSingular() {
331                for (double diag : rDiag) {
332                    if (FastMath.abs(diag) <= threshold) {
333                        return false;
334                    }
335                }
336                return true;
337            }
338    
339            /** {@inheritDoc} */
340            public RealVector solve(RealVector b) {
341                final int n = qrt.length;
342                final int m = qrt[0].length;
343                if (b.getDimension() != m) {
344                    throw new DimensionMismatchException(b.getDimension(), m);
345                }
346                if (!isNonSingular()) {
347                    throw new SingularMatrixException();
348                }
349    
350                final double[] x = new double[n];
351                final double[] y = b.toArray();
352    
353                // apply Householder transforms to solve Q.y = b
354                for (int minor = 0; minor < FastMath.min(m, n); minor++) {
355    
356                    final double[] qrtMinor = qrt[minor];
357                    double dotProduct = 0;
358                    for (int row = minor; row < m; row++) {
359                        dotProduct += y[row] * qrtMinor[row];
360                    }
361                    dotProduct /= rDiag[minor] * qrtMinor[minor];
362    
363                    for (int row = minor; row < m; row++) {
364                        y[row] += dotProduct * qrtMinor[row];
365                    }
366                }
367    
368                // solve triangular system R.x = y
369                for (int row = rDiag.length - 1; row >= 0; --row) {
370                    y[row] /= rDiag[row];
371                    final double yRow = y[row];
372                    final double[] qrtRow = qrt[row];
373                    x[row] = yRow;
374                    for (int i = 0; i < row; i++) {
375                        y[i] -= yRow * qrtRow[i];
376                    }
377                }
378    
379                return new ArrayRealVector(x, false);
380            }
381    
382            /** {@inheritDoc} */
383            public RealMatrix solve(RealMatrix b) {
384                final int n = qrt.length;
385                final int m = qrt[0].length;
386                if (b.getRowDimension() != m) {
387                    throw new DimensionMismatchException(b.getRowDimension(), m);
388                }
389                if (!isNonSingular()) {
390                    throw new SingularMatrixException();
391                }
392    
393                final int columns        = b.getColumnDimension();
394                final int blockSize      = BlockRealMatrix.BLOCK_SIZE;
395                final int cBlocks        = (columns + blockSize - 1) / blockSize;
396                final double[][] xBlocks = BlockRealMatrix.createBlocksLayout(n, columns);
397                final double[][] y       = new double[b.getRowDimension()][blockSize];
398                final double[]   alpha   = new double[blockSize];
399    
400                for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
401                    final int kStart = kBlock * blockSize;
402                    final int kEnd   = FastMath.min(kStart + blockSize, columns);
403                    final int kWidth = kEnd - kStart;
404    
405                    // get the right hand side vector
406                    b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
407    
408                    // apply Householder transforms to solve Q.y = b
409                    for (int minor = 0; minor < FastMath.min(m, n); minor++) {
410                        final double[] qrtMinor = qrt[minor];
411                        final double factor     = 1.0 / (rDiag[minor] * qrtMinor[minor]);
412    
413                        Arrays.fill(alpha, 0, kWidth, 0.0);
414                        for (int row = minor; row < m; ++row) {
415                            final double   d    = qrtMinor[row];
416                            final double[] yRow = y[row];
417                            for (int k = 0; k < kWidth; ++k) {
418                                alpha[k] += d * yRow[k];
419                            }
420                        }
421                        for (int k = 0; k < kWidth; ++k) {
422                            alpha[k] *= factor;
423                        }
424    
425                        for (int row = minor; row < m; ++row) {
426                            final double   d    = qrtMinor[row];
427                            final double[] yRow = y[row];
428                            for (int k = 0; k < kWidth; ++k) {
429                                yRow[k] += alpha[k] * d;
430                            }
431                        }
432                    }
433    
434                    // solve triangular system R.x = y
435                    for (int j = rDiag.length - 1; j >= 0; --j) {
436                        final int      jBlock = j / blockSize;
437                        final int      jStart = jBlock * blockSize;
438                        final double   factor = 1.0 / rDiag[j];
439                        final double[] yJ     = y[j];
440                        final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
441                        int index = (j - jStart) * kWidth;
442                        for (int k = 0; k < kWidth; ++k) {
443                            yJ[k]          *= factor;
444                            xBlock[index++] = yJ[k];
445                        }
446    
447                        final double[] qrtJ = qrt[j];
448                        for (int i = 0; i < j; ++i) {
449                            final double rIJ  = qrtJ[i];
450                            final double[] yI = y[i];
451                            for (int k = 0; k < kWidth; ++k) {
452                                yI[k] -= yJ[k] * rIJ;
453                            }
454                        }
455                    }
456                }
457    
458                return new BlockRealMatrix(n, columns, xBlocks, false);
459            }
460    
461            /** {@inheritDoc} */
462            public RealMatrix getInverse() {
463                return solve(MatrixUtils.createRealIdentityMatrix(rDiag.length));
464            }
465        }
466    }