1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.linear;
18
19 import java.util.Arrays;
20
21 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
22 import org.apache.commons.math4.legacy.exception.MathUnsupportedOperationException;
23 import org.apache.commons.math4.legacy.exception.MaxCountExceededException;
24 import org.apache.commons.math4.core.jdkmath.JdkMath;
25 import org.junit.Assert;
26 import org.junit.Test;
27
28 public class ConjugateGradientTest {
29
30 @Test(expected = NonSquareOperatorException.class)
31 public void testNonSquareOperator() {
32 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 3);
33 final IterativeLinearSolver solver;
34 solver = new ConjugateGradient(10, 0., false);
35 final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
36 final ArrayRealVector x = new ArrayRealVector(a.getColumnDimension());
37 solver.solve(a, b, x);
38 }
39
40 @Test(expected = DimensionMismatchException.class)
41 public void testDimensionMismatchRightHandSide() {
42 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(3, 3);
43 final IterativeLinearSolver solver;
44 solver = new ConjugateGradient(10, 0., false);
45 final ArrayRealVector b = new ArrayRealVector(2);
46 final ArrayRealVector x = new ArrayRealVector(3);
47 solver.solve(a, b, x);
48 }
49
50 @Test(expected = DimensionMismatchException.class)
51 public void testDimensionMismatchSolution() {
52 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(3, 3);
53 final IterativeLinearSolver solver;
54 solver = new ConjugateGradient(10, 0., false);
55 final ArrayRealVector b = new ArrayRealVector(3);
56 final ArrayRealVector x = new ArrayRealVector(2);
57 solver.solve(a, b, x);
58 }
59
60 @Test(expected = NonPositiveDefiniteOperatorException.class)
61 public void testNonPositiveDefiniteLinearOperator() {
62 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
63 a.setEntry(0, 0, -1.);
64 a.setEntry(0, 1, 2.);
65 a.setEntry(1, 0, 3.);
66 a.setEntry(1, 1, 4.);
67 final IterativeLinearSolver solver;
68 solver = new ConjugateGradient(10, 0., true);
69 final ArrayRealVector b = new ArrayRealVector(2);
70 b.setEntry(0, -1.);
71 b.setEntry(1, -1.);
72 final ArrayRealVector x = new ArrayRealVector(2);
73 solver.solve(a, b, x);
74 }
75
76 @Test
77 public void testUnpreconditionedSolution() {
78 final int n = 5;
79 final int maxIterations = 100;
80 final RealLinearOperator a = new HilbertMatrix(n);
81 final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
82 final IterativeLinearSolver solver;
83 solver = new ConjugateGradient(maxIterations, 1E-10, true);
84 final RealVector b = new ArrayRealVector(n);
85 for (int j = 0; j < n; j++) {
86 b.set(0.);
87 b.setEntry(j, 1.);
88 final RealVector x = solver.solve(a, b);
89 for (int i = 0; i < n; i++) {
90 final double actual = x.getEntry(i);
91 final double expected = ainv.getEntry(i, j);
92 final double delta = 1E-10 * JdkMath.abs(expected);
93 final String msg = String.format("entry[%d][%d]", i, j);
94 Assert.assertEquals(msg, expected, actual, delta);
95 }
96 }
97 }
98
99 @Test
100 public void testUnpreconditionedInPlaceSolutionWithInitialGuess() {
101 final int n = 5;
102 final int maxIterations = 100;
103 final RealLinearOperator a = new HilbertMatrix(n);
104 final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
105 final IterativeLinearSolver solver;
106 solver = new ConjugateGradient(maxIterations, 1E-10, true);
107 final RealVector b = new ArrayRealVector(n);
108 for (int j = 0; j < n; j++) {
109 b.set(0.);
110 b.setEntry(j, 1.);
111 final RealVector x0 = new ArrayRealVector(n);
112 x0.set(1.);
113 final RealVector x = solver.solveInPlace(a, b, x0);
114 Assert.assertSame("x should be a reference to x0", x0, x);
115 for (int i = 0; i < n; i++) {
116 final double actual = x.getEntry(i);
117 final double expected = ainv.getEntry(i, j);
118 final double delta = 1E-10 * JdkMath.abs(expected);
119 final String msg = String.format("entry[%d][%d)", i, j);
120 Assert.assertEquals(msg, expected, actual, delta);
121 }
122 }
123 }
124
125 @Test
126 public void testUnpreconditionedSolutionWithInitialGuess() {
127 final int n = 5;
128 final int maxIterations = 100;
129 final RealLinearOperator a = new HilbertMatrix(n);
130 final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
131 final IterativeLinearSolver solver;
132 solver = new ConjugateGradient(maxIterations, 1E-10, true);
133 final RealVector b = new ArrayRealVector(n);
134 for (int j = 0; j < n; j++) {
135 b.set(0.);
136 b.setEntry(j, 1.);
137 final RealVector x0 = new ArrayRealVector(n);
138 x0.set(1.);
139 final RealVector x = solver.solve(a, b, x0);
140 Assert.assertNotSame("x should not be a reference to x0", x0, x);
141 for (int i = 0; i < n; i++) {
142 final double actual = x.getEntry(i);
143 final double expected = ainv.getEntry(i, j);
144 final double delta = 1E-10 * JdkMath.abs(expected);
145 final String msg = String.format("entry[%d][%d]", i, j);
146 Assert.assertEquals(msg, expected, actual, delta);
147 Assert.assertEquals(msg, x0.getEntry(i), 1., Math.ulp(1.));
148 }
149 }
150 }
151
152
153
154
155
156
157
158 @Test
159 public void testUnpreconditionedResidual() {
160 final int n = 10;
161 final int maxIterations = n;
162 final RealLinearOperator a = new HilbertMatrix(n);
163 final ConjugateGradient solver;
164 solver = new ConjugateGradient(maxIterations, 1E-15, true);
165 final RealVector r = new ArrayRealVector(n);
166 final RealVector x = new ArrayRealVector(n);
167 final IterationListener listener = new IterationListener() {
168
169 @Override
170 public void terminationPerformed(final IterationEvent e) {
171
172 }
173
174 @Override
175 public void iterationStarted(final IterationEvent e) {
176
177 }
178
179 @Override
180 public void iterationPerformed(final IterationEvent e) {
181 final IterativeLinearSolverEvent evt;
182 evt = (IterativeLinearSolverEvent) e;
183 RealVector v = evt.getResidual();
184 r.setSubVector(0, v);
185 v = evt.getSolution();
186 x.setSubVector(0, v);
187 }
188
189 @Override
190 public void initializationPerformed(final IterationEvent e) {
191
192 }
193 };
194 solver.getIterationManager().addIterationListener(listener);
195 final RealVector b = new ArrayRealVector(n);
196 for (int j = 0; j < n; j++) {
197 b.set(0.);
198 b.setEntry(j, 1.);
199
200 boolean caught = false;
201 try {
202 solver.solve(a, b);
203 } catch (MaxCountExceededException e) {
204 caught = true;
205 final RealVector y = a.operate(x);
206 for (int i = 0; i < n; i++) {
207 final double actual = b.getEntry(i) - y.getEntry(i);
208 final double expected = r.getEntry(i);
209 final double delta = 1E-6 * JdkMath.abs(expected);
210 final String msg = String
211 .format("column %d, residual %d", i, j);
212 Assert.assertEquals(msg, expected, actual, delta);
213 }
214 }
215 Assert
216 .assertTrue("MaxCountExceededException should have been caught",
217 caught);
218 }
219 }
220
221 @Test(expected = NonSquareOperatorException.class)
222 public void testNonSquarePreconditioner() {
223 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
224 final RealLinearOperator m = new RealLinearOperator() {
225
226 @Override
227 public RealVector operate(final RealVector x) {
228 throw new UnsupportedOperationException();
229 }
230
231 @Override
232 public int getRowDimension() {
233 return 2;
234 }
235
236 @Override
237 public int getColumnDimension() {
238 return 3;
239 }
240 };
241 final PreconditionedIterativeLinearSolver solver;
242 solver = new ConjugateGradient(10, 0d, false);
243 final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
244 solver.solve(a, m, b);
245 }
246
247 @Test(expected = DimensionMismatchException.class)
248 public void testMismatchedOperatorDimensions() {
249 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
250 final RealLinearOperator m = new RealLinearOperator() {
251
252 @Override
253 public RealVector operate(final RealVector x) {
254 throw new UnsupportedOperationException();
255 }
256
257 @Override
258 public int getRowDimension() {
259 return 3;
260 }
261
262 @Override
263 public int getColumnDimension() {
264 return 3;
265 }
266 };
267 final PreconditionedIterativeLinearSolver solver;
268 solver = new ConjugateGradient(10, 0d, false);
269 final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
270 solver.solve(a, m, b);
271 }
272
273 @Test(expected = NonPositiveDefiniteOperatorException.class)
274 public void testNonPositiveDefinitePreconditioner() {
275 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
276 a.setEntry(0, 0, 1d);
277 a.setEntry(0, 1, 2d);
278 a.setEntry(1, 0, 3d);
279 a.setEntry(1, 1, 4d);
280 final RealLinearOperator m = new RealLinearOperator() {
281
282 @Override
283 public RealVector operate(final RealVector x) {
284 final ArrayRealVector y = new ArrayRealVector(2);
285 y.setEntry(0, -x.getEntry(0));
286 y.setEntry(1, x.getEntry(1));
287 return y;
288 }
289
290 @Override
291 public int getRowDimension() {
292 return 2;
293 }
294
295 @Override
296 public int getColumnDimension() {
297 return 2;
298 }
299 };
300 final PreconditionedIterativeLinearSolver solver;
301 solver = new ConjugateGradient(10, 0d, true);
302 final ArrayRealVector b = new ArrayRealVector(2);
303 b.setEntry(0, -1d);
304 b.setEntry(1, -1d);
305 solver.solve(a, m, b);
306 }
307
308 @Test
309 public void testPreconditionedSolution() {
310 final int n = 8;
311 final int maxIterations = 100;
312 final RealLinearOperator a = new HilbertMatrix(n);
313 final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
314 final RealLinearOperator m = JacobiPreconditioner.create(a);
315 final PreconditionedIterativeLinearSolver solver;
316 solver = new ConjugateGradient(maxIterations, 1E-15, true);
317 final RealVector b = new ArrayRealVector(n);
318 for (int j = 0; j < n; j++) {
319 b.set(0.);
320 b.setEntry(j, 1.);
321 final RealVector x = solver.solve(a, m, b);
322 for (int i = 0; i < n; i++) {
323 final double actual = x.getEntry(i);
324 final double expected = ainv.getEntry(i, j);
325 final double delta = 1E-6 * JdkMath.abs(expected);
326 final String msg = String.format("coefficient (%d, %d)", i, j);
327 Assert.assertEquals(msg, expected, actual, delta);
328 }
329 }
330 }
331
332 @Test
333 public void testPreconditionedResidual() {
334 final int n = 10;
335 final int maxIterations = n;
336 final RealLinearOperator a = new HilbertMatrix(n);
337 final RealLinearOperator m = JacobiPreconditioner.create(a);
338 final ConjugateGradient solver;
339 solver = new ConjugateGradient(maxIterations, 1E-15, true);
340 final RealVector r = new ArrayRealVector(n);
341 final RealVector x = new ArrayRealVector(n);
342 final IterationListener listener = new IterationListener() {
343
344 @Override
345 public void terminationPerformed(final IterationEvent e) {
346
347 }
348
349 @Override
350 public void iterationStarted(final IterationEvent e) {
351
352 }
353
354 @Override
355 public void iterationPerformed(final IterationEvent e) {
356 final IterativeLinearSolverEvent evt;
357 evt = (IterativeLinearSolverEvent) e;
358 RealVector v = evt.getResidual();
359 r.setSubVector(0, v);
360 v = evt.getSolution();
361 x.setSubVector(0, v);
362 }
363
364 @Override
365 public void initializationPerformed(final IterationEvent e) {
366
367 }
368 };
369 solver.getIterationManager().addIterationListener(listener);
370 final RealVector b = new ArrayRealVector(n);
371
372 for (int j = 0; j < n; j++) {
373 b.set(0.);
374 b.setEntry(j, 1.);
375
376 boolean caught = false;
377 try {
378 solver.solve(a, m, b);
379 } catch (MaxCountExceededException e) {
380 caught = true;
381 final RealVector y = a.operate(x);
382 for (int i = 0; i < n; i++) {
383 final double actual = b.getEntry(i) - y.getEntry(i);
384 final double expected = r.getEntry(i);
385 final double delta = 1E-6 * JdkMath.abs(expected);
386 final String msg = String.format("column %d, residual %d", i, j);
387 Assert.assertEquals(msg, expected, actual, delta);
388 }
389 }
390 Assert.assertTrue("MaxCountExceededException should have been caught", caught);
391 }
392 }
393
394 @Test
395 public void testPreconditionedSolution2() {
396 final int n = 100;
397 final int maxIterations = 100000;
398 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(n, n);
399 double daux = 1.;
400 for (int i = 0; i < n; i++) {
401 a.setEntry(i, i, daux);
402 daux *= 1.2;
403 for (int j = i + 1; j < n; j++) {
404 if (i != j) {
405 final double value = 1.0;
406 a.setEntry(i, j, value);
407 a.setEntry(j, i, value);
408 }
409 }
410 }
411 final RealLinearOperator m = JacobiPreconditioner.create(a);
412 final PreconditionedIterativeLinearSolver pcg;
413 final IterativeLinearSolver cg;
414 pcg = new ConjugateGradient(maxIterations, 1E-6, true);
415 cg = new ConjugateGradient(maxIterations, 1E-6, true);
416 final RealVector b = new ArrayRealVector(n);
417 final String pattern = "preconditioned gradient (%d iterations) should"
418 + " have been faster than unpreconditioned (%d iterations)";
419 String msg;
420 for (int j = 0; j < 1; j++) {
421 b.set(0.);
422 b.setEntry(j, 1.);
423 final RealVector px = pcg.solve(a, m, b);
424 final RealVector x = cg.solve(a, b);
425 final int npcg = pcg.getIterationManager().getIterations();
426 final int ncg = cg.getIterationManager().getIterations();
427 msg = String.format(pattern, npcg, ncg);
428 Assert.assertTrue(msg, npcg < ncg);
429 for (int i = 0; i < n; i++) {
430 msg = String.format("row %d, column %d", i, j);
431 final double expected = x.getEntry(i);
432 final double actual = px.getEntry(i);
433 final double delta = 1E-6 * JdkMath.abs(expected);
434 Assert.assertEquals(msg, expected, actual, delta);
435 }
436 }
437 }
438
439 @Test
440 public void testEventManagement() {
441 final int n = 5;
442 final int maxIterations = 100;
443 final RealLinearOperator a = new HilbertMatrix(n);
444 final IterativeLinearSolver solver;
445
446
447
448
449
450
451 final int[] count = new int[] {0, 0, 0, 0};
452 final IterationListener listener = new IterationListener() {
453 private void doTestVectorsAreUnmodifiable(final IterationEvent e) {
454 final IterativeLinearSolverEvent evt;
455 evt = (IterativeLinearSolverEvent) e;
456 try {
457 evt.getResidual().set(0.0);
458 Assert.fail("r is modifiable");
459 } catch (MathUnsupportedOperationException exc){
460
461 }
462 try {
463 evt.getRightHandSideVector().set(0.0);
464 Assert.fail("b is modifiable");
465 } catch (MathUnsupportedOperationException exc){
466
467 }
468 try {
469 evt.getSolution().set(0.0);
470 Assert.fail("x is modifiable");
471 } catch (MathUnsupportedOperationException exc){
472
473 }
474 }
475
476 @Override
477 public void initializationPerformed(final IterationEvent e) {
478 ++count[0];
479 doTestVectorsAreUnmodifiable(e);
480 }
481
482 @Override
483 public void iterationPerformed(final IterationEvent e) {
484 ++count[2];
485 Assert.assertEquals("iteration performed",
486 count[2], e.getIterations() - 1);
487 doTestVectorsAreUnmodifiable(e);
488 }
489
490 @Override
491 public void iterationStarted(final IterationEvent e) {
492 ++count[1];
493 Assert.assertEquals("iteration started",
494 count[1], e.getIterations() - 1);
495 doTestVectorsAreUnmodifiable(e);
496 }
497
498 @Override
499 public void terminationPerformed(final IterationEvent e) {
500 ++count[3];
501 doTestVectorsAreUnmodifiable(e);
502 }
503 };
504 solver = new ConjugateGradient(maxIterations, 1E-10, true);
505 solver.getIterationManager().addIterationListener(listener);
506 final RealVector b = new ArrayRealVector(n);
507 for (int j = 0; j < n; j++) {
508 Arrays.fill(count, 0);
509 b.set(0.);
510 b.setEntry(j, 1.);
511 solver.solve(a, b);
512 String msg = String.format("column %d (initialization)", j);
513 Assert.assertEquals(msg, 1, count[0]);
514 msg = String.format("column %d (finalization)", j);
515 Assert.assertEquals(msg, 1, count[3]);
516 }
517 }
518
519 @Test
520 public void testUnpreconditionedNormOfResidual() {
521 final int n = 5;
522 final int maxIterations = 100;
523 final RealLinearOperator a = new HilbertMatrix(n);
524 final IterativeLinearSolver solver;
525 final IterationListener listener = new IterationListener() {
526
527 private void doTestNormOfResidual(final IterationEvent e) {
528 final IterativeLinearSolverEvent evt;
529 evt = (IterativeLinearSolverEvent) e;
530 final RealVector x = evt.getSolution();
531 final RealVector b = evt.getRightHandSideVector();
532 final RealVector r = b.subtract(a.operate(x));
533 final double rnorm = r.getNorm();
534 Assert.assertEquals("iteration performed (residual)",
535 rnorm, evt.getNormOfResidual(),
536 JdkMath.max(1E-5 * rnorm, 1E-10));
537 }
538
539 @Override
540 public void initializationPerformed(final IterationEvent e) {
541 doTestNormOfResidual(e);
542 }
543
544 @Override
545 public void iterationPerformed(final IterationEvent e) {
546 doTestNormOfResidual(e);
547 }
548
549 @Override
550 public void iterationStarted(final IterationEvent e) {
551 doTestNormOfResidual(e);
552 }
553
554 @Override
555 public void terminationPerformed(final IterationEvent e) {
556 doTestNormOfResidual(e);
557 }
558 };
559 solver = new ConjugateGradient(maxIterations, 1E-10, true);
560 solver.getIterationManager().addIterationListener(listener);
561 final RealVector b = new ArrayRealVector(n);
562 for (int j = 0; j < n; j++) {
563 b.set(0.);
564 b.setEntry(j, 1.);
565 solver.solve(a, b);
566 }
567 }
568
569 @Test
570 public void testPreconditionedNormOfResidual() {
571 final int n = 5;
572 final int maxIterations = 100;
573 final RealLinearOperator a = new HilbertMatrix(n);
574 final RealLinearOperator m = JacobiPreconditioner.create(a);
575 final PreconditionedIterativeLinearSolver solver;
576 final IterationListener listener = new IterationListener() {
577
578 private void doTestNormOfResidual(final IterationEvent e) {
579 final IterativeLinearSolverEvent evt;
580 evt = (IterativeLinearSolverEvent) e;
581 final RealVector x = evt.getSolution();
582 final RealVector b = evt.getRightHandSideVector();
583 final RealVector r = b.subtract(a.operate(x));
584 final double rnorm = r.getNorm();
585 Assert.assertEquals("iteration performed (residual)",
586 rnorm, evt.getNormOfResidual(),
587 JdkMath.max(1E-5 * rnorm, 1E-10));
588 }
589
590 @Override
591 public void initializationPerformed(final IterationEvent e) {
592 doTestNormOfResidual(e);
593 }
594
595 @Override
596 public void iterationPerformed(final IterationEvent e) {
597 doTestNormOfResidual(e);
598 }
599
600 @Override
601 public void iterationStarted(final IterationEvent e) {
602 doTestNormOfResidual(e);
603 }
604
605 @Override
606 public void terminationPerformed(final IterationEvent e) {
607 doTestNormOfResidual(e);
608 }
609 };
610 solver = new ConjugateGradient(maxIterations, 1E-10, true);
611 solver.getIterationManager().addIterationListener(listener);
612 final RealVector b = new ArrayRealVector(n);
613 for (int j = 0; j < n; j++) {
614 b.set(0.);
615 b.setEntry(j, 1.);
616 solver.solve(a, m, b);
617 }
618 }
619 }