1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math3.optim.nonlinear.vector.jacobian;
18
19 import org.apache.commons.math3.exception.ConvergenceException;
20 import org.apache.commons.math3.exception.NullArgumentException;
21 import org.apache.commons.math3.exception.MathInternalError;
22 import org.apache.commons.math3.exception.MathUnsupportedOperationException;
23 import org.apache.commons.math3.exception.util.LocalizedFormats;
24 import org.apache.commons.math3.linear.ArrayRealVector;
25 import org.apache.commons.math3.linear.BlockRealMatrix;
26 import org.apache.commons.math3.linear.DecompositionSolver;
27 import org.apache.commons.math3.linear.LUDecomposition;
28 import org.apache.commons.math3.linear.QRDecomposition;
29 import org.apache.commons.math3.linear.RealMatrix;
30 import org.apache.commons.math3.linear.SingularMatrixException;
31 import org.apache.commons.math3.optim.ConvergenceChecker;
32 import org.apache.commons.math3.optim.PointVectorValuePair;
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52 public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
53
54 private final boolean useLU;
55
56
57
58
59
60
61
62 public GaussNewtonOptimizer(ConvergenceChecker<PointVectorValuePair> checker) {
63 this(true, checker);
64 }
65
66
67
68
69
70
71
72 public GaussNewtonOptimizer(final boolean useLU,
73 ConvergenceChecker<PointVectorValuePair> checker) {
74 super(checker);
75 this.useLU = useLU;
76 }
77
78
79 @Override
80 public PointVectorValuePair doOptimize() {
81 checkParameters();
82
83 final ConvergenceChecker<PointVectorValuePair> checker
84 = getConvergenceChecker();
85
86
87 if (checker == null) {
88 throw new NullArgumentException();
89 }
90
91 final double[] targetValues = getTarget();
92 final int nR = targetValues.length;
93
94 final RealMatrix weightMatrix = getWeight();
95
96 final double[] residualsWeights = new double[nR];
97 for (int i = 0; i < nR; i++) {
98 residualsWeights[i] = weightMatrix.getEntry(i, i);
99 }
100
101 final double[] currentPoint = getStartPoint();
102 final int nC = currentPoint.length;
103
104
105 PointVectorValuePair current = null;
106 for (boolean converged = false; !converged;) {
107 incrementIterationCount();
108
109
110 PointVectorValuePair previous = current;
111
112 final double[] currentObjective = computeObjectiveValue(currentPoint);
113 final double[] currentResiduals = computeResiduals(currentObjective);
114 final RealMatrix weightedJacobian = computeWeightedJacobian(currentPoint);
115 current = new PointVectorValuePair(currentPoint, currentObjective);
116
117
118 final double[] b = new double[nC];
119 final double[][] a = new double[nC][nC];
120 for (int i = 0; i < nR; ++i) {
121
122 final double[] grad = weightedJacobian.getRow(i);
123 final double weight = residualsWeights[i];
124 final double residual = currentResiduals[i];
125
126
127 final double wr = weight * residual;
128 for (int j = 0; j < nC; ++j) {
129 b[j] += wr * grad[j];
130 }
131
132
133 for (int k = 0; k < nC; ++k) {
134 double[] ak = a[k];
135 double wgk = weight * grad[k];
136 for (int l = 0; l < nC; ++l) {
137 ak[l] += wgk * grad[l];
138 }
139 }
140 }
141
142 try {
143
144 RealMatrix mA = new BlockRealMatrix(a);
145 DecompositionSolver solver = useLU ?
146 new LUDecomposition(mA).getSolver() :
147 new QRDecomposition(mA).getSolver();
148 final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray();
149
150 for (int i = 0; i < nC; ++i) {
151 currentPoint[i] += dX[i];
152 }
153 } catch (SingularMatrixException e) {
154 throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);
155 }
156
157
158 if (previous != null) {
159 converged = checker.converged(getIterations(), previous, current);
160 if (converged) {
161 setCost(computeCost(currentResiduals));
162 return current;
163 }
164 }
165 }
166
167 throw new MathInternalError();
168 }
169
170
171
172
173
174 private void checkParameters() {
175 if (getLowerBound() != null ||
176 getUpperBound() != null) {
177 throw new MathUnsupportedOperationException(LocalizedFormats.CONSTRAINT);
178 }
179 }
180 }