1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.ode;
19
20 import org.apache.commons.math4.legacy.exception.MaxCountExceededException;
21 import org.apache.commons.math4.legacy.ode.sampling.StepHandler;
22 import org.apache.commons.math4.legacy.ode.sampling.StepInterpolator;
23 import org.apache.commons.math4.core.jdkmath.JdkMath;
24
25
26
27
28
29 public class TestProblemHandler
30 implements StepHandler {
31
32
33 private TestProblemAbstract problem;
34
35
36 private double maxValueError;
37 private double maxTimeError;
38
39
40 private double lastError;
41
42
43 private double lastTime;
44
45
46 private ODEIntegrator integrator;
47
48
49 private double expectedStepStart;
50
51
52
53
54
55
56 public TestProblemHandler(TestProblemAbstract problem, ODEIntegrator integrator) {
57 this.problem = problem;
58 this.integrator = integrator;
59 maxValueError = 0;
60 maxTimeError = 0;
61 lastError = 0;
62 expectedStepStart = Double.NaN;
63 }
64
65 @Override
66 public void init(double t0, double[] y0, double t) {
67 maxValueError = 0;
68 maxTimeError = 0;
69 lastError = 0;
70 expectedStepStart = Double.NaN;
71 }
72
73 @Override
74 public void handleStep(StepInterpolator interpolator, boolean isLast) throws MaxCountExceededException {
75
76 double start = integrator.getCurrentStepStart();
77 if (JdkMath.abs((start - problem.getInitialTime()) / integrator.getCurrentSignedStepsize()) > 0.001) {
78
79
80 if (!Double.isNaN(expectedStepStart)) {
81
82
83 double stepError = JdkMath.max(maxTimeError, JdkMath.abs(start - expectedStepStart));
84 for (double eventTime : problem.getTheoreticalEventsTimes()) {
85 stepError = JdkMath.min(stepError, JdkMath.abs(start - eventTime));
86 }
87 maxTimeError = JdkMath.max(maxTimeError, stepError);
88 }
89 expectedStepStart = start + integrator.getCurrentSignedStepsize();
90 }
91
92
93 double pT = interpolator.getPreviousTime();
94 double cT = interpolator.getCurrentTime();
95 double[] errorScale = problem.getErrorScale();
96
97
98 if (isLast) {
99 double[] interpolatedY = interpolator.getInterpolatedState();
100 double[] theoreticalY = problem.computeTheoreticalState(cT);
101 for (int i = 0; i < interpolatedY.length; ++i) {
102 double error = JdkMath.abs(interpolatedY[i] - theoreticalY[i]);
103 lastError = JdkMath.max(error, lastError);
104 }
105 lastTime = cT;
106 }
107
108 for (int k = 0; k <= 20; ++k) {
109
110 double time = pT + (k * (cT - pT)) / 20;
111 interpolator.setInterpolatedTime(time);
112 double[] interpolatedY = interpolator.getInterpolatedState();
113 double[] theoreticalY = problem.computeTheoreticalState(interpolator.getInterpolatedTime());
114
115
116 for (int i = 0; i < interpolatedY.length; ++i) {
117 double error = errorScale[i] * JdkMath.abs(interpolatedY[i] - theoreticalY[i]);
118 maxValueError = JdkMath.max(error, maxValueError);
119 }
120 }
121 }
122
123
124
125
126
127 public double getMaximalValueError() {
128 return maxValueError;
129 }
130
131
132
133
134
135 public double getMaximalTimeError() {
136 return maxTimeError;
137 }
138
139
140 public int getCalls() {
141 return problem.getCalls();
142 }
143
144
145
146
147
148 public double getLastError() {
149 return lastError;
150 }
151
152
153
154
155
156 public double getLastTime() {
157 return lastTime;
158 }
159 }