001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.legacy.linear;
019
020import java.util.Arrays;
021
022import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
023import org.apache.commons.math4.core.jdkmath.JdkMath;
024import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
025
026
027/**
028 * Calculates the QR-decomposition of a matrix.
029 * <p>The QR-decomposition of a matrix A consists of two matrices Q and R
030 * that satisfy: A = QR, Q is orthogonal (Q<sup>T</sup>Q = I), and R is
031 * upper triangular. If A is m&times;n, Q is m&times;m and R m&times;n.</p>
032 * <p>This class compute the decomposition using Householder reflectors.</p>
033 * <p>For efficiency purposes, the decomposition in packed form is transposed.
034 * This allows inner loop to iterate inside rows, which is much more cache-efficient
035 * in Java.</p>
036 * <p>This class is based on the class with similar name from the
037 * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library, with the
038 * following changes:</p>
039 * <ul>
040 *   <li>a {@link #getQT() getQT} method has been added,</li>
041 *   <li>the {@code solve} and {@code isFullRank} methods have been replaced
042 *   by a {@link #getSolver() getSolver} method and the equivalent methods
043 *   provided by the returned {@link DecompositionSolver}.</li>
044 * </ul>
045 *
046 * @see <a href="http://mathworld.wolfram.com/QRDecomposition.html">MathWorld</a>
047 * @see <a href="http://en.wikipedia.org/wiki/QR_decomposition">Wikipedia</a>
048 *
049 * @since 1.2 (changed to concrete class in 3.0)
050 */
051public class QRDecomposition {
052    /**
053     * A packed TRANSPOSED representation of the QR decomposition.
054     * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
055     * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
056     * from which an explicit form of Q can be recomputed if desired.</p>
057     */
058    private double[][] qrt;
059    /** The diagonal elements of R. */
060    private double[] rDiag;
061    /** Cached value of Q. */
062    private RealMatrix cachedQ;
063    /** Cached value of QT. */
064    private RealMatrix cachedQT;
065    /** Cached value of R. */
066    private RealMatrix cachedR;
067    /** Cached value of H. */
068    private RealMatrix cachedH;
069    /** Singularity threshold. */
070    private final double threshold;
071
072    /**
073     * Calculates the QR-decomposition of the given matrix.
074     * The singularity threshold defaults to zero.
075     *
076     * @param matrix The matrix to decompose.
077     *
078     * @see #QRDecomposition(RealMatrix,double)
079     */
080    public QRDecomposition(RealMatrix matrix) {
081        this(matrix, 0d);
082    }
083
084    /**
085     * Calculates the QR-decomposition of the given matrix.
086     *
087     * @param matrix The matrix to decompose.
088     * @param threshold Singularity threshold.
089     * The matrix will be considered singular if the absolute value of
090     * any of the diagonal elements of the "R" matrix is smaller than
091     * the threshold.
092     */
093    public QRDecomposition(RealMatrix matrix,
094                           double threshold) {
095        this.threshold = threshold;
096
097        final int m = matrix.getRowDimension();
098        final int n = matrix.getColumnDimension();
099        qrt = matrix.transpose().getData();
100        rDiag = new double[JdkMath.min(m, n)];
101        cachedQ  = null;
102        cachedQT = null;
103        cachedR  = null;
104        cachedH  = null;
105
106        decompose(qrt);
107    }
108
109    /** Decompose matrix.
110     * @param matrix transposed matrix
111     * @since 3.2
112     */
113    protected void decompose(double[][] matrix) {
114        for (int minor = 0; minor < JdkMath.min(matrix.length, matrix[0].length); minor++) {
115            performHouseholderReflection(minor, matrix);
116        }
117    }
118
119    /** Perform Householder reflection for a minor A(minor, minor) of A.
120     * @param minor minor index
121     * @param matrix transposed matrix
122     * @since 3.2
123     */
124    protected void performHouseholderReflection(int minor, double[][] matrix) {
125
126        final double[] qrtMinor = matrix[minor];
127
128        /*
129         * Let x be the first column of the minor, and a^2 = |x|^2.
130         * x will be in the positions qr[minor][minor] through qr[m][minor].
131         * The first column of the transformed minor will be (a,0,0,..)'
132         * The sign of a is chosen to be opposite to the sign of the first
133         * component of x. Let's find a:
134         */
135        double xNormSqr = 0;
136        for (int row = minor; row < qrtMinor.length; row++) {
137            final double c = qrtMinor[row];
138            xNormSqr += c * c;
139        }
140        final double a = (qrtMinor[minor] > 0) ? -JdkMath.sqrt(xNormSqr) : JdkMath.sqrt(xNormSqr);
141        rDiag[minor] = a;
142
143        if (a != 0.0) {
144
145            /*
146             * Calculate the normalized reflection vector v and transform
147             * the first column. We know the norm of v beforehand: v = x-ae
148             * so |v|^2 = <x-ae,x-ae> = <x,x>-2a<x,e>+a^2<e,e> =
149             * a^2+a^2-2a<x,e> = 2a*(a - <x,e>).
150             * Here <x, e> is now qr[minor][minor].
151             * v = x-ae is stored in the column at qr:
152             */
153            qrtMinor[minor] -= a; // now |v|^2 = -2a*(qr[minor][minor])
154
155            /*
156             * Transform the rest of the columns of the minor:
157             * They will be transformed by the matrix H = I-2vv'/|v|^2.
158             * If x is a column vector of the minor, then
159             * Hx = (I-2vv'/|v|^2)x = x-2vv'x/|v|^2 = x - 2<x,v>/|v|^2 v.
160             * Therefore the transformation is easily calculated by
161             * subtracting the column vector (2<x,v>/|v|^2)v from x.
162             *
163             * Let 2<x,v>/|v|^2 = alpha. From above we have
164             * |v|^2 = -2a*(qr[minor][minor]), so
165             * alpha = -<x,v>/(a*qr[minor][minor])
166             */
167            for (int col = minor+1; col < matrix.length; col++) {
168                final double[] qrtCol = matrix[col];
169                double alpha = 0;
170                for (int row = minor; row < qrtCol.length; row++) {
171                    alpha -= qrtCol[row] * qrtMinor[row];
172                }
173                alpha /= a * qrtMinor[minor];
174
175                // Subtract the column vector alpha*v from x.
176                for (int row = minor; row < qrtCol.length; row++) {
177                    qrtCol[row] -= alpha * qrtMinor[row];
178                }
179            }
180        }
181    }
182
183
184    /**
185     * Returns the matrix R of the decomposition.
186     * <p>R is an upper-triangular matrix</p>
187     * @return the R matrix
188     */
189    public RealMatrix getR() {
190
191        if (cachedR == null) {
192
193            // R is supposed to be m x n
194            final int n = qrt.length;
195            final int m = qrt[0].length;
196            double[][] ra = new double[m][n];
197            // copy the diagonal from rDiag and the upper triangle of qr
198            for (int row = JdkMath.min(m, n) - 1; row >= 0; row--) {
199                ra[row][row] = rDiag[row];
200                for (int col = row + 1; col < n; col++) {
201                    ra[row][col] = qrt[col][row];
202                }
203            }
204            cachedR = MatrixUtils.createRealMatrix(ra);
205        }
206
207        // return the cached matrix
208        return cachedR;
209    }
210
211    /**
212     * Returns the matrix Q of the decomposition.
213     * <p>Q is an orthogonal matrix</p>
214     * @return the Q matrix
215     */
216    public RealMatrix getQ() {
217        if (cachedQ == null) {
218            cachedQ = getQT().transpose();
219        }
220        return cachedQ;
221    }
222
223    /**
224     * Returns the transpose of the matrix Q of the decomposition.
225     * <p>Q is an orthogonal matrix</p>
226     * @return the transpose of the Q matrix, Q<sup>T</sup>
227     */
228    public RealMatrix getQT() {
229        if (cachedQT == null) {
230
231            // QT is supposed to be m x m
232            final int n = qrt.length;
233            final int m = qrt[0].length;
234            double[][] qta = new double[m][m];
235
236            /*
237             * Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
238             * applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
239             * succession to the result
240             */
241            for (int minor = m - 1; minor >= JdkMath.min(m, n); minor--) {
242                qta[minor][minor] = 1.0d;
243            }
244
245            for (int minor = JdkMath.min(m, n)-1; minor >= 0; minor--){
246                final double[] qrtMinor = qrt[minor];
247                qta[minor][minor] = 1.0d;
248                if (qrtMinor[minor] != 0.0) {
249                    for (int col = minor; col < m; col++) {
250                        double alpha = 0;
251                        for (int row = minor; row < m; row++) {
252                            alpha -= qta[col][row] * qrtMinor[row];
253                        }
254                        alpha /= rDiag[minor] * qrtMinor[minor];
255
256                        for (int row = minor; row < m; row++) {
257                            qta[col][row] += -alpha * qrtMinor[row];
258                        }
259                    }
260                }
261            }
262            cachedQT = MatrixUtils.createRealMatrix(qta);
263        }
264
265        // return the cached matrix
266        return cachedQT;
267    }
268
269    /**
270     * Returns the Householder reflector vectors.
271     * <p>H is a lower trapezoidal matrix whose columns represent
272     * each successive Householder reflector vector. This matrix is used
273     * to compute Q.</p>
274     * @return a matrix containing the Householder reflector vectors
275     */
276    public RealMatrix getH() {
277        if (cachedH == null) {
278
279            final int n = qrt.length;
280            final int m = qrt[0].length;
281            double[][] ha = new double[m][n];
282            for (int i = 0; i < m; ++i) {
283                for (int j = 0; j < JdkMath.min(i + 1, n); ++j) {
284                    ha[i][j] = qrt[j][i] / -rDiag[j];
285                }
286            }
287            cachedH = MatrixUtils.createRealMatrix(ha);
288        }
289
290        // return the cached matrix
291        return cachedH;
292    }
293
294    /**
295     * Get a solver for finding the A &times; X = B solution in least square sense.
296     * <p>
297     * Least Square sense means a solver can be computed for an overdetermined system,
298     * (i.e. a system with more equations than unknowns, which corresponds to a tall A
299     * matrix with more rows than columns). In any case, if the matrix is singular
300     * within the tolerance set at {@link QRDecomposition#QRDecomposition(RealMatrix,
301     * double) construction}, an error will be triggered when
302     * the {@link DecompositionSolver#solve(RealVector) solve} method will be called.
303     * </p>
304     * @return a solver
305     */
306    public DecompositionSolver getSolver() {
307        return new Solver(qrt, rDiag, threshold);
308    }
309
310    /** Specialized solver. */
311    private static final class Solver implements DecompositionSolver {
312        /**
313         * A packed TRANSPOSED representation of the QR decomposition.
314         * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
315         * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
316         * from which an explicit form of Q can be recomputed if desired.</p>
317         */
318        private final double[][] qrt;
319        /** The diagonal elements of R. */
320        private final double[] rDiag;
321        /** Singularity threshold. */
322        private final double threshold;
323
324        /**
325         * Build a solver from decomposed matrix.
326         *
327         * @param qrt Packed TRANSPOSED representation of the QR decomposition.
328         * @param rDiag Diagonal elements of R.
329         * @param threshold Singularity threshold.
330         */
331        private Solver(final double[][] qrt,
332                       final double[] rDiag,
333                       final double threshold) {
334            this.qrt   = qrt;
335            this.rDiag = rDiag;
336            this.threshold = threshold;
337        }
338
339        /** {@inheritDoc} */
340        @Override
341        public boolean isNonSingular() {
342            return !checkSingular(rDiag, threshold, false);
343        }
344
345        /** {@inheritDoc} */
346        @Override
347        public RealVector solve(RealVector b) {
348            final int n = qrt.length;
349            final int m = qrt[0].length;
350            if (b.getDimension() != m) {
351                throw new DimensionMismatchException(b.getDimension(), m);
352            }
353            checkSingular(rDiag, threshold, true);
354
355            final double[] x = new double[n];
356            final double[] y = b.toArray();
357
358            // apply Householder transforms to solve Q.y = b
359            for (int minor = 0; minor < JdkMath.min(m, n); minor++) {
360
361                final double[] qrtMinor = qrt[minor];
362                double dotProduct = 0;
363                for (int row = minor; row < m; row++) {
364                    dotProduct += y[row] * qrtMinor[row];
365                }
366                dotProduct /= rDiag[minor] * qrtMinor[minor];
367
368                for (int row = minor; row < m; row++) {
369                    y[row] += dotProduct * qrtMinor[row];
370                }
371            }
372
373            // solve triangular system R.x = y
374            for (int row = rDiag.length - 1; row >= 0; --row) {
375                y[row] /= rDiag[row];
376                final double yRow = y[row];
377                final double[] qrtRow = qrt[row];
378                x[row] = yRow;
379                for (int i = 0; i < row; i++) {
380                    y[i] -= yRow * qrtRow[i];
381                }
382            }
383
384            return new ArrayRealVector(x, false);
385        }
386
387        /** {@inheritDoc} */
388        @Override
389        public RealMatrix solve(RealMatrix b) {
390            final int n = qrt.length;
391            final int m = qrt[0].length;
392            if (b.getRowDimension() != m) {
393                throw new DimensionMismatchException(b.getRowDimension(), m);
394            }
395            checkSingular(rDiag, threshold, true);
396
397            final int columns        = b.getColumnDimension();
398            final int blockSize      = BlockRealMatrix.BLOCK_SIZE;
399            final int cBlocks        = (columns + blockSize - 1) / blockSize;
400            final double[][] xBlocks = BlockRealMatrix.createBlocksLayout(n, columns);
401            final double[][] y       = new double[b.getRowDimension()][blockSize];
402            final double[]   alpha   = new double[blockSize];
403
404            for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
405                final int kStart = kBlock * blockSize;
406                final int kEnd   = JdkMath.min(kStart + blockSize, columns);
407                final int kWidth = kEnd - kStart;
408
409                // get the right hand side vector
410                b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
411
412                // apply Householder transforms to solve Q.y = b
413                for (int minor = 0; minor < JdkMath.min(m, n); minor++) {
414                    final double[] qrtMinor = qrt[minor];
415                    final double factor     = 1.0 / (rDiag[minor] * qrtMinor[minor]);
416
417                    Arrays.fill(alpha, 0, kWidth, 0.0);
418                    for (int row = minor; row < m; ++row) {
419                        final double   d    = qrtMinor[row];
420                        final double[] yRow = y[row];
421                        for (int k = 0; k < kWidth; ++k) {
422                            alpha[k] += d * yRow[k];
423                        }
424                    }
425                    for (int k = 0; k < kWidth; ++k) {
426                        alpha[k] *= factor;
427                    }
428
429                    for (int row = minor; row < m; ++row) {
430                        final double   d    = qrtMinor[row];
431                        final double[] yRow = y[row];
432                        for (int k = 0; k < kWidth; ++k) {
433                            yRow[k] += alpha[k] * d;
434                        }
435                    }
436                }
437
438                // solve triangular system R.x = y
439                for (int j = rDiag.length - 1; j >= 0; --j) {
440                    final int      jBlock = j / blockSize;
441                    final int      jStart = jBlock * blockSize;
442                    final double   factor = 1.0 / rDiag[j];
443                    final double[] yJ     = y[j];
444                    final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
445                    int index = (j - jStart) * kWidth;
446                    for (int k = 0; k < kWidth; ++k) {
447                        yJ[k]          *= factor;
448                        xBlock[index++] = yJ[k];
449                    }
450
451                    final double[] qrtJ = qrt[j];
452                    for (int i = 0; i < j; ++i) {
453                        final double rIJ  = qrtJ[i];
454                        final double[] yI = y[i];
455                        for (int k = 0; k < kWidth; ++k) {
456                            yI[k] -= yJ[k] * rIJ;
457                        }
458                    }
459                }
460            }
461
462            return new BlockRealMatrix(n, columns, xBlocks, false);
463        }
464
465        /**
466         * {@inheritDoc}
467         * @throws SingularMatrixException if the decomposed matrix is singular.
468         */
469        @Override
470        public RealMatrix getInverse() {
471            return solve(MatrixUtils.createRealIdentityMatrix(qrt[0].length));
472        }
473
474        /**
475         * Check singularity.
476         *
477         * @param diag Diagonal elements of the R matrix.
478         * @param min Singularity threshold.
479         * @param raise Whether to raise a {@link SingularMatrixException}
480         * if any element of the diagonal fails the check.
481         * @return {@code true} if any element of the diagonal is smaller
482         * or equal to {@code min}.
483         * @throws SingularMatrixException if the matrix is singular and
484         * {@code raise} is {@code true}.
485         */
486        private static boolean checkSingular(double[] diag,
487                                             double min,
488                                             boolean raise) {
489            final int len = diag.length;
490            for (int i = 0; i < len; i++) {
491                final double d = diag[i];
492                if (JdkMath.abs(d) <= min) {
493                    if (raise) {
494                        final SingularMatrixException e = new SingularMatrixException();
495                        e.getContext().addMessage(LocalizedFormats.NUMBER_TOO_SMALL, d, min);
496                        e.getContext().addMessage(LocalizedFormats.INDEX, i);
497                        throw e;
498                    } else {
499                        return true;
500                    }
501                }
502            }
503            return false;
504        }
505    }
506}