FieldEquationsMapper.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.ode;

  18. import java.io.Serializable;

  19. import org.apache.commons.math4.legacy.core.RealFieldElement;
  20. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  21. import org.apache.commons.math4.legacy.exception.MathIllegalArgumentException;
  22. import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
  23. import org.apache.commons.math4.legacy.core.MathArrays;

  24. /**
  25.  * Class mapping the part of a complete state or derivative that pertains
  26.  * to a set of differential equations.
  27.  * <p>
  28.  * Instances of this class are guaranteed to be immutable.
  29.  * </p>
  30.  * @see FieldExpandableODE
  31.  * @param <T> the type of the field elements
  32.  * @since 3.6
  33.  */
  34. public class FieldEquationsMapper<T extends RealFieldElement<T>> implements Serializable {

  35.     /** Serializable UID. */
  36.     private static final long serialVersionUID = 20151114L;

  37.     /** Start indices of the components. */
  38.     private final int[] start;

  39.     /** Create a mapper by adding a new equation to another mapper.
  40.      * <p>
  41.      * The new equation will have index {@code mapper.}{@link #getNumberOfEquations()},
  42.      * or 0 if {@code mapper} is null.
  43.      * </p>
  44.      * @param mapper former mapper, with one equation less (null for first equation)
  45.      * @param dimension dimension of the equation state vector
  46.      */
  47.     FieldEquationsMapper(final FieldEquationsMapper<T> mapper, final int dimension) {
  48.         final int index = (mapper == null) ? 0 : mapper.getNumberOfEquations();
  49.         this.start = new int[index + 2];
  50.         if (mapper == null) {
  51.             start[0] = 0;
  52.         } else {
  53.             System.arraycopy(mapper.start, 0, start, 0, index + 1);
  54.         }
  55.         start[index + 1] = start[index] + dimension;
  56.     }

  57.     /** Get the number of equations mapped.
  58.      * @return number of equations mapped
  59.      */
  60.     public int getNumberOfEquations() {
  61.         return start.length - 1;
  62.     }

  63.     /** Return the dimension of the complete set of equations.
  64.      * <p>
  65.      * The complete set of equations correspond to the primary set plus all secondary sets.
  66.      * </p>
  67.      * @return dimension of the complete set of equations
  68.      */
  69.     public int getTotalDimension() {
  70.         return start[start.length - 1];
  71.     }

  72.     /** Map a state to a complete flat array.
  73.      * @param state state to map
  74.      * @return flat array containing the mapped state, including primary and secondary components
  75.      */
  76.     public T[] mapState(final FieldODEState<T> state) {
  77.         final T[] y = MathArrays.buildArray(state.getTime().getField(), getTotalDimension());
  78.         int index = 0;
  79.         insertEquationData(index, state.getState(), y);
  80.         while (++index < getNumberOfEquations()) {
  81.             insertEquationData(index, state.getSecondaryState(index), y);
  82.         }
  83.         return y;
  84.     }

  85.     /** Map a state derivative to a complete flat array.
  86.      * @param state state to map
  87.      * @return flat array containing the mapped state derivative, including primary and secondary components
  88.      */
  89.     public T[] mapDerivative(final FieldODEStateAndDerivative<T> state) {
  90.         final T[] yDot = MathArrays.buildArray(state.getTime().getField(), getTotalDimension());
  91.         int index = 0;
  92.         insertEquationData(index, state.getDerivative(), yDot);
  93.         while (++index < getNumberOfEquations()) {
  94.             insertEquationData(index, state.getSecondaryDerivative(index), yDot);
  95.         }
  96.         return yDot;
  97.     }

  98.     /** Map flat arrays to a state and derivative.
  99.      * @param t time
  100.      * @param y state array to map, including primary and secondary components
  101.      * @param yDot state derivative array to map, including primary and secondary components
  102.      * @return mapped state
  103.      * @exception DimensionMismatchException if an array does not match total dimension
  104.      */
  105.     public FieldODEStateAndDerivative<T> mapStateAndDerivative(final T t, final T[] y, final T[] yDot)
  106.         throws DimensionMismatchException {

  107.         if (y.length != getTotalDimension()) {
  108.             throw new DimensionMismatchException(y.length, getTotalDimension());
  109.         }

  110.         if (yDot.length != getTotalDimension()) {
  111.             throw new DimensionMismatchException(yDot.length, getTotalDimension());
  112.         }

  113.         final int n = getNumberOfEquations();
  114.         int index = 0;
  115.         final T[] state      = extractEquationData(index, y);
  116.         final T[] derivative = extractEquationData(index, yDot);
  117.         if (n < 2) {
  118.             return new FieldODEStateAndDerivative<>(t, state, derivative);
  119.         } else {
  120.             final T[][] secondaryState      = MathArrays.buildArray(t.getField(), n - 1, -1);
  121.             final T[][] secondaryDerivative = MathArrays.buildArray(t.getField(), n - 1, -1);
  122.             while (++index < getNumberOfEquations()) {
  123.                 secondaryState[index - 1]      = extractEquationData(index, y);
  124.                 secondaryDerivative[index - 1] = extractEquationData(index, yDot);
  125.             }
  126.             return new FieldODEStateAndDerivative<>(t, state, derivative, secondaryState, secondaryDerivative);
  127.         }
  128.     }

  129.     /** Extract equation data from a complete state or derivative array.
  130.      * @param index index of the equation, must be between 0 included and
  131.      * {@link #getNumberOfEquations()} (excluded)
  132.      * @param complete complete state or derivative array from which
  133.      * equation data should be retrieved
  134.      * @return equation data
  135.      * @exception MathIllegalArgumentException if index is out of range
  136.      * @exception DimensionMismatchException if complete state has not enough elements
  137.      */
  138.     public T[] extractEquationData(final int index, final T[] complete)
  139.         throws MathIllegalArgumentException, DimensionMismatchException {
  140.         checkIndex(index);
  141.         final int begin     = start[index];
  142.         final int end       = start[index + 1];
  143.         if (complete.length < end) {
  144.             throw new DimensionMismatchException(complete.length, end);
  145.         }
  146.         final int dimension = end - begin;
  147.         final T[] equationData = MathArrays.buildArray(complete[0].getField(), dimension);
  148.         System.arraycopy(complete, begin, equationData, 0, dimension);
  149.         return equationData;
  150.     }

  151.     /** Insert equation data into a complete state or derivative array.
  152.      * @param index index of the equation, must be between 0 included and
  153.      * {@link #getNumberOfEquations()} (excluded)
  154.      * @param equationData equation data to be inserted into the complete array
  155.      * @param complete placeholder where to put equation data (only the
  156.      * part corresponding to the equation will be overwritten)
  157.      * @exception DimensionMismatchException if either array has not enough elements
  158.      */
  159.     public void insertEquationData(final int index, T[] equationData, T[] complete)
  160.         throws DimensionMismatchException {
  161.         checkIndex(index);
  162.         final int begin     = start[index];
  163.         final int end       = start[index + 1];
  164.         final int dimension = end - begin;
  165.         if (complete.length < end) {
  166.             throw new DimensionMismatchException(complete.length, end);
  167.         }
  168.         if (equationData.length != dimension) {
  169.             throw new DimensionMismatchException(equationData.length, dimension);
  170.         }
  171.         System.arraycopy(equationData, 0, complete, begin, dimension);
  172.     }

  173.     /** Check equation index.
  174.      * @param index index of the equation, must be between 0 included and
  175.      * {@link #getNumberOfEquations()} (excluded)
  176.      * @exception MathIllegalArgumentException if index is out of range
  177.      */
  178.     private void checkIndex(final int index) throws MathIllegalArgumentException {
  179.         if (index < 0 || index > start.length - 2) {
  180.             throw new MathIllegalArgumentException(LocalizedFormats.ARGUMENT_OUTSIDE_DOMAIN,
  181.                                                    index, 0, start.length - 2);
  182.         }
  183.     }
  184. }