View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv;
18  
19  import java.util.Arrays;
20  import java.util.List;
21  import java.util.ArrayList;
22  import java.io.PrintWriter;
23  import java.io.IOException;
24  import java.nio.file.Files;
25  import java.nio.file.Paths;
26  import java.nio.file.StandardOpenOption;
27  import org.junit.jupiter.api.Assertions;
28  import org.junit.jupiter.api.Test;
29  import org.junit.jupiter.api.extension.ParameterContext;
30  import org.junit.jupiter.params.ParameterizedTest;
31  import org.junit.jupiter.params.aggregator.ArgumentsAggregator;
32  import org.junit.jupiter.params.aggregator.ArgumentsAccessor;
33  import org.junit.jupiter.params.aggregator.ArgumentsAggregationException;
34  import org.junit.jupiter.params.aggregator.AggregateWith;
35  import org.junit.jupiter.params.provider.CsvFileSource;
36  import org.apache.commons.rng.UniformRandomProvider;
37  import org.apache.commons.rng.simple.RandomSource;
38  import org.apache.commons.rng.sampling.distribution.ContinuousUniformSampler;
39  import org.apache.commons.rng.sampling.UnitSphereSampler;
40  import org.apache.commons.math4.legacy.core.MathArrays;
41  import org.apache.commons.math4.legacy.exception.MathUnsupportedOperationException;
42  import org.apache.commons.math4.legacy.exception.TooManyEvaluationsException;
43  import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
44  import org.apache.commons.math4.legacy.optim.InitialGuess;
45  import org.apache.commons.math4.legacy.optim.MaxEval;
46  import org.apache.commons.math4.legacy.optim.PointValuePair;
47  import org.apache.commons.math4.legacy.optim.SimpleBounds;
48  import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType;
49  import org.apache.commons.math4.legacy.optim.nonlinear.scalar.ObjectiveFunction;
50  import org.apache.commons.math4.legacy.optim.nonlinear.scalar.SimulatedAnnealing;
51  import org.apache.commons.math4.legacy.optim.nonlinear.scalar.TestFunction;
52  
53  /**
54   * Tests for {@link SimplexOptimizer simplex-based algorithms}.
55   */
56  public class SimplexOptimizerTest {
57      private static final String NELDER_MEAD_INPUT_FILE = "std_test_func.simplex.nelder_mead.csv";
58      private static final String MULTIDIRECTIONAL_INPUT_FILE = "std_test_func.simplex.multidirectional.csv";
59      private static final String HEDAR_FUKUSHIMA_INPUT_FILE = "std_test_func.simplex.hedar_fukushima.csv";
60  
61      @Test
62      public void testMaxEvaluations() {
63          Assertions.assertThrows(TooManyEvaluationsException.class, () -> {
64                  final int dim = 4;
65                  final SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
66                  optimizer.optimize(new MaxEval(20),
67                                     new ObjectiveFunction(TestFunction.PARABOLA.withDimension(dim)),
68                                     GoalType.MINIMIZE,
69                                     new InitialGuess(new double[] { 3, -1, -3, 1 }),
70                                     Simplex.equalSidesAlongAxes(dim, 1d),
71                                     new NelderMeadTransform());
72              });
73      }
74  
75      @Test
76      public void testBoundsUnsupported() {
77          Assertions.assertThrows(MathUnsupportedOperationException.class, () -> {
78                  final int dim = 2;
79                  final SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
80                  optimizer.optimize(new MaxEval(100),
81                                     new ObjectiveFunction(TestFunction.PARABOLA.withDimension(dim)),
82                                     GoalType.MINIMIZE,
83                                     new InitialGuess(new double[] { -3, 0 }),
84                                     Simplex.alongAxes(new double[] { 0.2, 0.2 }),
85                                     new NelderMeadTransform(),
86                                     new SimpleBounds(new double[] { -5, -1 },
87                                                      new double[] { 5, 1 }));
88              });
89      }
90  
91      @ParameterizedTest
92      @CsvFileSource(resources = NELDER_MEAD_INPUT_FILE)
93      void testFunctionWithNelderMead(@AggregateWith(TaskAggregator.class) Task task) {
94          // task.checkAlongLine(1000);
95          task.run(new NelderMeadTransform());
96      }
97  
98      @ParameterizedTest
99      @CsvFileSource(resources = MULTIDIRECTIONAL_INPUT_FILE)
100     void testFunctionWithMultiDirectional(@AggregateWith(TaskAggregator.class) Task task) {
101         task.run(new MultiDirectionalTransform());
102     }
103 
104     @ParameterizedTest
105     @CsvFileSource(resources = HEDAR_FUKUSHIMA_INPUT_FILE)
106     void testFunctionWithHedarFukushima(@AggregateWith(TaskAggregator.class) Task task) {
107         task.run(new HedarFukushimaTransform());
108     }
109 
110     /**
111      * Optimization task.
112      */
113     public static class Task {
114         /** Function evaluations hard count (debugging). */
115         private static final int FUNC_EVAL_DEBUG = 500000;
116         /** Default convergence criterion. */
117         private static final double CONVERGENCE_CHECK = 1e-9;
118         /** Default cooling factor. */
119         private static final double SA_COOL_FACTOR = 0.7;
120         /** Default acceptance probability at beginning of SA. */
121         private static final double SA_START_PROB = 0.9;
122         /** Default acceptance probability at end of SA. */
123         private static final double SA_END_PROB = 1e-20;
124         /** Function. */
125         private final MultivariateFunction function;
126         /** Initial value. */
127         private final double[] start;
128         /** Optimum. */
129         private final double[] optimum;
130         /** Tolerance. */
131         private final double pointTolerance;
132         /** Allowed function evaluations. */
133         private final int functionEvaluations;
134         /** Side length of initial simplex. */
135         private final double simplexSideLength;
136         /** Whether to perform simulated annealing. */
137         private final boolean withSA;
138         /** File prefix (for saving debugging info). */
139         private final String tracePrefix;
140         /** Indices of simplex points to be saved for debugging. */
141         private final int[] traceIndices;
142 
143         /**
144          * @param function Test function.
145          * @param start Start point.
146          * @param optimum Optimum.
147          * @param pointTolerance Allowed distance between result and
148          * {@code optimum}.
149          * @param functionEvaluations Allowed number of function evaluations.
150          * @param simplexSideLength Side length of initial simplex.
151          * @param withSA Whether to perform simulated annealing.
152          * @param tracePrefix Prefix of the file where to save simplex
153          * transformations during the optimization.
154          * Can be {@code null} (no debugging).
155          * @param traceIndices Indices of simplex points to be saved.
156          * Can be {@code null} (all points are saved).
157          */
158         Task(MultivariateFunction function,
159              double[] start,
160              double[] optimum,
161              double pointTolerance,
162              int functionEvaluations,
163              double simplexSideLength,
164              boolean withSA,
165              String tracePrefix,
166              int[] traceIndices) {
167             this.function = function;
168             this.start = start;
169             this.optimum = optimum;
170             this.pointTolerance = pointTolerance;
171             this.functionEvaluations = functionEvaluations;
172             this.simplexSideLength = simplexSideLength;
173             this.withSA = withSA;
174             this.tracePrefix = tracePrefix;
175             this.traceIndices = traceIndices;
176         }
177 
178         @Override
179         public String toString() {
180             return function.toString();
181         }
182 
183         /**
184          * @param factory Simplex transform factory.
185          */
186         /* package-private */ void run(Simplex.TransformFactory factory) {
187             // Let run with a maximum number of evaluations larger than expected
188             // (as specified by "functionEvaluations") in order to have the unit
189             // test failure message (see assertion below) report the actual number
190             // required by the current code.
191             final int maxEval = Math.max(functionEvaluations, FUNC_EVAL_DEBUG);
192 
193             final String name = function.toString();
194             final int dim = start.length;
195 
196             final SimulatedAnnealing sa;
197             if (withSA) {
198                 final SimulatedAnnealing.CoolingSchedule coolSched =
199                     SimulatedAnnealing.CoolingSchedule.decreasingExponential(SA_COOL_FACTOR);
200 
201                 sa = new SimulatedAnnealing(dim,
202                                             SA_START_PROB,
203                                             SA_END_PROB,
204                                             coolSched,
205                                             RandomSource.KISS.create());
206             } else {
207                 sa = null;
208             }
209 
210             final SimplexOptimizer optim = new SimplexOptimizer(-1, CONVERGENCE_CHECK);
211             if (tracePrefix != null) {
212                 optim.addObserver(createCallback(factory));
213             }
214 
215             final Simplex initialSimplex = Simplex.equalSidesAlongAxes(dim, simplexSideLength);
216             final PointValuePair result =
217                 optim.optimize(new MaxEval(maxEval),
218                                new ObjectiveFunction(function),
219                                GoalType.MINIMIZE,
220                                new InitialGuess(start),
221                                initialSimplex,
222                                factory,
223                                sa);
224 
225             final double[] endPoint = result.getPoint();
226             final double funcValue = result.getValue();
227             final double dist = MathArrays.distance(optimum, endPoint);
228             Assertions.assertEquals(0d, dist, pointTolerance,
229                                     () -> name + ": distance to optimum" +
230                                     " f(" + Arrays.toString(endPoint) + ")=" +
231                                     funcValue);
232 
233             final int nEval = optim.getEvaluations();
234             Assertions.assertTrue(nEval < functionEvaluations,
235                                   () -> name + ": nEval=" + nEval + " < " + functionEvaluations);
236         }
237 
238         /**
239          * @param factory Simplex transform factory.
240          * @return a function to save the simplex's states to file.
241          */
242         private SimplexOptimizer.Observer createCallback(Simplex.TransformFactory factory) {
243             if (tracePrefix == null) {
244                 throw new IllegalArgumentException("Missing file prefix");
245             }
246 
247             final String sep = "__";
248             final String name = tracePrefix + sanitizeBasename(function + sep +
249                                                                Arrays.toString(start) + sep +
250                                                                factory + sep);
251 
252             // Create file; write first data block (optimum) and columns header.
253             try (PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(name)))) {
254                 out.println("# Function: " + function);
255                 out.println("# Transform: " + factory);
256                 out.println("#");
257 
258                 out.println("# Optimum");
259                 for (double c : optimum) {
260                     out.print(c + " ");
261                 }
262                 out.println();
263                 out.println();
264 
265                 out.println("#");
266                 out.print("# <1: evaluations> <2: f(x)> <3: |f(x) - f(optimum)|>");
267                 for (int i = 0; i < start.length; i++) {
268                     out.print(" <" + (i + 4) + ": x[" + i + "]>");
269                 }
270                 out.println();
271             } catch (IOException e) {
272                 Assertions.fail(e.getMessage());
273             }
274 
275             final double fAtOptimum = function.value(optimum);
276 
277             // Return callback function.
278             return (simplex, isInit, numEval) -> {
279                 try (PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(name),
280                                                                                StandardOpenOption.APPEND))) {
281                     if (isInit) {
282                         // Blank line indicating the start of an optimization
283                         // (new data block).
284                         out.println();
285                         out.println("# [init]"); // Initial simplex.
286                     }
287 
288                     final String fieldSep = " ";
289                     // 1 line per simplex point (requested for tracing).
290                     final List<PointValuePair> points = simplex.asList();
291                     for (int index : traceIndices) {
292                         final PointValuePair p = points.get(index);
293                         out.print(numEval + fieldSep +
294                                   p.getValue() + fieldSep +
295                                   Math.abs(p.getValue() - fAtOptimum) + fieldSep);
296 
297                         final double[] coord = p.getPoint();
298                         for (int i = 0; i < coord.length; i++) {
299                             out.print(coord[i] + fieldSep);
300                         }
301                         out.println();
302                     }
303                     // Blank line between simplexes.
304                     out.println();
305                 } catch (IOException e) {
306                     Assertions.fail(e.getMessage());
307                 }
308             };
309         }
310 
311         /**
312          * Asserts that the lowest function value (along a line starting at
313          * {@code start} is reached at the {@code optimum}.
314          *
315          * @param numPoints Number of points at which to evaluate the function.
316          */
317         public void checkAlongLine(int numPoints) {
318             if (tracePrefix != null) {
319                 final String name = tracePrefix + createPlotBasename(function, start, optimum);
320                 try (PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(name)))) {
321                     checkAlongLine(numPoints, out);
322                 } catch (IOException e) {
323                     Assertions.fail(e.getMessage());
324                 }
325             } else {
326                 checkAlongLine(numPoints, null);
327             }
328         }
329 
330         /**
331          * Computes the values of the function along the straight line between
332          * {@link #startPoint} and {@link #optimum} and asserts that the value
333          * at the latter is smaller than at any other points along the line.
334          * <p>
335          * If the {@code output} stream is not {@code null}, two columns are
336          * printed:
337          * <ol>
338          *  <li>parameter in the {@code [0, 1]} interval (0 at {@link #startPoint}
339          *   and 1 at {@link #optimum}),</li>
340          *  <li>function value at {@code t * (optimum - startPoint)}.</li>
341          * </ol>
342          *
343          * @param numPoints Number of points to evaluate between {@link #start}
344          * and {@link #optimum}.
345          * @param output Output stream.
346          */
347         private void checkAlongLine(int numPoints,
348                                     PrintWriter output) {
349             final double delta = 1d / numPoints;
350 
351             final int dim = start.length;
352             final double[] dir = new double[dim];
353             for (int i = 0; i < dim; i++) {
354                 dir[i] = optimum[i] - start[i];
355             }
356 
357             double[] minPoint = null;
358             double minValue = Double.POSITIVE_INFINITY;
359             int count = 0;
360             while (count <= numPoints) {
361                 final double[] p = new double[dim];
362                 final double t = count * delta;
363                 for (int i = 0; i < dim; i++) {
364                     p[i] = start[i] + t * dir[i];
365                 }
366 
367                 final double value = function.value(p);
368                 if (value <= minValue) {
369                     minValue = value;
370                     minPoint = p;
371                 }
372 
373                 if (output != null) {
374                     output.println(t + " " + value);
375                 }
376 
377                 ++count;
378             }
379 
380             final double tol = 1e-15;
381             final double[] point = minPoint;
382             final double value = minValue;
383             Assertions.assertArrayEquals(optimum, minPoint, tol,
384                                          () -> "Minimum: f(" + Arrays.toString(point) + ")=" + value);
385         }
386 
387         /**
388          * Generates a string suitable as a file name.
389          *
390          * @param f Function.
391          * @param start Start point.
392          * @param end End point.
393          * @return a string.
394          */
395         private static String createPlotBasename(MultivariateFunction f,
396                                                  double[] start,
397                                                  double[] end) {
398             final String s = f.toString() + "__" +
399                 Arrays.toString(start) + "__" +
400                 Arrays.toString(end);
401 
402             return sanitizeBasename(s) + ".dat";
403         }
404 
405         /**
406          * Generates a string suitable as a file name:
407          * Brackets and parentheses are removed; space, slash, "=" sign and
408          * comma characters are converted to underscores.
409          *
410          * @param str String.
411          * @return a string.
412          */
413         private static String sanitizeBasename(String str) {
414             final String repl = "_";
415             return str
416                 .replaceAll("\\(", "")
417                 .replaceAll("\\)", "")
418                 .replaceAll("\\[", "")
419                 .replaceAll("\\]", "")
420                 .replaceAll("=", repl)
421                 .replaceAll(",\\s+", repl)
422                 .replaceAll(",", repl)
423                 .replaceAll("\\s", repl)
424                 .replaceAll("/", repl)
425                 .replaceAll("^_+", "")
426                 .replaceAll("_+$", "");
427         }
428     }
429 
430     /**
431      * Helper for preparing a {@link Task}.
432      */
433     public static class TaskAggregator implements ArgumentsAggregator {
434         @Override
435         public Object aggregateArguments(ArgumentsAccessor a,
436                                          ParameterContext context)
437             throws ArgumentsAggregationException {
438 
439             int index = 0; // Argument index.
440 
441             final TestFunction funcGen = a.get(index++, TestFunction.class);
442             final int dim = a.getInteger(index++);
443             final double[] optimum = toArrayOfDoubles(a.getString(index++), dim);
444             final double minRadius = a.getDouble(index++);
445             final double maxRadius = a.getDouble(index++);
446             if (minRadius < 0 ||
447                 maxRadius < 0 ||
448                 minRadius >= maxRadius) {
449                 throw new ArgumentsAggregationException("radii");
450             }
451             final double pointTol = a.getDouble(index++);
452             final int funcEval = a.getInteger(index++);
453             final boolean withSA = a.getBoolean(index++);
454 
455             // Generate a start point within a spherical shell around the optimum.
456             final UniformRandomProvider rng = OptimTestUtils.rng();
457             final double radius = ContinuousUniformSampler.of(rng, minRadius, maxRadius).sample();
458             final double[] start = UnitSphereSampler.of(rng, dim).sample();
459             for (int i = 0; i < dim; i++) {
460                 start[i] *= radius;
461                 start[i] += optimum[i];
462             }
463             // Simplex side.
464             final double sideLength = 0.5 * (maxRadius - minRadius);
465 
466             if (index == a.size()) {
467                 // No more arguments.
468                 return new Task(funcGen.withDimension(dim),
469                                 start,
470                                 optimum,
471                                 pointTol,
472                                 funcEval,
473                                 sideLength,
474                                 withSA,
475                                 null,
476                                 null);
477             } else {
478                 // Debugging configuration.
479                 final String tracePrefix = a.getString(index++);
480                 final int[] spxIndices = tracePrefix == null ?
481                     null :
482                     toSimplexIndices(a.getString(index++), dim);
483 
484                 return new Task(funcGen.withDimension(dim),
485                                 start,
486                                 optimum,
487                                 pointTol,
488                                 funcEval,
489                                 sideLength,
490                                 withSA,
491                                 tracePrefix,
492                                 spxIndices);
493             }
494         }
495 
496         /**
497          * @param str Space-separated list of indices referring to
498          * simplex's points (in the interval {@code [0, dim]}).
499          * The string "LAST" will be converted to index {@code dim}.
500          * The empty string, the string "ALL" and {@code null} will be
501          * converted to all the indices in the interval {@code [0, dim]}.
502          * @param dim Space dimension.
503          * @return the indices (in the order specified in {@code str}).
504          * @throws IllegalArgumentException if an index is out the
505          * {@code [0, dim]} interval.
506          */
507         private static int[] toSimplexIndices(String str,
508                                               int dim) {
509             final List<Integer> list = new ArrayList<>();
510 
511             if (str == null ||
512                 str.isEmpty()) {
513                 for (int i = 0; i <= dim; i++) {
514                     list.add(i);
515                 }
516             } else {
517                 for (String s : str.split("\\s+")) {
518                     if (s.equals("LAST")) {
519                         list.add(dim);
520                     } else if (str.equals("ALL")) {
521                         for (int i = 0; i <= dim; i++) {
522                             list.add(i);
523                         }
524                     } else {
525                         final int index = Integer.valueOf(s);
526                         if (index < 0 ||
527                             index > dim) {
528                             throw new IllegalArgumentException("index: " + index +
529                                                                " (dim=" + dim + ")");
530                         }
531                         list.add(index);
532                     }
533                 }
534             }
535 
536             final int len = list.size();
537             final int[] indices = new int[len];
538             for (int i = 0; i < len; i++) {
539                 indices[i] = list.get(i);
540             }
541 
542             return indices;
543         }
544 
545         /**
546          * @param params Comma-separated list of values.
547          * @param dim Expected number of values.
548          * @return an array of {@code double} values.
549          * @throws ArgumentsAggregationException if the number of values
550          * is not equal to {@code dim}.
551          */
552         private static double[] toArrayOfDoubles(String params,
553                                                  int dim) {
554             final String[] s = params.trim().split("\\s+");
555 
556             if (s.length != dim) {
557                 final String msg = "Expected " + dim + " values: " + Arrays.toString(s);
558                 throw new ArgumentsAggregationException(msg);
559             }
560 
561             final double[] p = new double[dim];
562             for (int i = 0; i < dim; i++) {
563                 p[i] = Double.valueOf(s[i]);
564             }
565 
566             return p;
567         }
568     }
569 }