001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.linear;
019
020import java.util.Arrays;
021
022import org.apache.commons.math3.exception.DimensionMismatchException;
023import org.apache.commons.math3.util.FastMath;
024
025
026/**
027 * Calculates the QR-decomposition of a matrix.
028 * <p>The QR-decomposition of a matrix A consists of two matrices Q and R
029 * that satisfy: A = QR, Q is orthogonal (Q<sup>T</sup>Q = I), and R is
030 * upper triangular. If A is m&times;n, Q is m&times;m and R m&times;n.</p>
031 * <p>This class compute the decomposition using Householder reflectors.</p>
032 * <p>For efficiency purposes, the decomposition in packed form is transposed.
033 * This allows inner loop to iterate inside rows, which is much more cache-efficient
034 * in Java.</p>
035 * <p>This class is based on the class with similar name from the
036 * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library, with the
037 * following changes:</p>
038 * <ul>
039 *   <li>a {@link #getQT() getQT} method has been added,</li>
040 *   <li>the {@code solve} and {@code isFullRank} methods have been replaced
041 *   by a {@link #getSolver() getSolver} method and the equivalent methods
042 *   provided by the returned {@link DecompositionSolver}.</li>
043 * </ul>
044 *
045 * @see <a href="http://mathworld.wolfram.com/QRDecomposition.html">MathWorld</a>
046 * @see <a href="http://en.wikipedia.org/wiki/QR_decomposition">Wikipedia</a>
047 *
048 * @version $Id: QRDecomposition.java 1462423 2013-03-29 07:25:18Z luc $
049 * @since 1.2 (changed to concrete class in 3.0)
050 */
051public class QRDecomposition {
052    /**
053     * A packed TRANSPOSED representation of the QR decomposition.
054     * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
055     * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
056     * from which an explicit form of Q can be recomputed if desired.</p>
057     */
058    private double[][] qrt;
059    /** The diagonal elements of R. */
060    private double[] rDiag;
061    /** Cached value of Q. */
062    private RealMatrix cachedQ;
063    /** Cached value of QT. */
064    private RealMatrix cachedQT;
065    /** Cached value of R. */
066    private RealMatrix cachedR;
067    /** Cached value of H. */
068    private RealMatrix cachedH;
069    /** Singularity threshold. */
070    private final double threshold;
071
072    /**
073     * Calculates the QR-decomposition of the given matrix.
074     * The singularity threshold defaults to zero.
075     *
076     * @param matrix The matrix to decompose.
077     *
078     * @see #QRDecomposition(RealMatrix,double)
079     */
080    public QRDecomposition(RealMatrix matrix) {
081        this(matrix, 0d);
082    }
083
084    /**
085     * Calculates the QR-decomposition of the given matrix.
086     *
087     * @param matrix The matrix to decompose.
088     * @param threshold Singularity threshold.
089     */
090    public QRDecomposition(RealMatrix matrix,
091                           double threshold) {
092        this.threshold = threshold;
093
094        final int m = matrix.getRowDimension();
095        final int n = matrix.getColumnDimension();
096        qrt = matrix.transpose().getData();
097        rDiag = new double[FastMath.min(m, n)];
098        cachedQ  = null;
099        cachedQT = null;
100        cachedR  = null;
101        cachedH  = null;
102
103        decompose(qrt);
104
105    }
106
107    /** Decompose matrix.
108     * @param matrix transposed matrix
109     * @since 3.2
110     */
111    protected void decompose(double[][] matrix) {
112        for (int minor = 0; minor < FastMath.min(qrt.length, qrt[0].length); minor++) {
113            performHouseholderReflection(minor, qrt);
114        }
115    }
116
117    /** Perform Householder reflection for a minor A(minor, minor) of A.
118     * @param minor minor index
119     * @param matrix transposed matrix
120     * @since 3.2
121     */
122    protected void performHouseholderReflection(int minor, double[][] matrix) {
123
124        final double[] qrtMinor = qrt[minor];
125
126        /*
127         * Let x be the first column of the minor, and a^2 = |x|^2.
128         * x will be in the positions qr[minor][minor] through qr[m][minor].
129         * The first column of the transformed minor will be (a,0,0,..)'
130         * The sign of a is chosen to be opposite to the sign of the first
131         * component of x. Let's find a:
132         */
133        double xNormSqr = 0;
134        for (int row = minor; row < qrtMinor.length; row++) {
135            final double c = qrtMinor[row];
136            xNormSqr += c * c;
137        }
138        final double a = (qrtMinor[minor] > 0) ? -FastMath.sqrt(xNormSqr) : FastMath.sqrt(xNormSqr);
139        rDiag[minor] = a;
140
141        if (a != 0.0) {
142
143            /*
144             * Calculate the normalized reflection vector v and transform
145             * the first column. We know the norm of v beforehand: v = x-ae
146             * so |v|^2 = <x-ae,x-ae> = <x,x>-2a<x,e>+a^2<e,e> =
147             * a^2+a^2-2a<x,e> = 2a*(a - <x,e>).
148             * Here <x, e> is now qr[minor][minor].
149             * v = x-ae is stored in the column at qr:
150             */
151            qrtMinor[minor] -= a; // now |v|^2 = -2a*(qr[minor][minor])
152
153            /*
154             * Transform the rest of the columns of the minor:
155             * They will be transformed by the matrix H = I-2vv'/|v|^2.
156             * If x is a column vector of the minor, then
157             * Hx = (I-2vv'/|v|^2)x = x-2vv'x/|v|^2 = x - 2<x,v>/|v|^2 v.
158             * Therefore the transformation is easily calculated by
159             * subtracting the column vector (2<x,v>/|v|^2)v from x.
160             *
161             * Let 2<x,v>/|v|^2 = alpha. From above we have
162             * |v|^2 = -2a*(qr[minor][minor]), so
163             * alpha = -<x,v>/(a*qr[minor][minor])
164             */
165            for (int col = minor+1; col < qrt.length; col++) {
166                final double[] qrtCol = qrt[col];
167                double alpha = 0;
168                for (int row = minor; row < qrtCol.length; row++) {
169                    alpha -= qrtCol[row] * qrtMinor[row];
170                }
171                alpha /= a * qrtMinor[minor];
172
173                // Subtract the column vector alpha*v from x.
174                for (int row = minor; row < qrtCol.length; row++) {
175                    qrtCol[row] -= alpha * qrtMinor[row];
176                }
177            }
178        }
179    }
180
181
182    /**
183     * Returns the matrix R of the decomposition.
184     * <p>R is an upper-triangular matrix</p>
185     * @return the R matrix
186     */
187    public RealMatrix getR() {
188
189        if (cachedR == null) {
190
191            // R is supposed to be m x n
192            final int n = qrt.length;
193            final int m = qrt[0].length;
194            double[][] ra = new double[m][n];
195            // copy the diagonal from rDiag and the upper triangle of qr
196            for (int row = FastMath.min(m, n) - 1; row >= 0; row--) {
197                ra[row][row] = rDiag[row];
198                for (int col = row + 1; col < n; col++) {
199                    ra[row][col] = qrt[col][row];
200                }
201            }
202            cachedR = MatrixUtils.createRealMatrix(ra);
203        }
204
205        // return the cached matrix
206        return cachedR;
207    }
208
209    /**
210     * Returns the matrix Q of the decomposition.
211     * <p>Q is an orthogonal matrix</p>
212     * @return the Q matrix
213     */
214    public RealMatrix getQ() {
215        if (cachedQ == null) {
216            cachedQ = getQT().transpose();
217        }
218        return cachedQ;
219    }
220
221    /**
222     * Returns the transpose of the matrix Q of the decomposition.
223     * <p>Q is an orthogonal matrix</p>
224     * @return the transpose of the Q matrix, Q<sup>T</sup>
225     */
226    public RealMatrix getQT() {
227        if (cachedQT == null) {
228
229            // QT is supposed to be m x m
230            final int n = qrt.length;
231            final int m = qrt[0].length;
232            double[][] qta = new double[m][m];
233
234            /*
235             * Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
236             * applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
237             * succession to the result
238             */
239            for (int minor = m - 1; minor >= FastMath.min(m, n); minor--) {
240                qta[minor][minor] = 1.0d;
241            }
242
243            for (int minor = FastMath.min(m, n)-1; minor >= 0; minor--){
244                final double[] qrtMinor = qrt[minor];
245                qta[minor][minor] = 1.0d;
246                if (qrtMinor[minor] != 0.0) {
247                    for (int col = minor; col < m; col++) {
248                        double alpha = 0;
249                        for (int row = minor; row < m; row++) {
250                            alpha -= qta[col][row] * qrtMinor[row];
251                        }
252                        alpha /= rDiag[minor] * qrtMinor[minor];
253
254                        for (int row = minor; row < m; row++) {
255                            qta[col][row] += -alpha * qrtMinor[row];
256                        }
257                    }
258                }
259            }
260            cachedQT = MatrixUtils.createRealMatrix(qta);
261        }
262
263        // return the cached matrix
264        return cachedQT;
265    }
266
267    /**
268     * Returns the Householder reflector vectors.
269     * <p>H is a lower trapezoidal matrix whose columns represent
270     * each successive Householder reflector vector. This matrix is used
271     * to compute Q.</p>
272     * @return a matrix containing the Householder reflector vectors
273     */
274    public RealMatrix getH() {
275        if (cachedH == null) {
276
277            final int n = qrt.length;
278            final int m = qrt[0].length;
279            double[][] ha = new double[m][n];
280            for (int i = 0; i < m; ++i) {
281                for (int j = 0; j < FastMath.min(i + 1, n); ++j) {
282                    ha[i][j] = qrt[j][i] / -rDiag[j];
283                }
284            }
285            cachedH = MatrixUtils.createRealMatrix(ha);
286        }
287
288        // return the cached matrix
289        return cachedH;
290    }
291
292    /**
293     * Get a solver for finding the A &times; X = B solution in least square sense.
294     * @return a solver
295     */
296    public DecompositionSolver getSolver() {
297        return new Solver(qrt, rDiag, threshold);
298    }
299
300    /** Specialized solver. */
301    private static class Solver implements DecompositionSolver {
302        /**
303         * A packed TRANSPOSED representation of the QR decomposition.
304         * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
305         * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
306         * from which an explicit form of Q can be recomputed if desired.</p>
307         */
308        private final double[][] qrt;
309        /** The diagonal elements of R. */
310        private final double[] rDiag;
311        /** Singularity threshold. */
312        private final double threshold;
313
314        /**
315         * Build a solver from decomposed matrix.
316         *
317         * @param qrt Packed TRANSPOSED representation of the QR decomposition.
318         * @param rDiag Diagonal elements of R.
319         * @param threshold Singularity threshold.
320         */
321        private Solver(final double[][] qrt,
322                       final double[] rDiag,
323                       final double threshold) {
324            this.qrt   = qrt;
325            this.rDiag = rDiag;
326            this.threshold = threshold;
327        }
328
329        /** {@inheritDoc} */
330        public boolean isNonSingular() {
331            for (double diag : rDiag) {
332                if (FastMath.abs(diag) <= threshold) {
333                    return false;
334                }
335            }
336            return true;
337        }
338
339        /** {@inheritDoc} */
340        public RealVector solve(RealVector b) {
341            final int n = qrt.length;
342            final int m = qrt[0].length;
343            if (b.getDimension() != m) {
344                throw new DimensionMismatchException(b.getDimension(), m);
345            }
346            if (!isNonSingular()) {
347                throw new SingularMatrixException();
348            }
349
350            final double[] x = new double[n];
351            final double[] y = b.toArray();
352
353            // apply Householder transforms to solve Q.y = b
354            for (int minor = 0; minor < FastMath.min(m, n); minor++) {
355
356                final double[] qrtMinor = qrt[minor];
357                double dotProduct = 0;
358                for (int row = minor; row < m; row++) {
359                    dotProduct += y[row] * qrtMinor[row];
360                }
361                dotProduct /= rDiag[minor] * qrtMinor[minor];
362
363                for (int row = minor; row < m; row++) {
364                    y[row] += dotProduct * qrtMinor[row];
365                }
366            }
367
368            // solve triangular system R.x = y
369            for (int row = rDiag.length - 1; row >= 0; --row) {
370                y[row] /= rDiag[row];
371                final double yRow = y[row];
372                final double[] qrtRow = qrt[row];
373                x[row] = yRow;
374                for (int i = 0; i < row; i++) {
375                    y[i] -= yRow * qrtRow[i];
376                }
377            }
378
379            return new ArrayRealVector(x, false);
380        }
381
382        /** {@inheritDoc} */
383        public RealMatrix solve(RealMatrix b) {
384            final int n = qrt.length;
385            final int m = qrt[0].length;
386            if (b.getRowDimension() != m) {
387                throw new DimensionMismatchException(b.getRowDimension(), m);
388            }
389            if (!isNonSingular()) {
390                throw new SingularMatrixException();
391            }
392
393            final int columns        = b.getColumnDimension();
394            final int blockSize      = BlockRealMatrix.BLOCK_SIZE;
395            final int cBlocks        = (columns + blockSize - 1) / blockSize;
396            final double[][] xBlocks = BlockRealMatrix.createBlocksLayout(n, columns);
397            final double[][] y       = new double[b.getRowDimension()][blockSize];
398            final double[]   alpha   = new double[blockSize];
399
400            for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
401                final int kStart = kBlock * blockSize;
402                final int kEnd   = FastMath.min(kStart + blockSize, columns);
403                final int kWidth = kEnd - kStart;
404
405                // get the right hand side vector
406                b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
407
408                // apply Householder transforms to solve Q.y = b
409                for (int minor = 0; minor < FastMath.min(m, n); minor++) {
410                    final double[] qrtMinor = qrt[minor];
411                    final double factor     = 1.0 / (rDiag[minor] * qrtMinor[minor]);
412
413                    Arrays.fill(alpha, 0, kWidth, 0.0);
414                    for (int row = minor; row < m; ++row) {
415                        final double   d    = qrtMinor[row];
416                        final double[] yRow = y[row];
417                        for (int k = 0; k < kWidth; ++k) {
418                            alpha[k] += d * yRow[k];
419                        }
420                    }
421                    for (int k = 0; k < kWidth; ++k) {
422                        alpha[k] *= factor;
423                    }
424
425                    for (int row = minor; row < m; ++row) {
426                        final double   d    = qrtMinor[row];
427                        final double[] yRow = y[row];
428                        for (int k = 0; k < kWidth; ++k) {
429                            yRow[k] += alpha[k] * d;
430                        }
431                    }
432                }
433
434                // solve triangular system R.x = y
435                for (int j = rDiag.length - 1; j >= 0; --j) {
436                    final int      jBlock = j / blockSize;
437                    final int      jStart = jBlock * blockSize;
438                    final double   factor = 1.0 / rDiag[j];
439                    final double[] yJ     = y[j];
440                    final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
441                    int index = (j - jStart) * kWidth;
442                    for (int k = 0; k < kWidth; ++k) {
443                        yJ[k]          *= factor;
444                        xBlock[index++] = yJ[k];
445                    }
446
447                    final double[] qrtJ = qrt[j];
448                    for (int i = 0; i < j; ++i) {
449                        final double rIJ  = qrtJ[i];
450                        final double[] yI = y[i];
451                        for (int k = 0; k < kWidth; ++k) {
452                            yI[k] -= yJ[k] * rIJ;
453                        }
454                    }
455                }
456            }
457
458            return new BlockRealMatrix(n, columns, xBlocks, false);
459        }
460
461        /** {@inheritDoc} */
462        public RealMatrix getInverse() {
463            return solve(MatrixUtils.createRealIdentityMatrix(rDiag.length));
464        }
465    }
466}