001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.linear;
019    
020    import java.util.Arrays;
021    
022    import org.apache.commons.math.exception.DimensionMismatchException;
023    import org.apache.commons.math.util.FastMath;
024    
025    
026    /**
027     * Calculates the QR-decomposition of a matrix.
028     * <p>The QR-decomposition of a matrix A consists of two matrices Q and R
029     * that satisfy: A = QR, Q is orthogonal (Q<sup>T</sup>Q = I), and R is
030     * upper triangular. If A is m&times;n, Q is m&times;m and R m&times;n.</p>
031     * <p>This class compute the decomposition using Householder reflectors.</p>
032     * <p>For efficiency purposes, the decomposition in packed form is transposed.
033     * This allows inner loop to iterate inside rows, which is much more cache-efficient
034     * in Java.</p>
035     *
036     * @see <a href="http://mathworld.wolfram.com/QRDecomposition.html">MathWorld</a>
037     * @see <a href="http://en.wikipedia.org/wiki/QR_decomposition">Wikipedia</a>
038     *
039     * @version $Id: QRDecompositionImpl.java 1131229 2011-06-03 20:49:25Z luc $
040     * @since 1.2
041     */
042    public class QRDecompositionImpl implements QRDecomposition {
043    
044        /**
045         * A packed TRANSPOSED representation of the QR decomposition.
046         * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
047         * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
048         * from which an explicit form of Q can be recomputed if desired.</p>
049         */
050        private double[][] qrt;
051    
052        /** The diagonal elements of R. */
053        private double[] rDiag;
054    
055        /** Cached value of Q. */
056        private RealMatrix cachedQ;
057    
058        /** Cached value of QT. */
059        private RealMatrix cachedQT;
060    
061        /** Cached value of R. */
062        private RealMatrix cachedR;
063    
064        /** Cached value of H. */
065        private RealMatrix cachedH;
066    
067        /**
068         * Calculates the QR-decomposition of the given matrix.
069         * @param matrix The matrix to decompose.
070         */
071        public QRDecompositionImpl(RealMatrix matrix) {
072    
073            final int m = matrix.getRowDimension();
074            final int n = matrix.getColumnDimension();
075            qrt = matrix.transpose().getData();
076            rDiag = new double[FastMath.min(m, n)];
077            cachedQ  = null;
078            cachedQT = null;
079            cachedR  = null;
080            cachedH  = null;
081    
082            /*
083             * The QR decomposition of a matrix A is calculated using Householder
084             * reflectors by repeating the following operations to each minor
085             * A(minor,minor) of A:
086             */
087            for (int minor = 0; minor < FastMath.min(m, n); minor++) {
088    
089                final double[] qrtMinor = qrt[minor];
090    
091                /*
092                 * Let x be the first column of the minor, and a^2 = |x|^2.
093                 * x will be in the positions qr[minor][minor] through qr[m][minor].
094                 * The first column of the transformed minor will be (a,0,0,..)'
095                 * The sign of a is chosen to be opposite to the sign of the first
096                 * component of x. Let's find a:
097                 */
098                double xNormSqr = 0;
099                for (int row = minor; row < m; row++) {
100                    final double c = qrtMinor[row];
101                    xNormSqr += c * c;
102                }
103                final double a = (qrtMinor[minor] > 0) ? -FastMath.sqrt(xNormSqr) : FastMath.sqrt(xNormSqr);
104                rDiag[minor] = a;
105    
106                if (a != 0.0) {
107    
108                    /*
109                     * Calculate the normalized reflection vector v and transform
110                     * the first column. We know the norm of v beforehand: v = x-ae
111                     * so |v|^2 = <x-ae,x-ae> = <x,x>-2a<x,e>+a^2<e,e> =
112                     * a^2+a^2-2a<x,e> = 2a*(a - <x,e>).
113                     * Here <x, e> is now qr[minor][minor].
114                     * v = x-ae is stored in the column at qr:
115                     */
116                    qrtMinor[minor] -= a; // now |v|^2 = -2a*(qr[minor][minor])
117    
118                    /*
119                     * Transform the rest of the columns of the minor:
120                     * They will be transformed by the matrix H = I-2vv'/|v|^2.
121                     * If x is a column vector of the minor, then
122                     * Hx = (I-2vv'/|v|^2)x = x-2vv'x/|v|^2 = x - 2<x,v>/|v|^2 v.
123                     * Therefore the transformation is easily calculated by
124                     * subtracting the column vector (2<x,v>/|v|^2)v from x.
125                     *
126                     * Let 2<x,v>/|v|^2 = alpha. From above we have
127                     * |v|^2 = -2a*(qr[minor][minor]), so
128                     * alpha = -<x,v>/(a*qr[minor][minor])
129                     */
130                    for (int col = minor+1; col < n; col++) {
131                        final double[] qrtCol = qrt[col];
132                        double alpha = 0;
133                        for (int row = minor; row < m; row++) {
134                            alpha -= qrtCol[row] * qrtMinor[row];
135                        }
136                        alpha /= a * qrtMinor[minor];
137    
138                        // Subtract the column vector alpha*v from x.
139                        for (int row = minor; row < m; row++) {
140                            qrtCol[row] -= alpha * qrtMinor[row];
141                        }
142                    }
143                }
144            }
145        }
146    
147        /** {@inheritDoc} */
148        public RealMatrix getR() {
149    
150            if (cachedR == null) {
151    
152                // R is supposed to be m x n
153                final int n = qrt.length;
154                final int m = qrt[0].length;
155                cachedR = MatrixUtils.createRealMatrix(m, n);
156    
157                // copy the diagonal from rDiag and the upper triangle of qr
158                for (int row = FastMath.min(m, n) - 1; row >= 0; row--) {
159                    cachedR.setEntry(row, row, rDiag[row]);
160                    for (int col = row + 1; col < n; col++) {
161                        cachedR.setEntry(row, col, qrt[col][row]);
162                    }
163                }
164            }
165    
166            // return the cached matrix
167            return cachedR;
168        }
169    
170        /** {@inheritDoc} */
171        public RealMatrix getQ() {
172            if (cachedQ == null) {
173                cachedQ = getQT().transpose();
174            }
175            return cachedQ;
176        }
177    
178        /** {@inheritDoc} */
179        public RealMatrix getQT() {
180            if (cachedQT == null) {
181    
182                // QT is supposed to be m x m
183                final int n = qrt.length;
184                final int m = qrt[0].length;
185                cachedQT = MatrixUtils.createRealMatrix(m, m);
186    
187                /*
188                 * Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
189                 * applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
190                 * succession to the result
191                 */
192                for (int minor = m - 1; minor >= FastMath.min(m, n); minor--) {
193                    cachedQT.setEntry(minor, minor, 1.0);
194                }
195    
196                for (int minor = FastMath.min(m, n)-1; minor >= 0; minor--){
197                    final double[] qrtMinor = qrt[minor];
198                    cachedQT.setEntry(minor, minor, 1.0);
199                    if (qrtMinor[minor] != 0.0) {
200                        for (int col = minor; col < m; col++) {
201                            double alpha = 0;
202                            for (int row = minor; row < m; row++) {
203                                alpha -= cachedQT.getEntry(col, row) * qrtMinor[row];
204                            }
205                            alpha /= rDiag[minor] * qrtMinor[minor];
206    
207                            for (int row = minor; row < m; row++) {
208                                cachedQT.addToEntry(col, row, -alpha * qrtMinor[row]);
209                            }
210                        }
211                    }
212                }
213            }
214    
215            // return the cached matrix
216            return cachedQT;
217        }
218    
219        /** {@inheritDoc} */
220        public RealMatrix getH() {
221            if (cachedH == null) {
222    
223                final int n = qrt.length;
224                final int m = qrt[0].length;
225                cachedH = MatrixUtils.createRealMatrix(m, n);
226                for (int i = 0; i < m; ++i) {
227                    for (int j = 0; j < FastMath.min(i + 1, n); ++j) {
228                        cachedH.setEntry(i, j, qrt[j][i] / -rDiag[j]);
229                    }
230                }
231            }
232    
233            // return the cached matrix
234            return cachedH;
235        }
236    
237        /** {@inheritDoc} */
238        public DecompositionSolver getSolver() {
239            return new Solver(qrt, rDiag);
240        }
241    
242        /** Specialized solver. */
243        private static class Solver implements DecompositionSolver {
244    
245            /**
246             * A packed TRANSPOSED representation of the QR decomposition.
247             * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
248             * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
249             * from which an explicit form of Q can be recomputed if desired.</p>
250             */
251            private final double[][] qrt;
252    
253            /** The diagonal elements of R. */
254            private final double[] rDiag;
255    
256            /**
257             * Build a solver from decomposed matrix.
258             * @param qrt packed TRANSPOSED representation of the QR decomposition
259             * @param rDiag diagonal elements of R
260             */
261            private Solver(final double[][] qrt, final double[] rDiag) {
262                this.qrt   = qrt;
263                this.rDiag = rDiag;
264            }
265    
266            /** {@inheritDoc} */
267            public boolean isNonSingular() {
268    
269                for (double diag : rDiag) {
270                    if (diag == 0) {
271                        return false;
272                    }
273                }
274                return true;
275            }
276    
277            /** {@inheritDoc} */
278            public double[] solve(double[] b) {
279                final int n = qrt.length;
280                final int m = qrt[0].length;
281                if (b.length != m) {
282                    throw new DimensionMismatchException(b.length, m);
283                }
284                if (!isNonSingular()) {
285                    throw new SingularMatrixException();
286                }
287    
288                final double[] x = new double[n];
289                final double[] y = b.clone();
290    
291                // apply Householder transforms to solve Q.y = b
292                for (int minor = 0; minor < FastMath.min(m, n); minor++) {
293    
294                    final double[] qrtMinor = qrt[minor];
295                    double dotProduct = 0;
296                    for (int row = minor; row < m; row++) {
297                        dotProduct += y[row] * qrtMinor[row];
298                    }
299                    dotProduct /= rDiag[minor] * qrtMinor[minor];
300    
301                    for (int row = minor; row < m; row++) {
302                        y[row] += dotProduct * qrtMinor[row];
303                    }
304                }
305    
306                // solve triangular system R.x = y
307                for (int row = rDiag.length - 1; row >= 0; --row) {
308                    y[row] /= rDiag[row];
309                    final double yRow   = y[row];
310                    final double[] qrtRow = qrt[row];
311                    x[row] = yRow;
312                    for (int i = 0; i < row; i++) {
313                        y[i] -= yRow * qrtRow[i];
314                    }
315                }
316    
317                return x;
318            }
319    
320            /** {@inheritDoc} */
321            public RealVector solve(RealVector b) {
322                try {
323                    return solve((ArrayRealVector) b);
324                } catch (ClassCastException cce) {
325                    return new ArrayRealVector(solve(b.getData()), false);
326                }
327            }
328    
329            /** Solve the linear equation A &times; X = B.
330             * <p>The A matrix is implicit here. It is </p>
331             * @param b right-hand side of the equation A &times; X = B
332             * @return a vector X that minimizes the two norm of A &times; X - B
333             * @throws DimensionMismatchException if the matrices dimensions do not match.
334             * @throws SingularMatrixException if the decomposed matrix is singular.
335             */
336            public ArrayRealVector solve(ArrayRealVector b) {
337                return new ArrayRealVector(solve(b.getDataRef()), false);
338            }
339    
340            /** {@inheritDoc} */
341            public double[][] solve(double[][] b) {
342                return solve(new BlockRealMatrix(b)).getData();
343            }
344    
345            /** {@inheritDoc} */
346            public RealMatrix solve(RealMatrix b) {
347                final int n = qrt.length;
348                final int m = qrt[0].length;
349                if (b.getRowDimension() != m) {
350                    throw new DimensionMismatchException(b.getRowDimension(), m);
351                }
352                if (!isNonSingular()) {
353                    throw new SingularMatrixException();
354                }
355    
356                final int columns        = b.getColumnDimension();
357                final int blockSize      = BlockRealMatrix.BLOCK_SIZE;
358                final int cBlocks        = (columns + blockSize - 1) / blockSize;
359                final double[][] xBlocks = BlockRealMatrix.createBlocksLayout(n, columns);
360                final double[][] y       = new double[b.getRowDimension()][blockSize];
361                final double[]   alpha   = new double[blockSize];
362    
363                for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
364                    final int kStart = kBlock * blockSize;
365                    final int kEnd   = FastMath.min(kStart + blockSize, columns);
366                    final int kWidth = kEnd - kStart;
367    
368                    // get the right hand side vector
369                    b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
370    
371                    // apply Householder transforms to solve Q.y = b
372                    for (int minor = 0; minor < FastMath.min(m, n); minor++) {
373                        final double[] qrtMinor = qrt[minor];
374                        final double factor     = 1.0 / (rDiag[minor] * qrtMinor[minor]);
375    
376                        Arrays.fill(alpha, 0, kWidth, 0.0);
377                        for (int row = minor; row < m; ++row) {
378                            final double   d    = qrtMinor[row];
379                            final double[] yRow = y[row];
380                            for (int k = 0; k < kWidth; ++k) {
381                                alpha[k] += d * yRow[k];
382                            }
383                        }
384                        for (int k = 0; k < kWidth; ++k) {
385                            alpha[k] *= factor;
386                        }
387    
388                        for (int row = minor; row < m; ++row) {
389                            final double   d    = qrtMinor[row];
390                            final double[] yRow = y[row];
391                            for (int k = 0; k < kWidth; ++k) {
392                                yRow[k] += alpha[k] * d;
393                            }
394                        }
395                    }
396    
397                    // solve triangular system R.x = y
398                    for (int j = rDiag.length - 1; j >= 0; --j) {
399                        final int      jBlock = j / blockSize;
400                        final int      jStart = jBlock * blockSize;
401                        final double   factor = 1.0 / rDiag[j];
402                        final double[] yJ     = y[j];
403                        final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
404                        int index = (j - jStart) * kWidth;
405                        for (int k = 0; k < kWidth; ++k) {
406                            yJ[k]          *= factor;
407                            xBlock[index++] = yJ[k];
408                        }
409    
410                        final double[] qrtJ = qrt[j];
411                        for (int i = 0; i < j; ++i) {
412                            final double rIJ  = qrtJ[i];
413                            final double[] yI = y[i];
414                            for (int k = 0; k < kWidth; ++k) {
415                                yI[k] -= yJ[k] * rIJ;
416                            }
417                        }
418                    }
419                }
420    
421                return new BlockRealMatrix(n, columns, xBlocks, false);
422            }
423    
424            /** {@inheritDoc} */
425            public RealMatrix getInverse() {
426                return solve(MatrixUtils.createRealIdentityMatrix(rDiag.length));
427            }
428        }
429    }