001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math3.fitting; 018 019import java.util.ArrayList; 020import java.util.Collection; 021import java.util.Collections; 022import java.util.Comparator; 023import java.util.List; 024 025import org.apache.commons.math3.analysis.function.Gaussian; 026import org.apache.commons.math3.exception.NotStrictlyPositiveException; 027import org.apache.commons.math3.exception.NullArgumentException; 028import org.apache.commons.math3.exception.NumberIsTooSmallException; 029import org.apache.commons.math3.exception.OutOfRangeException; 030import org.apache.commons.math3.exception.ZeroException; 031import org.apache.commons.math3.exception.util.LocalizedFormats; 032import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder; 033import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem; 034import org.apache.commons.math3.linear.DiagonalMatrix; 035import org.apache.commons.math3.util.FastMath; 036 037/** 038 * Fits points to a {@link 039 * org.apache.commons.math3.analysis.function.Gaussian.Parametric Gaussian} 040 * function. 041 * <br/> 042 * The {@link #withStartPoint(double[]) initial guess values} must be passed 043 * in the following order: 044 * <ul> 045 * <li>Normalization</li> 046 * <li>Mean</li> 047 * <li>Sigma</li> 048 * </ul> 049 * The optimal values will be returned in the same order. 050 * 051 * <p> 052 * Usage example: 053 * <pre> 054 * WeightedObservedPoints obs = new WeightedObservedPoints(); 055 * obs.add(4.0254623, 531026.0); 056 * obs.add(4.03128248, 984167.0); 057 * obs.add(4.03839603, 1887233.0); 058 * obs.add(4.04421621, 2687152.0); 059 * obs.add(4.05132976, 3461228.0); 060 * obs.add(4.05326982, 3580526.0); 061 * obs.add(4.05779662, 3439750.0); 062 * obs.add(4.0636168, 2877648.0); 063 * obs.add(4.06943698, 2175960.0); 064 * obs.add(4.07525716, 1447024.0); 065 * obs.add(4.08237071, 717104.0); 066 * obs.add(4.08366408, 620014.0); 067 * double[] parameters = GaussianCurveFitter.create().fit(obs.toList()); 068 * </pre> 069 * 070 * @since 3.3 071 */ 072public class GaussianCurveFitter extends AbstractCurveFitter { 073 /** Parametric function to be fitted. */ 074 private static final Gaussian.Parametric FUNCTION = new Gaussian.Parametric() { 075 /** {@inheritDoc} */ 076 @Override 077 public double value(double x, double ... p) { 078 double v = Double.POSITIVE_INFINITY; 079 try { 080 v = super.value(x, p); 081 } catch (NotStrictlyPositiveException e) { // NOPMD 082 // Do nothing. 083 } 084 return v; 085 } 086 087 /** {@inheritDoc} */ 088 @Override 089 public double[] gradient(double x, double ... p) { 090 double[] v = { Double.POSITIVE_INFINITY, 091 Double.POSITIVE_INFINITY, 092 Double.POSITIVE_INFINITY }; 093 try { 094 v = super.gradient(x, p); 095 } catch (NotStrictlyPositiveException e) { // NOPMD 096 // Do nothing. 097 } 098 return v; 099 } 100 }; 101 /** Initial guess. */ 102 private final double[] initialGuess; 103 /** Maximum number of iterations of the optimization algorithm. */ 104 private final int maxIter; 105 106 /** 107 * Contructor used by the factory methods. 108 * 109 * @param initialGuess Initial guess. If set to {@code null}, the initial guess 110 * will be estimated using the {@link ParameterGuesser}. 111 * @param maxIter Maximum number of iterations of the optimization algorithm. 112 */ 113 private GaussianCurveFitter(double[] initialGuess, 114 int maxIter) { 115 this.initialGuess = initialGuess; 116 this.maxIter = maxIter; 117 } 118 119 /** 120 * Creates a default curve fitter. 121 * The initial guess for the parameters will be {@link ParameterGuesser} 122 * computed automatically, and the maximum number of iterations of the 123 * optimization algorithm is set to {@link Integer#MAX_VALUE}. 124 * 125 * @return a curve fitter. 126 * 127 * @see #withStartPoint(double[]) 128 * @see #withMaxIterations(int) 129 */ 130 public static GaussianCurveFitter create() { 131 return new GaussianCurveFitter(null, Integer.MAX_VALUE); 132 } 133 134 /** 135 * Configure the start point (initial guess). 136 * @param newStart new start point (initial guess) 137 * @return a new instance. 138 */ 139 public GaussianCurveFitter withStartPoint(double[] newStart) { 140 return new GaussianCurveFitter(newStart.clone(), 141 maxIter); 142 } 143 144 /** 145 * Configure the maximum number of iterations. 146 * @param newMaxIter maximum number of iterations 147 * @return a new instance. 148 */ 149 public GaussianCurveFitter withMaxIterations(int newMaxIter) { 150 return new GaussianCurveFitter(initialGuess, 151 newMaxIter); 152 } 153 154 /** {@inheritDoc} */ 155 @Override 156 protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) { 157 158 // Prepare least-squares problem. 159 final int len = observations.size(); 160 final double[] target = new double[len]; 161 final double[] weights = new double[len]; 162 163 int i = 0; 164 for (WeightedObservedPoint obs : observations) { 165 target[i] = obs.getY(); 166 weights[i] = obs.getWeight(); 167 ++i; 168 } 169 170 final AbstractCurveFitter.TheoreticalValuesFunction model = 171 new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations); 172 173 final double[] startPoint = initialGuess != null ? 174 initialGuess : 175 // Compute estimation. 176 new ParameterGuesser(observations).guess(); 177 178 // Return a new least squares problem set up to fit a Gaussian curve to the 179 // observed points. 180 return new LeastSquaresBuilder(). 181 maxEvaluations(Integer.MAX_VALUE). 182 maxIterations(maxIter). 183 start(startPoint). 184 target(target). 185 weight(new DiagonalMatrix(weights)). 186 model(model.getModelFunction(), model.getModelFunctionJacobian()). 187 build(); 188 189 } 190 191 /** 192 * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma} 193 * of a {@link org.apache.commons.math3.analysis.function.Gaussian.Parametric} 194 * based on the specified observed points. 195 */ 196 public static class ParameterGuesser { 197 /** Normalization factor. */ 198 private final double norm; 199 /** Mean. */ 200 private final double mean; 201 /** Standard deviation. */ 202 private final double sigma; 203 204 /** 205 * Constructs instance with the specified observed points. 206 * 207 * @param observations Observed points from which to guess the 208 * parameters of the Gaussian. 209 * @throws NullArgumentException if {@code observations} is 210 * {@code null}. 211 * @throws NumberIsTooSmallException if there are less than 3 212 * observations. 213 */ 214 public ParameterGuesser(Collection<WeightedObservedPoint> observations) { 215 if (observations == null) { 216 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); 217 } 218 if (observations.size() < 3) { 219 throw new NumberIsTooSmallException(observations.size(), 3, true); 220 } 221 222 final List<WeightedObservedPoint> sorted = sortObservations(observations); 223 final double[] params = basicGuess(sorted.toArray(new WeightedObservedPoint[0])); 224 225 norm = params[0]; 226 mean = params[1]; 227 sigma = params[2]; 228 } 229 230 /** 231 * Gets an estimation of the parameters. 232 * 233 * @return the guessed parameters, in the following order: 234 * <ul> 235 * <li>Normalization factor</li> 236 * <li>Mean</li> 237 * <li>Standard deviation</li> 238 * </ul> 239 */ 240 public double[] guess() { 241 return new double[] { norm, mean, sigma }; 242 } 243 244 /** 245 * Sort the observations. 246 * 247 * @param unsorted Input observations. 248 * @return the input observations, sorted. 249 */ 250 private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) { 251 final List<WeightedObservedPoint> observations = new ArrayList<WeightedObservedPoint>(unsorted); 252 253 final Comparator<WeightedObservedPoint> cmp = new Comparator<WeightedObservedPoint>() { 254 /** {@inheritDoc} */ 255 public int compare(WeightedObservedPoint p1, 256 WeightedObservedPoint p2) { 257 if (p1 == null && p2 == null) { 258 return 0; 259 } 260 if (p1 == null) { 261 return -1; 262 } 263 if (p2 == null) { 264 return 1; 265 } 266 final int cmpX = Double.compare(p1.getX(), p2.getX()); 267 if (cmpX < 0) { 268 return -1; 269 } 270 if (cmpX > 0) { 271 return 1; 272 } 273 final int cmpY = Double.compare(p1.getY(), p2.getY()); 274 if (cmpY < 0) { 275 return -1; 276 } 277 if (cmpY > 0) { 278 return 1; 279 } 280 final int cmpW = Double.compare(p1.getWeight(), p2.getWeight()); 281 if (cmpW < 0) { 282 return -1; 283 } 284 if (cmpW > 0) { 285 return 1; 286 } 287 return 0; 288 } 289 }; 290 291 Collections.sort(observations, cmp); 292 return observations; 293 } 294 295 /** 296 * Guesses the parameters based on the specified observed points. 297 * 298 * @param points Observed points, sorted. 299 * @return the guessed parameters (normalization factor, mean and 300 * sigma). 301 */ 302 private double[] basicGuess(WeightedObservedPoint[] points) { 303 final int maxYIdx = findMaxY(points); 304 final double n = points[maxYIdx].getY(); 305 final double m = points[maxYIdx].getX(); 306 307 double fwhmApprox; 308 try { 309 final double halfY = n + ((m - n) / 2); 310 final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY); 311 final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY); 312 fwhmApprox = fwhmX2 - fwhmX1; 313 } catch (OutOfRangeException e) { 314 // TODO: Exceptions should not be used for flow control. 315 fwhmApprox = points[points.length - 1].getX() - points[0].getX(); 316 } 317 final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2))); 318 319 return new double[] { n, m, s }; 320 } 321 322 /** 323 * Finds index of point in specified points with the largest Y. 324 * 325 * @param points Points to search. 326 * @return the index in specified points array. 327 */ 328 private int findMaxY(WeightedObservedPoint[] points) { 329 int maxYIdx = 0; 330 for (int i = 1; i < points.length; i++) { 331 if (points[i].getY() > points[maxYIdx].getY()) { 332 maxYIdx = i; 333 } 334 } 335 return maxYIdx; 336 } 337 338 /** 339 * Interpolates using the specified points to determine X at the 340 * specified Y. 341 * 342 * @param points Points to use for interpolation. 343 * @param startIdx Index within points from which to start the search for 344 * interpolation bounds points. 345 * @param idxStep Index step for searching interpolation bounds points. 346 * @param y Y value for which X should be determined. 347 * @return the value of X for the specified Y. 348 * @throws ZeroException if {@code idxStep} is 0. 349 * @throws OutOfRangeException if specified {@code y} is not within the 350 * range of the specified {@code points}. 351 */ 352 private double interpolateXAtY(WeightedObservedPoint[] points, 353 int startIdx, 354 int idxStep, 355 double y) 356 throws OutOfRangeException { 357 if (idxStep == 0) { 358 throw new ZeroException(); 359 } 360 final WeightedObservedPoint[] twoPoints 361 = getInterpolationPointsForY(points, startIdx, idxStep, y); 362 final WeightedObservedPoint p1 = twoPoints[0]; 363 final WeightedObservedPoint p2 = twoPoints[1]; 364 if (p1.getY() == y) { 365 return p1.getX(); 366 } 367 if (p2.getY() == y) { 368 return p2.getX(); 369 } 370 return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) / 371 (p2.getY() - p1.getY())); 372 } 373 374 /** 375 * Gets the two bounding interpolation points from the specified points 376 * suitable for determining X at the specified Y. 377 * 378 * @param points Points to use for interpolation. 379 * @param startIdx Index within points from which to start search for 380 * interpolation bounds points. 381 * @param idxStep Index step for search for interpolation bounds points. 382 * @param y Y value for which X should be determined. 383 * @return the array containing two points suitable for determining X at 384 * the specified Y. 385 * @throws ZeroException if {@code idxStep} is 0. 386 * @throws OutOfRangeException if specified {@code y} is not within the 387 * range of the specified {@code points}. 388 */ 389 private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points, 390 int startIdx, 391 int idxStep, 392 double y) 393 throws OutOfRangeException { 394 if (idxStep == 0) { 395 throw new ZeroException(); 396 } 397 for (int i = startIdx; 398 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length; 399 i += idxStep) { 400 final WeightedObservedPoint p1 = points[i]; 401 final WeightedObservedPoint p2 = points[i + idxStep]; 402 if (isBetween(y, p1.getY(), p2.getY())) { 403 if (idxStep < 0) { 404 return new WeightedObservedPoint[] { p2, p1 }; 405 } else { 406 return new WeightedObservedPoint[] { p1, p2 }; 407 } 408 } 409 } 410 411 // Boundaries are replaced by dummy values because the raised 412 // exception is caught and the message never displayed. 413 // TODO: Exceptions should not be used for flow control. 414 throw new OutOfRangeException(y, 415 Double.NEGATIVE_INFINITY, 416 Double.POSITIVE_INFINITY); 417 } 418 419 /** 420 * Determines whether a value is between two other values. 421 * 422 * @param value Value to test whether it is between {@code boundary1} 423 * and {@code boundary2}. 424 * @param boundary1 One end of the range. 425 * @param boundary2 Other end of the range. 426 * @return {@code true} if {@code value} is between {@code boundary1} and 427 * {@code boundary2} (inclusive), {@code false} otherwise. 428 */ 429 private boolean isBetween(double value, 430 double boundary1, 431 double boundary2) { 432 return (value >= boundary1 && value <= boundary2) || 433 (value >= boundary2 && value <= boundary1); 434 } 435 } 436}