001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math3.analysis.differentiation;
018    
019    import org.apache.commons.math3.analysis.MultivariateVectorFunction;
020    
021    /** Class representing the gradient of a multivariate function.
022     * <p>
023     * The vectorial components of the function represent the derivatives
024     * with respect to each function parameters.
025     * </p>
026     * @version $Id: GradientFunction.java 1455194 2013-03-11 15:45:54Z luc $
027     * @since 3.1
028     */
029    public class GradientFunction implements MultivariateVectorFunction {
030    
031        /** Underlying real-valued function. */
032        private final MultivariateDifferentiableFunction f;
033    
034        /** Simple constructor.
035         * @param f underlying real-valued function
036         */
037        public GradientFunction(final MultivariateDifferentiableFunction f) {
038            this.f = f;
039        }
040    
041        /** {@inheritDoc} */
042        public double[] value(double[] point) {
043    
044            // set up parameters
045            final DerivativeStructure[] dsX = new DerivativeStructure[point.length];
046            for (int i = 0; i < point.length; ++i) {
047                dsX[i] = new DerivativeStructure(point.length, 1, i, point[i]);
048            }
049    
050            // compute the derivatives
051            final DerivativeStructure dsY = f.value(dsX);
052    
053            // extract the gradient
054            final double[] y = new double[point.length];
055            final int[] orders = new int[point.length];
056            for (int i = 0; i < point.length; ++i) {
057                orders[i] = 1;
058                y[i] = dsY.getPartialDerivative(orders);
059                orders[i] = 0;
060            }
061    
062            return y;
063    
064        }
065    
066    }