1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.optim.univariate;
19
20 import java.util.Arrays;
21 import java.util.Comparator;
22
23 import org.apache.commons.math4.legacy.exception.MathIllegalStateException;
24 import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
25 import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
26 import org.apache.commons.math4.legacy.optim.MaxEval;
27 import org.apache.commons.math4.legacy.optim.OptimizationData;
28 import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType;
29 import org.apache.commons.rng.UniformRandomProvider;
30
31
32
33
34
35
36
37
38
39
40
41 public class MultiStartUnivariateOptimizer
42 extends UnivariateOptimizer {
43
44 private final UnivariateOptimizer optimizer;
45
46 private int totalEvaluations;
47
48 private final int starts;
49
50 private final UniformRandomProvider generator;
51
52 private UnivariatePointValuePair[] optima;
53
54 private OptimizationData[] optimData;
55
56
57
58
59 private int maxEvalIndex = -1;
60
61
62
63
64 private int searchIntervalIndex = -1;
65
66
67
68
69
70
71
72
73
74
75
76 public MultiStartUnivariateOptimizer(final UnivariateOptimizer optimizer,
77 final int starts,
78 final UniformRandomProvider generator) {
79 super(optimizer.getConvergenceChecker());
80
81 if (starts < 1) {
82 throw new NotStrictlyPositiveException(starts);
83 }
84
85 this.optimizer = optimizer;
86 this.starts = starts;
87 this.generator = generator;
88 }
89
90
91 @Override
92 public int getEvaluations() {
93 return totalEvaluations;
94 }
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118 public UnivariatePointValuePair[] getOptima() {
119 if (optima == null) {
120 throw new MathIllegalStateException(LocalizedFormats.NO_OPTIMUM_COMPUTED_YET);
121 }
122 return optima.clone();
123 }
124
125
126
127
128
129
130
131 @Override
132 public UnivariatePointValuePair optimize(OptimizationData... optData) {
133
134 optimData = optData;
135
136 return super.optimize(optData);
137 }
138
139
140 @Override
141 protected UnivariatePointValuePair doOptimize() {
142
143
144
145
146
147 for (int i = 0; i < optimData.length; i++) {
148 if (optimData[i] instanceof MaxEval) {
149 optimData[i] = null;
150 maxEvalIndex = i;
151 continue;
152 }
153 if (optimData[i] instanceof SearchInterval) {
154 optimData[i] = null;
155 searchIntervalIndex = i;
156 continue;
157 }
158 }
159 if (maxEvalIndex == -1) {
160 throw new MathIllegalStateException();
161 }
162 if (searchIntervalIndex == -1) {
163 throw new MathIllegalStateException();
164 }
165
166 RuntimeException lastException = null;
167 optima = new UnivariatePointValuePair[starts];
168 totalEvaluations = 0;
169
170 final int maxEval = getMaxEvaluations();
171 final double min = getMin();
172 final double max = getMax();
173 final double startValue = getStartValue();
174
175
176 for (int i = 0; i < starts; i++) {
177
178 try {
179
180 optimData[maxEvalIndex] = new MaxEval(maxEval - totalEvaluations);
181
182 final double s = (i == 0) ?
183 startValue :
184 min + generator.nextDouble() * (max - min);
185 optimData[searchIntervalIndex] = new SearchInterval(min, max, s);
186
187 optima[i] = optimizer.optimize(optimData);
188 } catch (RuntimeException mue) {
189 lastException = mue;
190 optima[i] = null;
191 }
192
193
194 totalEvaluations += optimizer.getEvaluations();
195 }
196
197 sortPairs(getGoalType());
198
199 if (optima[0] == null) {
200 throw lastException;
201 }
202
203
204 return optima[0];
205 }
206
207
208
209
210
211
212 private void sortPairs(final GoalType goal) {
213 Arrays.sort(optima, new Comparator<UnivariatePointValuePair>() {
214
215 @Override
216 public int compare(final UnivariatePointValuePair o1,
217 final UnivariatePointValuePair o2) {
218 if (o1 == null) {
219 return (o2 == null) ? 0 : 1;
220 } else if (o2 == null) {
221 return -1;
222 }
223 final double v1 = o1.getValue();
224 final double v2 = o2.getValue();
225 return (goal == GoalType.MINIMIZE) ?
226 Double.compare(v1, v2) : Double.compare(v2, v1);
227 }
228 });
229 }
230 }