001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.apache.commons.math.optimization.general;
019
020 import org.apache.commons.math.exception.ConvergenceException;
021 import org.apache.commons.math.exception.util.LocalizedFormats;
022 import org.apache.commons.math.linear.ArrayRealVector;
023 import org.apache.commons.math.linear.BlockRealMatrix;
024 import org.apache.commons.math.linear.DecompositionSolver;
025 import org.apache.commons.math.linear.LUDecomposition;
026 import org.apache.commons.math.linear.QRDecomposition;
027 import org.apache.commons.math.linear.RealMatrix;
028 import org.apache.commons.math.linear.SingularMatrixException;
029 import org.apache.commons.math.optimization.ConvergenceChecker;
030 import org.apache.commons.math.optimization.SimpleVectorialValueChecker;
031 import org.apache.commons.math.optimization.VectorialPointValuePair;
032
033 /**
034 * Gauss-Newton least-squares solver.
035 * <p>
036 * This class solve a least-square problem by solving the normal equations
037 * of the linearized problem at each iteration. Either LU decomposition or
038 * QR decomposition can be used to solve the normal equations. LU decomposition
039 * is faster but QR decomposition is more robust for difficult problems.
040 * </p>
041 *
042 * @version $Id: GaussNewtonOptimizer.java 1175100 2011-09-24 04:47:38Z celestin $
043 * @since 2.0
044 *
045 */
046
047 public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
048 /** Indicator for using LU decomposition. */
049 private final boolean useLU;
050
051 /**
052 * Simple constructor with default settings.
053 * The normal equations will be solved using LU decomposition and the
054 * convergence check is set to a {@link SimpleVectorialValueChecker}
055 * with default tolerances.
056 */
057 public GaussNewtonOptimizer() {
058 this(true);
059 }
060
061 /**
062 * Simple constructor with default settings.
063 * The normal equations will be solved using LU decomposition.
064 *
065 * @param checker Convergence checker.
066 */
067 public GaussNewtonOptimizer(ConvergenceChecker<VectorialPointValuePair> checker) {
068 this(true, checker);
069 }
070
071 /**
072 * Simple constructor with default settings.
073 * The convergence check is set to a {@link SimpleVectorialValueChecker}
074 * with default tolerances.
075 *
076 * @param useLU If {@code true}, the normal equations will be solved
077 * using LU decomposition, otherwise they will be solved using QR
078 * decomposition.
079 */
080 public GaussNewtonOptimizer(final boolean useLU) {
081 this(useLU, new SimpleVectorialValueChecker());
082 }
083
084 /**
085 * @param useLU If {@code true}, the normal equations will be solved
086 * using LU decomposition, otherwise they will be solved using QR
087 * decomposition.
088 * @param checker Convergence checker.
089 */
090 public GaussNewtonOptimizer(final boolean useLU,
091 ConvergenceChecker<VectorialPointValuePair> checker) {
092 super(checker);
093 this.useLU = useLU;
094 }
095
096 /** {@inheritDoc} */
097 @Override
098 public VectorialPointValuePair doOptimize() {
099
100 final ConvergenceChecker<VectorialPointValuePair> checker
101 = getConvergenceChecker();
102
103 // iterate until convergence is reached
104 VectorialPointValuePair current = null;
105 int iter = 0;
106 for (boolean converged = false; !converged;) {
107 ++iter;
108
109 // evaluate the objective function and its jacobian
110 VectorialPointValuePair previous = current;
111 updateResidualsAndCost();
112 updateJacobian();
113 current = new VectorialPointValuePair(point, objective);
114
115 final double[] targetValues = getTargetRef();
116 final double[] residualsWeights = getWeightRef();
117
118 // build the linear problem
119 final double[] b = new double[cols];
120 final double[][] a = new double[cols][cols];
121 for (int i = 0; i < rows; ++i) {
122
123 final double[] grad = weightedResidualJacobian[i];
124 final double weight = residualsWeights[i];
125 final double residual = objective[i] - targetValues[i];
126
127 // compute the normal equation
128 final double wr = weight * residual;
129 for (int j = 0; j < cols; ++j) {
130 b[j] += wr * grad[j];
131 }
132
133 // build the contribution matrix for measurement i
134 for (int k = 0; k < cols; ++k) {
135 double[] ak = a[k];
136 double wgk = weight * grad[k];
137 for (int l = 0; l < cols; ++l) {
138 ak[l] += wgk * grad[l];
139 }
140 }
141 }
142
143 try {
144 // solve the linearized least squares problem
145 RealMatrix mA = new BlockRealMatrix(a);
146 DecompositionSolver solver = useLU ?
147 new LUDecomposition(mA).getSolver() :
148 new QRDecomposition(mA).getSolver();
149 final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray();
150 // update the estimated parameters
151 for (int i = 0; i < cols; ++i) {
152 point[i] += dX[i];
153 }
154 } catch (SingularMatrixException e) {
155 throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);
156 }
157
158 // check convergence
159 if (checker != null) {
160 if (previous != null) {
161 converged = checker.converged(iter, previous, current);
162 }
163 }
164 }
165 // we have converged
166 return current;
167 }
168 }