1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.ode.nonstiff;
19
20
21 import org.apache.commons.math4.legacy.core.Field;
22 import org.apache.commons.math4.legacy.core.RealFieldElement;
23 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
24 import org.apache.commons.math4.legacy.exception.MaxCountExceededException;
25 import org.apache.commons.math4.legacy.exception.NoBracketingException;
26 import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
27 import org.apache.commons.math4.legacy.ode.AbstractFieldIntegrator;
28 import org.apache.commons.math4.legacy.ode.FieldEquationsMapper;
29 import org.apache.commons.math4.legacy.ode.FieldExpandableODE;
30 import org.apache.commons.math4.legacy.ode.FirstOrderFieldDifferentialEquations;
31 import org.apache.commons.math4.legacy.ode.FieldODEState;
32 import org.apache.commons.math4.legacy.ode.FieldODEStateAndDerivative;
33 import org.apache.commons.math4.legacy.core.MathArrays;
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59 public abstract class RungeKuttaFieldIntegrator<T extends RealFieldElement<T>>
60 extends AbstractFieldIntegrator<T>
61 implements FieldButcherArrayProvider<T> {
62
63
64 private final T[] c;
65
66
67 private final T[][] a;
68
69
70 private final T[] b;
71
72
73 private final T step;
74
75
76
77
78
79
80
81
82 protected RungeKuttaFieldIntegrator(final Field<T> field, final String name, final T step) {
83 super(field, name);
84 this.c = getC();
85 this.a = getA();
86 this.b = getB();
87 this.step = step.abs();
88 }
89
90
91
92
93
94
95 protected T fraction(final int p, final int q) {
96 return getField().getZero().add(p).divide(q);
97 }
98
99
100
101
102
103
104
105
106
107 protected abstract RungeKuttaFieldStepInterpolator<T> createInterpolator(boolean forward, T[][] yDotK,
108 FieldODEStateAndDerivative<T> globalPreviousState,
109 FieldODEStateAndDerivative<T> globalCurrentState,
110 FieldEquationsMapper<T> mapper);
111
112
113 @Override
114 public FieldODEStateAndDerivative<T> integrate(final FieldExpandableODE<T> equations,
115 final FieldODEState<T> initialState, final T finalTime)
116 throws NumberIsTooSmallException, DimensionMismatchException,
117 MaxCountExceededException, NoBracketingException {
118
119 sanityChecks(initialState, finalTime);
120 final T t0 = initialState.getTime();
121 final T[] y0 = equations.getMapper().mapState(initialState);
122 setStepStart(initIntegration(equations, t0, y0, finalTime));
123 final boolean forward = finalTime.subtract(initialState.getTime()).getReal() > 0;
124
125
126 final int stages = c.length + 1;
127 T[] y = y0;
128 final T[][] yDotK = MathArrays.buildArray(getField(), stages, -1);
129 final T[] yTmp = MathArrays.buildArray(getField(), y0.length);
130
131
132 if (forward) {
133 if (getStepStart().getTime().add(step).subtract(finalTime).getReal() >= 0) {
134 setStepSize(finalTime.subtract(getStepStart().getTime()));
135 } else {
136 setStepSize(step);
137 }
138 } else {
139 if (getStepStart().getTime().subtract(step).subtract(finalTime).getReal() <= 0) {
140 setStepSize(finalTime.subtract(getStepStart().getTime()));
141 } else {
142 setStepSize(step.negate());
143 }
144 }
145
146
147 setIsLastStep(false);
148 do {
149
150
151 y = equations.getMapper().mapState(getStepStart());
152 yDotK[0] = equations.getMapper().mapDerivative(getStepStart());
153
154
155 for (int k = 1; k < stages; ++k) {
156
157 for (int j = 0; j < y0.length; ++j) {
158 T sum = yDotK[0][j].multiply(a[k-1][0]);
159 for (int l = 1; l < k; ++l) {
160 sum = sum.add(yDotK[l][j].multiply(a[k-1][l]));
161 }
162 yTmp[j] = y[j].add(getStepSize().multiply(sum));
163 }
164
165 yDotK[k] = computeDerivatives(getStepStart().getTime().add(getStepSize().multiply(c[k-1])), yTmp);
166 }
167
168
169 for (int j = 0; j < y0.length; ++j) {
170 T sum = yDotK[0][j].multiply(b[0]);
171 for (int l = 1; l < stages; ++l) {
172 sum = sum.add(yDotK[l][j].multiply(b[l]));
173 }
174 yTmp[j] = y[j].add(getStepSize().multiply(sum));
175 }
176 final T stepEnd = getStepStart().getTime().add(getStepSize());
177 final T[] yDotTmp = computeDerivatives(stepEnd, yTmp);
178 final FieldODEStateAndDerivative<T> stateTmp = new FieldODEStateAndDerivative<>(stepEnd, yTmp, yDotTmp);
179
180
181 System.arraycopy(yTmp, 0, y, 0, y0.length);
182 setStepStart(acceptStep(createInterpolator(forward, yDotK, getStepStart(), stateTmp, equations.getMapper()),
183 finalTime));
184
185 if (!isLastStep()) {
186
187
188 final T nextT = getStepStart().getTime().add(getStepSize());
189 final boolean nextIsLast = forward ?
190 (nextT.subtract(finalTime).getReal() >= 0) :
191 (nextT.subtract(finalTime).getReal() <= 0);
192 if (nextIsLast) {
193 setStepSize(finalTime.subtract(getStepStart().getTime()));
194 }
195 }
196 } while (!isLastStep());
197
198 final FieldODEStateAndDerivative<T> finalState = getStepStart();
199 setStepStart(null);
200 setStepSize(null);
201 return finalState;
202 }
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229 public T[] singleStep(final FirstOrderFieldDifferentialEquations<T> equations,
230 final T t0, final T[] y0, final T t) {
231
232
233 final T[] y = y0.clone();
234 final int stages = c.length + 1;
235 final T[][] yDotK = MathArrays.buildArray(getField(), stages, -1);
236 final T[] yTmp = y0.clone();
237
238
239 final T h = t.subtract(t0);
240 yDotK[0] = equations.computeDerivatives(t0, y);
241
242
243 for (int k = 1; k < stages; ++k) {
244
245 for (int j = 0; j < y0.length; ++j) {
246 T sum = yDotK[0][j].multiply(a[k-1][0]);
247 for (int l = 1; l < k; ++l) {
248 sum = sum.add(yDotK[l][j].multiply(a[k-1][l]));
249 }
250 yTmp[j] = y[j].add(h.multiply(sum));
251 }
252
253 yDotK[k] = equations.computeDerivatives(t0.add(h.multiply(c[k-1])), yTmp);
254 }
255
256
257 for (int j = 0; j < y0.length; ++j) {
258 T sum = yDotK[0][j].multiply(b[0]);
259 for (int l = 1; l < stages; ++l) {
260 sum = sum.add(yDotK[l][j].multiply(b[l]));
261 }
262 y[j] = y[j].add(h.multiply(sum));
263 }
264
265 return y;
266 }
267 }