LoessInterpolator.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.analysis.interpolation;

  18. import java.util.Arrays;

  19. import org.apache.commons.math4.legacy.analysis.polynomials.PolynomialSplineFunction;
  20. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  21. import org.apache.commons.math4.legacy.exception.NoDataException;
  22. import org.apache.commons.math4.legacy.exception.NonMonotonicSequenceException;
  23. import org.apache.commons.math4.legacy.exception.NotFiniteNumberException;
  24. import org.apache.commons.math4.legacy.exception.NotPositiveException;
  25. import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
  26. import org.apache.commons.math4.legacy.exception.OutOfRangeException;
  27. import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
  28. import org.apache.commons.math4.core.jdkmath.JdkMath;
  29. import org.apache.commons.math4.legacy.core.MathArrays;

  30. /**
  31.  * Implements the <a href="http://en.wikipedia.org/wiki/Local_regression">
  32.  * Local Regression Algorithm</a> (also Loess, Lowess) for interpolation of
  33.  * real univariate functions.
  34.  * <p>
  35.  * For reference, see
  36.  * <a href="http://amstat.tandfonline.com/doi/abs/10.1080/01621459.1979.10481038">
  37.  * William S. Cleveland - Robust Locally Weighted Regression and Smoothing
  38.  * Scatterplots</a></p>
  39.  * <p>
  40.  * This class implements both the loess method and serves as an interpolation
  41.  * adapter to it, allowing one to build a spline on the obtained loess fit.</p>
  42.  *
  43.  * @since 2.0
  44.  */
  45. public class LoessInterpolator
  46.     implements UnivariateInterpolator {
  47.     /** Default value of the bandwidth parameter. */
  48.     public static final double DEFAULT_BANDWIDTH = 0.3;
  49.     /** Default value of the number of robustness iterations. */
  50.     public static final int DEFAULT_ROBUSTNESS_ITERS = 2;
  51.     /**
  52.      * Default value for accuracy.
  53.      * @since 2.1
  54.      */
  55.     public static final double DEFAULT_ACCURACY = 1e-12;
  56.     /**
  57.      * The bandwidth parameter: when computing the loess fit at
  58.      * a particular point, this fraction of source points closest
  59.      * to the current point is taken into account for computing
  60.      * a least-squares regression.
  61.      * <p>
  62.      * A sensible value is usually 0.25 to 0.5.</p>
  63.      */
  64.     private final double bandwidth;
  65.     /**
  66.      * The number of robustness iterations parameter: this many
  67.      * robustness iterations are done.
  68.      * <p>
  69.      * A sensible value is usually 0 (just the initial fit without any
  70.      * robustness iterations) to 4.</p>
  71.      */
  72.     private final int robustnessIters;
  73.     /**
  74.      * If the median residual at a certain robustness iteration
  75.      * is less than this amount, no more iterations are done.
  76.      */
  77.     private final double accuracy;

  78.     /**
  79.      * Constructs a new {@link LoessInterpolator}
  80.      * with a bandwidth of {@link #DEFAULT_BANDWIDTH},
  81.      * {@link #DEFAULT_ROBUSTNESS_ITERS} robustness iterations
  82.      * and an accuracy of {#link #DEFAULT_ACCURACY}.
  83.      * See {@link #LoessInterpolator(double, int, double)} for an explanation of
  84.      * the parameters.
  85.      */
  86.     public LoessInterpolator() {
  87.         this.bandwidth = DEFAULT_BANDWIDTH;
  88.         this.robustnessIters = DEFAULT_ROBUSTNESS_ITERS;
  89.         this.accuracy = DEFAULT_ACCURACY;
  90.     }

  91.     /**
  92.      * Construct a new {@link LoessInterpolator}
  93.      * with given bandwidth and number of robustness iterations.
  94.      * <p>
  95.      * Calling this constructor is equivalent to calling {link {@link
  96.      * #LoessInterpolator(double, int, double) LoessInterpolator(bandwidth,
  97.      * robustnessIters, LoessInterpolator.DEFAULT_ACCURACY)}
  98.      * </p>
  99.      *
  100.      * @param bandwidth  when computing the loess fit at
  101.      * a particular point, this fraction of source points closest
  102.      * to the current point is taken into account for computing
  103.      * a least-squares regression.
  104.      * A sensible value is usually 0.25 to 0.5, the default value is
  105.      * {@link #DEFAULT_BANDWIDTH}.
  106.      * @param robustnessIters This many robustness iterations are done.
  107.      * A sensible value is usually 0 (just the initial fit without any
  108.      * robustness iterations) to 4, the default value is
  109.      * {@link #DEFAULT_ROBUSTNESS_ITERS}.

  110.      * @see #LoessInterpolator(double, int, double)
  111.      */
  112.     public LoessInterpolator(double bandwidth, int robustnessIters) {
  113.         this(bandwidth, robustnessIters, DEFAULT_ACCURACY);
  114.     }

  115.     /**
  116.      * Construct a new {@link LoessInterpolator}
  117.      * with given bandwidth, number of robustness iterations and accuracy.
  118.      *
  119.      * @param bandwidth  when computing the loess fit at
  120.      * a particular point, this fraction of source points closest
  121.      * to the current point is taken into account for computing
  122.      * a least-squares regression.
  123.      * A sensible value is usually 0.25 to 0.5, the default value is
  124.      * {@link #DEFAULT_BANDWIDTH}.
  125.      * @param robustnessIters This many robustness iterations are done.
  126.      * A sensible value is usually 0 (just the initial fit without any
  127.      * robustness iterations) to 4, the default value is
  128.      * {@link #DEFAULT_ROBUSTNESS_ITERS}.
  129.      * @param accuracy If the median residual at a certain robustness iteration
  130.      * is less than this amount, no more iterations are done.
  131.      * @throws OutOfRangeException if bandwidth does not lie in the interval [0,1].
  132.      * @throws NotPositiveException if {@code robustnessIters} is negative.
  133.      * @see #LoessInterpolator(double, int)
  134.      * @since 2.1
  135.      */
  136.     public LoessInterpolator(double bandwidth, int robustnessIters, double accuracy)
  137.         throws OutOfRangeException,
  138.                NotPositiveException {
  139.         if (bandwidth < 0 ||
  140.             bandwidth > 1) {
  141.             throw new OutOfRangeException(LocalizedFormats.BANDWIDTH, bandwidth, 0, 1);
  142.         }
  143.         this.bandwidth = bandwidth;
  144.         if (robustnessIters < 0) {
  145.             throw new NotPositiveException(LocalizedFormats.ROBUSTNESS_ITERATIONS, robustnessIters);
  146.         }
  147.         this.robustnessIters = robustnessIters;
  148.         this.accuracy = accuracy;
  149.     }

  150.     /**
  151.      * Compute an interpolating function by performing a loess fit
  152.      * on the data at the original abscissae and then building a cubic spline
  153.      * with a
  154.      * {@link org.apache.commons.math4.legacy.analysis.interpolation.SplineInterpolator}
  155.      * on the resulting fit.
  156.      *
  157.      * @param xval the arguments for the interpolation points
  158.      * @param yval the values for the interpolation points
  159.      * @return A cubic spline built upon a loess fit to the data at the original abscissae
  160.      * @throws NonMonotonicSequenceException if {@code xval} not sorted in
  161.      * strictly increasing order.
  162.      * @throws DimensionMismatchException if {@code xval} and {@code yval} have
  163.      * different sizes.
  164.      * @throws NoDataException if {@code xval} or {@code yval} has zero size.
  165.      * @throws NotFiniteNumberException if any of the arguments and values are
  166.      * not finite real numbers.
  167.      * @throws NumberIsTooSmallException if the bandwidth is too small to
  168.      * accommodate the size of the input data (i.e. the bandwidth must be
  169.      * larger than 2/n).
  170.      */
  171.     @Override
  172.     public final PolynomialSplineFunction interpolate(final double[] xval,
  173.                                                       final double[] yval)
  174.         throws NonMonotonicSequenceException,
  175.                DimensionMismatchException,
  176.                NoDataException,
  177.                NotFiniteNumberException,
  178.                NumberIsTooSmallException {
  179.         return new SplineInterpolator().interpolate(xval, smooth(xval, yval));
  180.     }

  181.     /**
  182.      * Compute a weighted loess fit on the data at the original abscissae.
  183.      *
  184.      * @param xval Arguments for the interpolation points.
  185.      * @param yval Values for the interpolation points.
  186.      * @param weights point weights: coefficients by which the robustness weight
  187.      * of a point is multiplied.
  188.      * @return the values of the loess fit at corresponding original abscissae.
  189.      * @throws NonMonotonicSequenceException if {@code xval} not sorted in
  190.      * strictly increasing order.
  191.      * @throws DimensionMismatchException if {@code xval} and {@code yval} have
  192.      * different sizes.
  193.      * @throws NoDataException if {@code xval} or {@code yval} has zero size.
  194.      * @throws NotFiniteNumberException if any of the arguments and values are
  195.      not finite real numbers.
  196.      * @throws NumberIsTooSmallException if the bandwidth is too small to
  197.      * accommodate the size of the input data (i.e. the bandwidth must be
  198.      * larger than 2/n).
  199.      * @since 2.1
  200.      */
  201.     public final double[] smooth(final double[] xval, final double[] yval,
  202.                                  final double[] weights)
  203.         throws NonMonotonicSequenceException,
  204.                DimensionMismatchException,
  205.                NoDataException,
  206.                NotFiniteNumberException,
  207.                NumberIsTooSmallException {
  208.         if (xval.length != yval.length) {
  209.             throw new DimensionMismatchException(xval.length, yval.length);
  210.         }

  211.         final int n = xval.length;

  212.         if (n == 0) {
  213.             throw new NoDataException();
  214.         }

  215.         NotFiniteNumberException.check(xval);
  216.         NotFiniteNumberException.check(yval);
  217.         NotFiniteNumberException.check(weights);

  218.         MathArrays.checkOrder(xval);

  219.         if (n == 1) {
  220.             return new double[]{yval[0]};
  221.         }

  222.         if (n == 2) {
  223.             return new double[]{yval[0], yval[1]};
  224.         }

  225.         int bandwidthInPoints = (int) (bandwidth * n);

  226.         if (bandwidthInPoints < 2) {
  227.             throw new NumberIsTooSmallException(LocalizedFormats.BANDWIDTH,
  228.                                                 bandwidthInPoints, 2, true);
  229.         }

  230.         final double[] res = new double[n];

  231.         final double[] residuals = new double[n];
  232.         final double[] sortedResiduals = new double[n];

  233.         final double[] robustnessWeights = new double[n];

  234.         // Do an initial fit and 'robustnessIters' robustness iterations.
  235.         // This is equivalent to doing 'robustnessIters+1' robustness iterations
  236.         // starting with all robustness weights set to 1.
  237.         Arrays.fill(robustnessWeights, 1);

  238.         for (int iter = 0; iter <= robustnessIters; ++iter) {
  239.             final int[] bandwidthInterval = {0, bandwidthInPoints - 1};
  240.             // At each x, compute a local weighted linear regression
  241.             for (int i = 0; i < n; ++i) {
  242.                 final double x = xval[i];

  243.                 // Find out the interval of source points on which
  244.                 // a regression is to be made.
  245.                 if (i > 0) {
  246.                     updateBandwidthInterval(xval, weights, i, bandwidthInterval);
  247.                 }

  248.                 final int ileft = bandwidthInterval[0];
  249.                 final int iright = bandwidthInterval[1];

  250.                 // Compute the point of the bandwidth interval that is
  251.                 // farthest from x
  252.                 final int edge;
  253.                 if (xval[i] - xval[ileft] > xval[iright] - xval[i]) {
  254.                     edge = ileft;
  255.                 } else {
  256.                     edge = iright;
  257.                 }

  258.                 // Compute a least-squares linear fit weighted by
  259.                 // the product of robustness weights and the tricube
  260.                 // weight function.
  261.                 // See http://en.wikipedia.org/wiki/Linear_regression
  262.                 // (section "Univariate linear case")
  263.                 // and http://en.wikipedia.org/wiki/Weighted_least_squares
  264.                 // (section "Weighted least squares")
  265.                 double sumWeights = 0;
  266.                 double sumX = 0;
  267.                 double sumXSquared = 0;
  268.                 double sumY = 0;
  269.                 double sumXY = 0;
  270.                 double denom = JdkMath.abs(1.0 / (xval[edge] - x));
  271.                 for (int k = ileft; k <= iright; ++k) {
  272.                     final double xk   = xval[k];
  273.                     final double yk   = yval[k];
  274.                     final double dist = (k < i) ? x - xk : xk - x;
  275.                     final double w    = tricube(dist * denom) * robustnessWeights[k] * weights[k];
  276.                     final double xkw  = xk * w;
  277.                     sumWeights += w;
  278.                     sumX += xkw;
  279.                     sumXSquared += xk * xkw;
  280.                     sumY += yk * w;
  281.                     sumXY += yk * xkw;
  282.                 }

  283.                 final double meanX = sumX / sumWeights;
  284.                 final double meanY = sumY / sumWeights;
  285.                 final double meanXY = sumXY / sumWeights;
  286.                 final double meanXSquared = sumXSquared / sumWeights;

  287.                 final double beta;
  288.                 if (JdkMath.sqrt(JdkMath.abs(meanXSquared - meanX * meanX)) < accuracy) {
  289.                     beta = 0;
  290.                 } else {
  291.                     beta = (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX);
  292.                 }

  293.                 final double alpha = meanY - beta * meanX;

  294.                 res[i] = beta * x + alpha;
  295.                 residuals[i] = JdkMath.abs(yval[i] - res[i]);
  296.             }

  297.             // No need to recompute the robustness weights at the last
  298.             // iteration, they won't be needed anymore
  299.             if (iter == robustnessIters) {
  300.                 break;
  301.             }

  302.             // Recompute the robustness weights.

  303.             // Find the median residual.
  304.             // An arraycopy and a sort are completely tractable here,
  305.             // because the preceding loop is a lot more expensive
  306.             System.arraycopy(residuals, 0, sortedResiduals, 0, n);
  307.             Arrays.sort(sortedResiduals);
  308.             final double medianResidual = sortedResiduals[n / 2];

  309.             if (JdkMath.abs(medianResidual) < accuracy) {
  310.                 break;
  311.             }

  312.             for (int i = 0; i < n; ++i) {
  313.                 final double arg = residuals[i] / (6 * medianResidual);
  314.                 if (arg >= 1) {
  315.                     robustnessWeights[i] = 0;
  316.                 } else {
  317.                     final double w = 1 - arg * arg;
  318.                     robustnessWeights[i] = w * w;
  319.                 }
  320.             }
  321.         }

  322.         return res;
  323.     }

  324.     /**
  325.      * Compute a loess fit on the data at the original abscissae.
  326.      *
  327.      * @param xval the arguments for the interpolation points
  328.      * @param yval the values for the interpolation points
  329.      * @return values of the loess fit at corresponding original abscissae
  330.      * @throws NonMonotonicSequenceException if {@code xval} not sorted in
  331.      * strictly increasing order.
  332.      * @throws DimensionMismatchException if {@code xval} and {@code yval} have
  333.      * different sizes.
  334.      * @throws NoDataException if {@code xval} or {@code yval} has zero size.
  335.      * @throws NotFiniteNumberException if any of the arguments and values are
  336.      * not finite real numbers.
  337.      * @throws NumberIsTooSmallException if the bandwidth is too small to
  338.      * accommodate the size of the input data (i.e. the bandwidth must be
  339.      * larger than 2/n).
  340.      */
  341.     public final double[] smooth(final double[] xval, final double[] yval)
  342.         throws NonMonotonicSequenceException,
  343.                DimensionMismatchException,
  344.                NoDataException,
  345.                NotFiniteNumberException,
  346.                NumberIsTooSmallException {
  347.         if (xval.length != yval.length) {
  348.             throw new DimensionMismatchException(xval.length, yval.length);
  349.         }

  350.         final double[] unitWeights = new double[xval.length];
  351.         Arrays.fill(unitWeights, 1.0);

  352.         return smooth(xval, yval, unitWeights);
  353.     }

  354.     /**
  355.      * Given an index interval into xval that embraces a certain number of
  356.      * points closest to {@code xval[i-1]}, update the interval so that it
  357.      * embraces the same number of points closest to {@code xval[i]},
  358.      * ignoring zero weights.
  359.      *
  360.      * @param xval Arguments array.
  361.      * @param weights Weights array.
  362.      * @param i Index around which the new interval should be computed.
  363.      * @param bandwidthInterval a two-element array {left, right} such that:
  364.      * {@code (left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])}
  365.      * and
  366.      * {@code (right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])}.
  367.      * The array will be updated.
  368.      */
  369.     private static void updateBandwidthInterval(final double[] xval,
  370.                                                 final double[] weights,
  371.                                                 final int i,
  372.                                                 final int[] bandwidthInterval) {
  373.         final int left = bandwidthInterval[0];
  374.         final int right = bandwidthInterval[1];

  375.         // The right edge should be adjusted if the next point to the right
  376.         // is closer to xval[i] than the leftmost point of the current interval
  377.         int nextRight = nextNonzero(weights, right);
  378.         int nextLeft = left;
  379.         while (nextRight < xval.length &&
  380.                xval[nextRight] - xval[i] < xval[i] - xval[nextLeft]) {
  381.             nextLeft = nextNonzero(weights, bandwidthInterval[0]);
  382.             bandwidthInterval[0] = nextLeft;
  383.             bandwidthInterval[1] = nextRight;
  384.             nextRight = nextNonzero(weights, nextRight);
  385.         }
  386.     }

  387.     /**
  388.      * Return the smallest index {@code j} such that
  389.      * {@code j > i && (j == weights.length || weights[j] != 0)}.
  390.      *
  391.      * @param weights Weights array.
  392.      * @param i Index from which to start search.
  393.      * @return the smallest compliant index.
  394.      */
  395.     private static int nextNonzero(final double[] weights, final int i) {
  396.         int j = i + 1;
  397.         while(j < weights.length && weights[j] == 0) {
  398.             ++j;
  399.         }
  400.         return j;
  401.     }

  402.     /**
  403.      * Compute the
  404.      * <a href="http://en.wikipedia.org/wiki/Local_regression#Weight_function">tricube</a>
  405.      * weight function.
  406.      *
  407.      * @param x Argument.
  408.      * @return <code>(1 - |x|<sup>3</sup>)<sup>3</sup></code> for |x| &lt; 1, 0 otherwise.
  409.      */
  410.     private static double tricube(final double x) {
  411.         final double absX = JdkMath.abs(x);
  412.         if (absX >= 1.0) {
  413.             return 0.0;
  414.         }
  415.         final double tmp = 1 - absX * absX * absX;
  416.         return tmp * tmp * tmp;
  417.     }
  418. }