1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  package org.apache.commons.math4.legacy.analysis.solvers;
18  
19  import org.apache.commons.math4.legacy.analysis.QuinticFunction;
20  import org.apache.commons.math4.legacy.analysis.UnivariateFunction;
21  import org.apache.commons.math4.legacy.analysis.XMinus5Function;
22  import org.apache.commons.math4.legacy.analysis.function.Sin;
23  import org.apache.commons.math4.legacy.exception.NoBracketingException;
24  import org.apache.commons.math4.legacy.exception.NumberIsTooLargeException;
25  import org.apache.commons.math4.core.jdkmath.JdkMath;
26  import org.junit.Assert;
27  import org.junit.Test;
28  
29  
30  
31  
32  
33  
34  public abstract class BaseSecantSolverAbstractTest {
35      
36  
37  
38      protected abstract UnivariateSolver getSolver();
39  
40      
41  
42  
43  
44  
45  
46      protected abstract int[] getQuinticEvalCounts();
47  
48      @Test
49      public void testSinZero() {
50          
51          
52          
53          UnivariateFunction f = new Sin();
54          double result;
55          UnivariateSolver solver = getSolver();
56  
57          result = solver.solve(100, f, 3, 4);
58          
59          
60          Assert.assertEquals(result, JdkMath.PI, solver.getAbsoluteAccuracy());
61          Assert.assertTrue(solver.getEvaluations() <= 6);
62          result = solver.solve(100, f, 1, 4);
63          
64          
65          Assert.assertEquals(result, JdkMath.PI, solver.getAbsoluteAccuracy());
66          Assert.assertTrue(solver.getEvaluations() <= 7);
67      }
68  
69      @Test
70      public void testQuinticZero() {
71          
72          
73          
74          
75          
76          
77          
78          
79          UnivariateFunction f = new QuinticFunction();
80          double result;
81          UnivariateSolver solver = getSolver();
82          double atol = solver.getAbsoluteAccuracy();
83          int[] counts = getQuinticEvalCounts();
84  
85          
86          double[][] testsData = {{-0.2,  0.2,  0.0},
87                                  {-0.1,  0.3,  0.0},
88                                  {-0.3,  0.45, 0.0},
89                                  { 0.3,  0.7,  0.5},
90                                  { 0.2,  0.6,  0.5},
91                                  { 0.05, 0.95, 0.5},
92                                  { 0.85, 1.25, 1.0},
93                                  { 0.8,  1.2,  1.0},
94                                  { 0.85, 1.75, 1.0},
95                                  { 0.55, 1.45, 1.0},
96                                  { 0.85, 5.0,  1.0},
97                                 };
98          int maxIter = 500;
99  
100         for(int i = 0; i < testsData.length; i++) {
101             
102             if (counts[i] == -1) {
103                 continue;
104             }
105 
106             
107             double[] testData = testsData[i];
108             result = solver.solve(maxIter, f, testData[0], testData[1]);
109             
110             
111 
112             
113             Assert.assertEquals(result, testData[2], atol);
114             Assert.assertTrue(solver.getEvaluations() <= counts[i] + 1);
115         }
116     }
117 
118     @Test
119     public void testRootEndpoints() {
120         UnivariateFunction f = new XMinus5Function();
121         UnivariateSolver solver = getSolver();
122 
123         
124         
125         double result = solver.solve(100, f, 5.0, 6.0);
126         Assert.assertEquals(5.0, result, 0.0);
127 
128         result = solver.solve(100, f, 4.0, 5.0);
129         Assert.assertEquals(5.0, result, 0.0);
130 
131         result = solver.solve(100, f, 5.0, 6.0, 5.5);
132         Assert.assertEquals(5.0, result, 0.0);
133 
134         result = solver.solve(100, f, 4.0, 5.0, 4.5);
135         Assert.assertEquals(5.0, result, 0.0);
136     }
137 
138     @Test
139     public void testBadEndpoints() {
140         UnivariateFunction f = new Sin();
141         UnivariateSolver solver = getSolver();
142         try {  
143             solver.solve(100, f, 1, -1);
144             Assert.fail("Expecting NumberIsTooLargeException - bad interval");
145         } catch (NumberIsTooLargeException ex) {
146             
147         }
148         try {  
149             solver.solve(100, f, 1, 1.5);
150             Assert.fail("Expecting NoBracketingException - non-bracketing");
151         } catch (NoBracketingException ex) {
152             
153         }
154         try {  
155             solver.solve(100, f, 1, 1.5, 1.2);
156             Assert.fail("Expecting NoBracketingException - non-bracketing");
157         } catch (NoBracketingException ex) {
158             
159         }
160     }
161 
162     @Test
163     public void testSolutionLeftSide() {
164         UnivariateFunction f = new Sin();
165         UnivariateSolver solver = getSolver();
166         double left = -1.5;
167         double right = 0.05;
168         for(int i = 0; i < 10; i++) {
169             
170             double solution = getSolution(solver, 100, f, left, right, AllowedSolution.LEFT_SIDE);
171             if (!Double.isNaN(solution)) {
172                 Assert.assertTrue(solution <= 0.0);
173             }
174 
175             
176             left -= 0.1;
177             right += 0.3;
178         }
179     }
180 
181     @Test
182     public void testSolutionRightSide() {
183         UnivariateFunction f = new Sin();
184         UnivariateSolver solver = getSolver();
185         double left = -1.5;
186         double right = 0.05;
187         for(int i = 0; i < 10; i++) {
188             
189             double solution = getSolution(solver, 100, f, left, right, AllowedSolution.RIGHT_SIDE);
190             if (!Double.isNaN(solution)) {
191                 Assert.assertTrue(solution >= 0.0);
192             }
193 
194             
195             left -= 0.1;
196             right += 0.3;
197         }
198     }
199     @Test
200     public void testSolutionBelowSide() {
201         UnivariateFunction f = new Sin();
202         UnivariateSolver solver = getSolver();
203         double left = -1.5;
204         double right = 0.05;
205         for(int i = 0; i < 10; i++) {
206             
207             double solution = getSolution(solver, 100, f, left, right, AllowedSolution.BELOW_SIDE);
208             if (!Double.isNaN(solution)) {
209                 Assert.assertTrue(f.value(solution) <= 0.0);
210             }
211 
212             
213             left -= 0.1;
214             right += 0.3;
215         }
216     }
217 
218     @Test
219     public void testSolutionAboveSide() {
220         UnivariateFunction f = new Sin();
221         UnivariateSolver solver = getSolver();
222         double left = -1.5;
223         double right = 0.05;
224         for(int i = 0; i < 10; i++) {
225             
226             double solution = getSolution(solver, 100, f, left, right, AllowedSolution.ABOVE_SIDE);
227             if (!Double.isNaN(solution)) {
228                 Assert.assertTrue(f.value(solution) >= 0.0);
229             }
230 
231             
232             left -= 0.1;
233             right += 0.3;
234         }
235     }
236 
237     private double getSolution(UnivariateSolver solver, int maxEval, UnivariateFunction f,
238                                double left, double right, AllowedSolution allowedSolution) {
239         try {
240             @SuppressWarnings("unchecked")
241             BracketedUnivariateSolver<UnivariateFunction> bracketing =
242             (BracketedUnivariateSolver<UnivariateFunction>) solver;
243             return bracketing.solve(100, f, left, right, allowedSolution);
244         } catch (ClassCastException cce) {
245             double baseRoot = solver.solve(maxEval, f, left, right);
246             if (baseRoot <= left || baseRoot >= right) {
247                 
248                 return Double.NaN;
249             }
250             PegasusSolver bracketing =
251                     new PegasusSolver(solver.getRelativeAccuracy(), solver.getAbsoluteAccuracy(),
252                                       solver.getFunctionValueAccuracy());
253             return UnivariateSolverUtils.forceSide(maxEval - solver.getEvaluations(),
254                                                        f, bracketing, baseRoot, left, right,
255                                                        allowedSolution);
256         }
257     }
258 }