1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.fitting.leastsquares;
18
19 import org.apache.commons.math4.legacy.analysis.MultivariateMatrixFunction;
20 import org.apache.commons.math4.legacy.analysis.MultivariateVectorFunction;
21 import org.apache.commons.math4.legacy.exception.MathIllegalStateException;
22 import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
23 import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem.Evaluation;
24 import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
25 import org.apache.commons.math4.legacy.linear.ArrayRealVector;
26 import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
27 import org.apache.commons.math4.legacy.linear.EigenDecomposition;
28 import org.apache.commons.math4.legacy.linear.RealMatrix;
29 import org.apache.commons.math4.legacy.linear.RealVector;
30 import org.apache.commons.math4.legacy.optim.AbstractOptimizationProblem;
31 import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
32 import org.apache.commons.math4.legacy.optim.PointVectorValuePair;
33 import org.apache.commons.math4.core.jdkmath.JdkMath;
34 import org.apache.commons.math4.legacy.core.IntegerSequence;
35 import org.apache.commons.math4.legacy.core.Pair;
36
37
38
39
40
41
42 public final class LeastSquaresFactory {
43
44
45 private LeastSquaresFactory() {}
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65 public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
66 final RealVector observed,
67 final RealVector start,
68 final RealMatrix weight,
69 final ConvergenceChecker<Evaluation> checker,
70 final int maxEvaluations,
71 final int maxIterations,
72 final boolean lazyEvaluation,
73 final ParameterValidator paramValidator) {
74 final LeastSquaresProblem p = new LocalLeastSquaresProblem(model,
75 observed,
76 start,
77 checker,
78 maxEvaluations,
79 maxIterations,
80 lazyEvaluation,
81 paramValidator);
82 if (weight != null) {
83 return weightMatrix(p, weight);
84 } else {
85 return p;
86 }
87 }
88
89
90
91
92
93
94
95
96
97
98
99
100
101 public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
102 final RealVector observed,
103 final RealVector start,
104 final ConvergenceChecker<Evaluation> checker,
105 final int maxEvaluations,
106 final int maxIterations) {
107 return create(model,
108 observed,
109 start,
110 null,
111 checker,
112 maxEvaluations,
113 maxIterations,
114 false,
115 null);
116 }
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131 public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
132 final RealVector observed,
133 final RealVector start,
134 final RealMatrix weight,
135 final ConvergenceChecker<Evaluation> checker,
136 final int maxEvaluations,
137 final int maxIterations) {
138 return weightMatrix(create(model,
139 observed,
140 start,
141 checker,
142 maxEvaluations,
143 maxIterations),
144 weight);
145 }
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166 public static LeastSquaresProblem create(final MultivariateVectorFunction model,
167 final MultivariateMatrixFunction jacobian,
168 final double[] observed,
169 final double[] start,
170 final RealMatrix weight,
171 final ConvergenceChecker<Evaluation> checker,
172 final int maxEvaluations,
173 final int maxIterations) {
174 return create(model(model, jacobian),
175 new ArrayRealVector(observed, false),
176 new ArrayRealVector(start, false),
177 weight,
178 checker,
179 maxEvaluations,
180 maxIterations);
181 }
182
183
184
185
186
187
188
189
190
191 public static LeastSquaresProblem weightMatrix(final LeastSquaresProblem problem,
192 final RealMatrix weights) {
193 final RealMatrix weightSquareRoot = squareRoot(weights);
194 return new LeastSquaresAdapter(problem) {
195
196 @Override
197 public Evaluation evaluate(final RealVector point) {
198 return new DenseWeightedEvaluation(super.evaluate(point), weightSquareRoot);
199 }
200 };
201 }
202
203
204
205
206
207
208
209
210
211 public static LeastSquaresProblem weightDiagonal(final LeastSquaresProblem problem,
212 final RealVector weights) {
213
214 return weightMatrix(problem, new DiagonalMatrix(weights.toArray()));
215 }
216
217
218
219
220
221
222
223
224
225
226 public static LeastSquaresProblem countEvaluations(final LeastSquaresProblem problem,
227 final IntegerSequence.Incrementor counter) {
228 return new LeastSquaresAdapter(problem) {
229
230
231 @Override
232 public Evaluation evaluate(final RealVector point) {
233 counter.increment();
234 return super.evaluate(point);
235 }
236
237
238 };
239 }
240
241
242
243
244
245
246
247
248 public static ConvergenceChecker<Evaluation> evaluationChecker(final ConvergenceChecker<PointVectorValuePair> checker) {
249 return new ConvergenceChecker<Evaluation>() {
250
251 @Override
252 public boolean converged(final int iteration,
253 final Evaluation previous,
254 final Evaluation current) {
255 return checker.converged(
256 iteration,
257 new PointVectorValuePair(
258 previous.getPoint().toArray(),
259 previous.getResiduals().toArray(),
260 false),
261 new PointVectorValuePair(
262 current.getPoint().toArray(),
263 current.getResiduals().toArray(),
264 false)
265 );
266 }
267 };
268 }
269
270
271
272
273
274
275
276 private static RealMatrix squareRoot(final RealMatrix m) {
277 if (m instanceof DiagonalMatrix) {
278 final int dim = m.getRowDimension();
279 final RealMatrix sqrtM = new DiagonalMatrix(dim);
280 for (int i = 0; i < dim; i++) {
281 sqrtM.setEntry(i, i, JdkMath.sqrt(m.getEntry(i, i)));
282 }
283 return sqrtM;
284 } else {
285 final EigenDecomposition dec = new EigenDecomposition(m);
286 return dec.getSquareRoot();
287 }
288 }
289
290
291
292
293
294
295
296
297
298 public static MultivariateJacobianFunction model(final MultivariateVectorFunction value,
299 final MultivariateMatrixFunction jacobian) {
300 return new LocalValueAndJacobianFunction(value, jacobian);
301 }
302
303
304
305
306
307 private static final class LocalValueAndJacobianFunction
308 implements ValueAndJacobianFunction {
309
310 private final MultivariateVectorFunction value;
311
312 private final MultivariateMatrixFunction jacobian;
313
314
315
316
317
318 LocalValueAndJacobianFunction(final MultivariateVectorFunction value,
319 final MultivariateMatrixFunction jacobian) {
320 this.value = value;
321 this.jacobian = jacobian;
322 }
323
324
325 @Override
326 public Pair<RealVector, RealMatrix> value(final RealVector point) {
327
328 final double[] p = point.toArray();
329
330
331 return new Pair<>(computeValue(p), computeJacobian(p));
332 }
333
334
335 @Override
336 public RealVector computeValue(final double[] params) {
337 return new ArrayRealVector(value.value(params), false);
338 }
339
340
341 @Override
342 public RealMatrix computeJacobian(final double[] params) {
343 return new Array2DRowRealMatrix(jacobian.value(params), false);
344 }
345 }
346
347
348
349
350
351
352
353 private static final class LocalLeastSquaresProblem
354 extends AbstractOptimizationProblem<Evaluation>
355 implements LeastSquaresProblem {
356
357
358 private final RealVector target;
359
360 private final MultivariateJacobianFunction model;
361
362 private final RealVector start;
363
364 private final boolean lazyEvaluation;
365
366 private final ParameterValidator paramValidator;
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381 LocalLeastSquaresProblem(final MultivariateJacobianFunction model,
382 final RealVector target,
383 final RealVector start,
384 final ConvergenceChecker<Evaluation> checker,
385 final int maxEvaluations,
386 final int maxIterations,
387 final boolean lazyEvaluation,
388 final ParameterValidator paramValidator) {
389 super(maxEvaluations, maxIterations, checker);
390 this.target = target;
391 this.model = model;
392 this.start = start;
393 this.lazyEvaluation = lazyEvaluation;
394 this.paramValidator = paramValidator;
395
396 if (lazyEvaluation &&
397 !(model instanceof ValueAndJacobianFunction)) {
398
399
400 throw new MathIllegalStateException(LocalizedFormats.INVALID_IMPLEMENTATION,
401 model.getClass().getName());
402 }
403 }
404
405
406 @Override
407 public int getObservationSize() {
408 return target.getDimension();
409 }
410
411
412 @Override
413 public int getParameterSize() {
414 return start.getDimension();
415 }
416
417
418 @Override
419 public RealVector getStart() {
420 return start == null ? null : start.copy();
421 }
422
423
424 @Override
425 public Evaluation evaluate(final RealVector point) {
426
427 final RealVector p = paramValidator == null ?
428 point.copy() :
429 paramValidator.validate(point.copy());
430
431 if (lazyEvaluation) {
432 return new LazyUnweightedEvaluation((ValueAndJacobianFunction) model,
433 target,
434 p);
435 } else {
436
437 final Pair<RealVector, RealMatrix> value = model.value(p);
438 return new UnweightedEvaluation(value.getFirst(),
439 value.getSecond(),
440 target,
441 p);
442 }
443 }
444
445
446
447
448 private static final class UnweightedEvaluation extends AbstractEvaluation {
449
450 private final RealVector point;
451
452 private final RealMatrix jacobian;
453
454 private final RealVector residuals;
455
456
457
458
459
460
461
462
463
464 private UnweightedEvaluation(final RealVector values,
465 final RealMatrix jacobian,
466 final RealVector target,
467 final RealVector point) {
468 super(target.getDimension());
469 this.jacobian = jacobian;
470 this.point = point;
471 this.residuals = target.subtract(values);
472 }
473
474
475 @Override
476 public RealMatrix getJacobian() {
477 return jacobian;
478 }
479
480
481 @Override
482 public RealVector getPoint() {
483 return point;
484 }
485
486
487 @Override
488 public RealVector getResiduals() {
489 return residuals;
490 }
491 }
492
493
494
495
496 private static final class LazyUnweightedEvaluation extends AbstractEvaluation {
497
498 private final RealVector point;
499
500 private final ValueAndJacobianFunction model;
501
502 private final RealVector target;
503
504
505
506
507
508
509
510
511 private LazyUnweightedEvaluation(final ValueAndJacobianFunction model,
512 final RealVector target,
513 final RealVector point) {
514 super(target.getDimension());
515
516 this.model = model;
517 this.point = point;
518 this.target = target;
519 }
520
521
522 @Override
523 public RealMatrix getJacobian() {
524 return model.computeJacobian(point.toArray());
525 }
526
527
528 @Override
529 public RealVector getPoint() {
530 return point;
531 }
532
533
534 @Override
535 public RealVector getResiduals() {
536 return target.subtract(model.computeValue(point.toArray()));
537 }
538 }
539 }
540 }
541