001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math3.optimization.general;
019    
020    import org.apache.commons.math3.exception.ConvergenceException;
021    import org.apache.commons.math3.exception.NullArgumentException;
022    import org.apache.commons.math3.exception.MathInternalError;
023    import org.apache.commons.math3.exception.util.LocalizedFormats;
024    import org.apache.commons.math3.linear.ArrayRealVector;
025    import org.apache.commons.math3.linear.BlockRealMatrix;
026    import org.apache.commons.math3.linear.DecompositionSolver;
027    import org.apache.commons.math3.linear.LUDecomposition;
028    import org.apache.commons.math3.linear.QRDecomposition;
029    import org.apache.commons.math3.linear.RealMatrix;
030    import org.apache.commons.math3.linear.SingularMatrixException;
031    import org.apache.commons.math3.optimization.ConvergenceChecker;
032    import org.apache.commons.math3.optimization.SimpleVectorValueChecker;
033    import org.apache.commons.math3.optimization.PointVectorValuePair;
034    
035    /**
036     * Gauss-Newton least-squares solver.
037     * <p>
038     * This class solve a least-square problem by solving the normal equations
039     * of the linearized problem at each iteration. Either LU decomposition or
040     * QR decomposition can be used to solve the normal equations. LU decomposition
041     * is faster but QR decomposition is more robust for difficult problems.
042     * </p>
043     *
044     * @version $Id: GaussNewtonOptimizer.java 1423687 2012-12-18 21:56:18Z erans $
045     * @deprecated As of 3.1 (to be removed in 4.0).
046     * @since 2.0
047     *
048     */
049    @Deprecated
050    public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
051        /** Indicator for using LU decomposition. */
052        private final boolean useLU;
053    
054        /**
055         * Simple constructor with default settings.
056         * The normal equations will be solved using LU decomposition and the
057         * convergence check is set to a {@link SimpleVectorValueChecker}
058         * with default tolerances.
059         * @deprecated See {@link SimpleVectorValueChecker#SimpleVectorValueChecker()}
060         */
061        @Deprecated
062        public GaussNewtonOptimizer() {
063            this(true);
064        }
065    
066        /**
067         * Simple constructor with default settings.
068         * The normal equations will be solved using LU decomposition.
069         *
070         * @param checker Convergence checker.
071         */
072        public GaussNewtonOptimizer(ConvergenceChecker<PointVectorValuePair> checker) {
073            this(true, checker);
074        }
075    
076        /**
077         * Simple constructor with default settings.
078         * The convergence check is set to a {@link SimpleVectorValueChecker}
079         * with default tolerances.
080         *
081         * @param useLU If {@code true}, the normal equations will be solved
082         * using LU decomposition, otherwise they will be solved using QR
083         * decomposition.
084         * @deprecated See {@link SimpleVectorValueChecker#SimpleVectorValueChecker()}
085         */
086        @Deprecated
087        public GaussNewtonOptimizer(final boolean useLU) {
088            this(useLU, new SimpleVectorValueChecker());
089        }
090    
091        /**
092         * @param useLU If {@code true}, the normal equations will be solved
093         * using LU decomposition, otherwise they will be solved using QR
094         * decomposition.
095         * @param checker Convergence checker.
096         */
097        public GaussNewtonOptimizer(final boolean useLU,
098                                    ConvergenceChecker<PointVectorValuePair> checker) {
099            super(checker);
100            this.useLU = useLU;
101        }
102    
103        /** {@inheritDoc} */
104        @Override
105        public PointVectorValuePair doOptimize() {
106            final ConvergenceChecker<PointVectorValuePair> checker
107                = getConvergenceChecker();
108    
109            // Computation will be useless without a checker (see "for-loop").
110            if (checker == null) {
111                throw new NullArgumentException();
112            }
113    
114            final double[] targetValues = getTarget();
115            final int nR = targetValues.length; // Number of observed data.
116    
117            final RealMatrix weightMatrix = getWeight();
118            // Diagonal of the weight matrix.
119            final double[] residualsWeights = new double[nR];
120            for (int i = 0; i < nR; i++) {
121                residualsWeights[i] = weightMatrix.getEntry(i, i);
122            }
123    
124            final double[] currentPoint = getStartPoint();
125            final int nC = currentPoint.length;
126    
127            // iterate until convergence is reached
128            PointVectorValuePair current = null;
129            int iter = 0;
130            for (boolean converged = false; !converged;) {
131                ++iter;
132    
133                // evaluate the objective function and its jacobian
134                PointVectorValuePair previous = current;
135                // Value of the objective function at "currentPoint".
136                final double[] currentObjective = computeObjectiveValue(currentPoint);
137                final double[] currentResiduals = computeResiduals(currentObjective);
138                final RealMatrix weightedJacobian = computeWeightedJacobian(currentPoint);
139                current = new PointVectorValuePair(currentPoint, currentObjective);
140    
141                // build the linear problem
142                final double[]   b = new double[nC];
143                final double[][] a = new double[nC][nC];
144                for (int i = 0; i < nR; ++i) {
145    
146                    final double[] grad   = weightedJacobian.getRow(i);
147                    final double weight   = residualsWeights[i];
148                    final double residual = currentResiduals[i];
149    
150                    // compute the normal equation
151                    final double wr = weight * residual;
152                    for (int j = 0; j < nC; ++j) {
153                        b[j] += wr * grad[j];
154                    }
155    
156                    // build the contribution matrix for measurement i
157                    for (int k = 0; k < nC; ++k) {
158                        double[] ak = a[k];
159                        double wgk = weight * grad[k];
160                        for (int l = 0; l < nC; ++l) {
161                            ak[l] += wgk * grad[l];
162                        }
163                    }
164                }
165    
166                try {
167                    // solve the linearized least squares problem
168                    RealMatrix mA = new BlockRealMatrix(a);
169                    DecompositionSolver solver = useLU ?
170                            new LUDecomposition(mA).getSolver() :
171                            new QRDecomposition(mA).getSolver();
172                    final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray();
173                    // update the estimated parameters
174                    for (int i = 0; i < nC; ++i) {
175                        currentPoint[i] += dX[i];
176                    }
177                } catch (SingularMatrixException e) {
178                    throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);
179                }
180    
181                // Check convergence.
182                if (previous != null) {
183                    converged = checker.converged(iter, previous, current);
184                    if (converged) {
185                        cost = computeCost(currentResiduals);
186                        // Update (deprecated) "point" field.
187                        point = current.getPoint();
188                        return current;
189                    }
190                }
191            }
192            // Must never happen.
193            throw new MathInternalError();
194        }
195    }