SimpleCurveFitter.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.fitting;

  18. import java.util.Collections;
  19. import java.util.Collection;
  20. import java.util.Comparator;
  21. import java.util.List;
  22. import java.util.ArrayList;

  23. import org.apache.commons.math4.legacy.exception.ZeroException;
  24. import org.apache.commons.math4.legacy.exception.OutOfRangeException;
  25. import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
  26. import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder;
  27. import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
  28. import org.apache.commons.math4.legacy.linear.DiagonalMatrix;

  29. /**
  30.  * Fits points to a user-defined {@link ParametricUnivariateFunction function}.
  31.  *
  32.  * @since 3.4
  33.  */
  34. public class SimpleCurveFitter extends AbstractCurveFitter {
  35.     /** Function to fit. */
  36.     private final ParametricUnivariateFunction function;
  37.     /** Initial guess for the parameters. */
  38.     private final double[] initialGuess;
  39.     /** Parameter guesser. */
  40.     private final ParameterGuesser guesser;
  41.     /** Maximum number of iterations of the optimization algorithm. */
  42.     private final int maxIter;

  43.     /**
  44.      * Constructor used by the factory methods.
  45.      *
  46.      * @param function Function to fit.
  47.      * @param initialGuess Initial guess. Cannot be {@code null}. Its length must
  48.      * be consistent with the number of parameters of the {@code function} to fit.
  49.      * @param guesser Method for providing an initial guess (if {@code initialGuess}
  50.      * is {@code null}).
  51.      * @param maxIter Maximum number of iterations of the optimization algorithm.
  52.      */
  53.     protected SimpleCurveFitter(ParametricUnivariateFunction function,
  54.                                 double[] initialGuess,
  55.                                 ParameterGuesser guesser,
  56.                                 int maxIter) {
  57.         this.function = function;
  58.         this.initialGuess = initialGuess;
  59.         this.guesser = guesser;
  60.         this.maxIter = maxIter;
  61.     }

  62.     /**
  63.      * Creates a curve fitter.
  64.      * The maximum number of iterations of the optimization algorithm is set
  65.      * to {@link Integer#MAX_VALUE}.
  66.      *
  67.      * @param f Function to fit.
  68.      * @param start Initial guess for the parameters.  Cannot be {@code null}.
  69.      * Its length must be consistent with the number of parameters of the
  70.      * function to fit.
  71.      * @return a curve fitter.
  72.      *
  73.      * @see #withStartPoint(double[])
  74.      * @see #withMaxIterations(int)
  75.      */
  76.     public static SimpleCurveFitter create(ParametricUnivariateFunction f,
  77.                                            double[] start) {
  78.         return new SimpleCurveFitter(f, start, null, Integer.MAX_VALUE);
  79.     }

  80.     /**
  81.      * Creates a curve fitter.
  82.      * The maximum number of iterations of the optimization algorithm is set
  83.      * to {@link Integer#MAX_VALUE}.
  84.      *
  85.      * @param f Function to fit.
  86.      * @param guesser Method for providing an initial guess.
  87.      * @return a curve fitter.
  88.      *
  89.      * @see #withStartPoint(double[])
  90.      * @see #withMaxIterations(int)
  91.      */
  92.     public static SimpleCurveFitter create(ParametricUnivariateFunction f,
  93.                                            ParameterGuesser guesser) {
  94.         return new SimpleCurveFitter(f, null, guesser, Integer.MAX_VALUE);
  95.     }

  96.     /**
  97.      * Configure the start point (initial guess).
  98.      * @param newStart new start point (initial guess)
  99.      * @return a new instance.
  100.      */
  101.     public SimpleCurveFitter withStartPoint(double[] newStart) {
  102.         return new SimpleCurveFitter(function,
  103.                                      newStart.clone(),
  104.                                      null,
  105.                                      maxIter);
  106.     }

  107.     /**
  108.      * Configure the maximum number of iterations.
  109.      * @param newMaxIter maximum number of iterations
  110.      * @return a new instance.
  111.      */
  112.     public SimpleCurveFitter withMaxIterations(int newMaxIter) {
  113.         return new SimpleCurveFitter(function,
  114.                                      initialGuess,
  115.                                      guesser,
  116.                                      newMaxIter);
  117.     }

  118.     /** {@inheritDoc} */
  119.     @Override
  120.     protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
  121.         // Prepare least-squares problem.
  122.         final int len = observations.size();
  123.         final double[] target  = new double[len];
  124.         final double[] weights = new double[len];

  125.         int count = 0;
  126.         for (WeightedObservedPoint obs : observations) {
  127.             target[count]  = obs.getY();
  128.             weights[count] = obs.getWeight();
  129.             ++count;
  130.         }

  131.         final AbstractCurveFitter.TheoreticalValuesFunction model
  132.             = new AbstractCurveFitter.TheoreticalValuesFunction(function,
  133.                                                                 observations);

  134.         final double[] startPoint = initialGuess != null ?
  135.             initialGuess :
  136.             // Compute estimation.
  137.             guesser.guess(observations);

  138.         // Create an optimizer for fitting the curve to the observed points.
  139.         return new LeastSquaresBuilder().
  140.                 maxEvaluations(Integer.MAX_VALUE).
  141.                 maxIterations(maxIter).
  142.                 start(startPoint).
  143.                 target(target).
  144.                 weight(new DiagonalMatrix(weights)).
  145.                 model(model.getModelFunction(), model.getModelFunctionJacobian()).
  146.                 build();
  147.     }

  148.     /**
  149.      * Guesses the parameters.
  150.      */
  151.     public abstract static class ParameterGuesser {
  152.         /** Comparator. */
  153.         private static final Comparator<WeightedObservedPoint> CMP = new Comparator<WeightedObservedPoint>() {
  154.                 /** {@inheritDoc} */
  155.                 @Override
  156.                 public int compare(WeightedObservedPoint p1,
  157.                                    WeightedObservedPoint p2) {
  158.                     if (p1 == null && p2 == null) {
  159.                         return 0;
  160.                     }
  161.                     if (p1 == null) {
  162.                         return -1;
  163.                     }
  164.                     if (p2 == null) {
  165.                         return 1;
  166.                     }
  167.                     int comp = Double.compare(p1.getX(), p2.getX());
  168.                     if (comp != 0) {
  169.                         return comp;
  170.                     }
  171.                     comp = Double.compare(p1.getY(), p2.getY());
  172.                     if (comp != 0) {
  173.                         return comp;
  174.                     }
  175.                     return Double.compare(p1.getWeight(), p2.getWeight());
  176.                 }
  177.             };

  178.         /**
  179.          * Computes an estimation of the parameters.
  180.          *
  181.          * @param obs Observations.
  182.          * @return the guessed parameters.
  183.          */
  184.         public abstract double[] guess(Collection<WeightedObservedPoint> obs);

  185.         /**
  186.          * Sort the observations.
  187.          *
  188.          * @param unsorted Input observations.
  189.          * @return the input observations, sorted.
  190.          */
  191.         protected List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
  192.             final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);
  193.             Collections.sort(observations, CMP);
  194.             return observations;
  195.         }

  196.         /**
  197.          * Finds index of point in specified points with the largest Y.
  198.          *
  199.          * @param points Points to search.
  200.          * @return the index in specified points array.
  201.          */
  202.         protected int findMaxY(WeightedObservedPoint[] points) {
  203.             int maxYIdx = 0;
  204.             for (int i = 1; i < points.length; i++) {
  205.                 if (points[i].getY() > points[maxYIdx].getY()) {
  206.                     maxYIdx = i;
  207.                 }
  208.             }
  209.             return maxYIdx;
  210.         }

  211.         /**
  212.          * Interpolates using the specified points to determine X at the
  213.          * specified Y.
  214.          *
  215.          * @param points Points to use for interpolation.
  216.          * @param startIdx Index within points from which to start the search for
  217.          * interpolation bounds points.
  218.          * @param idxStep Index step for searching interpolation bounds points.
  219.          * @param y Y value for which X should be determined.
  220.          * @return the value of X for the specified Y.
  221.          * @throws ZeroException if {@code idxStep} is 0.
  222.          * @throws OutOfRangeException if specified {@code y} is not within the
  223.          * range of the specified {@code points}.
  224.          */
  225.         protected double interpolateXAtY(WeightedObservedPoint[] points,
  226.                                          int startIdx,
  227.                                          int idxStep,
  228.                                          double y) {
  229.             if (idxStep == 0) {
  230.                 throw new ZeroException();
  231.             }
  232.             final WeightedObservedPoint[] twoPoints
  233.                 = getInterpolationPointsForY(points, startIdx, idxStep, y);
  234.             final WeightedObservedPoint p1 = twoPoints[0];
  235.             final WeightedObservedPoint p2 = twoPoints[1];
  236.             if (p1.getY() == y) {
  237.                 return p1.getX();
  238.             }
  239.             if (p2.getY() == y) {
  240.                 return p2.getX();
  241.             }
  242.             return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
  243.                                 (p2.getY() - p1.getY()));
  244.         }

  245.         /**
  246.          * Gets the two bounding interpolation points from the specified points
  247.          * suitable for determining X at the specified Y.
  248.          *
  249.          * @param points Points to use for interpolation.
  250.          * @param startIdx Index within points from which to start search for
  251.          * interpolation bounds points.
  252.          * @param idxStep Index step for search for interpolation bounds points.
  253.          * @param y Y value for which X should be determined.
  254.          * @return the array containing two points suitable for determining X at
  255.          * the specified Y.
  256.          * @throws ZeroException if {@code idxStep} is 0.
  257.          * @throws OutOfRangeException if specified {@code y} is not within the
  258.          * range of the specified {@code points}.
  259.          */
  260.         private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
  261.                                                                    int startIdx,
  262.                                                                    int idxStep,
  263.                                                                    double y) {
  264.             if (idxStep == 0) {
  265.                 throw new ZeroException();
  266.             }
  267.             for (int i = startIdx;
  268.                  idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
  269.                  i += idxStep) {
  270.                 final WeightedObservedPoint p1 = points[i];
  271.                 final WeightedObservedPoint p2 = points[i + idxStep];
  272.                 if (isBetween(y, p1.getY(), p2.getY())) {
  273.                     if (idxStep < 0) {
  274.                         return new WeightedObservedPoint[] { p2, p1 };
  275.                     } else {
  276.                         return new WeightedObservedPoint[] { p1, p2 };
  277.                     }
  278.                 }
  279.             }

  280.             // Boundaries are replaced by dummy values because the raised
  281.             // exception is caught and the message never displayed.
  282.             // TODO: Exceptions should not be used for flow control.
  283.             throw new OutOfRangeException(y,
  284.                                           Double.NEGATIVE_INFINITY,
  285.                                           Double.POSITIVE_INFINITY);
  286.         }

  287.         /**
  288.          * Determines whether a value is between two other values.
  289.          *
  290.          * @param value Value to test whether it is between {@code boundary1}
  291.          * and {@code boundary2}.
  292.          * @param boundary1 One end of the range.
  293.          * @param boundary2 Other end of the range.
  294.          * @return {@code true} if {@code value} is between {@code boundary1} and
  295.          * {@code boundary2} (inclusive), {@code false} otherwise.
  296.          */
  297.         private boolean isBetween(double value,
  298.                                   double boundary1,
  299.                                   double boundary2) {
  300.             return (value >= boundary1 && value <= boundary2) ||
  301.                 (value >= boundary2 && value <= boundary1);
  302.         }
  303.     }
  304. }