View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.analysis;
19  
20  import org.apache.commons.numbers.core.Sum;
21  import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
22  import org.apache.commons.math4.legacy.analysis.differentiation.MultivariateDifferentiableFunction;
23  import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction;
24  import org.apache.commons.math4.legacy.analysis.function.Identity;
25  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
26  import org.apache.commons.math4.legacy.exception.NumberIsTooLargeException;
27  
28  /**
29   * Utilities for manipulating function objects.
30   *
31   * @since 3.0
32   */
33  public final class FunctionUtils {
34      /**
35       * Class only contains static methods.
36       */
37      private FunctionUtils() {}
38  
39      /**
40       * Composes functions.
41       * <p>
42       * The functions in the argument list are composed sequentially, in the
43       * given order.  For example, compose(f1,f2,f3) acts like f1(f2(f3(x))).</p>
44       *
45       * @param f List of functions.
46       * @return the composite function.
47       */
48      public static UnivariateFunction compose(final UnivariateFunction ... f) {
49          return new UnivariateFunction() {
50              /** {@inheritDoc} */
51              @Override
52              public double value(double x) {
53                  double r = x;
54                  for (int i = f.length - 1; i >= 0; i--) {
55                      r = f[i].value(r);
56                  }
57                  return r;
58              }
59          };
60      }
61  
62      /**
63       * Composes functions.
64       * <p>
65       * The functions in the argument list are composed sequentially, in the
66       * given order.  For example, compose(f1,f2,f3) acts like f1(f2(f3(x))).</p>
67       *
68       * @param f List of functions.
69       * @return the composite function.
70       * @since 3.1
71       */
72      public static UnivariateDifferentiableFunction compose(final UnivariateDifferentiableFunction ... f) {
73          return new UnivariateDifferentiableFunction() {
74  
75              /** {@inheritDoc} */
76              @Override
77              public double value(final double t) {
78                  double r = t;
79                  for (int i = f.length - 1; i >= 0; i--) {
80                      r = f[i].value(r);
81                  }
82                  return r;
83              }
84  
85              /** {@inheritDoc} */
86              @Override
87              public DerivativeStructure value(final DerivativeStructure t) {
88                  DerivativeStructure r = t;
89                  for (int i = f.length - 1; i >= 0; i--) {
90                      r = f[i].value(r);
91                  }
92                  return r;
93              }
94          };
95      }
96  
97      /**
98       * Adds functions.
99       *
100      * @param f List of functions.
101      * @return a function that computes the sum of the functions.
102      */
103     public static UnivariateFunction add(final UnivariateFunction ... f) {
104         return new UnivariateFunction() {
105             /** {@inheritDoc} */
106             @Override
107             public double value(double x) {
108                 double r = f[0].value(x);
109                 for (int i = 1; i < f.length; i++) {
110                     r += f[i].value(x);
111                 }
112                 return r;
113             }
114         };
115     }
116 
117     /**
118      * Adds functions.
119      *
120      * @param f List of functions.
121      * @return a function that computes the sum of the functions.
122      * @since 3.1
123      */
124     public static UnivariateDifferentiableFunction add(final UnivariateDifferentiableFunction ... f) {
125         return new UnivariateDifferentiableFunction() {
126 
127             /** {@inheritDoc} */
128             @Override
129             public double value(final double t) {
130                 double r = f[0].value(t);
131                 for (int i = 1; i < f.length; i++) {
132                     r += f[i].value(t);
133                 }
134                 return r;
135             }
136 
137             /** {@inheritDoc}
138              * @throws DimensionMismatchException if functions are not consistent with each other
139              */
140             @Override
141             public DerivativeStructure value(final DerivativeStructure t)
142                 throws DimensionMismatchException {
143                 DerivativeStructure r = f[0].value(t);
144                 for (int i = 1; i < f.length; i++) {
145                     r = r.add(f[i].value(t));
146                 }
147                 return r;
148             }
149         };
150     }
151 
152     /**
153      * Multiplies functions.
154      *
155      * @param f List of functions.
156      * @return a function that computes the product of the functions.
157      */
158     public static UnivariateFunction multiply(final UnivariateFunction ... f) {
159         return new UnivariateFunction() {
160             /** {@inheritDoc} */
161             @Override
162             public double value(double x) {
163                 double r = f[0].value(x);
164                 for (int i = 1; i < f.length; i++) {
165                     r *= f[i].value(x);
166                 }
167                 return r;
168             }
169         };
170     }
171 
172     /**
173      * Multiplies functions.
174      *
175      * @param f List of functions.
176      * @return a function that computes the product of the functions.
177      * @since 3.1
178      */
179     public static UnivariateDifferentiableFunction multiply(final UnivariateDifferentiableFunction ... f) {
180         return new UnivariateDifferentiableFunction() {
181 
182             /** {@inheritDoc} */
183             @Override
184             public double value(final double t) {
185                 double r = f[0].value(t);
186                 for (int i = 1; i < f.length; i++) {
187                     r  *= f[i].value(t);
188                 }
189                 return r;
190             }
191 
192             /** {@inheritDoc} */
193             @Override
194             public DerivativeStructure value(final DerivativeStructure t) {
195                 DerivativeStructure r = f[0].value(t);
196                 for (int i = 1; i < f.length; i++) {
197                     r = r.multiply(f[i].value(t));
198                 }
199                 return r;
200             }
201         };
202     }
203 
204     /**
205      * Returns the univariate function
206      * {@code h(x) = combiner(f(x), g(x))}.
207      *
208      * @param combiner Combiner function.
209      * @param f Function.
210      * @param g Function.
211      * @return the composite function.
212      */
213     public static UnivariateFunction combine(final BivariateFunction combiner,
214                                              final UnivariateFunction f,
215                                              final UnivariateFunction g) {
216         return new UnivariateFunction() {
217             /** {@inheritDoc} */
218             @Override
219             public double value(double x) {
220                 return combiner.value(f.value(x), g.value(x));
221             }
222         };
223     }
224 
225     /**
226      * Returns a MultivariateFunction h(x[]). Defined by:
227      * <pre> <code>
228      * h(x[]) = combiner(...combiner(combiner(initialValue,f(x[0])),f(x[1]))...),f(x[x.length-1]))
229      * </code></pre>
230      *
231      * @param combiner Combiner function.
232      * @param f Function.
233      * @param initialValue Initial value.
234      * @return a collector function.
235      */
236     public static MultivariateFunction collector(final BivariateFunction combiner,
237                                                  final UnivariateFunction f,
238                                                  final double initialValue) {
239         return new MultivariateFunction() {
240             /** {@inheritDoc} */
241             @Override
242             public double value(double[] point) {
243                 double result = combiner.value(initialValue, f.value(point[0]));
244                 for (int i = 1; i < point.length; i++) {
245                     result = combiner.value(result, f.value(point[i]));
246                 }
247                 return result;
248             }
249         };
250     }
251 
252     /**
253      * Returns a MultivariateFunction h(x[]). Defined by:
254      * <pre> <code>
255      * h(x[]) = combiner(...combiner(combiner(initialValue,x[0]),x[1])...),x[x.length-1])
256      * </code></pre>
257      *
258      * @param combiner Combiner function.
259      * @param initialValue Initial value.
260      * @return a collector function.
261      */
262     public static MultivariateFunction collector(final BivariateFunction combiner,
263                                                  final double initialValue) {
264         return collector(combiner, new Identity(), initialValue);
265     }
266 
267     /**
268      * Creates a unary function by fixing the first argument of a binary function.
269      *
270      * @param f Binary function.
271      * @param fixed value to which the first argument of {@code f} is set.
272      * @return the unary function h(x) = f(fixed, x)
273      */
274     public static UnivariateFunction fix1stArgument(final BivariateFunction f,
275                                                     final double fixed) {
276         return new UnivariateFunction() {
277             /** {@inheritDoc} */
278             @Override
279             public double value(double x) {
280                 return f.value(fixed, x);
281             }
282         };
283     }
284     /**
285      * Creates a unary function by fixing the second argument of a binary function.
286      *
287      * @param f Binary function.
288      * @param fixed value to which the second argument of {@code f} is set.
289      * @return the unary function h(x) = f(x, fixed)
290      */
291     public static UnivariateFunction fix2ndArgument(final BivariateFunction f,
292                                                     final double fixed) {
293         return new UnivariateFunction() {
294             /** {@inheritDoc} */
295             @Override
296             public double value(double x) {
297                 return f.value(x, fixed);
298             }
299         };
300     }
301 
302     /** Convert regular functions to {@link UnivariateDifferentiableFunction}.
303      * <p>
304      * This method handle the case with one free parameter and several derivatives.
305      * For the case with several free parameters and only first order derivatives,
306      * see {@link #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)}.
307      * There are no direct support for intermediate cases, with several free parameters
308      * and order 2 or more derivatives, as is would be difficult to specify all the
309      * cross derivatives.
310      * </p>
311      * <p>
312      * Note that the derivatives are expected to be computed only with respect to the
313      * raw parameter x of the base function, i.e. they are df/dx, df<sup>2</sup>/dx<sup>2</sup>, ...
314      * Even if the built function is later used in a composition like f(sin(t)), the provided
315      * derivatives should <em>not</em> apply the composition with sine and its derivatives by
316      * themselves. The composition will be done automatically here and the result will properly
317      * contain f(sin(t)), df(sin(t))/dt, df<sup>2</sup>(sin(t))/dt<sup>2</sup> despite the
318      * provided derivatives functions know nothing about the sine function.
319      * </p>
320      * @param f base function f(x)
321      * @param derivatives derivatives of the base function, in increasing differentiation order
322      * @return a differentiable function with value and all specified derivatives
323      * @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
324      * @see #derivative(UnivariateDifferentiableFunction, int)
325      */
326     public static UnivariateDifferentiableFunction toDifferentiable(final UnivariateFunction f,
327                                                                        final UnivariateFunction ... derivatives) {
328 
329         return new UnivariateDifferentiableFunction() {
330 
331             /** {@inheritDoc} */
332             @Override
333             public double value(final double x) {
334                 return f.value(x);
335             }
336 
337             /** {@inheritDoc} */
338             @Override
339             public DerivativeStructure value(final DerivativeStructure x) {
340                 if (x.getOrder() > derivatives.length) {
341                     throw new NumberIsTooLargeException(x.getOrder(), derivatives.length, true);
342                 }
343                 final double[] packed = new double[x.getOrder() + 1];
344                 packed[0] = f.value(x.getValue());
345                 for (int i = 0; i < x.getOrder(); ++i) {
346                     packed[i + 1] = derivatives[i].value(x.getValue());
347                 }
348                 return x.compose(packed);
349             }
350         };
351     }
352 
353     /** Convert regular functions to {@link MultivariateDifferentiableFunction}.
354      * <p>
355      * This method handle the case with several free parameters and only first order derivatives.
356      * For the case with one free parameter and several derivatives,
357      * see {@link #toDifferentiable(UnivariateFunction, UnivariateFunction...)}.
358      * There are no direct support for intermediate cases, with several free parameters
359      * and order 2 or more derivatives, as is would be difficult to specify all the
360      * cross derivatives.
361      * </p>
362      * <p>
363      * Note that the gradient is expected to be computed only with respect to the
364      * raw parameter x of the base function, i.e. it is df/dx<sub>1</sub>, df/dx<sub>2</sub>, ...
365      * Even if the built function is later used in a composition like f(sin(t), cos(t)), the provided
366      * gradient should <em>not</em> apply the composition with sine or cosine and their derivative by
367      * itself. The composition will be done automatically here and the result will properly
368      * contain f(sin(t), cos(t)), df(sin(t), cos(t))/dt despite the provided derivatives functions
369      * know nothing about the sine or cosine functions.
370      * </p>
371      * @param f base function f(x)
372      * @param gradient gradient of the base function
373      * @return a differentiable function with value and gradient
374      * @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
375      * @see #derivative(MultivariateDifferentiableFunction, int[])
376      */
377     public static MultivariateDifferentiableFunction toDifferentiable(final MultivariateFunction f,
378                                                                       final MultivariateVectorFunction gradient) {
379 
380         return new MultivariateDifferentiableFunction() {
381 
382             /** {@inheritDoc} */
383             @Override
384             public double value(final double[] point) {
385                 return f.value(point);
386             }
387 
388             /** {@inheritDoc} */
389             @Override
390             public DerivativeStructure value(final DerivativeStructure[] point) {
391 
392                 // set up the input parameters
393                 final double[] dPoint = new double[point.length];
394                 for (int i = 0; i < point.length; ++i) {
395                     dPoint[i] = point[i].getValue();
396                     if (point[i].getOrder() > 1) {
397                         throw new NumberIsTooLargeException(point[i].getOrder(), 1, true);
398                     }
399                 }
400 
401                 // evaluate regular functions
402                 final double    v = f.value(dPoint);
403                 final double[] dv = gradient.value(dPoint);
404                 if (dv.length != point.length) {
405                     // the gradient function is inconsistent
406                     throw new DimensionMismatchException(dv.length, point.length);
407                 }
408 
409                 // build the combined derivative
410                 final int parameters = point[0].getFreeParameters();
411                 final double[] partials = new double[point.length];
412                 final double[] packed = new double[parameters + 1];
413                 packed[0] = v;
414                 final int[] orders = new int[parameters];
415                 for (int i = 0; i < parameters; ++i) {
416 
417                     // we differentiate once with respect to parameter i
418                     orders[i] = 1;
419                     for (int j = 0; j < point.length; ++j) {
420                         partials[j] = point[j].getPartialDerivative(orders);
421                     }
422                     orders[i] = 0;
423 
424                     // compose partial derivatives
425                     packed[i + 1] = Sum.ofProducts(dv, partials).getAsDouble();
426                 }
427 
428                 return new DerivativeStructure(parameters, 1, packed);
429             }
430         };
431     }
432 
433     /** Convert an {@link UnivariateDifferentiableFunction} to an
434      * {@link UnivariateFunction} computing n<sup>th</sup> order derivative.
435      * <p>
436      * This converter is only a convenience method. Beware computing only one derivative does
437      * not save any computation as the original function will really be called under the hood.
438      * The derivative will be extracted from the full {@link DerivativeStructure} result.
439      * </p>
440      * @param f original function, with value and all its derivatives
441      * @param order of the derivative to extract
442      * @return function computing the derivative at required order
443      * @see #derivative(MultivariateDifferentiableFunction, int[])
444      * @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
445      */
446     public static UnivariateFunction derivative(final UnivariateDifferentiableFunction f, final int order) {
447         return new UnivariateFunction() {
448 
449             /** {@inheritDoc} */
450             @Override
451             public double value(final double x) {
452                 final DerivativeStructure dsX = new DerivativeStructure(1, order, 0, x);
453                 return f.value(dsX).getPartialDerivative(order);
454             }
455         };
456     }
457 
458     /** Convert an {@link MultivariateDifferentiableFunction} to an
459      * {@link MultivariateFunction} computing n<sup>th</sup> order derivative.
460      * <p>
461      * This converter is only a convenience method. Beware computing only one derivative does
462      * not save any computation as the original function will really be called under the hood.
463      * The derivative will be extracted from the full {@link DerivativeStructure} result.
464      * </p>
465      * @param f original function, with value and all its derivatives
466      * @param orders of the derivative to extract, for each free parameters
467      * @return function computing the derivative at required order
468      * @see #derivative(UnivariateDifferentiableFunction, int)
469      * @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
470      */
471     public static MultivariateFunction derivative(final MultivariateDifferentiableFunction f, final int[] orders) {
472         return new MultivariateFunction() {
473 
474             /** {@inheritDoc} */
475             @Override
476             public double value(final double[] point) {
477 
478                 // the maximum differentiation order is the sum of all orders
479                 int sumOrders = 0;
480                 for (final int order : orders) {
481                     sumOrders += order;
482                 }
483 
484                 // set up the input parameters
485                 final DerivativeStructure[] dsPoint = new DerivativeStructure[point.length];
486                 for (int i = 0; i < point.length; ++i) {
487                     dsPoint[i] = new DerivativeStructure(point.length, sumOrders, i, point[i]);
488                 }
489 
490                 return f.value(dsPoint).getPartialDerivative(orders);
491             }
492         };
493     }
494 }