001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.optimization.general;
019    
020    import org.apache.commons.math.exception.ConvergenceException;
021    import org.apache.commons.math.exception.util.LocalizedFormats;
022    import org.apache.commons.math.linear.ArrayRealVector;
023    import org.apache.commons.math.linear.BlockRealMatrix;
024    import org.apache.commons.math.linear.DecompositionSolver;
025    import org.apache.commons.math.linear.LUDecomposition;
026    import org.apache.commons.math.linear.QRDecomposition;
027    import org.apache.commons.math.linear.RealMatrix;
028    import org.apache.commons.math.linear.SingularMatrixException;
029    import org.apache.commons.math.optimization.ConvergenceChecker;
030    import org.apache.commons.math.optimization.SimpleVectorialValueChecker;
031    import org.apache.commons.math.optimization.VectorialPointValuePair;
032    
033    /**
034     * Gauss-Newton least-squares solver.
035     * <p>
036     * This class solve a least-square problem by solving the normal equations
037     * of the linearized problem at each iteration. Either LU decomposition or
038     * QR decomposition can be used to solve the normal equations. LU decomposition
039     * is faster but QR decomposition is more robust for difficult problems.
040     * </p>
041     *
042     * @version $Id: GaussNewtonOptimizer.java 1175100 2011-09-24 04:47:38Z celestin $
043     * @since 2.0
044     *
045     */
046    
047    public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
048        /** Indicator for using LU decomposition. */
049        private final boolean useLU;
050    
051        /**
052         * Simple constructor with default settings.
053         * The normal equations will be solved using LU decomposition and the
054         * convergence check is set to a {@link SimpleVectorialValueChecker}
055         * with default tolerances.
056         */
057        public GaussNewtonOptimizer() {
058            this(true);
059        }
060    
061        /**
062         * Simple constructor with default settings.
063         * The normal equations will be solved using LU decomposition.
064         *
065         * @param checker Convergence checker.
066         */
067        public GaussNewtonOptimizer(ConvergenceChecker<VectorialPointValuePair> checker) {
068            this(true, checker);
069        }
070    
071        /**
072         * Simple constructor with default settings.
073         * The convergence check is set to a {@link SimpleVectorialValueChecker}
074         * with default tolerances.
075         *
076         * @param useLU If {@code true}, the normal equations will be solved
077         * using LU decomposition, otherwise they will be solved using QR
078         * decomposition.
079         */
080        public GaussNewtonOptimizer(final boolean useLU) {
081            this(useLU, new SimpleVectorialValueChecker());
082        }
083    
084        /**
085         * @param useLU If {@code true}, the normal equations will be solved
086         * using LU decomposition, otherwise they will be solved using QR
087         * decomposition.
088         * @param checker Convergence checker.
089         */
090        public GaussNewtonOptimizer(final boolean useLU,
091                                    ConvergenceChecker<VectorialPointValuePair> checker) {
092            super(checker);
093            this.useLU = useLU;
094        }
095    
096        /** {@inheritDoc} */
097        @Override
098        public VectorialPointValuePair doOptimize() {
099    
100            final ConvergenceChecker<VectorialPointValuePair> checker
101                = getConvergenceChecker();
102    
103            // iterate until convergence is reached
104            VectorialPointValuePair current = null;
105            int iter = 0;
106            for (boolean converged = false; !converged;) {
107                ++iter;
108    
109                // evaluate the objective function and its jacobian
110                VectorialPointValuePair previous = current;
111                updateResidualsAndCost();
112                updateJacobian();
113                current = new VectorialPointValuePair(point, objective);
114    
115                final double[] targetValues = getTargetRef();
116                final double[] residualsWeights = getWeightRef();
117    
118                // build the linear problem
119                final double[]   b = new double[cols];
120                final double[][] a = new double[cols][cols];
121                for (int i = 0; i < rows; ++i) {
122    
123                    final double[] grad   = weightedResidualJacobian[i];
124                    final double weight   = residualsWeights[i];
125                    final double residual = objective[i] - targetValues[i];
126    
127                    // compute the normal equation
128                    final double wr = weight * residual;
129                    for (int j = 0; j < cols; ++j) {
130                        b[j] += wr * grad[j];
131                    }
132    
133                    // build the contribution matrix for measurement i
134                    for (int k = 0; k < cols; ++k) {
135                        double[] ak = a[k];
136                        double wgk = weight * grad[k];
137                        for (int l = 0; l < cols; ++l) {
138                            ak[l] += wgk * grad[l];
139                        }
140                    }
141                }
142    
143                try {
144                    // solve the linearized least squares problem
145                    RealMatrix mA = new BlockRealMatrix(a);
146                    DecompositionSolver solver = useLU ?
147                            new LUDecomposition(mA).getSolver() :
148                            new QRDecomposition(mA).getSolver();
149                    final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray();
150                    // update the estimated parameters
151                    for (int i = 0; i < cols; ++i) {
152                        point[i] += dX[i];
153                    }
154                } catch (SingularMatrixException e) {
155                    throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);
156                }
157    
158                // check convergence
159                if (checker != null) {
160                    if (previous != null) {
161                        converged = checker.converged(iter, previous, current);
162                    }
163                }
164            }
165            // we have converged
166            return current;
167        }
168    }