RungeKuttaFieldIntegrator.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.ode.nonstiff;


  18. import org.apache.commons.math4.legacy.core.Field;
  19. import org.apache.commons.math4.legacy.core.RealFieldElement;
  20. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  21. import org.apache.commons.math4.legacy.exception.MaxCountExceededException;
  22. import org.apache.commons.math4.legacy.exception.NoBracketingException;
  23. import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
  24. import org.apache.commons.math4.legacy.ode.AbstractFieldIntegrator;
  25. import org.apache.commons.math4.legacy.ode.FieldEquationsMapper;
  26. import org.apache.commons.math4.legacy.ode.FieldExpandableODE;
  27. import org.apache.commons.math4.legacy.ode.FirstOrderFieldDifferentialEquations;
  28. import org.apache.commons.math4.legacy.ode.FieldODEState;
  29. import org.apache.commons.math4.legacy.ode.FieldODEStateAndDerivative;
  30. import org.apache.commons.math4.legacy.core.MathArrays;

  31. /**
  32.  * This class implements the common part of all fixed step Runge-Kutta
  33.  * integrators for Ordinary Differential Equations.
  34.  *
  35.  * <p>These methods are explicit Runge-Kutta methods, their Butcher
  36.  * arrays are as follows :
  37.  * <pre>
  38.  *    0  |
  39.  *   c2  | a21
  40.  *   c3  | a31  a32
  41.  *   ... |        ...
  42.  *   cs  | as1  as2  ...  ass-1
  43.  *       |--------------------------
  44.  *       |  b1   b2  ...   bs-1  bs
  45.  * </pre>
  46.  *
  47.  * @see EulerFieldIntegrator
  48.  * @see ClassicalRungeKuttaFieldIntegrator
  49.  * @see GillFieldIntegrator
  50.  * @see MidpointFieldIntegrator
  51.  * @param <T> the type of the field elements
  52.  * @since 3.6
  53.  */

  54. public abstract class RungeKuttaFieldIntegrator<T extends RealFieldElement<T>>
  55.     extends AbstractFieldIntegrator<T>
  56.     implements FieldButcherArrayProvider<T> {

  57.     /** Time steps from Butcher array (without the first zero). */
  58.     private final T[] c;

  59.     /** Internal weights from Butcher array (without the first empty row). */
  60.     private final T[][] a;

  61.     /** External weights for the high order method from Butcher array. */
  62.     private final T[] b;

  63.     /** Integration step. */
  64.     private final T step;

  65.     /** Simple constructor.
  66.      * Build a Runge-Kutta integrator with the given
  67.      * step. The default step handler does nothing.
  68.      * @param field field to which the time and state vector elements belong
  69.      * @param name name of the method
  70.      * @param step integration step
  71.      */
  72.     protected RungeKuttaFieldIntegrator(final Field<T> field, final String name, final T step) {
  73.         super(field, name);
  74.         this.c    = getC();
  75.         this.a    = getA();
  76.         this.b    = getB();
  77.         this.step = step.abs();
  78.     }

  79.     /** Create a fraction.
  80.      * @param p numerator
  81.      * @param q denominator
  82.      * @return p/q computed in the instance field
  83.      */
  84.     protected T fraction(final int p, final int q) {
  85.         return getField().getZero().add(p).divide(q);
  86.     }

  87.     /** Create an interpolator.
  88.      * @param forward integration direction indicator
  89.      * @param yDotK slopes at the intermediate points
  90.      * @param globalPreviousState start of the global step
  91.      * @param globalCurrentState end of the global step
  92.      * @param mapper equations mapper for the all equations
  93.      * @return external weights for the high order method from Butcher array
  94.      */
  95.     protected abstract RungeKuttaFieldStepInterpolator<T> createInterpolator(boolean forward, T[][] yDotK,
  96.                                                                              FieldODEStateAndDerivative<T> globalPreviousState,
  97.                                                                              FieldODEStateAndDerivative<T> globalCurrentState,
  98.                                                                              FieldEquationsMapper<T> mapper);

  99.     /** {@inheritDoc} */
  100.     @Override
  101.     public FieldODEStateAndDerivative<T> integrate(final FieldExpandableODE<T> equations,
  102.                                                    final FieldODEState<T> initialState, final T finalTime)
  103.         throws NumberIsTooSmallException, DimensionMismatchException,
  104.         MaxCountExceededException, NoBracketingException {

  105.         sanityChecks(initialState, finalTime);
  106.         final T   t0 = initialState.getTime();
  107.         final T[] y0 = equations.getMapper().mapState(initialState);
  108.         setStepStart(initIntegration(equations, t0, y0, finalTime));
  109.         final boolean forward = finalTime.subtract(initialState.getTime()).getReal() > 0;

  110.         // create some internal working arrays
  111.         final int   stages = c.length + 1;
  112.         T[]         y      = y0;
  113.         final T[][] yDotK  = MathArrays.buildArray(getField(), stages, -1);
  114.         final T[]   yTmp   = MathArrays.buildArray(getField(), y0.length);

  115.         // set up integration control objects
  116.         if (forward) {
  117.             if (getStepStart().getTime().add(step).subtract(finalTime).getReal() >= 0) {
  118.                 setStepSize(finalTime.subtract(getStepStart().getTime()));
  119.             } else {
  120.                 setStepSize(step);
  121.             }
  122.         } else {
  123.             if (getStepStart().getTime().subtract(step).subtract(finalTime).getReal() <= 0) {
  124.                 setStepSize(finalTime.subtract(getStepStart().getTime()));
  125.             } else {
  126.                 setStepSize(step.negate());
  127.             }
  128.         }

  129.         // main integration loop
  130.         setIsLastStep(false);
  131.         do {

  132.             // first stage
  133.             y        = equations.getMapper().mapState(getStepStart());
  134.             yDotK[0] = equations.getMapper().mapDerivative(getStepStart());

  135.             // next stages
  136.             for (int k = 1; k < stages; ++k) {

  137.                 for (int j = 0; j < y0.length; ++j) {
  138.                     T sum = yDotK[0][j].multiply(a[k-1][0]);
  139.                     for (int l = 1; l < k; ++l) {
  140.                         sum = sum.add(yDotK[l][j].multiply(a[k-1][l]));
  141.                     }
  142.                     yTmp[j] = y[j].add(getStepSize().multiply(sum));
  143.                 }

  144.                 yDotK[k] = computeDerivatives(getStepStart().getTime().add(getStepSize().multiply(c[k-1])), yTmp);
  145.             }

  146.             // estimate the state at the end of the step
  147.             for (int j = 0; j < y0.length; ++j) {
  148.                 T sum = yDotK[0][j].multiply(b[0]);
  149.                 for (int l = 1; l < stages; ++l) {
  150.                     sum = sum.add(yDotK[l][j].multiply(b[l]));
  151.                 }
  152.                 yTmp[j] = y[j].add(getStepSize().multiply(sum));
  153.             }
  154.             final T stepEnd   = getStepStart().getTime().add(getStepSize());
  155.             final T[] yDotTmp = computeDerivatives(stepEnd, yTmp);
  156.             final FieldODEStateAndDerivative<T> stateTmp = new FieldODEStateAndDerivative<>(stepEnd, yTmp, yDotTmp);

  157.             // discrete events handling
  158.             System.arraycopy(yTmp, 0, y, 0, y0.length);
  159.             setStepStart(acceptStep(createInterpolator(forward, yDotK, getStepStart(), stateTmp, equations.getMapper()),
  160.                                     finalTime));

  161.             if (!isLastStep()) {

  162.                 // stepsize control for next step
  163.                 final T  nextT      = getStepStart().getTime().add(getStepSize());
  164.                 final boolean nextIsLast = forward ?
  165.                                            (nextT.subtract(finalTime).getReal() >= 0) :
  166.                                            (nextT.subtract(finalTime).getReal() <= 0);
  167.                 if (nextIsLast) {
  168.                     setStepSize(finalTime.subtract(getStepStart().getTime()));
  169.                 }
  170.             }
  171.         } while (!isLastStep());

  172.         final FieldODEStateAndDerivative<T> finalState = getStepStart();
  173.         setStepStart(null);
  174.         setStepSize(null);
  175.         return finalState;
  176.     }

  177.     /** Fast computation of a single step of ODE integration.
  178.      * <p>This method is intended for the limited use case of
  179.      * very fast computation of only one step without using any of the
  180.      * rich features of general integrators that may take some time
  181.      * to set up (i.e. no step handlers, no events handlers, no additional
  182.      * states, no interpolators, no error control, no evaluations count,
  183.      * no sanity checks ...). It handles the strict minimum of computation,
  184.      * so it can be embedded in outer loops.</p>
  185.      * <p>
  186.      * This method is <em>not</em> used at all by the {@link #integrate(FieldExpandableODE,
  187.      * FieldODEState, RealFieldElement)} method. It also completely ignores the step set at
  188.      * construction time, and uses only a single step to go from {@code t0} to {@code t}.
  189.      * </p>
  190.      * <p>
  191.      * As this method does not use any of the state-dependent features of the integrator,
  192.      * it should be reasonably thread-safe <em>if and only if</em> the provided differential
  193.      * equations are themselves thread-safe.
  194.      * </p>
  195.      * @param equations differential equations to integrate
  196.      * @param t0 initial time
  197.      * @param y0 initial value of the state vector at t0
  198.      * @param t target time for the integration
  199.      * (can be set to a value smaller than {@code t0} for backward integration)
  200.      * @return state vector at {@code t}
  201.      */
  202.     public T[] singleStep(final FirstOrderFieldDifferentialEquations<T> equations,
  203.                           final T t0, final T[] y0, final T t) {

  204.         // create some internal working arrays
  205.         final T[] y       = y0.clone();
  206.         final int stages  = c.length + 1;
  207.         final T[][] yDotK = MathArrays.buildArray(getField(), stages, -1);
  208.         final T[] yTmp    = y0.clone();

  209.         // first stage
  210.         final T h = t.subtract(t0);
  211.         yDotK[0] = equations.computeDerivatives(t0, y);

  212.         // next stages
  213.         for (int k = 1; k < stages; ++k) {

  214.             for (int j = 0; j < y0.length; ++j) {
  215.                 T sum = yDotK[0][j].multiply(a[k-1][0]);
  216.                 for (int l = 1; l < k; ++l) {
  217.                     sum = sum.add(yDotK[l][j].multiply(a[k-1][l]));
  218.                 }
  219.                 yTmp[j] = y[j].add(h.multiply(sum));
  220.             }

  221.             yDotK[k] = equations.computeDerivatives(t0.add(h.multiply(c[k-1])), yTmp);
  222.         }

  223.         // estimate the state at the end of the step
  224.         for (int j = 0; j < y0.length; ++j) {
  225.             T sum = yDotK[0][j].multiply(b[0]);
  226.             for (int l = 1; l < stages; ++l) {
  227.                 sum = sum.add(yDotK[l][j].multiply(b[l]));
  228.             }
  229.             y[j] = y[j].add(h.multiply(sum));
  230.         }

  231.         return y;
  232.     }
  233. }