View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.analysis.solvers;
18  
19  import org.apache.commons.math4.legacy.analysis.QuinticFunction;
20  import org.apache.commons.math4.legacy.analysis.UnivariateFunction;
21  import org.apache.commons.math4.legacy.analysis.XMinus5Function;
22  import org.apache.commons.math4.legacy.analysis.function.Sin;
23  import org.apache.commons.math4.legacy.exception.NoBracketingException;
24  import org.apache.commons.math4.legacy.exception.NumberIsTooLargeException;
25  import org.apache.commons.math4.core.jdkmath.JdkMath;
26  import org.junit.Assert;
27  import org.junit.Test;
28  
29  /**
30   * Base class for root-finding algorithms tests derived from
31   * {@link BaseSecantSolver}.
32   *
33   */
34  public abstract class BaseSecantSolverAbstractTest {
35      /** Returns the solver to use to perform the tests.
36       * @return the solver to use to perform the tests
37       */
38      protected abstract UnivariateSolver getSolver();
39  
40      /** Returns the expected number of evaluations for the
41       * {@link #testQuinticZero} unit test. A value of {@code -1} indicates that
42       * the test should be skipped for that solver.
43       * @return the expected number of evaluations for the
44       * {@link #testQuinticZero} unit test
45       */
46      protected abstract int[] getQuinticEvalCounts();
47  
48      @Test
49      public void testSinZero() {
50          // The sinus function is behaved well around the root at pi. The second
51          // order derivative is zero, which means linear approximating methods
52          // still converge quadratically.
53          UnivariateFunction f = new Sin();
54          double result;
55          UnivariateSolver solver = getSolver();
56  
57          result = solver.solve(100, f, 3, 4);
58          //System.out.println(
59          //    "Root: " + result + " Evaluations: " + solver.getEvaluations());
60          Assert.assertEquals(result, JdkMath.PI, solver.getAbsoluteAccuracy());
61          Assert.assertTrue(solver.getEvaluations() <= 6);
62          result = solver.solve(100, f, 1, 4);
63          //System.out.println(
64          //    "Root: " + result + " Evaluations: " + solver.getEvaluations());
65          Assert.assertEquals(result, JdkMath.PI, solver.getAbsoluteAccuracy());
66          Assert.assertTrue(solver.getEvaluations() <= 7);
67      }
68  
69      @Test
70      public void testQuinticZero() {
71          // The quintic function has zeros at 0, +-0.5 and +-1.
72          // Around the root of 0 the function is well behaved, with a second
73          // derivative of zero a 0.
74          // The other roots are less well to find, in particular the root at 1,
75          // because the function grows fast for x>1.
76          // The function has extrema (first derivative is zero) at 0.27195613
77          // and 0.82221643, intervals containing these values are harder for
78          // the solvers.
79          UnivariateFunction f = new QuinticFunction();
80          double result;
81          UnivariateSolver solver = getSolver();
82          double atol = solver.getAbsoluteAccuracy();
83          int[] counts = getQuinticEvalCounts();
84  
85          // Tests data: initial bounds, and expected solution, per test case.
86          double[][] testsData = {{-0.2,  0.2,  0.0},
87                                  {-0.1,  0.3,  0.0},
88                                  {-0.3,  0.45, 0.0},
89                                  { 0.3,  0.7,  0.5},
90                                  { 0.2,  0.6,  0.5},
91                                  { 0.05, 0.95, 0.5},
92                                  { 0.85, 1.25, 1.0},
93                                  { 0.8,  1.2,  1.0},
94                                  { 0.85, 1.75, 1.0},
95                                  { 0.55, 1.45, 1.0},
96                                  { 0.85, 5.0,  1.0},
97                                 };
98          int maxIter = 500;
99  
100         for(int i = 0; i < testsData.length; i++) {
101             // Skip test, if needed.
102             if (counts[i] == -1) {
103                 continue;
104             }
105 
106             // Compute solution.
107             double[] testData = testsData[i];
108             result = solver.solve(maxIter, f, testData[0], testData[1]);
109             //System.out.println(
110             //    "Root: " + result + " Evaluations: " + solver.getEvaluations());
111 
112             // Check solution.
113             Assert.assertEquals(result, testData[2], atol);
114             Assert.assertTrue(solver.getEvaluations() <= counts[i] + 1);
115         }
116     }
117 
118     @Test
119     public void testRootEndpoints() {
120         UnivariateFunction f = new XMinus5Function();
121         UnivariateSolver solver = getSolver();
122 
123         // End-point is root. This should be a special case in the solver, and
124         // the initial end-point should be returned exactly.
125         double result = solver.solve(100, f, 5.0, 6.0);
126         Assert.assertEquals(5.0, result, 0.0);
127 
128         result = solver.solve(100, f, 4.0, 5.0);
129         Assert.assertEquals(5.0, result, 0.0);
130 
131         result = solver.solve(100, f, 5.0, 6.0, 5.5);
132         Assert.assertEquals(5.0, result, 0.0);
133 
134         result = solver.solve(100, f, 4.0, 5.0, 4.5);
135         Assert.assertEquals(5.0, result, 0.0);
136     }
137 
138     @Test
139     public void testBadEndpoints() {
140         UnivariateFunction f = new Sin();
141         UnivariateSolver solver = getSolver();
142         try {  // bad interval
143             solver.solve(100, f, 1, -1);
144             Assert.fail("Expecting NumberIsTooLargeException - bad interval");
145         } catch (NumberIsTooLargeException ex) {
146             // expected
147         }
148         try {  // no bracket
149             solver.solve(100, f, 1, 1.5);
150             Assert.fail("Expecting NoBracketingException - non-bracketing");
151         } catch (NoBracketingException ex) {
152             // expected
153         }
154         try {  // no bracket
155             solver.solve(100, f, 1, 1.5, 1.2);
156             Assert.fail("Expecting NoBracketingException - non-bracketing");
157         } catch (NoBracketingException ex) {
158             // expected
159         }
160     }
161 
162     @Test
163     public void testSolutionLeftSide() {
164         UnivariateFunction f = new Sin();
165         UnivariateSolver solver = getSolver();
166         double left = -1.5;
167         double right = 0.05;
168         for(int i = 0; i < 10; i++) {
169             // Test whether the allowed solutions are taken into account.
170             double solution = getSolution(solver, 100, f, left, right, AllowedSolution.LEFT_SIDE);
171             if (!Double.isNaN(solution)) {
172                 Assert.assertTrue(solution <= 0.0);
173             }
174 
175             // Prepare for next test.
176             left -= 0.1;
177             right += 0.3;
178         }
179     }
180 
181     @Test
182     public void testSolutionRightSide() {
183         UnivariateFunction f = new Sin();
184         UnivariateSolver solver = getSolver();
185         double left = -1.5;
186         double right = 0.05;
187         for(int i = 0; i < 10; i++) {
188             // Test whether the allowed solutions are taken into account.
189             double solution = getSolution(solver, 100, f, left, right, AllowedSolution.RIGHT_SIDE);
190             if (!Double.isNaN(solution)) {
191                 Assert.assertTrue(solution >= 0.0);
192             }
193 
194             // Prepare for next test.
195             left -= 0.1;
196             right += 0.3;
197         }
198     }
199     @Test
200     public void testSolutionBelowSide() {
201         UnivariateFunction f = new Sin();
202         UnivariateSolver solver = getSolver();
203         double left = -1.5;
204         double right = 0.05;
205         for(int i = 0; i < 10; i++) {
206             // Test whether the allowed solutions are taken into account.
207             double solution = getSolution(solver, 100, f, left, right, AllowedSolution.BELOW_SIDE);
208             if (!Double.isNaN(solution)) {
209                 Assert.assertTrue(f.value(solution) <= 0.0);
210             }
211 
212             // Prepare for next test.
213             left -= 0.1;
214             right += 0.3;
215         }
216     }
217 
218     @Test
219     public void testSolutionAboveSide() {
220         UnivariateFunction f = new Sin();
221         UnivariateSolver solver = getSolver();
222         double left = -1.5;
223         double right = 0.05;
224         for(int i = 0; i < 10; i++) {
225             // Test whether the allowed solutions are taken into account.
226             double solution = getSolution(solver, 100, f, left, right, AllowedSolution.ABOVE_SIDE);
227             if (!Double.isNaN(solution)) {
228                 Assert.assertTrue(f.value(solution) >= 0.0);
229             }
230 
231             // Prepare for next test.
232             left -= 0.1;
233             right += 0.3;
234         }
235     }
236 
237     private double getSolution(UnivariateSolver solver, int maxEval, UnivariateFunction f,
238                                double left, double right, AllowedSolution allowedSolution) {
239         try {
240             @SuppressWarnings("unchecked")
241             BracketedUnivariateSolver<UnivariateFunction> bracketing =
242             (BracketedUnivariateSolver<UnivariateFunction>) solver;
243             return bracketing.solve(100, f, left, right, allowedSolution);
244         } catch (ClassCastException cce) {
245             double baseRoot = solver.solve(maxEval, f, left, right);
246             if (baseRoot <= left || baseRoot >= right) {
247                 // the solution slipped out of interval
248                 return Double.NaN;
249             }
250             PegasusSolver bracketing =
251                     new PegasusSolver(solver.getRelativeAccuracy(), solver.getAbsoluteAccuracy(),
252                                       solver.getFunctionValueAccuracy());
253             return UnivariateSolverUtils.forceSide(maxEval - solver.getEvaluations(),
254                                                        f, bracketing, baseRoot, left, right,
255                                                        allowedSolution);
256         }
257     }
258 }