001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math3.linear; 019 020import java.io.IOException; 021import java.io.ObjectInputStream; 022import java.io.ObjectOutputStream; 023import java.util.Arrays; 024 025import org.apache.commons.math3.Field; 026import org.apache.commons.math3.FieldElement; 027import org.apache.commons.math3.exception.DimensionMismatchException; 028import org.apache.commons.math3.exception.MathArithmeticException; 029import org.apache.commons.math3.exception.NoDataException; 030import org.apache.commons.math3.exception.NullArgumentException; 031import org.apache.commons.math3.exception.NumberIsTooSmallException; 032import org.apache.commons.math3.exception.OutOfRangeException; 033import org.apache.commons.math3.exception.ZeroException; 034import org.apache.commons.math3.exception.util.LocalizedFormats; 035import org.apache.commons.math3.fraction.BigFraction; 036import org.apache.commons.math3.fraction.Fraction; 037import org.apache.commons.math3.util.FastMath; 038import org.apache.commons.math3.util.MathArrays; 039import org.apache.commons.math3.util.MathUtils; 040import org.apache.commons.math3.util.Precision; 041 042/** 043 * A collection of static methods that operate on or return matrices. 044 * 045 */ 046public class MatrixUtils { 047 048 /** 049 * The default format for {@link RealMatrix} objects. 050 * @since 3.1 051 */ 052 public static final RealMatrixFormat DEFAULT_FORMAT = RealMatrixFormat.getInstance(); 053 054 /** 055 * A format for {@link RealMatrix} objects compatible with octave. 056 * @since 3.1 057 */ 058 public static final RealMatrixFormat OCTAVE_FORMAT = new RealMatrixFormat("[", "]", "", "", "; ", ", "); 059 060 /** 061 * Private constructor. 062 */ 063 private MatrixUtils() { 064 super(); 065 } 066 067 /** 068 * Returns a {@link RealMatrix} with specified dimensions. 069 * <p>The type of matrix returned depends on the dimension. Below 070 * 2<sup>12</sup> elements (i.e. 4096 elements or 64×64 for a 071 * square matrix) which can be stored in a 32kB array, a {@link 072 * Array2DRowRealMatrix} instance is built. Above this threshold a {@link 073 * BlockRealMatrix} instance is built.</p> 074 * <p>The matrix elements are all set to 0.0.</p> 075 * @param rows number of rows of the matrix 076 * @param columns number of columns of the matrix 077 * @return RealMatrix with specified dimensions 078 * @see #createRealMatrix(double[][]) 079 */ 080 public static RealMatrix createRealMatrix(final int rows, final int columns) { 081 return (rows * columns <= 4096) ? 082 new Array2DRowRealMatrix(rows, columns) : new BlockRealMatrix(rows, columns); 083 } 084 085 /** 086 * Returns a {@link FieldMatrix} with specified dimensions. 087 * <p>The type of matrix returned depends on the dimension. Below 088 * 2<sup>12</sup> elements (i.e. 4096 elements or 64×64 for a 089 * square matrix), a {@link FieldMatrix} instance is built. Above 090 * this threshold a {@link BlockFieldMatrix} instance is built.</p> 091 * <p>The matrix elements are all set to field.getZero().</p> 092 * @param <T> the type of the field elements 093 * @param field field to which the matrix elements belong 094 * @param rows number of rows of the matrix 095 * @param columns number of columns of the matrix 096 * @return FieldMatrix with specified dimensions 097 * @see #createFieldMatrix(FieldElement[][]) 098 * @since 2.0 099 */ 100 public static <T extends FieldElement<T>> FieldMatrix<T> createFieldMatrix(final Field<T> field, 101 final int rows, 102 final int columns) { 103 return (rows * columns <= 4096) ? 104 new Array2DRowFieldMatrix<T>(field, rows, columns) : new BlockFieldMatrix<T>(field, rows, columns); 105 } 106 107 /** 108 * Returns a {@link RealMatrix} whose entries are the the values in the 109 * the input array. 110 * <p>The type of matrix returned depends on the dimension. Below 111 * 2<sup>12</sup> elements (i.e. 4096 elements or 64×64 for a 112 * square matrix) which can be stored in a 32kB array, a {@link 113 * Array2DRowRealMatrix} instance is built. Above this threshold a {@link 114 * BlockRealMatrix} instance is built.</p> 115 * <p>The input array is copied, not referenced.</p> 116 * 117 * @param data input array 118 * @return RealMatrix containing the values of the array 119 * @throws org.apache.commons.math3.exception.DimensionMismatchException 120 * if {@code data} is not rectangular (not all rows have the same length). 121 * @throws NoDataException if a row or column is empty. 122 * @throws NullArgumentException if either {@code data} or {@code data[0]} 123 * is {@code null}. 124 * @throws DimensionMismatchException if {@code data} is not rectangular. 125 * @see #createRealMatrix(int, int) 126 */ 127 public static RealMatrix createRealMatrix(double[][] data) 128 throws NullArgumentException, DimensionMismatchException, 129 NoDataException { 130 if (data == null || 131 data[0] == null) { 132 throw new NullArgumentException(); 133 } 134 return (data.length * data[0].length <= 4096) ? 135 new Array2DRowRealMatrix(data) : new BlockRealMatrix(data); 136 } 137 138 /** 139 * Returns a {@link FieldMatrix} whose entries are the the values in the 140 * the input array. 141 * <p>The type of matrix returned depends on the dimension. Below 142 * 2<sup>12</sup> elements (i.e. 4096 elements or 64×64 for a 143 * square matrix), a {@link FieldMatrix} instance is built. Above 144 * this threshold a {@link BlockFieldMatrix} instance is built.</p> 145 * <p>The input array is copied, not referenced.</p> 146 * @param <T> the type of the field elements 147 * @param data input array 148 * @return a matrix containing the values of the array. 149 * @throws org.apache.commons.math3.exception.DimensionMismatchException 150 * if {@code data} is not rectangular (not all rows have the same length). 151 * @throws NoDataException if a row or column is empty. 152 * @throws NullArgumentException if either {@code data} or {@code data[0]} 153 * is {@code null}. 154 * @see #createFieldMatrix(Field, int, int) 155 * @since 2.0 156 */ 157 public static <T extends FieldElement<T>> FieldMatrix<T> createFieldMatrix(T[][] data) 158 throws DimensionMismatchException, NoDataException, NullArgumentException { 159 if (data == null || 160 data[0] == null) { 161 throw new NullArgumentException(); 162 } 163 return (data.length * data[0].length <= 4096) ? 164 new Array2DRowFieldMatrix<T>(data) : new BlockFieldMatrix<T>(data); 165 } 166 167 /** 168 * Returns <code>dimension x dimension</code> identity matrix. 169 * 170 * @param dimension dimension of identity matrix to generate 171 * @return identity matrix 172 * @throws IllegalArgumentException if dimension is not positive 173 * @since 1.1 174 */ 175 public static RealMatrix createRealIdentityMatrix(int dimension) { 176 final RealMatrix m = createRealMatrix(dimension, dimension); 177 for (int i = 0; i < dimension; ++i) { 178 m.setEntry(i, i, 1.0); 179 } 180 return m; 181 } 182 183 /** 184 * Returns <code>dimension x dimension</code> identity matrix. 185 * 186 * @param <T> the type of the field elements 187 * @param field field to which the elements belong 188 * @param dimension dimension of identity matrix to generate 189 * @return identity matrix 190 * @throws IllegalArgumentException if dimension is not positive 191 * @since 2.0 192 */ 193 public static <T extends FieldElement<T>> FieldMatrix<T> 194 createFieldIdentityMatrix(final Field<T> field, final int dimension) { 195 final T zero = field.getZero(); 196 final T one = field.getOne(); 197 final T[][] d = MathArrays.buildArray(field, dimension, dimension); 198 for (int row = 0; row < dimension; row++) { 199 final T[] dRow = d[row]; 200 Arrays.fill(dRow, zero); 201 dRow[row] = one; 202 } 203 return new Array2DRowFieldMatrix<T>(field, d, false); 204 } 205 206 /** 207 * Returns a diagonal matrix with specified elements. 208 * 209 * @param diagonal diagonal elements of the matrix (the array elements 210 * will be copied) 211 * @return diagonal matrix 212 * @since 2.0 213 */ 214 public static RealMatrix createRealDiagonalMatrix(final double[] diagonal) { 215 final RealMatrix m = createRealMatrix(diagonal.length, diagonal.length); 216 for (int i = 0; i < diagonal.length; ++i) { 217 m.setEntry(i, i, diagonal[i]); 218 } 219 return m; 220 } 221 222 /** 223 * Returns a diagonal matrix with specified elements. 224 * 225 * @param <T> the type of the field elements 226 * @param diagonal diagonal elements of the matrix (the array elements 227 * will be copied) 228 * @return diagonal matrix 229 * @since 2.0 230 */ 231 public static <T extends FieldElement<T>> FieldMatrix<T> 232 createFieldDiagonalMatrix(final T[] diagonal) { 233 final FieldMatrix<T> m = 234 createFieldMatrix(diagonal[0].getField(), diagonal.length, diagonal.length); 235 for (int i = 0; i < diagonal.length; ++i) { 236 m.setEntry(i, i, diagonal[i]); 237 } 238 return m; 239 } 240 241 /** 242 * Creates a {@link RealVector} using the data from the input array. 243 * 244 * @param data the input data 245 * @return a data.length RealVector 246 * @throws NoDataException if {@code data} is empty. 247 * @throws NullArgumentException if {@code data} is {@code null}. 248 */ 249 public static RealVector createRealVector(double[] data) 250 throws NoDataException, NullArgumentException { 251 if (data == null) { 252 throw new NullArgumentException(); 253 } 254 return new ArrayRealVector(data, true); 255 } 256 257 /** 258 * Creates a {@link FieldVector} using the data from the input array. 259 * 260 * @param <T> the type of the field elements 261 * @param data the input data 262 * @return a data.length FieldVector 263 * @throws NoDataException if {@code data} is empty. 264 * @throws NullArgumentException if {@code data} is {@code null}. 265 * @throws ZeroException if {@code data} has 0 elements 266 */ 267 public static <T extends FieldElement<T>> FieldVector<T> createFieldVector(final T[] data) 268 throws NoDataException, NullArgumentException, ZeroException { 269 if (data == null) { 270 throw new NullArgumentException(); 271 } 272 if (data.length == 0) { 273 throw new ZeroException(LocalizedFormats.VECTOR_MUST_HAVE_AT_LEAST_ONE_ELEMENT); 274 } 275 return new ArrayFieldVector<T>(data[0].getField(), data, true); 276 } 277 278 /** 279 * Create a row {@link RealMatrix} using the data from the input 280 * array. 281 * 282 * @param rowData the input row data 283 * @return a 1 x rowData.length RealMatrix 284 * @throws NoDataException if {@code rowData} is empty. 285 * @throws NullArgumentException if {@code rowData} is {@code null}. 286 */ 287 public static RealMatrix createRowRealMatrix(double[] rowData) 288 throws NoDataException, NullArgumentException { 289 if (rowData == null) { 290 throw new NullArgumentException(); 291 } 292 final int nCols = rowData.length; 293 final RealMatrix m = createRealMatrix(1, nCols); 294 for (int i = 0; i < nCols; ++i) { 295 m.setEntry(0, i, rowData[i]); 296 } 297 return m; 298 } 299 300 /** 301 * Create a row {@link FieldMatrix} using the data from the input 302 * array. 303 * 304 * @param <T> the type of the field elements 305 * @param rowData the input row data 306 * @return a 1 x rowData.length FieldMatrix 307 * @throws NoDataException if {@code rowData} is empty. 308 * @throws NullArgumentException if {@code rowData} is {@code null}. 309 */ 310 public static <T extends FieldElement<T>> FieldMatrix<T> 311 createRowFieldMatrix(final T[] rowData) 312 throws NoDataException, NullArgumentException { 313 if (rowData == null) { 314 throw new NullArgumentException(); 315 } 316 final int nCols = rowData.length; 317 if (nCols == 0) { 318 throw new NoDataException(LocalizedFormats.AT_LEAST_ONE_COLUMN); 319 } 320 final FieldMatrix<T> m = createFieldMatrix(rowData[0].getField(), 1, nCols); 321 for (int i = 0; i < nCols; ++i) { 322 m.setEntry(0, i, rowData[i]); 323 } 324 return m; 325 } 326 327 /** 328 * Creates a column {@link RealMatrix} using the data from the input 329 * array. 330 * 331 * @param columnData the input column data 332 * @return a columnData x 1 RealMatrix 333 * @throws NoDataException if {@code columnData} is empty. 334 * @throws NullArgumentException if {@code columnData} is {@code null}. 335 */ 336 public static RealMatrix createColumnRealMatrix(double[] columnData) 337 throws NoDataException, NullArgumentException { 338 if (columnData == null) { 339 throw new NullArgumentException(); 340 } 341 final int nRows = columnData.length; 342 final RealMatrix m = createRealMatrix(nRows, 1); 343 for (int i = 0; i < nRows; ++i) { 344 m.setEntry(i, 0, columnData[i]); 345 } 346 return m; 347 } 348 349 /** 350 * Creates a column {@link FieldMatrix} using the data from the input 351 * array. 352 * 353 * @param <T> the type of the field elements 354 * @param columnData the input column data 355 * @return a columnData x 1 FieldMatrix 356 * @throws NoDataException if {@code data} is empty. 357 * @throws NullArgumentException if {@code columnData} is {@code null}. 358 */ 359 public static <T extends FieldElement<T>> FieldMatrix<T> 360 createColumnFieldMatrix(final T[] columnData) 361 throws NoDataException, NullArgumentException { 362 if (columnData == null) { 363 throw new NullArgumentException(); 364 } 365 final int nRows = columnData.length; 366 if (nRows == 0) { 367 throw new NoDataException(LocalizedFormats.AT_LEAST_ONE_ROW); 368 } 369 final FieldMatrix<T> m = createFieldMatrix(columnData[0].getField(), nRows, 1); 370 for (int i = 0; i < nRows; ++i) { 371 m.setEntry(i, 0, columnData[i]); 372 } 373 return m; 374 } 375 376 /** 377 * Checks whether a matrix is symmetric, within a given relative tolerance. 378 * 379 * @param matrix Matrix to check. 380 * @param relativeTolerance Tolerance of the symmetry check. 381 * @param raiseException If {@code true}, an exception will be raised if 382 * the matrix is not symmetric. 383 * @return {@code true} if {@code matrix} is symmetric. 384 * @throws NonSquareMatrixException if the matrix is not square. 385 * @throws NonSymmetricMatrixException if the matrix is not symmetric. 386 */ 387 private static boolean isSymmetricInternal(RealMatrix matrix, 388 double relativeTolerance, 389 boolean raiseException) { 390 final int rows = matrix.getRowDimension(); 391 if (rows != matrix.getColumnDimension()) { 392 if (raiseException) { 393 throw new NonSquareMatrixException(rows, matrix.getColumnDimension()); 394 } else { 395 return false; 396 } 397 } 398 for (int i = 0; i < rows; i++) { 399 for (int j = i + 1; j < rows; j++) { 400 final double mij = matrix.getEntry(i, j); 401 final double mji = matrix.getEntry(j, i); 402 if (FastMath.abs(mij - mji) > 403 FastMath.max(FastMath.abs(mij), FastMath.abs(mji)) * relativeTolerance) { 404 if (raiseException) { 405 throw new NonSymmetricMatrixException(i, j, relativeTolerance); 406 } else { 407 return false; 408 } 409 } 410 } 411 } 412 return true; 413 } 414 415 /** 416 * Checks whether a matrix is symmetric. 417 * 418 * @param matrix Matrix to check. 419 * @param eps Relative tolerance. 420 * @throws NonSquareMatrixException if the matrix is not square. 421 * @throws NonSymmetricMatrixException if the matrix is not symmetric. 422 * @since 3.1 423 */ 424 public static void checkSymmetric(RealMatrix matrix, 425 double eps) { 426 isSymmetricInternal(matrix, eps, true); 427 } 428 429 /** 430 * Checks whether a matrix is symmetric. 431 * 432 * @param matrix Matrix to check. 433 * @param eps Relative tolerance. 434 * @return {@code true} if {@code matrix} is symmetric. 435 * @since 3.1 436 */ 437 public static boolean isSymmetric(RealMatrix matrix, 438 double eps) { 439 return isSymmetricInternal(matrix, eps, false); 440 } 441 442 /** 443 * Check if matrix indices are valid. 444 * 445 * @param m Matrix. 446 * @param row Row index to check. 447 * @param column Column index to check. 448 * @throws OutOfRangeException if {@code row} or {@code column} is not 449 * a valid index. 450 */ 451 public static void checkMatrixIndex(final AnyMatrix m, 452 final int row, final int column) 453 throws OutOfRangeException { 454 checkRowIndex(m, row); 455 checkColumnIndex(m, column); 456 } 457 458 /** 459 * Check if a row index is valid. 460 * 461 * @param m Matrix. 462 * @param row Row index to check. 463 * @throws OutOfRangeException if {@code row} is not a valid index. 464 */ 465 public static void checkRowIndex(final AnyMatrix m, final int row) 466 throws OutOfRangeException { 467 if (row < 0 || 468 row >= m.getRowDimension()) { 469 throw new OutOfRangeException(LocalizedFormats.ROW_INDEX, 470 row, 0, m.getRowDimension() - 1); 471 } 472 } 473 474 /** 475 * Check if a column index is valid. 476 * 477 * @param m Matrix. 478 * @param column Column index to check. 479 * @throws OutOfRangeException if {@code column} is not a valid index. 480 */ 481 public static void checkColumnIndex(final AnyMatrix m, final int column) 482 throws OutOfRangeException { 483 if (column < 0 || column >= m.getColumnDimension()) { 484 throw new OutOfRangeException(LocalizedFormats.COLUMN_INDEX, 485 column, 0, m.getColumnDimension() - 1); 486 } 487 } 488 489 /** 490 * Check if submatrix ranges indices are valid. 491 * Rows and columns are indicated counting from 0 to {@code n - 1}. 492 * 493 * @param m Matrix. 494 * @param startRow Initial row index. 495 * @param endRow Final row index. 496 * @param startColumn Initial column index. 497 * @param endColumn Final column index. 498 * @throws OutOfRangeException if the indices are invalid. 499 * @throws NumberIsTooSmallException if {@code endRow < startRow} or 500 * {@code endColumn < startColumn}. 501 */ 502 public static void checkSubMatrixIndex(final AnyMatrix m, 503 final int startRow, final int endRow, 504 final int startColumn, final int endColumn) 505 throws NumberIsTooSmallException, OutOfRangeException { 506 checkRowIndex(m, startRow); 507 checkRowIndex(m, endRow); 508 if (endRow < startRow) { 509 throw new NumberIsTooSmallException(LocalizedFormats.INITIAL_ROW_AFTER_FINAL_ROW, 510 endRow, startRow, false); 511 } 512 513 checkColumnIndex(m, startColumn); 514 checkColumnIndex(m, endColumn); 515 if (endColumn < startColumn) { 516 throw new NumberIsTooSmallException(LocalizedFormats.INITIAL_COLUMN_AFTER_FINAL_COLUMN, 517 endColumn, startColumn, false); 518 } 519 520 521 } 522 523 /** 524 * Check if submatrix ranges indices are valid. 525 * Rows and columns are indicated counting from 0 to n-1. 526 * 527 * @param m Matrix. 528 * @param selectedRows Array of row indices. 529 * @param selectedColumns Array of column indices. 530 * @throws NullArgumentException if {@code selectedRows} or 531 * {@code selectedColumns} are {@code null}. 532 * @throws NoDataException if the row or column selections are empty (zero 533 * length). 534 * @throws OutOfRangeException if row or column selections are not valid. 535 */ 536 public static void checkSubMatrixIndex(final AnyMatrix m, 537 final int[] selectedRows, 538 final int[] selectedColumns) 539 throws NoDataException, NullArgumentException, OutOfRangeException { 540 if (selectedRows == null) { 541 throw new NullArgumentException(); 542 } 543 if (selectedColumns == null) { 544 throw new NullArgumentException(); 545 } 546 if (selectedRows.length == 0) { 547 throw new NoDataException(LocalizedFormats.EMPTY_SELECTED_ROW_INDEX_ARRAY); 548 } 549 if (selectedColumns.length == 0) { 550 throw new NoDataException(LocalizedFormats.EMPTY_SELECTED_COLUMN_INDEX_ARRAY); 551 } 552 553 for (final int row : selectedRows) { 554 checkRowIndex(m, row); 555 } 556 for (final int column : selectedColumns) { 557 checkColumnIndex(m, column); 558 } 559 } 560 561 /** 562 * Check if matrices are addition compatible. 563 * 564 * @param left Left hand side matrix. 565 * @param right Right hand side matrix. 566 * @throws MatrixDimensionMismatchException if the matrices are not addition 567 * compatible. 568 */ 569 public static void checkAdditionCompatible(final AnyMatrix left, final AnyMatrix right) 570 throws MatrixDimensionMismatchException { 571 if ((left.getRowDimension() != right.getRowDimension()) || 572 (left.getColumnDimension() != right.getColumnDimension())) { 573 throw new MatrixDimensionMismatchException(left.getRowDimension(), left.getColumnDimension(), 574 right.getRowDimension(), right.getColumnDimension()); 575 } 576 } 577 578 /** 579 * Check if matrices are subtraction compatible 580 * 581 * @param left Left hand side matrix. 582 * @param right Right hand side matrix. 583 * @throws MatrixDimensionMismatchException if the matrices are not addition 584 * compatible. 585 */ 586 public static void checkSubtractionCompatible(final AnyMatrix left, final AnyMatrix right) 587 throws MatrixDimensionMismatchException { 588 if ((left.getRowDimension() != right.getRowDimension()) || 589 (left.getColumnDimension() != right.getColumnDimension())) { 590 throw new MatrixDimensionMismatchException(left.getRowDimension(), left.getColumnDimension(), 591 right.getRowDimension(), right.getColumnDimension()); 592 } 593 } 594 595 /** 596 * Check if matrices are multiplication compatible 597 * 598 * @param left Left hand side matrix. 599 * @param right Right hand side matrix. 600 * @throws DimensionMismatchException if matrices are not multiplication 601 * compatible. 602 */ 603 public static void checkMultiplicationCompatible(final AnyMatrix left, final AnyMatrix right) 604 throws DimensionMismatchException { 605 606 if (left.getColumnDimension() != right.getRowDimension()) { 607 throw new DimensionMismatchException(left.getColumnDimension(), 608 right.getRowDimension()); 609 } 610 } 611 612 /** 613 * Convert a {@link FieldMatrix}/{@link Fraction} matrix to a {@link RealMatrix}. 614 * @param m Matrix to convert. 615 * @return the converted matrix. 616 */ 617 public static Array2DRowRealMatrix fractionMatrixToRealMatrix(final FieldMatrix<Fraction> m) { 618 final FractionMatrixConverter converter = new FractionMatrixConverter(); 619 m.walkInOptimizedOrder(converter); 620 return converter.getConvertedMatrix(); 621 } 622 623 /** Converter for {@link FieldMatrix}/{@link Fraction}. */ 624 private static class FractionMatrixConverter extends DefaultFieldMatrixPreservingVisitor<Fraction> { 625 /** Converted array. */ 626 private double[][] data; 627 /** Simple constructor. */ 628 FractionMatrixConverter() { 629 super(Fraction.ZERO); 630 } 631 632 /** {@inheritDoc} */ 633 @Override 634 public void start(int rows, int columns, 635 int startRow, int endRow, int startColumn, int endColumn) { 636 data = new double[rows][columns]; 637 } 638 639 /** {@inheritDoc} */ 640 @Override 641 public void visit(int row, int column, Fraction value) { 642 data[row][column] = value.doubleValue(); 643 } 644 645 /** 646 * Get the converted matrix. 647 * 648 * @return the converted matrix. 649 */ 650 Array2DRowRealMatrix getConvertedMatrix() { 651 return new Array2DRowRealMatrix(data, false); 652 } 653 654 } 655 656 /** 657 * Convert a {@link FieldMatrix}/{@link BigFraction} matrix to a {@link RealMatrix}. 658 * 659 * @param m Matrix to convert. 660 * @return the converted matrix. 661 */ 662 public static Array2DRowRealMatrix bigFractionMatrixToRealMatrix(final FieldMatrix<BigFraction> m) { 663 final BigFractionMatrixConverter converter = new BigFractionMatrixConverter(); 664 m.walkInOptimizedOrder(converter); 665 return converter.getConvertedMatrix(); 666 } 667 668 /** Converter for {@link FieldMatrix}/{@link BigFraction}. */ 669 private static class BigFractionMatrixConverter extends DefaultFieldMatrixPreservingVisitor<BigFraction> { 670 /** Converted array. */ 671 private double[][] data; 672 /** Simple constructor. */ 673 BigFractionMatrixConverter() { 674 super(BigFraction.ZERO); 675 } 676 677 /** {@inheritDoc} */ 678 @Override 679 public void start(int rows, int columns, 680 int startRow, int endRow, int startColumn, int endColumn) { 681 data = new double[rows][columns]; 682 } 683 684 /** {@inheritDoc} */ 685 @Override 686 public void visit(int row, int column, BigFraction value) { 687 data[row][column] = value.doubleValue(); 688 } 689 690 /** 691 * Get the converted matrix. 692 * 693 * @return the converted matrix. 694 */ 695 Array2DRowRealMatrix getConvertedMatrix() { 696 return new Array2DRowRealMatrix(data, false); 697 } 698 } 699 700 /** Serialize a {@link RealVector}. 701 * <p> 702 * This method is intended to be called from within a private 703 * <code>writeObject</code> method (after a call to 704 * <code>oos.defaultWriteObject()</code>) in a class that has a 705 * {@link RealVector} field, which should be declared <code>transient</code>. 706 * This way, the default handling does not serialize the vector (the {@link 707 * RealVector} interface is not serializable by default) but this method does 708 * serialize it specifically. 709 * </p> 710 * <p> 711 * The following example shows how a simple class with a name and a real vector 712 * should be written: 713 * <pre><code> 714 * public class NamedVector implements Serializable { 715 * 716 * private final String name; 717 * private final transient RealVector coefficients; 718 * 719 * // omitted constructors, getters ... 720 * 721 * private void writeObject(ObjectOutputStream oos) throws IOException { 722 * oos.defaultWriteObject(); // takes care of name field 723 * MatrixUtils.serializeRealVector(coefficients, oos); 724 * } 725 * 726 * private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { 727 * ois.defaultReadObject(); // takes care of name field 728 * MatrixUtils.deserializeRealVector(this, "coefficients", ois); 729 * } 730 * 731 * } 732 * </code></pre> 733 * </p> 734 * 735 * @param vector real vector to serialize 736 * @param oos stream where the real vector should be written 737 * @exception IOException if object cannot be written to stream 738 * @see #deserializeRealVector(Object, String, ObjectInputStream) 739 */ 740 public static void serializeRealVector(final RealVector vector, 741 final ObjectOutputStream oos) 742 throws IOException { 743 final int n = vector.getDimension(); 744 oos.writeInt(n); 745 for (int i = 0; i < n; ++i) { 746 oos.writeDouble(vector.getEntry(i)); 747 } 748 } 749 750 /** Deserialize a {@link RealVector} field in a class. 751 * <p> 752 * This method is intended to be called from within a private 753 * <code>readObject</code> method (after a call to 754 * <code>ois.defaultReadObject()</code>) in a class that has a 755 * {@link RealVector} field, which should be declared <code>transient</code>. 756 * This way, the default handling does not deserialize the vector (the {@link 757 * RealVector} interface is not serializable by default) but this method does 758 * deserialize it specifically. 759 * </p> 760 * @param instance instance in which the field must be set up 761 * @param fieldName name of the field within the class (may be private and final) 762 * @param ois stream from which the real vector should be read 763 * @exception ClassNotFoundException if a class in the stream cannot be found 764 * @exception IOException if object cannot be read from the stream 765 * @see #serializeRealVector(RealVector, ObjectOutputStream) 766 */ 767 public static void deserializeRealVector(final Object instance, 768 final String fieldName, 769 final ObjectInputStream ois) 770 throws ClassNotFoundException, IOException { 771 try { 772 773 // read the vector data 774 final int n = ois.readInt(); 775 final double[] data = new double[n]; 776 for (int i = 0; i < n; ++i) { 777 data[i] = ois.readDouble(); 778 } 779 780 // create the instance 781 final RealVector vector = new ArrayRealVector(data, false); 782 783 // set up the field 784 final java.lang.reflect.Field f = 785 instance.getClass().getDeclaredField(fieldName); 786 f.setAccessible(true); 787 f.set(instance, vector); 788 789 } catch (NoSuchFieldException nsfe) { 790 IOException ioe = new IOException(); 791 ioe.initCause(nsfe); 792 throw ioe; 793 } catch (IllegalAccessException iae) { 794 IOException ioe = new IOException(); 795 ioe.initCause(iae); 796 throw ioe; 797 } 798 799 } 800 801 /** Serialize a {@link RealMatrix}. 802 * <p> 803 * This method is intended to be called from within a private 804 * <code>writeObject</code> method (after a call to 805 * <code>oos.defaultWriteObject()</code>) in a class that has a 806 * {@link RealMatrix} field, which should be declared <code>transient</code>. 807 * This way, the default handling does not serialize the matrix (the {@link 808 * RealMatrix} interface is not serializable by default) but this method does 809 * serialize it specifically. 810 * </p> 811 * <p> 812 * The following example shows how a simple class with a name and a real matrix 813 * should be written: 814 * <pre><code> 815 * public class NamedMatrix implements Serializable { 816 * 817 * private final String name; 818 * private final transient RealMatrix coefficients; 819 * 820 * // omitted constructors, getters ... 821 * 822 * private void writeObject(ObjectOutputStream oos) throws IOException { 823 * oos.defaultWriteObject(); // takes care of name field 824 * MatrixUtils.serializeRealMatrix(coefficients, oos); 825 * } 826 * 827 * private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { 828 * ois.defaultReadObject(); // takes care of name field 829 * MatrixUtils.deserializeRealMatrix(this, "coefficients", ois); 830 * } 831 * 832 * } 833 * </code></pre> 834 * </p> 835 * 836 * @param matrix real matrix to serialize 837 * @param oos stream where the real matrix should be written 838 * @exception IOException if object cannot be written to stream 839 * @see #deserializeRealMatrix(Object, String, ObjectInputStream) 840 */ 841 public static void serializeRealMatrix(final RealMatrix matrix, 842 final ObjectOutputStream oos) 843 throws IOException { 844 final int n = matrix.getRowDimension(); 845 final int m = matrix.getColumnDimension(); 846 oos.writeInt(n); 847 oos.writeInt(m); 848 for (int i = 0; i < n; ++i) { 849 for (int j = 0; j < m; ++j) { 850 oos.writeDouble(matrix.getEntry(i, j)); 851 } 852 } 853 } 854 855 /** Deserialize a {@link RealMatrix} field in a class. 856 * <p> 857 * This method is intended to be called from within a private 858 * <code>readObject</code> method (after a call to 859 * <code>ois.defaultReadObject()</code>) in a class that has a 860 * {@link RealMatrix} field, which should be declared <code>transient</code>. 861 * This way, the default handling does not deserialize the matrix (the {@link 862 * RealMatrix} interface is not serializable by default) but this method does 863 * deserialize it specifically. 864 * </p> 865 * @param instance instance in which the field must be set up 866 * @param fieldName name of the field within the class (may be private and final) 867 * @param ois stream from which the real matrix should be read 868 * @exception ClassNotFoundException if a class in the stream cannot be found 869 * @exception IOException if object cannot be read from the stream 870 * @see #serializeRealMatrix(RealMatrix, ObjectOutputStream) 871 */ 872 public static void deserializeRealMatrix(final Object instance, 873 final String fieldName, 874 final ObjectInputStream ois) 875 throws ClassNotFoundException, IOException { 876 try { 877 878 // read the matrix data 879 final int n = ois.readInt(); 880 final int m = ois.readInt(); 881 final double[][] data = new double[n][m]; 882 for (int i = 0; i < n; ++i) { 883 final double[] dataI = data[i]; 884 for (int j = 0; j < m; ++j) { 885 dataI[j] = ois.readDouble(); 886 } 887 } 888 889 // create the instance 890 final RealMatrix matrix = new Array2DRowRealMatrix(data, false); 891 892 // set up the field 893 final java.lang.reflect.Field f = 894 instance.getClass().getDeclaredField(fieldName); 895 f.setAccessible(true); 896 f.set(instance, matrix); 897 898 } catch (NoSuchFieldException nsfe) { 899 IOException ioe = new IOException(); 900 ioe.initCause(nsfe); 901 throw ioe; 902 } catch (IllegalAccessException iae) { 903 IOException ioe = new IOException(); 904 ioe.initCause(iae); 905 throw ioe; 906 } 907 } 908 909 /**Solve a system of composed of a Lower Triangular Matrix 910 * {@link RealMatrix}. 911 * <p> 912 * This method is called to solve systems of equations which are 913 * of the lower triangular form. The matrix {@link RealMatrix} 914 * is assumed, though not checked, to be in lower triangular form. 915 * The vector {@link RealVector} is overwritten with the solution. 916 * The matrix is checked that it is square and its dimensions match 917 * the length of the vector. 918 * </p> 919 * @param rm RealMatrix which is lower triangular 920 * @param b RealVector this is overwritten 921 * @throws DimensionMismatchException if the matrix and vector are not 922 * conformable 923 * @throws NonSquareMatrixException if the matrix {@code rm} is not square 924 * @throws MathArithmeticException if the absolute value of one of the diagonal 925 * coefficient of {@code rm} is lower than {@link Precision#SAFE_MIN} 926 */ 927 public static void solveLowerTriangularSystem(RealMatrix rm, RealVector b) 928 throws DimensionMismatchException, MathArithmeticException, 929 NonSquareMatrixException { 930 if ((rm == null) || (b == null) || ( rm.getRowDimension() != b.getDimension())) { 931 throw new DimensionMismatchException( 932 (rm == null) ? 0 : rm.getRowDimension(), 933 (b == null) ? 0 : b.getDimension()); 934 } 935 if( rm.getColumnDimension() != rm.getRowDimension() ){ 936 throw new NonSquareMatrixException(rm.getRowDimension(), 937 rm.getColumnDimension()); 938 } 939 int rows = rm.getRowDimension(); 940 for( int i = 0 ; i < rows ; i++ ){ 941 double diag = rm.getEntry(i, i); 942 if( FastMath.abs(diag) < Precision.SAFE_MIN ){ 943 throw new MathArithmeticException(LocalizedFormats.ZERO_DENOMINATOR); 944 } 945 double bi = b.getEntry(i)/diag; 946 b.setEntry(i, bi ); 947 for( int j = i+1; j< rows; j++ ){ 948 b.setEntry(j, b.getEntry(j)-bi*rm.getEntry(j,i) ); 949 } 950 } 951 } 952 953 /** Solver a system composed of an Upper Triangular Matrix 954 * {@link RealMatrix}. 955 * <p> 956 * This method is called to solve systems of equations which are 957 * of the lower triangular form. The matrix {@link RealMatrix} 958 * is assumed, though not checked, to be in upper triangular form. 959 * The vector {@link RealVector} is overwritten with the solution. 960 * The matrix is checked that it is square and its dimensions match 961 * the length of the vector. 962 * </p> 963 * @param rm RealMatrix which is upper triangular 964 * @param b RealVector this is overwritten 965 * @throws DimensionMismatchException if the matrix and vector are not 966 * conformable 967 * @throws NonSquareMatrixException if the matrix {@code rm} is not 968 * square 969 * @throws MathArithmeticException if the absolute value of one of the diagonal 970 * coefficient of {@code rm} is lower than {@link Precision#SAFE_MIN} 971 */ 972 public static void solveUpperTriangularSystem(RealMatrix rm, RealVector b) 973 throws DimensionMismatchException, MathArithmeticException, 974 NonSquareMatrixException { 975 if ((rm == null) || (b == null) || ( rm.getRowDimension() != b.getDimension())) { 976 throw new DimensionMismatchException( 977 (rm == null) ? 0 : rm.getRowDimension(), 978 (b == null) ? 0 : b.getDimension()); 979 } 980 if( rm.getColumnDimension() != rm.getRowDimension() ){ 981 throw new NonSquareMatrixException(rm.getRowDimension(), 982 rm.getColumnDimension()); 983 } 984 int rows = rm.getRowDimension(); 985 for( int i = rows-1 ; i >-1 ; i-- ){ 986 double diag = rm.getEntry(i, i); 987 if( FastMath.abs(diag) < Precision.SAFE_MIN ){ 988 throw new MathArithmeticException(LocalizedFormats.ZERO_DENOMINATOR); 989 } 990 double bi = b.getEntry(i)/diag; 991 b.setEntry(i, bi ); 992 for( int j = i-1; j>-1; j-- ){ 993 b.setEntry(j, b.getEntry(j)-bi*rm.getEntry(j,i) ); 994 } 995 } 996 } 997 998 /** 999 * Computes the inverse of the given matrix by splitting it into 1000 * 4 sub-matrices. 1001 * 1002 * @param m Matrix whose inverse must be computed. 1003 * @param splitIndex Index that determines the "split" line and 1004 * column. 1005 * The element corresponding to this index will part of the 1006 * upper-left sub-matrix. 1007 * @return the inverse of {@code m}. 1008 * @throws NonSquareMatrixException if {@code m} is not square. 1009 */ 1010 public static RealMatrix blockInverse(RealMatrix m, 1011 int splitIndex) { 1012 final int n = m.getRowDimension(); 1013 if (m.getColumnDimension() != n) { 1014 throw new NonSquareMatrixException(m.getRowDimension(), 1015 m.getColumnDimension()); 1016 } 1017 1018 final int splitIndex1 = splitIndex + 1; 1019 1020 final RealMatrix a = m.getSubMatrix(0, splitIndex, 0, splitIndex); 1021 final RealMatrix b = m.getSubMatrix(0, splitIndex, splitIndex1, n - 1); 1022 final RealMatrix c = m.getSubMatrix(splitIndex1, n - 1, 0, splitIndex); 1023 final RealMatrix d = m.getSubMatrix(splitIndex1, n - 1, splitIndex1, n - 1); 1024 1025 final SingularValueDecomposition aDec = new SingularValueDecomposition(a); 1026 final DecompositionSolver aSolver = aDec.getSolver(); 1027 if (!aSolver.isNonSingular()) { 1028 throw new SingularMatrixException(); 1029 } 1030 final RealMatrix aInv = aSolver.getInverse(); 1031 1032 final SingularValueDecomposition dDec = new SingularValueDecomposition(d); 1033 final DecompositionSolver dSolver = dDec.getSolver(); 1034 if (!dSolver.isNonSingular()) { 1035 throw new SingularMatrixException(); 1036 } 1037 final RealMatrix dInv = dSolver.getInverse(); 1038 1039 final RealMatrix tmp1 = a.subtract(b.multiply(dInv).multiply(c)); 1040 final SingularValueDecomposition tmp1Dec = new SingularValueDecomposition(tmp1); 1041 final DecompositionSolver tmp1Solver = tmp1Dec.getSolver(); 1042 if (!tmp1Solver.isNonSingular()) { 1043 throw new SingularMatrixException(); 1044 } 1045 final RealMatrix result00 = tmp1Solver.getInverse(); 1046 1047 final RealMatrix tmp2 = d.subtract(c.multiply(aInv).multiply(b)); 1048 final SingularValueDecomposition tmp2Dec = new SingularValueDecomposition(tmp2); 1049 final DecompositionSolver tmp2Solver = tmp2Dec.getSolver(); 1050 if (!tmp2Solver.isNonSingular()) { 1051 throw new SingularMatrixException(); 1052 } 1053 final RealMatrix result11 = tmp2Solver.getInverse(); 1054 1055 final RealMatrix result01 = aInv.multiply(b).multiply(result11).scalarMultiply(-1); 1056 final RealMatrix result10 = dInv.multiply(c).multiply(result00).scalarMultiply(-1); 1057 1058 final RealMatrix result = new Array2DRowRealMatrix(n, n); 1059 result.setSubMatrix(result00.getData(), 0, 0); 1060 result.setSubMatrix(result01.getData(), 0, splitIndex1); 1061 result.setSubMatrix(result10.getData(), splitIndex1, 0); 1062 result.setSubMatrix(result11.getData(), splitIndex1, splitIndex1); 1063 1064 return result; 1065 } 1066 1067 /** 1068 * Computes the inverse of the given matrix. 1069 * <p> 1070 * By default, the inverse of the matrix is computed using the QR-decomposition, 1071 * unless a more efficient method can be determined for the input matrix. 1072 * <p> 1073 * Note: this method will use a singularity threshold of 0, 1074 * use {@link #inverse(RealMatrix, double)} if a different threshold is needed. 1075 * 1076 * @param matrix Matrix whose inverse shall be computed 1077 * @return the inverse of {@code matrix} 1078 * @throws NullArgumentException if {@code matrix} is {@code null} 1079 * @throws SingularMatrixException if m is singular 1080 * @throws NonSquareMatrixException if matrix is not square 1081 * @since 3.3 1082 */ 1083 public static RealMatrix inverse(RealMatrix matrix) 1084 throws NullArgumentException, SingularMatrixException, NonSquareMatrixException { 1085 return inverse(matrix, 0); 1086 } 1087 1088 /** 1089 * Computes the inverse of the given matrix. 1090 * <p> 1091 * By default, the inverse of the matrix is computed using the QR-decomposition, 1092 * unless a more efficient method can be determined for the input matrix. 1093 * 1094 * @param matrix Matrix whose inverse shall be computed 1095 * @param threshold Singularity threshold 1096 * @return the inverse of {@code m} 1097 * @throws NullArgumentException if {@code matrix} is {@code null} 1098 * @throws SingularMatrixException if matrix is singular 1099 * @throws NonSquareMatrixException if matrix is not square 1100 * @since 3.3 1101 */ 1102 public static RealMatrix inverse(RealMatrix matrix, double threshold) 1103 throws NullArgumentException, SingularMatrixException, NonSquareMatrixException { 1104 1105 MathUtils.checkNotNull(matrix); 1106 1107 if (!matrix.isSquare()) { 1108 throw new NonSquareMatrixException(matrix.getRowDimension(), 1109 matrix.getColumnDimension()); 1110 } 1111 1112 if (matrix instanceof DiagonalMatrix) { 1113 return ((DiagonalMatrix) matrix).inverse(threshold); 1114 } else { 1115 QRDecomposition decomposition = new QRDecomposition(matrix, threshold); 1116 return decomposition.getSolver().getInverse(); 1117 } 1118 } 1119}