View Javadoc

1   /*
2    * Copyright 2011 The Apache Software Foundation.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *      http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  package org.apache.commons.math.linear;
17  
18  import java.util.Arrays;
19  import org.apache.commons.math.util.MathArrays;
20  import org.apache.commons.math.exception.ConvergenceException;
21  import org.apache.commons.math.exception.DimensionMismatchException;
22  import org.apache.commons.math.exception.util.LocalizedFormats;
23  import org.apache.commons.math.util.FastMath;
24  
25  /**
26   *
27   * @author gregsterijevski
28   */
29  public class PivotingQRDecomposition {
30  
31      private double[][] qr;
32      /** The diagonal elements of R. */
33      private double[] rDiag;
34      /** Cached value of Q. */
35      private RealMatrix cachedQ;
36      /** Cached value of QT. */
37      private RealMatrix cachedQT;
38      /** Cached value of R. */
39      private RealMatrix cachedR;
40      /** Cached value of H. */
41      private RealMatrix cachedH;
42      /** permutation info */
43      private int[] permutation;
44      /** the rank **/
45      private int rank;
46      /** vector of column multipliers */
47      private double[] beta;
48  
49      public boolean isSingular() {
50          return rank != qr[0].length;
51      }
52  
53      public int getRank() {
54          return rank;
55      }
56  
57      public int[] getOrder() {
58          return MathArrays.copyOf(permutation);
59      }
60  
61      public PivotingQRDecomposition(RealMatrix matrix) throws ConvergenceException {
62          this(matrix, 1.0e-16, true);
63      }
64  
65      public PivotingQRDecomposition(RealMatrix matrix, boolean allowPivot) throws ConvergenceException {
66          this(matrix, 1.0e-16, allowPivot);
67      }
68  
69      public PivotingQRDecomposition(RealMatrix matrix, double qrRankingThreshold,
70              boolean allowPivot) throws ConvergenceException {
71          final int rows = matrix.getRowDimension();
72          final int cols = matrix.getColumnDimension();
73          qr = matrix.getData();
74          rDiag = new double[cols];
75          //final double[] norms = new double[cols];
76          this.beta = new double[cols];
77          this.permutation = new int[cols];
78          cachedQ = null;
79          cachedQT = null;
80          cachedR = null;
81          cachedH = null;
82  
83          /*- initialize the permutation vector and calculate the norms */
84          for (int k = 0; k < cols; ++k) {
85              permutation[k] = k;
86          }
87          // transform the matrix column after column
88          for (int k = 0; k < cols; ++k) {
89              // select the column with the greatest norm on active components
90              int nextColumn = -1;
91              double ak2 = Double.NEGATIVE_INFINITY;
92              if (allowPivot) {
93                  for (int i = k; i < cols; ++i) {
94                      double norm2 = 0;
95                      for (int j = k; j < rows; ++j) {
96                          final double aki = qr[j][permutation[i]];
97                          norm2 += aki * aki;
98                      }
99                      if (Double.isInfinite(norm2) || Double.isNaN(norm2)) {
100                         throw new ConvergenceException(LocalizedFormats.UNABLE_TO_PERFORM_QR_DECOMPOSITION_ON_JACOBIAN,
101                                 rows, cols);
102                     }
103                     if (norm2 > ak2) {
104                         nextColumn = i;
105                         ak2 = norm2;
106                     }
107                 }
108             } else {
109                 nextColumn = k;
110                 ak2 = 0.0;
111                 for (int j = k; j < rows; ++j) {
112                     final double aki = qr[j][k];
113                     ak2 += aki * aki;
114                 }
115             }
116             if (ak2 <= qrRankingThreshold) {
117                 rank = k;
118                 for (int i = rank; i < rows; i++) {
119                     for (int j = i + 1; j < cols; j++) {
120                         qr[i][permutation[j]] = 0.0;
121                     }
122                 }
123                 return;
124             }
125             final int pk = permutation[nextColumn];
126             permutation[nextColumn] = permutation[k];
127             permutation[k] = pk;
128 
129             // choose alpha such that Hk.u = alpha ek
130             final double akk = qr[k][pk];
131             final double alpha = (akk > 0) ? -FastMath.sqrt(ak2) : FastMath.sqrt(ak2);
132             final double betak = 1.0 / (ak2 - akk * alpha);
133             beta[pk] = betak;
134 
135             // transform the current column
136             rDiag[pk] = alpha;
137             qr[k][pk] -= alpha;
138 
139             // transform the remaining columns
140             for (int dk = cols - 1 - k; dk > 0; --dk) {
141                 double gamma = 0;
142                 for (int j = k; j < rows; ++j) {
143                     gamma += qr[j][pk] * qr[j][permutation[k + dk]];
144                 }
145                 gamma *= betak;
146                 for (int j = k; j < rows; ++j) {
147                     qr[j][permutation[k + dk]] -= gamma * qr[j][pk];
148                 }
149             }
150         }
151         rank = cols;
152         return;
153     }
154 
155     /**
156      * Returns the matrix Q of the decomposition.
157      * <p>Q is an orthogonal matrix</p>
158      * @return the Q matrix
159      */
160     public RealMatrix getQ() {
161         if (cachedQ == null) {
162             cachedQ = getQT().transpose();
163         }
164         return cachedQ;
165     }
166 
167     /**
168      * Returns the transpose of the matrix Q of the decomposition.
169      * <p>Q is an orthogonal matrix</p>
170      * @return the Q matrix
171      */
172     public RealMatrix getQT() {
173         if (cachedQT == null) {
174 
175             // QT is supposed to be m x m
176             final int m = qr.length;
177             cachedQT = MatrixUtils.createRealMatrix(m, m);
178 
179             /*
180              * Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
181              * applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
182              * succession to the result
183              */
184             for (int minor = m - 1; minor >= rank; minor--) {
185                 cachedQT.setEntry(minor, minor, 1.0);
186             }
187 
188             for (int minor = rank - 1; minor >= 0; minor--) {
189                 //final double[] qrtMinor = qrt[minor];
190                 final int p_minor = permutation[minor];
191                 cachedQT.setEntry(minor, minor, 1.0);
192                 //if (qrtMinor[minor] != 0.0) {
193                 for (int col = minor; col < m; col++) {
194                     double alpha = 0.0;
195                     for (int row = minor; row < m; row++) {
196                         alpha -= cachedQT.getEntry(col, row) * qr[row][p_minor];
197                     }
198                     alpha /= rDiag[p_minor] * qr[minor][p_minor];
199                     for (int row = minor; row < m; row++) {
200                         cachedQT.addToEntry(col, row, -alpha * qr[row][p_minor]);
201                     }
202                 }
203                 //}
204             }
205         }
206         // return the cached matrix
207         return cachedQT;
208     }
209 
210     /**
211      * Returns the matrix R of the decomposition.
212      * <p>R is an upper-triangular matrix</p>
213      * @return the R matrix
214      */
215     public RealMatrix getR() {
216         if (cachedR == null) {
217             // R is supposed to be m x n
218             final int n = qr[0].length;
219             final int m = qr.length;
220             cachedR = MatrixUtils.createRealMatrix(m, n);
221             // copy the diagonal from rDiag and the upper triangle of qr
222             for (int row = rank - 1; row >= 0; row--) {
223                 cachedR.setEntry(row, row, rDiag[permutation[row]]);
224                 for (int col = row + 1; col < n; col++) {
225                     cachedR.setEntry(row, col, qr[row][permutation[col]]);
226                 }
227             }
228         }
229         // return the cached matrix
230         return cachedR;
231     }
232 
233     public RealMatrix getH() {
234         if (cachedH == null) {
235             final int n = qr[0].length;
236             final int m = qr.length;
237             cachedH = MatrixUtils.createRealMatrix(m, n);
238             for (int i = 0; i < m; ++i) {
239                 for (int j = 0; j < FastMath.min(i + 1, n); ++j) {
240                     final int p_j = permutation[j];
241                     cachedH.setEntry(i, j, qr[i][p_j] / -rDiag[p_j]);
242                 }
243             }
244         }
245         // return the cached matrix
246         return cachedH;
247     }
248 
249     public RealMatrix getPermutationMatrix() {
250         RealMatrix rm = MatrixUtils.createRealMatrix(qr[0].length, qr[0].length);
251         for (int i = 0; i < this.qr[0].length; i++) {
252             rm.setEntry(permutation[i], i, 1.0);
253         }
254         return rm;
255     }
256 
257     public DecompositionSolver getSolver() {
258         return new Solver(qr, rDiag, permutation, rank);
259     }
260 
261     /** Specialized solver. */
262     private static class Solver implements DecompositionSolver {
263 
264         /**
265          * A packed TRANSPOSED representation of the QR decomposition.
266          * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
267          * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
268          * from which an explicit form of Q can be recomputed if desired.</p>
269          */
270         private final double[][] qr;
271         /** The diagonal elements of R. */
272         private final double[] rDiag;
273         /** The rank of the matrix      */
274         private final int rank;
275         /** The permutation matrix      */
276         private final int[] perm;
277 
278         /**
279          * Build a solver from decomposed matrix.
280          * @param qrt packed TRANSPOSED representation of the QR decomposition
281          * @param rDiag diagonal elements of R
282          */
283         private Solver(final double[][] qr, final double[] rDiag, int[] perm, int rank) {
284             this.qr = qr;
285             this.rDiag = rDiag;
286             this.perm = perm;
287             this.rank = rank;
288         }
289 
290         /** {@inheritDoc} */
291         public boolean isNonSingular() {
292             if (qr.length >= qr[0].length) {
293                 return rank == qr[0].length;
294             } else { //qr.length < qr[0].length
295                 return rank == qr.length;
296             }
297         }
298 
299         /** {@inheritDoc} */
300         public RealVector solve(RealVector b) {
301             final int n = qr[0].length;
302             final int m = qr.length;
303             if (b.getDimension() != m) {
304                 throw new DimensionMismatchException(b.getDimension(), m);
305             }
306             if (!isNonSingular()) {
307                 throw new SingularMatrixException();
308             }
309 
310             final double[] x = new double[n];
311             final double[] y = b.toArray();
312 
313             // apply Householder transforms to solve Q.y = b
314             for (int minor = 0; minor < rank; minor++) {
315                 final int m_idx = perm[minor];
316                 double dotProduct = 0;
317                 for (int row = minor; row < m; row++) {
318                     dotProduct += y[row] * qr[row][m_idx];
319                 }
320                 dotProduct /= rDiag[m_idx] * qr[minor][m_idx];
321                 for (int row = minor; row < m; row++) {
322                     y[row] += dotProduct * qr[row][m_idx];
323                 }
324             }
325             // solve triangular system R.x = y
326             for (int row = rank - 1; row >= 0; --row) {
327                 final int m_row = perm[row];
328                 y[row] /= rDiag[m_row];
329                 final double yRow = y[row];
330                 //final double[] qrtRow = qrt[row];
331                 x[perm[row]] = yRow;
332                 for (int i = 0; i < row; i++) {
333                     y[i] -= yRow * qr[i][m_row];
334                 }
335             }
336             return new ArrayRealVector(x, false);
337         }
338 
339         /** {@inheritDoc} */
340         public RealMatrix solve(RealMatrix b) {
341             final int cols = qr[0].length;
342             final int rows = qr.length;
343             if (b.getRowDimension() != rows) {
344                 throw new DimensionMismatchException(b.getRowDimension(), rows);
345             }
346             if (!isNonSingular()) {
347                 throw new SingularMatrixException();
348             }
349 
350             final int columns = b.getColumnDimension();
351             final int blockSize = BlockRealMatrix.BLOCK_SIZE;
352             final int cBlocks = (columns + blockSize - 1) / blockSize;
353             final double[][] xBlocks = BlockRealMatrix.createBlocksLayout(cols, columns);
354             final double[][] y = new double[b.getRowDimension()][blockSize];
355             final double[] alpha = new double[blockSize];
356             //final BlockRealMatrix result = new BlockRealMatrix(cols, columns, xBlocks, false);
357             for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
358                 final int kStart = kBlock * blockSize;
359                 final int kEnd = FastMath.min(kStart + blockSize, columns);
360                 final int kWidth = kEnd - kStart;
361                 // get the right hand side vector
362                 b.copySubMatrix(0, rows - 1, kStart, kEnd - 1, y);
363 
364                 // apply Householder transforms to solve Q.y = b
365                 for (int minor = 0; minor < rank; minor++) {
366                     final int m_idx = perm[minor];
367                     final double factor = 1.0 / (rDiag[m_idx] * qr[minor][m_idx]);
368 
369                     Arrays.fill(alpha, 0, kWidth, 0.0);
370                     for (int row = minor; row < rows; ++row) {
371                         final double d = qr[row][m_idx];
372                         final double[] yRow = y[row];
373                         for (int k = 0; k < kWidth; ++k) {
374                             alpha[k] += d * yRow[k];
375                         }
376                     }
377                     for (int k = 0; k < kWidth; ++k) {
378                         alpha[k] *= factor;
379                     }
380 
381                     for (int row = minor; row < rows; ++row) {
382                         final double d = qr[row][m_idx];
383                         final double[] yRow = y[row];
384                         for (int k = 0; k < kWidth; ++k) {
385                             yRow[k] += alpha[k] * d;
386                         }
387                     }
388                 }
389 
390                 // solve triangular system R.x = y
391                 for (int j = rank - 1; j >= 0; --j) {
392                     final int jBlock = perm[j] / blockSize; //which block
393                     final int jStart = jBlock * blockSize;  // idx of top corner of block in my coord
394                     final double factor = 1.0 / rDiag[perm[j]];
395                     final double[] yJ = y[j];
396                     final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
397                     int index = (perm[j] - jStart) * kWidth; //to local (block) coordinates
398                     for (int k = 0; k < kWidth; ++k) {
399                         yJ[k] *= factor;
400                         xBlock[index++] = yJ[k];
401                     }
402                     for (int i = 0; i < j; ++i) {
403                         final double rIJ = qr[i][perm[j]];
404                         final double[] yI = y[i];
405                         for (int k = 0; k < kWidth; ++k) {
406                             yI[k] -= yJ[k] * rIJ;
407                         }
408                     }
409                 }
410             }
411             //return result;
412             return new BlockRealMatrix(cols, columns, xBlocks, false);
413         }
414 
415         /** {@inheritDoc} */
416         public RealMatrix getInverse() {
417             return solve(MatrixUtils.createRealIdentityMatrix(rDiag.length));
418         }
419     }
420 }