001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018 package org.apache.commons.math3.optimization.fitting; 019 020 import java.util.Arrays; 021 import java.util.Comparator; 022 023 import org.apache.commons.math3.analysis.function.Gaussian; 024 import org.apache.commons.math3.exception.NullArgumentException; 025 import org.apache.commons.math3.exception.NumberIsTooSmallException; 026 import org.apache.commons.math3.exception.OutOfRangeException; 027 import org.apache.commons.math3.exception.ZeroException; 028 import org.apache.commons.math3.exception.NotStrictlyPositiveException; 029 import org.apache.commons.math3.exception.util.LocalizedFormats; 030 import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer; 031 import org.apache.commons.math3.util.FastMath; 032 033 /** 034 * Fits points to a {@link 035 * org.apache.commons.math3.analysis.function.Gaussian.Parametric Gaussian} function. 036 * <p> 037 * Usage example: 038 * <pre> 039 * GaussianFitter fitter = new GaussianFitter( 040 * new LevenbergMarquardtOptimizer()); 041 * fitter.addObservedPoint(4.0254623, 531026.0); 042 * fitter.addObservedPoint(4.03128248, 984167.0); 043 * fitter.addObservedPoint(4.03839603, 1887233.0); 044 * fitter.addObservedPoint(4.04421621, 2687152.0); 045 * fitter.addObservedPoint(4.05132976, 3461228.0); 046 * fitter.addObservedPoint(4.05326982, 3580526.0); 047 * fitter.addObservedPoint(4.05779662, 3439750.0); 048 * fitter.addObservedPoint(4.0636168, 2877648.0); 049 * fitter.addObservedPoint(4.06943698, 2175960.0); 050 * fitter.addObservedPoint(4.07525716, 1447024.0); 051 * fitter.addObservedPoint(4.08237071, 717104.0); 052 * fitter.addObservedPoint(4.08366408, 620014.0); 053 * double[] parameters = fitter.fit(); 054 * </pre> 055 * 056 * @since 2.2 057 * @version $Id: GaussianFitter.java 1422230 2012-12-15 12:11:13Z erans $ 058 * @deprecated As of 3.1 (to be removed in 4.0). 059 */ 060 @Deprecated 061 public class GaussianFitter extends CurveFitter<Gaussian.Parametric> { 062 /** 063 * Constructs an instance using the specified optimizer. 064 * 065 * @param optimizer Optimizer to use for the fitting. 066 */ 067 public GaussianFitter(DifferentiableMultivariateVectorOptimizer optimizer) { 068 super(optimizer); 069 } 070 071 /** 072 * Fits a Gaussian function to the observed points. 073 * 074 * @param initialGuess First guess values in the following order: 075 * <ul> 076 * <li>Norm</li> 077 * <li>Mean</li> 078 * <li>Sigma</li> 079 * </ul> 080 * @return the parameters of the Gaussian function that best fits the 081 * observed points (in the same order as above). 082 * @since 3.0 083 */ 084 public double[] fit(double[] initialGuess) { 085 final Gaussian.Parametric f = new Gaussian.Parametric() { 086 @Override 087 public double value(double x, double ... p) { 088 double v = Double.POSITIVE_INFINITY; 089 try { 090 v = super.value(x, p); 091 } catch (NotStrictlyPositiveException e) { // NOPMD 092 // Do nothing. 093 } 094 return v; 095 } 096 097 @Override 098 public double[] gradient(double x, double ... p) { 099 double[] v = { Double.POSITIVE_INFINITY, 100 Double.POSITIVE_INFINITY, 101 Double.POSITIVE_INFINITY }; 102 try { 103 v = super.gradient(x, p); 104 } catch (NotStrictlyPositiveException e) { // NOPMD 105 // Do nothing. 106 } 107 return v; 108 } 109 }; 110 111 return fit(f, initialGuess); 112 } 113 114 /** 115 * Fits a Gaussian function to the observed points. 116 * 117 * @return the parameters of the Gaussian function that best fits the 118 * observed points (in the same order as above). 119 */ 120 public double[] fit() { 121 final double[] guess = (new ParameterGuesser(getObservations())).guess(); 122 return fit(guess); 123 } 124 125 /** 126 * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma} 127 * of a {@link org.apache.commons.math3.analysis.function.Gaussian.Parametric} 128 * based on the specified observed points. 129 */ 130 public static class ParameterGuesser { 131 /** Normalization factor. */ 132 private final double norm; 133 /** Mean. */ 134 private final double mean; 135 /** Standard deviation. */ 136 private final double sigma; 137 138 /** 139 * Constructs instance with the specified observed points. 140 * 141 * @param observations Observed points from which to guess the 142 * parameters of the Gaussian. 143 * @throws NullArgumentException if {@code observations} is 144 * {@code null}. 145 * @throws NumberIsTooSmallException if there are less than 3 146 * observations. 147 */ 148 public ParameterGuesser(WeightedObservedPoint[] observations) { 149 if (observations == null) { 150 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); 151 } 152 if (observations.length < 3) { 153 throw new NumberIsTooSmallException(observations.length, 3, true); 154 } 155 156 final WeightedObservedPoint[] sorted = sortObservations(observations); 157 final double[] params = basicGuess(sorted); 158 159 norm = params[0]; 160 mean = params[1]; 161 sigma = params[2]; 162 } 163 164 /** 165 * Gets an estimation of the parameters. 166 * 167 * @return the guessed parameters, in the following order: 168 * <ul> 169 * <li>Normalization factor</li> 170 * <li>Mean</li> 171 * <li>Standard deviation</li> 172 * </ul> 173 */ 174 public double[] guess() { 175 return new double[] { norm, mean, sigma }; 176 } 177 178 /** 179 * Sort the observations. 180 * 181 * @param unsorted Input observations. 182 * @return the input observations, sorted. 183 */ 184 private WeightedObservedPoint[] sortObservations(WeightedObservedPoint[] unsorted) { 185 final WeightedObservedPoint[] observations = unsorted.clone(); 186 final Comparator<WeightedObservedPoint> cmp 187 = new Comparator<WeightedObservedPoint>() { 188 public int compare(WeightedObservedPoint p1, 189 WeightedObservedPoint p2) { 190 if (p1 == null && p2 == null) { 191 return 0; 192 } 193 if (p1 == null) { 194 return -1; 195 } 196 if (p2 == null) { 197 return 1; 198 } 199 if (p1.getX() < p2.getX()) { 200 return -1; 201 } 202 if (p1.getX() > p2.getX()) { 203 return 1; 204 } 205 if (p1.getY() < p2.getY()) { 206 return -1; 207 } 208 if (p1.getY() > p2.getY()) { 209 return 1; 210 } 211 if (p1.getWeight() < p2.getWeight()) { 212 return -1; 213 } 214 if (p1.getWeight() > p2.getWeight()) { 215 return 1; 216 } 217 return 0; 218 } 219 }; 220 221 Arrays.sort(observations, cmp); 222 return observations; 223 } 224 225 /** 226 * Guesses the parameters based on the specified observed points. 227 * 228 * @param points Observed points, sorted. 229 * @return the guessed parameters (normalization factor, mean and 230 * sigma). 231 */ 232 private double[] basicGuess(WeightedObservedPoint[] points) { 233 final int maxYIdx = findMaxY(points); 234 final double n = points[maxYIdx].getY(); 235 final double m = points[maxYIdx].getX(); 236 237 double fwhmApprox; 238 try { 239 final double halfY = n + ((m - n) / 2); 240 final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY); 241 final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY); 242 fwhmApprox = fwhmX2 - fwhmX1; 243 } catch (OutOfRangeException e) { 244 // TODO: Exceptions should not be used for flow control. 245 fwhmApprox = points[points.length - 1].getX() - points[0].getX(); 246 } 247 final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2))); 248 249 return new double[] { n, m, s }; 250 } 251 252 /** 253 * Finds index of point in specified points with the largest Y. 254 * 255 * @param points Points to search. 256 * @return the index in specified points array. 257 */ 258 private int findMaxY(WeightedObservedPoint[] points) { 259 int maxYIdx = 0; 260 for (int i = 1; i < points.length; i++) { 261 if (points[i].getY() > points[maxYIdx].getY()) { 262 maxYIdx = i; 263 } 264 } 265 return maxYIdx; 266 } 267 268 /** 269 * Interpolates using the specified points to determine X at the 270 * specified Y. 271 * 272 * @param points Points to use for interpolation. 273 * @param startIdx Index within points from which to start the search for 274 * interpolation bounds points. 275 * @param idxStep Index step for searching interpolation bounds points. 276 * @param y Y value for which X should be determined. 277 * @return the value of X for the specified Y. 278 * @throws ZeroException if {@code idxStep} is 0. 279 * @throws OutOfRangeException if specified {@code y} is not within the 280 * range of the specified {@code points}. 281 */ 282 private double interpolateXAtY(WeightedObservedPoint[] points, 283 int startIdx, 284 int idxStep, 285 double y) 286 throws OutOfRangeException { 287 if (idxStep == 0) { 288 throw new ZeroException(); 289 } 290 final WeightedObservedPoint[] twoPoints 291 = getInterpolationPointsForY(points, startIdx, idxStep, y); 292 final WeightedObservedPoint p1 = twoPoints[0]; 293 final WeightedObservedPoint p2 = twoPoints[1]; 294 if (p1.getY() == y) { 295 return p1.getX(); 296 } 297 if (p2.getY() == y) { 298 return p2.getX(); 299 } 300 return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) / 301 (p2.getY() - p1.getY())); 302 } 303 304 /** 305 * Gets the two bounding interpolation points from the specified points 306 * suitable for determining X at the specified Y. 307 * 308 * @param points Points to use for interpolation. 309 * @param startIdx Index within points from which to start search for 310 * interpolation bounds points. 311 * @param idxStep Index step for search for interpolation bounds points. 312 * @param y Y value for which X should be determined. 313 * @return the array containing two points suitable for determining X at 314 * the specified Y. 315 * @throws ZeroException if {@code idxStep} is 0. 316 * @throws OutOfRangeException if specified {@code y} is not within the 317 * range of the specified {@code points}. 318 */ 319 private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points, 320 int startIdx, 321 int idxStep, 322 double y) 323 throws OutOfRangeException { 324 if (idxStep == 0) { 325 throw new ZeroException(); 326 } 327 for (int i = startIdx; 328 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length; 329 i += idxStep) { 330 final WeightedObservedPoint p1 = points[i]; 331 final WeightedObservedPoint p2 = points[i + idxStep]; 332 if (isBetween(y, p1.getY(), p2.getY())) { 333 if (idxStep < 0) { 334 return new WeightedObservedPoint[] { p2, p1 }; 335 } else { 336 return new WeightedObservedPoint[] { p1, p2 }; 337 } 338 } 339 } 340 341 // Boundaries are replaced by dummy values because the raised 342 // exception is caught and the message never displayed. 343 // TODO: Exceptions should not be used for flow control. 344 throw new OutOfRangeException(y, 345 Double.NEGATIVE_INFINITY, 346 Double.POSITIVE_INFINITY); 347 } 348 349 /** 350 * Determines whether a value is between two other values. 351 * 352 * @param value Value to test whether it is between {@code boundary1} 353 * and {@code boundary2}. 354 * @param boundary1 One end of the range. 355 * @param boundary2 Other end of the range. 356 * @return {@code true} if {@code value} is between {@code boundary1} and 357 * {@code boundary2} (inclusive), {@code false} otherwise. 358 */ 359 private boolean isBetween(double value, 360 double boundary1, 361 double boundary2) { 362 return (value >= boundary1 && value <= boundary2) || 363 (value >= boundary2 && value <= boundary1); 364 } 365 } 366 }