001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math3.optim.nonlinear.vector.jacobian;
018    
019    import org.apache.commons.math3.exception.ConvergenceException;
020    import org.apache.commons.math3.exception.NullArgumentException;
021    import org.apache.commons.math3.exception.MathInternalError;
022    import org.apache.commons.math3.exception.util.LocalizedFormats;
023    import org.apache.commons.math3.linear.ArrayRealVector;
024    import org.apache.commons.math3.linear.BlockRealMatrix;
025    import org.apache.commons.math3.linear.DecompositionSolver;
026    import org.apache.commons.math3.linear.LUDecomposition;
027    import org.apache.commons.math3.linear.QRDecomposition;
028    import org.apache.commons.math3.linear.RealMatrix;
029    import org.apache.commons.math3.linear.SingularMatrixException;
030    import org.apache.commons.math3.optim.ConvergenceChecker;
031    import org.apache.commons.math3.optim.PointVectorValuePair;
032    
033    /**
034     * Gauss-Newton least-squares solver.
035     * <p>
036     * This class solve a least-square problem by solving the normal equations
037     * of the linearized problem at each iteration. Either LU decomposition or
038     * QR decomposition can be used to solve the normal equations. LU decomposition
039     * is faster but QR decomposition is more robust for difficult problems.
040     * </p>
041     *
042     * @version $Id: GaussNewtonOptimizer.java 1416643 2012-12-03 19:37:14Z tn $
043     * @since 2.0
044     *
045     */
046    public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
047        /** Indicator for using LU decomposition. */
048        private final boolean useLU;
049    
050        /**
051         * Simple constructor with default settings.
052         * The normal equations will be solved using LU decomposition.
053         *
054         * @param checker Convergence checker.
055         */
056        public GaussNewtonOptimizer(ConvergenceChecker<PointVectorValuePair> checker) {
057            this(true, checker);
058        }
059    
060        /**
061         * @param useLU If {@code true}, the normal equations will be solved
062         * using LU decomposition, otherwise they will be solved using QR
063         * decomposition.
064         * @param checker Convergence checker.
065         */
066        public GaussNewtonOptimizer(final boolean useLU,
067                                    ConvergenceChecker<PointVectorValuePair> checker) {
068            super(checker);
069            this.useLU = useLU;
070        }
071    
072        /** {@inheritDoc} */
073        @Override
074        public PointVectorValuePair doOptimize() {
075            final ConvergenceChecker<PointVectorValuePair> checker
076                = getConvergenceChecker();
077    
078            // Computation will be useless without a checker (see "for-loop").
079            if (checker == null) {
080                throw new NullArgumentException();
081            }
082    
083            final double[] targetValues = getTarget();
084            final int nR = targetValues.length; // Number of observed data.
085    
086            final RealMatrix weightMatrix = getWeight();
087            // Diagonal of the weight matrix.
088            final double[] residualsWeights = new double[nR];
089            for (int i = 0; i < nR; i++) {
090                residualsWeights[i] = weightMatrix.getEntry(i, i);
091            }
092    
093            final double[] currentPoint = getStartPoint();
094            final int nC = currentPoint.length;
095    
096            // iterate until convergence is reached
097            PointVectorValuePair current = null;
098            int iter = 0;
099            for (boolean converged = false; !converged;) {
100                ++iter;
101    
102                // evaluate the objective function and its jacobian
103                PointVectorValuePair previous = current;
104                // Value of the objective function at "currentPoint".
105                final double[] currentObjective = computeObjectiveValue(currentPoint);
106                final double[] currentResiduals = computeResiduals(currentObjective);
107                final RealMatrix weightedJacobian = computeWeightedJacobian(currentPoint);
108                current = new PointVectorValuePair(currentPoint, currentObjective);
109    
110                // build the linear problem
111                final double[]   b = new double[nC];
112                final double[][] a = new double[nC][nC];
113                for (int i = 0; i < nR; ++i) {
114    
115                    final double[] grad   = weightedJacobian.getRow(i);
116                    final double weight   = residualsWeights[i];
117                    final double residual = currentResiduals[i];
118    
119                    // compute the normal equation
120                    final double wr = weight * residual;
121                    for (int j = 0; j < nC; ++j) {
122                        b[j] += wr * grad[j];
123                    }
124    
125                    // build the contribution matrix for measurement i
126                    for (int k = 0; k < nC; ++k) {
127                        double[] ak = a[k];
128                        double wgk = weight * grad[k];
129                        for (int l = 0; l < nC; ++l) {
130                            ak[l] += wgk * grad[l];
131                        }
132                    }
133                }
134    
135                try {
136                    // solve the linearized least squares problem
137                    RealMatrix mA = new BlockRealMatrix(a);
138                    DecompositionSolver solver = useLU ?
139                            new LUDecomposition(mA).getSolver() :
140                            new QRDecomposition(mA).getSolver();
141                    final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray();
142                    // update the estimated parameters
143                    for (int i = 0; i < nC; ++i) {
144                        currentPoint[i] += dX[i];
145                    }
146                } catch (SingularMatrixException e) {
147                    throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);
148                }
149    
150                // Check convergence.
151                if (previous != null) {
152                    converged = checker.converged(iter, previous, current);
153                    if (converged) {
154                        setCost(computeCost(currentResiduals));
155                        return current;
156                    }
157                }
158            }
159            // Must never happen.
160            throw new MathInternalError();
161        }
162    }