1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.optim.nonlinear.scalar.gradient;
19
20 import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
21 import org.apache.commons.math4.legacy.exception.MathInternalError;
22 import org.apache.commons.math4.legacy.exception.MathUnsupportedOperationException;
23 import org.apache.commons.math4.legacy.exception.TooManyEvaluationsException;
24 import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
25 import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
26 import org.apache.commons.math4.legacy.optim.OptimizationData;
27 import org.apache.commons.math4.legacy.optim.PointValuePair;
28 import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType;
29 import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GradientMultivariateOptimizer;
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46 public class NonLinearConjugateGradientOptimizer
47 extends GradientMultivariateOptimizer {
48
49 private final Formula updateFormula;
50
51 private final Preconditioner preconditioner;
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71 public enum Formula {
72
73 FLETCHER_REEVES,
74
75 POLAK_RIBIERE
76 }
77
78
79
80
81
82
83
84
85
86
87
88 public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
89 ConvergenceChecker<PointValuePair> checker) {
90 this(updateFormula,
91 checker,
92 new IdentityPreconditioner());
93 }
94
95
96
97
98
99
100
101
102
103
104 public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
105 ConvergenceChecker<PointValuePair> checker,
106 final Preconditioner preconditioner) {
107 super(checker);
108
109 this.updateFormula = updateFormula;
110 this.preconditioner = preconditioner;
111 }
112
113
114
115
116 @Override
117 public PointValuePair optimize(OptimizationData... optData)
118 throws TooManyEvaluationsException {
119
120 return super.optimize(optData);
121 }
122
123
124 @Override
125 protected PointValuePair doOptimize() {
126 final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
127 final double[] point = getStartPoint();
128 final GoalType goal = getGoalType();
129 final MultivariateFunction func = getObjectiveFunction();
130 final int n = point.length;
131 double[] r = computeObjectiveGradient(point);
132 if (goal == GoalType.MINIMIZE) {
133 for (int i = 0; i < n; i++) {
134 r[i] = -r[i];
135 }
136 }
137
138
139 double[] steepestDescent = preconditioner.precondition(point, r);
140 double[] searchDirection = steepestDescent.clone();
141
142 double delta = 0;
143 for (int i = 0; i < n; ++i) {
144 delta += r[i] * searchDirection[i];
145 }
146
147 createLineSearch();
148
149 PointValuePair current = null;
150 while (true) {
151 incrementIterationCount();
152
153 final double objective = func.value(point);
154 PointValuePair previous = current;
155 current = new PointValuePair(point, objective);
156 if (previous != null &&
157 checker.converged(getIterations(), previous, current)) {
158
159 return current;
160 }
161
162 final double step = lineSearch(point, searchDirection).getPoint();
163
164
165 for (int i = 0; i < point.length; ++i) {
166 point[i] += step * searchDirection[i];
167 }
168
169 r = computeObjectiveGradient(point);
170 if (goal == GoalType.MINIMIZE) {
171 for (int i = 0; i < n; ++i) {
172 r[i] = -r[i];
173 }
174 }
175
176
177 final double deltaOld = delta;
178 final double[] newSteepestDescent = preconditioner.precondition(point, r);
179 delta = 0;
180 for (int i = 0; i < n; ++i) {
181 delta += r[i] * newSteepestDescent[i];
182 }
183
184 final double beta;
185 switch (updateFormula) {
186 case FLETCHER_REEVES:
187 beta = delta / deltaOld;
188 break;
189 case POLAK_RIBIERE:
190 double deltaMid = 0;
191 for (int i = 0; i < r.length; ++i) {
192 deltaMid += r[i] * steepestDescent[i];
193 }
194 beta = (delta - deltaMid) / deltaOld;
195 break;
196 default:
197
198 throw new MathInternalError();
199 }
200 steepestDescent = newSteepestDescent;
201
202
203 if (getIterations() % n == 0 ||
204 beta < 0) {
205
206 searchDirection = steepestDescent.clone();
207 } else {
208
209 for (int i = 0; i < n; ++i) {
210 searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
211 }
212 }
213 }
214 }
215
216
217
218
219 @Override
220 protected void parseOptimizationData(OptimizationData... optData) {
221
222 super.parseOptimizationData(optData);
223
224 checkParameters();
225 }
226
227
228 public static class IdentityPreconditioner implements Preconditioner {
229
230 @Override
231 public double[] precondition(double[] variables, double[] r) {
232 return r.clone();
233 }
234 }
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290 private void checkParameters() {
291 if (getLowerBound() != null ||
292 getUpperBound() != null) {
293 throw new MathUnsupportedOperationException(LocalizedFormats.CONSTRAINT);
294 }
295 }
296 }