1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.linear;
19
20 import java.util.Arrays;
21
22 import org.apache.commons.math4.core.jdkmath.JdkMath;
23 import org.junit.Test;
24 import org.junit.Assert;
25
26 public class TriDiagonalTransformerTest {
27
28 private double[][] testSquare5 = {
29 { 1, 2, 3, 1, 1 },
30 { 2, 1, 1, 3, 1 },
31 { 3, 1, 1, 1, 2 },
32 { 1, 3, 1, 2, 1 },
33 { 1, 1, 2, 1, 3 }
34 };
35
36 private double[][] testSquare3 = {
37 { 1, 3, 4 },
38 { 3, 2, 2 },
39 { 4, 2, 0 }
40 };
41
42 @Test
43 public void testNonSquare() {
44 try {
45 new TriDiagonalTransformer(MatrixUtils.createRealMatrix(new double[3][2]));
46 Assert.fail("an exception should have been thrown");
47 } catch (NonSquareMatrixException ime) {
48
49 }
50 }
51
52 @Test
53 public void testAEqualQTQt() {
54 checkAEqualQTQt(MatrixUtils.createRealMatrix(testSquare5));
55 checkAEqualQTQt(MatrixUtils.createRealMatrix(testSquare3));
56 }
57
58 private void checkAEqualQTQt(RealMatrix matrix) {
59 TriDiagonalTransformer transformer = new TriDiagonalTransformer(matrix);
60 RealMatrix q = transformer.getQ();
61 RealMatrix qT = transformer.getQT();
62 RealMatrix t = transformer.getT();
63 double norm = q.multiply(t).multiply(qT).subtract(matrix).getNorm();
64 Assert.assertEquals(0, norm, 4.0e-15);
65 }
66
67 @Test
68 public void testNoAccessBelowDiagonal() {
69 checkNoAccessBelowDiagonal(testSquare5);
70 checkNoAccessBelowDiagonal(testSquare3);
71 }
72
73 private void checkNoAccessBelowDiagonal(double[][] data) {
74 double[][] modifiedData = new double[data.length][];
75 for (int i = 0; i < data.length; ++i) {
76 modifiedData[i] = data[i].clone();
77 Arrays.fill(modifiedData[i], 0, i, Double.NaN);
78 }
79 RealMatrix matrix = MatrixUtils.createRealMatrix(modifiedData);
80 TriDiagonalTransformer transformer = new TriDiagonalTransformer(matrix);
81 RealMatrix q = transformer.getQ();
82 RealMatrix qT = transformer.getQT();
83 RealMatrix t = transformer.getT();
84 double norm = q.multiply(t).multiply(qT).subtract(MatrixUtils.createRealMatrix(data)).getNorm();
85 Assert.assertEquals(0, norm, 4.0e-15);
86 }
87
88 @Test
89 public void testQOrthogonal() {
90 checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getQ());
91 checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getQ());
92 }
93
94 @Test
95 public void testQTOrthogonal() {
96 checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getQT());
97 checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getQT());
98 }
99
100 private void checkOrthogonal(RealMatrix m) {
101 RealMatrix mTm = m.transpose().multiply(m);
102 RealMatrix id = MatrixUtils.createRealIdentityMatrix(mTm.getRowDimension());
103 Assert.assertEquals(0, mTm.subtract(id).getNorm(), 1.0e-15);
104 }
105
106 @Test
107 public void testTTriDiagonal() {
108 checkTriDiagonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getT());
109 checkTriDiagonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getT());
110 }
111
112 private void checkTriDiagonal(RealMatrix m) {
113 final int rows = m.getRowDimension();
114 final int cols = m.getColumnDimension();
115 for (int i = 0; i < rows; ++i) {
116 for (int j = 0; j < cols; ++j) {
117 if (i < j - 1 || i > j + 1) {
118 Assert.assertEquals(0, m.getEntry(i, j), 1.0e-16);
119 }
120 }
121 }
122 }
123
124 @Test
125 public void testMatricesValues5() {
126 checkMatricesValues(testSquare5,
127 new double[][] {
128 { 1.0, 0.0, 0.0, 0.0, 0.0 },
129 { 0.0, -0.5163977794943222, 0.016748280772542083, 0.839800693771262, 0.16669620021405473 },
130 { 0.0, -0.7745966692414833, -0.4354553000860955, -0.44989322880603355, -0.08930153582895772 },
131 { 0.0, -0.2581988897471611, 0.6364346693566014, -0.30263204032131164, 0.6608313651342882 },
132 { 0.0, -0.2581988897471611, 0.6364346693566009, -0.027289660803112598, -0.7263191580755246 }
133 },
134 new double[] { 1, 4.4, 1.433099579242636, -0.89537362758743, 2.062274048344794 },
135 new double[] { -JdkMath.sqrt(15), -3.0832882879592476, 0.6082710842351517, 1.1786086405912128 });
136 }
137
138 @Test
139 public void testMatricesValues3() {
140 checkMatricesValues(testSquare3,
141 new double[][] {
142 { 1.0, 0.0, 0.0 },
143 { 0.0, -0.6, 0.8 },
144 { 0.0, -0.8, -0.6 },
145 },
146 new double[] { 1, 2.64, -0.64 },
147 new double[] { -5, -1.52 });
148 }
149
150 private void checkMatricesValues(double[][] matrix, double[][] qRef,
151 double[] mainDiagnonal,
152 double[] secondaryDiagonal) {
153 TriDiagonalTransformer transformer =
154 new TriDiagonalTransformer(MatrixUtils.createRealMatrix(matrix));
155
156
157 RealMatrix q = transformer.getQ();
158 Assert.assertEquals(0, q.subtract(MatrixUtils.createRealMatrix(qRef)).getNorm(), 1.0e-14);
159
160 RealMatrix t = transformer.getT();
161 double[][] tData = new double[mainDiagnonal.length][mainDiagnonal.length];
162 for (int i = 0; i < mainDiagnonal.length; ++i) {
163 tData[i][i] = mainDiagnonal[i];
164 if (i > 0) {
165 tData[i][i - 1] = secondaryDiagonal[i - 1];
166 }
167 if (i < secondaryDiagonal.length) {
168 tData[i][i + 1] = secondaryDiagonal[i];
169 }
170 }
171 Assert.assertEquals(0, t.subtract(MatrixUtils.createRealMatrix(tData)).getNorm(), 1.0e-14);
172
173
174 Assert.assertSame(q, transformer.getQ());
175 Assert.assertSame(t, transformer.getT());
176 }
177 }