View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.linear;
19  
20  import java.util.Random;
21  
22  import org.apache.commons.math4.legacy.exception.MathIllegalArgumentException;
23  import org.junit.Test;
24  import org.junit.Assert;
25  
26  public class QRSolverTest {
27      private double[][] testData3x3NonSingular = {
28              { 12, -51,   4 },
29              {  6, 167, -68 },
30              { -4,  24, -41 }
31      };
32  
33      private double[][] testData3x3Singular = {
34              { 1, 2,  2 },
35              { 2, 4,  6 },
36              { 4, 8, 12 }
37      };
38  
39      private double[][] testData3x4 = {
40              { 12, -51,   4, 1 },
41              {  6, 167, -68, 2 },
42              { -4,  24, -41, 3 }
43      };
44  
45      private double[][] testData4x3 = {
46              { 12, -51,   4 },
47              {  6, 167, -68 },
48              { -4,  24, -41 },
49              { -5,  34,   7 }
50      };
51  
52      /** test rank */
53      @Test
54      public void testRank() {
55          DecompositionSolver solver =
56              new QRDecomposition(MatrixUtils.createRealMatrix(testData3x3NonSingular)).getSolver();
57          Assert.assertTrue(solver.isNonSingular());
58  
59          solver = new QRDecomposition(MatrixUtils.createRealMatrix(testData3x3Singular)).getSolver();
60          Assert.assertFalse(solver.isNonSingular());
61  
62          solver = new QRDecomposition(MatrixUtils.createRealMatrix(testData3x4)).getSolver();
63          Assert.assertTrue(solver.isNonSingular());
64  
65          solver = new QRDecomposition(MatrixUtils.createRealMatrix(testData4x3)).getSolver();
66          Assert.assertTrue(solver.isNonSingular());
67      }
68  
69      /** test solve dimension errors */
70      @Test
71      public void testSolveDimensionErrors() {
72          DecompositionSolver solver =
73              new QRDecomposition(MatrixUtils.createRealMatrix(testData3x3NonSingular)).getSolver();
74          RealMatrix b = MatrixUtils.createRealMatrix(new double[2][2]);
75          try {
76              solver.solve(b);
77              Assert.fail("an exception should have been thrown");
78          } catch (MathIllegalArgumentException iae) {
79              // expected behavior
80          }
81          try {
82              solver.solve(b.getColumnVector(0));
83              Assert.fail("an exception should have been thrown");
84          } catch (MathIllegalArgumentException iae) {
85              // expected behavior
86          }
87      }
88  
89      /** test solve rank errors */
90      @Test
91      public void testSolveRankErrors() {
92          DecompositionSolver solver =
93              new QRDecomposition(MatrixUtils.createRealMatrix(testData3x3Singular)).getSolver();
94          RealMatrix b = MatrixUtils.createRealMatrix(new double[3][2]);
95          try {
96              solver.solve(b);
97              Assert.fail("an exception should have been thrown");
98          } catch (SingularMatrixException iae) {
99              // expected behavior
100         }
101         try {
102             solver.solve(b.getColumnVector(0));
103             Assert.fail("an exception should have been thrown");
104         } catch (SingularMatrixException iae) {
105             // expected behavior
106         }
107     }
108 
109     /** test solve */
110     @Test
111     public void testSolve() {
112         QRDecomposition decomposition =
113             new QRDecomposition(MatrixUtils.createRealMatrix(testData3x3NonSingular));
114         DecompositionSolver solver = decomposition.getSolver();
115         RealMatrix b = MatrixUtils.createRealMatrix(new double[][] {
116                 { -102, 12250 }, { 544, 24500 }, { 167, -36750 }
117         });
118         RealMatrix xRef = MatrixUtils.createRealMatrix(new double[][] {
119                 { 1, 2515 }, { 2, 422 }, { -3, 898 }
120         });
121 
122         // using RealMatrix
123         Assert.assertEquals(0, solver.solve(b).subtract(xRef).getNorm(), 2.0e-16 * xRef.getNorm());
124 
125         // using ArrayRealVector
126         for (int i = 0; i < b.getColumnDimension(); ++i) {
127             final RealVector x = solver.solve(b.getColumnVector(i));
128             final double error = x.subtract(xRef.getColumnVector(i)).getNorm();
129             Assert.assertEquals(0, error, 3.0e-16 * xRef.getColumnVector(i).getNorm());
130         }
131 
132         // using RealVector with an alternate implementation
133         for (int i = 0; i < b.getColumnDimension(); ++i) {
134             ArrayRealVectorTest.RealVectorTestImpl v =
135                 new ArrayRealVectorTest.RealVectorTestImpl(b.getColumn(i));
136             final RealVector x = solver.solve(v);
137             final double error = x.subtract(xRef.getColumnVector(i)).getNorm();
138             Assert.assertEquals(0, error, 3.0e-16 * xRef.getColumnVector(i).getNorm());
139         }
140     }
141 
142     @Test
143     public void testOverdetermined() {
144         final Random r    = new Random(5559252868205245L);
145         int          p    = (7 * BlockRealMatrix.BLOCK_SIZE) / 4;
146         int          q    = (5 * BlockRealMatrix.BLOCK_SIZE) / 4;
147         RealMatrix   a    = createTestMatrix(r, p, q);
148         RealMatrix   xRef = createTestMatrix(r, q, BlockRealMatrix.BLOCK_SIZE + 3);
149 
150         // build a perturbed system: A.X + noise = B
151         RealMatrix b = a.multiply(xRef);
152         final double noise = 0.001;
153         b.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
154             @Override
155             public double visit(int row, int column, double value) {
156                 return value * (1.0 + noise * (2 * r.nextDouble() - 1));
157             }
158         });
159 
160         // despite perturbation, the least square solution should be pretty good
161         RealMatrix x = new QRDecomposition(a).getSolver().solve(b);
162         Assert.assertEquals(0, x.subtract(xRef).getNorm(), 0.01 * noise * p * q);
163     }
164 
165     @Test
166     public void testUnderdetermined() {
167         final Random r    = new Random(42185006424567123L);
168         int          p    = (5 * BlockRealMatrix.BLOCK_SIZE) / 4;
169         int          q    = (7 * BlockRealMatrix.BLOCK_SIZE) / 4;
170         RealMatrix   a    = createTestMatrix(r, p, q);
171         RealMatrix   xRef = createTestMatrix(r, q, BlockRealMatrix.BLOCK_SIZE + 3);
172         RealMatrix   b    = a.multiply(xRef);
173         RealMatrix   x = new QRDecomposition(a).getSolver().solve(b);
174 
175         // too many equations, the system cannot be solved at all
176         Assert.assertTrue(x.subtract(xRef).getNorm() / (p * q) > 0.01);
177 
178         // the last unknown should have been set to 0
179         Assert.assertEquals(0.0, x.getSubMatrix(p, q - 1, 0, x.getColumnDimension() - 1).getNorm(), 0);
180     }
181 
182     private RealMatrix createTestMatrix(final Random r, final int rows, final int columns) {
183         RealMatrix m = MatrixUtils.createRealMatrix(rows, columns);
184         m.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
185                 @Override
186                     public double visit(int row, int column, double value) {
187                     return 2.0 * r.nextDouble() - 1.0;
188                 }
189             });
190         return m;
191     }
192 }