1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.analysis;
19
20 import org.apache.commons.numbers.core.Sum;
21 import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
22 import org.apache.commons.math4.legacy.analysis.differentiation.MultivariateDifferentiableFunction;
23 import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction;
24 import org.apache.commons.math4.legacy.analysis.function.Identity;
25 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
26 import org.apache.commons.math4.legacy.exception.NumberIsTooLargeException;
27
28
29
30
31
32
33 public final class FunctionUtils {
34
35
36
37 private FunctionUtils() {}
38
39
40
41
42
43
44
45
46
47
48 public static UnivariateFunction compose(final UnivariateFunction ... f) {
49 return new UnivariateFunction() {
50
51 @Override
52 public double value(double x) {
53 double r = x;
54 for (int i = f.length - 1; i >= 0; i--) {
55 r = f[i].value(r);
56 }
57 return r;
58 }
59 };
60 }
61
62
63
64
65
66
67
68
69
70
71
72 public static UnivariateDifferentiableFunction compose(final UnivariateDifferentiableFunction ... f) {
73 return new UnivariateDifferentiableFunction() {
74
75
76 @Override
77 public double value(final double t) {
78 double r = t;
79 for (int i = f.length - 1; i >= 0; i--) {
80 r = f[i].value(r);
81 }
82 return r;
83 }
84
85
86 @Override
87 public DerivativeStructure value(final DerivativeStructure t) {
88 DerivativeStructure r = t;
89 for (int i = f.length - 1; i >= 0; i--) {
90 r = f[i].value(r);
91 }
92 return r;
93 }
94 };
95 }
96
97
98
99
100
101
102
103 public static UnivariateFunction add(final UnivariateFunction ... f) {
104 return new UnivariateFunction() {
105
106 @Override
107 public double value(double x) {
108 double r = f[0].value(x);
109 for (int i = 1; i < f.length; i++) {
110 r += f[i].value(x);
111 }
112 return r;
113 }
114 };
115 }
116
117
118
119
120
121
122
123
124 public static UnivariateDifferentiableFunction add(final UnivariateDifferentiableFunction ... f) {
125 return new UnivariateDifferentiableFunction() {
126
127
128 @Override
129 public double value(final double t) {
130 double r = f[0].value(t);
131 for (int i = 1; i < f.length; i++) {
132 r += f[i].value(t);
133 }
134 return r;
135 }
136
137
138
139
140 @Override
141 public DerivativeStructure value(final DerivativeStructure t)
142 throws DimensionMismatchException {
143 DerivativeStructure r = f[0].value(t);
144 for (int i = 1; i < f.length; i++) {
145 r = r.add(f[i].value(t));
146 }
147 return r;
148 }
149 };
150 }
151
152
153
154
155
156
157
158 public static UnivariateFunction multiply(final UnivariateFunction ... f) {
159 return new UnivariateFunction() {
160
161 @Override
162 public double value(double x) {
163 double r = f[0].value(x);
164 for (int i = 1; i < f.length; i++) {
165 r *= f[i].value(x);
166 }
167 return r;
168 }
169 };
170 }
171
172
173
174
175
176
177
178
179 public static UnivariateDifferentiableFunction multiply(final UnivariateDifferentiableFunction ... f) {
180 return new UnivariateDifferentiableFunction() {
181
182
183 @Override
184 public double value(final double t) {
185 double r = f[0].value(t);
186 for (int i = 1; i < f.length; i++) {
187 r *= f[i].value(t);
188 }
189 return r;
190 }
191
192
193 @Override
194 public DerivativeStructure value(final DerivativeStructure t) {
195 DerivativeStructure r = f[0].value(t);
196 for (int i = 1; i < f.length; i++) {
197 r = r.multiply(f[i].value(t));
198 }
199 return r;
200 }
201 };
202 }
203
204
205
206
207
208
209
210
211
212
213 public static UnivariateFunction combine(final BivariateFunction combiner,
214 final UnivariateFunction f,
215 final UnivariateFunction g) {
216 return new UnivariateFunction() {
217
218 @Override
219 public double value(double x) {
220 return combiner.value(f.value(x), g.value(x));
221 }
222 };
223 }
224
225
226
227
228
229
230
231
232
233
234
235
236 public static MultivariateFunction collector(final BivariateFunction combiner,
237 final UnivariateFunction f,
238 final double initialValue) {
239 return new MultivariateFunction() {
240
241 @Override
242 public double value(double[] point) {
243 double result = combiner.value(initialValue, f.value(point[0]));
244 for (int i = 1; i < point.length; i++) {
245 result = combiner.value(result, f.value(point[i]));
246 }
247 return result;
248 }
249 };
250 }
251
252
253
254
255
256
257
258
259
260
261
262 public static MultivariateFunction collector(final BivariateFunction combiner,
263 final double initialValue) {
264 return collector(combiner, new Identity(), initialValue);
265 }
266
267
268
269
270
271
272
273
274 public static UnivariateFunction fix1stArgument(final BivariateFunction f,
275 final double fixed) {
276 return new UnivariateFunction() {
277
278 @Override
279 public double value(double x) {
280 return f.value(fixed, x);
281 }
282 };
283 }
284
285
286
287
288
289
290
291 public static UnivariateFunction fix2ndArgument(final BivariateFunction f,
292 final double fixed) {
293 return new UnivariateFunction() {
294
295 @Override
296 public double value(double x) {
297 return f.value(x, fixed);
298 }
299 };
300 }
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326 public static UnivariateDifferentiableFunction toDifferentiable(final UnivariateFunction f,
327 final UnivariateFunction ... derivatives) {
328
329 return new UnivariateDifferentiableFunction() {
330
331
332 @Override
333 public double value(final double x) {
334 return f.value(x);
335 }
336
337
338 @Override
339 public DerivativeStructure value(final DerivativeStructure x) {
340 if (x.getOrder() > derivatives.length) {
341 throw new NumberIsTooLargeException(x.getOrder(), derivatives.length, true);
342 }
343 final double[] packed = new double[x.getOrder() + 1];
344 packed[0] = f.value(x.getValue());
345 for (int i = 0; i < x.getOrder(); ++i) {
346 packed[i + 1] = derivatives[i].value(x.getValue());
347 }
348 return x.compose(packed);
349 }
350 };
351 }
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377 public static MultivariateDifferentiableFunction toDifferentiable(final MultivariateFunction f,
378 final MultivariateVectorFunction gradient) {
379
380 return new MultivariateDifferentiableFunction() {
381
382
383 @Override
384 public double value(final double[] point) {
385 return f.value(point);
386 }
387
388
389 @Override
390 public DerivativeStructure value(final DerivativeStructure[] point) {
391
392
393 final double[] dPoint = new double[point.length];
394 for (int i = 0; i < point.length; ++i) {
395 dPoint[i] = point[i].getValue();
396 if (point[i].getOrder() > 1) {
397 throw new NumberIsTooLargeException(point[i].getOrder(), 1, true);
398 }
399 }
400
401
402 final double v = f.value(dPoint);
403 final double[] dv = gradient.value(dPoint);
404 if (dv.length != point.length) {
405
406 throw new DimensionMismatchException(dv.length, point.length);
407 }
408
409
410 final int parameters = point[0].getFreeParameters();
411 final double[] partials = new double[point.length];
412 final double[] packed = new double[parameters + 1];
413 packed[0] = v;
414 final int[] orders = new int[parameters];
415 for (int i = 0; i < parameters; ++i) {
416
417
418 orders[i] = 1;
419 for (int j = 0; j < point.length; ++j) {
420 partials[j] = point[j].getPartialDerivative(orders);
421 }
422 orders[i] = 0;
423
424
425 packed[i + 1] = Sum.ofProducts(dv, partials).getAsDouble();
426 }
427
428 return new DerivativeStructure(parameters, 1, packed);
429 }
430 };
431 }
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446 public static UnivariateFunction derivative(final UnivariateDifferentiableFunction f, final int order) {
447 return new UnivariateFunction() {
448
449
450 @Override
451 public double value(final double x) {
452 final DerivativeStructure dsX = new DerivativeStructure(1, order, 0, x);
453 return f.value(dsX).getPartialDerivative(order);
454 }
455 };
456 }
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471 public static MultivariateFunction derivative(final MultivariateDifferentiableFunction f, final int[] orders) {
472 return new MultivariateFunction() {
473
474
475 @Override
476 public double value(final double[] point) {
477
478
479 int sumOrders = 0;
480 for (final int order : orders) {
481 sumOrders += order;
482 }
483
484
485 final DerivativeStructure[] dsPoint = new DerivativeStructure[point.length];
486 for (int i = 0; i < point.length; ++i) {
487 dsPoint[i] = new DerivativeStructure(point.length, sumOrders, i, point[i]);
488 }
489
490 return f.value(dsPoint).getPartialDerivative(orders);
491 }
492 };
493 }
494 }