001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math4.legacy.fitting.leastsquares; 018 019import org.apache.commons.math4.legacy.analysis.MultivariateVectorFunction; 020import org.apache.commons.math4.legacy.analysis.UnivariateVectorFunction; 021import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure; 022import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateVectorFunctionDifferentiator; 023import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix; 024import org.apache.commons.math4.legacy.linear.ArrayRealVector; 025import org.apache.commons.math4.legacy.linear.RealMatrix; 026import org.apache.commons.math4.legacy.linear.RealVector; 027import org.apache.commons.math4.legacy.core.Pair; 028 029import java.util.Arrays; 030 031/** 032 * A MultivariateJacobianFunction (a thing that requires a derivative) 033 * combined with the thing that can find derivatives. 034 * 035 * Can be used with a LeastSquaresProblem, a LeastSquaresFactory, or a LeastSquaresBuilder. 036 */ 037public class DifferentiatorVectorMultivariateJacobianFunction implements MultivariateJacobianFunction { 038 /** 039 * The input function to find a jacobian for. 040 */ 041 private final MultivariateVectorFunction function; 042 /** 043 * The differentiator to use to find the jacobian. 044 */ 045 private final UnivariateVectorFunctionDifferentiator differentiator; 046 047 /** 048 * Build the jacobian function using a differentiator. 049 * 050 * @param function the function to turn into a jacobian 051 * @param differentiator the differentiator to find the derivative 052 */ 053 public DifferentiatorVectorMultivariateJacobianFunction(MultivariateVectorFunction function, UnivariateVectorFunctionDifferentiator differentiator) { 054 this.function = function; 055 this.differentiator = differentiator; 056 } 057 058 /** {@inheritDoc} */ 059 @Override 060 public Pair<RealVector, RealMatrix> value(RealVector point) { 061 double[] testArray = point.toArray(); 062 RealVector value = new ArrayRealVector(function.value(testArray)); 063 RealMatrix jacobian = new Array2DRowRealMatrix(value.getDimension(), point.getDimension()); 064 065 for(int column = 0; column < point.getDimension(); column++) { 066 final int columnFinal = column; 067 double originalPoint = point.getEntry(column); 068 double[] partialDerivatives = getPartialDerivative(testPoint -> { 069 070 testArray[columnFinal] = testPoint; 071 072 return function.value(testArray); 073 }, originalPoint); 074 075 testArray[column] = originalPoint; //set it back 076 077 jacobian.setColumn(column, partialDerivatives); 078 } 079 080 return new Pair<>(value, jacobian); 081 } 082 083 /** 084 * Returns first order derivative for the function passed in using a differentiator. 085 * @param univariateVectorFunction the function to differentiate 086 * @param atParameterValue the point at which to differentiate it at 087 * @return the slopes at that point 088 */ 089 private double[] getPartialDerivative(UnivariateVectorFunction univariateVectorFunction, double atParameterValue) { 090 DerivativeStructure[] derivatives = differentiator 091 .differentiate(univariateVectorFunction) 092 .value(new DerivativeStructure(1, 1, 0, atParameterValue)); 093 return Arrays.stream(derivatives).mapToDouble(derivative -> derivative.getPartialDerivative(1)).toArray(); 094 } 095}