View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.ode.nonstiff;
19  
20  
21  import org.apache.commons.math4.legacy.core.Field;
22  import org.apache.commons.math4.legacy.core.RealFieldElement;
23  import org.apache.commons.math4.legacy.ode.AbstractIntegrator;
24  import org.apache.commons.math4.legacy.ode.EquationsMapper;
25  import org.apache.commons.math4.legacy.ode.ExpandableStatefulODE;
26  import org.apache.commons.math4.legacy.ode.FieldEquationsMapper;
27  import org.apache.commons.math4.legacy.ode.FieldExpandableODE;
28  import org.apache.commons.math4.legacy.ode.FirstOrderFieldDifferentialEquations;
29  import org.apache.commons.math4.legacy.ode.FieldODEStateAndDerivative;
30  import org.apache.commons.math4.legacy.ode.sampling.AbstractFieldStepInterpolator;
31  import org.apache.commons.math4.core.jdkmath.JdkMath;
32  import org.apache.commons.math4.legacy.core.MathArrays;
33  import org.junit.Assert;
34  import org.junit.Test;
35  
36  public abstract class RungeKuttaFieldStepInterpolatorAbstractTest {
37  
38      protected abstract <T extends RealFieldElement<T>> RungeKuttaFieldStepInterpolator<T>
39          createInterpolator(Field<T> field, boolean forward, T[][] yDotK,
40                             FieldODEStateAndDerivative<T> globalPreviousState,
41                             FieldODEStateAndDerivative<T> globalCurrentState,
42                             FieldODEStateAndDerivative<T> softPreviousState,
43                             FieldODEStateAndDerivative<T> softCurrentState,
44                             FieldEquationsMapper<T> mapper);
45  
46      protected abstract <T extends RealFieldElement<T>> FieldButcherArrayProvider<T>
47          createButcherArrayProvider(Field<T> field);
48  
49      @Test
50      public abstract void interpolationAtBounds();
51  
52      protected <T extends RealFieldElement<T>> void doInterpolationAtBounds(final Field<T> field, double epsilon) {
53  
54          RungeKuttaFieldStepInterpolator<T> interpolator = setUpInterpolator(field,
55                                                                              new SinCos<>(field),
56                                                                              0.0, new double[] { 0.0, 1.0 }, 0.125);
57  
58          Assert.assertEquals(0.0, interpolator.getPreviousState().getTime().getReal(), 1.0e-15);
59          for (int i = 0; i < 2; ++i) {
60              Assert.assertEquals(interpolator.getPreviousState().getState()[i].getReal(),
61                                  interpolator.getInterpolatedState(interpolator.getPreviousState().getTime()).getState()[i].getReal(),
62                                  epsilon);
63          }
64          Assert.assertEquals(0.125, interpolator.getCurrentState().getTime().getReal(), 1.0e-15);
65          for (int i = 0; i < 2; ++i) {
66              Assert.assertEquals(interpolator.getCurrentState().getState()[i].getReal(),
67                                  interpolator.getInterpolatedState(interpolator.getCurrentState().getTime()).getState()[i].getReal(),
68                                  epsilon);
69          }
70      }
71  
72      @Test
73      public abstract void interpolationInside();
74  
75      protected <T extends RealFieldElement<T>> void doInterpolationInside(final Field<T> field,
76                                                                           double epsilonSin, double epsilonCos) {
77  
78          RungeKuttaFieldStepInterpolator<T> interpolator = setUpInterpolator(field,
79                                                                              new SinCos<>(field),
80                                                                              0.0, new double[] { 0.0, 1.0 }, 0.0125);
81  
82          int n = 100;
83          double maxErrorSin = 0;
84          double maxErrorCos = 0;
85          for (int i = 0; i <= n; ++i) {
86              T t =     interpolator.getPreviousState().getTime().multiply(n - i).
87                    add(interpolator.getCurrentState().getTime().multiply(i)).
88                    divide(n);
89              FieldODEStateAndDerivative<T> state = interpolator.getInterpolatedState(t);
90              maxErrorSin = JdkMath.max(maxErrorSin, state.getState()[0].subtract(t.sin()).abs().getReal());
91              maxErrorCos = JdkMath.max(maxErrorCos, state.getState()[1].subtract(t.cos()).abs().getReal());
92          }
93          Assert.assertEquals(0.0, maxErrorSin, epsilonSin);
94          Assert.assertEquals(0.0, maxErrorCos, epsilonCos);
95      }
96  
97      @Test
98      public abstract void nonFieldInterpolatorConsistency();
99  
100     protected <T extends RealFieldElement<T>> void doNonFieldInterpolatorConsistency(final Field<T> field,
101                                                                                      double epsilonSin, double epsilonCos,
102                                                                                      double epsilonSinDot, double epsilonCosDot) {
103 
104         FirstOrderFieldDifferentialEquations<T> eqn = new SinCos<>(field);
105         RungeKuttaFieldStepInterpolator<T> fieldInterpolator =
106                         setUpInterpolator(field, eqn, 0.0, new double[] { 0.0, 1.0 }, 0.125);
107         RungeKuttaStepInterpolator regularInterpolator = convertInterpolator(fieldInterpolator, eqn);
108 
109         int n = 100;
110         double maxErrorSin    = 0;
111         double maxErrorCos    = 0;
112         double maxErrorSinDot = 0;
113         double maxErrorCosDot = 0;
114         for (int i = 0; i <= n; ++i) {
115 
116             T t =     fieldInterpolator.getPreviousState().getTime().multiply(n - i).
117                   add(fieldInterpolator.getCurrentState().getTime().multiply(i)).
118                   divide(n);
119 
120             FieldODEStateAndDerivative<T> state = fieldInterpolator.getInterpolatedState(t);
121             T[] fieldY    = state.getState();
122             T[] fieldYDot = state.getDerivative();
123 
124             regularInterpolator.setInterpolatedTime(t.getReal());
125             double[] regularY     = regularInterpolator.getInterpolatedState();
126             double[] regularYDot  = regularInterpolator.getInterpolatedDerivatives();
127 
128             maxErrorSin    = JdkMath.max(maxErrorSin,    fieldY[0].subtract(regularY[0]).abs().getReal());
129             maxErrorCos    = JdkMath.max(maxErrorCos,    fieldY[1].subtract(regularY[1]).abs().getReal());
130             maxErrorSinDot = JdkMath.max(maxErrorSinDot, fieldYDot[0].subtract(regularYDot[0]).abs().getReal());
131             maxErrorCosDot = JdkMath.max(maxErrorCosDot, fieldYDot[1].subtract(regularYDot[1]).abs().getReal());
132         }
133         Assert.assertEquals(0.0, maxErrorSin,    epsilonSin);
134         Assert.assertEquals(0.0, maxErrorCos,    epsilonCos);
135         Assert.assertEquals(0.0, maxErrorSinDot, epsilonSinDot);
136         Assert.assertEquals(0.0, maxErrorCosDot, epsilonCosDot);
137     }
138 
139     private <T extends RealFieldElement<T>>
140     RungeKuttaFieldStepInterpolator<T> setUpInterpolator(final Field<T> field,
141                                                          final FirstOrderFieldDifferentialEquations<T> eqn,
142                                                          final double t0, final double[] y0,
143                                                          final double t1) {
144 
145         // get the Butcher arrays from the field integrator
146         FieldButcherArrayProvider<T> provider = createButcherArrayProvider(field);
147         T[][] a = provider.getA();
148         T[]   b = provider.getB();
149         T[]   c = provider.getC();
150 
151         // store initial state
152         T     t          = field.getZero().add(t0);
153         T[]   fieldY     = MathArrays.buildArray(field, eqn.getDimension());
154         T[][] fieldYDotK = MathArrays.buildArray(field, b.length, -1);
155         for (int i = 0; i < y0.length; ++i) {
156             fieldY[i] = field.getZero().add(y0[i]);
157         }
158         fieldYDotK[0] = eqn.computeDerivatives(t, fieldY);
159         FieldODEStateAndDerivative<T> s0 = new FieldODEStateAndDerivative<>(t, fieldY, fieldYDotK[0]);
160 
161         // perform one integration step, in order to get consistent derivatives
162         T h = field.getZero().add(t1 - t0);
163         for (int k = 0; k < a.length; ++k) {
164             for (int i = 0; i < y0.length; ++i) {
165                 fieldY[i] = field.getZero().add(y0[i]);
166                 for (int s = 0; s <= k; ++s) {
167                     fieldY[i] = fieldY[i].add(h.multiply(a[k][s].multiply(fieldYDotK[s][i])));
168                 }
169             }
170             fieldYDotK[k + 1] = eqn.computeDerivatives(h.multiply(c[k]).add(t0), fieldY);
171         }
172 
173         // store state at step end
174         t = field.getZero().add(t1);
175         for (int i = 0; i < y0.length; ++i) {
176             fieldY[i] = field.getZero().add(y0[i]);
177             for (int s = 0; s < b.length; ++s) {
178                 fieldY[i] = fieldY[i].add(h.multiply(b[s].multiply(fieldYDotK[s][i])));
179             }
180         }
181         FieldODEStateAndDerivative<T> s1 = new FieldODEStateAndDerivative<>(t, fieldY,
182                                                                              eqn.computeDerivatives(t, fieldY));
183 
184         return createInterpolator(field, t1 > t0, fieldYDotK, s0, s1, s0, s1,
185                                   new FieldExpandableODE<>(eqn).getMapper());
186     }
187 
188     private <T extends RealFieldElement<T>>
189     RungeKuttaStepInterpolator convertInterpolator(final RungeKuttaFieldStepInterpolator<T> fieldInterpolator,
190                                                    final FirstOrderFieldDifferentialEquations<T> eqn) {
191 
192         RungeKuttaStepInterpolator regularInterpolator = null;
193         try {
194 
195             String interpolatorName = fieldInterpolator.getClass().getName();
196             String integratorName = interpolatorName.replaceAll("Field", "");
197             @SuppressWarnings("unchecked")
198             Class<RungeKuttaStepInterpolator> clz = (Class<RungeKuttaStepInterpolator>) Class.forName(integratorName);
199             regularInterpolator = clz.newInstance();
200 
201             double[][] yDotArray = null;
202             java.lang.reflect.Field fYD = RungeKuttaFieldStepInterpolator.class.getDeclaredField("yDotK");
203             fYD.setAccessible(true);
204             @SuppressWarnings("unchecked")
205             T[][] fieldYDotk = (T[][]) fYD.get(fieldInterpolator);
206             yDotArray = new double[fieldYDotk.length][];
207             for (int i = 0; i < yDotArray.length; ++i) {
208                 yDotArray[i] = new double[fieldYDotk[i].length];
209                 for (int j = 0; j < yDotArray[i].length; ++j) {
210                     yDotArray[i][j] = fieldYDotk[i][j].getReal();
211                 }
212             }
213             double[] y = new double[yDotArray[0].length];
214 
215             EquationsMapper primaryMapper = null;
216             EquationsMapper[] secondaryMappers = null;
217             java.lang.reflect.Field fMapper = AbstractFieldStepInterpolator.class.getDeclaredField("mapper");
218             fMapper.setAccessible(true);
219             @SuppressWarnings("unchecked")
220             FieldEquationsMapper<T> mapper = (FieldEquationsMapper<T>) fMapper.get(fieldInterpolator);
221             java.lang.reflect.Field fStart = FieldEquationsMapper.class.getDeclaredField("start");
222             fStart.setAccessible(true);
223             int[] start = (int[]) fStart.get(mapper);
224             primaryMapper = new EquationsMapper(start[0], start[1]);
225             secondaryMappers = new EquationsMapper[mapper.getNumberOfEquations() - 1];
226             for (int i = 0; i < secondaryMappers.length; ++i) {
227                 secondaryMappers[i] = new EquationsMapper(start[i + 1], start[i + 2]);
228             }
229 
230             AbstractIntegrator dummyIntegrator = new AbstractIntegrator("dummy") {
231                 @Override
232                 public void integrate(ExpandableStatefulODE equations, double t) {
233                     Assert.fail("this method should not be called");
234                 }
235                 @Override
236                 public void computeDerivatives(final double t, final double[] y, final double[] yDot) {
237                     T fieldT = fieldInterpolator.getCurrentState().getTime().getField().getZero().add(t);
238                     T[] fieldY = MathArrays.buildArray(fieldInterpolator.getCurrentState().getTime().getField(), y.length);
239                     for (int i = 0; i < y.length; ++i) {
240                         fieldY[i] = fieldInterpolator.getCurrentState().getTime().getField().getZero().add(y[i]);
241                     }
242                     T[] fieldYDot = eqn.computeDerivatives(fieldT, fieldY);
243                     for (int i = 0; i < yDot.length; ++i) {
244                         yDot[i] = fieldYDot[i].getReal();
245                     }
246                 }
247             };
248             regularInterpolator.reinitialize(dummyIntegrator, y, yDotArray,
249                                              fieldInterpolator.isForward(),
250                                              primaryMapper, secondaryMappers);
251 
252             T[] fieldPreviousY = fieldInterpolator.getPreviousState().getState();
253             for (int i = 0; i < y.length; ++i) {
254                 y[i] = fieldPreviousY[i].getReal();
255             }
256             regularInterpolator.storeTime(fieldInterpolator.getPreviousState().getTime().getReal());
257 
258             regularInterpolator.shift();
259 
260             T[] fieldCurrentY = fieldInterpolator.getCurrentState().getState();
261             for (int i = 0; i < y.length; ++i) {
262                 y[i] = fieldCurrentY[i].getReal();
263             }
264             regularInterpolator.storeTime(fieldInterpolator.getCurrentState().getTime().getReal());
265         } catch (ClassNotFoundException cnfe) {
266             Assert.fail(cnfe.getLocalizedMessage());
267         } catch (InstantiationException ie) {
268             Assert.fail(ie.getLocalizedMessage());
269         } catch (IllegalAccessException iae) {
270             Assert.fail(iae.getLocalizedMessage());
271         } catch (NoSuchFieldException nsfe) {
272             Assert.fail(nsfe.getLocalizedMessage());
273         } catch (IllegalArgumentException iae) {
274             Assert.fail(iae.getLocalizedMessage());
275         }
276 
277         return regularInterpolator;
278     }
279 
280     private static final class SinCos<T extends RealFieldElement<T>> implements FirstOrderFieldDifferentialEquations<T> {
281         private final Field<T> field;
282         protected SinCos(final Field<T> field) {
283             this.field = field;
284         }
285         @Override
286         public int getDimension() {
287             return 2;
288         }
289         @Override
290         public void init(final T t0, final T[] y0, final T finalTime) {
291         }
292         @Override
293         public T[] computeDerivatives(final T t, final T[] y) {
294             T[] yDot = MathArrays.buildArray(field, 2);
295             yDot[0] = y[1];
296             yDot[1] = y[0].negate();
297             return yDot;
298         }
299     }
300 }