1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.statistics.inference;
18
19 import java.util.function.DoubleUnaryOperator;
20
21 /**
22 * Provide an interval that brackets a local minimum of a function.
23 * This code is based on a Python implementation (from <em>SciPy</em>,
24 * module {@code optimize.py} v0.5).
25 *
26 * <p>This class has been extracted from {@code o.a.c.math4.optim.univariate}
27 * and modified to: remove support for bracketing a maximum; support bounds
28 * on the bracket; correct the sign of the denominator when the magnitude is small;
29 * and return true/false if there is a minimum strictly inside the bounds.
30 *
31 * @since 1.1
32 */
33 class BracketFinder {
34 /** Tolerance to avoid division by zero. */
35 private static final double EPS_MIN = 1e-21;
36 /** Golden section. */
37 private static final double GOLD = 1.6180339887498948482;
38 /** Factor for expanding the interval. */
39 private final double growLimit;
40 /** Number of allowed function evaluations. */
41 private final int maxEvaluations;
42 /** Number of function evaluations performed in the last search. */
43 private int evaluations;
44 /** Lower bound of the bracket. */
45 private double lo;
46 /** Higher bound of the bracket. */
47 private double hi;
48 /** Point inside the bracket. */
49 private double mid;
50 /** Function value at {@link #lo}. */
51 private double fLo;
52 /** Function value at {@link #hi}. */
53 private double fHi;
54 /** Function value at {@link #mid}. */
55 private double fMid;
56
57 /**
58 * Constructor with default values {@code 100, 100000} (see the
59 * {@link #BracketFinder(double,int) other constructor}).
60 */
61 BracketFinder() {
62 this(100, 100000);
63 }
64
65 /**
66 * Create a bracketing interval finder.
67 *
68 * @param growLimit Expanding factor.
69 * @param maxEvaluations Maximum number of evaluations allowed for finding
70 * a bracketing interval.
71 * @throws IllegalArgumentException if the {@code growLimit} or {@code maxEvalutations}
72 * are not strictly positive.
73 */
74 BracketFinder(double growLimit, int maxEvaluations) {
75 Arguments.checkStrictlyPositive(growLimit);
76 Arguments.checkStrictlyPositive(maxEvaluations);
77 this.growLimit = growLimit;
78 this.maxEvaluations = maxEvaluations;
79 }
80
81 /**
82 * Search downhill from the initial points to obtain new points that bracket a local
83 * minimum of the function. Note that the initial points do not have to bracket a minimum.
84 * An exception is raised if a minimum cannot be found within the configured number
85 * of function evaluations.
86 *
87 * <p>The bracket is limited to the provided bounds if they create a positive interval
88 * {@code min < max}. It is possible that the middle of the bracket is at the bounds as
89 * the final bracket is {@code f(mid) <= min(f(lo), f(hi))} and {@code lo <= mid <= hi}.
90 *
91 * <p>No exception is raised if the initial points are not within the bounds; the points
92 * are updated to be within the bounds.
93 *
94 * <p>No exception is raised if the initial points are equal; the bracket will be returned
95 * as a single point {@code lo == mid == hi}.
96 *
97 * @param func Function whose optimum should be bracketed.
98 * @param a Initial point.
99 * @param b Initial point.
100 * @param min Minimum bound of the bracket (inclusive).
101 * @param max Maximum bound of the bracket (inclusive).
102 * @return true if the mid-point is strictly within the final bracket {@code [lo, hi]};
103 * false if there is no local minima.
104 * @throws IllegalStateException if the maximum number of evaluations is exceeded.
105 */
106 boolean search(DoubleUnaryOperator func,
107 double a, double b,
108 double min, double max) {
109 evaluations = 0;
110
111 // Limit the range of x
112 final DoubleUnaryOperator range;
113 if (min < max) {
114 // Limit: min <= x <= max
115 range = x -> {
116 if (x > min) {
117 return x < max ? x : max;
118 }
119 return min;
120 };
121 } else {
122 range = DoubleUnaryOperator.identity();
123 }
124
125 double xA = range.applyAsDouble(a);
126 double xB = range.applyAsDouble(b);
127 double fA = value(func, xA);
128 double fB = value(func, xB);
129 // Ensure fB <= fA
130 if (fA < fB) {
131 double tmp = xA;
132 xA = xB;
133 xB = tmp;
134 tmp = fA;
135 fA = fB;
136 fB = tmp;
137 }
138
139 double xC = range.applyAsDouble(xB + GOLD * (xB - xA));
140 double fC = value(func, xC);
141
142 // Note: When a [min, max] interval is provided and there is no minima then this
143 // loop will terminate when B == C and both are at the min/max bound.
144 while (fC < fB) {
145 final double tmp1 = (xB - xA) * (fB - fC);
146 final double tmp2 = (xB - xC) * (fB - fA);
147
148 final double val = tmp2 - tmp1;
149 // limit magnitude of val to a small value
150 final double denom = 2 * Math.copySign(Math.max(Math.abs(val), EPS_MIN), val);
151
152 double w = range.applyAsDouble(xB - ((xB - xC) * tmp2 - (xB - xA) * tmp1) / denom);
153 final double wLim = range.applyAsDouble(xB + growLimit * (xC - xB));
154
155 double fW;
156 if ((w - xC) * (xB - w) > 0) {
157 // xB < w < xC
158 fW = value(func, w);
159 if (fW < fC) {
160 // minimum in [xB, xC]
161 xA = xB;
162 xB = w;
163 fA = fB;
164 fB = fW;
165 break;
166 } else if (fW > fB) {
167 // minimum in [xA, w]
168 xC = w;
169 fC = fW;
170 break;
171 }
172 // continue downhill
173 w = range.applyAsDouble(xC + GOLD * (xC - xB));
174 fW = value(func, w);
175 } else if ((w - wLim) * (xC - w) > 0) {
176 // xC < w < limit
177 fW = value(func, w);
178 if (fW < fC) {
179 // continue downhill
180 xB = xC;
181 xC = w;
182 w = range.applyAsDouble(xC + GOLD * (xC - xB));
183 fB = fC;
184 fC = fW;
185 fW = value(func, w);
186 }
187 } else if ((w - wLim) * (wLim - xC) >= 0) {
188 // xC <= limit <= w
189 w = wLim;
190 fW = value(func, w);
191 } else {
192 // possibly w == xC; reject w and take a default step
193 w = range.applyAsDouble(xC + GOLD * (xC - xB));
194 fW = value(func, w);
195 }
196
197 xA = xB;
198 fA = fB;
199 xB = xC;
200 fB = fC;
201 xC = w;
202 fC = fW;
203 }
204
205 mid = xB;
206 fMid = fB;
207
208 // Store the bracket: lo <= mid <= hi
209 if (xC < xA) {
210 lo = xC;
211 fLo = fC;
212 hi = xA;
213 fHi = fA;
214 } else {
215 lo = xA;
216 fLo = fA;
217 hi = xC;
218 fHi = fC;
219 }
220
221 return lo < mid && mid < hi;
222 }
223
224 /**
225 * @return the number of evaluations.
226 */
227 int getEvaluations() {
228 return evaluations;
229 }
230
231 /**
232 * @return the lower bound of the bracket.
233 * @see #getFLo()
234 */
235 double getLo() {
236 return lo;
237 }
238
239 /**
240 * Get function value at {@link #getLo()}.
241 * @return function value at {@link #getLo()}
242 */
243 double getFLo() {
244 return fLo;
245 }
246
247 /**
248 * @return the higher bound of the bracket.
249 * @see #getFHi()
250 */
251 double getHi() {
252 return hi;
253 }
254
255 /**
256 * Get function value at {@link #getHi()}.
257 * @return function value at {@link #getHi()}
258 */
259 double getFHi() {
260 return fHi;
261 }
262
263 /**
264 * @return a point in the middle of the bracket.
265 * @see #getFMid()
266 */
267 double getMid() {
268 return mid;
269 }
270
271 /**
272 * Get function value at {@link #getMid()}.
273 * @return function value at {@link #getMid()}
274 */
275 double getFMid() {
276 return fMid;
277 }
278
279 /**
280 * Get the value of the function.
281 *
282 * @param func Function.
283 * @param x Point.
284 * @return the value
285 * @throws IllegalStateException if the maximal number of evaluations is exceeded.
286 */
287 private double value(DoubleUnaryOperator func, double x) {
288 if (evaluations >= maxEvaluations) {
289 throw new IllegalStateException("Too many evaluations: " + evaluations);
290 }
291 evaluations++;
292 return func.applyAsDouble(x);
293 }
294 }