1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.apache.commons.math.linear;
17
18 import java.util.Arrays;
19 import org.apache.commons.math.util.MathArrays;
20 import org.apache.commons.math.exception.ConvergenceException;
21 import org.apache.commons.math.exception.DimensionMismatchException;
22 import org.apache.commons.math.exception.util.LocalizedFormats;
23 import org.apache.commons.math.util.FastMath;
24
25
26
27
28
29 public class PivotingQRDecomposition {
30
31 private double[][] qr;
32
33 private double[] rDiag;
34
35 private RealMatrix cachedQ;
36
37 private RealMatrix cachedQT;
38
39 private RealMatrix cachedR;
40
41 private RealMatrix cachedH;
42
43 private int[] permutation;
44
45 private int rank;
46
47 private double[] beta;
48
49 public boolean isSingular() {
50 return rank != qr[0].length;
51 }
52
53 public int getRank() {
54 return rank;
55 }
56
57 public int[] getOrder() {
58 return MathArrays.copyOf(permutation);
59 }
60
61 public PivotingQRDecomposition(RealMatrix matrix) throws ConvergenceException {
62 this(matrix, 1.0e-16, true);
63 }
64
65 public PivotingQRDecomposition(RealMatrix matrix, boolean allowPivot) throws ConvergenceException {
66 this(matrix, 1.0e-16, allowPivot);
67 }
68
69 public PivotingQRDecomposition(RealMatrix matrix, double qrRankingThreshold,
70 boolean allowPivot) throws ConvergenceException {
71 final int rows = matrix.getRowDimension();
72 final int cols = matrix.getColumnDimension();
73 qr = matrix.getData();
74 rDiag = new double[cols];
75
76 this.beta = new double[cols];
77 this.permutation = new int[cols];
78 cachedQ = null;
79 cachedQT = null;
80 cachedR = null;
81 cachedH = null;
82
83
84 for (int k = 0; k < cols; ++k) {
85 permutation[k] = k;
86 }
87
88 for (int k = 0; k < cols; ++k) {
89
90 int nextColumn = -1;
91 double ak2 = Double.NEGATIVE_INFINITY;
92 if (allowPivot) {
93 for (int i = k; i < cols; ++i) {
94 double norm2 = 0;
95 for (int j = k; j < rows; ++j) {
96 final double aki = qr[j][permutation[i]];
97 norm2 += aki * aki;
98 }
99 if (Double.isInfinite(norm2) || Double.isNaN(norm2)) {
100 throw new ConvergenceException(LocalizedFormats.UNABLE_TO_PERFORM_QR_DECOMPOSITION_ON_JACOBIAN,
101 rows, cols);
102 }
103 if (norm2 > ak2) {
104 nextColumn = i;
105 ak2 = norm2;
106 }
107 }
108 } else {
109 nextColumn = k;
110 ak2 = 0.0;
111 for (int j = k; j < rows; ++j) {
112 final double aki = qr[j][k];
113 ak2 += aki * aki;
114 }
115 }
116 if (ak2 <= qrRankingThreshold) {
117 rank = k;
118 for (int i = rank; i < rows; i++) {
119 for (int j = i + 1; j < cols; j++) {
120 qr[i][permutation[j]] = 0.0;
121 }
122 }
123 return;
124 }
125 final int pk = permutation[nextColumn];
126 permutation[nextColumn] = permutation[k];
127 permutation[k] = pk;
128
129
130 final double akk = qr[k][pk];
131 final double alpha = (akk > 0) ? -FastMath.sqrt(ak2) : FastMath.sqrt(ak2);
132 final double betak = 1.0 / (ak2 - akk * alpha);
133 beta[pk] = betak;
134
135
136 rDiag[pk] = alpha;
137 qr[k][pk] -= alpha;
138
139
140 for (int dk = cols - 1 - k; dk > 0; --dk) {
141 double gamma = 0;
142 for (int j = k; j < rows; ++j) {
143 gamma += qr[j][pk] * qr[j][permutation[k + dk]];
144 }
145 gamma *= betak;
146 for (int j = k; j < rows; ++j) {
147 qr[j][permutation[k + dk]] -= gamma * qr[j][pk];
148 }
149 }
150 }
151 rank = cols;
152 return;
153 }
154
155
156
157
158
159
160 public RealMatrix getQ() {
161 if (cachedQ == null) {
162 cachedQ = getQT().transpose();
163 }
164 return cachedQ;
165 }
166
167
168
169
170
171
172 public RealMatrix getQT() {
173 if (cachedQT == null) {
174
175
176 final int m = qr.length;
177 cachedQT = MatrixUtils.createRealMatrix(m, m);
178
179
180
181
182
183
184 for (int minor = m - 1; minor >= rank; minor--) {
185 cachedQT.setEntry(minor, minor, 1.0);
186 }
187
188 for (int minor = rank - 1; minor >= 0; minor--) {
189
190 final int p_minor = permutation[minor];
191 cachedQT.setEntry(minor, minor, 1.0);
192
193 for (int col = minor; col < m; col++) {
194 double alpha = 0.0;
195 for (int row = minor; row < m; row++) {
196 alpha -= cachedQT.getEntry(col, row) * qr[row][p_minor];
197 }
198 alpha /= rDiag[p_minor] * qr[minor][p_minor];
199 for (int row = minor; row < m; row++) {
200 cachedQT.addToEntry(col, row, -alpha * qr[row][p_minor]);
201 }
202 }
203
204 }
205 }
206
207 return cachedQT;
208 }
209
210
211
212
213
214
215 public RealMatrix getR() {
216 if (cachedR == null) {
217
218 final int n = qr[0].length;
219 final int m = qr.length;
220 cachedR = MatrixUtils.createRealMatrix(m, n);
221
222 for (int row = rank - 1; row >= 0; row--) {
223 cachedR.setEntry(row, row, rDiag[permutation[row]]);
224 for (int col = row + 1; col < n; col++) {
225 cachedR.setEntry(row, col, qr[row][permutation[col]]);
226 }
227 }
228 }
229
230 return cachedR;
231 }
232
233 public RealMatrix getH() {
234 if (cachedH == null) {
235 final int n = qr[0].length;
236 final int m = qr.length;
237 cachedH = MatrixUtils.createRealMatrix(m, n);
238 for (int i = 0; i < m; ++i) {
239 for (int j = 0; j < FastMath.min(i + 1, n); ++j) {
240 final int p_j = permutation[j];
241 cachedH.setEntry(i, j, qr[i][p_j] / -rDiag[p_j]);
242 }
243 }
244 }
245
246 return cachedH;
247 }
248
249 public RealMatrix getPermutationMatrix() {
250 RealMatrix rm = MatrixUtils.createRealMatrix(qr[0].length, qr[0].length);
251 for (int i = 0; i < this.qr[0].length; i++) {
252 rm.setEntry(permutation[i], i, 1.0);
253 }
254 return rm;
255 }
256
257 public DecompositionSolver getSolver() {
258 return new Solver(qr, rDiag, permutation, rank);
259 }
260
261
262 private static class Solver implements DecompositionSolver {
263
264
265
266
267
268
269
270 private final double[][] qr;
271
272 private final double[] rDiag;
273
274 private final int rank;
275
276 private final int[] perm;
277
278
279
280
281
282
283 private Solver(final double[][] qr, final double[] rDiag, int[] perm, int rank) {
284 this.qr = qr;
285 this.rDiag = rDiag;
286 this.perm = perm;
287 this.rank = rank;
288 }
289
290
291 public boolean isNonSingular() {
292 if (qr.length >= qr[0].length) {
293 return rank == qr[0].length;
294 } else {
295 return rank == qr.length;
296 }
297 }
298
299
300 public RealVector solve(RealVector b) {
301 final int n = qr[0].length;
302 final int m = qr.length;
303 if (b.getDimension() != m) {
304 throw new DimensionMismatchException(b.getDimension(), m);
305 }
306 if (!isNonSingular()) {
307 throw new SingularMatrixException();
308 }
309
310 final double[] x = new double[n];
311 final double[] y = b.toArray();
312
313
314 for (int minor = 0; minor < rank; minor++) {
315 final int m_idx = perm[minor];
316 double dotProduct = 0;
317 for (int row = minor; row < m; row++) {
318 dotProduct += y[row] * qr[row][m_idx];
319 }
320 dotProduct /= rDiag[m_idx] * qr[minor][m_idx];
321 for (int row = minor; row < m; row++) {
322 y[row] += dotProduct * qr[row][m_idx];
323 }
324 }
325
326 for (int row = rank - 1; row >= 0; --row) {
327 final int m_row = perm[row];
328 y[row] /= rDiag[m_row];
329 final double yRow = y[row];
330
331 x[perm[row]] = yRow;
332 for (int i = 0; i < row; i++) {
333 y[i] -= yRow * qr[i][m_row];
334 }
335 }
336 return new ArrayRealVector(x, false);
337 }
338
339
340 public RealMatrix solve(RealMatrix b) {
341 final int cols = qr[0].length;
342 final int rows = qr.length;
343 if (b.getRowDimension() != rows) {
344 throw new DimensionMismatchException(b.getRowDimension(), rows);
345 }
346 if (!isNonSingular()) {
347 throw new SingularMatrixException();
348 }
349
350 final int columns = b.getColumnDimension();
351 final int blockSize = BlockRealMatrix.BLOCK_SIZE;
352 final int cBlocks = (columns + blockSize - 1) / blockSize;
353 final double[][] xBlocks = BlockRealMatrix.createBlocksLayout(cols, columns);
354 final double[][] y = new double[b.getRowDimension()][blockSize];
355 final double[] alpha = new double[blockSize];
356
357 for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
358 final int kStart = kBlock * blockSize;
359 final int kEnd = FastMath.min(kStart + blockSize, columns);
360 final int kWidth = kEnd - kStart;
361
362 b.copySubMatrix(0, rows - 1, kStart, kEnd - 1, y);
363
364
365 for (int minor = 0; minor < rank; minor++) {
366 final int m_idx = perm[minor];
367 final double factor = 1.0 / (rDiag[m_idx] * qr[minor][m_idx]);
368
369 Arrays.fill(alpha, 0, kWidth, 0.0);
370 for (int row = minor; row < rows; ++row) {
371 final double d = qr[row][m_idx];
372 final double[] yRow = y[row];
373 for (int k = 0; k < kWidth; ++k) {
374 alpha[k] += d * yRow[k];
375 }
376 }
377 for (int k = 0; k < kWidth; ++k) {
378 alpha[k] *= factor;
379 }
380
381 for (int row = minor; row < rows; ++row) {
382 final double d = qr[row][m_idx];
383 final double[] yRow = y[row];
384 for (int k = 0; k < kWidth; ++k) {
385 yRow[k] += alpha[k] * d;
386 }
387 }
388 }
389
390
391 for (int j = rank - 1; j >= 0; --j) {
392 final int jBlock = perm[j] / blockSize;
393 final int jStart = jBlock * blockSize;
394 final double factor = 1.0 / rDiag[perm[j]];
395 final double[] yJ = y[j];
396 final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
397 int index = (perm[j] - jStart) * kWidth;
398 for (int k = 0; k < kWidth; ++k) {
399 yJ[k] *= factor;
400 xBlock[index++] = yJ[k];
401 }
402 for (int i = 0; i < j; ++i) {
403 final double rIJ = qr[i][perm[j]];
404 final double[] yI = y[i];
405 for (int k = 0; k < kWidth; ++k) {
406 yI[k] -= yJ[k] * rIJ;
407 }
408 }
409 }
410 }
411
412 return new BlockRealMatrix(cols, columns, xBlocks, false);
413 }
414
415
416 public RealMatrix getInverse() {
417 return solve(MatrixUtils.createRealIdentityMatrix(rDiag.length));
418 }
419 }
420 }