001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.legacy.analysis.interpolation;
018
019import java.util.Arrays;
020
021import org.apache.commons.math4.legacy.analysis.polynomials.PolynomialSplineFunction;
022import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
023import org.apache.commons.math4.legacy.exception.NoDataException;
024import org.apache.commons.math4.legacy.exception.NonMonotonicSequenceException;
025import org.apache.commons.math4.legacy.exception.NotFiniteNumberException;
026import org.apache.commons.math4.legacy.exception.NotPositiveException;
027import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
028import org.apache.commons.math4.legacy.exception.OutOfRangeException;
029import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
030import org.apache.commons.math4.core.jdkmath.JdkMath;
031import org.apache.commons.math4.legacy.core.MathArrays;
032
033/**
034 * Implements the <a href="http://en.wikipedia.org/wiki/Local_regression">
035 * Local Regression Algorithm</a> (also Loess, Lowess) for interpolation of
036 * real univariate functions.
037 * <p>
038 * For reference, see
039 * <a href="http://amstat.tandfonline.com/doi/abs/10.1080/01621459.1979.10481038">
040 * William S. Cleveland - Robust Locally Weighted Regression and Smoothing
041 * Scatterplots</a></p>
042 * <p>
043 * This class implements both the loess method and serves as an interpolation
044 * adapter to it, allowing one to build a spline on the obtained loess fit.</p>
045 *
046 * @since 2.0
047 */
048public class LoessInterpolator
049    implements UnivariateInterpolator {
050    /** Default value of the bandwidth parameter. */
051    public static final double DEFAULT_BANDWIDTH = 0.3;
052    /** Default value of the number of robustness iterations. */
053    public static final int DEFAULT_ROBUSTNESS_ITERS = 2;
054    /**
055     * Default value for accuracy.
056     * @since 2.1
057     */
058    public static final double DEFAULT_ACCURACY = 1e-12;
059    /**
060     * The bandwidth parameter: when computing the loess fit at
061     * a particular point, this fraction of source points closest
062     * to the current point is taken into account for computing
063     * a least-squares regression.
064     * <p>
065     * A sensible value is usually 0.25 to 0.5.</p>
066     */
067    private final double bandwidth;
068    /**
069     * The number of robustness iterations parameter: this many
070     * robustness iterations are done.
071     * <p>
072     * A sensible value is usually 0 (just the initial fit without any
073     * robustness iterations) to 4.</p>
074     */
075    private final int robustnessIters;
076    /**
077     * If the median residual at a certain robustness iteration
078     * is less than this amount, no more iterations are done.
079     */
080    private final double accuracy;
081
082    /**
083     * Constructs a new {@link LoessInterpolator}
084     * with a bandwidth of {@link #DEFAULT_BANDWIDTH},
085     * {@link #DEFAULT_ROBUSTNESS_ITERS} robustness iterations
086     * and an accuracy of {#link #DEFAULT_ACCURACY}.
087     * See {@link #LoessInterpolator(double, int, double)} for an explanation of
088     * the parameters.
089     */
090    public LoessInterpolator() {
091        this.bandwidth = DEFAULT_BANDWIDTH;
092        this.robustnessIters = DEFAULT_ROBUSTNESS_ITERS;
093        this.accuracy = DEFAULT_ACCURACY;
094    }
095
096    /**
097     * Construct a new {@link LoessInterpolator}
098     * with given bandwidth and number of robustness iterations.
099     * <p>
100     * Calling this constructor is equivalent to calling {link {@link
101     * #LoessInterpolator(double, int, double) LoessInterpolator(bandwidth,
102     * robustnessIters, LoessInterpolator.DEFAULT_ACCURACY)}
103     * </p>
104     *
105     * @param bandwidth  when computing the loess fit at
106     * a particular point, this fraction of source points closest
107     * to the current point is taken into account for computing
108     * a least-squares regression.
109     * A sensible value is usually 0.25 to 0.5, the default value is
110     * {@link #DEFAULT_BANDWIDTH}.
111     * @param robustnessIters This many robustness iterations are done.
112     * A sensible value is usually 0 (just the initial fit without any
113     * robustness iterations) to 4, the default value is
114     * {@link #DEFAULT_ROBUSTNESS_ITERS}.
115
116     * @see #LoessInterpolator(double, int, double)
117     */
118    public LoessInterpolator(double bandwidth, int robustnessIters) {
119        this(bandwidth, robustnessIters, DEFAULT_ACCURACY);
120    }
121
122    /**
123     * Construct a new {@link LoessInterpolator}
124     * with given bandwidth, number of robustness iterations and accuracy.
125     *
126     * @param bandwidth  when computing the loess fit at
127     * a particular point, this fraction of source points closest
128     * to the current point is taken into account for computing
129     * a least-squares regression.
130     * A sensible value is usually 0.25 to 0.5, the default value is
131     * {@link #DEFAULT_BANDWIDTH}.
132     * @param robustnessIters This many robustness iterations are done.
133     * A sensible value is usually 0 (just the initial fit without any
134     * robustness iterations) to 4, the default value is
135     * {@link #DEFAULT_ROBUSTNESS_ITERS}.
136     * @param accuracy If the median residual at a certain robustness iteration
137     * is less than this amount, no more iterations are done.
138     * @throws OutOfRangeException if bandwidth does not lie in the interval [0,1].
139     * @throws NotPositiveException if {@code robustnessIters} is negative.
140     * @see #LoessInterpolator(double, int)
141     * @since 2.1
142     */
143    public LoessInterpolator(double bandwidth, int robustnessIters, double accuracy)
144        throws OutOfRangeException,
145               NotPositiveException {
146        if (bandwidth < 0 ||
147            bandwidth > 1) {
148            throw new OutOfRangeException(LocalizedFormats.BANDWIDTH, bandwidth, 0, 1);
149        }
150        this.bandwidth = bandwidth;
151        if (robustnessIters < 0) {
152            throw new NotPositiveException(LocalizedFormats.ROBUSTNESS_ITERATIONS, robustnessIters);
153        }
154        this.robustnessIters = robustnessIters;
155        this.accuracy = accuracy;
156    }
157
158    /**
159     * Compute an interpolating function by performing a loess fit
160     * on the data at the original abscissae and then building a cubic spline
161     * with a
162     * {@link org.apache.commons.math4.legacy.analysis.interpolation.SplineInterpolator}
163     * on the resulting fit.
164     *
165     * @param xval the arguments for the interpolation points
166     * @param yval the values for the interpolation points
167     * @return A cubic spline built upon a loess fit to the data at the original abscissae
168     * @throws NonMonotonicSequenceException if {@code xval} not sorted in
169     * strictly increasing order.
170     * @throws DimensionMismatchException if {@code xval} and {@code yval} have
171     * different sizes.
172     * @throws NoDataException if {@code xval} or {@code yval} has zero size.
173     * @throws NotFiniteNumberException if any of the arguments and values are
174     * not finite real numbers.
175     * @throws NumberIsTooSmallException if the bandwidth is too small to
176     * accommodate the size of the input data (i.e. the bandwidth must be
177     * larger than 2/n).
178     */
179    @Override
180    public final PolynomialSplineFunction interpolate(final double[] xval,
181                                                      final double[] yval)
182        throws NonMonotonicSequenceException,
183               DimensionMismatchException,
184               NoDataException,
185               NotFiniteNumberException,
186               NumberIsTooSmallException {
187        return new SplineInterpolator().interpolate(xval, smooth(xval, yval));
188    }
189
190    /**
191     * Compute a weighted loess fit on the data at the original abscissae.
192     *
193     * @param xval Arguments for the interpolation points.
194     * @param yval Values for the interpolation points.
195     * @param weights point weights: coefficients by which the robustness weight
196     * of a point is multiplied.
197     * @return the values of the loess fit at corresponding original abscissae.
198     * @throws NonMonotonicSequenceException if {@code xval} not sorted in
199     * strictly increasing order.
200     * @throws DimensionMismatchException if {@code xval} and {@code yval} have
201     * different sizes.
202     * @throws NoDataException if {@code xval} or {@code yval} has zero size.
203     * @throws NotFiniteNumberException if any of the arguments and values are
204     not finite real numbers.
205     * @throws NumberIsTooSmallException if the bandwidth is too small to
206     * accommodate the size of the input data (i.e. the bandwidth must be
207     * larger than 2/n).
208     * @since 2.1
209     */
210    public final double[] smooth(final double[] xval, final double[] yval,
211                                 final double[] weights)
212        throws NonMonotonicSequenceException,
213               DimensionMismatchException,
214               NoDataException,
215               NotFiniteNumberException,
216               NumberIsTooSmallException {
217        if (xval.length != yval.length) {
218            throw new DimensionMismatchException(xval.length, yval.length);
219        }
220
221        final int n = xval.length;
222
223        if (n == 0) {
224            throw new NoDataException();
225        }
226
227        NotFiniteNumberException.check(xval);
228        NotFiniteNumberException.check(yval);
229        NotFiniteNumberException.check(weights);
230
231        MathArrays.checkOrder(xval);
232
233        if (n == 1) {
234            return new double[]{yval[0]};
235        }
236
237        if (n == 2) {
238            return new double[]{yval[0], yval[1]};
239        }
240
241        int bandwidthInPoints = (int) (bandwidth * n);
242
243        if (bandwidthInPoints < 2) {
244            throw new NumberIsTooSmallException(LocalizedFormats.BANDWIDTH,
245                                                bandwidthInPoints, 2, true);
246        }
247
248        final double[] res = new double[n];
249
250        final double[] residuals = new double[n];
251        final double[] sortedResiduals = new double[n];
252
253        final double[] robustnessWeights = new double[n];
254
255        // Do an initial fit and 'robustnessIters' robustness iterations.
256        // This is equivalent to doing 'robustnessIters+1' robustness iterations
257        // starting with all robustness weights set to 1.
258        Arrays.fill(robustnessWeights, 1);
259
260        for (int iter = 0; iter <= robustnessIters; ++iter) {
261            final int[] bandwidthInterval = {0, bandwidthInPoints - 1};
262            // At each x, compute a local weighted linear regression
263            for (int i = 0; i < n; ++i) {
264                final double x = xval[i];
265
266                // Find out the interval of source points on which
267                // a regression is to be made.
268                if (i > 0) {
269                    updateBandwidthInterval(xval, weights, i, bandwidthInterval);
270                }
271
272                final int ileft = bandwidthInterval[0];
273                final int iright = bandwidthInterval[1];
274
275                // Compute the point of the bandwidth interval that is
276                // farthest from x
277                final int edge;
278                if (xval[i] - xval[ileft] > xval[iright] - xval[i]) {
279                    edge = ileft;
280                } else {
281                    edge = iright;
282                }
283
284                // Compute a least-squares linear fit weighted by
285                // the product of robustness weights and the tricube
286                // weight function.
287                // See http://en.wikipedia.org/wiki/Linear_regression
288                // (section "Univariate linear case")
289                // and http://en.wikipedia.org/wiki/Weighted_least_squares
290                // (section "Weighted least squares")
291                double sumWeights = 0;
292                double sumX = 0;
293                double sumXSquared = 0;
294                double sumY = 0;
295                double sumXY = 0;
296                double denom = JdkMath.abs(1.0 / (xval[edge] - x));
297                for (int k = ileft; k <= iright; ++k) {
298                    final double xk   = xval[k];
299                    final double yk   = yval[k];
300                    final double dist = (k < i) ? x - xk : xk - x;
301                    final double w    = tricube(dist * denom) * robustnessWeights[k] * weights[k];
302                    final double xkw  = xk * w;
303                    sumWeights += w;
304                    sumX += xkw;
305                    sumXSquared += xk * xkw;
306                    sumY += yk * w;
307                    sumXY += yk * xkw;
308                }
309
310                final double meanX = sumX / sumWeights;
311                final double meanY = sumY / sumWeights;
312                final double meanXY = sumXY / sumWeights;
313                final double meanXSquared = sumXSquared / sumWeights;
314
315                final double beta;
316                if (JdkMath.sqrt(JdkMath.abs(meanXSquared - meanX * meanX)) < accuracy) {
317                    beta = 0;
318                } else {
319                    beta = (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX);
320                }
321
322                final double alpha = meanY - beta * meanX;
323
324                res[i] = beta * x + alpha;
325                residuals[i] = JdkMath.abs(yval[i] - res[i]);
326            }
327
328            // No need to recompute the robustness weights at the last
329            // iteration, they won't be needed anymore
330            if (iter == robustnessIters) {
331                break;
332            }
333
334            // Recompute the robustness weights.
335
336            // Find the median residual.
337            // An arraycopy and a sort are completely tractable here,
338            // because the preceding loop is a lot more expensive
339            System.arraycopy(residuals, 0, sortedResiduals, 0, n);
340            Arrays.sort(sortedResiduals);
341            final double medianResidual = sortedResiduals[n / 2];
342
343            if (JdkMath.abs(medianResidual) < accuracy) {
344                break;
345            }
346
347            for (int i = 0; i < n; ++i) {
348                final double arg = residuals[i] / (6 * medianResidual);
349                if (arg >= 1) {
350                    robustnessWeights[i] = 0;
351                } else {
352                    final double w = 1 - arg * arg;
353                    robustnessWeights[i] = w * w;
354                }
355            }
356        }
357
358        return res;
359    }
360
361    /**
362     * Compute a loess fit on the data at the original abscissae.
363     *
364     * @param xval the arguments for the interpolation points
365     * @param yval the values for the interpolation points
366     * @return values of the loess fit at corresponding original abscissae
367     * @throws NonMonotonicSequenceException if {@code xval} not sorted in
368     * strictly increasing order.
369     * @throws DimensionMismatchException if {@code xval} and {@code yval} have
370     * different sizes.
371     * @throws NoDataException if {@code xval} or {@code yval} has zero size.
372     * @throws NotFiniteNumberException if any of the arguments and values are
373     * not finite real numbers.
374     * @throws NumberIsTooSmallException if the bandwidth is too small to
375     * accommodate the size of the input data (i.e. the bandwidth must be
376     * larger than 2/n).
377     */
378    public final double[] smooth(final double[] xval, final double[] yval)
379        throws NonMonotonicSequenceException,
380               DimensionMismatchException,
381               NoDataException,
382               NotFiniteNumberException,
383               NumberIsTooSmallException {
384        if (xval.length != yval.length) {
385            throw new DimensionMismatchException(xval.length, yval.length);
386        }
387
388        final double[] unitWeights = new double[xval.length];
389        Arrays.fill(unitWeights, 1.0);
390
391        return smooth(xval, yval, unitWeights);
392    }
393
394    /**
395     * Given an index interval into xval that embraces a certain number of
396     * points closest to {@code xval[i-1]}, update the interval so that it
397     * embraces the same number of points closest to {@code xval[i]},
398     * ignoring zero weights.
399     *
400     * @param xval Arguments array.
401     * @param weights Weights array.
402     * @param i Index around which the new interval should be computed.
403     * @param bandwidthInterval a two-element array {left, right} such that:
404     * {@code (left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])}
405     * and
406     * {@code (right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])}.
407     * The array will be updated.
408     */
409    private static void updateBandwidthInterval(final double[] xval,
410                                                final double[] weights,
411                                                final int i,
412                                                final int[] bandwidthInterval) {
413        final int left = bandwidthInterval[0];
414        final int right = bandwidthInterval[1];
415
416        // The right edge should be adjusted if the next point to the right
417        // is closer to xval[i] than the leftmost point of the current interval
418        int nextRight = nextNonzero(weights, right);
419        int nextLeft = left;
420        while (nextRight < xval.length &&
421               xval[nextRight] - xval[i] < xval[i] - xval[nextLeft]) {
422            nextLeft = nextNonzero(weights, bandwidthInterval[0]);
423            bandwidthInterval[0] = nextLeft;
424            bandwidthInterval[1] = nextRight;
425            nextRight = nextNonzero(weights, nextRight);
426        }
427    }
428
429    /**
430     * Return the smallest index {@code j} such that
431     * {@code j > i && (j == weights.length || weights[j] != 0)}.
432     *
433     * @param weights Weights array.
434     * @param i Index from which to start search.
435     * @return the smallest compliant index.
436     */
437    private static int nextNonzero(final double[] weights, final int i) {
438        int j = i + 1;
439        while(j < weights.length && weights[j] == 0) {
440            ++j;
441        }
442        return j;
443    }
444
445    /**
446     * Compute the
447     * <a href="http://en.wikipedia.org/wiki/Local_regression#Weight_function">tricube</a>
448     * weight function.
449     *
450     * @param x Argument.
451     * @return <code>(1 - |x|<sup>3</sup>)<sup>3</sup></code> for |x| &lt; 1, 0 otherwise.
452     */
453    private static double tricube(final double x) {
454        final double absX = JdkMath.abs(x);
455        if (absX >= 1.0) {
456            return 0.0;
457        }
458        final double tmp = 1 - absX * absX * absX;
459        return tmp * tmp * tmp;
460    }
461}