FunctionUtils.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.analysis;

  18. import org.apache.commons.numbers.core.Sum;
  19. import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
  20. import org.apache.commons.math4.legacy.analysis.differentiation.MultivariateDifferentiableFunction;
  21. import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction;
  22. import org.apache.commons.math4.legacy.analysis.function.Identity;
  23. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  24. import org.apache.commons.math4.legacy.exception.NumberIsTooLargeException;

  25. /**
  26.  * Utilities for manipulating function objects.
  27.  *
  28.  * @since 3.0
  29.  */
  30. public final class FunctionUtils {
  31.     /**
  32.      * Class only contains static methods.
  33.      */
  34.     private FunctionUtils() {}

  35.     /**
  36.      * Composes functions.
  37.      * <p>
  38.      * The functions in the argument list are composed sequentially, in the
  39.      * given order.  For example, compose(f1,f2,f3) acts like f1(f2(f3(x))).</p>
  40.      *
  41.      * @param f List of functions.
  42.      * @return the composite function.
  43.      */
  44.     public static UnivariateFunction compose(final UnivariateFunction ... f) {
  45.         return new UnivariateFunction() {
  46.             /** {@inheritDoc} */
  47.             @Override
  48.             public double value(double x) {
  49.                 double r = x;
  50.                 for (int i = f.length - 1; i >= 0; i--) {
  51.                     r = f[i].value(r);
  52.                 }
  53.                 return r;
  54.             }
  55.         };
  56.     }

  57.     /**
  58.      * Composes functions.
  59.      * <p>
  60.      * The functions in the argument list are composed sequentially, in the
  61.      * given order.  For example, compose(f1,f2,f3) acts like f1(f2(f3(x))).</p>
  62.      *
  63.      * @param f List of functions.
  64.      * @return the composite function.
  65.      * @since 3.1
  66.      */
  67.     public static UnivariateDifferentiableFunction compose(final UnivariateDifferentiableFunction ... f) {
  68.         return new UnivariateDifferentiableFunction() {

  69.             /** {@inheritDoc} */
  70.             @Override
  71.             public double value(final double t) {
  72.                 double r = t;
  73.                 for (int i = f.length - 1; i >= 0; i--) {
  74.                     r = f[i].value(r);
  75.                 }
  76.                 return r;
  77.             }

  78.             /** {@inheritDoc} */
  79.             @Override
  80.             public DerivativeStructure value(final DerivativeStructure t) {
  81.                 DerivativeStructure r = t;
  82.                 for (int i = f.length - 1; i >= 0; i--) {
  83.                     r = f[i].value(r);
  84.                 }
  85.                 return r;
  86.             }
  87.         };
  88.     }

  89.     /**
  90.      * Adds functions.
  91.      *
  92.      * @param f List of functions.
  93.      * @return a function that computes the sum of the functions.
  94.      */
  95.     public static UnivariateFunction add(final UnivariateFunction ... f) {
  96.         return new UnivariateFunction() {
  97.             /** {@inheritDoc} */
  98.             @Override
  99.             public double value(double x) {
  100.                 double r = f[0].value(x);
  101.                 for (int i = 1; i < f.length; i++) {
  102.                     r += f[i].value(x);
  103.                 }
  104.                 return r;
  105.             }
  106.         };
  107.     }

  108.     /**
  109.      * Adds functions.
  110.      *
  111.      * @param f List of functions.
  112.      * @return a function that computes the sum of the functions.
  113.      * @since 3.1
  114.      */
  115.     public static UnivariateDifferentiableFunction add(final UnivariateDifferentiableFunction ... f) {
  116.         return new UnivariateDifferentiableFunction() {

  117.             /** {@inheritDoc} */
  118.             @Override
  119.             public double value(final double t) {
  120.                 double r = f[0].value(t);
  121.                 for (int i = 1; i < f.length; i++) {
  122.                     r += f[i].value(t);
  123.                 }
  124.                 return r;
  125.             }

  126.             /** {@inheritDoc}
  127.              * @throws DimensionMismatchException if functions are not consistent with each other
  128.              */
  129.             @Override
  130.             public DerivativeStructure value(final DerivativeStructure t)
  131.                 throws DimensionMismatchException {
  132.                 DerivativeStructure r = f[0].value(t);
  133.                 for (int i = 1; i < f.length; i++) {
  134.                     r = r.add(f[i].value(t));
  135.                 }
  136.                 return r;
  137.             }
  138.         };
  139.     }

  140.     /**
  141.      * Multiplies functions.
  142.      *
  143.      * @param f List of functions.
  144.      * @return a function that computes the product of the functions.
  145.      */
  146.     public static UnivariateFunction multiply(final UnivariateFunction ... f) {
  147.         return new UnivariateFunction() {
  148.             /** {@inheritDoc} */
  149.             @Override
  150.             public double value(double x) {
  151.                 double r = f[0].value(x);
  152.                 for (int i = 1; i < f.length; i++) {
  153.                     r *= f[i].value(x);
  154.                 }
  155.                 return r;
  156.             }
  157.         };
  158.     }

  159.     /**
  160.      * Multiplies functions.
  161.      *
  162.      * @param f List of functions.
  163.      * @return a function that computes the product of the functions.
  164.      * @since 3.1
  165.      */
  166.     public static UnivariateDifferentiableFunction multiply(final UnivariateDifferentiableFunction ... f) {
  167.         return new UnivariateDifferentiableFunction() {

  168.             /** {@inheritDoc} */
  169.             @Override
  170.             public double value(final double t) {
  171.                 double r = f[0].value(t);
  172.                 for (int i = 1; i < f.length; i++) {
  173.                     r  *= f[i].value(t);
  174.                 }
  175.                 return r;
  176.             }

  177.             /** {@inheritDoc} */
  178.             @Override
  179.             public DerivativeStructure value(final DerivativeStructure t) {
  180.                 DerivativeStructure r = f[0].value(t);
  181.                 for (int i = 1; i < f.length; i++) {
  182.                     r = r.multiply(f[i].value(t));
  183.                 }
  184.                 return r;
  185.             }
  186.         };
  187.     }

  188.     /**
  189.      * Returns the univariate function
  190.      * {@code h(x) = combiner(f(x), g(x))}.
  191.      *
  192.      * @param combiner Combiner function.
  193.      * @param f Function.
  194.      * @param g Function.
  195.      * @return the composite function.
  196.      */
  197.     public static UnivariateFunction combine(final BivariateFunction combiner,
  198.                                              final UnivariateFunction f,
  199.                                              final UnivariateFunction g) {
  200.         return new UnivariateFunction() {
  201.             /** {@inheritDoc} */
  202.             @Override
  203.             public double value(double x) {
  204.                 return combiner.value(f.value(x), g.value(x));
  205.             }
  206.         };
  207.     }

  208.     /**
  209.      * Returns a MultivariateFunction h(x[]). Defined by:
  210.      * <pre> <code>
  211.      * h(x[]) = combiner(...combiner(combiner(initialValue,f(x[0])),f(x[1]))...),f(x[x.length-1]))
  212.      * </code></pre>
  213.      *
  214.      * @param combiner Combiner function.
  215.      * @param f Function.
  216.      * @param initialValue Initial value.
  217.      * @return a collector function.
  218.      */
  219.     public static MultivariateFunction collector(final BivariateFunction combiner,
  220.                                                  final UnivariateFunction f,
  221.                                                  final double initialValue) {
  222.         return new MultivariateFunction() {
  223.             /** {@inheritDoc} */
  224.             @Override
  225.             public double value(double[] point) {
  226.                 double result = combiner.value(initialValue, f.value(point[0]));
  227.                 for (int i = 1; i < point.length; i++) {
  228.                     result = combiner.value(result, f.value(point[i]));
  229.                 }
  230.                 return result;
  231.             }
  232.         };
  233.     }

  234.     /**
  235.      * Returns a MultivariateFunction h(x[]). Defined by:
  236.      * <pre> <code>
  237.      * h(x[]) = combiner(...combiner(combiner(initialValue,x[0]),x[1])...),x[x.length-1])
  238.      * </code></pre>
  239.      *
  240.      * @param combiner Combiner function.
  241.      * @param initialValue Initial value.
  242.      * @return a collector function.
  243.      */
  244.     public static MultivariateFunction collector(final BivariateFunction combiner,
  245.                                                  final double initialValue) {
  246.         return collector(combiner, new Identity(), initialValue);
  247.     }

  248.     /**
  249.      * Creates a unary function by fixing the first argument of a binary function.
  250.      *
  251.      * @param f Binary function.
  252.      * @param fixed value to which the first argument of {@code f} is set.
  253.      * @return the unary function h(x) = f(fixed, x)
  254.      */
  255.     public static UnivariateFunction fix1stArgument(final BivariateFunction f,
  256.                                                     final double fixed) {
  257.         return new UnivariateFunction() {
  258.             /** {@inheritDoc} */
  259.             @Override
  260.             public double value(double x) {
  261.                 return f.value(fixed, x);
  262.             }
  263.         };
  264.     }
  265.     /**
  266.      * Creates a unary function by fixing the second argument of a binary function.
  267.      *
  268.      * @param f Binary function.
  269.      * @param fixed value to which the second argument of {@code f} is set.
  270.      * @return the unary function h(x) = f(x, fixed)
  271.      */
  272.     public static UnivariateFunction fix2ndArgument(final BivariateFunction f,
  273.                                                     final double fixed) {
  274.         return new UnivariateFunction() {
  275.             /** {@inheritDoc} */
  276.             @Override
  277.             public double value(double x) {
  278.                 return f.value(x, fixed);
  279.             }
  280.         };
  281.     }

  282.     /** Convert regular functions to {@link UnivariateDifferentiableFunction}.
  283.      * <p>
  284.      * This method handle the case with one free parameter and several derivatives.
  285.      * For the case with several free parameters and only first order derivatives,
  286.      * see {@link #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)}.
  287.      * There are no direct support for intermediate cases, with several free parameters
  288.      * and order 2 or more derivatives, as is would be difficult to specify all the
  289.      * cross derivatives.
  290.      * </p>
  291.      * <p>
  292.      * Note that the derivatives are expected to be computed only with respect to the
  293.      * raw parameter x of the base function, i.e. they are df/dx, df<sup>2</sup>/dx<sup>2</sup>, ...
  294.      * Even if the built function is later used in a composition like f(sin(t)), the provided
  295.      * derivatives should <em>not</em> apply the composition with sine and its derivatives by
  296.      * themselves. The composition will be done automatically here and the result will properly
  297.      * contain f(sin(t)), df(sin(t))/dt, df<sup>2</sup>(sin(t))/dt<sup>2</sup> despite the
  298.      * provided derivatives functions know nothing about the sine function.
  299.      * </p>
  300.      * @param f base function f(x)
  301.      * @param derivatives derivatives of the base function, in increasing differentiation order
  302.      * @return a differentiable function with value and all specified derivatives
  303.      * @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
  304.      * @see #derivative(UnivariateDifferentiableFunction, int)
  305.      */
  306.     public static UnivariateDifferentiableFunction toDifferentiable(final UnivariateFunction f,
  307.                                                                        final UnivariateFunction ... derivatives) {

  308.         return new UnivariateDifferentiableFunction() {

  309.             /** {@inheritDoc} */
  310.             @Override
  311.             public double value(final double x) {
  312.                 return f.value(x);
  313.             }

  314.             /** {@inheritDoc} */
  315.             @Override
  316.             public DerivativeStructure value(final DerivativeStructure x) {
  317.                 if (x.getOrder() > derivatives.length) {
  318.                     throw new NumberIsTooLargeException(x.getOrder(), derivatives.length, true);
  319.                 }
  320.                 final double[] packed = new double[x.getOrder() + 1];
  321.                 packed[0] = f.value(x.getValue());
  322.                 for (int i = 0; i < x.getOrder(); ++i) {
  323.                     packed[i + 1] = derivatives[i].value(x.getValue());
  324.                 }
  325.                 return x.compose(packed);
  326.             }
  327.         };
  328.     }

  329.     /** Convert regular functions to {@link MultivariateDifferentiableFunction}.
  330.      * <p>
  331.      * This method handle the case with several free parameters and only first order derivatives.
  332.      * For the case with one free parameter and several derivatives,
  333.      * see {@link #toDifferentiable(UnivariateFunction, UnivariateFunction...)}.
  334.      * There are no direct support for intermediate cases, with several free parameters
  335.      * and order 2 or more derivatives, as is would be difficult to specify all the
  336.      * cross derivatives.
  337.      * </p>
  338.      * <p>
  339.      * Note that the gradient is expected to be computed only with respect to the
  340.      * raw parameter x of the base function, i.e. it is df/dx<sub>1</sub>, df/dx<sub>2</sub>, ...
  341.      * Even if the built function is later used in a composition like f(sin(t), cos(t)), the provided
  342.      * gradient should <em>not</em> apply the composition with sine or cosine and their derivative by
  343.      * itself. The composition will be done automatically here and the result will properly
  344.      * contain f(sin(t), cos(t)), df(sin(t), cos(t))/dt despite the provided derivatives functions
  345.      * know nothing about the sine or cosine functions.
  346.      * </p>
  347.      * @param f base function f(x)
  348.      * @param gradient gradient of the base function
  349.      * @return a differentiable function with value and gradient
  350.      * @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
  351.      * @see #derivative(MultivariateDifferentiableFunction, int[])
  352.      */
  353.     public static MultivariateDifferentiableFunction toDifferentiable(final MultivariateFunction f,
  354.                                                                       final MultivariateVectorFunction gradient) {

  355.         return new MultivariateDifferentiableFunction() {

  356.             /** {@inheritDoc} */
  357.             @Override
  358.             public double value(final double[] point) {
  359.                 return f.value(point);
  360.             }

  361.             /** {@inheritDoc} */
  362.             @Override
  363.             public DerivativeStructure value(final DerivativeStructure[] point) {

  364.                 // set up the input parameters
  365.                 final double[] dPoint = new double[point.length];
  366.                 for (int i = 0; i < point.length; ++i) {
  367.                     dPoint[i] = point[i].getValue();
  368.                     if (point[i].getOrder() > 1) {
  369.                         throw new NumberIsTooLargeException(point[i].getOrder(), 1, true);
  370.                     }
  371.                 }

  372.                 // evaluate regular functions
  373.                 final double    v = f.value(dPoint);
  374.                 final double[] dv = gradient.value(dPoint);
  375.                 if (dv.length != point.length) {
  376.                     // the gradient function is inconsistent
  377.                     throw new DimensionMismatchException(dv.length, point.length);
  378.                 }

  379.                 // build the combined derivative
  380.                 final int parameters = point[0].getFreeParameters();
  381.                 final double[] partials = new double[point.length];
  382.                 final double[] packed = new double[parameters + 1];
  383.                 packed[0] = v;
  384.                 final int[] orders = new int[parameters];
  385.                 for (int i = 0; i < parameters; ++i) {

  386.                     // we differentiate once with respect to parameter i
  387.                     orders[i] = 1;
  388.                     for (int j = 0; j < point.length; ++j) {
  389.                         partials[j] = point[j].getPartialDerivative(orders);
  390.                     }
  391.                     orders[i] = 0;

  392.                     // compose partial derivatives
  393.                     packed[i + 1] = Sum.ofProducts(dv, partials).getAsDouble();
  394.                 }

  395.                 return new DerivativeStructure(parameters, 1, packed);
  396.             }
  397.         };
  398.     }

  399.     /** Convert an {@link UnivariateDifferentiableFunction} to an
  400.      * {@link UnivariateFunction} computing n<sup>th</sup> order derivative.
  401.      * <p>
  402.      * This converter is only a convenience method. Beware computing only one derivative does
  403.      * not save any computation as the original function will really be called under the hood.
  404.      * The derivative will be extracted from the full {@link DerivativeStructure} result.
  405.      * </p>
  406.      * @param f original function, with value and all its derivatives
  407.      * @param order of the derivative to extract
  408.      * @return function computing the derivative at required order
  409.      * @see #derivative(MultivariateDifferentiableFunction, int[])
  410.      * @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
  411.      */
  412.     public static UnivariateFunction derivative(final UnivariateDifferentiableFunction f, final int order) {
  413.         return new UnivariateFunction() {

  414.             /** {@inheritDoc} */
  415.             @Override
  416.             public double value(final double x) {
  417.                 final DerivativeStructure dsX = new DerivativeStructure(1, order, 0, x);
  418.                 return f.value(dsX).getPartialDerivative(order);
  419.             }
  420.         };
  421.     }

  422.     /** Convert an {@link MultivariateDifferentiableFunction} to an
  423.      * {@link MultivariateFunction} computing n<sup>th</sup> order derivative.
  424.      * <p>
  425.      * This converter is only a convenience method. Beware computing only one derivative does
  426.      * not save any computation as the original function will really be called under the hood.
  427.      * The derivative will be extracted from the full {@link DerivativeStructure} result.
  428.      * </p>
  429.      * @param f original function, with value and all its derivatives
  430.      * @param orders of the derivative to extract, for each free parameters
  431.      * @return function computing the derivative at required order
  432.      * @see #derivative(UnivariateDifferentiableFunction, int)
  433.      * @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
  434.      */
  435.     public static MultivariateFunction derivative(final MultivariateDifferentiableFunction f, final int[] orders) {
  436.         return new MultivariateFunction() {

  437.             /** {@inheritDoc} */
  438.             @Override
  439.             public double value(final double[] point) {

  440.                 // the maximum differentiation order is the sum of all orders
  441.                 int sumOrders = 0;
  442.                 for (final int order : orders) {
  443.                     sumOrders += order;
  444.                 }

  445.                 // set up the input parameters
  446.                 final DerivativeStructure[] dsPoint = new DerivativeStructure[point.length];
  447.                 for (int i = 0; i < point.length; ++i) {
  448.                     dsPoint[i] = new DerivativeStructure(point.length, sumOrders, i, point[i]);
  449.                 }

  450.                 return f.value(dsPoint).getPartialDerivative(orders);
  451.             }
  452.         };
  453.     }
  454. }