View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.optim.nonlinear.scalar;
18  
19  import java.util.function.Supplier;
20  import org.apache.commons.geometry.euclidean.twod.Vector2D;
21  import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
22  import org.apache.commons.math4.legacy.optim.InitialGuess;
23  import org.apache.commons.math4.legacy.optim.MaxEval;
24  import org.apache.commons.math4.legacy.optim.PointValuePair;
25  import org.apache.commons.math4.legacy.optim.SimpleValueChecker;
26  import org.apache.commons.math4.legacy.optim.nonlinear.scalar.gradient.CircleScalar;
27  import org.apache.commons.math4.legacy.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer;
28  import org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv.NelderMeadTransform;
29  import org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv.SimplexOptimizer;
30  import org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv.Simplex;
31  import org.apache.commons.rng.UniformRandomProvider;
32  import org.apache.commons.rng.simple.RandomSource;
33  import org.apache.commons.rng.sampling.distribution.GaussianSampler;
34  import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
35  import org.junit.Assert;
36  import org.junit.Test;
37  
38  public class MultiStartMultivariateOptimizerTest {
39      @Test
40      public void testCircleFitting() {
41          CircleScalar circle = new CircleScalar();
42          circle.addPoint( 30.0,  68.0);
43          circle.addPoint( 50.0,  -6.0);
44          circle.addPoint(110.0, -20.0);
45          circle.addPoint( 35.0,  15.0);
46          circle.addPoint( 45.0,  97.0);
47          final GradientMultivariateOptimizer underlying
48              = new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
49                                                        new SimpleValueChecker(1e-10, 1e-10));
50          final Supplier<double[]> generator = gaussianRandom(new double[] { 50, 50 },
51                                                              new double[] { 10, 10 },
52                                                              RandomSource.MT_64.create());
53          int nbStarts = 10;
54          MultiStartMultivariateOptimizer optimizer
55              = new MultiStartMultivariateOptimizer(underlying, nbStarts, generator);
56          PointValuePair optimum
57              = optimizer.optimize(new MaxEval(1000),
58                                   circle.getObjectiveFunction(),
59                                   circle.getObjectiveFunctionGradient(),
60                                   new NelderMeadTransform(),
61                                   GoalType.MINIMIZE,
62                                   new InitialGuess(new double[] { 98.680, 47.345 }),
63                                   new LineSearchTolerance(1e-10, 1e-10, 1));
64          Assert.assertEquals(1000, optimizer.getMaxEvaluations());
65          final PointValuePair[] optima = optimizer.getOptima();
66          Assert.assertEquals(nbStarts, optima.length);
67          for (PointValuePair o : optima) {
68              // Check the results of all intermediate restarts.
69              Vector2D center = Vector2D.of(o.getPointRef()[0], o.getPointRef()[1]);
70              Assert.assertEquals(69.9597, circle.getRadius(center), 1e-3);
71              Assert.assertEquals(96.07535, center.getX(), 1.4e-3);
72              Assert.assertEquals(48.1349, center.getY(), 5e-3);
73          }
74  
75          final int numEval = optimizer.getEvaluations();
76          Assert.assertTrue("exp: n > 700, act: " + numEval, numEval > 700);
77          Assert.assertTrue("exp: n < 950, act: " + numEval, numEval < 950);
78  
79          Assert.assertEquals(3.1267527, optimum.getValue(), 1e-8);
80      }
81  
82      @Test
83      public void testRosenbrock() {
84          Rosenbrock rosenbrock = new Rosenbrock();
85          SimplexOptimizer underlying
86              = new SimplexOptimizer(new SimpleValueChecker(-1, 1e-3));
87          final Simplex simplex = Simplex.of(new double[][] {
88                  { -1.2,  1.0 },
89                  { 0.9, 1.2 } ,
90                  {  3.5, -2.3 }
91              });
92          final Supplier<double[]> generator = gaussianRandom(new double[] { 0, 0 },
93                                                              new double[] { 1, 1 },
94                                                              RandomSource.MT_64.create());
95          int nbStarts = 10;
96          MultiStartMultivariateOptimizer optimizer
97              = new MultiStartMultivariateOptimizer(underlying, nbStarts, generator);
98          PointValuePair optimum
99              = optimizer.optimize(new MaxEval(1100),
100                                  new ObjectiveFunction(rosenbrock),
101                                  GoalType.MINIMIZE,
102                                  simplex,
103                                  new NelderMeadTransform(),
104                                  new InitialGuess(new double[] { -1.2, 1.0 }));
105         Assert.assertEquals(nbStarts, optimizer.getOptima().length);
106 
107         final int numEval = optimizer.getEvaluations();
108         Assert.assertEquals(rosenbrock.getCount(), numEval);
109         Assert.assertTrue("numEval=" + numEval, numEval > 700);
110         Assert.assertTrue("numEval=" + numEval, numEval < 1200);
111         Assert.assertTrue("optimum=" + optimum.getValue(), optimum.getValue() < 5e-5);
112     }
113 
114     private static final class Rosenbrock implements MultivariateFunction {
115         private int count;
116 
117         Rosenbrock() {
118             count = 0;
119         }
120 
121         @Override
122         public double value(double[] x) {
123             ++count;
124             double a = x[1] - x[0] * x[0];
125             double b = 1 - x[0];
126             return 100 * a * a + b * b;
127         }
128 
129         public int getCount() {
130             return count;
131         }
132     }
133 
134     /**
135      * @param mean Means.
136      * @param stdev Standard deviations.
137      * @param rng Underlying RNG.
138      * @return a random array generator where each element is a Gaussian
139      * sampling with the given mean and standard deviation.
140      */
141     private Supplier<double[]> gaussianRandom(final double[] mean,
142                                               final double[] stdev,
143                                               final UniformRandomProvider rng) {
144         final ZigguratNormalizedGaussianSampler normalized = new ZigguratNormalizedGaussianSampler(rng);
145         final GaussianSampler[] samplers = new GaussianSampler[mean.length];
146         for (int i = 0; i < mean.length; i++) {
147             samplers[i] = new GaussianSampler(normalized, mean[i], stdev[i]);
148         }
149 
150         return new Supplier<double[]>() {
151             @Override
152             public double[] get() {
153                 final double[] s = new double[mean.length];
154                 for (int i = 0; i < mean.length; i++) {
155                     s[i] = samplers[i].sample();
156                 }
157                 return s;
158             }
159         };
160     }
161 }