001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.legacy.analysis;
019
020import org.apache.commons.numbers.core.Sum;
021import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
022import org.apache.commons.math4.legacy.analysis.differentiation.MultivariateDifferentiableFunction;
023import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction;
024import org.apache.commons.math4.legacy.analysis.function.Identity;
025import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
026import org.apache.commons.math4.legacy.exception.NumberIsTooLargeException;
027
028/**
029 * Utilities for manipulating function objects.
030 *
031 * @since 3.0
032 */
033public final class FunctionUtils {
034    /**
035     * Class only contains static methods.
036     */
037    private FunctionUtils() {}
038
039    /**
040     * Composes functions.
041     * <p>
042     * The functions in the argument list are composed sequentially, in the
043     * given order.  For example, compose(f1,f2,f3) acts like f1(f2(f3(x))).</p>
044     *
045     * @param f List of functions.
046     * @return the composite function.
047     */
048    public static UnivariateFunction compose(final UnivariateFunction ... f) {
049        return new UnivariateFunction() {
050            /** {@inheritDoc} */
051            @Override
052            public double value(double x) {
053                double r = x;
054                for (int i = f.length - 1; i >= 0; i--) {
055                    r = f[i].value(r);
056                }
057                return r;
058            }
059        };
060    }
061
062    /**
063     * Composes functions.
064     * <p>
065     * The functions in the argument list are composed sequentially, in the
066     * given order.  For example, compose(f1,f2,f3) acts like f1(f2(f3(x))).</p>
067     *
068     * @param f List of functions.
069     * @return the composite function.
070     * @since 3.1
071     */
072    public static UnivariateDifferentiableFunction compose(final UnivariateDifferentiableFunction ... f) {
073        return new UnivariateDifferentiableFunction() {
074
075            /** {@inheritDoc} */
076            @Override
077            public double value(final double t) {
078                double r = t;
079                for (int i = f.length - 1; i >= 0; i--) {
080                    r = f[i].value(r);
081                }
082                return r;
083            }
084
085            /** {@inheritDoc} */
086            @Override
087            public DerivativeStructure value(final DerivativeStructure t) {
088                DerivativeStructure r = t;
089                for (int i = f.length - 1; i >= 0; i--) {
090                    r = f[i].value(r);
091                }
092                return r;
093            }
094        };
095    }
096
097    /**
098     * Adds functions.
099     *
100     * @param f List of functions.
101     * @return a function that computes the sum of the functions.
102     */
103    public static UnivariateFunction add(final UnivariateFunction ... f) {
104        return new UnivariateFunction() {
105            /** {@inheritDoc} */
106            @Override
107            public double value(double x) {
108                double r = f[0].value(x);
109                for (int i = 1; i < f.length; i++) {
110                    r += f[i].value(x);
111                }
112                return r;
113            }
114        };
115    }
116
117    /**
118     * Adds functions.
119     *
120     * @param f List of functions.
121     * @return a function that computes the sum of the functions.
122     * @since 3.1
123     */
124    public static UnivariateDifferentiableFunction add(final UnivariateDifferentiableFunction ... f) {
125        return new UnivariateDifferentiableFunction() {
126
127            /** {@inheritDoc} */
128            @Override
129            public double value(final double t) {
130                double r = f[0].value(t);
131                for (int i = 1; i < f.length; i++) {
132                    r += f[i].value(t);
133                }
134                return r;
135            }
136
137            /** {@inheritDoc}
138             * @throws DimensionMismatchException if functions are not consistent with each other
139             */
140            @Override
141            public DerivativeStructure value(final DerivativeStructure t)
142                throws DimensionMismatchException {
143                DerivativeStructure r = f[0].value(t);
144                for (int i = 1; i < f.length; i++) {
145                    r = r.add(f[i].value(t));
146                }
147                return r;
148            }
149        };
150    }
151
152    /**
153     * Multiplies functions.
154     *
155     * @param f List of functions.
156     * @return a function that computes the product of the functions.
157     */
158    public static UnivariateFunction multiply(final UnivariateFunction ... f) {
159        return new UnivariateFunction() {
160            /** {@inheritDoc} */
161            @Override
162            public double value(double x) {
163                double r = f[0].value(x);
164                for (int i = 1; i < f.length; i++) {
165                    r *= f[i].value(x);
166                }
167                return r;
168            }
169        };
170    }
171
172    /**
173     * Multiplies functions.
174     *
175     * @param f List of functions.
176     * @return a function that computes the product of the functions.
177     * @since 3.1
178     */
179    public static UnivariateDifferentiableFunction multiply(final UnivariateDifferentiableFunction ... f) {
180        return new UnivariateDifferentiableFunction() {
181
182            /** {@inheritDoc} */
183            @Override
184            public double value(final double t) {
185                double r = f[0].value(t);
186                for (int i = 1; i < f.length; i++) {
187                    r  *= f[i].value(t);
188                }
189                return r;
190            }
191
192            /** {@inheritDoc} */
193            @Override
194            public DerivativeStructure value(final DerivativeStructure t) {
195                DerivativeStructure r = f[0].value(t);
196                for (int i = 1; i < f.length; i++) {
197                    r = r.multiply(f[i].value(t));
198                }
199                return r;
200            }
201        };
202    }
203
204    /**
205     * Returns the univariate function
206     * {@code h(x) = combiner(f(x), g(x))}.
207     *
208     * @param combiner Combiner function.
209     * @param f Function.
210     * @param g Function.
211     * @return the composite function.
212     */
213    public static UnivariateFunction combine(final BivariateFunction combiner,
214                                             final UnivariateFunction f,
215                                             final UnivariateFunction g) {
216        return new UnivariateFunction() {
217            /** {@inheritDoc} */
218            @Override
219            public double value(double x) {
220                return combiner.value(f.value(x), g.value(x));
221            }
222        };
223    }
224
225    /**
226     * Returns a MultivariateFunction h(x[]). Defined by:
227     * <pre> <code>
228     * h(x[]) = combiner(...combiner(combiner(initialValue,f(x[0])),f(x[1]))...),f(x[x.length-1]))
229     * </code></pre>
230     *
231     * @param combiner Combiner function.
232     * @param f Function.
233     * @param initialValue Initial value.
234     * @return a collector function.
235     */
236    public static MultivariateFunction collector(final BivariateFunction combiner,
237                                                 final UnivariateFunction f,
238                                                 final double initialValue) {
239        return new MultivariateFunction() {
240            /** {@inheritDoc} */
241            @Override
242            public double value(double[] point) {
243                double result = combiner.value(initialValue, f.value(point[0]));
244                for (int i = 1; i < point.length; i++) {
245                    result = combiner.value(result, f.value(point[i]));
246                }
247                return result;
248            }
249        };
250    }
251
252    /**
253     * Returns a MultivariateFunction h(x[]). Defined by:
254     * <pre> <code>
255     * h(x[]) = combiner(...combiner(combiner(initialValue,x[0]),x[1])...),x[x.length-1])
256     * </code></pre>
257     *
258     * @param combiner Combiner function.
259     * @param initialValue Initial value.
260     * @return a collector function.
261     */
262    public static MultivariateFunction collector(final BivariateFunction combiner,
263                                                 final double initialValue) {
264        return collector(combiner, new Identity(), initialValue);
265    }
266
267    /**
268     * Creates a unary function by fixing the first argument of a binary function.
269     *
270     * @param f Binary function.
271     * @param fixed value to which the first argument of {@code f} is set.
272     * @return the unary function h(x) = f(fixed, x)
273     */
274    public static UnivariateFunction fix1stArgument(final BivariateFunction f,
275                                                    final double fixed) {
276        return new UnivariateFunction() {
277            /** {@inheritDoc} */
278            @Override
279            public double value(double x) {
280                return f.value(fixed, x);
281            }
282        };
283    }
284    /**
285     * Creates a unary function by fixing the second argument of a binary function.
286     *
287     * @param f Binary function.
288     * @param fixed value to which the second argument of {@code f} is set.
289     * @return the unary function h(x) = f(x, fixed)
290     */
291    public static UnivariateFunction fix2ndArgument(final BivariateFunction f,
292                                                    final double fixed) {
293        return new UnivariateFunction() {
294            /** {@inheritDoc} */
295            @Override
296            public double value(double x) {
297                return f.value(x, fixed);
298            }
299        };
300    }
301
302    /** Convert regular functions to {@link UnivariateDifferentiableFunction}.
303     * <p>
304     * This method handle the case with one free parameter and several derivatives.
305     * For the case with several free parameters and only first order derivatives,
306     * see {@link #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)}.
307     * There are no direct support for intermediate cases, with several free parameters
308     * and order 2 or more derivatives, as is would be difficult to specify all the
309     * cross derivatives.
310     * </p>
311     * <p>
312     * Note that the derivatives are expected to be computed only with respect to the
313     * raw parameter x of the base function, i.e. they are df/dx, df<sup>2</sup>/dx<sup>2</sup>, ...
314     * Even if the built function is later used in a composition like f(sin(t)), the provided
315     * derivatives should <em>not</em> apply the composition with sine and its derivatives by
316     * themselves. The composition will be done automatically here and the result will properly
317     * contain f(sin(t)), df(sin(t))/dt, df<sup>2</sup>(sin(t))/dt<sup>2</sup> despite the
318     * provided derivatives functions know nothing about the sine function.
319     * </p>
320     * @param f base function f(x)
321     * @param derivatives derivatives of the base function, in increasing differentiation order
322     * @return a differentiable function with value and all specified derivatives
323     * @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
324     * @see #derivative(UnivariateDifferentiableFunction, int)
325     */
326    public static UnivariateDifferentiableFunction toDifferentiable(final UnivariateFunction f,
327                                                                       final UnivariateFunction ... derivatives) {
328
329        return new UnivariateDifferentiableFunction() {
330
331            /** {@inheritDoc} */
332            @Override
333            public double value(final double x) {
334                return f.value(x);
335            }
336
337            /** {@inheritDoc} */
338            @Override
339            public DerivativeStructure value(final DerivativeStructure x) {
340                if (x.getOrder() > derivatives.length) {
341                    throw new NumberIsTooLargeException(x.getOrder(), derivatives.length, true);
342                }
343                final double[] packed = new double[x.getOrder() + 1];
344                packed[0] = f.value(x.getValue());
345                for (int i = 0; i < x.getOrder(); ++i) {
346                    packed[i + 1] = derivatives[i].value(x.getValue());
347                }
348                return x.compose(packed);
349            }
350        };
351    }
352
353    /** Convert regular functions to {@link MultivariateDifferentiableFunction}.
354     * <p>
355     * This method handle the case with several free parameters and only first order derivatives.
356     * For the case with one free parameter and several derivatives,
357     * see {@link #toDifferentiable(UnivariateFunction, UnivariateFunction...)}.
358     * There are no direct support for intermediate cases, with several free parameters
359     * and order 2 or more derivatives, as is would be difficult to specify all the
360     * cross derivatives.
361     * </p>
362     * <p>
363     * Note that the gradient is expected to be computed only with respect to the
364     * raw parameter x of the base function, i.e. it is df/dx<sub>1</sub>, df/dx<sub>2</sub>, ...
365     * Even if the built function is later used in a composition like f(sin(t), cos(t)), the provided
366     * gradient should <em>not</em> apply the composition with sine or cosine and their derivative by
367     * itself. The composition will be done automatically here and the result will properly
368     * contain f(sin(t), cos(t)), df(sin(t), cos(t))/dt despite the provided derivatives functions
369     * know nothing about the sine or cosine functions.
370     * </p>
371     * @param f base function f(x)
372     * @param gradient gradient of the base function
373     * @return a differentiable function with value and gradient
374     * @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
375     * @see #derivative(MultivariateDifferentiableFunction, int[])
376     */
377    public static MultivariateDifferentiableFunction toDifferentiable(final MultivariateFunction f,
378                                                                      final MultivariateVectorFunction gradient) {
379
380        return new MultivariateDifferentiableFunction() {
381
382            /** {@inheritDoc} */
383            @Override
384            public double value(final double[] point) {
385                return f.value(point);
386            }
387
388            /** {@inheritDoc} */
389            @Override
390            public DerivativeStructure value(final DerivativeStructure[] point) {
391
392                // set up the input parameters
393                final double[] dPoint = new double[point.length];
394                for (int i = 0; i < point.length; ++i) {
395                    dPoint[i] = point[i].getValue();
396                    if (point[i].getOrder() > 1) {
397                        throw new NumberIsTooLargeException(point[i].getOrder(), 1, true);
398                    }
399                }
400
401                // evaluate regular functions
402                final double    v = f.value(dPoint);
403                final double[] dv = gradient.value(dPoint);
404                if (dv.length != point.length) {
405                    // the gradient function is inconsistent
406                    throw new DimensionMismatchException(dv.length, point.length);
407                }
408
409                // build the combined derivative
410                final int parameters = point[0].getFreeParameters();
411                final double[] partials = new double[point.length];
412                final double[] packed = new double[parameters + 1];
413                packed[0] = v;
414                final int[] orders = new int[parameters];
415                for (int i = 0; i < parameters; ++i) {
416
417                    // we differentiate once with respect to parameter i
418                    orders[i] = 1;
419                    for (int j = 0; j < point.length; ++j) {
420                        partials[j] = point[j].getPartialDerivative(orders);
421                    }
422                    orders[i] = 0;
423
424                    // compose partial derivatives
425                    packed[i + 1] = Sum.ofProducts(dv, partials).getAsDouble();
426                }
427
428                return new DerivativeStructure(parameters, 1, packed);
429            }
430        };
431    }
432
433    /** Convert an {@link UnivariateDifferentiableFunction} to an
434     * {@link UnivariateFunction} computing n<sup>th</sup> order derivative.
435     * <p>
436     * This converter is only a convenience method. Beware computing only one derivative does
437     * not save any computation as the original function will really be called under the hood.
438     * The derivative will be extracted from the full {@link DerivativeStructure} result.
439     * </p>
440     * @param f original function, with value and all its derivatives
441     * @param order of the derivative to extract
442     * @return function computing the derivative at required order
443     * @see #derivative(MultivariateDifferentiableFunction, int[])
444     * @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
445     */
446    public static UnivariateFunction derivative(final UnivariateDifferentiableFunction f, final int order) {
447        return new UnivariateFunction() {
448
449            /** {@inheritDoc} */
450            @Override
451            public double value(final double x) {
452                final DerivativeStructure dsX = new DerivativeStructure(1, order, 0, x);
453                return f.value(dsX).getPartialDerivative(order);
454            }
455        };
456    }
457
458    /** Convert an {@link MultivariateDifferentiableFunction} to an
459     * {@link MultivariateFunction} computing n<sup>th</sup> order derivative.
460     * <p>
461     * This converter is only a convenience method. Beware computing only one derivative does
462     * not save any computation as the original function will really be called under the hood.
463     * The derivative will be extracted from the full {@link DerivativeStructure} result.
464     * </p>
465     * @param f original function, with value and all its derivatives
466     * @param orders of the derivative to extract, for each free parameters
467     * @return function computing the derivative at required order
468     * @see #derivative(UnivariateDifferentiableFunction, int)
469     * @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
470     */
471    public static MultivariateFunction derivative(final MultivariateDifferentiableFunction f, final int[] orders) {
472        return new MultivariateFunction() {
473
474            /** {@inheritDoc} */
475            @Override
476            public double value(final double[] point) {
477
478                // the maximum differentiation order is the sum of all orders
479                int sumOrders = 0;
480                for (final int order : orders) {
481                    sumOrders += order;
482                }
483
484                // set up the input parameters
485                final DerivativeStructure[] dsPoint = new DerivativeStructure[point.length];
486                for (int i = 0; i < point.length; ++i) {
487                    dsPoint[i] = new DerivativeStructure(point.length, sumOrders, i, point[i]);
488                }
489
490                return f.value(dsPoint).getPartialDerivative(orders);
491            }
492        };
493    }
494}