001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.analysis;
019
020import org.apache.commons.numbers.arrays.LinearCombination;
021import org.apache.commons.math4.analysis.differentiation.DerivativeStructure;
022import org.apache.commons.math4.analysis.differentiation.MultivariateDifferentiableFunction;
023import org.apache.commons.math4.analysis.differentiation.UnivariateDifferentiableFunction;
024import org.apache.commons.math4.analysis.function.Identity;
025import org.apache.commons.math4.exception.DimensionMismatchException;
026import org.apache.commons.math4.exception.NotStrictlyPositiveException;
027import org.apache.commons.math4.exception.NumberIsTooLargeException;
028import org.apache.commons.math4.exception.util.LocalizedFormats;
029
030/**
031 * Utilities for manipulating function objects.
032 *
033 * @since 3.0
034 */
035public class FunctionUtils {
036    /**
037     * Class only contains static methods.
038     */
039    private FunctionUtils() {}
040
041    /**
042     * Composes functions.
043     * <p>
044     * The functions in the argument list are composed sequentially, in the
045     * given order.  For example, compose(f1,f2,f3) acts like f1(f2(f3(x))).</p>
046     *
047     * @param f List of functions.
048     * @return the composite function.
049     */
050    public static UnivariateFunction compose(final UnivariateFunction ... f) {
051        return new UnivariateFunction() {
052            /** {@inheritDoc} */
053            @Override
054            public double value(double x) {
055                double r = x;
056                for (int i = f.length - 1; i >= 0; i--) {
057                    r = f[i].value(r);
058                }
059                return r;
060            }
061        };
062    }
063
064    /**
065     * Composes functions.
066     * <p>
067     * The functions in the argument list are composed sequentially, in the
068     * given order.  For example, compose(f1,f2,f3) acts like f1(f2(f3(x))).</p>
069     *
070     * @param f List of functions.
071     * @return the composite function.
072     * @since 3.1
073     */
074    public static UnivariateDifferentiableFunction compose(final UnivariateDifferentiableFunction ... f) {
075        return new UnivariateDifferentiableFunction() {
076
077            /** {@inheritDoc} */
078            @Override
079            public double value(final double t) {
080                double r = t;
081                for (int i = f.length - 1; i >= 0; i--) {
082                    r = f[i].value(r);
083                }
084                return r;
085            }
086
087            /** {@inheritDoc} */
088            @Override
089            public DerivativeStructure value(final DerivativeStructure t) {
090                DerivativeStructure r = t;
091                for (int i = f.length - 1; i >= 0; i--) {
092                    r = f[i].value(r);
093                }
094                return r;
095            }
096
097        };
098    }
099
100    /**
101     * Adds functions.
102     *
103     * @param f List of functions.
104     * @return a function that computes the sum of the functions.
105     */
106    public static UnivariateFunction add(final UnivariateFunction ... f) {
107        return new UnivariateFunction() {
108            /** {@inheritDoc} */
109            @Override
110            public double value(double x) {
111                double r = f[0].value(x);
112                for (int i = 1; i < f.length; i++) {
113                    r += f[i].value(x);
114                }
115                return r;
116            }
117        };
118    }
119
120    /**
121     * Adds functions.
122     *
123     * @param f List of functions.
124     * @return a function that computes the sum of the functions.
125     * @since 3.1
126     */
127    public static UnivariateDifferentiableFunction add(final UnivariateDifferentiableFunction ... f) {
128        return new UnivariateDifferentiableFunction() {
129
130            /** {@inheritDoc} */
131            @Override
132            public double value(final double t) {
133                double r = f[0].value(t);
134                for (int i = 1; i < f.length; i++) {
135                    r += f[i].value(t);
136                }
137                return r;
138            }
139
140            /** {@inheritDoc}
141             * @throws DimensionMismatchException if functions are not consistent with each other
142             */
143            @Override
144            public DerivativeStructure value(final DerivativeStructure t)
145                throws DimensionMismatchException {
146                DerivativeStructure r = f[0].value(t);
147                for (int i = 1; i < f.length; i++) {
148                    r = r.add(f[i].value(t));
149                }
150                return r;
151            }
152
153        };
154    }
155
156    /**
157     * Multiplies functions.
158     *
159     * @param f List of functions.
160     * @return a function that computes the product of the functions.
161     */
162    public static UnivariateFunction multiply(final UnivariateFunction ... f) {
163        return new UnivariateFunction() {
164            /** {@inheritDoc} */
165            @Override
166            public double value(double x) {
167                double r = f[0].value(x);
168                for (int i = 1; i < f.length; i++) {
169                    r *= f[i].value(x);
170                }
171                return r;
172            }
173        };
174    }
175
176    /**
177     * Multiplies functions.
178     *
179     * @param f List of functions.
180     * @return a function that computes the product of the functions.
181     * @since 3.1
182     */
183    public static UnivariateDifferentiableFunction multiply(final UnivariateDifferentiableFunction ... f) {
184        return new UnivariateDifferentiableFunction() {
185
186            /** {@inheritDoc} */
187            @Override
188            public double value(final double t) {
189                double r = f[0].value(t);
190                for (int i = 1; i < f.length; i++) {
191                    r  *= f[i].value(t);
192                }
193                return r;
194            }
195
196            /** {@inheritDoc} */
197            @Override
198            public DerivativeStructure value(final DerivativeStructure t) {
199                DerivativeStructure r = f[0].value(t);
200                for (int i = 1; i < f.length; i++) {
201                    r = r.multiply(f[i].value(t));
202                }
203                return r;
204            }
205
206        };
207    }
208
209    /**
210     * Returns the univariate function
211     * {@code h(x) = combiner(f(x), g(x)).}
212     *
213     * @param combiner Combiner function.
214     * @param f Function.
215     * @param g Function.
216     * @return the composite function.
217     */
218    public static UnivariateFunction combine(final BivariateFunction combiner,
219                                             final UnivariateFunction f,
220                                             final UnivariateFunction g) {
221        return new UnivariateFunction() {
222            /** {@inheritDoc} */
223            @Override
224            public double value(double x) {
225                return combiner.value(f.value(x), g.value(x));
226            }
227        };
228    }
229
230    /**
231     * Returns a MultivariateFunction h(x[]) defined by <pre> <code>
232     * h(x[]) = combiner(...combiner(combiner(initialValue,f(x[0])),f(x[1]))...),f(x[x.length-1]))
233     * </code></pre>
234     *
235     * @param combiner Combiner function.
236     * @param f Function.
237     * @param initialValue Initial value.
238     * @return a collector function.
239     */
240    public static MultivariateFunction collector(final BivariateFunction combiner,
241                                                 final UnivariateFunction f,
242                                                 final double initialValue) {
243        return new MultivariateFunction() {
244            /** {@inheritDoc} */
245            @Override
246            public double value(double[] point) {
247                double result = combiner.value(initialValue, f.value(point[0]));
248                for (int i = 1; i < point.length; i++) {
249                    result = combiner.value(result, f.value(point[i]));
250                }
251                return result;
252            }
253        };
254    }
255
256    /**
257     * Returns a MultivariateFunction h(x[]) defined by <pre> <code>
258     * h(x[]) = combiner(...combiner(combiner(initialValue,x[0]),x[1])...),x[x.length-1])
259     * </code></pre>
260     *
261     * @param combiner Combiner function.
262     * @param initialValue Initial value.
263     * @return a collector function.
264     */
265    public static MultivariateFunction collector(final BivariateFunction combiner,
266                                                 final double initialValue) {
267        return collector(combiner, new Identity(), initialValue);
268    }
269
270    /**
271     * Creates a unary function by fixing the first argument of a binary function.
272     *
273     * @param f Binary function.
274     * @param fixed value to which the first argument of {@code f} is set.
275     * @return the unary function h(x) = f(fixed, x)
276     */
277    public static UnivariateFunction fix1stArgument(final BivariateFunction f,
278                                                    final double fixed) {
279        return new UnivariateFunction() {
280            /** {@inheritDoc} */
281            @Override
282            public double value(double x) {
283                return f.value(fixed, x);
284            }
285        };
286    }
287    /**
288     * Creates a unary function by fixing the second argument of a binary function.
289     *
290     * @param f Binary function.
291     * @param fixed value to which the second argument of {@code f} is set.
292     * @return the unary function h(x) = f(x, fixed)
293     */
294    public static UnivariateFunction fix2ndArgument(final BivariateFunction f,
295                                                    final double fixed) {
296        return new UnivariateFunction() {
297            /** {@inheritDoc} */
298            @Override
299            public double value(double x) {
300                return f.value(x, fixed);
301            }
302        };
303    }
304
305    /**
306     * Samples the specified univariate real function on the specified interval.
307     * <p>
308     * The interval is divided equally into {@code n} sections and sample points
309     * are taken from {@code min} to {@code max - (max - min) / n}; therefore
310     * {@code f} is not sampled at the upper bound {@code max}.</p>
311     *
312     * @param f Function to be sampled
313     * @param min Lower bound of the interval (included).
314     * @param max Upper bound of the interval (excluded).
315     * @param n Number of sample points.
316     * @return the array of samples.
317     * @throws NumberIsTooLargeException if the lower bound {@code min} is
318     * greater than, or equal to the upper bound {@code max}.
319     * @throws NotStrictlyPositiveException if the number of sample points
320     * {@code n} is negative.
321     */
322    public static double[] sample(UnivariateFunction f, double min, double max, int n)
323       throws NumberIsTooLargeException, NotStrictlyPositiveException {
324
325        if (n <= 0) {
326            throw new NotStrictlyPositiveException(
327                    LocalizedFormats.NOT_POSITIVE_NUMBER_OF_SAMPLES,
328                    Integer.valueOf(n));
329        }
330        if (min >= max) {
331            throw new NumberIsTooLargeException(min, max, false);
332        }
333
334        final double[] s = new double[n];
335        final double h = (max - min) / n;
336        for (int i = 0; i < n; i++) {
337            s[i] = f.value(min + i * h);
338        }
339        return s;
340    }
341
342    /** Convert regular functions to {@link UnivariateDifferentiableFunction}.
343     * <p>
344     * This method handle the case with one free parameter and several derivatives.
345     * For the case with several free parameters and only first order derivatives,
346     * see {@link #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)}.
347     * There are no direct support for intermediate cases, with several free parameters
348     * and order 2 or more derivatives, as is would be difficult to specify all the
349     * cross derivatives.
350     * </p>
351     * <p>
352     * Note that the derivatives are expected to be computed only with respect to the
353     * raw parameter x of the base function, i.e. they are df/dx, df<sup>2</sup>/dx<sup>2</sup>, ...
354     * Even if the built function is later used in a composition like f(sin(t)), the provided
355     * derivatives should <em>not</em> apply the composition with sine and its derivatives by
356     * themselves. The composition will be done automatically here and the result will properly
357     * contain f(sin(t)), df(sin(t))/dt, df<sup>2</sup>(sin(t))/dt<sup>2</sup> despite the
358     * provided derivatives functions know nothing about the sine function.
359     * </p>
360     * @param f base function f(x)
361     * @param derivatives derivatives of the base function, in increasing differentiation order
362     * @return a differentiable function with value and all specified derivatives
363     * @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
364     * @see #derivative(UnivariateDifferentiableFunction, int)
365     */
366    public static UnivariateDifferentiableFunction toDifferentiable(final UnivariateFunction f,
367                                                                       final UnivariateFunction ... derivatives) {
368
369        return new UnivariateDifferentiableFunction() {
370
371            /** {@inheritDoc} */
372            @Override
373            public double value(final double x) {
374                return f.value(x);
375            }
376
377            /** {@inheritDoc} */
378            @Override
379            public DerivativeStructure value(final DerivativeStructure x) {
380                if (x.getOrder() > derivatives.length) {
381                    throw new NumberIsTooLargeException(x.getOrder(), derivatives.length, true);
382                }
383                final double[] packed = new double[x.getOrder() + 1];
384                packed[0] = f.value(x.getValue());
385                for (int i = 0; i < x.getOrder(); ++i) {
386                    packed[i + 1] = derivatives[i].value(x.getValue());
387                }
388                return x.compose(packed);
389            }
390
391        };
392
393    }
394
395    /** Convert regular functions to {@link MultivariateDifferentiableFunction}.
396     * <p>
397     * This method handle the case with several free parameters and only first order derivatives.
398     * For the case with one free parameter and several derivatives,
399     * see {@link #toDifferentiable(UnivariateFunction, UnivariateFunction...)}.
400     * There are no direct support for intermediate cases, with several free parameters
401     * and order 2 or more derivatives, as is would be difficult to specify all the
402     * cross derivatives.
403     * </p>
404     * <p>
405     * Note that the gradient is expected to be computed only with respect to the
406     * raw parameter x of the base function, i.e. it is df/dx<sub>1</sub>, df/dx<sub>2</sub>, ...
407     * Even if the built function is later used in a composition like f(sin(t), cos(t)), the provided
408     * gradient should <em>not</em> apply the composition with sine or cosine and their derivative by
409     * itself. The composition will be done automatically here and the result will properly
410     * contain f(sin(t), cos(t)), df(sin(t), cos(t))/dt despite the provided derivatives functions
411     * know nothing about the sine or cosine functions.
412     * </p>
413     * @param f base function f(x)
414     * @param gradient gradient of the base function
415     * @return a differentiable function with value and gradient
416     * @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
417     * @see #derivative(MultivariateDifferentiableFunction, int[])
418     */
419    public static MultivariateDifferentiableFunction toDifferentiable(final MultivariateFunction f,
420                                                                         final MultivariateVectorFunction gradient) {
421
422        return new MultivariateDifferentiableFunction() {
423
424            /** {@inheritDoc} */
425            @Override
426            public double value(final double[] point) {
427                return f.value(point);
428            }
429
430            /** {@inheritDoc} */
431            @Override
432            public DerivativeStructure value(final DerivativeStructure[] point) {
433
434                // set up the input parameters
435                final double[] dPoint = new double[point.length];
436                for (int i = 0; i < point.length; ++i) {
437                    dPoint[i] = point[i].getValue();
438                    if (point[i].getOrder() > 1) {
439                        throw new NumberIsTooLargeException(point[i].getOrder(), 1, true);
440                    }
441                }
442
443                // evaluate regular functions
444                final double    v = f.value(dPoint);
445                final double[] dv = gradient.value(dPoint);
446                if (dv.length != point.length) {
447                    // the gradient function is inconsistent
448                    throw new DimensionMismatchException(dv.length, point.length);
449                }
450
451                // build the combined derivative
452                final int parameters = point[0].getFreeParameters();
453                final double[] partials = new double[point.length];
454                final double[] packed = new double[parameters + 1];
455                packed[0] = v;
456                final int orders[] = new int[parameters];
457                for (int i = 0; i < parameters; ++i) {
458
459                    // we differentiate once with respect to parameter i
460                    orders[i] = 1;
461                    for (int j = 0; j < point.length; ++j) {
462                        partials[j] = point[j].getPartialDerivative(orders);
463                    }
464                    orders[i] = 0;
465
466                    // compose partial derivatives
467                    packed[i + 1] = LinearCombination.value(dv, partials);
468
469                }
470
471                return new DerivativeStructure(parameters, 1, packed);
472
473            }
474
475        };
476
477    }
478
479    /** Convert an {@link UnivariateDifferentiableFunction} to an
480     * {@link UnivariateFunction} computing n<sup>th</sup> order derivative.
481     * <p>
482     * This converter is only a convenience method. Beware computing only one derivative does
483     * not save any computation as the original function will really be called under the hood.
484     * The derivative will be extracted from the full {@link DerivativeStructure} result.
485     * </p>
486     * @param f original function, with value and all its derivatives
487     * @param order of the derivative to extract
488     * @return function computing the derivative at required order
489     * @see #derivative(MultivariateDifferentiableFunction, int[])
490     * @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
491     */
492    public static UnivariateFunction derivative(final UnivariateDifferentiableFunction f, final int order) {
493        return new UnivariateFunction() {
494
495            /** {@inheritDoc} */
496            @Override
497            public double value(final double x) {
498                final DerivativeStructure dsX = new DerivativeStructure(1, order, 0, x);
499                return f.value(dsX).getPartialDerivative(order);
500            }
501
502        };
503    }
504
505    /** Convert an {@link MultivariateDifferentiableFunction} to an
506     * {@link MultivariateFunction} computing n<sup>th</sup> order derivative.
507     * <p>
508     * This converter is only a convenience method. Beware computing only one derivative does
509     * not save any computation as the original function will really be called under the hood.
510     * The derivative will be extracted from the full {@link DerivativeStructure} result.
511     * </p>
512     * @param f original function, with value and all its derivatives
513     * @param orders of the derivative to extract, for each free parameters
514     * @return function computing the derivative at required order
515     * @see #derivative(UnivariateDifferentiableFunction, int)
516     * @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
517     */
518    public static MultivariateFunction derivative(final MultivariateDifferentiableFunction f, final int[] orders) {
519        return new MultivariateFunction() {
520
521            /** {@inheritDoc} */
522            @Override
523            public double value(final double[] point) {
524
525                // the maximum differentiation order is the sum of all orders
526                int sumOrders = 0;
527                for (final int order : orders) {
528                    sumOrders += order;
529                }
530
531                // set up the input parameters
532                final DerivativeStructure[] dsPoint = new DerivativeStructure[point.length];
533                for (int i = 0; i < point.length; ++i) {
534                    dsPoint[i] = new DerivativeStructure(point.length, sumOrders, i, point[i]);
535                }
536
537                return f.value(dsPoint).getPartialDerivative(orders);
538
539            }
540
541        };
542    }
543
544}