FieldBracketingNthOrderBrentSolver.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.analysis.solvers;


  18. import org.apache.commons.math4.legacy.core.Field;
  19. import org.apache.commons.math4.legacy.core.RealFieldElement;
  20. import org.apache.commons.math4.legacy.analysis.RealFieldUnivariateFunction;
  21. import org.apache.commons.math4.legacy.exception.MathInternalError;
  22. import org.apache.commons.math4.legacy.exception.NoBracketingException;
  23. import org.apache.commons.math4.legacy.exception.NullArgumentException;
  24. import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
  25. import org.apache.commons.math4.legacy.core.IntegerSequence;
  26. import org.apache.commons.math4.legacy.core.MathArrays;
  27. import org.apache.commons.numbers.core.Precision;

  28. /**
  29.  * This class implements a modification of the <a
  30.  * href="http://mathworld.wolfram.com/BrentsMethod.html"> Brent algorithm</a>.
  31.  * <p>
  32.  * The changes with respect to the original Brent algorithm are:
  33.  * <ul>
  34.  *   <li>the returned value is chosen in the current interval according
  35.  *   to user specified {@link AllowedSolution}</li>
  36.  *   <li>the maximal order for the invert polynomial root search is
  37.  *   user-specified instead of being invert quadratic only</li>
  38.  * </ul><p>
  39.  * The given interval must bracket the root.</p>
  40.  *
  41.  * @param <T> the type of the field elements
  42.  * @since 3.6
  43.  */
  44. public class FieldBracketingNthOrderBrentSolver<T extends RealFieldElement<T>>
  45.     implements BracketedRealFieldUnivariateSolver<T> {

  46.    /** Maximal aging triggering an attempt to balance the bracketing interval. */
  47.     private static final int MAXIMAL_AGING = 2;

  48.     /** Field to which the elements belong. */
  49.     private final Field<T> field;

  50.     /** Maximal order. */
  51.     private final int maximalOrder;

  52.     /** Function value accuracy. */
  53.     private final T functionValueAccuracy;

  54.     /** Absolute accuracy. */
  55.     private final T absoluteAccuracy;

  56.     /** Relative accuracy. */
  57.     private final T relativeAccuracy;

  58.     /** Evaluations counter. */
  59.     private IntegerSequence.Incrementor evaluations;

  60.     /**
  61.      * Construct a solver.
  62.      *
  63.      * @param relativeAccuracy Relative accuracy.
  64.      * @param absoluteAccuracy Absolute accuracy.
  65.      * @param functionValueAccuracy Function value accuracy.
  66.      * @param maximalOrder maximal order.
  67.      * @exception NumberIsTooSmallException if maximal order is lower than 2
  68.      */
  69.     public FieldBracketingNthOrderBrentSolver(final T relativeAccuracy,
  70.                                               final T absoluteAccuracy,
  71.                                               final T functionValueAccuracy,
  72.                                               final int maximalOrder)
  73.         throws NumberIsTooSmallException {
  74.         if (maximalOrder < 2) {
  75.             throw new NumberIsTooSmallException(maximalOrder, 2, true);
  76.         }
  77.         this.field                 = relativeAccuracy.getField();
  78.         this.maximalOrder          = maximalOrder;
  79.         this.absoluteAccuracy      = absoluteAccuracy;
  80.         this.relativeAccuracy      = relativeAccuracy;
  81.         this.functionValueAccuracy = functionValueAccuracy;
  82.         this.evaluations           = IntegerSequence.Incrementor.create();
  83.     }

  84.     /** Get the maximal order.
  85.      * @return maximal order
  86.      */
  87.     public int getMaximalOrder() {
  88.         return maximalOrder;
  89.     }

  90.     /**
  91.      * Get the maximal number of function evaluations.
  92.      *
  93.      * @return the maximal number of function evaluations.
  94.      */
  95.     @Override
  96.     public int getMaxEvaluations() {
  97.         return evaluations.getMaximalCount();
  98.     }

  99.     /**
  100.      * Get the number of evaluations of the objective function.
  101.      * The number of evaluations corresponds to the last call to the
  102.      * {@code optimize} method. It is 0 if the method has not been
  103.      * called yet.
  104.      *
  105.      * @return the number of evaluations of the objective function.
  106.      */
  107.     @Override
  108.     public int getEvaluations() {
  109.         return evaluations.getCount();
  110.     }

  111.     /**
  112.      * Get the absolute accuracy.
  113.      * @return absolute accuracy
  114.      */
  115.     @Override
  116.     public T getAbsoluteAccuracy() {
  117.         return absoluteAccuracy;
  118.     }

  119.     /**
  120.      * Get the relative accuracy.
  121.      * @return relative accuracy
  122.      */
  123.     @Override
  124.     public T getRelativeAccuracy() {
  125.         return relativeAccuracy;
  126.     }

  127.     /**
  128.      * Get the function accuracy.
  129.      * @return function accuracy
  130.      */
  131.     @Override
  132.     public T getFunctionValueAccuracy() {
  133.         return functionValueAccuracy;
  134.     }

  135.     /**
  136.      * Solve for a zero in the given interval.
  137.      * A solver may require that the interval brackets a single zero root.
  138.      * Solvers that do require bracketing should be able to handle the case
  139.      * where one of the endpoints is itself a root.
  140.      *
  141.      * @param maxEval Maximum number of evaluations.
  142.      * @param f Function to solve.
  143.      * @param min Lower bound for the interval.
  144.      * @param max Upper bound for the interval.
  145.      * @param allowedSolution The kind of solutions that the root-finding algorithm may
  146.      * accept as solutions.
  147.      * @return a value where the function is zero.
  148.      * @exception NullArgumentException if f is null.
  149.      * @exception NoBracketingException if root cannot be bracketed
  150.      */
  151.     @Override
  152.     public T solve(final int maxEval, final RealFieldUnivariateFunction<T> f,
  153.                    final T min, final T max, final AllowedSolution allowedSolution)
  154.         throws NullArgumentException, NoBracketingException {
  155.         return solve(maxEval, f, min, max, min.add(max).divide(2), allowedSolution);
  156.     }

  157.     /**
  158.      * Solve for a zero in the given interval, start at {@code startValue}.
  159.      * A solver may require that the interval brackets a single zero root.
  160.      * Solvers that do require bracketing should be able to handle the case
  161.      * where one of the endpoints is itself a root.
  162.      *
  163.      * @param maxEval Maximum number of evaluations.
  164.      * @param f Function to solve.
  165.      * @param min Lower bound for the interval.
  166.      * @param max Upper bound for the interval.
  167.      * @param startValue Start value to use.
  168.      * @param allowedSolution The kind of solutions that the root-finding algorithm may
  169.      * accept as solutions.
  170.      * @return a value where the function is zero.
  171.      * @exception NullArgumentException if f is null.
  172.      * @exception NoBracketingException if root cannot be bracketed
  173.      */
  174.     @Override
  175.     public T solve(final int maxEval, final RealFieldUnivariateFunction<T> f,
  176.                    final T min, final T max, final T startValue,
  177.                    final AllowedSolution allowedSolution)
  178.         throws NullArgumentException, NoBracketingException {

  179.         // Checks.
  180.         NullArgumentException.check(f);

  181.         // Reset.
  182.         evaluations = evaluations.withMaximalCount(maxEval).withStart(0);
  183.         T zero = field.getZero();
  184.         T nan  = zero.add(Double.NaN);

  185.         // prepare arrays with the first points
  186.         final T[] x = MathArrays.buildArray(field, maximalOrder + 1);
  187.         final T[] y = MathArrays.buildArray(field, maximalOrder + 1);
  188.         x[0] = min;
  189.         x[1] = startValue;
  190.         x[2] = max;

  191.         // evaluate initial guess
  192.         evaluations.increment();
  193.         y[1] = f.value(x[1]);
  194.         if (Precision.equals(y[1].getReal(), 0.0, 1)) {
  195.             // return the initial guess if it is a perfect root.
  196.             return x[1];
  197.         }

  198.         // evaluate first endpoint
  199.         evaluations.increment();
  200.         y[0] = f.value(x[0]);
  201.         if (Precision.equals(y[0].getReal(), 0.0, 1)) {
  202.             // return the first endpoint if it is a perfect root.
  203.             return x[0];
  204.         }

  205.         int nbPoints;
  206.         int signChangeIndex;
  207.         if (y[0].multiply(y[1]).getReal() < 0) {

  208.             // reduce interval if it brackets the root
  209.             nbPoints        = 2;
  210.             signChangeIndex = 1;
  211.         } else {

  212.             // evaluate second endpoint
  213.             evaluations.increment();
  214.             y[2] = f.value(x[2]);
  215.             if (Precision.equals(y[2].getReal(), 0.0, 1)) {
  216.                 // return the second endpoint if it is a perfect root.
  217.                 return x[2];
  218.             }

  219.             if (y[1].multiply(y[2]).getReal() < 0) {
  220.                 // use all computed point as a start sampling array for solving
  221.                 nbPoints        = 3;
  222.                 signChangeIndex = 2;
  223.             } else {
  224.                 throw new NoBracketingException(x[0].getReal(), x[2].getReal(),
  225.                                                 y[0].getReal(), y[2].getReal());
  226.             }
  227.         }

  228.         // prepare a work array for inverse polynomial interpolation
  229.         final T[] tmpX = MathArrays.buildArray(field, x.length);

  230.         // current tightest bracketing of the root
  231.         T xA    = x[signChangeIndex - 1];
  232.         T yA    = y[signChangeIndex - 1];
  233.         T absXA = xA.abs();
  234.         T absYA = yA.abs();
  235.         int agingA   = 0;
  236.         T xB    = x[signChangeIndex];
  237.         T yB    = y[signChangeIndex];
  238.         T absXB = xB.abs();
  239.         T absYB = yB.abs();
  240.         int agingB   = 0;

  241.         // search loop
  242.         while (true) {

  243.             // check convergence of bracketing interval
  244.             T maxX = absXA.subtract(absXB).getReal() < 0 ? absXB : absXA;
  245.             T maxY = absYA.subtract(absYB).getReal() < 0 ? absYB : absYA;
  246.             final T xTol = absoluteAccuracy.add(relativeAccuracy.multiply(maxX));
  247.             if (xB.subtract(xA).subtract(xTol).getReal() <= 0 ||
  248.                 maxY.subtract(functionValueAccuracy).getReal() < 0) {
  249.                 switch (allowedSolution) {
  250.                 case ANY_SIDE :
  251.                     return absYA.subtract(absYB).getReal() < 0 ? xA : xB;
  252.                 case LEFT_SIDE :
  253.                     return xA;
  254.                 case RIGHT_SIDE :
  255.                     return xB;
  256.                 case BELOW_SIDE :
  257.                     return yA.getReal() <= 0 ? xA : xB;
  258.                 case ABOVE_SIDE :
  259.                     return yA.getReal() < 0 ? xB : xA;
  260.                 default :
  261.                     // this should never happen
  262.                     throw new MathInternalError(null);
  263.                 }
  264.             }

  265.             // target for the next evaluation point
  266.             T targetY;
  267.             if (agingA >= MAXIMAL_AGING) {
  268.                 // we keep updating the high bracket, try to compensate this
  269.                 targetY = yB.divide(16).negate();
  270.             } else if (agingB >= MAXIMAL_AGING) {
  271.                 // we keep updating the low bracket, try to compensate this
  272.                 targetY = yA.divide(16).negate();
  273.             } else {
  274.                 // bracketing is balanced, try to find the root itself
  275.                 targetY = zero;
  276.             }

  277.             // make a few attempts to guess a root,
  278.             T nextX;
  279.             int start = 0;
  280.             int end   = nbPoints;
  281.             do {

  282.                 // guess a value for current target, using inverse polynomial interpolation
  283.                 System.arraycopy(x, start, tmpX, start, end - start);
  284.                 nextX = guessX(targetY, tmpX, y, start, end);

  285.                 if (!(nextX.subtract(xA).getReal() > 0 && nextX.subtract(xB).getReal() < 0)) {
  286.                     // the guessed root is not strictly inside of the tightest bracketing interval

  287.                     // the guessed root is either not strictly inside the interval or it
  288.                     // is a NaN (which occurs when some sampling points share the same y)
  289.                     // we try again with a lower interpolation order
  290.                     if (signChangeIndex - start >= end - signChangeIndex) {
  291.                         // we have more points before the sign change, drop the lowest point
  292.                         ++start;
  293.                     } else {
  294.                         // we have more points after sign change, drop the highest point
  295.                         --end;
  296.                     }

  297.                     // we need to do one more attempt
  298.                     nextX = nan;
  299.                 }
  300.             } while (Double.isNaN(nextX.getReal()) && end - start > 1);

  301.             if (Double.isNaN(nextX.getReal())) {
  302.                 // fall back to bisection
  303.                 nextX = xA.add(xB.subtract(xA).divide(2));
  304.                 start = signChangeIndex - 1;
  305.                 end   = signChangeIndex;
  306.             }

  307.             // evaluate the function at the guessed root
  308.             evaluations.increment();
  309.             final T nextY = f.value(nextX);
  310.             if (Precision.equals(nextY.getReal(), 0.0, 1)) {
  311.                 // we have found an exact root, since it is not an approximation
  312.                 // we don't need to bother about the allowed solutions setting
  313.                 return nextX;
  314.             }

  315.             if (nbPoints > 2 && end - start != nbPoints) {

  316.                 // we have been forced to ignore some points to keep bracketing,
  317.                 // they are probably too far from the root, drop them from now on
  318.                 nbPoints = end - start;
  319.                 System.arraycopy(x, start, x, 0, nbPoints);
  320.                 System.arraycopy(y, start, y, 0, nbPoints);
  321.                 signChangeIndex -= start;
  322.             } else  if (nbPoints == x.length) {

  323.                 // we have to drop one point in order to insert the new one
  324.                 nbPoints--;

  325.                 // keep the tightest bracketing interval as centered as possible
  326.                 if (signChangeIndex >= (x.length + 1) / 2) {
  327.                     // we drop the lowest point, we have to shift the arrays and the index
  328.                     System.arraycopy(x, 1, x, 0, nbPoints);
  329.                     System.arraycopy(y, 1, y, 0, nbPoints);
  330.                     --signChangeIndex;
  331.                 }
  332.             }

  333.             // insert the last computed point
  334.             //(by construction, we know it lies inside the tightest bracketing interval)
  335.             System.arraycopy(x, signChangeIndex, x, signChangeIndex + 1, nbPoints - signChangeIndex);
  336.             x[signChangeIndex] = nextX;
  337.             System.arraycopy(y, signChangeIndex, y, signChangeIndex + 1, nbPoints - signChangeIndex);
  338.             y[signChangeIndex] = nextY;
  339.             ++nbPoints;

  340.             // update the bracketing interval
  341.             if (nextY.multiply(yA).getReal() <= 0) {
  342.                 // the sign change occurs before the inserted point
  343.                 xB = nextX;
  344.                 yB = nextY;
  345.                 absYB = yB.abs();
  346.                 ++agingA;
  347.                 agingB = 0;
  348.             } else {
  349.                 // the sign change occurs after the inserted point
  350.                 xA = nextX;
  351.                 yA = nextY;
  352.                 absYA = yA.abs();
  353.                 agingA = 0;
  354.                 ++agingB;

  355.                 // update the sign change index
  356.                 signChangeIndex++;
  357.             }
  358.         }
  359.     }

  360.     /** Guess an x value by n<sup>th</sup> order inverse polynomial interpolation.
  361.      * <p>
  362.      * The x value is guessed by evaluating polynomial Q(y) at y = targetY, where Q
  363.      * is built such that for all considered points (x<sub>i</sub>, y<sub>i</sub>),
  364.      * Q(y<sub>i</sub>) = x<sub>i</sub>.
  365.      * </p>
  366.      * @param targetY target value for y
  367.      * @param x reference points abscissas for interpolation,
  368.      * note that this array <em>is</em> modified during computation
  369.      * @param y reference points ordinates for interpolation
  370.      * @param start start index of the points to consider (inclusive)
  371.      * @param end end index of the points to consider (exclusive)
  372.      * @return guessed root (will be a NaN if two points share the same y)
  373.      */
  374.     private T guessX(final T targetY, final T[] x, final T[] y,
  375.                        final int start, final int end) {

  376.         // compute Q Newton coefficients by divided differences
  377.         for (int i = start; i < end - 1; ++i) {
  378.             final int delta = i + 1 - start;
  379.             for (int j = end - 1; j > i; --j) {
  380.                 x[j] = x[j].subtract(x[j-1]).divide(y[j].subtract(y[j - delta]));
  381.             }
  382.         }

  383.         // evaluate Q(targetY)
  384.         T x0 = field.getZero();
  385.         for (int j = end - 1; j >= start; --j) {
  386.             x0 = x[j].add(x0.multiply(targetY.subtract(y[j])));
  387.         }

  388.         return x0;
  389.     }
  390. }