DifferentiatorVectorMultivariateJacobianFunction.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.fitting.leastsquares;

  18. import org.apache.commons.math4.legacy.analysis.MultivariateVectorFunction;
  19. import org.apache.commons.math4.legacy.analysis.UnivariateVectorFunction;
  20. import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
  21. import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateVectorFunctionDifferentiator;
  22. import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
  23. import org.apache.commons.math4.legacy.linear.ArrayRealVector;
  24. import org.apache.commons.math4.legacy.linear.RealMatrix;
  25. import org.apache.commons.math4.legacy.linear.RealVector;
  26. import org.apache.commons.math4.legacy.core.Pair;

  27. import java.util.Arrays;

  28. /**
  29.  * A MultivariateJacobianFunction (a thing that requires a derivative)
  30.  * combined with the thing that can find derivatives.
  31.  *
  32.  * Can be used with a LeastSquaresProblem, a LeastSquaresFactory, or a LeastSquaresBuilder.
  33.  */
  34. public class DifferentiatorVectorMultivariateJacobianFunction implements MultivariateJacobianFunction {
  35.     /**
  36.      * The input function to find a jacobian for.
  37.      */
  38.     private final MultivariateVectorFunction function;
  39.     /**
  40.      * The differentiator to use to find the jacobian.
  41.      */
  42.     private final UnivariateVectorFunctionDifferentiator differentiator;

  43.     /**
  44.      * Build the jacobian function using a differentiator.
  45.      *
  46.      * @param function the function to turn into a jacobian
  47.      * @param differentiator the differentiator to find the derivative
  48.      */
  49.     public DifferentiatorVectorMultivariateJacobianFunction(MultivariateVectorFunction function, UnivariateVectorFunctionDifferentiator differentiator) {
  50.         this.function = function;
  51.         this.differentiator = differentiator;
  52.     }

  53.     /** {@inheritDoc} */
  54.     @Override
  55.     public Pair<RealVector, RealMatrix> value(RealVector point) {
  56.         double[] testArray = point.toArray();
  57.         RealVector value = new ArrayRealVector(function.value(testArray));
  58.         RealMatrix jacobian = new Array2DRowRealMatrix(value.getDimension(), point.getDimension());

  59.         for(int column = 0; column < point.getDimension(); column++) {
  60.             final int columnFinal = column;
  61.             double originalPoint = point.getEntry(column);
  62.             double[] partialDerivatives = getPartialDerivative(testPoint -> {

  63.                 testArray[columnFinal] = testPoint;

  64.                 return function.value(testArray);
  65.             }, originalPoint);

  66.             testArray[column] = originalPoint; //set it back

  67.             jacobian.setColumn(column, partialDerivatives);
  68.         }

  69.         return new Pair<>(value, jacobian);
  70.     }

  71.     /**
  72.      * Returns first order derivative for the function passed in using a differentiator.
  73.      * @param univariateVectorFunction the function to differentiate
  74.      * @param atParameterValue the point at which to differentiate it at
  75.      * @return the slopes at that point
  76.      */
  77.     private double[] getPartialDerivative(UnivariateVectorFunction univariateVectorFunction, double atParameterValue) {
  78.         DerivativeStructure[] derivatives = differentiator
  79.                 .differentiate(univariateVectorFunction)
  80.                 .value(new DerivativeStructure(1, 1, 0, atParameterValue));
  81.         return Arrays.stream(derivatives).mapToDouble(derivative -> derivative.getPartialDerivative(1)).toArray();
  82.     }
  83. }