001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math4.legacy.linear; 019 020import java.util.Arrays; 021 022import org.apache.commons.math4.legacy.exception.DimensionMismatchException; 023import org.apache.commons.math4.core.jdkmath.JdkMath; 024import org.apache.commons.math4.legacy.exception.util.LocalizedFormats; 025 026 027/** 028 * Calculates the QR-decomposition of a matrix. 029 * <p>The QR-decomposition of a matrix A consists of two matrices Q and R 030 * that satisfy: A = QR, Q is orthogonal (Q<sup>T</sup>Q = I), and R is 031 * upper triangular. If A is m×n, Q is m×m and R m×n.</p> 032 * <p>This class compute the decomposition using Householder reflectors.</p> 033 * <p>For efficiency purposes, the decomposition in packed form is transposed. 034 * This allows inner loop to iterate inside rows, which is much more cache-efficient 035 * in Java.</p> 036 * <p>This class is based on the class with similar name from the 037 * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library, with the 038 * following changes:</p> 039 * <ul> 040 * <li>a {@link #getQT() getQT} method has been added,</li> 041 * <li>the {@code solve} and {@code isFullRank} methods have been replaced 042 * by a {@link #getSolver() getSolver} method and the equivalent methods 043 * provided by the returned {@link DecompositionSolver}.</li> 044 * </ul> 045 * 046 * @see <a href="http://mathworld.wolfram.com/QRDecomposition.html">MathWorld</a> 047 * @see <a href="http://en.wikipedia.org/wiki/QR_decomposition">Wikipedia</a> 048 * 049 * @since 1.2 (changed to concrete class in 3.0) 050 */ 051public class QRDecomposition { 052 /** 053 * A packed TRANSPOSED representation of the QR decomposition. 054 * <p>The elements BELOW the diagonal are the elements of the UPPER triangular 055 * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors 056 * from which an explicit form of Q can be recomputed if desired.</p> 057 */ 058 private double[][] qrt; 059 /** The diagonal elements of R. */ 060 private double[] rDiag; 061 /** Cached value of Q. */ 062 private RealMatrix cachedQ; 063 /** Cached value of QT. */ 064 private RealMatrix cachedQT; 065 /** Cached value of R. */ 066 private RealMatrix cachedR; 067 /** Cached value of H. */ 068 private RealMatrix cachedH; 069 /** Singularity threshold. */ 070 private final double threshold; 071 072 /** 073 * Calculates the QR-decomposition of the given matrix. 074 * The singularity threshold defaults to zero. 075 * 076 * @param matrix The matrix to decompose. 077 * 078 * @see #QRDecomposition(RealMatrix,double) 079 */ 080 public QRDecomposition(RealMatrix matrix) { 081 this(matrix, 0d); 082 } 083 084 /** 085 * Calculates the QR-decomposition of the given matrix. 086 * 087 * @param matrix The matrix to decompose. 088 * @param threshold Singularity threshold. 089 * The matrix will be considered singular if the absolute value of 090 * any of the diagonal elements of the "R" matrix is smaller than 091 * the threshold. 092 */ 093 public QRDecomposition(RealMatrix matrix, 094 double threshold) { 095 this.threshold = threshold; 096 097 final int m = matrix.getRowDimension(); 098 final int n = matrix.getColumnDimension(); 099 qrt = matrix.transpose().getData(); 100 rDiag = new double[JdkMath.min(m, n)]; 101 cachedQ = null; 102 cachedQT = null; 103 cachedR = null; 104 cachedH = null; 105 106 decompose(qrt); 107 } 108 109 /** Decompose matrix. 110 * @param matrix transposed matrix 111 * @since 3.2 112 */ 113 protected void decompose(double[][] matrix) { 114 for (int minor = 0; minor < JdkMath.min(matrix.length, matrix[0].length); minor++) { 115 performHouseholderReflection(minor, matrix); 116 } 117 } 118 119 /** Perform Householder reflection for a minor A(minor, minor) of A. 120 * @param minor minor index 121 * @param matrix transposed matrix 122 * @since 3.2 123 */ 124 protected void performHouseholderReflection(int minor, double[][] matrix) { 125 126 final double[] qrtMinor = matrix[minor]; 127 128 /* 129 * Let x be the first column of the minor, and a^2 = |x|^2. 130 * x will be in the positions qr[minor][minor] through qr[m][minor]. 131 * The first column of the transformed minor will be (a,0,0,..)' 132 * The sign of a is chosen to be opposite to the sign of the first 133 * component of x. Let's find a: 134 */ 135 double xNormSqr = 0; 136 for (int row = minor; row < qrtMinor.length; row++) { 137 final double c = qrtMinor[row]; 138 xNormSqr += c * c; 139 } 140 final double a = (qrtMinor[minor] > 0) ? -JdkMath.sqrt(xNormSqr) : JdkMath.sqrt(xNormSqr); 141 rDiag[minor] = a; 142 143 if (a != 0.0) { 144 145 /* 146 * Calculate the normalized reflection vector v and transform 147 * the first column. We know the norm of v beforehand: v = x-ae 148 * so |v|^2 = <x-ae,x-ae> = <x,x>-2a<x,e>+a^2<e,e> = 149 * a^2+a^2-2a<x,e> = 2a*(a - <x,e>). 150 * Here <x, e> is now qr[minor][minor]. 151 * v = x-ae is stored in the column at qr: 152 */ 153 qrtMinor[minor] -= a; // now |v|^2 = -2a*(qr[minor][minor]) 154 155 /* 156 * Transform the rest of the columns of the minor: 157 * They will be transformed by the matrix H = I-2vv'/|v|^2. 158 * If x is a column vector of the minor, then 159 * Hx = (I-2vv'/|v|^2)x = x-2vv'x/|v|^2 = x - 2<x,v>/|v|^2 v. 160 * Therefore the transformation is easily calculated by 161 * subtracting the column vector (2<x,v>/|v|^2)v from x. 162 * 163 * Let 2<x,v>/|v|^2 = alpha. From above we have 164 * |v|^2 = -2a*(qr[minor][minor]), so 165 * alpha = -<x,v>/(a*qr[minor][minor]) 166 */ 167 for (int col = minor+1; col < matrix.length; col++) { 168 final double[] qrtCol = matrix[col]; 169 double alpha = 0; 170 for (int row = minor; row < qrtCol.length; row++) { 171 alpha -= qrtCol[row] * qrtMinor[row]; 172 } 173 alpha /= a * qrtMinor[minor]; 174 175 // Subtract the column vector alpha*v from x. 176 for (int row = minor; row < qrtCol.length; row++) { 177 qrtCol[row] -= alpha * qrtMinor[row]; 178 } 179 } 180 } 181 } 182 183 184 /** 185 * Returns the matrix R of the decomposition. 186 * <p>R is an upper-triangular matrix</p> 187 * @return the R matrix 188 */ 189 public RealMatrix getR() { 190 191 if (cachedR == null) { 192 193 // R is supposed to be m x n 194 final int n = qrt.length; 195 final int m = qrt[0].length; 196 double[][] ra = new double[m][n]; 197 // copy the diagonal from rDiag and the upper triangle of qr 198 for (int row = JdkMath.min(m, n) - 1; row >= 0; row--) { 199 ra[row][row] = rDiag[row]; 200 for (int col = row + 1; col < n; col++) { 201 ra[row][col] = qrt[col][row]; 202 } 203 } 204 cachedR = MatrixUtils.createRealMatrix(ra); 205 } 206 207 // return the cached matrix 208 return cachedR; 209 } 210 211 /** 212 * Returns the matrix Q of the decomposition. 213 * <p>Q is an orthogonal matrix</p> 214 * @return the Q matrix 215 */ 216 public RealMatrix getQ() { 217 if (cachedQ == null) { 218 cachedQ = getQT().transpose(); 219 } 220 return cachedQ; 221 } 222 223 /** 224 * Returns the transpose of the matrix Q of the decomposition. 225 * <p>Q is an orthogonal matrix</p> 226 * @return the transpose of the Q matrix, Q<sup>T</sup> 227 */ 228 public RealMatrix getQT() { 229 if (cachedQT == null) { 230 231 // QT is supposed to be m x m 232 final int n = qrt.length; 233 final int m = qrt[0].length; 234 double[][] qta = new double[m][m]; 235 236 /* 237 * Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then 238 * applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in 239 * succession to the result 240 */ 241 for (int minor = m - 1; minor >= JdkMath.min(m, n); minor--) { 242 qta[minor][minor] = 1.0d; 243 } 244 245 for (int minor = JdkMath.min(m, n)-1; minor >= 0; minor--){ 246 final double[] qrtMinor = qrt[minor]; 247 qta[minor][minor] = 1.0d; 248 if (qrtMinor[minor] != 0.0) { 249 for (int col = minor; col < m; col++) { 250 double alpha = 0; 251 for (int row = minor; row < m; row++) { 252 alpha -= qta[col][row] * qrtMinor[row]; 253 } 254 alpha /= rDiag[minor] * qrtMinor[minor]; 255 256 for (int row = minor; row < m; row++) { 257 qta[col][row] += -alpha * qrtMinor[row]; 258 } 259 } 260 } 261 } 262 cachedQT = MatrixUtils.createRealMatrix(qta); 263 } 264 265 // return the cached matrix 266 return cachedQT; 267 } 268 269 /** 270 * Returns the Householder reflector vectors. 271 * <p>H is a lower trapezoidal matrix whose columns represent 272 * each successive Householder reflector vector. This matrix is used 273 * to compute Q.</p> 274 * @return a matrix containing the Householder reflector vectors 275 */ 276 public RealMatrix getH() { 277 if (cachedH == null) { 278 279 final int n = qrt.length; 280 final int m = qrt[0].length; 281 double[][] ha = new double[m][n]; 282 for (int i = 0; i < m; ++i) { 283 for (int j = 0; j < JdkMath.min(i + 1, n); ++j) { 284 ha[i][j] = qrt[j][i] / -rDiag[j]; 285 } 286 } 287 cachedH = MatrixUtils.createRealMatrix(ha); 288 } 289 290 // return the cached matrix 291 return cachedH; 292 } 293 294 /** 295 * Get a solver for finding the A × X = B solution in least square sense. 296 * <p> 297 * Least Square sense means a solver can be computed for an overdetermined system, 298 * (i.e. a system with more equations than unknowns, which corresponds to a tall A 299 * matrix with more rows than columns). In any case, if the matrix is singular 300 * within the tolerance set at {@link QRDecomposition#QRDecomposition(RealMatrix, 301 * double) construction}, an error will be triggered when 302 * the {@link DecompositionSolver#solve(RealVector) solve} method will be called. 303 * </p> 304 * @return a solver 305 */ 306 public DecompositionSolver getSolver() { 307 return new Solver(qrt, rDiag, threshold); 308 } 309 310 /** Specialized solver. */ 311 private static final class Solver implements DecompositionSolver { 312 /** 313 * A packed TRANSPOSED representation of the QR decomposition. 314 * <p>The elements BELOW the diagonal are the elements of the UPPER triangular 315 * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors 316 * from which an explicit form of Q can be recomputed if desired.</p> 317 */ 318 private final double[][] qrt; 319 /** The diagonal elements of R. */ 320 private final double[] rDiag; 321 /** Singularity threshold. */ 322 private final double threshold; 323 324 /** 325 * Build a solver from decomposed matrix. 326 * 327 * @param qrt Packed TRANSPOSED representation of the QR decomposition. 328 * @param rDiag Diagonal elements of R. 329 * @param threshold Singularity threshold. 330 */ 331 private Solver(final double[][] qrt, 332 final double[] rDiag, 333 final double threshold) { 334 this.qrt = qrt; 335 this.rDiag = rDiag; 336 this.threshold = threshold; 337 } 338 339 /** {@inheritDoc} */ 340 @Override 341 public boolean isNonSingular() { 342 return !checkSingular(rDiag, threshold, false); 343 } 344 345 /** {@inheritDoc} */ 346 @Override 347 public RealVector solve(RealVector b) { 348 final int n = qrt.length; 349 final int m = qrt[0].length; 350 if (b.getDimension() != m) { 351 throw new DimensionMismatchException(b.getDimension(), m); 352 } 353 checkSingular(rDiag, threshold, true); 354 355 final double[] x = new double[n]; 356 final double[] y = b.toArray(); 357 358 // apply Householder transforms to solve Q.y = b 359 for (int minor = 0; minor < JdkMath.min(m, n); minor++) { 360 361 final double[] qrtMinor = qrt[minor]; 362 double dotProduct = 0; 363 for (int row = minor; row < m; row++) { 364 dotProduct += y[row] * qrtMinor[row]; 365 } 366 dotProduct /= rDiag[minor] * qrtMinor[minor]; 367 368 for (int row = minor; row < m; row++) { 369 y[row] += dotProduct * qrtMinor[row]; 370 } 371 } 372 373 // solve triangular system R.x = y 374 for (int row = rDiag.length - 1; row >= 0; --row) { 375 y[row] /= rDiag[row]; 376 final double yRow = y[row]; 377 final double[] qrtRow = qrt[row]; 378 x[row] = yRow; 379 for (int i = 0; i < row; i++) { 380 y[i] -= yRow * qrtRow[i]; 381 } 382 } 383 384 return new ArrayRealVector(x, false); 385 } 386 387 /** {@inheritDoc} */ 388 @Override 389 public RealMatrix solve(RealMatrix b) { 390 final int n = qrt.length; 391 final int m = qrt[0].length; 392 if (b.getRowDimension() != m) { 393 throw new DimensionMismatchException(b.getRowDimension(), m); 394 } 395 checkSingular(rDiag, threshold, true); 396 397 final int columns = b.getColumnDimension(); 398 final int blockSize = BlockRealMatrix.BLOCK_SIZE; 399 final int cBlocks = (columns + blockSize - 1) / blockSize; 400 final double[][] xBlocks = BlockRealMatrix.createBlocksLayout(n, columns); 401 final double[][] y = new double[b.getRowDimension()][blockSize]; 402 final double[] alpha = new double[blockSize]; 403 404 for (int kBlock = 0; kBlock < cBlocks; ++kBlock) { 405 final int kStart = kBlock * blockSize; 406 final int kEnd = JdkMath.min(kStart + blockSize, columns); 407 final int kWidth = kEnd - kStart; 408 409 // get the right hand side vector 410 b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y); 411 412 // apply Householder transforms to solve Q.y = b 413 for (int minor = 0; minor < JdkMath.min(m, n); minor++) { 414 final double[] qrtMinor = qrt[minor]; 415 final double factor = 1.0 / (rDiag[minor] * qrtMinor[minor]); 416 417 Arrays.fill(alpha, 0, kWidth, 0.0); 418 for (int row = minor; row < m; ++row) { 419 final double d = qrtMinor[row]; 420 final double[] yRow = y[row]; 421 for (int k = 0; k < kWidth; ++k) { 422 alpha[k] += d * yRow[k]; 423 } 424 } 425 for (int k = 0; k < kWidth; ++k) { 426 alpha[k] *= factor; 427 } 428 429 for (int row = minor; row < m; ++row) { 430 final double d = qrtMinor[row]; 431 final double[] yRow = y[row]; 432 for (int k = 0; k < kWidth; ++k) { 433 yRow[k] += alpha[k] * d; 434 } 435 } 436 } 437 438 // solve triangular system R.x = y 439 for (int j = rDiag.length - 1; j >= 0; --j) { 440 final int jBlock = j / blockSize; 441 final int jStart = jBlock * blockSize; 442 final double factor = 1.0 / rDiag[j]; 443 final double[] yJ = y[j]; 444 final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock]; 445 int index = (j - jStart) * kWidth; 446 for (int k = 0; k < kWidth; ++k) { 447 yJ[k] *= factor; 448 xBlock[index++] = yJ[k]; 449 } 450 451 final double[] qrtJ = qrt[j]; 452 for (int i = 0; i < j; ++i) { 453 final double rIJ = qrtJ[i]; 454 final double[] yI = y[i]; 455 for (int k = 0; k < kWidth; ++k) { 456 yI[k] -= yJ[k] * rIJ; 457 } 458 } 459 } 460 } 461 462 return new BlockRealMatrix(n, columns, xBlocks, false); 463 } 464 465 /** 466 * {@inheritDoc} 467 * @throws SingularMatrixException if the decomposed matrix is singular. 468 */ 469 @Override 470 public RealMatrix getInverse() { 471 return solve(MatrixUtils.createRealIdentityMatrix(qrt[0].length)); 472 } 473 474 /** 475 * Check singularity. 476 * 477 * @param diag Diagonal elements of the R matrix. 478 * @param min Singularity threshold. 479 * @param raise Whether to raise a {@link SingularMatrixException} 480 * if any element of the diagonal fails the check. 481 * @return {@code true} if any element of the diagonal is smaller 482 * or equal to {@code min}. 483 * @throws SingularMatrixException if the matrix is singular and 484 * {@code raise} is {@code true}. 485 */ 486 private static boolean checkSingular(double[] diag, 487 double min, 488 boolean raise) { 489 final int len = diag.length; 490 for (int i = 0; i < len; i++) { 491 final double d = diag[i]; 492 if (JdkMath.abs(d) <= min) { 493 if (raise) { 494 final SingularMatrixException e = new SingularMatrixException(); 495 e.getContext().addMessage(LocalizedFormats.NUMBER_TOO_SMALL, d, min); 496 e.getContext().addMessage(LocalizedFormats.INDEX, i); 497 throw e; 498 } else { 499 return true; 500 } 501 } 502 } 503 return false; 504 } 505 } 506}