GaussNewtonOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.fitting.leastsquares;

  18. import org.apache.commons.math4.legacy.exception.ConvergenceException;
  19. import org.apache.commons.math4.legacy.exception.NullArgumentException;
  20. import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
  21. import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem.Evaluation;
  22. import org.apache.commons.math4.legacy.linear.ArrayRealVector;
  23. import org.apache.commons.math4.legacy.linear.CholeskyDecomposition;
  24. import org.apache.commons.math4.legacy.linear.LUDecomposition;
  25. import org.apache.commons.math4.legacy.linear.MatrixUtils;
  26. import org.apache.commons.math4.legacy.linear.NonPositiveDefiniteMatrixException;
  27. import org.apache.commons.math4.legacy.linear.QRDecomposition;
  28. import org.apache.commons.math4.legacy.linear.RealMatrix;
  29. import org.apache.commons.math4.legacy.linear.RealVector;
  30. import org.apache.commons.math4.legacy.linear.SingularMatrixException;
  31. import org.apache.commons.math4.legacy.linear.SingularValueDecomposition;
  32. import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
  33. import org.apache.commons.math4.legacy.core.IntegerSequence;
  34. import org.apache.commons.math4.legacy.core.Pair;

  35. /**
  36.  * Gauss-Newton least-squares solver.
  37.  * <p> This class solve a least-square problem by
  38.  * solving the normal equations of the linearized problem at each iteration. Either LU
  39.  * decomposition or Cholesky decomposition can be used to solve the normal equations,
  40.  * or QR decomposition or SVD decomposition can be used to solve the linear system. LU
  41.  * decomposition is faster but QR decomposition is more robust for difficult problems,
  42.  * and SVD can compute a solution for rank-deficient problems.
  43.  * </p>
  44.  *
  45.  * @since 3.3
  46.  */
  47. public class GaussNewtonOptimizer implements LeastSquaresOptimizer {

  48.     /** The decomposition algorithm to use to solve the normal equations. */
  49.     //TODO move to linear package and expand options?
  50.     public enum Decomposition {
  51.         /**
  52.          * Solve by forming the normal equations (J<sup>T</sup>Jx=J<sup>T</sup>r) and
  53.          * using the {@link LUDecomposition}.
  54.          *
  55.          * <p> Theoretically this method takes mn<sup>2</sup>/2 operations to compute the
  56.          * normal matrix and n<sup>3</sup>/3 operations (m &gt; n) to solve the system using
  57.          * the LU decomposition. </p>
  58.          */
  59.         LU {
  60.             @Override
  61.             protected RealVector solve(final RealMatrix jacobian,
  62.                                        final RealVector residuals) {
  63.                 try {
  64.                     final Pair<RealMatrix, RealVector> normalEquation =
  65.                             computeNormalMatrix(jacobian, residuals);
  66.                     final RealMatrix normal = normalEquation.getFirst();
  67.                     final RealVector jTr = normalEquation.getSecond();
  68.                     return new LUDecomposition(normal, SINGULARITY_THRESHOLD)
  69.                             .getSolver()
  70.                             .solve(jTr);
  71.                 } catch (SingularMatrixException e) {
  72.                     throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM, e);
  73.                 }
  74.             }
  75.         },
  76.         /**
  77.          * Solve the linear least squares problem (Jx=r) using the {@link
  78.          * QRDecomposition}.
  79.          *
  80.          * <p> Theoretically this method takes mn<sup>2</sup> - n<sup>3</sup>/3 operations
  81.          * (m &gt; n) and has better numerical accuracy than any method that forms the normal
  82.          * equations. </p>
  83.          */
  84.         QR {
  85.             @Override
  86.             protected RealVector solve(final RealMatrix jacobian,
  87.                                        final RealVector residuals) {
  88.                 try {
  89.                     return new QRDecomposition(jacobian, SINGULARITY_THRESHOLD)
  90.                             .getSolver()
  91.                             .solve(residuals);
  92.                 } catch (SingularMatrixException e) {
  93.                     throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM, e);
  94.                 }
  95.             }
  96.         },
  97.         /**
  98.          * Solve by forming the normal equations (J<sup>T</sup>Jx=J<sup>T</sup>r) and
  99.          * using the {@link CholeskyDecomposition}.
  100.          *
  101.          * <p> Theoretically this method takes mn<sup>2</sup>/2 operations to compute the
  102.          * normal matrix and n<sup>3</sup>/6 operations (m &gt; n) to solve the system using
  103.          * the Cholesky decomposition. </p>
  104.          */
  105.         CHOLESKY {
  106.             @Override
  107.             protected RealVector solve(final RealMatrix jacobian,
  108.                                        final RealVector residuals) {
  109.                 try {
  110.                     final Pair<RealMatrix, RealVector> normalEquation =
  111.                             computeNormalMatrix(jacobian, residuals);
  112.                     final RealMatrix normal = normalEquation.getFirst();
  113.                     final RealVector jTr = normalEquation.getSecond();
  114.                     return new CholeskyDecomposition(
  115.                             normal, SINGULARITY_THRESHOLD, SINGULARITY_THRESHOLD)
  116.                             .getSolver()
  117.                             .solve(jTr);
  118.                 } catch (NonPositiveDefiniteMatrixException e) {
  119.                     throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM, e);
  120.                 }
  121.             }
  122.         },
  123.         /**
  124.          * Solve the linear least squares problem using the {@link
  125.          * SingularValueDecomposition}.
  126.          *
  127.          * <p> This method is slower, but can provide a solution for rank deficient and
  128.          * nearly singular systems.
  129.          */
  130.         SVD {
  131.             @Override
  132.             protected RealVector solve(final RealMatrix jacobian,
  133.                                        final RealVector residuals) {
  134.                 return new SingularValueDecomposition(jacobian)
  135.                         .getSolver()
  136.                         .solve(residuals);
  137.             }
  138.         };

  139.         /**
  140.          * Solve the linear least squares problem Jx=r.
  141.          *
  142.          * @param jacobian  the Jacobian matrix, J. the number of rows &gt;= the number or
  143.          *                  columns.
  144.          * @param residuals the computed residuals, r.
  145.          * @return the solution x, to the linear least squares problem Jx=r.
  146.          * @throws ConvergenceException if the matrix properties (e.g. singular) do not
  147.          *                              permit a solution.
  148.          */
  149.         protected abstract RealVector solve(RealMatrix jacobian,
  150.                                             RealVector residuals);
  151.     }

  152.     /**
  153.      * The singularity threshold for matrix decompositions. Determines when a {@link
  154.      * ConvergenceException} is thrown. The current value was the default value for {@link
  155.      * LUDecomposition}.
  156.      */
  157.     private static final double SINGULARITY_THRESHOLD = 1e-11;

  158.     /** Indicator for using LU decomposition. */
  159.     private final Decomposition decomposition;

  160.     /**
  161.      * Creates a Gauss Newton optimizer.
  162.      * <p>
  163.      * The default for the algorithm is to solve the normal equations using QR
  164.      * decomposition.
  165.      */
  166.     public GaussNewtonOptimizer() {
  167.         this(Decomposition.QR);
  168.     }

  169.     /**
  170.      * Create a Gauss Newton optimizer that uses the given decomposition algorithm to
  171.      * solve the normal equations.
  172.      *
  173.      * @param decomposition the {@link Decomposition} algorithm.
  174.      */
  175.     public GaussNewtonOptimizer(final Decomposition decomposition) {
  176.         this.decomposition = decomposition;
  177.     }

  178.     /**
  179.      * Get the matrix decomposition algorithm used to solve the normal equations.
  180.      *
  181.      * @return the matrix {@link Decomposition} algoritm.
  182.      */
  183.     public Decomposition getDecomposition() {
  184.         return this.decomposition;
  185.     }

  186.     /**
  187.      * Configure the decomposition algorithm.
  188.      *
  189.      * @param newDecomposition the {@link Decomposition} algorithm to use.
  190.      * @return a new instance.
  191.      */
  192.     public GaussNewtonOptimizer withDecomposition(final Decomposition newDecomposition) {
  193.         return new GaussNewtonOptimizer(newDecomposition);
  194.     }

  195.     /** {@inheritDoc} */
  196.     @Override
  197.     public Optimum optimize(final LeastSquaresProblem lsp) {
  198.         //create local evaluation and iteration counts
  199.         final IntegerSequence.Incrementor evaluationCounter = lsp.getEvaluationCounter();
  200.         final IntegerSequence.Incrementor iterationCounter = lsp.getIterationCounter();
  201.         final ConvergenceChecker<Evaluation> checker
  202.                 = lsp.getConvergenceChecker();

  203.         // Computation will be useless without a checker (see "for-loop").
  204.         if (checker == null) {
  205.             throw new NullArgumentException();
  206.         }

  207.         RealVector currentPoint = lsp.getStart();

  208.         // iterate until convergence is reached
  209.         Evaluation current = null;
  210.         while (true) {
  211.             iterationCounter.increment();

  212.             // evaluate the objective function and its jacobian
  213.             Evaluation previous = current;
  214.             // Value of the objective function at "currentPoint".
  215.             evaluationCounter.increment();
  216.             current = lsp.evaluate(currentPoint);
  217.             final RealVector currentResiduals = current.getResiduals();
  218.             final RealMatrix weightedJacobian = current.getJacobian();
  219.             currentPoint = current.getPoint();

  220.             // Check convergence.
  221.             if (previous != null &&
  222.                 checker.converged(iterationCounter.getCount(), previous, current)) {
  223.                 return new OptimumImpl(current,
  224.                                        evaluationCounter.getCount(),
  225.                                        iterationCounter.getCount());
  226.             }

  227.             // solve the linearized least squares problem
  228.             final RealVector dX = this.decomposition.solve(weightedJacobian, currentResiduals);
  229.             // update the estimated parameters
  230.             currentPoint = currentPoint.add(dX);
  231.         }
  232.     }

  233.     /** {@inheritDoc} */
  234.     @Override
  235.     public String toString() {
  236.         return "GaussNewtonOptimizer{" +
  237.                 "decomposition=" + decomposition +
  238.                 '}';
  239.     }

  240.     /**
  241.      * Compute the normal matrix, J<sup>T</sup>J.
  242.      *
  243.      * @param jacobian  the m by n jacobian matrix, J. Input.
  244.      * @param residuals the m by 1 residual vector, r. Input.
  245.      * @return  the n by n normal matrix and  the n by 1 J<sup>Tr</sup> vector.
  246.      */
  247.     private static Pair<RealMatrix, RealVector> computeNormalMatrix(final RealMatrix jacobian,
  248.                                                                     final RealVector residuals) {
  249.         //since the normal matrix is symmetric, we only need to compute half of it.
  250.         final int nR = jacobian.getRowDimension();
  251.         final int nC = jacobian.getColumnDimension();
  252.         //allocate space for return values
  253.         final RealMatrix normal = MatrixUtils.createRealMatrix(nC, nC);
  254.         final RealVector jTr = new ArrayRealVector(nC);
  255.         //for each measurement
  256.         for (int i = 0; i < nR; ++i) {
  257.             //compute JTr for measurement i
  258.             for (int j = 0; j < nC; j++) {
  259.                 jTr.setEntry(j, jTr.getEntry(j) +
  260.                         residuals.getEntry(i) * jacobian.getEntry(i, j));
  261.             }

  262.             // add the contribution to the normal matrix for measurement i
  263.             for (int k = 0; k < nC; ++k) {
  264.                 //only compute the upper triangular part
  265.                 for (int l = k; l < nC; ++l) {
  266.                     normal.setEntry(k, l, normal.getEntry(k, l) +
  267.                             jacobian.getEntry(i, k) * jacobian.getEntry(i, l));
  268.                 }
  269.             }
  270.         }
  271.         //copy the upper triangular part to the lower triangular part.
  272.         for (int i = 0; i < nC; i++) {
  273.             for (int j = 0; j < i; j++) {
  274.                 normal.setEntry(i, j, normal.getEntry(j, i));
  275.             }
  276.         }
  277.         return new Pair<>(normal, jTr);
  278.     }
  279. }