View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.fitting.leastsquares;
18  
19  import org.apache.commons.math4.legacy.analysis.MultivariateVectorFunction;
20  import org.apache.commons.math4.legacy.analysis.UnivariateVectorFunction;
21  import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
22  import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateVectorFunctionDifferentiator;
23  import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
24  import org.apache.commons.math4.legacy.linear.ArrayRealVector;
25  import org.apache.commons.math4.legacy.linear.RealMatrix;
26  import org.apache.commons.math4.legacy.linear.RealVector;
27  import org.apache.commons.math4.legacy.core.Pair;
28  
29  import java.util.Arrays;
30  
31  /**
32   * A MultivariateJacobianFunction (a thing that requires a derivative)
33   * combined with the thing that can find derivatives.
34   *
35   * Can be used with a LeastSquaresProblem, a LeastSquaresFactory, or a LeastSquaresBuilder.
36   */
37  public class DifferentiatorVectorMultivariateJacobianFunction implements MultivariateJacobianFunction {
38      /**
39       * The input function to find a jacobian for.
40       */
41      private final MultivariateVectorFunction function;
42      /**
43       * The differentiator to use to find the jacobian.
44       */
45      private final UnivariateVectorFunctionDifferentiator differentiator;
46  
47      /**
48       * Build the jacobian function using a differentiator.
49       *
50       * @param function the function to turn into a jacobian
51       * @param differentiator the differentiator to find the derivative
52       */
53      public DifferentiatorVectorMultivariateJacobianFunction(MultivariateVectorFunction function, UnivariateVectorFunctionDifferentiator differentiator) {
54          this.function = function;
55          this.differentiator = differentiator;
56      }
57  
58      /** {@inheritDoc} */
59      @Override
60      public Pair<RealVector, RealMatrix> value(RealVector point) {
61          double[] testArray = point.toArray();
62          RealVector value = new ArrayRealVector(function.value(testArray));
63          RealMatrix jacobian = new Array2DRowRealMatrix(value.getDimension(), point.getDimension());
64  
65          for(int column = 0; column < point.getDimension(); column++) {
66              final int columnFinal = column;
67              double originalPoint = point.getEntry(column);
68              double[] partialDerivatives = getPartialDerivative(testPoint -> {
69  
70                  testArray[columnFinal] = testPoint;
71  
72                  return function.value(testArray);
73              }, originalPoint);
74  
75              testArray[column] = originalPoint; //set it back
76  
77              jacobian.setColumn(column, partialDerivatives);
78          }
79  
80          return new Pair<>(value, jacobian);
81      }
82  
83      /**
84       * Returns first order derivative for the function passed in using a differentiator.
85       * @param univariateVectorFunction the function to differentiate
86       * @param atParameterValue the point at which to differentiate it at
87       * @return the slopes at that point
88       */
89      private double[] getPartialDerivative(UnivariateVectorFunction univariateVectorFunction, double atParameterValue) {
90          DerivativeStructure[] derivatives = differentiator
91                  .differentiate(univariateVectorFunction)
92                  .value(new DerivativeStructure(1, 1, 0, atParameterValue));
93          return Arrays.stream(derivatives).mapToDouble(derivative -> derivative.getPartialDerivative(1)).toArray();
94      }
95  }