001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.analysis.interpolation;
018
019import java.io.Serializable;
020import java.util.Arrays;
021
022import org.apache.commons.math3.analysis.polynomials.PolynomialSplineFunction;
023import org.apache.commons.math3.exception.DimensionMismatchException;
024import org.apache.commons.math3.exception.NoDataException;
025import org.apache.commons.math3.exception.NonMonotonicSequenceException;
026import org.apache.commons.math3.exception.NotFiniteNumberException;
027import org.apache.commons.math3.exception.NotPositiveException;
028import org.apache.commons.math3.exception.NumberIsTooSmallException;
029import org.apache.commons.math3.exception.OutOfRangeException;
030import org.apache.commons.math3.exception.util.LocalizedFormats;
031import org.apache.commons.math3.util.FastMath;
032import org.apache.commons.math3.util.MathArrays;
033import org.apache.commons.math3.util.MathUtils;
034
035/**
036 * Implements the <a href="http://en.wikipedia.org/wiki/Local_regression">
037 * Local Regression Algorithm</a> (also Loess, Lowess) for interpolation of
038 * real univariate functions.
039 * <p>
040 * For reference, see
041 * <a href="http://amstat.tandfonline.com/doi/abs/10.1080/01621459.1979.10481038">
042 * William S. Cleveland - Robust Locally Weighted Regression and Smoothing
043 * Scatterplots</a>
044 * </p>
045 * This class implements both the loess method and serves as an interpolation
046 * adapter to it, allowing one to build a spline on the obtained loess fit.
047 *
048 * @since 2.0
049 */
050public class LoessInterpolator
051    implements UnivariateInterpolator, Serializable {
052    /** Default value of the bandwidth parameter. */
053    public static final double DEFAULT_BANDWIDTH = 0.3;
054    /** Default value of the number of robustness iterations. */
055    public static final int DEFAULT_ROBUSTNESS_ITERS = 2;
056    /**
057     * Default value for accuracy.
058     * @since 2.1
059     */
060    public static final double DEFAULT_ACCURACY = 1e-12;
061    /** serializable version identifier. */
062    private static final long serialVersionUID = 5204927143605193821L;
063    /**
064     * The bandwidth parameter: when computing the loess fit at
065     * a particular point, this fraction of source points closest
066     * to the current point is taken into account for computing
067     * a least-squares regression.
068     * <p>
069     * A sensible value is usually 0.25 to 0.5.</p>
070     */
071    private final double bandwidth;
072    /**
073     * The number of robustness iterations parameter: this many
074     * robustness iterations are done.
075     * <p>
076     * A sensible value is usually 0 (just the initial fit without any
077     * robustness iterations) to 4.</p>
078     */
079    private final int robustnessIters;
080    /**
081     * If the median residual at a certain robustness iteration
082     * is less than this amount, no more iterations are done.
083     */
084    private final double accuracy;
085
086    /**
087     * Constructs a new {@link LoessInterpolator}
088     * with a bandwidth of {@link #DEFAULT_BANDWIDTH},
089     * {@link #DEFAULT_ROBUSTNESS_ITERS} robustness iterations
090     * and an accuracy of {#link #DEFAULT_ACCURACY}.
091     * See {@link #LoessInterpolator(double, int, double)} for an explanation of
092     * the parameters.
093     */
094    public LoessInterpolator() {
095        this.bandwidth = DEFAULT_BANDWIDTH;
096        this.robustnessIters = DEFAULT_ROBUSTNESS_ITERS;
097        this.accuracy = DEFAULT_ACCURACY;
098    }
099
100    /**
101     * Construct a new {@link LoessInterpolator}
102     * with given bandwidth and number of robustness iterations.
103     * <p>
104     * Calling this constructor is equivalent to calling {link {@link
105     * #LoessInterpolator(double, int, double) LoessInterpolator(bandwidth,
106     * robustnessIters, LoessInterpolator.DEFAULT_ACCURACY)}
107     * </p>
108     *
109     * @param bandwidth  when computing the loess fit at
110     * a particular point, this fraction of source points closest
111     * to the current point is taken into account for computing
112     * a least-squares regression.
113     * A sensible value is usually 0.25 to 0.5, the default value is
114     * {@link #DEFAULT_BANDWIDTH}.
115     * @param robustnessIters This many robustness iterations are done.
116     * A sensible value is usually 0 (just the initial fit without any
117     * robustness iterations) to 4, the default value is
118     * {@link #DEFAULT_ROBUSTNESS_ITERS}.
119
120     * @see #LoessInterpolator(double, int, double)
121     */
122    public LoessInterpolator(double bandwidth, int robustnessIters) {
123        this(bandwidth, robustnessIters, DEFAULT_ACCURACY);
124    }
125
126    /**
127     * Construct a new {@link LoessInterpolator}
128     * with given bandwidth, number of robustness iterations and accuracy.
129     *
130     * @param bandwidth  when computing the loess fit at
131     * a particular point, this fraction of source points closest
132     * to the current point is taken into account for computing
133     * a least-squares regression.
134     * A sensible value is usually 0.25 to 0.5, the default value is
135     * {@link #DEFAULT_BANDWIDTH}.
136     * @param robustnessIters This many robustness iterations are done.
137     * A sensible value is usually 0 (just the initial fit without any
138     * robustness iterations) to 4, the default value is
139     * {@link #DEFAULT_ROBUSTNESS_ITERS}.
140     * @param accuracy If the median residual at a certain robustness iteration
141     * is less than this amount, no more iterations are done.
142     * @throws OutOfRangeException if bandwidth does not lie in the interval [0,1].
143     * @throws NotPositiveException if {@code robustnessIters} is negative.
144     * @see #LoessInterpolator(double, int)
145     * @since 2.1
146     */
147    public LoessInterpolator(double bandwidth, int robustnessIters, double accuracy)
148        throws OutOfRangeException,
149               NotPositiveException {
150        if (bandwidth < 0 ||
151            bandwidth > 1) {
152            throw new OutOfRangeException(LocalizedFormats.BANDWIDTH, bandwidth, 0, 1);
153        }
154        this.bandwidth = bandwidth;
155        if (robustnessIters < 0) {
156            throw new NotPositiveException(LocalizedFormats.ROBUSTNESS_ITERATIONS, robustnessIters);
157        }
158        this.robustnessIters = robustnessIters;
159        this.accuracy = accuracy;
160    }
161
162    /**
163     * Compute an interpolating function by performing a loess fit
164     * on the data at the original abscissae and then building a cubic spline
165     * with a
166     * {@link org.apache.commons.math3.analysis.interpolation.SplineInterpolator}
167     * on the resulting fit.
168     *
169     * @param xval the arguments for the interpolation points
170     * @param yval the values for the interpolation points
171     * @return A cubic spline built upon a loess fit to the data at the original abscissae
172     * @throws NonMonotonicSequenceException if {@code xval} not sorted in
173     * strictly increasing order.
174     * @throws DimensionMismatchException if {@code xval} and {@code yval} have
175     * different sizes.
176     * @throws NoDataException if {@code xval} or {@code yval} has zero size.
177     * @throws NotFiniteNumberException if any of the arguments and values are
178     * not finite real numbers.
179     * @throws NumberIsTooSmallException if the bandwidth is too small to
180     * accomodate the size of the input data (i.e. the bandwidth must be
181     * larger than 2/n).
182     */
183    public final PolynomialSplineFunction interpolate(final double[] xval,
184                                                      final double[] yval)
185        throws NonMonotonicSequenceException,
186               DimensionMismatchException,
187               NoDataException,
188               NotFiniteNumberException,
189               NumberIsTooSmallException {
190        return new SplineInterpolator().interpolate(xval, smooth(xval, yval));
191    }
192
193    /**
194     * Compute a weighted loess fit on the data at the original abscissae.
195     *
196     * @param xval Arguments for the interpolation points.
197     * @param yval Values for the interpolation points.
198     * @param weights point weights: coefficients by which the robustness weight
199     * of a point is multiplied.
200     * @return the values of the loess fit at corresponding original abscissae.
201     * @throws NonMonotonicSequenceException if {@code xval} not sorted in
202     * strictly increasing order.
203     * @throws DimensionMismatchException if {@code xval} and {@code yval} have
204     * different sizes.
205     * @throws NoDataException if {@code xval} or {@code yval} has zero size.
206     * @throws NotFiniteNumberException if any of the arguments and values are
207     not finite real numbers.
208     * @throws NumberIsTooSmallException if the bandwidth is too small to
209     * accomodate the size of the input data (i.e. the bandwidth must be
210     * larger than 2/n).
211     * @since 2.1
212     */
213    public final double[] smooth(final double[] xval, final double[] yval,
214                                 final double[] weights)
215        throws NonMonotonicSequenceException,
216               DimensionMismatchException,
217               NoDataException,
218               NotFiniteNumberException,
219               NumberIsTooSmallException {
220        if (xval.length != yval.length) {
221            throw new DimensionMismatchException(xval.length, yval.length);
222        }
223
224        final int n = xval.length;
225
226        if (n == 0) {
227            throw new NoDataException();
228        }
229
230        checkAllFiniteReal(xval);
231        checkAllFiniteReal(yval);
232        checkAllFiniteReal(weights);
233
234        MathArrays.checkOrder(xval);
235
236        if (n == 1) {
237            return new double[]{yval[0]};
238        }
239
240        if (n == 2) {
241            return new double[]{yval[0], yval[1]};
242        }
243
244        int bandwidthInPoints = (int) (bandwidth * n);
245
246        if (bandwidthInPoints < 2) {
247            throw new NumberIsTooSmallException(LocalizedFormats.BANDWIDTH,
248                                                bandwidthInPoints, 2, true);
249        }
250
251        final double[] res = new double[n];
252
253        final double[] residuals = new double[n];
254        final double[] sortedResiduals = new double[n];
255
256        final double[] robustnessWeights = new double[n];
257
258        // Do an initial fit and 'robustnessIters' robustness iterations.
259        // This is equivalent to doing 'robustnessIters+1' robustness iterations
260        // starting with all robustness weights set to 1.
261        Arrays.fill(robustnessWeights, 1);
262
263        for (int iter = 0; iter <= robustnessIters; ++iter) {
264            final int[] bandwidthInterval = {0, bandwidthInPoints - 1};
265            // At each x, compute a local weighted linear regression
266            for (int i = 0; i < n; ++i) {
267                final double x = xval[i];
268
269                // Find out the interval of source points on which
270                // a regression is to be made.
271                if (i > 0) {
272                    updateBandwidthInterval(xval, weights, i, bandwidthInterval);
273                }
274
275                final int ileft = bandwidthInterval[0];
276                final int iright = bandwidthInterval[1];
277
278                // Compute the point of the bandwidth interval that is
279                // farthest from x
280                final int edge;
281                if (xval[i] - xval[ileft] > xval[iright] - xval[i]) {
282                    edge = ileft;
283                } else {
284                    edge = iright;
285                }
286
287                // Compute a least-squares linear fit weighted by
288                // the product of robustness weights and the tricube
289                // weight function.
290                // See http://en.wikipedia.org/wiki/Linear_regression
291                // (section "Univariate linear case")
292                // and http://en.wikipedia.org/wiki/Weighted_least_squares
293                // (section "Weighted least squares")
294                double sumWeights = 0;
295                double sumX = 0;
296                double sumXSquared = 0;
297                double sumY = 0;
298                double sumXY = 0;
299                double denom = FastMath.abs(1.0 / (xval[edge] - x));
300                for (int k = ileft; k <= iright; ++k) {
301                    final double xk   = xval[k];
302                    final double yk   = yval[k];
303                    final double dist = (k < i) ? x - xk : xk - x;
304                    final double w    = tricube(dist * denom) * robustnessWeights[k] * weights[k];
305                    final double xkw  = xk * w;
306                    sumWeights += w;
307                    sumX += xkw;
308                    sumXSquared += xk * xkw;
309                    sumY += yk * w;
310                    sumXY += yk * xkw;
311                }
312
313                final double meanX = sumX / sumWeights;
314                final double meanY = sumY / sumWeights;
315                final double meanXY = sumXY / sumWeights;
316                final double meanXSquared = sumXSquared / sumWeights;
317
318                final double beta;
319                if (FastMath.sqrt(FastMath.abs(meanXSquared - meanX * meanX)) < accuracy) {
320                    beta = 0;
321                } else {
322                    beta = (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX);
323                }
324
325                final double alpha = meanY - beta * meanX;
326
327                res[i] = beta * x + alpha;
328                residuals[i] = FastMath.abs(yval[i] - res[i]);
329            }
330
331            // No need to recompute the robustness weights at the last
332            // iteration, they won't be needed anymore
333            if (iter == robustnessIters) {
334                break;
335            }
336
337            // Recompute the robustness weights.
338
339            // Find the median residual.
340            // An arraycopy and a sort are completely tractable here,
341            // because the preceding loop is a lot more expensive
342            System.arraycopy(residuals, 0, sortedResiduals, 0, n);
343            Arrays.sort(sortedResiduals);
344            final double medianResidual = sortedResiduals[n / 2];
345
346            if (FastMath.abs(medianResidual) < accuracy) {
347                break;
348            }
349
350            for (int i = 0; i < n; ++i) {
351                final double arg = residuals[i] / (6 * medianResidual);
352                if (arg >= 1) {
353                    robustnessWeights[i] = 0;
354                } else {
355                    final double w = 1 - arg * arg;
356                    robustnessWeights[i] = w * w;
357                }
358            }
359        }
360
361        return res;
362    }
363
364    /**
365     * Compute a loess fit on the data at the original abscissae.
366     *
367     * @param xval the arguments for the interpolation points
368     * @param yval the values for the interpolation points
369     * @return values of the loess fit at corresponding original abscissae
370     * @throws NonMonotonicSequenceException if {@code xval} not sorted in
371     * strictly increasing order.
372     * @throws DimensionMismatchException if {@code xval} and {@code yval} have
373     * different sizes.
374     * @throws NoDataException if {@code xval} or {@code yval} has zero size.
375     * @throws NotFiniteNumberException if any of the arguments and values are
376     * not finite real numbers.
377     * @throws NumberIsTooSmallException if the bandwidth is too small to
378     * accomodate the size of the input data (i.e. the bandwidth must be
379     * larger than 2/n).
380     */
381    public final double[] smooth(final double[] xval, final double[] yval)
382        throws NonMonotonicSequenceException,
383               DimensionMismatchException,
384               NoDataException,
385               NotFiniteNumberException,
386               NumberIsTooSmallException {
387        if (xval.length != yval.length) {
388            throw new DimensionMismatchException(xval.length, yval.length);
389        }
390
391        final double[] unitWeights = new double[xval.length];
392        Arrays.fill(unitWeights, 1.0);
393
394        return smooth(xval, yval, unitWeights);
395    }
396
397    /**
398     * Given an index interval into xval that embraces a certain number of
399     * points closest to {@code xval[i-1]}, update the interval so that it
400     * embraces the same number of points closest to {@code xval[i]},
401     * ignoring zero weights.
402     *
403     * @param xval Arguments array.
404     * @param weights Weights array.
405     * @param i Index around which the new interval should be computed.
406     * @param bandwidthInterval a two-element array {left, right} such that:
407     * {@code (left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])}
408     * and
409     * {@code (right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])}.
410     * The array will be updated.
411     */
412    private static void updateBandwidthInterval(final double[] xval, final double[] weights,
413                                                final int i,
414                                                final int[] bandwidthInterval) {
415        final int left = bandwidthInterval[0];
416        final int right = bandwidthInterval[1];
417
418        // The right edge should be adjusted if the next point to the right
419        // is closer to xval[i] than the leftmost point of the current interval
420        int nextRight = nextNonzero(weights, right);
421        if (nextRight < xval.length && xval[nextRight] - xval[i] < xval[i] - xval[left]) {
422            int nextLeft = nextNonzero(weights, bandwidthInterval[0]);
423            bandwidthInterval[0] = nextLeft;
424            bandwidthInterval[1] = nextRight;
425        }
426    }
427
428    /**
429     * Return the smallest index {@code j} such that
430     * {@code j > i && (j == weights.length || weights[j] != 0)}.
431     *
432     * @param weights Weights array.
433     * @param i Index from which to start search.
434     * @return the smallest compliant index.
435     */
436    private static int nextNonzero(final double[] weights, final int i) {
437        int j = i + 1;
438        while(j < weights.length && weights[j] == 0) {
439            ++j;
440        }
441        return j;
442    }
443
444    /**
445     * Compute the
446     * <a href="http://en.wikipedia.org/wiki/Local_regression#Weight_function">tricube</a>
447     * weight function
448     *
449     * @param x Argument.
450     * @return <code>(1 - |x|<sup>3</sup>)<sup>3</sup></code> for |x| &lt; 1, 0 otherwise.
451     */
452    private static double tricube(final double x) {
453        final double absX = FastMath.abs(x);
454        if (absX >= 1.0) {
455            return 0.0;
456        }
457        final double tmp = 1 - absX * absX * absX;
458        return tmp * tmp * tmp;
459    }
460
461    /**
462     * Check that all elements of an array are finite real numbers.
463     *
464     * @param values Values array.
465     * @throws org.apache.commons.math3.exception.NotFiniteNumberException
466     * if one of the values is not a finite real number.
467     */
468    private static void checkAllFiniteReal(final double[] values) {
469        for (int i = 0; i < values.length; i++) {
470            MathUtils.checkFinite(values[i]);
471        }
472    }
473}