1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.linear;
19
20 import java.io.IOException;
21 import java.io.ObjectInputStream;
22 import java.io.ObjectOutputStream;
23 import java.util.Arrays;
24
25 import org.apache.commons.math4.legacy.core.Field;
26 import org.apache.commons.math4.legacy.core.FieldElement;
27 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
28 import org.apache.commons.math4.legacy.exception.MathArithmeticException;
29 import org.apache.commons.math4.legacy.exception.NoDataException;
30 import org.apache.commons.math4.legacy.exception.NullArgumentException;
31 import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
32 import org.apache.commons.math4.legacy.exception.OutOfRangeException;
33 import org.apache.commons.math4.legacy.exception.ZeroException;
34 import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
35 import org.apache.commons.math4.core.jdkmath.JdkMath;
36 import org.apache.commons.math4.legacy.core.MathArrays;
37 import org.apache.commons.numbers.core.Precision;
38
39
40
41
42
43 public final class MatrixUtils {
44
45
46
47
48
49 public static final RealMatrixFormat DEFAULT_FORMAT = RealMatrixFormat.getInstance();
50
51
52
53
54
55 public static final RealMatrixFormat OCTAVE_FORMAT = new RealMatrixFormat("[", "]", "", "", "; ", ", ");
56
57
58
59
60 private MatrixUtils() {
61 super();
62 }
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77 public static RealMatrix createRealMatrix(final int rows, final int columns) {
78 return (rows * columns <= 4096) ?
79 new Array2DRowRealMatrix(rows, columns) : new BlockRealMatrix(rows, columns);
80 }
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97 public static <T extends FieldElement<T>> FieldMatrix<T> createFieldMatrix(final Field<T> field,
98 final int rows,
99 final int columns) {
100 return (rows * columns <= 4096) ?
101 new Array2DRowFieldMatrix<>(field, rows, columns) : new BlockFieldMatrix<>(field, rows, columns);
102 }
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124 public static RealMatrix createRealMatrix(double[][] data)
125 throws NullArgumentException, DimensionMismatchException,
126 NoDataException {
127 if (data == null ||
128 data[0] == null) {
129 throw new NullArgumentException();
130 }
131 return (data.length * data[0].length <= 4096) ?
132 new Array2DRowRealMatrix(data) : new BlockRealMatrix(data);
133 }
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154 public static <T extends FieldElement<T>> FieldMatrix<T> createFieldMatrix(T[][] data)
155 throws DimensionMismatchException, NoDataException, NullArgumentException {
156 if (data == null ||
157 data[0] == null) {
158 throw new NullArgumentException();
159 }
160 return (data.length * data[0].length <= 4096) ?
161 new Array2DRowFieldMatrix<>(data) : new BlockFieldMatrix<>(data);
162 }
163
164
165
166
167
168
169
170
171
172 public static RealMatrix createRealIdentityMatrix(int dimension) {
173 final RealMatrix m = createRealMatrix(dimension, dimension);
174 for (int i = 0; i < dimension; ++i) {
175 m.setEntry(i, i, 1.0);
176 }
177 return m;
178 }
179
180
181
182
183
184
185
186
187
188
189
190 public static <T extends FieldElement<T>> FieldMatrix<T>
191 createFieldIdentityMatrix(final Field<T> field, final int dimension) {
192 final T zero = field.getZero();
193 final T one = field.getOne();
194 final T[][] d = MathArrays.buildArray(field, dimension, dimension);
195 for (int row = 0; row < dimension; row++) {
196 final T[] dRow = d[row];
197 Arrays.fill(dRow, zero);
198 dRow[row] = one;
199 }
200 return new Array2DRowFieldMatrix<>(field, d, false);
201 }
202
203
204
205
206
207
208
209
210
211
212
213 public static DiagonalMatrix createRealDiagonalMatrix(final double[] diagonal) {
214 return new DiagonalMatrix(diagonal, true);
215 }
216
217
218
219
220
221
222
223
224
225
226 public static RealMatrix createRealMatrixWithDiagonal(final double[] diagonal) {
227 final int size = diagonal.length;
228 final RealMatrix m = createRealMatrix(size, size);
229 for (int i = 0; i < size; i++) {
230 m.setEntry(i, i, diagonal[i]);
231 }
232 return m;
233 }
234
235
236
237
238
239
240
241
242
243
244 public static <T extends FieldElement<T>> FieldMatrix<T>
245 createFieldDiagonalMatrix(final T[] diagonal) {
246 final FieldMatrix<T> m =
247 createFieldMatrix(diagonal[0].getField(), diagonal.length, diagonal.length);
248 for (int i = 0; i < diagonal.length; ++i) {
249 m.setEntry(i, i, diagonal[i]);
250 }
251 return m;
252 }
253
254
255
256
257
258
259
260
261
262 public static RealVector createRealVector(double[] data)
263 throws NoDataException, NullArgumentException {
264 if (data == null) {
265 throw new NullArgumentException();
266 }
267 return new ArrayRealVector(data, true);
268 }
269
270
271
272
273
274
275
276
277
278
279
280 public static <T extends FieldElement<T>> FieldVector<T> createFieldVector(final T[] data)
281 throws NoDataException, NullArgumentException, ZeroException {
282 if (data == null) {
283 throw new NullArgumentException();
284 }
285 if (data.length == 0) {
286 throw new ZeroException(LocalizedFormats.VECTOR_MUST_HAVE_AT_LEAST_ONE_ELEMENT);
287 }
288 return new ArrayFieldVector<>(data[0].getField(), data, true);
289 }
290
291
292
293
294
295
296
297
298
299
300 public static RealMatrix createRowRealMatrix(double[] rowData)
301 throws NoDataException, NullArgumentException {
302 if (rowData == null) {
303 throw new NullArgumentException();
304 }
305 final int nCols = rowData.length;
306 final RealMatrix m = createRealMatrix(1, nCols);
307 for (int i = 0; i < nCols; ++i) {
308 m.setEntry(0, i, rowData[i]);
309 }
310 return m;
311 }
312
313
314
315
316
317
318
319
320
321
322
323 public static <T extends FieldElement<T>> FieldMatrix<T>
324 createRowFieldMatrix(final T[] rowData)
325 throws NoDataException, NullArgumentException {
326 if (rowData == null) {
327 throw new NullArgumentException();
328 }
329 final int nCols = rowData.length;
330 if (nCols == 0) {
331 throw new NoDataException(LocalizedFormats.AT_LEAST_ONE_COLUMN);
332 }
333 final FieldMatrix<T> m = createFieldMatrix(rowData[0].getField(), 1, nCols);
334 for (int i = 0; i < nCols; ++i) {
335 m.setEntry(0, i, rowData[i]);
336 }
337 return m;
338 }
339
340
341
342
343
344
345
346
347
348
349 public static RealMatrix createColumnRealMatrix(double[] columnData)
350 throws NoDataException, NullArgumentException {
351 if (columnData == null) {
352 throw new NullArgumentException();
353 }
354 final int nRows = columnData.length;
355 final RealMatrix m = createRealMatrix(nRows, 1);
356 for (int i = 0; i < nRows; ++i) {
357 m.setEntry(i, 0, columnData[i]);
358 }
359 return m;
360 }
361
362
363
364
365
366
367
368
369
370
371
372 public static <T extends FieldElement<T>> FieldMatrix<T>
373 createColumnFieldMatrix(final T[] columnData)
374 throws NoDataException, NullArgumentException {
375 if (columnData == null) {
376 throw new NullArgumentException();
377 }
378 final int nRows = columnData.length;
379 if (nRows == 0) {
380 throw new NoDataException(LocalizedFormats.AT_LEAST_ONE_ROW);
381 }
382 final FieldMatrix<T> m = createFieldMatrix(columnData[0].getField(), nRows, 1);
383 for (int i = 0; i < nRows; ++i) {
384 m.setEntry(i, 0, columnData[i]);
385 }
386 return m;
387 }
388
389
390
391
392
393
394
395
396
397
398
399
400 private static boolean isSymmetricInternal(RealMatrix matrix,
401 double relativeTolerance,
402 boolean raiseException) {
403 final int rows = matrix.getRowDimension();
404 if (rows != matrix.getColumnDimension()) {
405 if (raiseException) {
406 throw new NonSquareMatrixException(rows, matrix.getColumnDimension());
407 } else {
408 return false;
409 }
410 }
411 for (int i = 0; i < rows; i++) {
412 for (int j = i + 1; j < rows; j++) {
413 final double mij = matrix.getEntry(i, j);
414 final double mji = matrix.getEntry(j, i);
415 if (JdkMath.abs(mij - mji) >
416 JdkMath.max(JdkMath.abs(mij), JdkMath.abs(mji)) * relativeTolerance) {
417 if (raiseException) {
418 throw new NonSymmetricMatrixException(i, j, relativeTolerance);
419 } else {
420 return false;
421 }
422 }
423 }
424 }
425 return true;
426 }
427
428
429
430
431
432
433
434
435
436
437 public static void checkSymmetric(RealMatrix matrix,
438 double eps) {
439 isSymmetricInternal(matrix, eps, true);
440 }
441
442
443
444
445
446
447
448
449
450 public static boolean isSymmetric(RealMatrix matrix,
451 double eps) {
452 return isSymmetricInternal(matrix, eps, false);
453 }
454
455
456
457
458
459
460
461
462
463
464 public static void checkMatrixIndex(final AnyMatrix m,
465 final int row, final int column)
466 throws OutOfRangeException {
467 checkRowIndex(m, row);
468 checkColumnIndex(m, column);
469 }
470
471
472
473
474
475
476
477
478 public static void checkRowIndex(final AnyMatrix m, final int row)
479 throws OutOfRangeException {
480 if (row < 0 ||
481 row >= m.getRowDimension()) {
482 throw new OutOfRangeException(LocalizedFormats.ROW_INDEX,
483 row, 0, m.getRowDimension() - 1);
484 }
485 }
486
487
488
489
490
491
492
493
494 public static void checkColumnIndex(final AnyMatrix m, final int column)
495 throws OutOfRangeException {
496 if (column < 0 || column >= m.getColumnDimension()) {
497 throw new OutOfRangeException(LocalizedFormats.COLUMN_INDEX,
498 column, 0, m.getColumnDimension() - 1);
499 }
500 }
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515 public static void checkSubMatrixIndex(final AnyMatrix m,
516 final int startRow, final int endRow,
517 final int startColumn, final int endColumn)
518 throws NumberIsTooSmallException, OutOfRangeException {
519 checkRowIndex(m, startRow);
520 checkRowIndex(m, endRow);
521 if (endRow < startRow) {
522 throw new NumberIsTooSmallException(LocalizedFormats.INITIAL_ROW_AFTER_FINAL_ROW,
523 endRow, startRow, false);
524 }
525
526 checkColumnIndex(m, startColumn);
527 checkColumnIndex(m, endColumn);
528 if (endColumn < startColumn) {
529 throw new NumberIsTooSmallException(LocalizedFormats.INITIAL_COLUMN_AFTER_FINAL_COLUMN,
530 endColumn, startColumn, false);
531 }
532 }
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547 public static void checkSubMatrixIndex(final AnyMatrix m,
548 final int[] selectedRows,
549 final int[] selectedColumns)
550 throws NoDataException, NullArgumentException, OutOfRangeException {
551 if (selectedRows == null) {
552 throw new NullArgumentException();
553 }
554 if (selectedColumns == null) {
555 throw new NullArgumentException();
556 }
557 if (selectedRows.length == 0) {
558 throw new NoDataException(LocalizedFormats.EMPTY_SELECTED_ROW_INDEX_ARRAY);
559 }
560 if (selectedColumns.length == 0) {
561 throw new NoDataException(LocalizedFormats.EMPTY_SELECTED_COLUMN_INDEX_ARRAY);
562 }
563
564 for (final int row : selectedRows) {
565 checkRowIndex(m, row);
566 }
567 for (final int column : selectedColumns) {
568 checkColumnIndex(m, column);
569 }
570 }
571
572
573
574
575
576
577
578
579
580 public static void checkAdditionCompatible(final AnyMatrix left, final AnyMatrix right) {
581 left.checkAdd(right);
582 }
583
584
585
586
587
588
589
590
591
592 public static void checkSubtractionCompatible(final AnyMatrix left, final AnyMatrix right) {
593 left.checkAdd(right);
594 }
595
596
597
598
599
600
601
602
603
604 public static void checkMultiplicationCompatible(final AnyMatrix left, final AnyMatrix right) {
605 left.checkMultiply(right);
606 }
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647 public static void serializeRealVector(final RealVector vector,
648 final ObjectOutputStream oos)
649 throws IOException {
650 final int n = vector.getDimension();
651 oos.writeInt(n);
652 for (int i = 0; i < n; ++i) {
653 oos.writeDouble(vector.getEntry(i));
654 }
655 }
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674 public static void deserializeRealVector(final Object instance,
675 final String fieldName,
676 final ObjectInputStream ois)
677 throws ClassNotFoundException, IOException {
678 try {
679
680
681 final int n = ois.readInt();
682 final double[] data = new double[n];
683 for (int i = 0; i < n; ++i) {
684 data[i] = ois.readDouble();
685 }
686
687
688 final RealVector vector = new ArrayRealVector(data, false);
689
690
691 final java.lang.reflect.Field f =
692 instance.getClass().getDeclaredField(fieldName);
693 f.setAccessible(true);
694 f.set(instance, vector);
695 } catch (NoSuchFieldException nsfe) {
696 IOException ioe = new IOException();
697 ioe.initCause(nsfe);
698 throw ioe;
699 } catch (IllegalAccessException iae) {
700 IOException ioe = new IOException();
701 ioe.initCause(iae);
702 throw ioe;
703 }
704 }
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745 public static void serializeRealMatrix(final RealMatrix matrix,
746 final ObjectOutputStream oos)
747 throws IOException {
748 final int n = matrix.getRowDimension();
749 final int m = matrix.getColumnDimension();
750 oos.writeInt(n);
751 oos.writeInt(m);
752 for (int i = 0; i < n; ++i) {
753 for (int j = 0; j < m; ++j) {
754 oos.writeDouble(matrix.getEntry(i, j));
755 }
756 }
757 }
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776 public static void deserializeRealMatrix(final Object instance,
777 final String fieldName,
778 final ObjectInputStream ois)
779 throws ClassNotFoundException, IOException {
780 try {
781
782
783 final int n = ois.readInt();
784 final int m = ois.readInt();
785 final double[][] data = new double[n][m];
786 for (int i = 0; i < n; ++i) {
787 final double[] dataI = data[i];
788 for (int j = 0; j < m; ++j) {
789 dataI[j] = ois.readDouble();
790 }
791 }
792
793
794 final RealMatrix matrix = new Array2DRowRealMatrix(data, false);
795
796
797 final java.lang.reflect.Field f =
798 instance.getClass().getDeclaredField(fieldName);
799 f.setAccessible(true);
800 f.set(instance, matrix);
801 } catch (NoSuchFieldException nsfe) {
802 IOException ioe = new IOException();
803 ioe.initCause(nsfe);
804 throw ioe;
805 } catch (IllegalAccessException iae) {
806 IOException ioe = new IOException();
807 ioe.initCause(iae);
808 throw ioe;
809 }
810 }
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830 public static void solveLowerTriangularSystem(RealMatrix rm, RealVector b)
831 throws DimensionMismatchException, MathArithmeticException,
832 NonSquareMatrixException {
833 if (rm == null || b == null || rm.getRowDimension() != b.getDimension()) {
834 throw new DimensionMismatchException(
835 (rm == null) ? 0 : rm.getRowDimension(),
836 (b == null) ? 0 : b.getDimension());
837 }
838 if( rm.getColumnDimension() != rm.getRowDimension() ){
839 throw new NonSquareMatrixException(rm.getRowDimension(),
840 rm.getColumnDimension());
841 }
842 int rows = rm.getRowDimension();
843 for( int i = 0 ; i < rows ; i++ ){
844 double diag = rm.getEntry(i, i);
845 if( JdkMath.abs(diag) < Precision.SAFE_MIN ){
846 throw new MathArithmeticException(LocalizedFormats.ZERO_DENOMINATOR);
847 }
848 double bi = b.getEntry(i)/diag;
849 b.setEntry(i, bi );
850 for( int j = i+1; j< rows; j++ ){
851 b.setEntry(j, b.getEntry(j)-bi*rm.getEntry(j,i) );
852 }
853 }
854 }
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875 public static void solveUpperTriangularSystem(RealMatrix rm, RealVector b)
876 throws DimensionMismatchException, MathArithmeticException,
877 NonSquareMatrixException {
878 if (rm == null || b == null || rm.getRowDimension() != b.getDimension()) {
879 throw new DimensionMismatchException(
880 (rm == null) ? 0 : rm.getRowDimension(),
881 (b == null) ? 0 : b.getDimension());
882 }
883 if( rm.getColumnDimension() != rm.getRowDimension() ){
884 throw new NonSquareMatrixException(rm.getRowDimension(),
885 rm.getColumnDimension());
886 }
887 int rows = rm.getRowDimension();
888 for( int i = rows-1 ; i >-1 ; i-- ){
889 double diag = rm.getEntry(i, i);
890 if( JdkMath.abs(diag) < Precision.SAFE_MIN ){
891 throw new MathArithmeticException(LocalizedFormats.ZERO_DENOMINATOR);
892 }
893 double bi = b.getEntry(i)/diag;
894 b.setEntry(i, bi );
895 for( int j = i-1; j>-1; j-- ){
896 b.setEntry(j, b.getEntry(j)-bi*rm.getEntry(j,i) );
897 }
898 }
899 }
900
901
902
903
904
905
906
907
908
909
910
911
912
913 public static RealMatrix blockInverse(RealMatrix m,
914 int splitIndex) {
915 final int n = m.getRowDimension();
916 if (m.getColumnDimension() != n) {
917 throw new NonSquareMatrixException(m.getRowDimension(),
918 m.getColumnDimension());
919 }
920
921 final int splitIndex1 = splitIndex + 1;
922
923 final RealMatrix a = m.getSubMatrix(0, splitIndex, 0, splitIndex);
924 final RealMatrix b = m.getSubMatrix(0, splitIndex, splitIndex1, n - 1);
925 final RealMatrix c = m.getSubMatrix(splitIndex1, n - 1, 0, splitIndex);
926 final RealMatrix d = m.getSubMatrix(splitIndex1, n - 1, splitIndex1, n - 1);
927
928 final SingularValueDecomposition aDec = new SingularValueDecomposition(a);
929 final DecompositionSolver aSolver = aDec.getSolver();
930 if (!aSolver.isNonSingular()) {
931 throw new SingularMatrixException();
932 }
933 final RealMatrix aInv = aSolver.getInverse();
934
935 final SingularValueDecomposition dDec = new SingularValueDecomposition(d);
936 final DecompositionSolver dSolver = dDec.getSolver();
937 if (!dSolver.isNonSingular()) {
938 throw new SingularMatrixException();
939 }
940 final RealMatrix dInv = dSolver.getInverse();
941
942 final RealMatrix tmp1 = a.subtract(b.multiply(dInv).multiply(c));
943 final SingularValueDecomposition tmp1Dec = new SingularValueDecomposition(tmp1);
944 final DecompositionSolver tmp1Solver = tmp1Dec.getSolver();
945 if (!tmp1Solver.isNonSingular()) {
946 throw new SingularMatrixException();
947 }
948 final RealMatrix result00 = tmp1Solver.getInverse();
949
950 final RealMatrix tmp2 = d.subtract(c.multiply(aInv).multiply(b));
951 final SingularValueDecomposition tmp2Dec = new SingularValueDecomposition(tmp2);
952 final DecompositionSolver tmp2Solver = tmp2Dec.getSolver();
953 if (!tmp2Solver.isNonSingular()) {
954 throw new SingularMatrixException();
955 }
956 final RealMatrix result11 = tmp2Solver.getInverse();
957
958 final RealMatrix result01 = aInv.multiply(b).multiply(result11).scalarMultiply(-1);
959 final RealMatrix result10 = dInv.multiply(c).multiply(result00).scalarMultiply(-1);
960
961 final RealMatrix result = new Array2DRowRealMatrix(n, n);
962 result.setSubMatrix(result00.getData(), 0, 0);
963 result.setSubMatrix(result01.getData(), 0, splitIndex1);
964 result.setSubMatrix(result10.getData(), splitIndex1, 0);
965 result.setSubMatrix(result11.getData(), splitIndex1, splitIndex1);
966
967 return result;
968 }
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986 public static RealMatrix inverse(RealMatrix matrix)
987 throws NullArgumentException, SingularMatrixException, NonSquareMatrixException {
988 return inverse(matrix, 0);
989 }
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005 public static RealMatrix inverse(RealMatrix matrix, double threshold)
1006 throws NullArgumentException, SingularMatrixException, NonSquareMatrixException {
1007
1008 NullArgumentException.check(matrix);
1009
1010 if (!matrix.isSquare()) {
1011 throw new NonSquareMatrixException(matrix.getRowDimension(),
1012 matrix.getColumnDimension());
1013 }
1014
1015 if (matrix instanceof DiagonalMatrix) {
1016 return ((DiagonalMatrix) matrix).inverse(threshold);
1017 } else {
1018 QRDecomposition decomposition = new QRDecomposition(matrix, threshold);
1019 return decomposition.getSolver().getInverse();
1020 }
1021 }
1022 }