001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math4.legacy.linear; 019 020import java.io.IOException; 021import java.io.ObjectInputStream; 022import java.io.ObjectOutputStream; 023import java.util.Arrays; 024 025import org.apache.commons.math4.legacy.core.Field; 026import org.apache.commons.math4.legacy.core.FieldElement; 027import org.apache.commons.math4.legacy.exception.DimensionMismatchException; 028import org.apache.commons.math4.legacy.exception.MathArithmeticException; 029import org.apache.commons.math4.legacy.exception.NoDataException; 030import org.apache.commons.math4.legacy.exception.NullArgumentException; 031import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException; 032import org.apache.commons.math4.legacy.exception.OutOfRangeException; 033import org.apache.commons.math4.legacy.exception.ZeroException; 034import org.apache.commons.math4.legacy.exception.util.LocalizedFormats; 035import org.apache.commons.math4.core.jdkmath.JdkMath; 036import org.apache.commons.math4.legacy.core.MathArrays; 037import org.apache.commons.numbers.core.Precision; 038 039/** 040 * A collection of static methods that operate on or return matrices. 041 * 042 */ 043public final class MatrixUtils { 044 045 /** 046 * The default format for {@link RealMatrix} objects. 047 * @since 3.1 048 */ 049 public static final RealMatrixFormat DEFAULT_FORMAT = RealMatrixFormat.getInstance(); 050 051 /** 052 * A format for {@link RealMatrix} objects compatible with octave. 053 * @since 3.1 054 */ 055 public static final RealMatrixFormat OCTAVE_FORMAT = new RealMatrixFormat("[", "]", "", "", "; ", ", "); 056 057 /** 058 * Private constructor. 059 */ 060 private MatrixUtils() { 061 super(); 062 } 063 064 /** 065 * Returns a {@link RealMatrix} with specified dimensions. 066 * <p>The type of matrix returned depends on the dimension. Below 067 * 2<sup>12</sup> elements (i.e. 4096 elements or 64×64 for a 068 * square matrix) which can be stored in a 32kB array, a {@link 069 * Array2DRowRealMatrix} instance is built. Above this threshold a {@link 070 * BlockRealMatrix} instance is built.</p> 071 * <p>The matrix elements are all set to 0.0.</p> 072 * @param rows number of rows of the matrix 073 * @param columns number of columns of the matrix 074 * @return RealMatrix with specified dimensions 075 * @see #createRealMatrix(double[][]) 076 */ 077 public static RealMatrix createRealMatrix(final int rows, final int columns) { 078 return (rows * columns <= 4096) ? 079 new Array2DRowRealMatrix(rows, columns) : new BlockRealMatrix(rows, columns); 080 } 081 082 /** 083 * Returns a {@link FieldMatrix} with specified dimensions. 084 * <p>The type of matrix returned depends on the dimension. Below 085 * 2<sup>12</sup> elements (i.e. 4096 elements or 64×64 for a 086 * square matrix), a {@link FieldMatrix} instance is built. Above 087 * this threshold a {@link BlockFieldMatrix} instance is built.</p> 088 * <p>The matrix elements are all set to field.getZero().</p> 089 * @param <T> the type of the field elements 090 * @param field field to which the matrix elements belong 091 * @param rows number of rows of the matrix 092 * @param columns number of columns of the matrix 093 * @return FieldMatrix with specified dimensions 094 * @see #createFieldMatrix(FieldElement[][]) 095 * @since 2.0 096 */ 097 public static <T extends FieldElement<T>> FieldMatrix<T> createFieldMatrix(final Field<T> field, 098 final int rows, 099 final int columns) { 100 return (rows * columns <= 4096) ? 101 new Array2DRowFieldMatrix<>(field, rows, columns) : new BlockFieldMatrix<>(field, rows, columns); 102 } 103 104 /** 105 * Returns a {@link RealMatrix} whose entries are the values in the 106 * the input array. 107 * <p>The type of matrix returned depends on the dimension. Below 108 * 2<sup>12</sup> elements (i.e. 4096 elements or 64×64 for a 109 * square matrix) which can be stored in a 32kB array, a {@link 110 * Array2DRowRealMatrix} instance is built. Above this threshold a {@link 111 * BlockRealMatrix} instance is built.</p> 112 * <p>The input array is copied, not referenced.</p> 113 * 114 * @param data input array 115 * @return RealMatrix containing the values of the array 116 * @throws org.apache.commons.math4.legacy.exception.DimensionMismatchException 117 * if {@code data} is not rectangular (not all rows have the same length). 118 * @throws NoDataException if a row or column is empty. 119 * @throws NullArgumentException if either {@code data} or {@code data[0]} 120 * is {@code null}. 121 * @throws DimensionMismatchException if {@code data} is not rectangular. 122 * @see #createRealMatrix(int, int) 123 */ 124 public static RealMatrix createRealMatrix(double[][] data) 125 throws NullArgumentException, DimensionMismatchException, 126 NoDataException { 127 if (data == null || 128 data[0] == null) { 129 throw new NullArgumentException(); 130 } 131 return (data.length * data[0].length <= 4096) ? 132 new Array2DRowRealMatrix(data) : new BlockRealMatrix(data); 133 } 134 135 /** 136 * Returns a {@link FieldMatrix} whose entries are the values in the 137 * the input array. 138 * <p>The type of matrix returned depends on the dimension. Below 139 * 2<sup>12</sup> elements (i.e. 4096 elements or 64×64 for a 140 * square matrix), a {@link FieldMatrix} instance is built. Above 141 * this threshold a {@link BlockFieldMatrix} instance is built.</p> 142 * <p>The input array is copied, not referenced.</p> 143 * @param <T> the type of the field elements 144 * @param data input array 145 * @return a matrix containing the values of the array. 146 * @throws org.apache.commons.math4.legacy.exception.DimensionMismatchException 147 * if {@code data} is not rectangular (not all rows have the same length). 148 * @throws NoDataException if a row or column is empty. 149 * @throws NullArgumentException if either {@code data} or {@code data[0]} 150 * is {@code null}. 151 * @see #createFieldMatrix(Field, int, int) 152 * @since 2.0 153 */ 154 public static <T extends FieldElement<T>> FieldMatrix<T> createFieldMatrix(T[][] data) 155 throws DimensionMismatchException, NoDataException, NullArgumentException { 156 if (data == null || 157 data[0] == null) { 158 throw new NullArgumentException(); 159 } 160 return (data.length * data[0].length <= 4096) ? 161 new Array2DRowFieldMatrix<>(data) : new BlockFieldMatrix<>(data); 162 } 163 164 /** 165 * Returns <code>dimension x dimension</code> identity matrix. 166 * 167 * @param dimension dimension of identity matrix to generate 168 * @return identity matrix 169 * @throws IllegalArgumentException if dimension is not positive 170 * @since 1.1 171 */ 172 public static RealMatrix createRealIdentityMatrix(int dimension) { 173 final RealMatrix m = createRealMatrix(dimension, dimension); 174 for (int i = 0; i < dimension; ++i) { 175 m.setEntry(i, i, 1.0); 176 } 177 return m; 178 } 179 180 /** 181 * Returns <code>dimension x dimension</code> identity matrix. 182 * 183 * @param <T> the type of the field elements 184 * @param field field to which the elements belong 185 * @param dimension dimension of identity matrix to generate 186 * @return identity matrix 187 * @throws IllegalArgumentException if dimension is not positive 188 * @since 2.0 189 */ 190 public static <T extends FieldElement<T>> FieldMatrix<T> 191 createFieldIdentityMatrix(final Field<T> field, final int dimension) { 192 final T zero = field.getZero(); 193 final T one = field.getOne(); 194 final T[][] d = MathArrays.buildArray(field, dimension, dimension); 195 for (int row = 0; row < dimension; row++) { 196 final T[] dRow = d[row]; 197 Arrays.fill(dRow, zero); 198 dRow[row] = one; 199 } 200 return new Array2DRowFieldMatrix<>(field, d, false); 201 } 202 203 /** 204 * Creates a diagonal matrix with the specified diagonal elements. 205 * 206 * @param diagonal Diagonal elements of the matrix. 207 * The array elements will be copied. 208 * @return a diagonal matrix instance. 209 * 210 * @see #createRealMatrixWithDiagonal(double[]) 211 * @since 2.0 212 */ 213 public static DiagonalMatrix createRealDiagonalMatrix(final double[] diagonal) { 214 return new DiagonalMatrix(diagonal, true); 215 } 216 217 /** 218 * Creates a dense matrix with the specified diagonal elements. 219 * 220 * @param diagonal Diagonal elements of the matrix. 221 * @return a matrix instance. 222 * 223 * @see #createRealDiagonalMatrix(double[]) 224 * @since 4.0 225 */ 226 public static RealMatrix createRealMatrixWithDiagonal(final double[] diagonal) { 227 final int size = diagonal.length; 228 final RealMatrix m = createRealMatrix(size, size); 229 for (int i = 0; i < size; i++) { 230 m.setEntry(i, i, diagonal[i]); 231 } 232 return m; 233 } 234 235 /** 236 * Returns a diagonal matrix with specified elements. 237 * 238 * @param <T> the type of the field elements 239 * @param diagonal diagonal elements of the matrix (the array elements 240 * will be copied) 241 * @return diagonal matrix 242 * @since 2.0 243 */ 244 public static <T extends FieldElement<T>> FieldMatrix<T> 245 createFieldDiagonalMatrix(final T[] diagonal) { 246 final FieldMatrix<T> m = 247 createFieldMatrix(diagonal[0].getField(), diagonal.length, diagonal.length); 248 for (int i = 0; i < diagonal.length; ++i) { 249 m.setEntry(i, i, diagonal[i]); 250 } 251 return m; 252 } 253 254 /** 255 * Creates a {@link RealVector} using the data from the input array. 256 * 257 * @param data the input data 258 * @return a data.length RealVector 259 * @throws NoDataException if {@code data} is empty. 260 * @throws NullArgumentException if {@code data} is {@code null}. 261 */ 262 public static RealVector createRealVector(double[] data) 263 throws NoDataException, NullArgumentException { 264 if (data == null) { 265 throw new NullArgumentException(); 266 } 267 return new ArrayRealVector(data, true); 268 } 269 270 /** 271 * Creates a {@link FieldVector} using the data from the input array. 272 * 273 * @param <T> the type of the field elements 274 * @param data the input data 275 * @return a data.length FieldVector 276 * @throws NoDataException if {@code data} is empty. 277 * @throws NullArgumentException if {@code data} is {@code null}. 278 * @throws ZeroException if {@code data} has 0 elements 279 */ 280 public static <T extends FieldElement<T>> FieldVector<T> createFieldVector(final T[] data) 281 throws NoDataException, NullArgumentException, ZeroException { 282 if (data == null) { 283 throw new NullArgumentException(); 284 } 285 if (data.length == 0) { 286 throw new ZeroException(LocalizedFormats.VECTOR_MUST_HAVE_AT_LEAST_ONE_ELEMENT); 287 } 288 return new ArrayFieldVector<>(data[0].getField(), data, true); 289 } 290 291 /** 292 * Create a row {@link RealMatrix} using the data from the input 293 * array. 294 * 295 * @param rowData the input row data 296 * @return a 1 x rowData.length RealMatrix 297 * @throws NoDataException if {@code rowData} is empty. 298 * @throws NullArgumentException if {@code rowData} is {@code null}. 299 */ 300 public static RealMatrix createRowRealMatrix(double[] rowData) 301 throws NoDataException, NullArgumentException { 302 if (rowData == null) { 303 throw new NullArgumentException(); 304 } 305 final int nCols = rowData.length; 306 final RealMatrix m = createRealMatrix(1, nCols); 307 for (int i = 0; i < nCols; ++i) { 308 m.setEntry(0, i, rowData[i]); 309 } 310 return m; 311 } 312 313 /** 314 * Create a row {@link FieldMatrix} using the data from the input 315 * array. 316 * 317 * @param <T> the type of the field elements 318 * @param rowData the input row data 319 * @return a 1 x rowData.length FieldMatrix 320 * @throws NoDataException if {@code rowData} is empty. 321 * @throws NullArgumentException if {@code rowData} is {@code null}. 322 */ 323 public static <T extends FieldElement<T>> FieldMatrix<T> 324 createRowFieldMatrix(final T[] rowData) 325 throws NoDataException, NullArgumentException { 326 if (rowData == null) { 327 throw new NullArgumentException(); 328 } 329 final int nCols = rowData.length; 330 if (nCols == 0) { 331 throw new NoDataException(LocalizedFormats.AT_LEAST_ONE_COLUMN); 332 } 333 final FieldMatrix<T> m = createFieldMatrix(rowData[0].getField(), 1, nCols); 334 for (int i = 0; i < nCols; ++i) { 335 m.setEntry(0, i, rowData[i]); 336 } 337 return m; 338 } 339 340 /** 341 * Creates a column {@link RealMatrix} using the data from the input 342 * array. 343 * 344 * @param columnData the input column data 345 * @return a columnData x 1 RealMatrix 346 * @throws NoDataException if {@code columnData} is empty. 347 * @throws NullArgumentException if {@code columnData} is {@code null}. 348 */ 349 public static RealMatrix createColumnRealMatrix(double[] columnData) 350 throws NoDataException, NullArgumentException { 351 if (columnData == null) { 352 throw new NullArgumentException(); 353 } 354 final int nRows = columnData.length; 355 final RealMatrix m = createRealMatrix(nRows, 1); 356 for (int i = 0; i < nRows; ++i) { 357 m.setEntry(i, 0, columnData[i]); 358 } 359 return m; 360 } 361 362 /** 363 * Creates a column {@link FieldMatrix} using the data from the input 364 * array. 365 * 366 * @param <T> the type of the field elements 367 * @param columnData the input column data 368 * @return a columnData x 1 FieldMatrix 369 * @throws NoDataException if {@code data} is empty. 370 * @throws NullArgumentException if {@code columnData} is {@code null}. 371 */ 372 public static <T extends FieldElement<T>> FieldMatrix<T> 373 createColumnFieldMatrix(final T[] columnData) 374 throws NoDataException, NullArgumentException { 375 if (columnData == null) { 376 throw new NullArgumentException(); 377 } 378 final int nRows = columnData.length; 379 if (nRows == 0) { 380 throw new NoDataException(LocalizedFormats.AT_LEAST_ONE_ROW); 381 } 382 final FieldMatrix<T> m = createFieldMatrix(columnData[0].getField(), nRows, 1); 383 for (int i = 0; i < nRows; ++i) { 384 m.setEntry(i, 0, columnData[i]); 385 } 386 return m; 387 } 388 389 /** 390 * Checks whether a matrix is symmetric, within a given relative tolerance. 391 * 392 * @param matrix Matrix to check. 393 * @param relativeTolerance Tolerance of the symmetry check. 394 * @param raiseException If {@code true}, an exception will be raised if 395 * the matrix is not symmetric. 396 * @return {@code true} if {@code matrix} is symmetric. 397 * @throws NonSquareMatrixException if the matrix is not square. 398 * @throws NonSymmetricMatrixException if the matrix is not symmetric. 399 */ 400 private static boolean isSymmetricInternal(RealMatrix matrix, 401 double relativeTolerance, 402 boolean raiseException) { 403 final int rows = matrix.getRowDimension(); 404 if (rows != matrix.getColumnDimension()) { 405 if (raiseException) { 406 throw new NonSquareMatrixException(rows, matrix.getColumnDimension()); 407 } else { 408 return false; 409 } 410 } 411 for (int i = 0; i < rows; i++) { 412 for (int j = i + 1; j < rows; j++) { 413 final double mij = matrix.getEntry(i, j); 414 final double mji = matrix.getEntry(j, i); 415 if (JdkMath.abs(mij - mji) > 416 JdkMath.max(JdkMath.abs(mij), JdkMath.abs(mji)) * relativeTolerance) { 417 if (raiseException) { 418 throw new NonSymmetricMatrixException(i, j, relativeTolerance); 419 } else { 420 return false; 421 } 422 } 423 } 424 } 425 return true; 426 } 427 428 /** 429 * Checks whether a matrix is symmetric. 430 * 431 * @param matrix Matrix to check. 432 * @param eps Relative tolerance. 433 * @throws NonSquareMatrixException if the matrix is not square. 434 * @throws NonSymmetricMatrixException if the matrix is not symmetric. 435 * @since 3.1 436 */ 437 public static void checkSymmetric(RealMatrix matrix, 438 double eps) { 439 isSymmetricInternal(matrix, eps, true); 440 } 441 442 /** 443 * Checks whether a matrix is symmetric. 444 * 445 * @param matrix Matrix to check. 446 * @param eps Relative tolerance. 447 * @return {@code true} if {@code matrix} is symmetric. 448 * @since 3.1 449 */ 450 public static boolean isSymmetric(RealMatrix matrix, 451 double eps) { 452 return isSymmetricInternal(matrix, eps, false); 453 } 454 455 /** 456 * Check if matrix indices are valid. 457 * 458 * @param m Matrix. 459 * @param row Row index to check. 460 * @param column Column index to check. 461 * @throws OutOfRangeException if {@code row} or {@code column} is not 462 * a valid index. 463 */ 464 public static void checkMatrixIndex(final AnyMatrix m, 465 final int row, final int column) 466 throws OutOfRangeException { 467 checkRowIndex(m, row); 468 checkColumnIndex(m, column); 469 } 470 471 /** 472 * Check if a row index is valid. 473 * 474 * @param m Matrix. 475 * @param row Row index to check. 476 * @throws OutOfRangeException if {@code row} is not a valid index. 477 */ 478 public static void checkRowIndex(final AnyMatrix m, final int row) 479 throws OutOfRangeException { 480 if (row < 0 || 481 row >= m.getRowDimension()) { 482 throw new OutOfRangeException(LocalizedFormats.ROW_INDEX, 483 row, 0, m.getRowDimension() - 1); 484 } 485 } 486 487 /** 488 * Check if a column index is valid. 489 * 490 * @param m Matrix. 491 * @param column Column index to check. 492 * @throws OutOfRangeException if {@code column} is not a valid index. 493 */ 494 public static void checkColumnIndex(final AnyMatrix m, final int column) 495 throws OutOfRangeException { 496 if (column < 0 || column >= m.getColumnDimension()) { 497 throw new OutOfRangeException(LocalizedFormats.COLUMN_INDEX, 498 column, 0, m.getColumnDimension() - 1); 499 } 500 } 501 502 /** 503 * Check if submatrix ranges indices are valid. 504 * Rows and columns are indicated counting from 0 to {@code n - 1}. 505 * 506 * @param m Matrix. 507 * @param startRow Initial row index. 508 * @param endRow Final row index. 509 * @param startColumn Initial column index. 510 * @param endColumn Final column index. 511 * @throws OutOfRangeException if the indices are invalid. 512 * @throws NumberIsTooSmallException if {@code endRow < startRow} or 513 * {@code endColumn < startColumn}. 514 */ 515 public static void checkSubMatrixIndex(final AnyMatrix m, 516 final int startRow, final int endRow, 517 final int startColumn, final int endColumn) 518 throws NumberIsTooSmallException, OutOfRangeException { 519 checkRowIndex(m, startRow); 520 checkRowIndex(m, endRow); 521 if (endRow < startRow) { 522 throw new NumberIsTooSmallException(LocalizedFormats.INITIAL_ROW_AFTER_FINAL_ROW, 523 endRow, startRow, false); 524 } 525 526 checkColumnIndex(m, startColumn); 527 checkColumnIndex(m, endColumn); 528 if (endColumn < startColumn) { 529 throw new NumberIsTooSmallException(LocalizedFormats.INITIAL_COLUMN_AFTER_FINAL_COLUMN, 530 endColumn, startColumn, false); 531 } 532 } 533 534 /** 535 * Check if submatrix ranges indices are valid. 536 * Rows and columns are indicated counting from 0 to n-1. 537 * 538 * @param m Matrix. 539 * @param selectedRows Array of row indices. 540 * @param selectedColumns Array of column indices. 541 * @throws NullArgumentException if {@code selectedRows} or 542 * {@code selectedColumns} are {@code null}. 543 * @throws NoDataException if the row or column selections are empty (zero 544 * length). 545 * @throws OutOfRangeException if row or column selections are not valid. 546 */ 547 public static void checkSubMatrixIndex(final AnyMatrix m, 548 final int[] selectedRows, 549 final int[] selectedColumns) 550 throws NoDataException, NullArgumentException, OutOfRangeException { 551 if (selectedRows == null) { 552 throw new NullArgumentException(); 553 } 554 if (selectedColumns == null) { 555 throw new NullArgumentException(); 556 } 557 if (selectedRows.length == 0) { 558 throw new NoDataException(LocalizedFormats.EMPTY_SELECTED_ROW_INDEX_ARRAY); 559 } 560 if (selectedColumns.length == 0) { 561 throw new NoDataException(LocalizedFormats.EMPTY_SELECTED_COLUMN_INDEX_ARRAY); 562 } 563 564 for (final int row : selectedRows) { 565 checkRowIndex(m, row); 566 } 567 for (final int column : selectedColumns) { 568 checkColumnIndex(m, column); 569 } 570 } 571 572 /** 573 * Check if matrices are addition compatible. 574 * 575 * @param left Left hand side matrix. 576 * @param right Right hand side matrix. 577 * @throws MatrixDimensionMismatchException if the matrices are not addition 578 * compatible. 579 */ 580 public static void checkAdditionCompatible(final AnyMatrix left, final AnyMatrix right) { 581 left.checkAdd(right); 582 } 583 584 /** 585 * Check if matrices are subtraction compatible. 586 * 587 * @param left Left hand side matrix. 588 * @param right Right hand side matrix. 589 * @throws MatrixDimensionMismatchException if the matrices are not addition 590 * compatible. 591 */ 592 public static void checkSubtractionCompatible(final AnyMatrix left, final AnyMatrix right) { 593 left.checkAdd(right); 594 } 595 596 /** 597 * Check if matrices are multiplication compatible. 598 * 599 * @param left Left hand side matrix. 600 * @param right Right hand side matrix. 601 * @throws DimensionMismatchException if matrices are not multiplication 602 * compatible. 603 */ 604 public static void checkMultiplicationCompatible(final AnyMatrix left, final AnyMatrix right) { 605 left.checkMultiply(right); 606 } 607 608 /** Serialize a {@link RealVector}. 609 * <p> 610 * This method is intended to be called from within a private 611 * <code>writeObject</code> method (after a call to 612 * <code>oos.defaultWriteObject()</code>) in a class that has a 613 * {@link RealVector} field, which should be declared <code>transient</code>. 614 * This way, the default handling does not serialize the vector (the {@link 615 * RealVector} interface is not serializable by default) but this method does 616 * serialize it specifically. 617 * </p> 618 * <p> 619 * The following example shows how a simple class with a name and a real vector 620 * should be written: 621 * <pre><code> 622 * public class NamedVector implements Serializable { 623 * 624 * private final String name; 625 * private final transient RealVector coefficients; 626 * 627 * // omitted constructors, getters ... 628 * 629 * private void writeObject(ObjectOutputStream oos) throws IOException { 630 * oos.defaultWriteObject(); // takes care of name field 631 * MatrixUtils.serializeRealVector(coefficients, oos); 632 * } 633 * 634 * private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { 635 * ois.defaultReadObject(); // takes care of name field 636 * MatrixUtils.deserializeRealVector(this, "coefficients", ois); 637 * } 638 * 639 * } 640 * </code></pre> 641 * 642 * @param vector real vector to serialize 643 * @param oos stream where the real vector should be written 644 * @exception IOException if object cannot be written to stream 645 * @see #deserializeRealVector(Object, String, ObjectInputStream) 646 */ 647 public static void serializeRealVector(final RealVector vector, 648 final ObjectOutputStream oos) 649 throws IOException { 650 final int n = vector.getDimension(); 651 oos.writeInt(n); 652 for (int i = 0; i < n; ++i) { 653 oos.writeDouble(vector.getEntry(i)); 654 } 655 } 656 657 /** Deserialize a {@link RealVector} field in a class. 658 * <p> 659 * This method is intended to be called from within a private 660 * <code>readObject</code> method (after a call to 661 * <code>ois.defaultReadObject()</code>) in a class that has a 662 * {@link RealVector} field, which should be declared <code>transient</code>. 663 * This way, the default handling does not deserialize the vector (the {@link 664 * RealVector} interface is not serializable by default) but this method does 665 * deserialize it specifically. 666 * </p> 667 * @param instance instance in which the field must be set up 668 * @param fieldName name of the field within the class (may be private and final) 669 * @param ois stream from which the real vector should be read 670 * @exception ClassNotFoundException if a class in the stream cannot be found 671 * @exception IOException if object cannot be read from the stream 672 * @see #serializeRealVector(RealVector, ObjectOutputStream) 673 */ 674 public static void deserializeRealVector(final Object instance, 675 final String fieldName, 676 final ObjectInputStream ois) 677 throws ClassNotFoundException, IOException { 678 try { 679 680 // read the vector data 681 final int n = ois.readInt(); 682 final double[] data = new double[n]; 683 for (int i = 0; i < n; ++i) { 684 data[i] = ois.readDouble(); 685 } 686 687 // create the instance 688 final RealVector vector = new ArrayRealVector(data, false); 689 690 // set up the field 691 final java.lang.reflect.Field f = 692 instance.getClass().getDeclaredField(fieldName); 693 f.setAccessible(true); 694 f.set(instance, vector); 695 } catch (NoSuchFieldException nsfe) { 696 IOException ioe = new IOException(); 697 ioe.initCause(nsfe); 698 throw ioe; 699 } catch (IllegalAccessException iae) { 700 IOException ioe = new IOException(); 701 ioe.initCause(iae); 702 throw ioe; 703 } 704 } 705 706 /** Serialize a {@link RealMatrix}. 707 * <p> 708 * This method is intended to be called from within a private 709 * <code>writeObject</code> method (after a call to 710 * <code>oos.defaultWriteObject()</code>) in a class that has a 711 * {@link RealMatrix} field, which should be declared <code>transient</code>. 712 * This way, the default handling does not serialize the matrix (the {@link 713 * RealMatrix} interface is not serializable by default) but this method does 714 * serialize it specifically. 715 * </p> 716 * <p> 717 * The following example shows how a simple class with a name and a real matrix 718 * should be written: 719 * <pre><code> 720 * public class NamedMatrix implements Serializable { 721 * 722 * private final String name; 723 * private final transient RealMatrix coefficients; 724 * 725 * // omitted constructors, getters ... 726 * 727 * private void writeObject(ObjectOutputStream oos) throws IOException { 728 * oos.defaultWriteObject(); // takes care of name field 729 * MatrixUtils.serializeRealMatrix(coefficients, oos); 730 * } 731 * 732 * private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { 733 * ois.defaultReadObject(); // takes care of name field 734 * MatrixUtils.deserializeRealMatrix(this, "coefficients", ois); 735 * } 736 * 737 * } 738 * </code></pre> 739 * 740 * @param matrix real matrix to serialize 741 * @param oos stream where the real matrix should be written 742 * @exception IOException if object cannot be written to stream 743 * @see #deserializeRealMatrix(Object, String, ObjectInputStream) 744 */ 745 public static void serializeRealMatrix(final RealMatrix matrix, 746 final ObjectOutputStream oos) 747 throws IOException { 748 final int n = matrix.getRowDimension(); 749 final int m = matrix.getColumnDimension(); 750 oos.writeInt(n); 751 oos.writeInt(m); 752 for (int i = 0; i < n; ++i) { 753 for (int j = 0; j < m; ++j) { 754 oos.writeDouble(matrix.getEntry(i, j)); 755 } 756 } 757 } 758 759 /** Deserialize a {@link RealMatrix} field in a class. 760 * <p> 761 * This method is intended to be called from within a private 762 * <code>readObject</code> method (after a call to 763 * <code>ois.defaultReadObject()</code>) in a class that has a 764 * {@link RealMatrix} field, which should be declared <code>transient</code>. 765 * This way, the default handling does not deserialize the matrix (the {@link 766 * RealMatrix} interface is not serializable by default) but this method does 767 * deserialize it specifically. 768 * </p> 769 * @param instance instance in which the field must be set up 770 * @param fieldName name of the field within the class (may be private and final) 771 * @param ois stream from which the real matrix should be read 772 * @exception ClassNotFoundException if a class in the stream cannot be found 773 * @exception IOException if object cannot be read from the stream 774 * @see #serializeRealMatrix(RealMatrix, ObjectOutputStream) 775 */ 776 public static void deserializeRealMatrix(final Object instance, 777 final String fieldName, 778 final ObjectInputStream ois) 779 throws ClassNotFoundException, IOException { 780 try { 781 782 // read the matrix data 783 final int n = ois.readInt(); 784 final int m = ois.readInt(); 785 final double[][] data = new double[n][m]; 786 for (int i = 0; i < n; ++i) { 787 final double[] dataI = data[i]; 788 for (int j = 0; j < m; ++j) { 789 dataI[j] = ois.readDouble(); 790 } 791 } 792 793 // create the instance 794 final RealMatrix matrix = new Array2DRowRealMatrix(data, false); 795 796 // set up the field 797 final java.lang.reflect.Field f = 798 instance.getClass().getDeclaredField(fieldName); 799 f.setAccessible(true); 800 f.set(instance, matrix); 801 } catch (NoSuchFieldException nsfe) { 802 IOException ioe = new IOException(); 803 ioe.initCause(nsfe); 804 throw ioe; 805 } catch (IllegalAccessException iae) { 806 IOException ioe = new IOException(); 807 ioe.initCause(iae); 808 throw ioe; 809 } 810 } 811 812 /**Solve a system of composed of a Lower Triangular Matrix 813 * {@link RealMatrix}. 814 * <p> 815 * This method is called to solve systems of equations which are 816 * of the lower triangular form. The matrix {@link RealMatrix} 817 * is assumed, though not checked, to be in lower triangular form. 818 * The vector {@link RealVector} is overwritten with the solution. 819 * The matrix is checked that it is square and its dimensions match 820 * the length of the vector. 821 * </p> 822 * @param rm RealMatrix which is lower triangular 823 * @param b RealVector this is overwritten 824 * @throws DimensionMismatchException if the matrix and vector are not 825 * conformable 826 * @throws NonSquareMatrixException if the matrix {@code rm} is not square 827 * @throws MathArithmeticException if the absolute value of one of the diagonal 828 * coefficient of {@code rm} is lower than {@link Precision#SAFE_MIN} 829 */ 830 public static void solveLowerTriangularSystem(RealMatrix rm, RealVector b) 831 throws DimensionMismatchException, MathArithmeticException, 832 NonSquareMatrixException { 833 if (rm == null || b == null || rm.getRowDimension() != b.getDimension()) { 834 throw new DimensionMismatchException( 835 (rm == null) ? 0 : rm.getRowDimension(), 836 (b == null) ? 0 : b.getDimension()); 837 } 838 if( rm.getColumnDimension() != rm.getRowDimension() ){ 839 throw new NonSquareMatrixException(rm.getRowDimension(), 840 rm.getColumnDimension()); 841 } 842 int rows = rm.getRowDimension(); 843 for( int i = 0 ; i < rows ; i++ ){ 844 double diag = rm.getEntry(i, i); 845 if( JdkMath.abs(diag) < Precision.SAFE_MIN ){ 846 throw new MathArithmeticException(LocalizedFormats.ZERO_DENOMINATOR); 847 } 848 double bi = b.getEntry(i)/diag; 849 b.setEntry(i, bi ); 850 for( int j = i+1; j< rows; j++ ){ 851 b.setEntry(j, b.getEntry(j)-bi*rm.getEntry(j,i) ); 852 } 853 } 854 } 855 856 /** Solver a system composed of an Upper Triangular Matrix 857 * {@link RealMatrix}. 858 * <p> 859 * This method is called to solve systems of equations which are 860 * of the lower triangular form. The matrix {@link RealMatrix} 861 * is assumed, though not checked, to be in upper triangular form. 862 * The vector {@link RealVector} is overwritten with the solution. 863 * The matrix is checked that it is square and its dimensions match 864 * the length of the vector. 865 * </p> 866 * @param rm RealMatrix which is upper triangular 867 * @param b RealVector this is overwritten 868 * @throws DimensionMismatchException if the matrix and vector are not 869 * conformable 870 * @throws NonSquareMatrixException if the matrix {@code rm} is not 871 * square 872 * @throws MathArithmeticException if the absolute value of one of the diagonal 873 * coefficient of {@code rm} is lower than {@link Precision#SAFE_MIN} 874 */ 875 public static void solveUpperTriangularSystem(RealMatrix rm, RealVector b) 876 throws DimensionMismatchException, MathArithmeticException, 877 NonSquareMatrixException { 878 if (rm == null || b == null || rm.getRowDimension() != b.getDimension()) { 879 throw new DimensionMismatchException( 880 (rm == null) ? 0 : rm.getRowDimension(), 881 (b == null) ? 0 : b.getDimension()); 882 } 883 if( rm.getColumnDimension() != rm.getRowDimension() ){ 884 throw new NonSquareMatrixException(rm.getRowDimension(), 885 rm.getColumnDimension()); 886 } 887 int rows = rm.getRowDimension(); 888 for( int i = rows-1 ; i >-1 ; i-- ){ 889 double diag = rm.getEntry(i, i); 890 if( JdkMath.abs(diag) < Precision.SAFE_MIN ){ 891 throw new MathArithmeticException(LocalizedFormats.ZERO_DENOMINATOR); 892 } 893 double bi = b.getEntry(i)/diag; 894 b.setEntry(i, bi ); 895 for( int j = i-1; j>-1; j-- ){ 896 b.setEntry(j, b.getEntry(j)-bi*rm.getEntry(j,i) ); 897 } 898 } 899 } 900 901 /** 902 * Computes the inverse of the given matrix by splitting it into 903 * 4 sub-matrices. 904 * 905 * @param m Matrix whose inverse must be computed. 906 * @param splitIndex Index that determines the "split" line and 907 * column. 908 * The element corresponding to this index will part of the 909 * upper-left sub-matrix. 910 * @return the inverse of {@code m}. 911 * @throws NonSquareMatrixException if {@code m} is not square. 912 */ 913 public static RealMatrix blockInverse(RealMatrix m, 914 int splitIndex) { 915 final int n = m.getRowDimension(); 916 if (m.getColumnDimension() != n) { 917 throw new NonSquareMatrixException(m.getRowDimension(), 918 m.getColumnDimension()); 919 } 920 921 final int splitIndex1 = splitIndex + 1; 922 923 final RealMatrix a = m.getSubMatrix(0, splitIndex, 0, splitIndex); 924 final RealMatrix b = m.getSubMatrix(0, splitIndex, splitIndex1, n - 1); 925 final RealMatrix c = m.getSubMatrix(splitIndex1, n - 1, 0, splitIndex); 926 final RealMatrix d = m.getSubMatrix(splitIndex1, n - 1, splitIndex1, n - 1); 927 928 final SingularValueDecomposition aDec = new SingularValueDecomposition(a); 929 final DecompositionSolver aSolver = aDec.getSolver(); 930 if (!aSolver.isNonSingular()) { 931 throw new SingularMatrixException(); 932 } 933 final RealMatrix aInv = aSolver.getInverse(); 934 935 final SingularValueDecomposition dDec = new SingularValueDecomposition(d); 936 final DecompositionSolver dSolver = dDec.getSolver(); 937 if (!dSolver.isNonSingular()) { 938 throw new SingularMatrixException(); 939 } 940 final RealMatrix dInv = dSolver.getInverse(); 941 942 final RealMatrix tmp1 = a.subtract(b.multiply(dInv).multiply(c)); 943 final SingularValueDecomposition tmp1Dec = new SingularValueDecomposition(tmp1); 944 final DecompositionSolver tmp1Solver = tmp1Dec.getSolver(); 945 if (!tmp1Solver.isNonSingular()) { 946 throw new SingularMatrixException(); 947 } 948 final RealMatrix result00 = tmp1Solver.getInverse(); 949 950 final RealMatrix tmp2 = d.subtract(c.multiply(aInv).multiply(b)); 951 final SingularValueDecomposition tmp2Dec = new SingularValueDecomposition(tmp2); 952 final DecompositionSolver tmp2Solver = tmp2Dec.getSolver(); 953 if (!tmp2Solver.isNonSingular()) { 954 throw new SingularMatrixException(); 955 } 956 final RealMatrix result11 = tmp2Solver.getInverse(); 957 958 final RealMatrix result01 = aInv.multiply(b).multiply(result11).scalarMultiply(-1); 959 final RealMatrix result10 = dInv.multiply(c).multiply(result00).scalarMultiply(-1); 960 961 final RealMatrix result = new Array2DRowRealMatrix(n, n); 962 result.setSubMatrix(result00.getData(), 0, 0); 963 result.setSubMatrix(result01.getData(), 0, splitIndex1); 964 result.setSubMatrix(result10.getData(), splitIndex1, 0); 965 result.setSubMatrix(result11.getData(), splitIndex1, splitIndex1); 966 967 return result; 968 } 969 970 /** 971 * Computes the inverse of the given matrix. 972 * <p> 973 * By default, the inverse of the matrix is computed using the QR-decomposition, 974 * unless a more efficient method can be determined for the input matrix. 975 * <p> 976 * Note: this method will use a singularity threshold of 0, 977 * use {@link #inverse(RealMatrix, double)} if a different threshold is needed. 978 * 979 * @param matrix Matrix whose inverse shall be computed 980 * @return the inverse of {@code matrix} 981 * @throws NullArgumentException if {@code matrix} is {@code null} 982 * @throws SingularMatrixException if m is singular 983 * @throws NonSquareMatrixException if matrix is not square 984 * @since 3.3 985 */ 986 public static RealMatrix inverse(RealMatrix matrix) 987 throws NullArgumentException, SingularMatrixException, NonSquareMatrixException { 988 return inverse(matrix, 0); 989 } 990 991 /** 992 * Computes the inverse of the given matrix. 993 * <p> 994 * By default, the inverse of the matrix is computed using the QR-decomposition, 995 * unless a more efficient method can be determined for the input matrix. 996 * 997 * @param matrix Matrix whose inverse shall be computed 998 * @param threshold Singularity threshold 999 * @return the inverse of {@code m} 1000 * @throws NullArgumentException if {@code matrix} is {@code null} 1001 * @throws SingularMatrixException if matrix is singular 1002 * @throws NonSquareMatrixException if matrix is not square 1003 * @since 3.3 1004 */ 1005 public static RealMatrix inverse(RealMatrix matrix, double threshold) 1006 throws NullArgumentException, SingularMatrixException, NonSquareMatrixException { 1007 1008 NullArgumentException.check(matrix); 1009 1010 if (!matrix.isSquare()) { 1011 throw new NonSquareMatrixException(matrix.getRowDimension(), 1012 matrix.getColumnDimension()); 1013 } 1014 1015 if (matrix instanceof DiagonalMatrix) { 1016 return ((DiagonalMatrix) matrix).inverse(threshold); 1017 } else { 1018 QRDecomposition decomposition = new QRDecomposition(matrix, threshold); 1019 return decomposition.getSolver().getInverse(); 1020 } 1021 } 1022}