1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 package org.apache.commons.math4.legacy.analysis.differentiation; 18 19 import org.apache.commons.math4.legacy.analysis.MultivariateVectorFunction; 20 21 /** Class representing the gradient of a multivariate function. 22 * <p> 23 * The vectorial components of the function represent the derivatives 24 * with respect to each function parameters. 25 * </p> 26 * @since 3.1 27 */ 28 public class GradientFunction implements MultivariateVectorFunction { 29 30 /** Underlying real-valued function. */ 31 private final MultivariateDifferentiableFunction f; 32 33 /** Simple constructor. 34 * @param f underlying real-valued function 35 */ 36 public GradientFunction(final MultivariateDifferentiableFunction f) { 37 this.f = f; 38 } 39 40 /** {@inheritDoc} */ 41 @Override 42 public double[] value(double[] point) { 43 44 // set up parameters 45 final DerivativeStructure[] dsX = new DerivativeStructure[point.length]; 46 for (int i = 0; i < point.length; ++i) { 47 dsX[i] = new DerivativeStructure(point.length, 1, i, point[i]); 48 } 49 50 // compute the derivatives 51 final DerivativeStructure dsY = f.value(dsX); 52 53 // extract the gradient 54 final double[] y = new double[point.length]; 55 final int[] orders = new int[point.length]; 56 for (int i = 0; i < point.length; ++i) { 57 orders[i] = 1; 58 y[i] = dsY.getPartialDerivative(orders); 59 orders[i] = 0; 60 } 61 62 return y; 63 } 64 }