View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.linear;
19  
20  import java.util.Arrays;
21  
22  import org.apache.commons.math4.core.jdkmath.JdkMath;
23  import org.junit.Test;
24  import org.junit.Assert;
25  
26  public class TriDiagonalTransformerTest {
27  
28      private double[][] testSquare5 = {
29              { 1, 2, 3, 1, 1 },
30              { 2, 1, 1, 3, 1 },
31              { 3, 1, 1, 1, 2 },
32              { 1, 3, 1, 2, 1 },
33              { 1, 1, 2, 1, 3 }
34      };
35  
36      private double[][] testSquare3 = {
37              { 1, 3, 4 },
38              { 3, 2, 2 },
39              { 4, 2, 0 }
40      };
41  
42      @Test
43      public void testNonSquare() {
44          try {
45              new TriDiagonalTransformer(MatrixUtils.createRealMatrix(new double[3][2]));
46              Assert.fail("an exception should have been thrown");
47          } catch (NonSquareMatrixException ime) {
48              // expected behavior
49          }
50      }
51  
52      @Test
53      public void testAEqualQTQt() {
54          checkAEqualQTQt(MatrixUtils.createRealMatrix(testSquare5));
55          checkAEqualQTQt(MatrixUtils.createRealMatrix(testSquare3));
56      }
57  
58      private void checkAEqualQTQt(RealMatrix matrix) {
59          TriDiagonalTransformer transformer = new TriDiagonalTransformer(matrix);
60          RealMatrix q  = transformer.getQ();
61          RealMatrix qT = transformer.getQT();
62          RealMatrix t  = transformer.getT();
63          double norm = q.multiply(t).multiply(qT).subtract(matrix).getNorm();
64          Assert.assertEquals(0, norm, 4.0e-15);
65      }
66  
67      @Test
68      public void testNoAccessBelowDiagonal() {
69          checkNoAccessBelowDiagonal(testSquare5);
70          checkNoAccessBelowDiagonal(testSquare3);
71      }
72  
73      private void checkNoAccessBelowDiagonal(double[][] data) {
74          double[][] modifiedData = new double[data.length][];
75          for (int i = 0; i < data.length; ++i) {
76              modifiedData[i] = data[i].clone();
77              Arrays.fill(modifiedData[i], 0, i, Double.NaN);
78          }
79          RealMatrix matrix = MatrixUtils.createRealMatrix(modifiedData);
80          TriDiagonalTransformer transformer = new TriDiagonalTransformer(matrix);
81          RealMatrix q  = transformer.getQ();
82          RealMatrix qT = transformer.getQT();
83          RealMatrix t  = transformer.getT();
84          double norm = q.multiply(t).multiply(qT).subtract(MatrixUtils.createRealMatrix(data)).getNorm();
85          Assert.assertEquals(0, norm, 4.0e-15);
86      }
87  
88      @Test
89      public void testQOrthogonal() {
90          checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getQ());
91          checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getQ());
92      }
93  
94      @Test
95      public void testQTOrthogonal() {
96          checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getQT());
97          checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getQT());
98      }
99  
100     private void checkOrthogonal(RealMatrix m) {
101         RealMatrix mTm = m.transpose().multiply(m);
102         RealMatrix id  = MatrixUtils.createRealIdentityMatrix(mTm.getRowDimension());
103         Assert.assertEquals(0, mTm.subtract(id).getNorm(), 1.0e-15);
104     }
105 
106     @Test
107     public void testTTriDiagonal() {
108         checkTriDiagonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getT());
109         checkTriDiagonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getT());
110     }
111 
112     private void checkTriDiagonal(RealMatrix m) {
113         final int rows = m.getRowDimension();
114         final int cols = m.getColumnDimension();
115         for (int i = 0; i < rows; ++i) {
116             for (int j = 0; j < cols; ++j) {
117                 if (i < j - 1 || i > j + 1) {
118                     Assert.assertEquals(0, m.getEntry(i, j), 1.0e-16);
119                 }
120             }
121         }
122     }
123 
124     @Test
125     public void testMatricesValues5() {
126         checkMatricesValues(testSquare5,
127                             new double[][] {
128                                 { 1.0,  0.0,                 0.0,                  0.0,                   0.0 },
129                                 { 0.0, -0.5163977794943222,  0.016748280772542083, 0.839800693771262,     0.16669620021405473 },
130                                 { 0.0, -0.7745966692414833, -0.4354553000860955,  -0.44989322880603355,  -0.08930153582895772 },
131                                 { 0.0, -0.2581988897471611,  0.6364346693566014,  -0.30263204032131164,   0.6608313651342882 },
132                                 { 0.0, -0.2581988897471611,  0.6364346693566009,  -0.027289660803112598, -0.7263191580755246 }
133                             },
134                             new double[] { 1, 4.4, 1.433099579242636, -0.89537362758743, 2.062274048344794 },
135                             new double[] { -JdkMath.sqrt(15), -3.0832882879592476, 0.6082710842351517, 1.1786086405912128 });
136     }
137 
138     @Test
139     public void testMatricesValues3() {
140         checkMatricesValues(testSquare3,
141                             new double[][] {
142                                 {  1.0,  0.0,  0.0 },
143                                 {  0.0, -0.6,  0.8 },
144                                 {  0.0, -0.8, -0.6 },
145                             },
146                             new double[] { 1, 2.64, -0.64 },
147                             new double[] { -5, -1.52 });
148     }
149 
150     private void checkMatricesValues(double[][] matrix, double[][] qRef,
151                                      double[] mainDiagnonal,
152                                      double[] secondaryDiagonal) {
153         TriDiagonalTransformer transformer =
154             new TriDiagonalTransformer(MatrixUtils.createRealMatrix(matrix));
155 
156         // check values against known references
157         RealMatrix q = transformer.getQ();
158         Assert.assertEquals(0, q.subtract(MatrixUtils.createRealMatrix(qRef)).getNorm(), 1.0e-14);
159 
160         RealMatrix t = transformer.getT();
161         double[][] tData = new double[mainDiagnonal.length][mainDiagnonal.length];
162         for (int i = 0; i < mainDiagnonal.length; ++i) {
163             tData[i][i] = mainDiagnonal[i];
164             if (i > 0) {
165                 tData[i][i - 1] = secondaryDiagonal[i - 1];
166             }
167             if (i < secondaryDiagonal.length) {
168                 tData[i][i + 1] = secondaryDiagonal[i];
169             }
170         }
171         Assert.assertEquals(0, t.subtract(MatrixUtils.createRealMatrix(tData)).getNorm(), 1.0e-14);
172 
173         // check the same cached instance is returned the second time
174         Assert.assertSame(q, transformer.getQ());
175         Assert.assertSame(t, transformer.getT());
176     }
177 }