1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.fitting.leastsquares;
18
19 import org.apache.commons.math4.legacy.exception.ConvergenceException;
20 import org.apache.commons.math4.legacy.exception.NullArgumentException;
21 import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
22 import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem.Evaluation;
23 import org.apache.commons.math4.legacy.linear.ArrayRealVector;
24 import org.apache.commons.math4.legacy.linear.CholeskyDecomposition;
25 import org.apache.commons.math4.legacy.linear.LUDecomposition;
26 import org.apache.commons.math4.legacy.linear.MatrixUtils;
27 import org.apache.commons.math4.legacy.linear.NonPositiveDefiniteMatrixException;
28 import org.apache.commons.math4.legacy.linear.QRDecomposition;
29 import org.apache.commons.math4.legacy.linear.RealMatrix;
30 import org.apache.commons.math4.legacy.linear.RealVector;
31 import org.apache.commons.math4.legacy.linear.SingularMatrixException;
32 import org.apache.commons.math4.legacy.linear.SingularValueDecomposition;
33 import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
34 import org.apache.commons.math4.legacy.core.IntegerSequence;
35 import org.apache.commons.math4.legacy.core.Pair;
36
37
38
39
40
41
42
43
44
45
46
47
48
49 public class GaussNewtonOptimizer implements LeastSquaresOptimizer {
50
51
52
53 public enum Decomposition {
54
55
56
57
58
59
60
61
62 LU {
63 @Override
64 protected RealVector solve(final RealMatrix jacobian,
65 final RealVector residuals) {
66 try {
67 final Pair<RealMatrix, RealVector> normalEquation =
68 computeNormalMatrix(jacobian, residuals);
69 final RealMatrix normal = normalEquation.getFirst();
70 final RealVector jTr = normalEquation.getSecond();
71 return new LUDecomposition(normal, SINGULARITY_THRESHOLD)
72 .getSolver()
73 .solve(jTr);
74 } catch (SingularMatrixException e) {
75 throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM, e);
76 }
77 }
78 },
79
80
81
82
83
84
85
86
87 QR {
88 @Override
89 protected RealVector solve(final RealMatrix jacobian,
90 final RealVector residuals) {
91 try {
92 return new QRDecomposition(jacobian, SINGULARITY_THRESHOLD)
93 .getSolver()
94 .solve(residuals);
95 } catch (SingularMatrixException e) {
96 throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM, e);
97 }
98 }
99 },
100
101
102
103
104
105
106
107
108 CHOLESKY {
109 @Override
110 protected RealVector solve(final RealMatrix jacobian,
111 final RealVector residuals) {
112 try {
113 final Pair<RealMatrix, RealVector> normalEquation =
114 computeNormalMatrix(jacobian, residuals);
115 final RealMatrix normal = normalEquation.getFirst();
116 final RealVector jTr = normalEquation.getSecond();
117 return new CholeskyDecomposition(
118 normal, SINGULARITY_THRESHOLD, SINGULARITY_THRESHOLD)
119 .getSolver()
120 .solve(jTr);
121 } catch (NonPositiveDefiniteMatrixException e) {
122 throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM, e);
123 }
124 }
125 },
126
127
128
129
130
131
132
133 SVD {
134 @Override
135 protected RealVector solve(final RealMatrix jacobian,
136 final RealVector residuals) {
137 return new SingularValueDecomposition(jacobian)
138 .getSolver()
139 .solve(residuals);
140 }
141 };
142
143
144
145
146
147
148
149
150
151
152
153 protected abstract RealVector solve(RealMatrix jacobian,
154 RealVector residuals);
155 }
156
157
158
159
160
161
162 private static final double SINGULARITY_THRESHOLD = 1e-11;
163
164
165 private final Decomposition decomposition;
166
167
168
169
170
171
172
173 public GaussNewtonOptimizer() {
174 this(Decomposition.QR);
175 }
176
177
178
179
180
181
182
183 public GaussNewtonOptimizer(final Decomposition decomposition) {
184 this.decomposition = decomposition;
185 }
186
187
188
189
190
191
192 public Decomposition getDecomposition() {
193 return this.decomposition;
194 }
195
196
197
198
199
200
201
202 public GaussNewtonOptimizer withDecomposition(final Decomposition newDecomposition) {
203 return new GaussNewtonOptimizer(newDecomposition);
204 }
205
206
207 @Override
208 public Optimum optimize(final LeastSquaresProblem lsp) {
209
210 final IntegerSequence.Incrementor evaluationCounter = lsp.getEvaluationCounter();
211 final IntegerSequence.Incrementor iterationCounter = lsp.getIterationCounter();
212 final ConvergenceChecker<Evaluation> checker
213 = lsp.getConvergenceChecker();
214
215
216 if (checker == null) {
217 throw new NullArgumentException();
218 }
219
220 RealVector currentPoint = lsp.getStart();
221
222
223 Evaluation current = null;
224 while (true) {
225 iterationCounter.increment();
226
227
228 Evaluation previous = current;
229
230 evaluationCounter.increment();
231 current = lsp.evaluate(currentPoint);
232 final RealVector currentResiduals = current.getResiduals();
233 final RealMatrix weightedJacobian = current.getJacobian();
234 currentPoint = current.getPoint();
235
236
237 if (previous != null &&
238 checker.converged(iterationCounter.getCount(), previous, current)) {
239 return new OptimumImpl(current,
240 evaluationCounter.getCount(),
241 iterationCounter.getCount());
242 }
243
244
245 final RealVector dX = this.decomposition.solve(weightedJacobian, currentResiduals);
246
247 currentPoint = currentPoint.add(dX);
248 }
249 }
250
251
252 @Override
253 public String toString() {
254 return "GaussNewtonOptimizer{" +
255 "decomposition=" + decomposition +
256 '}';
257 }
258
259
260
261
262
263
264
265
266 private static Pair<RealMatrix, RealVector> computeNormalMatrix(final RealMatrix jacobian,
267 final RealVector residuals) {
268
269 final int nR = jacobian.getRowDimension();
270 final int nC = jacobian.getColumnDimension();
271
272 final RealMatrix normal = MatrixUtils.createRealMatrix(nC, nC);
273 final RealVector jTr = new ArrayRealVector(nC);
274
275 for (int i = 0; i < nR; ++i) {
276
277 for (int j = 0; j < nC; j++) {
278 jTr.setEntry(j, jTr.getEntry(j) +
279 residuals.getEntry(i) * jacobian.getEntry(i, j));
280 }
281
282
283 for (int k = 0; k < nC; ++k) {
284
285 for (int l = k; l < nC; ++l) {
286 normal.setEntry(k, l, normal.getEntry(k, l) +
287 jacobian.getEntry(i, k) * jacobian.getEntry(i, l));
288 }
289 }
290 }
291
292 for (int i = 0; i < nC; i++) {
293 for (int j = 0; j < i; j++) {
294 normal.setEntry(i, j, normal.getEntry(j, i));
295 }
296 }
297 return new Pair<>(normal, jTr);
298 }
299 }