View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.analysis.differentiation;
19  
20  import org.apache.commons.math4.legacy.TestUtils;
21  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
22  import org.apache.commons.math4.legacy.exception.MathIllegalArgumentException;
23  import org.apache.commons.math4.core.jdkmath.JdkMath;
24  import org.junit.Test;
25  
26  
27  /**
28   * Test for class {@link GradientFunction}.
29   */
30  public class GradientFunctionTest {
31  
32      @Test
33      public void test2DDistance() {
34          EuclideanDistance f = new EuclideanDistance();
35          GradientFunction g = new GradientFunction(f);
36          for (double x = -10; x < 10; x += 0.5) {
37              for (double y = -10; y < 10; y += 0.5) {
38                  double[] point = new double[] { x, y };
39                  TestUtils.assertEquals(f.gradient(point), g.value(point), 1.0e-15);
40              }
41          }
42      }
43  
44      @Test
45      public void test3DDistance() {
46          EuclideanDistance f = new EuclideanDistance();
47          GradientFunction g = new GradientFunction(f);
48          for (double x = -10; x < 10; x += 0.5) {
49              for (double y = -10; y < 10; y += 0.5) {
50                  for (double z = -10; z < 10; z += 0.5) {
51                      double[] point = new double[] { x, y, z };
52                      TestUtils.assertEquals(f.gradient(point), g.value(point), 1.0e-15);
53                  }
54              }
55          }
56      }
57  
58      private static final class EuclideanDistance implements MultivariateDifferentiableFunction {
59  
60          @Override
61          public double value(double[] point) {
62              double d2 = 0;
63              for (double x : point) {
64                  d2 += x * x;
65              }
66              return JdkMath.sqrt(d2);
67          }
68  
69          @Override
70          public DerivativeStructure value(DerivativeStructure[] point)
71              throws DimensionMismatchException, MathIllegalArgumentException {
72              DerivativeStructure d2 = point[0].getField().getZero();
73              for (DerivativeStructure x : point) {
74                  d2 = d2.add(x.multiply(x));
75              }
76              return d2.sqrt();
77          }
78  
79          public double[] gradient(double[] point) {
80              double[] gradient = new double[point.length];
81              double d = value(point);
82              for (int i = 0; i < point.length; ++i) {
83                  gradient[i] = point[i] / d;
84              }
85              return gradient;
86          }
87      }
88  }