View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math.optimization.fitting;
19  
20  import java.util.Arrays;
21  import java.util.Comparator;
22  
23  import org.apache.commons.math.analysis.function.Gaussian;
24  import org.apache.commons.math.analysis.ParametricUnivariateRealFunction;
25  import org.apache.commons.math.exception.NullArgumentException;
26  import org.apache.commons.math.exception.NumberIsTooSmallException;
27  import org.apache.commons.math.exception.OutOfRangeException;
28  import org.apache.commons.math.exception.ZeroException;
29  import org.apache.commons.math.exception.NotStrictlyPositiveException;
30  import org.apache.commons.math.exception.util.LocalizedFormats;
31  import org.apache.commons.math.optimization.DifferentiableMultivariateVectorialOptimizer;
32  import org.apache.commons.math.optimization.fitting.CurveFitter;
33  import org.apache.commons.math.optimization.fitting.WeightedObservedPoint;
34  
35  /**
36   * Fits points to a {@link
37   * org.apache.commons.math.analysis.function.Gaussian.Parametric Gaussian} function.
38   * <p>
39   * Usage example:
40   * <pre>
41   *   GaussianFitter fitter = new GaussianFitter(
42   *     new LevenbergMarquardtOptimizer());
43   *   fitter.addObservedPoint(4.0254623,  531026.0);
44   *   fitter.addObservedPoint(4.03128248, 984167.0);
45   *   fitter.addObservedPoint(4.03839603, 1887233.0);
46   *   fitter.addObservedPoint(4.04421621, 2687152.0);
47   *   fitter.addObservedPoint(4.05132976, 3461228.0);
48   *   fitter.addObservedPoint(4.05326982, 3580526.0);
49   *   fitter.addObservedPoint(4.05779662, 3439750.0);
50   *   fitter.addObservedPoint(4.0636168,  2877648.0);
51   *   fitter.addObservedPoint(4.06943698, 2175960.0);
52   *   fitter.addObservedPoint(4.07525716, 1447024.0);
53   *   fitter.addObservedPoint(4.08237071, 717104.0);
54   *   fitter.addObservedPoint(4.08366408, 620014.0);
55   *   double[] parameters = fitter.fit();
56   * </pre>
57   *
58   * @since 2.2
59   * @version $Id: GaussianFitter.java 1179928 2011-10-07 03:20:39Z psteitz $
60   */
61  public class GaussianFitter extends CurveFitter {
62      /**
63       * Constructs an instance using the specified optimizer.
64       *
65       * @param optimizer Optimizer to use for the fitting.
66       */
67      public GaussianFitter(DifferentiableMultivariateVectorialOptimizer optimizer) {
68          super(optimizer);
69      }
70  
71      /**
72       * Fits a Gaussian function to the observed points.
73       *
74       * @param initialGuess First guess values in the following order:
75       * <ul>
76       *  <li>Norm</li>
77       *  <li>Mean</li>
78       *  <li>Sigma</li>
79       * </ul>
80       * @return the parameters of the Gaussian function that best fits the
81       * observed points (in the same order as above).
82       * @since 3.0
83       */
84      public double[] fit(double[] initialGuess) {
85          final ParametricUnivariateRealFunction f = new ParametricUnivariateRealFunction() {
86                  private final ParametricUnivariateRealFunction g = new Gaussian.Parametric();
87  
88                  public double value(double x, double ... p) {
89                      double v = Double.POSITIVE_INFINITY;
90                      try {
91                          v = g.value(x, p);
92                      } catch (NotStrictlyPositiveException e) {
93                          // Do nothing.
94                      }
95                      return v;
96                  }
97  
98                  public double[] gradient(double x, double ... p) {
99                      double[] v = { Double.POSITIVE_INFINITY,
100                                    Double.POSITIVE_INFINITY,
101                                    Double.POSITIVE_INFINITY };
102                     try {
103                         v = g.gradient(x, p);
104                     } catch (NotStrictlyPositiveException e) {
105                         // Do nothing.
106                     }
107                     return v;
108                 }
109             };
110 
111         return fit(f, initialGuess);
112     }
113 
114     /**
115      * Fits a Gaussian function to the observed points.
116      *
117      * @return the parameters of the Gaussian function that best fits the
118      * observed points (in the same order as above).
119      */
120     public double[] fit() {
121         final double[] guess = (new ParameterGuesser(getObservations())).guess();
122         return fit(guess);
123     }
124 
125     /**
126      * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma}
127      * of a {@link org.apache.commons.math.analysis.function.Gaussian.Parametric}
128      * based on the specified observed points.
129      */
130     public static class ParameterGuesser {
131         /** Observed points. */
132         private final WeightedObservedPoint[] observations;
133         /** Resulting guessed parameters. */
134         private double[] parameters;
135 
136         /**
137          * Constructs instance with the specified observed points.
138          *
139          * @param observations observed points upon which should base guess
140          */
141         public ParameterGuesser(WeightedObservedPoint[] observations) {
142             if (observations == null) {
143                 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
144             }
145             if (observations.length < 3) {
146                 throw new NumberIsTooSmallException(observations.length, 3, true);
147             }
148             this.observations = observations.clone();
149         }
150 
151         /**
152          * Guesses the parameters based on the observed points.
153          *
154          * @return the guessed parameters: norm, mean and sigma.
155          */
156         public double[] guess() {
157             if (parameters == null) {
158                 parameters = basicGuess(observations);
159             }
160             return parameters.clone();
161         }
162 
163         /**
164          * Guesses the parameters based on the specified observed points.
165          *
166          * @param points Observed points upon which should base guess.
167          * @return the guessed parameters: norm, mean and sigma.
168          */
169         private double[] basicGuess(WeightedObservedPoint[] points) {
170             Arrays.sort(points, createWeightedObservedPointComparator());
171             double[] params = new double[3];
172 
173             int maxYIdx = findMaxY(points);
174             params[0] = points[maxYIdx].getY();
175             params[1] = points[maxYIdx].getX();
176 
177             double fwhmApprox;
178             try {
179                 double halfY = params[0] + ((params[1] - params[0]) / 2.0);
180                 double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
181                 double fwhmX2 = interpolateXAtY(points, maxYIdx, +1, halfY);
182                 fwhmApprox = fwhmX2 - fwhmX1;
183             } catch (OutOfRangeException e) {
184                 fwhmApprox = points[points.length - 1].getX() - points[0].getX();
185             }
186             params[2] = fwhmApprox / (2.0 * Math.sqrt(2.0 * Math.log(2.0)));
187 
188             return params;
189         }
190 
191         /**
192          * Finds index of point in specified points with the largest Y.
193          *
194          * @param points Points to search.
195          * @return the index in specified points array.
196          */
197         private int findMaxY(WeightedObservedPoint[] points) {
198             int maxYIdx = 0;
199             for (int i = 1; i < points.length; i++) {
200                 if (points[i].getY() > points[maxYIdx].getY()) {
201                     maxYIdx = i;
202                 }
203             }
204             return maxYIdx;
205         }
206 
207         /**
208          * Interpolates using the specified points to determine X at the
209          * specified Y.
210          *
211          * @param points Points to use for interpolation.
212          * @param startIdx Index within points from which to start search for
213          *  interpolation bounds points.
214          * @param idxStep Index step for search for interpolation bounds points.
215          * @param y Y value for which X should be determined.
216          * @return the value of X at the specified Y.
217          * @throws ZeroException if {@code idxStep} is 0.
218          * @throws OutOfRangeException if specified {@code y} is not within the
219          * range of the specified {@code points}.
220          */
221         private double interpolateXAtY(WeightedObservedPoint[] points,
222                                        int startIdx, int idxStep, double y)
223             throws OutOfRangeException {
224             if (idxStep == 0) {
225                 throw new ZeroException();
226             }
227             WeightedObservedPoint[] twoPoints = getInterpolationPointsForY(points, startIdx, idxStep, y);
228             WeightedObservedPoint pointA = twoPoints[0];
229             WeightedObservedPoint pointB = twoPoints[1];
230             if (pointA.getY() == y) {
231                 return pointA.getX();
232             }
233             if (pointB.getY() == y) {
234                 return pointB.getX();
235             }
236             return pointA.getX() +
237                    (((y - pointA.getY()) * (pointB.getX() - pointA.getX())) /
238                     (pointB.getY() - pointA.getY()));
239         }
240 
241         /**
242          * Gets the two bounding interpolation points from the specified points
243          * suitable for determining X at the specified Y.
244          *
245          * @param points Points to use for interpolation.
246          * @param startIdx Index within points from which to start search for
247          * interpolation bounds points.
248          * @param idxStep Index step for search for interpolation bounds points.
249          * @param y Y value for which X should be determined.
250          * @return the array containing two points suitable for determining X at
251          * the specified Y.
252          * @throws ZeroException if {@code idxStep} is 0.
253          * @throws OutOfRangeException if specified {@code y} is not within the
254          * range of the specified {@code points}.
255          */
256         private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
257                                                                    int startIdx, int idxStep, double y)
258             throws OutOfRangeException {
259             if (idxStep == 0) {
260                 throw new ZeroException();
261             }
262             for (int i = startIdx;
263                  (idxStep < 0) ? (i + idxStep >= 0) : (i + idxStep < points.length);
264                  i += idxStep) {
265                 if (isBetween(y, points[i].getY(), points[i + idxStep].getY())) {
266                     return (idxStep < 0) ?
267                            new WeightedObservedPoint[] { points[i + idxStep], points[i] } :
268                            new WeightedObservedPoint[] { points[i], points[i + idxStep] };
269                 }
270             }
271 
272             double minY = Double.POSITIVE_INFINITY;
273             double maxY = Double.NEGATIVE_INFINITY;
274             for (final WeightedObservedPoint point : points) {
275                 minY = Math.min(minY, point.getY());
276                 maxY = Math.max(maxY, point.getY());
277             }
278             throw new OutOfRangeException(y, minY, maxY);
279         }
280 
281         /**
282          * Determines whether a value is between two other values.
283          *
284          * @param value Value to determine whether is between {@code boundary1}
285          * and {@code boundary2}.
286          * @param boundary1 One end of the range.
287          * @param boundary2 Other end of the range.
288          * @return {@code true} if {@code value} is between {@code boundary1} and
289          * {@code boundary2} (inclusive), {@code false} otherwise.
290          */
291         private boolean isBetween(double value, double boundary1, double boundary2) {
292             return (value >= boundary1 && value <= boundary2) ||
293                    (value >= boundary2 && value <= boundary1);
294         }
295 
296         /**
297          * Factory method creating {@code Comparator} for comparing
298          * {@code WeightedObservedPoint} instances.
299          *
300          * @return the new {@code Comparator} instance.
301          */
302         private Comparator<WeightedObservedPoint> createWeightedObservedPointComparator() {
303             return new Comparator<WeightedObservedPoint>() {
304                 public int compare(WeightedObservedPoint p1, WeightedObservedPoint p2) {
305                     if (p1 == null && p2 == null) {
306                         return 0;
307                     }
308                     if (p1 == null) {
309                         return -1;
310                     }
311                     if (p2 == null) {
312                         return 1;
313                     }
314                     if (p1.getX() < p2.getX()) {
315                         return -1;
316                     }
317                     if (p1.getX() > p2.getX()) {
318                         return 1;
319                     }
320                     if (p1.getY() < p2.getY()) {
321                         return -1;
322                     }
323                     if (p1.getY() > p2.getY()) {
324                         return 1;
325                     }
326                     if (p1.getWeight() < p2.getWeight()) {
327                         return -1;
328                     }
329                     if (p1.getWeight() > p2.getWeight()) {
330                         return 1;
331                     }
332                     return 0;
333                 }
334             };
335         }
336     }
337 }