View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.ode;
19  
20  
21  import org.apache.commons.math4.legacy.core.Field;
22  import org.apache.commons.math4.legacy.core.RealFieldElement;
23  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
24  import org.apache.commons.math4.legacy.exception.MathIllegalArgumentException;
25  import org.apache.commons.math4.legacy.ode.nonstiff.Decimal64Field;
26  import org.apache.commons.math4.legacy.core.MathArrays;
27  import org.junit.Assert;
28  import org.junit.Test;
29  
30  public class FieldExpandableODETest {
31  
32      @Test
33      public void testOnlyMainEquation() {
34          doTestOnlyMainEquation(Decimal64Field.getInstance());
35      }
36  
37      private <T extends RealFieldElement<T>> void doTestOnlyMainEquation(final Field<T> field) {
38          FirstOrderFieldDifferentialEquations<T> main = new Linear<>(field, 3, 0);
39          FieldExpandableODE<T> equation = new FieldExpandableODE<>(main);
40          Assert.assertEquals(main.getDimension(), equation.getMapper().getTotalDimension());
41          Assert.assertEquals(1, equation.getMapper().getNumberOfEquations());
42          T t0 = field.getZero().add(10);
43          T t  = field.getZero().add(100);
44          T[] complete    = MathArrays.buildArray(field, equation.getMapper().getTotalDimension());
45          for (int i = 0; i < complete.length; ++i) {
46              complete[i] = field.getZero().add(i);
47          }
48          equation.init(t0, complete, t);
49          T[] completeDot = equation.computeDerivatives(t0, complete);
50          FieldODEStateAndDerivative<T> state = equation.getMapper().mapStateAndDerivative(t0, complete, completeDot);
51          Assert.assertEquals(0, state.getNumberOfSecondaryStates());
52          T[] mainState    = state.getState();
53          T[] mainStateDot = state.getDerivative();
54          Assert.assertEquals(main.getDimension(), mainState.length);
55          for (int i = 0; i < main.getDimension(); ++i) {
56              Assert.assertEquals(i, mainState[i].getReal(),   1.0e-15);
57              Assert.assertEquals(i, mainStateDot[i].getReal(), 1.0e-15);
58              Assert.assertEquals(i, completeDot[i].getReal(),  1.0e-15);
59          }
60      }
61  
62      @Test
63      public void testMainAndSecondary() {
64          doTestMainAndSecondary(Decimal64Field.getInstance());
65      }
66  
67      private <T extends RealFieldElement<T>> void doTestMainAndSecondary(final Field<T> field) {
68  
69          FirstOrderFieldDifferentialEquations<T> main = new Linear<>(field, 3, 0);
70          FieldExpandableODE<T> equation = new FieldExpandableODE<>(main);
71          FieldSecondaryEquations<T> secondary1 = new Linear<>(field, 3, main.getDimension());
72          int i1 = equation.addSecondaryEquations(secondary1);
73          FieldSecondaryEquations<T> secondary2 = new Linear<>(field, 5, main.getDimension() + secondary1.getDimension());
74          int i2 = equation.addSecondaryEquations(secondary2);
75          Assert.assertEquals(main.getDimension() + secondary1.getDimension() + secondary2.getDimension(),
76                              equation.getMapper().getTotalDimension());
77          Assert.assertEquals(3, equation.getMapper().getNumberOfEquations());
78          Assert.assertEquals(1, i1);
79          Assert.assertEquals(2, i2);
80  
81          T t0 = field.getZero().add(10);
82          T t  = field.getZero().add(100);
83          T[] complete    = MathArrays.buildArray(field, equation.getMapper().getTotalDimension());
84          for (int i = 0; i < complete.length; ++i) {
85              complete[i] = field.getZero().add(i);
86          }
87          equation.init(t0, complete, t);
88          T[] completeDot = equation.computeDerivatives(t0, complete);
89  
90          T[] mainState    = equation.getMapper().extractEquationData(0,  complete);
91          T[] mainStateDot = equation.getMapper().extractEquationData(0,  completeDot);
92          Assert.assertEquals(main.getDimension(), mainState.length);
93          for (int i = 0; i < main.getDimension(); ++i) {
94              Assert.assertEquals(i, mainState[i].getReal(),   1.0e-15);
95              Assert.assertEquals(i, mainStateDot[i].getReal(), 1.0e-15);
96              Assert.assertEquals(i, completeDot[i].getReal(),  1.0e-15);
97          }
98  
99          T[] secondaryState1    = equation.getMapper().extractEquationData(i1,  complete);
100         T[] secondaryState1Dot = equation.getMapper().extractEquationData(i1,  completeDot);
101         Assert.assertEquals(secondary1.getDimension(), secondaryState1.length);
102         for (int i = 0; i < secondary1.getDimension(); ++i) {
103             Assert.assertEquals(i + main.getDimension(), secondaryState1[i].getReal(),   1.0e-15);
104             Assert.assertEquals(-i, secondaryState1Dot[i].getReal(), 1.0e-15);
105             Assert.assertEquals(-i, completeDot[i + main.getDimension()].getReal(),  1.0e-15);
106         }
107 
108         T[] secondaryState2    = equation.getMapper().extractEquationData(i2,  complete);
109         T[] secondaryState2Dot = equation.getMapper().extractEquationData(i2,  completeDot);
110         Assert.assertEquals(secondary2.getDimension(), secondaryState2.length);
111         for (int i = 0; i < secondary2.getDimension(); ++i) {
112             Assert.assertEquals(i + main.getDimension() + secondary1.getDimension(), secondaryState2[i].getReal(),   1.0e-15);
113             Assert.assertEquals(-i, secondaryState2Dot[i].getReal(), 1.0e-15);
114             Assert.assertEquals(-i, completeDot[i + main.getDimension() + secondary1.getDimension()].getReal(),  1.0e-15);
115         }
116     }
117 
118     @Test
119     public void testMap() {
120         doTestMap(Decimal64Field.getInstance());
121     }
122 
123     private <T extends RealFieldElement<T>> void doTestMap(final Field<T> field) {
124 
125         FirstOrderFieldDifferentialEquations<T> main = new Linear<>(field, 3, 0);
126         FieldExpandableODE<T> equation = new FieldExpandableODE<>(main);
127         FieldSecondaryEquations<T> secondary1 = new Linear<>(field, 3, main.getDimension());
128         int i1 = equation.addSecondaryEquations(secondary1);
129         FieldSecondaryEquations<T> secondary2 = new Linear<>(field, 5, main.getDimension() + secondary1.getDimension());
130         int i2 = equation.addSecondaryEquations(secondary2);
131         Assert.assertEquals(main.getDimension() + secondary1.getDimension() + secondary2.getDimension(),
132                             equation.getMapper().getTotalDimension());
133         Assert.assertEquals(3, equation.getMapper().getNumberOfEquations());
134         Assert.assertEquals(1, i1);
135         Assert.assertEquals(2, i2);
136 
137         T t0 = field.getZero().add(10);
138         T t  = field.getZero().add(100);
139         T[] complete    = MathArrays.buildArray(field, equation.getMapper().getTotalDimension());
140         for (int i = 0; i < complete.length; ++i) {
141             complete[i] = field.getZero().add(i);
142         }
143         equation.init(t0, complete, t);
144         T[] completeDot = equation.computeDerivatives(t0, complete);
145 
146         try {
147             equation.getMapper().mapStateAndDerivative(t0, MathArrays.buildArray(field, complete.length + 1), completeDot);
148             Assert.fail("an exception should have been thrown");
149         } catch (DimensionMismatchException dme) {
150             // expected
151         }
152         try {
153             equation.getMapper().mapStateAndDerivative(t0, complete, MathArrays.buildArray(field, completeDot.length + 1));
154             Assert.fail("an exception should have been thrown");
155         } catch (DimensionMismatchException dme) {
156             // expected
157         }
158         FieldODEStateAndDerivative<T> state = equation.getMapper().mapStateAndDerivative(t0, complete, completeDot);
159         Assert.assertEquals(2, state.getNumberOfSecondaryStates());
160         Assert.assertEquals(main.getDimension(),       state.getSecondaryStateDimension(0));
161         Assert.assertEquals(secondary1.getDimension(), state.getSecondaryStateDimension(i1));
162         Assert.assertEquals(secondary2.getDimension(), state.getSecondaryStateDimension(i2));
163 
164         T[] mainState             = state.getState();
165         T[] mainStateDot          = state.getDerivative();
166         T[] mainStateAlternate    = state.getSecondaryState(0);
167         T[] mainStateDotAlternate = state.getSecondaryDerivative(0);
168         Assert.assertEquals(main.getDimension(), mainState.length);
169         for (int i = 0; i < main.getDimension(); ++i) {
170             Assert.assertEquals(i, mainState[i].getReal(),             1.0e-15);
171             Assert.assertEquals(i, mainStateDot[i].getReal(),          1.0e-15);
172             Assert.assertEquals(i, mainStateAlternate[i].getReal(),    1.0e-15);
173             Assert.assertEquals(i, mainStateDotAlternate[i].getReal(), 1.0e-15);
174             Assert.assertEquals(i, completeDot[i].getReal(),           1.0e-15);
175         }
176 
177         T[] secondaryState1    = state.getSecondaryState(i1);
178         T[] secondaryState1Dot = state.getSecondaryDerivative(i1);
179         Assert.assertEquals(secondary1.getDimension(), secondaryState1.length);
180         for (int i = 0; i < secondary1.getDimension(); ++i) {
181             Assert.assertEquals(i + main.getDimension(), secondaryState1[i].getReal(),   1.0e-15);
182             Assert.assertEquals(-i, secondaryState1Dot[i].getReal(), 1.0e-15);
183             Assert.assertEquals(-i, completeDot[i + main.getDimension()].getReal(),  1.0e-15);
184         }
185 
186         T[] secondaryState2    = state.getSecondaryState(i2);
187         T[] secondaryState2Dot = state.getSecondaryDerivative(i2);
188         Assert.assertEquals(secondary2.getDimension(), secondaryState2.length);
189         for (int i = 0; i < secondary2.getDimension(); ++i) {
190             Assert.assertEquals(i + main.getDimension() + secondary1.getDimension(), secondaryState2[i].getReal(),   1.0e-15);
191             Assert.assertEquals(-i, secondaryState2Dot[i].getReal(), 1.0e-15);
192             Assert.assertEquals(-i, completeDot[i + main.getDimension() + secondary1.getDimension()].getReal(),  1.0e-15);
193         }
194 
195         T[] remappedState = equation.getMapper().mapState(state);
196         T[] remappedDerivative = equation.getMapper().mapDerivative(state);
197         Assert.assertEquals(equation.getMapper().getTotalDimension(), remappedState.length);
198         Assert.assertEquals(equation.getMapper().getTotalDimension(), remappedDerivative.length);
199         for (int i = 0; i < remappedState.length; ++i) {
200             Assert.assertEquals(complete[i].getReal(),    remappedState[i].getReal(),      1.0e-15);
201             Assert.assertEquals(completeDot[i].getReal(), remappedDerivative[i].getReal(), 1.0e-15);
202         }
203     }
204 
205     @Test(expected=DimensionMismatchException.class)
206     public void testExtractDimensionMismatch() {
207         doTestExtractDimensionMismatch(Decimal64Field.getInstance());
208     }
209 
210     private <T extends RealFieldElement<T>> void doTestExtractDimensionMismatch(final Field<T> field)
211         throws DimensionMismatchException {
212 
213         FirstOrderFieldDifferentialEquations<T> main = new Linear<>(field, 3, 0);
214         FieldExpandableODE<T> equation = new FieldExpandableODE<>(main);
215         FieldSecondaryEquations<T> secondary1 = new Linear<>(field, 3, main.getDimension());
216         int i1 = equation.addSecondaryEquations(secondary1);
217         T[] tooShort    = MathArrays.buildArray(field, main.getDimension());
218         equation.getMapper().extractEquationData(i1, tooShort);
219     }
220 
221     @Test(expected=DimensionMismatchException.class)
222     public void testInsertTooShortComplete() {
223         doTestInsertTooShortComplete(Decimal64Field.getInstance());
224     }
225 
226     private <T extends RealFieldElement<T>> void doTestInsertTooShortComplete(final Field<T> field)
227         throws DimensionMismatchException {
228 
229         FirstOrderFieldDifferentialEquations<T> main = new Linear<>(field, 3, 0);
230         FieldExpandableODE<T> equation = new FieldExpandableODE<>(main);
231         FieldSecondaryEquations<T> secondary1 = new Linear<>(field, 3, main.getDimension());
232         int i1 = equation.addSecondaryEquations(secondary1);
233         T[] equationData = MathArrays.buildArray(field, secondary1.getDimension());
234         T[] tooShort     = MathArrays.buildArray(field, main.getDimension());
235         equation.getMapper().insertEquationData(i1, equationData, tooShort);
236     }
237 
238     @Test(expected=DimensionMismatchException.class)
239     public void testInsertWrongEquationData() {
240         doTestInsertWrongEquationData(Decimal64Field.getInstance());
241     }
242 
243     private <T extends RealFieldElement<T>> void doTestInsertWrongEquationData(final Field<T> field)
244         throws DimensionMismatchException {
245 
246         FirstOrderFieldDifferentialEquations<T> main = new Linear<>(field, 3, 0);
247         FieldExpandableODE<T> equation = new FieldExpandableODE<>(main);
248         FieldSecondaryEquations<T> secondary1 = new Linear<>(field, 3, main.getDimension());
249         int i1 = equation.addSecondaryEquations(secondary1);
250         T[] wrongEquationData = MathArrays.buildArray(field, secondary1.getDimension() + 1);
251         T[] complete          = MathArrays.buildArray(field, equation.getMapper().getTotalDimension());
252         equation.getMapper().insertEquationData(i1, wrongEquationData, complete);
253     }
254 
255     @Test(expected=MathIllegalArgumentException.class)
256     public void testNegativeIndex() {
257         doTestNegativeIndex(Decimal64Field.getInstance());
258     }
259 
260     private <T extends RealFieldElement<T>> void doTestNegativeIndex(final Field<T> field)
261         throws MathIllegalArgumentException {
262 
263         FirstOrderFieldDifferentialEquations<T> main = new Linear<>(field, 3, 0);
264         FieldExpandableODE<T> equation = new FieldExpandableODE<>(main);
265         T[] complete = MathArrays.buildArray(field, equation.getMapper().getTotalDimension());
266         equation.getMapper().extractEquationData(-1, complete);
267     }
268 
269     @Test(expected=MathIllegalArgumentException.class)
270     public void testTooLargeIndex() {
271         doTestTooLargeIndex(Decimal64Field.getInstance());
272     }
273 
274     private <T extends RealFieldElement<T>> void doTestTooLargeIndex(final Field<T> field)
275         throws MathIllegalArgumentException {
276 
277         FirstOrderFieldDifferentialEquations<T> main = new Linear<>(field, 3, 0);
278         FieldExpandableODE<T> equation = new FieldExpandableODE<>(main);
279         T[] complete = MathArrays.buildArray(field, equation.getMapper().getTotalDimension());
280         equation.getMapper().extractEquationData(+1, complete);
281     }
282 
283     private static final class  Linear<T extends RealFieldElement<T>>
284         implements  FirstOrderFieldDifferentialEquations<T>, FieldSecondaryEquations<T> {
285 
286         private final Field<T> field;
287         private final int dimension;
288         private final int start;
289 
290         private Linear(final Field<T> field, final int dimension, final int start) {
291             this.field     = field;
292             this.dimension = dimension;
293             this.start     = start;
294         }
295 
296         @Override
297         public int getDimension() {
298             return dimension;
299         }
300 
301         @Override
302         public void init(final T t0, final T[] y0, final T finalTime) {
303             Assert.assertEquals(dimension, y0.length);
304             Assert.assertEquals(10.0,  t0.getReal(), 1.0e-15);
305             Assert.assertEquals(100.0, finalTime.getReal(), 1.0e-15);
306             for (int i = 0; i < y0.length; ++i) {
307                 Assert.assertEquals(i, y0[i].getReal(), 1.0e-15);
308             }
309         }
310 
311         @Override
312         public T[] computeDerivatives(final T t, final T[] y) {
313             final T[] yDot = MathArrays.buildArray(field, dimension);
314             for (int i = 0; i < dimension; ++i) {
315                 yDot[i] = field.getZero().add(i);
316             }
317             return yDot;
318         }
319 
320         @Override
321         public void init(final T t0, final T[] primary0, final T[] secondary0, final T finalTime) {
322             Assert.assertEquals(dimension, secondary0.length);
323             Assert.assertEquals(10.0,  t0.getReal(), 1.0e-15);
324             Assert.assertEquals(100.0, finalTime.getReal(), 1.0e-15);
325             for (int i = 0; i < primary0.length; ++i) {
326                 Assert.assertEquals(i, primary0[i].getReal(), 1.0e-15);
327             }
328             for (int i = 0; i < secondary0.length; ++i) {
329                 Assert.assertEquals(start + i, secondary0[i].getReal(), 1.0e-15);
330             }
331         }
332 
333         @Override
334         public T[] computeDerivatives(final T t, final T[] primary, final T[] primaryDot, final T[] secondary) {
335             final T[] secondaryDot = MathArrays.buildArray(field, dimension);
336             for (int i = 0; i < dimension; ++i) {
337                 secondaryDot[i] = field.getZero().subtract(i);
338             }
339             return secondaryDot;
340         }
341     }
342 }