EmbeddedRungeKuttaFieldIntegrator.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.ode.nonstiff;

  18. import org.apache.commons.math4.legacy.core.Field;
  19. import org.apache.commons.math4.legacy.core.RealFieldElement;
  20. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  21. import org.apache.commons.math4.legacy.exception.MaxCountExceededException;
  22. import org.apache.commons.math4.legacy.exception.NoBracketingException;
  23. import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
  24. import org.apache.commons.math4.legacy.ode.FieldEquationsMapper;
  25. import org.apache.commons.math4.legacy.ode.FieldExpandableODE;
  26. import org.apache.commons.math4.legacy.ode.FieldODEState;
  27. import org.apache.commons.math4.legacy.ode.FieldODEStateAndDerivative;
  28. import org.apache.commons.math4.legacy.core.MathArrays;

  29. /**
  30.  * This class implements the common part of all embedded Runge-Kutta
  31.  * integrators for Ordinary Differential Equations.
  32.  *
  33.  * <p>These methods are embedded explicit Runge-Kutta methods with two
  34.  * sets of coefficients allowing to estimate the error, their Butcher
  35.  * arrays are as follows :
  36.  * <pre>
  37.  *    0  |
  38.  *   c2  | a21
  39.  *   c3  | a31  a32
  40.  *   ... |        ...
  41.  *   cs  | as1  as2  ...  ass-1
  42.  *       |--------------------------
  43.  *       |  b1   b2  ...   bs-1  bs
  44.  *       |  b'1  b'2 ...   b's-1 b's
  45.  * </pre>
  46.  *
  47.  * <p>In fact, we rather use the array defined by ej = bj - b'j to
  48.  * compute directly the error rather than computing two estimates and
  49.  * then comparing them.</p>
  50.  *
  51.  * <p>Some methods are qualified as <i>fsal</i> (first same as last)
  52.  * methods. This means the last evaluation of the derivatives in one
  53.  * step is the same as the first in the next step. Then, this
  54.  * evaluation can be reused from one step to the next one and the cost
  55.  * of such a method is really s-1 evaluations despite the method still
  56.  * has s stages. This behaviour is true only for successful steps, if
  57.  * the step is rejected after the error estimation phase, no
  58.  * evaluation is saved. For an <i>fsal</i> method, we have cs = 1 and
  59.  * asi = bi for all i.</p>
  60.  *
  61.  * @param <T> the type of the field elements
  62.  * @since 3.6
  63.  */

  64. public abstract class EmbeddedRungeKuttaFieldIntegrator<T extends RealFieldElement<T>>
  65.     extends AdaptiveStepsizeFieldIntegrator<T>
  66.     implements FieldButcherArrayProvider<T> {

  67.     /** Index of the pre-computed derivative for <i>fsal</i> methods. */
  68.     private final int fsal;

  69.     /** Time steps from Butcher array (without the first zero). */
  70.     private final T[] c;

  71.     /** Internal weights from Butcher array (without the first empty row). */
  72.     private final T[][] a;

  73.     /** External weights for the high order method from Butcher array. */
  74.     private final T[] b;

  75.     /** Stepsize control exponent. */
  76.     private final T exp;

  77.     /** Safety factor for stepsize control. */
  78.     private T safety;

  79.     /** Minimal reduction factor for stepsize control. */
  80.     private T minReduction;

  81.     /** Maximal growth factor for stepsize control. */
  82.     private T maxGrowth;

  83.     /** Build a Runge-Kutta integrator with the given Butcher array.
  84.      * @param field field to which the time and state vector elements belong
  85.      * @param name name of the method
  86.      * @param fsal index of the pre-computed derivative for <i>fsal</i> methods
  87.      * or -1 if method is not <i>fsal</i>
  88.      * @param minStep minimal step (sign is irrelevant, regardless of
  89.      * integration direction, forward or backward), the last step can
  90.      * be smaller than this
  91.      * @param maxStep maximal step (sign is irrelevant, regardless of
  92.      * integration direction, forward or backward), the last step can
  93.      * be smaller than this
  94.      * @param scalAbsoluteTolerance allowed absolute error
  95.      * @param scalRelativeTolerance allowed relative error
  96.      */
  97.     protected EmbeddedRungeKuttaFieldIntegrator(final Field<T> field, final String name, final int fsal,
  98.                                                 final double minStep, final double maxStep,
  99.                                                 final double scalAbsoluteTolerance,
  100.                                                 final double scalRelativeTolerance) {

  101.         super(field, name, minStep, maxStep, scalAbsoluteTolerance, scalRelativeTolerance);

  102.         this.fsal = fsal;
  103.         this.c    = getC();
  104.         this.a    = getA();
  105.         this.b    = getB();

  106.         exp = field.getOne().divide(-getOrder());

  107.         // set the default values of the algorithm control parameters
  108.         setSafety(field.getZero().add(0.9));
  109.         setMinReduction(field.getZero().add(0.2));
  110.         setMaxGrowth(field.getZero().add(10.0));
  111.     }

  112.     /** Build a Runge-Kutta integrator with the given Butcher array.
  113.      * @param field field to which the time and state vector elements belong
  114.      * @param name name of the method
  115.      * @param fsal index of the pre-computed derivative for <i>fsal</i> methods
  116.      * or -1 if method is not <i>fsal</i>
  117.      * @param minStep minimal step (must be positive even for backward
  118.      * integration), the last step can be smaller than this
  119.      * @param maxStep maximal step (must be positive even for backward
  120.      * integration)
  121.      * @param vecAbsoluteTolerance allowed absolute error
  122.      * @param vecRelativeTolerance allowed relative error
  123.      */
  124.     protected EmbeddedRungeKuttaFieldIntegrator(final Field<T> field, final String name, final int fsal,
  125.                                                 final double   minStep, final double maxStep,
  126.                                                 final double[] vecAbsoluteTolerance,
  127.                                                 final double[] vecRelativeTolerance) {

  128.         super(field, name, minStep, maxStep, vecAbsoluteTolerance, vecRelativeTolerance);

  129.         this.fsal = fsal;
  130.         this.c    = getC();
  131.         this.a    = getA();
  132.         this.b    = getB();

  133.         exp = field.getOne().divide(-getOrder());

  134.         // set the default values of the algorithm control parameters
  135.         setSafety(field.getZero().add(0.9));
  136.         setMinReduction(field.getZero().add(0.2));
  137.         setMaxGrowth(field.getZero().add(10.0));
  138.     }

  139.     /** Create a fraction.
  140.      * @param p numerator
  141.      * @param q denominator
  142.      * @return p/q computed in the instance field
  143.      */
  144.     protected T fraction(final int p, final int q) {
  145.         return getField().getOne().multiply(p).divide(q);
  146.     }

  147.     /** Create a fraction.
  148.      * @param p numerator
  149.      * @param q denominator
  150.      * @return p/q computed in the instance field
  151.      */
  152.     protected T fraction(final double p, final double q) {
  153.         return getField().getOne().multiply(p).divide(q);
  154.     }

  155.     /** Create an interpolator.
  156.      * @param forward integration direction indicator
  157.      * @param yDotK slopes at the intermediate points
  158.      * @param globalPreviousState start of the global step
  159.      * @param globalCurrentState end of the global step
  160.      * @param mapper equations mapper for the all equations
  161.      * @return external weights for the high order method from Butcher array
  162.      */
  163.     protected abstract RungeKuttaFieldStepInterpolator<T> createInterpolator(boolean forward, T[][] yDotK,
  164.                                                                              FieldODEStateAndDerivative<T> globalPreviousState,
  165.                                                                              FieldODEStateAndDerivative<T> globalCurrentState,
  166.                                                                              FieldEquationsMapper<T> mapper);
  167.     /** Get the order of the method.
  168.      * @return order of the method
  169.      */
  170.     public abstract int getOrder();

  171.     /** Get the safety factor for stepsize control.
  172.      * @return safety factor
  173.      */
  174.     public T getSafety() {
  175.         return safety;
  176.     }

  177.     /** Set the safety factor for stepsize control.
  178.      * @param safety safety factor
  179.      */
  180.     public void setSafety(final T safety) {
  181.         this.safety = safety;
  182.     }

  183.     /** {@inheritDoc} */
  184.     @Override
  185.     public FieldODEStateAndDerivative<T> integrate(final FieldExpandableODE<T> equations,
  186.                                                    final FieldODEState<T> initialState, final T finalTime)
  187.         throws NumberIsTooSmallException, DimensionMismatchException,
  188.         MaxCountExceededException, NoBracketingException {

  189.         sanityChecks(initialState, finalTime);
  190.         final T   t0 = initialState.getTime();
  191.         final T[] y0 = equations.getMapper().mapState(initialState);
  192.         setStepStart(initIntegration(equations, t0, y0, finalTime));
  193.         final boolean forward = finalTime.subtract(initialState.getTime()).getReal() > 0;

  194.         // create some internal working arrays
  195.         final int   stages = c.length + 1;
  196.         T[]         y      = y0;
  197.         final T[][] yDotK  = MathArrays.buildArray(getField(), stages, -1);
  198.         final T[]   yTmp   = MathArrays.buildArray(getField(), y0.length);

  199.         // set up integration control objects
  200.         T  hNew           = getField().getZero();
  201.         boolean firstTime = true;

  202.         // main integration loop
  203.         setIsLastStep(false);
  204.         do {

  205.             // iterate over step size, ensuring local normalized error is smaller than 1
  206.             T error = getField().getZero().add(10);
  207.             while (error.subtract(1.0).getReal() >= 0) {

  208.                 // first stage
  209.                 y        = equations.getMapper().mapState(getStepStart());
  210.                 yDotK[0] = equations.getMapper().mapDerivative(getStepStart());

  211.                 if (firstTime) {
  212.                     final T[] scale = MathArrays.buildArray(getField(), mainSetDimension);
  213.                     if (vecAbsoluteTolerance == null) {
  214.                         for (int i = 0; i < scale.length; ++i) {
  215.                             scale[i] = y[i].abs().multiply(scalRelativeTolerance).add(scalAbsoluteTolerance);
  216.                         }
  217.                     } else {
  218.                         for (int i = 0; i < scale.length; ++i) {
  219.                             scale[i] = y[i].abs().multiply(vecRelativeTolerance[i]).add(vecAbsoluteTolerance[i]);
  220.                         }
  221.                     }
  222.                     hNew = initializeStep(forward, getOrder(), scale, getStepStart(), equations.getMapper());
  223.                     firstTime = false;
  224.                 }

  225.                 setStepSize(hNew);
  226.                 if (forward) {
  227.                     if (getStepStart().getTime().add(getStepSize()).subtract(finalTime).getReal() >= 0) {
  228.                         setStepSize(finalTime.subtract(getStepStart().getTime()));
  229.                     }
  230.                 } else {
  231.                     if (getStepStart().getTime().add(getStepSize()).subtract(finalTime).getReal() <= 0) {
  232.                         setStepSize(finalTime.subtract(getStepStart().getTime()));
  233.                     }
  234.                 }

  235.                 // next stages
  236.                 for (int k = 1; k < stages; ++k) {

  237.                     for (int j = 0; j < y0.length; ++j) {
  238.                         T sum = yDotK[0][j].multiply(a[k-1][0]);
  239.                         for (int l = 1; l < k; ++l) {
  240.                             sum = sum.add(yDotK[l][j].multiply(a[k-1][l]));
  241.                         }
  242.                         yTmp[j] = y[j].add(getStepSize().multiply(sum));
  243.                     }

  244.                     yDotK[k] = computeDerivatives(getStepStart().getTime().add(getStepSize().multiply(c[k-1])), yTmp);
  245.                 }

  246.                 // estimate the state at the end of the step
  247.                 for (int j = 0; j < y0.length; ++j) {
  248.                     T sum    = yDotK[0][j].multiply(b[0]);
  249.                     for (int l = 1; l < stages; ++l) {
  250.                         sum = sum.add(yDotK[l][j].multiply(b[l]));
  251.                     }
  252.                     yTmp[j] = y[j].add(getStepSize().multiply(sum));
  253.                 }

  254.                 // estimate the error at the end of the step
  255.                 error = estimateError(yDotK, y, yTmp, getStepSize());
  256.                 if (error.subtract(1.0).getReal() >= 0) {
  257.                     // reject the step and attempt to reduce error by stepsize control
  258.                     final T factor = RealFieldElement.min(maxGrowth,
  259.                                                    RealFieldElement.max(minReduction, safety.multiply(error.pow(exp))));
  260.                     hNew = filterStep(getStepSize().multiply(factor), forward, false);
  261.                 }
  262.             }
  263.             final T   stepEnd = getStepStart().getTime().add(getStepSize());
  264.             final T[] yDotTmp = (fsal >= 0) ? yDotK[fsal] : computeDerivatives(stepEnd, yTmp);
  265.             final FieldODEStateAndDerivative<T> stateTmp = new FieldODEStateAndDerivative<>(stepEnd, yTmp, yDotTmp);

  266.             // local error is small enough: accept the step, trigger events and step handlers
  267.             System.arraycopy(yTmp, 0, y, 0, y0.length);
  268.             setStepStart(acceptStep(createInterpolator(forward, yDotK, getStepStart(), stateTmp, equations.getMapper()),
  269.                                     finalTime));

  270.             if (!isLastStep()) {

  271.                 // stepsize control for next step
  272.                 final T factor = RealFieldElement.min(maxGrowth,
  273.                                                RealFieldElement.max(minReduction, safety.multiply(error.pow(exp))));
  274.                 final T  scaledH    = getStepSize().multiply(factor);
  275.                 final T  nextT      = getStepStart().getTime().add(scaledH);
  276.                 final boolean nextIsLast = forward ?
  277.                                            nextT.subtract(finalTime).getReal() >= 0 :
  278.                                            nextT.subtract(finalTime).getReal() <= 0;
  279.                 hNew = filterStep(scaledH, forward, nextIsLast);

  280.                 final T  filteredNextT      = getStepStart().getTime().add(hNew);
  281.                 final boolean filteredNextIsLast = forward ?
  282.                                                    filteredNextT.subtract(finalTime).getReal() >= 0 :
  283.                                                    filteredNextT.subtract(finalTime).getReal() <= 0;
  284.                 if (filteredNextIsLast) {
  285.                     hNew = finalTime.subtract(getStepStart().getTime());
  286.                 }
  287.             }
  288.         } while (!isLastStep());

  289.         final FieldODEStateAndDerivative<T> finalState = getStepStart();
  290.         resetInternalState();
  291.         return finalState;
  292.     }

  293.     /** Get the minimal reduction factor for stepsize control.
  294.      * @return minimal reduction factor
  295.      */
  296.     public T getMinReduction() {
  297.         return minReduction;
  298.     }

  299.     /** Set the minimal reduction factor for stepsize control.
  300.      * @param minReduction minimal reduction factor
  301.      */
  302.     public void setMinReduction(final T minReduction) {
  303.         this.minReduction = minReduction;
  304.     }

  305.     /** Get the maximal growth factor for stepsize control.
  306.      * @return maximal growth factor
  307.      */
  308.     public T getMaxGrowth() {
  309.         return maxGrowth;
  310.     }

  311.     /** Set the maximal growth factor for stepsize control.
  312.      * @param maxGrowth maximal growth factor
  313.      */
  314.     public void setMaxGrowth(final T maxGrowth) {
  315.         this.maxGrowth = maxGrowth;
  316.     }

  317.     /** Compute the error ratio.
  318.      * @param yDotK derivatives computed during the first stages
  319.      * @param y0 estimate of the step at the start of the step
  320.      * @param y1 estimate of the step at the end of the step
  321.      * @param h  current step
  322.      * @return error ratio, greater than 1 if step should be rejected
  323.      */
  324.     protected abstract T estimateError(T[][] yDotK, T[] y0, T[] y1, T h);
  325. }