1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.ode.nonstiff;
19
20
21 import org.apache.commons.math4.legacy.core.Field;
22 import org.apache.commons.math4.legacy.core.RealFieldElement;
23 import org.apache.commons.math4.legacy.ode.AbstractIntegrator;
24 import org.apache.commons.math4.legacy.ode.EquationsMapper;
25 import org.apache.commons.math4.legacy.ode.ExpandableStatefulODE;
26 import org.apache.commons.math4.legacy.ode.FieldEquationsMapper;
27 import org.apache.commons.math4.legacy.ode.FieldExpandableODE;
28 import org.apache.commons.math4.legacy.ode.FirstOrderFieldDifferentialEquations;
29 import org.apache.commons.math4.legacy.ode.FieldODEStateAndDerivative;
30 import org.apache.commons.math4.legacy.ode.sampling.AbstractFieldStepInterpolator;
31 import org.apache.commons.math4.core.jdkmath.JdkMath;
32 import org.apache.commons.math4.legacy.core.MathArrays;
33 import org.junit.Assert;
34 import org.junit.Test;
35
36 public abstract class RungeKuttaFieldStepInterpolatorAbstractTest {
37
38 protected abstract <T extends RealFieldElement<T>> RungeKuttaFieldStepInterpolator<T>
39 createInterpolator(Field<T> field, boolean forward, T[][] yDotK,
40 FieldODEStateAndDerivative<T> globalPreviousState,
41 FieldODEStateAndDerivative<T> globalCurrentState,
42 FieldODEStateAndDerivative<T> softPreviousState,
43 FieldODEStateAndDerivative<T> softCurrentState,
44 FieldEquationsMapper<T> mapper);
45
46 protected abstract <T extends RealFieldElement<T>> FieldButcherArrayProvider<T>
47 createButcherArrayProvider(Field<T> field);
48
49 @Test
50 public abstract void interpolationAtBounds();
51
52 protected <T extends RealFieldElement<T>> void doInterpolationAtBounds(final Field<T> field, double epsilon) {
53
54 RungeKuttaFieldStepInterpolator<T> interpolator = setUpInterpolator(field,
55 new SinCos<>(field),
56 0.0, new double[] { 0.0, 1.0 }, 0.125);
57
58 Assert.assertEquals(0.0, interpolator.getPreviousState().getTime().getReal(), 1.0e-15);
59 for (int i = 0; i < 2; ++i) {
60 Assert.assertEquals(interpolator.getPreviousState().getState()[i].getReal(),
61 interpolator.getInterpolatedState(interpolator.getPreviousState().getTime()).getState()[i].getReal(),
62 epsilon);
63 }
64 Assert.assertEquals(0.125, interpolator.getCurrentState().getTime().getReal(), 1.0e-15);
65 for (int i = 0; i < 2; ++i) {
66 Assert.assertEquals(interpolator.getCurrentState().getState()[i].getReal(),
67 interpolator.getInterpolatedState(interpolator.getCurrentState().getTime()).getState()[i].getReal(),
68 epsilon);
69 }
70 }
71
72 @Test
73 public abstract void interpolationInside();
74
75 protected <T extends RealFieldElement<T>> void doInterpolationInside(final Field<T> field,
76 double epsilonSin, double epsilonCos) {
77
78 RungeKuttaFieldStepInterpolator<T> interpolator = setUpInterpolator(field,
79 new SinCos<>(field),
80 0.0, new double[] { 0.0, 1.0 }, 0.0125);
81
82 int n = 100;
83 double maxErrorSin = 0;
84 double maxErrorCos = 0;
85 for (int i = 0; i <= n; ++i) {
86 T t = interpolator.getPreviousState().getTime().multiply(n - i).
87 add(interpolator.getCurrentState().getTime().multiply(i)).
88 divide(n);
89 FieldODEStateAndDerivative<T> state = interpolator.getInterpolatedState(t);
90 maxErrorSin = JdkMath.max(maxErrorSin, state.getState()[0].subtract(t.sin()).abs().getReal());
91 maxErrorCos = JdkMath.max(maxErrorCos, state.getState()[1].subtract(t.cos()).abs().getReal());
92 }
93 Assert.assertEquals(0.0, maxErrorSin, epsilonSin);
94 Assert.assertEquals(0.0, maxErrorCos, epsilonCos);
95 }
96
97 @Test
98 public abstract void nonFieldInterpolatorConsistency();
99
100 protected <T extends RealFieldElement<T>> void doNonFieldInterpolatorConsistency(final Field<T> field,
101 double epsilonSin, double epsilonCos,
102 double epsilonSinDot, double epsilonCosDot) {
103
104 FirstOrderFieldDifferentialEquations<T> eqn = new SinCos<>(field);
105 RungeKuttaFieldStepInterpolator<T> fieldInterpolator =
106 setUpInterpolator(field, eqn, 0.0, new double[] { 0.0, 1.0 }, 0.125);
107 RungeKuttaStepInterpolator regularInterpolator = convertInterpolator(fieldInterpolator, eqn);
108
109 int n = 100;
110 double maxErrorSin = 0;
111 double maxErrorCos = 0;
112 double maxErrorSinDot = 0;
113 double maxErrorCosDot = 0;
114 for (int i = 0; i <= n; ++i) {
115
116 T t = fieldInterpolator.getPreviousState().getTime().multiply(n - i).
117 add(fieldInterpolator.getCurrentState().getTime().multiply(i)).
118 divide(n);
119
120 FieldODEStateAndDerivative<T> state = fieldInterpolator.getInterpolatedState(t);
121 T[] fieldY = state.getState();
122 T[] fieldYDot = state.getDerivative();
123
124 regularInterpolator.setInterpolatedTime(t.getReal());
125 double[] regularY = regularInterpolator.getInterpolatedState();
126 double[] regularYDot = regularInterpolator.getInterpolatedDerivatives();
127
128 maxErrorSin = JdkMath.max(maxErrorSin, fieldY[0].subtract(regularY[0]).abs().getReal());
129 maxErrorCos = JdkMath.max(maxErrorCos, fieldY[1].subtract(regularY[1]).abs().getReal());
130 maxErrorSinDot = JdkMath.max(maxErrorSinDot, fieldYDot[0].subtract(regularYDot[0]).abs().getReal());
131 maxErrorCosDot = JdkMath.max(maxErrorCosDot, fieldYDot[1].subtract(regularYDot[1]).abs().getReal());
132 }
133 Assert.assertEquals(0.0, maxErrorSin, epsilonSin);
134 Assert.assertEquals(0.0, maxErrorCos, epsilonCos);
135 Assert.assertEquals(0.0, maxErrorSinDot, epsilonSinDot);
136 Assert.assertEquals(0.0, maxErrorCosDot, epsilonCosDot);
137 }
138
139 private <T extends RealFieldElement<T>>
140 RungeKuttaFieldStepInterpolator<T> setUpInterpolator(final Field<T> field,
141 final FirstOrderFieldDifferentialEquations<T> eqn,
142 final double t0, final double[] y0,
143 final double t1) {
144
145
146 FieldButcherArrayProvider<T> provider = createButcherArrayProvider(field);
147 T[][] a = provider.getA();
148 T[] b = provider.getB();
149 T[] c = provider.getC();
150
151
152 T t = field.getZero().add(t0);
153 T[] fieldY = MathArrays.buildArray(field, eqn.getDimension());
154 T[][] fieldYDotK = MathArrays.buildArray(field, b.length, -1);
155 for (int i = 0; i < y0.length; ++i) {
156 fieldY[i] = field.getZero().add(y0[i]);
157 }
158 fieldYDotK[0] = eqn.computeDerivatives(t, fieldY);
159 FieldODEStateAndDerivative<T> s0 = new FieldODEStateAndDerivative<>(t, fieldY, fieldYDotK[0]);
160
161
162 T h = field.getZero().add(t1 - t0);
163 for (int k = 0; k < a.length; ++k) {
164 for (int i = 0; i < y0.length; ++i) {
165 fieldY[i] = field.getZero().add(y0[i]);
166 for (int s = 0; s <= k; ++s) {
167 fieldY[i] = fieldY[i].add(h.multiply(a[k][s].multiply(fieldYDotK[s][i])));
168 }
169 }
170 fieldYDotK[k + 1] = eqn.computeDerivatives(h.multiply(c[k]).add(t0), fieldY);
171 }
172
173
174 t = field.getZero().add(t1);
175 for (int i = 0; i < y0.length; ++i) {
176 fieldY[i] = field.getZero().add(y0[i]);
177 for (int s = 0; s < b.length; ++s) {
178 fieldY[i] = fieldY[i].add(h.multiply(b[s].multiply(fieldYDotK[s][i])));
179 }
180 }
181 FieldODEStateAndDerivative<T> s1 = new FieldODEStateAndDerivative<>(t, fieldY,
182 eqn.computeDerivatives(t, fieldY));
183
184 return createInterpolator(field, t1 > t0, fieldYDotK, s0, s1, s0, s1,
185 new FieldExpandableODE<>(eqn).getMapper());
186 }
187
188 private <T extends RealFieldElement<T>>
189 RungeKuttaStepInterpolator convertInterpolator(final RungeKuttaFieldStepInterpolator<T> fieldInterpolator,
190 final FirstOrderFieldDifferentialEquations<T> eqn) {
191
192 RungeKuttaStepInterpolator regularInterpolator = null;
193 try {
194
195 String interpolatorName = fieldInterpolator.getClass().getName();
196 String integratorName = interpolatorName.replaceAll("Field", "");
197 @SuppressWarnings("unchecked")
198 Class<RungeKuttaStepInterpolator> clz = (Class<RungeKuttaStepInterpolator>) Class.forName(integratorName);
199 regularInterpolator = clz.newInstance();
200
201 double[][] yDotArray = null;
202 java.lang.reflect.Field fYD = RungeKuttaFieldStepInterpolator.class.getDeclaredField("yDotK");
203 fYD.setAccessible(true);
204 @SuppressWarnings("unchecked")
205 T[][] fieldYDotk = (T[][]) fYD.get(fieldInterpolator);
206 yDotArray = new double[fieldYDotk.length][];
207 for (int i = 0; i < yDotArray.length; ++i) {
208 yDotArray[i] = new double[fieldYDotk[i].length];
209 for (int j = 0; j < yDotArray[i].length; ++j) {
210 yDotArray[i][j] = fieldYDotk[i][j].getReal();
211 }
212 }
213 double[] y = new double[yDotArray[0].length];
214
215 EquationsMapper primaryMapper = null;
216 EquationsMapper[] secondaryMappers = null;
217 java.lang.reflect.Field fMapper = AbstractFieldStepInterpolator.class.getDeclaredField("mapper");
218 fMapper.setAccessible(true);
219 @SuppressWarnings("unchecked")
220 FieldEquationsMapper<T> mapper = (FieldEquationsMapper<T>) fMapper.get(fieldInterpolator);
221 java.lang.reflect.Field fStart = FieldEquationsMapper.class.getDeclaredField("start");
222 fStart.setAccessible(true);
223 int[] start = (int[]) fStart.get(mapper);
224 primaryMapper = new EquationsMapper(start[0], start[1]);
225 secondaryMappers = new EquationsMapper[mapper.getNumberOfEquations() - 1];
226 for (int i = 0; i < secondaryMappers.length; ++i) {
227 secondaryMappers[i] = new EquationsMapper(start[i + 1], start[i + 2]);
228 }
229
230 AbstractIntegrator dummyIntegrator = new AbstractIntegrator("dummy") {
231 @Override
232 public void integrate(ExpandableStatefulODE equations, double t) {
233 Assert.fail("this method should not be called");
234 }
235 @Override
236 public void computeDerivatives(final double t, final double[] y, final double[] yDot) {
237 T fieldT = fieldInterpolator.getCurrentState().getTime().getField().getZero().add(t);
238 T[] fieldY = MathArrays.buildArray(fieldInterpolator.getCurrentState().getTime().getField(), y.length);
239 for (int i = 0; i < y.length; ++i) {
240 fieldY[i] = fieldInterpolator.getCurrentState().getTime().getField().getZero().add(y[i]);
241 }
242 T[] fieldYDot = eqn.computeDerivatives(fieldT, fieldY);
243 for (int i = 0; i < yDot.length; ++i) {
244 yDot[i] = fieldYDot[i].getReal();
245 }
246 }
247 };
248 regularInterpolator.reinitialize(dummyIntegrator, y, yDotArray,
249 fieldInterpolator.isForward(),
250 primaryMapper, secondaryMappers);
251
252 T[] fieldPreviousY = fieldInterpolator.getPreviousState().getState();
253 for (int i = 0; i < y.length; ++i) {
254 y[i] = fieldPreviousY[i].getReal();
255 }
256 regularInterpolator.storeTime(fieldInterpolator.getPreviousState().getTime().getReal());
257
258 regularInterpolator.shift();
259
260 T[] fieldCurrentY = fieldInterpolator.getCurrentState().getState();
261 for (int i = 0; i < y.length; ++i) {
262 y[i] = fieldCurrentY[i].getReal();
263 }
264 regularInterpolator.storeTime(fieldInterpolator.getCurrentState().getTime().getReal());
265 } catch (ClassNotFoundException cnfe) {
266 Assert.fail(cnfe.getLocalizedMessage());
267 } catch (InstantiationException ie) {
268 Assert.fail(ie.getLocalizedMessage());
269 } catch (IllegalAccessException iae) {
270 Assert.fail(iae.getLocalizedMessage());
271 } catch (NoSuchFieldException nsfe) {
272 Assert.fail(nsfe.getLocalizedMessage());
273 } catch (IllegalArgumentException iae) {
274 Assert.fail(iae.getLocalizedMessage());
275 }
276
277 return regularInterpolator;
278 }
279
280 private static final class SinCos<T extends RealFieldElement<T>> implements FirstOrderFieldDifferentialEquations<T> {
281 private final Field<T> field;
282 protected SinCos(final Field<T> field) {
283 this.field = field;
284 }
285 @Override
286 public int getDimension() {
287 return 2;
288 }
289 @Override
290 public void init(final T t0, final T[] y0, final T finalTime) {
291 }
292 @Override
293 public T[] computeDerivatives(final T t, final T[] y) {
294 T[] yDot = MathArrays.buildArray(field, 2);
295 yDot[0] = y[1];
296 yDot[1] = y[0].negate();
297 return yDot;
298 }
299 }
300 }