001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math3.optimization.fitting; 019 020import java.util.Arrays; 021import java.util.Comparator; 022 023import org.apache.commons.math3.analysis.function.Gaussian; 024import org.apache.commons.math3.exception.NullArgumentException; 025import org.apache.commons.math3.exception.NumberIsTooSmallException; 026import org.apache.commons.math3.exception.OutOfRangeException; 027import org.apache.commons.math3.exception.ZeroException; 028import org.apache.commons.math3.exception.NotStrictlyPositiveException; 029import org.apache.commons.math3.exception.util.LocalizedFormats; 030import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer; 031import org.apache.commons.math3.util.FastMath; 032 033/** 034 * Fits points to a {@link 035 * org.apache.commons.math3.analysis.function.Gaussian.Parametric Gaussian} function. 036 * <p> 037 * Usage example: 038 * <pre> 039 * GaussianFitter fitter = new GaussianFitter( 040 * new LevenbergMarquardtOptimizer()); 041 * fitter.addObservedPoint(4.0254623, 531026.0); 042 * fitter.addObservedPoint(4.03128248, 984167.0); 043 * fitter.addObservedPoint(4.03839603, 1887233.0); 044 * fitter.addObservedPoint(4.04421621, 2687152.0); 045 * fitter.addObservedPoint(4.05132976, 3461228.0); 046 * fitter.addObservedPoint(4.05326982, 3580526.0); 047 * fitter.addObservedPoint(4.05779662, 3439750.0); 048 * fitter.addObservedPoint(4.0636168, 2877648.0); 049 * fitter.addObservedPoint(4.06943698, 2175960.0); 050 * fitter.addObservedPoint(4.07525716, 1447024.0); 051 * fitter.addObservedPoint(4.08237071, 717104.0); 052 * fitter.addObservedPoint(4.08366408, 620014.0); 053 * double[] parameters = fitter.fit(); 054 * </pre> 055 * 056 * @since 2.2 057 * @deprecated As of 3.1 (to be removed in 4.0). 058 */ 059@Deprecated 060public class GaussianFitter extends CurveFitter<Gaussian.Parametric> { 061 /** 062 * Constructs an instance using the specified optimizer. 063 * 064 * @param optimizer Optimizer to use for the fitting. 065 */ 066 public GaussianFitter(DifferentiableMultivariateVectorOptimizer optimizer) { 067 super(optimizer); 068 } 069 070 /** 071 * Fits a Gaussian function to the observed points. 072 * 073 * @param initialGuess First guess values in the following order: 074 * <ul> 075 * <li>Norm</li> 076 * <li>Mean</li> 077 * <li>Sigma</li> 078 * </ul> 079 * @return the parameters of the Gaussian function that best fits the 080 * observed points (in the same order as above). 081 * @since 3.0 082 */ 083 public double[] fit(double[] initialGuess) { 084 final Gaussian.Parametric f = new Gaussian.Parametric() { 085 /** {@inheritDoc} */ 086 @Override 087 public double value(double x, double ... p) { 088 double v = Double.POSITIVE_INFINITY; 089 try { 090 v = super.value(x, p); 091 } catch (NotStrictlyPositiveException e) { // NOPMD 092 // Do nothing. 093 } 094 return v; 095 } 096 097 /** {@inheritDoc} */ 098 @Override 099 public double[] gradient(double x, double ... p) { 100 double[] v = { Double.POSITIVE_INFINITY, 101 Double.POSITIVE_INFINITY, 102 Double.POSITIVE_INFINITY }; 103 try { 104 v = super.gradient(x, p); 105 } catch (NotStrictlyPositiveException e) { // NOPMD 106 // Do nothing. 107 } 108 return v; 109 } 110 }; 111 112 return fit(f, initialGuess); 113 } 114 115 /** 116 * Fits a Gaussian function to the observed points. 117 * 118 * @return the parameters of the Gaussian function that best fits the 119 * observed points (in the same order as above). 120 */ 121 public double[] fit() { 122 final double[] guess = (new ParameterGuesser(getObservations())).guess(); 123 return fit(guess); 124 } 125 126 /** 127 * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma} 128 * of a {@link org.apache.commons.math3.analysis.function.Gaussian.Parametric} 129 * based on the specified observed points. 130 */ 131 public static class ParameterGuesser { 132 /** Normalization factor. */ 133 private final double norm; 134 /** Mean. */ 135 private final double mean; 136 /** Standard deviation. */ 137 private final double sigma; 138 139 /** 140 * Constructs instance with the specified observed points. 141 * 142 * @param observations Observed points from which to guess the 143 * parameters of the Gaussian. 144 * @throws NullArgumentException if {@code observations} is 145 * {@code null}. 146 * @throws NumberIsTooSmallException if there are less than 3 147 * observations. 148 */ 149 public ParameterGuesser(WeightedObservedPoint[] observations) { 150 if (observations == null) { 151 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); 152 } 153 if (observations.length < 3) { 154 throw new NumberIsTooSmallException(observations.length, 3, true); 155 } 156 157 final WeightedObservedPoint[] sorted = sortObservations(observations); 158 final double[] params = basicGuess(sorted); 159 160 norm = params[0]; 161 mean = params[1]; 162 sigma = params[2]; 163 } 164 165 /** 166 * Gets an estimation of the parameters. 167 * 168 * @return the guessed parameters, in the following order: 169 * <ul> 170 * <li>Normalization factor</li> 171 * <li>Mean</li> 172 * <li>Standard deviation</li> 173 * </ul> 174 */ 175 public double[] guess() { 176 return new double[] { norm, mean, sigma }; 177 } 178 179 /** 180 * Sort the observations. 181 * 182 * @param unsorted Input observations. 183 * @return the input observations, sorted. 184 */ 185 private WeightedObservedPoint[] sortObservations(WeightedObservedPoint[] unsorted) { 186 final WeightedObservedPoint[] observations = unsorted.clone(); 187 final Comparator<WeightedObservedPoint> cmp 188 = new Comparator<WeightedObservedPoint>() { 189 /** {@inheritDoc} */ 190 public int compare(WeightedObservedPoint p1, 191 WeightedObservedPoint p2) { 192 if (p1 == null && p2 == null) { 193 return 0; 194 } 195 if (p1 == null) { 196 return -1; 197 } 198 if (p2 == null) { 199 return 1; 200 } 201 final int cmpX = Double.compare(p1.getX(), p2.getX()); 202 if (cmpX < 0) { 203 return -1; 204 } 205 if (cmpX > 0) { 206 return 1; 207 } 208 final int cmpY = Double.compare(p1.getY(), p2.getY()); 209 if (cmpY < 0) { 210 return -1; 211 } 212 if (cmpY > 0) { 213 return 1; 214 } 215 final int cmpW = Double.compare(p1.getWeight(), p2.getWeight()); 216 if (cmpW < 0) { 217 return -1; 218 } 219 if (cmpW > 0) { 220 return 1; 221 } 222 return 0; 223 } 224 }; 225 226 Arrays.sort(observations, cmp); 227 return observations; 228 } 229 230 /** 231 * Guesses the parameters based on the specified observed points. 232 * 233 * @param points Observed points, sorted. 234 * @return the guessed parameters (normalization factor, mean and 235 * sigma). 236 */ 237 private double[] basicGuess(WeightedObservedPoint[] points) { 238 final int maxYIdx = findMaxY(points); 239 final double n = points[maxYIdx].getY(); 240 final double m = points[maxYIdx].getX(); 241 242 double fwhmApprox; 243 try { 244 final double halfY = n + ((m - n) / 2); 245 final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY); 246 final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY); 247 fwhmApprox = fwhmX2 - fwhmX1; 248 } catch (OutOfRangeException e) { 249 // TODO: Exceptions should not be used for flow control. 250 fwhmApprox = points[points.length - 1].getX() - points[0].getX(); 251 } 252 final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2))); 253 254 return new double[] { n, m, s }; 255 } 256 257 /** 258 * Finds index of point in specified points with the largest Y. 259 * 260 * @param points Points to search. 261 * @return the index in specified points array. 262 */ 263 private int findMaxY(WeightedObservedPoint[] points) { 264 int maxYIdx = 0; 265 for (int i = 1; i < points.length; i++) { 266 if (points[i].getY() > points[maxYIdx].getY()) { 267 maxYIdx = i; 268 } 269 } 270 return maxYIdx; 271 } 272 273 /** 274 * Interpolates using the specified points to determine X at the 275 * specified Y. 276 * 277 * @param points Points to use for interpolation. 278 * @param startIdx Index within points from which to start the search for 279 * interpolation bounds points. 280 * @param idxStep Index step for searching interpolation bounds points. 281 * @param y Y value for which X should be determined. 282 * @return the value of X for the specified Y. 283 * @throws ZeroException if {@code idxStep} is 0. 284 * @throws OutOfRangeException if specified {@code y} is not within the 285 * range of the specified {@code points}. 286 */ 287 private double interpolateXAtY(WeightedObservedPoint[] points, 288 int startIdx, 289 int idxStep, 290 double y) 291 throws OutOfRangeException { 292 if (idxStep == 0) { 293 throw new ZeroException(); 294 } 295 final WeightedObservedPoint[] twoPoints 296 = getInterpolationPointsForY(points, startIdx, idxStep, y); 297 final WeightedObservedPoint p1 = twoPoints[0]; 298 final WeightedObservedPoint p2 = twoPoints[1]; 299 if (p1.getY() == y) { 300 return p1.getX(); 301 } 302 if (p2.getY() == y) { 303 return p2.getX(); 304 } 305 return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) / 306 (p2.getY() - p1.getY())); 307 } 308 309 /** 310 * Gets the two bounding interpolation points from the specified points 311 * suitable for determining X at the specified Y. 312 * 313 * @param points Points to use for interpolation. 314 * @param startIdx Index within points from which to start search for 315 * interpolation bounds points. 316 * @param idxStep Index step for search for interpolation bounds points. 317 * @param y Y value for which X should be determined. 318 * @return the array containing two points suitable for determining X at 319 * the specified Y. 320 * @throws ZeroException if {@code idxStep} is 0. 321 * @throws OutOfRangeException if specified {@code y} is not within the 322 * range of the specified {@code points}. 323 */ 324 private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points, 325 int startIdx, 326 int idxStep, 327 double y) 328 throws OutOfRangeException { 329 if (idxStep == 0) { 330 throw new ZeroException(); 331 } 332 for (int i = startIdx; 333 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length; 334 i += idxStep) { 335 final WeightedObservedPoint p1 = points[i]; 336 final WeightedObservedPoint p2 = points[i + idxStep]; 337 if (isBetween(y, p1.getY(), p2.getY())) { 338 if (idxStep < 0) { 339 return new WeightedObservedPoint[] { p2, p1 }; 340 } else { 341 return new WeightedObservedPoint[] { p1, p2 }; 342 } 343 } 344 } 345 346 // Boundaries are replaced by dummy values because the raised 347 // exception is caught and the message never displayed. 348 // TODO: Exceptions should not be used for flow control. 349 throw new OutOfRangeException(y, 350 Double.NEGATIVE_INFINITY, 351 Double.POSITIVE_INFINITY); 352 } 353 354 /** 355 * Determines whether a value is between two other values. 356 * 357 * @param value Value to test whether it is between {@code boundary1} 358 * and {@code boundary2}. 359 * @param boundary1 One end of the range. 360 * @param boundary2 Other end of the range. 361 * @return {@code true} if {@code value} is between {@code boundary1} and 362 * {@code boundary2} (inclusive), {@code false} otherwise. 363 */ 364 private boolean isBetween(double value, 365 double boundary1, 366 double boundary2) { 367 return (value >= boundary1 && value <= boundary2) || 368 (value >= boundary2 && value <= boundary1); 369 } 370 } 371}