001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.ode;
019
020import java.io.Serializable;
021
022import org.apache.commons.math3.RealFieldElement;
023import org.apache.commons.math3.exception.DimensionMismatchException;
024import org.apache.commons.math3.exception.MathIllegalArgumentException;
025import org.apache.commons.math3.exception.util.LocalizedFormats;
026import org.apache.commons.math3.util.MathArrays;
027
028/**
029 * Class mapping the part of a complete state or derivative that pertains
030 * to a set of differential equations.
031 * <p>
032 * Instances of this class are guaranteed to be immutable.
033 * </p>
034 * @see FieldExpandableODE
035 * @param <T> the type of the field elements
036 * @since 3.6
037 */
038public class FieldEquationsMapper<T extends RealFieldElement<T>> implements Serializable {
039
040    /** Serializable UID. */
041    private static final long serialVersionUID = 20151114L;
042
043    /** Start indices of the components. */
044    private final int[] start;
045
046    /** Create a mapper by adding a new equation to another mapper.
047     * <p>
048     * The new equation will have index {@code mapper.}{@link #getNumberOfEquations()},
049     * or 0 if {@code mapper} is null.
050     * </p>
051     * @param mapper former mapper, with one equation less (null for first equation)
052     * @param dimension dimension of the equation state vector
053     */
054    FieldEquationsMapper(final FieldEquationsMapper<T> mapper, final int dimension) {
055        final int index = (mapper == null) ? 0 : mapper.getNumberOfEquations();
056        this.start = new int[index + 2];
057        if (mapper == null) {
058            start[0] = 0;
059        } else {
060            System.arraycopy(mapper.start, 0, start, 0, index + 1);
061        }
062        start[index + 1] = start[index] + dimension;
063    }
064
065    /** Get the number of equations mapped.
066     * @return number of equations mapped
067     */
068    public int getNumberOfEquations() {
069        return start.length - 1;
070    }
071
072    /** Return the dimension of the complete set of equations.
073     * <p>
074     * The complete set of equations correspond to the primary set plus all secondary sets.
075     * </p>
076     * @return dimension of the complete set of equations
077     */
078    public int getTotalDimension() {
079        return start[start.length - 1];
080    }
081
082    /** Map a state to a complete flat array.
083     * @param state state to map
084     * @return flat array containing the mapped state, including primary and secondary components
085     */
086    public T[] mapState(final FieldODEState<T> state) {
087        final T[] y = MathArrays.buildArray(state.getTime().getField(), getTotalDimension());
088        int index = 0;
089        insertEquationData(index, state.getState(), y);
090        while (++index < getNumberOfEquations()) {
091            insertEquationData(index, state.getSecondaryState(index), y);
092        }
093        return y;
094    }
095
096    /** Map a state derivative to a complete flat array.
097     * @param state state to map
098     * @return flat array containing the mapped state derivative, including primary and secondary components
099     */
100    public T[] mapDerivative(final FieldODEStateAndDerivative<T> state) {
101        final T[] yDot = MathArrays.buildArray(state.getTime().getField(), getTotalDimension());
102        int index = 0;
103        insertEquationData(index, state.getDerivative(), yDot);
104        while (++index < getNumberOfEquations()) {
105            insertEquationData(index, state.getSecondaryDerivative(index), yDot);
106        }
107        return yDot;
108    }
109
110    /** Map flat arrays to a state and derivative.
111     * @param t time
112     * @param y state array to map, including primary and secondary components
113     * @param yDot state derivative array to map, including primary and secondary components
114     * @return mapped state
115     * @exception DimensionMismatchException if an array does not match total dimension
116     */
117    public FieldODEStateAndDerivative<T> mapStateAndDerivative(final T t, final T[] y, final T[] yDot)
118        throws DimensionMismatchException {
119
120        if (y.length != getTotalDimension()) {
121            throw new DimensionMismatchException(y.length, getTotalDimension());
122        }
123
124        if (yDot.length != getTotalDimension()) {
125            throw new DimensionMismatchException(yDot.length, getTotalDimension());
126        }
127
128        final int n = getNumberOfEquations();
129        int index = 0;
130        final T[] state      = extractEquationData(index, y);
131        final T[] derivative = extractEquationData(index, yDot);
132        if (n < 2) {
133            return new FieldODEStateAndDerivative<T>(t, state, derivative);
134        } else {
135            final T[][] secondaryState      = MathArrays.buildArray(t.getField(), n - 1, -1);
136            final T[][] secondaryDerivative = MathArrays.buildArray(t.getField(), n - 1, -1);
137            while (++index < getNumberOfEquations()) {
138                secondaryState[index - 1]      = extractEquationData(index, y);
139                secondaryDerivative[index - 1] = extractEquationData(index, yDot);
140            }
141            return new FieldODEStateAndDerivative<T>(t, state, derivative, secondaryState, secondaryDerivative);
142        }
143    }
144
145    /** Extract equation data from a complete state or derivative array.
146     * @param index index of the equation, must be between 0 included and
147     * {@link #getNumberOfEquations()} (excluded)
148     * @param complete complete state or derivative array from which
149     * equation data should be retrieved
150     * @return equation data
151     * @exception MathIllegalArgumentException if index is out of range
152     * @exception DimensionMismatchException if complete state has not enough elements
153     */
154    public T[] extractEquationData(final int index, final T[] complete)
155        throws MathIllegalArgumentException, DimensionMismatchException {
156        checkIndex(index);
157        final int begin     = start[index];
158        final int end       = start[index + 1];
159        if (complete.length < end) {
160            throw new DimensionMismatchException(complete.length, end);
161        }
162        final int dimension = end - begin;
163        final T[] equationData = MathArrays.buildArray(complete[0].getField(), dimension);
164        System.arraycopy(complete, begin, equationData, 0, dimension);
165        return equationData;
166    }
167
168    /** Insert equation data into a complete state or derivative array.
169     * @param index index of the equation, must be between 0 included and
170     * {@link #getNumberOfEquations()} (excluded)
171     * @param equationData equation data to be inserted into the complete array
172     * @param complete placeholder where to put equation data (only the
173     * part corresponding to the equation will be overwritten)
174     * @exception DimensionMismatchException if either array has not enough elements
175     */
176    public void insertEquationData(final int index, T[] equationData, T[] complete)
177        throws DimensionMismatchException {
178        checkIndex(index);
179        final int begin     = start[index];
180        final int end       = start[index + 1];
181        final int dimension = end - begin;
182        if (complete.length < end) {
183            throw new DimensionMismatchException(complete.length, end);
184        }
185        if (equationData.length != dimension) {
186            throw new DimensionMismatchException(equationData.length, dimension);
187        }
188        System.arraycopy(equationData, 0, complete, begin, dimension);
189    }
190
191    /** Check equation index.
192     * @param index index of the equation, must be between 0 included and
193     * {@link #getNumberOfEquations()} (excluded)
194     * @exception MathIllegalArgumentException if index is out of range
195     */
196    private void checkIndex(final int index) throws MathIllegalArgumentException {
197        if (index < 0 || index > start.length - 2) {
198            throw new MathIllegalArgumentException(LocalizedFormats.ARGUMENT_OUTSIDE_DOMAIN,
199                                                   index, 0, start.length - 2);
200        }
201    }
202
203}