001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017 package org.apache.commons.math.analysis.interpolation;
018
019 import java.io.Serializable;
020 import java.util.Arrays;
021
022 import org.apache.commons.math.MathException;
023 import org.apache.commons.math.analysis.polynomials.PolynomialSplineFunction;
024
025 /**
026 * Implements the <a href="http://en.wikipedia.org/wiki/Local_regression">
027 * Local Regression Algorithm</a> (also Loess, Lowess) for interpolation of
028 * real univariate functions.
029 * <p/>
030 * For reference, see
031 * <a href="http://www.math.tau.ac.il/~yekutiel/MA seminar/Cleveland 1979.pdf">
032 * William S. Cleveland - Robust Locally Weighted Regression and Smoothing
033 * Scatterplots</a>
034 * <p/>
035 * This class implements both the loess method and serves as an interpolation
036 * adapter to it, allowing to build a spline on the obtained loess fit.
037 *
038 * @version $Revision: 825919 $ $Date: 2009-10-16 10:51:55 -0400 (Fri, 16 Oct 2009) $
039 * @since 2.0
040 */
041 public class LoessInterpolator
042 implements UnivariateRealInterpolator, Serializable {
043
044 /** Default value of the bandwidth parameter. */
045 public static final double DEFAULT_BANDWIDTH = 0.3;
046
047 /** Default value of the number of robustness iterations. */
048 public static final int DEFAULT_ROBUSTNESS_ITERS = 2;
049
050 /** Default value for accuracy. */
051 public static final double DEFAULT_ACCURACY = 1e-12;
052
053 /** serializable version identifier. */
054 private static final long serialVersionUID = 5204927143605193821L;
055
056 /**
057 * The bandwidth parameter: when computing the loess fit at
058 * a particular point, this fraction of source points closest
059 * to the current point is taken into account for computing
060 * a least-squares regression.
061 * <p/>
062 * A sensible value is usually 0.25 to 0.5.
063 */
064 private final double bandwidth;
065
066 /**
067 * The number of robustness iterations parameter: this many
068 * robustness iterations are done.
069 * <p/>
070 * A sensible value is usually 0 (just the initial fit without any
071 * robustness iterations) to 4.
072 */
073 private final int robustnessIters;
074
075 /**
076 * If the median residual at a certain robustness iteration
077 * is less than this amount, no more iterations are done.
078 */
079 private final double accuracy;
080
081 /**
082 * Constructs a new {@link LoessInterpolator}
083 * with a bandwidth of {@link #DEFAULT_BANDWIDTH},
084 * {@link #DEFAULT_ROBUSTNESS_ITERS} robustness iterations
085 * and an accuracy of {#link #DEFAULT_ACCURACY}.
086 * See {@link #LoessInterpolator(double, int, double)} for an explanation of
087 * the parameters.
088 */
089 public LoessInterpolator() {
090 this.bandwidth = DEFAULT_BANDWIDTH;
091 this.robustnessIters = DEFAULT_ROBUSTNESS_ITERS;
092 this.accuracy = DEFAULT_ACCURACY;
093 }
094
095 /**
096 * Constructs a new {@link LoessInterpolator}
097 * with given bandwidth and number of robustness iterations.
098 * <p>
099 * Calling this constructor is equivalent to calling {link {@link
100 * #LoessInterpolator(double, int, double) LoessInterpolator(bandwidth,
101 * robustnessIters, LoessInterpolator.DEFAULT_ACCURACY)}
102 * </p>
103 *
104 * @param bandwidth when computing the loess fit at
105 * a particular point, this fraction of source points closest
106 * to the current point is taken into account for computing
107 * a least-squares regression.</br>
108 * A sensible value is usually 0.25 to 0.5, the default value is
109 * {@link #DEFAULT_BANDWIDTH}.
110 * @param robustnessIters This many robustness iterations are done.</br>
111 * A sensible value is usually 0 (just the initial fit without any
112 * robustness iterations) to 4, the default value is
113 * {@link #DEFAULT_ROBUSTNESS_ITERS}.
114 * @throws MathException if bandwidth does not lie in the interval [0,1]
115 * or if robustnessIters is negative.
116 * @see #LoessInterpolator(double, int, double)
117 */
118 public LoessInterpolator(double bandwidth, int robustnessIters) throws MathException {
119 this(bandwidth, robustnessIters, DEFAULT_ACCURACY);
120 }
121
122 /**
123 * Constructs a new {@link LoessInterpolator}
124 * with given bandwidth, number of robustness iterations and accuracy.
125 *
126 * @param bandwidth when computing the loess fit at
127 * a particular point, this fraction of source points closest
128 * to the current point is taken into account for computing
129 * a least-squares regression.</br>
130 * A sensible value is usually 0.25 to 0.5, the default value is
131 * {@link #DEFAULT_BANDWIDTH}.
132 * @param robustnessIters This many robustness iterations are done.</br>
133 * A sensible value is usually 0 (just the initial fit without any
134 * robustness iterations) to 4, the default value is
135 * {@link #DEFAULT_ROBUSTNESS_ITERS}.
136 * @param accuracy If the median residual at a certain robustness iteration
137 * is less than this amount, no more iterations are done.
138 * @throws MathException if bandwidth does not lie in the interval [0,1]
139 * or if robustnessIters is negative.
140 * @see #LoessInterpolator(double, int)
141 * @since 2.1
142 */
143 public LoessInterpolator(double bandwidth, int robustnessIters, double accuracy) throws MathException {
144 if (bandwidth < 0 || bandwidth > 1) {
145 throw new MathException("bandwidth must be in the interval [0,1], but got {0}",
146 bandwidth);
147 }
148 this.bandwidth = bandwidth;
149 if (robustnessIters < 0) {
150 throw new MathException("the number of robustness iterations must " +
151 "be non-negative, but got {0}",
152 robustnessIters);
153 }
154 this.robustnessIters = robustnessIters;
155 this.accuracy = accuracy;
156 }
157
158 /**
159 * Compute an interpolating function by performing a loess fit
160 * on the data at the original abscissae and then building a cubic spline
161 * with a
162 * {@link org.apache.commons.math.analysis.interpolation.SplineInterpolator}
163 * on the resulting fit.
164 *
165 * @param xval the arguments for the interpolation points
166 * @param yval the values for the interpolation points
167 * @return A cubic spline built upon a loess fit to the data at the original abscissae
168 * @throws MathException if some of the following conditions are false:
169 * <ul>
170 * <li> Arguments and values are of the same size that is greater than zero</li>
171 * <li> The arguments are in a strictly increasing order</li>
172 * <li> All arguments and values are finite real numbers</li>
173 * </ul>
174 */
175 public final PolynomialSplineFunction interpolate(
176 final double[] xval, final double[] yval) throws MathException {
177 return new SplineInterpolator().interpolate(xval, smooth(xval, yval));
178 }
179
180 /**
181 * Compute a weighted loess fit on the data at the original abscissae.
182 *
183 * @param xval the arguments for the interpolation points
184 * @param yval the values for the interpolation points
185 * @param weights point weights: coefficients by which the robustness weight of a point is multiplied
186 * @return values of the loess fit at corresponding original abscissae
187 * @throws MathException if some of the following conditions are false:
188 * <ul>
189 * <li> Arguments and values are of the same size that is greater than zero</li>
190 * <li> The arguments are in a strictly increasing order</li>
191 * <li> All arguments and values are finite real numbers</li>
192 * </ul>
193 * @since 2.1
194 */
195 public final double[] smooth(final double[] xval, final double[] yval, final double[] weights)
196 throws MathException {
197 if (xval.length != yval.length) {
198 throw new MathException(
199 "Loess expects the abscissa and ordinate arrays " +
200 "to be of the same size, " +
201 "but got {0} abscisssae and {1} ordinatae",
202 xval.length, yval.length);
203 }
204
205 final int n = xval.length;
206
207 if (n == 0) {
208 throw new MathException("Loess expects at least 1 point");
209 }
210
211 checkAllFiniteReal(xval, "all abscissae must be finite real numbers, but {0}-th is {1}");
212 checkAllFiniteReal(yval, "all ordinatae must be finite real numbers, but {0}-th is {1}");
213 checkAllFiniteReal(weights, "all weights must be finite real numbers, but {0}-th is {1}");
214
215 checkStrictlyIncreasing(xval);
216
217 if (n == 1) {
218 return new double[]{yval[0]};
219 }
220
221 if (n == 2) {
222 return new double[]{yval[0], yval[1]};
223 }
224
225 int bandwidthInPoints = (int) (bandwidth * n);
226
227 if (bandwidthInPoints < 2) {
228 throw new MathException(
229 "the bandwidth must be large enough to " +
230 "accomodate at least 2 points. There are {0} " +
231 " data points, and bandwidth must be at least {1} " +
232 " but it is only {2}",
233 n, 2.0 / n, bandwidth);
234 }
235
236 final double[] res = new double[n];
237
238 final double[] residuals = new double[n];
239 final double[] sortedResiduals = new double[n];
240
241 final double[] robustnessWeights = new double[n];
242
243 // Do an initial fit and 'robustnessIters' robustness iterations.
244 // This is equivalent to doing 'robustnessIters+1' robustness iterations
245 // starting with all robustness weights set to 1.
246 Arrays.fill(robustnessWeights, 1);
247
248 for (int iter = 0; iter <= robustnessIters; ++iter) {
249 final int[] bandwidthInterval = {0, bandwidthInPoints - 1};
250 // At each x, compute a local weighted linear regression
251 for (int i = 0; i < n; ++i) {
252 final double x = xval[i];
253
254 // Find out the interval of source points on which
255 // a regression is to be made.
256 if (i > 0) {
257 updateBandwidthInterval(xval, i, bandwidthInterval);
258 }
259
260 final int ileft = bandwidthInterval[0];
261 final int iright = bandwidthInterval[1];
262
263 // Compute the point of the bandwidth interval that is
264 // farthest from x
265 final int edge;
266 if (xval[i] - xval[ileft] > xval[iright] - xval[i]) {
267 edge = ileft;
268 } else {
269 edge = iright;
270 }
271
272 // Compute a least-squares linear fit weighted by
273 // the product of robustness weights and the tricube
274 // weight function.
275 // See http://en.wikipedia.org/wiki/Linear_regression
276 // (section "Univariate linear case")
277 // and http://en.wikipedia.org/wiki/Weighted_least_squares
278 // (section "Weighted least squares")
279 double sumWeights = 0;
280 double sumX = 0;
281 double sumXSquared = 0;
282 double sumY = 0;
283 double sumXY = 0;
284 double denom = Math.abs(1.0 / (xval[edge] - x));
285 for (int k = ileft; k <= iright; ++k) {
286 final double xk = xval[k];
287 final double yk = yval[k];
288 final double dist = (k < i) ? x - xk : xk - x;
289 final double w = tricube(dist * denom) * robustnessWeights[k] * weights[k];
290 final double xkw = xk * w;
291 sumWeights += w;
292 sumX += xkw;
293 sumXSquared += xk * xkw;
294 sumY += yk * w;
295 sumXY += yk * xkw;
296 }
297
298 final double meanX = sumX / sumWeights;
299 final double meanY = sumY / sumWeights;
300 final double meanXY = sumXY / sumWeights;
301 final double meanXSquared = sumXSquared / sumWeights;
302
303 final double beta;
304 if (Math.sqrt(Math.abs(meanXSquared - meanX * meanX)) < accuracy) {
305 beta = 0;
306 } else {
307 beta = (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX);
308 }
309
310 final double alpha = meanY - beta * meanX;
311
312 res[i] = beta * x + alpha;
313 residuals[i] = Math.abs(yval[i] - res[i]);
314 }
315
316 // No need to recompute the robustness weights at the last
317 // iteration, they won't be needed anymore
318 if (iter == robustnessIters) {
319 break;
320 }
321
322 // Recompute the robustness weights.
323
324 // Find the median residual.
325 // An arraycopy and a sort are completely tractable here,
326 // because the preceding loop is a lot more expensive
327 System.arraycopy(residuals, 0, sortedResiduals, 0, n);
328 Arrays.sort(sortedResiduals);
329 final double medianResidual = sortedResiduals[n / 2];
330
331 if (Math.abs(medianResidual) < accuracy) {
332 break;
333 }
334
335 for (int i = 0; i < n; ++i) {
336 final double arg = residuals[i] / (6 * medianResidual);
337 if (arg >= 1) {
338 robustnessWeights[i] = 0;
339 } else {
340 final double w = 1 - arg * arg;
341 robustnessWeights[i] = w * w;
342 }
343 }
344 }
345
346 return res;
347 }
348
349 /**
350 * Compute a loess fit on the data at the original abscissae.
351 *
352 * @param xval the arguments for the interpolation points
353 * @param yval the values for the interpolation points
354 * @return values of the loess fit at corresponding original abscissae
355 * @throws MathException if some of the following conditions are false:
356 * <ul>
357 * <li> Arguments and values are of the same size that is greater than zero</li>
358 * <li> The arguments are in a strictly increasing order</li>
359 * <li> All arguments and values are finite real numbers</li>
360 * </ul>
361 */
362 public final double[] smooth(final double[] xval, final double[] yval)
363 throws MathException {
364
365 final double[] unitWeights = new double[xval.length];
366 Arrays.fill(unitWeights, 1.0);
367
368 return smooth(xval, yval, unitWeights);
369
370 }
371
372
373 /**
374 * Given an index interval into xval that embraces a certain number of
375 * points closest to xval[i-1], update the interval so that it embraces
376 * the same number of points closest to xval[i]
377 *
378 * @param xval arguments array
379 * @param i the index around which the new interval should be computed
380 * @param bandwidthInterval a two-element array {left, right} such that: <p/>
381 * <tt>(left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])</tt>
382 * <p/> and also <p/>
383 * <tt>(right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])</tt>.
384 * The array will be updated.
385 */
386 private static void updateBandwidthInterval(final double[] xval, final int i,
387 final int[] bandwidthInterval) {
388 final int left = bandwidthInterval[0];
389 final int right = bandwidthInterval[1];
390
391 // The right edge should be adjusted if the next point to the right
392 // is closer to xval[i] than the leftmost point of the current interval
393 if (right < xval.length - 1 &&
394 xval[right+1] - xval[i] < xval[i] - xval[left]) {
395 bandwidthInterval[0]++;
396 bandwidthInterval[1]++;
397 }
398 }
399
400 /**
401 * Compute the
402 * <a href="http://en.wikipedia.org/wiki/Local_regression#Weight_function">tricube</a>
403 * weight function
404 *
405 * @param x the argument
406 * @return (1-|x|^3)^3
407 */
408 private static double tricube(final double x) {
409 final double tmp = 1 - x * x * x;
410 return tmp * tmp * tmp;
411 }
412
413 /**
414 * Check that all elements of an array are finite real numbers.
415 *
416 * @param values the values array
417 * @param pattern pattern of the error message
418 * @throws MathException if one of the values is not a finite real number
419 */
420 private static void checkAllFiniteReal(final double[] values, final String pattern)
421 throws MathException {
422 for (int i = 0; i < values.length; i++) {
423 final double x = values[i];
424 if (Double.isInfinite(x) || Double.isNaN(x)) {
425 throw new MathException(pattern, i, x);
426 }
427 }
428 }
429
430 /**
431 * Check that elements of the abscissae array are in a strictly
432 * increasing order.
433 *
434 * @param xval the abscissae array
435 * @throws MathException if the abscissae array
436 * is not in a strictly increasing order
437 */
438 private static void checkStrictlyIncreasing(final double[] xval)
439 throws MathException {
440 for (int i = 0; i < xval.length; ++i) {
441 if (i >= 1 && xval[i - 1] >= xval[i]) {
442 throw new MathException(
443 "the abscissae array must be sorted in a strictly " +
444 "increasing order, but the {0}-th element is {1} " +
445 "whereas {2}-th is {3}",
446 i - 1, xval[i - 1], i, xval[i]);
447 }
448 }
449 }
450 }