View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.linear;
18  
19  import java.util.Arrays;
20  
21  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
22  import org.apache.commons.math4.legacy.exception.MathUnsupportedOperationException;
23  import org.apache.commons.math4.legacy.exception.MaxCountExceededException;
24  import org.apache.commons.math4.core.jdkmath.JdkMath;
25  import org.junit.Assert;
26  import org.junit.Test;
27  
28  public class ConjugateGradientTest {
29  
30      @Test(expected = NonSquareOperatorException.class)
31      public void testNonSquareOperator() {
32          final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 3);
33          final IterativeLinearSolver solver;
34          solver = new ConjugateGradient(10, 0., false);
35          final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
36          final ArrayRealVector x = new ArrayRealVector(a.getColumnDimension());
37          solver.solve(a, b, x);
38      }
39  
40      @Test(expected = DimensionMismatchException.class)
41      public void testDimensionMismatchRightHandSide() {
42          final Array2DRowRealMatrix a = new Array2DRowRealMatrix(3, 3);
43          final IterativeLinearSolver solver;
44          solver = new ConjugateGradient(10, 0., false);
45          final ArrayRealVector b = new ArrayRealVector(2);
46          final ArrayRealVector x = new ArrayRealVector(3);
47          solver.solve(a, b, x);
48      }
49  
50      @Test(expected = DimensionMismatchException.class)
51      public void testDimensionMismatchSolution() {
52          final Array2DRowRealMatrix a = new Array2DRowRealMatrix(3, 3);
53          final IterativeLinearSolver solver;
54          solver = new ConjugateGradient(10, 0., false);
55          final ArrayRealVector b = new ArrayRealVector(3);
56          final ArrayRealVector x = new ArrayRealVector(2);
57          solver.solve(a, b, x);
58      }
59  
60      @Test(expected = NonPositiveDefiniteOperatorException.class)
61      public void testNonPositiveDefiniteLinearOperator() {
62          final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
63          a.setEntry(0, 0, -1.);
64          a.setEntry(0, 1, 2.);
65          a.setEntry(1, 0, 3.);
66          a.setEntry(1, 1, 4.);
67          final IterativeLinearSolver solver;
68          solver = new ConjugateGradient(10, 0., true);
69          final ArrayRealVector b = new ArrayRealVector(2);
70          b.setEntry(0, -1.);
71          b.setEntry(1, -1.);
72          final ArrayRealVector x = new ArrayRealVector(2);
73          solver.solve(a, b, x);
74      }
75  
76      @Test
77      public void testUnpreconditionedSolution() {
78          final int n = 5;
79          final int maxIterations = 100;
80          final RealLinearOperator a = new HilbertMatrix(n);
81          final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
82          final IterativeLinearSolver solver;
83          solver = new ConjugateGradient(maxIterations, 1E-10, true);
84          final RealVector b = new ArrayRealVector(n);
85          for (int j = 0; j < n; j++) {
86              b.set(0.);
87              b.setEntry(j, 1.);
88              final RealVector x = solver.solve(a, b);
89              for (int i = 0; i < n; i++) {
90                  final double actual = x.getEntry(i);
91                  final double expected = ainv.getEntry(i, j);
92                  final double delta = 1E-10 * JdkMath.abs(expected);
93                  final String msg = String.format("entry[%d][%d]", i, j);
94                  Assert.assertEquals(msg, expected, actual, delta);
95              }
96          }
97      }
98  
99      @Test
100     public void testUnpreconditionedInPlaceSolutionWithInitialGuess() {
101         final int n = 5;
102         final int maxIterations = 100;
103         final RealLinearOperator a = new HilbertMatrix(n);
104         final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
105         final IterativeLinearSolver solver;
106         solver = new ConjugateGradient(maxIterations, 1E-10, true);
107         final RealVector b = new ArrayRealVector(n);
108         for (int j = 0; j < n; j++) {
109             b.set(0.);
110             b.setEntry(j, 1.);
111             final RealVector x0 = new ArrayRealVector(n);
112             x0.set(1.);
113             final RealVector x = solver.solveInPlace(a, b, x0);
114             Assert.assertSame("x should be a reference to x0", x0, x);
115             for (int i = 0; i < n; i++) {
116                 final double actual = x.getEntry(i);
117                 final double expected = ainv.getEntry(i, j);
118                 final double delta = 1E-10 * JdkMath.abs(expected);
119                 final String msg = String.format("entry[%d][%d)", i, j);
120                 Assert.assertEquals(msg, expected, actual, delta);
121             }
122         }
123     }
124 
125     @Test
126     public void testUnpreconditionedSolutionWithInitialGuess() {
127         final int n = 5;
128         final int maxIterations = 100;
129         final RealLinearOperator a = new HilbertMatrix(n);
130         final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
131         final IterativeLinearSolver solver;
132         solver = new ConjugateGradient(maxIterations, 1E-10, true);
133         final RealVector b = new ArrayRealVector(n);
134         for (int j = 0; j < n; j++) {
135             b.set(0.);
136             b.setEntry(j, 1.);
137             final RealVector x0 = new ArrayRealVector(n);
138             x0.set(1.);
139             final RealVector x = solver.solve(a, b, x0);
140             Assert.assertNotSame("x should not be a reference to x0", x0, x);
141             for (int i = 0; i < n; i++) {
142                 final double actual = x.getEntry(i);
143                 final double expected = ainv.getEntry(i, j);
144                 final double delta = 1E-10 * JdkMath.abs(expected);
145                 final String msg = String.format("entry[%d][%d]", i, j);
146                 Assert.assertEquals(msg, expected, actual, delta);
147                 Assert.assertEquals(msg, x0.getEntry(i), 1., Math.ulp(1.));
148             }
149         }
150     }
151 
152     /**
153      * Check whether the estimate of the (updated) residual corresponds to the
154      * exact residual. This fails to be true for a large number of iterations,
155      * due to the loss of orthogonality of the successive search directions.
156      * Therefore, in the present test, the number of iterations is limited.
157      */
158     @Test
159     public void testUnpreconditionedResidual() {
160         final int n = 10;
161         final int maxIterations = n;
162         final RealLinearOperator a = new HilbertMatrix(n);
163         final ConjugateGradient solver;
164         solver = new ConjugateGradient(maxIterations, 1E-15, true);
165         final RealVector r = new ArrayRealVector(n);
166         final RealVector x = new ArrayRealVector(n);
167         final IterationListener listener = new IterationListener() {
168 
169             @Override
170             public void terminationPerformed(final IterationEvent e) {
171                 // Do nothing
172             }
173 
174             @Override
175             public void iterationStarted(final IterationEvent e) {
176                 // Do nothing
177             }
178 
179             @Override
180             public void iterationPerformed(final IterationEvent e) {
181                 final IterativeLinearSolverEvent evt;
182                 evt = (IterativeLinearSolverEvent) e;
183                 RealVector v = evt.getResidual();
184                 r.setSubVector(0, v);
185                 v = evt.getSolution();
186                 x.setSubVector(0, v);
187             }
188 
189             @Override
190             public void initializationPerformed(final IterationEvent e) {
191                 // Do nothing
192             }
193         };
194         solver.getIterationManager().addIterationListener(listener);
195         final RealVector b = new ArrayRealVector(n);
196         for (int j = 0; j < n; j++) {
197             b.set(0.);
198             b.setEntry(j, 1.);
199 
200             boolean caught = false;
201             try {
202                 solver.solve(a, b);
203             } catch (MaxCountExceededException e) {
204                 caught = true;
205                 final RealVector y = a.operate(x);
206                 for (int i = 0; i < n; i++) {
207                     final double actual = b.getEntry(i) - y.getEntry(i);
208                     final double expected = r.getEntry(i);
209                     final double delta = 1E-6 * JdkMath.abs(expected);
210                     final String msg = String
211                         .format("column %d, residual %d", i, j);
212                     Assert.assertEquals(msg, expected, actual, delta);
213                 }
214             }
215             Assert
216                 .assertTrue("MaxCountExceededException should have been caught",
217                             caught);
218         }
219     }
220 
221     @Test(expected = NonSquareOperatorException.class)
222     public void testNonSquarePreconditioner() {
223         final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
224         final RealLinearOperator m = new RealLinearOperator() {
225 
226             @Override
227             public RealVector operate(final RealVector x) {
228                 throw new UnsupportedOperationException();
229             }
230 
231             @Override
232             public int getRowDimension() {
233                 return 2;
234             }
235 
236             @Override
237             public int getColumnDimension() {
238                 return 3;
239             }
240         };
241         final PreconditionedIterativeLinearSolver solver;
242         solver = new ConjugateGradient(10, 0d, false);
243         final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
244         solver.solve(a, m, b);
245     }
246 
247     @Test(expected = DimensionMismatchException.class)
248     public void testMismatchedOperatorDimensions() {
249         final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
250         final RealLinearOperator m = new RealLinearOperator() {
251 
252             @Override
253             public RealVector operate(final RealVector x) {
254                 throw new UnsupportedOperationException();
255             }
256 
257             @Override
258             public int getRowDimension() {
259                 return 3;
260             }
261 
262             @Override
263             public int getColumnDimension() {
264                 return 3;
265             }
266         };
267         final PreconditionedIterativeLinearSolver solver;
268         solver = new ConjugateGradient(10, 0d, false);
269         final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
270         solver.solve(a, m, b);
271     }
272 
273     @Test(expected = NonPositiveDefiniteOperatorException.class)
274     public void testNonPositiveDefinitePreconditioner() {
275         final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
276         a.setEntry(0, 0, 1d);
277         a.setEntry(0, 1, 2d);
278         a.setEntry(1, 0, 3d);
279         a.setEntry(1, 1, 4d);
280         final RealLinearOperator m = new RealLinearOperator() {
281 
282             @Override
283             public RealVector operate(final RealVector x) {
284                 final ArrayRealVector y = new ArrayRealVector(2);
285                 y.setEntry(0, -x.getEntry(0));
286                 y.setEntry(1, x.getEntry(1));
287                 return y;
288             }
289 
290             @Override
291             public int getRowDimension() {
292                 return 2;
293             }
294 
295             @Override
296             public int getColumnDimension() {
297                 return 2;
298             }
299         };
300         final PreconditionedIterativeLinearSolver solver;
301         solver = new ConjugateGradient(10, 0d, true);
302         final ArrayRealVector b = new ArrayRealVector(2);
303         b.setEntry(0, -1d);
304         b.setEntry(1, -1d);
305         solver.solve(a, m, b);
306     }
307 
308     @Test
309     public void testPreconditionedSolution() {
310         final int n = 8;
311         final int maxIterations = 100;
312         final RealLinearOperator a = new HilbertMatrix(n);
313         final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
314         final RealLinearOperator m = JacobiPreconditioner.create(a);
315         final PreconditionedIterativeLinearSolver solver;
316         solver = new ConjugateGradient(maxIterations, 1E-15, true);
317         final RealVector b = new ArrayRealVector(n);
318         for (int j = 0; j < n; j++) {
319             b.set(0.);
320             b.setEntry(j, 1.);
321             final RealVector x = solver.solve(a, m, b);
322             for (int i = 0; i < n; i++) {
323                 final double actual = x.getEntry(i);
324                 final double expected = ainv.getEntry(i, j);
325                 final double delta = 1E-6 * JdkMath.abs(expected);
326                 final String msg = String.format("coefficient (%d, %d)", i, j);
327                 Assert.assertEquals(msg, expected, actual, delta);
328             }
329         }
330     }
331 
332     @Test
333     public void testPreconditionedResidual() {
334         final int n = 10;
335         final int maxIterations = n;
336         final RealLinearOperator a = new HilbertMatrix(n);
337         final RealLinearOperator m = JacobiPreconditioner.create(a);
338         final ConjugateGradient solver;
339         solver = new ConjugateGradient(maxIterations, 1E-15, true);
340         final RealVector r = new ArrayRealVector(n);
341         final RealVector x = new ArrayRealVector(n);
342         final IterationListener listener = new IterationListener() {
343 
344             @Override
345             public void terminationPerformed(final IterationEvent e) {
346                 // Do nothing
347             }
348 
349             @Override
350             public void iterationStarted(final IterationEvent e) {
351                 // Do nothing
352             }
353 
354             @Override
355             public void iterationPerformed(final IterationEvent e) {
356                 final IterativeLinearSolverEvent evt;
357                 evt = (IterativeLinearSolverEvent) e;
358                 RealVector v = evt.getResidual();
359                 r.setSubVector(0, v);
360                 v = evt.getSolution();
361                 x.setSubVector(0, v);
362             }
363 
364             @Override
365             public void initializationPerformed(final IterationEvent e) {
366                 // Do nothing
367             }
368         };
369         solver.getIterationManager().addIterationListener(listener);
370         final RealVector b = new ArrayRealVector(n);
371 
372         for (int j = 0; j < n; j++) {
373             b.set(0.);
374             b.setEntry(j, 1.);
375 
376             boolean caught = false;
377             try {
378                 solver.solve(a, m, b);
379             } catch (MaxCountExceededException e) {
380                 caught = true;
381                 final RealVector y = a.operate(x);
382                 for (int i = 0; i < n; i++) {
383                     final double actual = b.getEntry(i) - y.getEntry(i);
384                     final double expected = r.getEntry(i);
385                     final double delta = 1E-6 * JdkMath.abs(expected);
386                     final String msg = String.format("column %d, residual %d", i, j);
387                     Assert.assertEquals(msg, expected, actual, delta);
388                 }
389             }
390             Assert.assertTrue("MaxCountExceededException should have been caught", caught);
391         }
392     }
393 
394     @Test
395     public void testPreconditionedSolution2() {
396         final int n = 100;
397         final int maxIterations = 100000;
398         final Array2DRowRealMatrix a = new Array2DRowRealMatrix(n, n);
399         double daux = 1.;
400         for (int i = 0; i < n; i++) {
401             a.setEntry(i, i, daux);
402             daux *= 1.2;
403             for (int j = i + 1; j < n; j++) {
404                 if (i != j) {
405                     final double value = 1.0;
406                     a.setEntry(i, j, value);
407                     a.setEntry(j, i, value);
408                 }
409             }
410         }
411         final RealLinearOperator m = JacobiPreconditioner.create(a);
412         final PreconditionedIterativeLinearSolver pcg;
413         final IterativeLinearSolver cg;
414         pcg = new ConjugateGradient(maxIterations, 1E-6, true);
415         cg = new ConjugateGradient(maxIterations, 1E-6, true);
416         final RealVector b = new ArrayRealVector(n);
417         final String pattern = "preconditioned gradient (%d iterations) should"
418                                + " have been faster than unpreconditioned (%d iterations)";
419         String msg;
420         for (int j = 0; j < 1; j++) {
421             b.set(0.);
422             b.setEntry(j, 1.);
423             final RealVector px = pcg.solve(a, m, b);
424             final RealVector x = cg.solve(a, b);
425             final int npcg = pcg.getIterationManager().getIterations();
426             final int ncg = cg.getIterationManager().getIterations();
427             msg = String.format(pattern, npcg, ncg);
428             Assert.assertTrue(msg, npcg < ncg);
429             for (int i = 0; i < n; i++) {
430                 msg = String.format("row %d, column %d", i, j);
431                 final double expected = x.getEntry(i);
432                 final double actual = px.getEntry(i);
433                 final double delta = 1E-6 * JdkMath.abs(expected);
434                 Assert.assertEquals(msg, expected, actual, delta);
435             }
436         }
437     }
438 
439     @Test
440     public void testEventManagement() {
441         final int n = 5;
442         final int maxIterations = 100;
443         final RealLinearOperator a = new HilbertMatrix(n);
444         final IterativeLinearSolver solver;
445         /*
446          * count[0] = number of calls to initializationPerformed
447          * count[1] = number of calls to iterationStarted
448          * count[2] = number of calls to iterationPerformed
449          * count[3] = number of calls to terminationPerformed
450          */
451         final int[] count = new int[] {0, 0, 0, 0};
452         final IterationListener listener = new IterationListener() {
453             private void doTestVectorsAreUnmodifiable(final IterationEvent e) {
454                 final IterativeLinearSolverEvent evt;
455                 evt = (IterativeLinearSolverEvent) e;
456                 try {
457                     evt.getResidual().set(0.0);
458                     Assert.fail("r is modifiable");
459                 } catch (MathUnsupportedOperationException exc){
460                     // Expected behavior
461                 }
462                 try {
463                     evt.getRightHandSideVector().set(0.0);
464                     Assert.fail("b is modifiable");
465                 } catch (MathUnsupportedOperationException exc){
466                     // Expected behavior
467                 }
468                 try {
469                     evt.getSolution().set(0.0);
470                     Assert.fail("x is modifiable");
471                 } catch (MathUnsupportedOperationException exc){
472                     // Expected behavior
473                 }
474             }
475 
476             @Override
477             public void initializationPerformed(final IterationEvent e) {
478                 ++count[0];
479                 doTestVectorsAreUnmodifiable(e);
480             }
481 
482             @Override
483             public void iterationPerformed(final IterationEvent e) {
484                 ++count[2];
485                 Assert.assertEquals("iteration performed",
486                     count[2], e.getIterations() - 1);
487                 doTestVectorsAreUnmodifiable(e);
488             }
489 
490             @Override
491             public void iterationStarted(final IterationEvent e) {
492                 ++count[1];
493                 Assert.assertEquals("iteration started",
494                     count[1], e.getIterations() - 1);
495                 doTestVectorsAreUnmodifiable(e);
496             }
497 
498             @Override
499             public void terminationPerformed(final IterationEvent e) {
500                 ++count[3];
501                 doTestVectorsAreUnmodifiable(e);
502             }
503         };
504         solver = new ConjugateGradient(maxIterations, 1E-10, true);
505         solver.getIterationManager().addIterationListener(listener);
506         final RealVector b = new ArrayRealVector(n);
507         for (int j = 0; j < n; j++) {
508             Arrays.fill(count, 0);
509             b.set(0.);
510             b.setEntry(j, 1.);
511             solver.solve(a, b);
512             String msg = String.format("column %d (initialization)", j);
513             Assert.assertEquals(msg, 1, count[0]);
514             msg = String.format("column %d (finalization)", j);
515             Assert.assertEquals(msg, 1, count[3]);
516         }
517     }
518 
519     @Test
520     public void testUnpreconditionedNormOfResidual() {
521         final int n = 5;
522         final int maxIterations = 100;
523         final RealLinearOperator a = new HilbertMatrix(n);
524         final IterativeLinearSolver solver;
525         final IterationListener listener = new IterationListener() {
526 
527             private void doTestNormOfResidual(final IterationEvent e) {
528                 final IterativeLinearSolverEvent evt;
529                 evt = (IterativeLinearSolverEvent) e;
530                 final RealVector x = evt.getSolution();
531                 final RealVector b = evt.getRightHandSideVector();
532                 final RealVector r = b.subtract(a.operate(x));
533                 final double rnorm = r.getNorm();
534                 Assert.assertEquals("iteration performed (residual)",
535                     rnorm, evt.getNormOfResidual(),
536                     JdkMath.max(1E-5 * rnorm, 1E-10));
537             }
538 
539             @Override
540             public void initializationPerformed(final IterationEvent e) {
541                 doTestNormOfResidual(e);
542             }
543 
544             @Override
545             public void iterationPerformed(final IterationEvent e) {
546                 doTestNormOfResidual(e);
547             }
548 
549             @Override
550             public void iterationStarted(final IterationEvent e) {
551                 doTestNormOfResidual(e);
552             }
553 
554             @Override
555             public void terminationPerformed(final IterationEvent e) {
556                 doTestNormOfResidual(e);
557             }
558         };
559         solver = new ConjugateGradient(maxIterations, 1E-10, true);
560         solver.getIterationManager().addIterationListener(listener);
561         final RealVector b = new ArrayRealVector(n);
562         for (int j = 0; j < n; j++) {
563             b.set(0.);
564             b.setEntry(j, 1.);
565             solver.solve(a, b);
566         }
567     }
568 
569     @Test
570     public void testPreconditionedNormOfResidual() {
571         final int n = 5;
572         final int maxIterations = 100;
573         final RealLinearOperator a = new HilbertMatrix(n);
574         final RealLinearOperator m = JacobiPreconditioner.create(a);
575         final PreconditionedIterativeLinearSolver solver;
576         final IterationListener listener = new IterationListener() {
577 
578             private void doTestNormOfResidual(final IterationEvent e) {
579                 final IterativeLinearSolverEvent evt;
580                 evt = (IterativeLinearSolverEvent) e;
581                 final RealVector x = evt.getSolution();
582                 final RealVector b = evt.getRightHandSideVector();
583                 final RealVector r = b.subtract(a.operate(x));
584                 final double rnorm = r.getNorm();
585                 Assert.assertEquals("iteration performed (residual)",
586                     rnorm, evt.getNormOfResidual(),
587                     JdkMath.max(1E-5 * rnorm, 1E-10));
588             }
589 
590             @Override
591             public void initializationPerformed(final IterationEvent e) {
592                 doTestNormOfResidual(e);
593             }
594 
595             @Override
596             public void iterationPerformed(final IterationEvent e) {
597                 doTestNormOfResidual(e);
598             }
599 
600             @Override
601             public void iterationStarted(final IterationEvent e) {
602                 doTestNormOfResidual(e);
603             }
604 
605             @Override
606             public void terminationPerformed(final IterationEvent e) {
607                 doTestNormOfResidual(e);
608             }
609         };
610         solver = new ConjugateGradient(maxIterations, 1E-10, true);
611         solver.getIterationManager().addIterationListener(listener);
612         final RealVector b = new ArrayRealVector(n);
613         for (int j = 0; j < n; j++) {
614             b.set(0.);
615             b.setEntry(j, 1.);
616             solver.solve(a, m, b);
617         }
618     }
619 }