View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math3.optimization.fitting;
19  
20  import java.util.Arrays;
21  import java.util.Comparator;
22  
23  import org.apache.commons.math3.analysis.function.Gaussian;
24  import org.apache.commons.math3.exception.NullArgumentException;
25  import org.apache.commons.math3.exception.NumberIsTooSmallException;
26  import org.apache.commons.math3.exception.OutOfRangeException;
27  import org.apache.commons.math3.exception.ZeroException;
28  import org.apache.commons.math3.exception.NotStrictlyPositiveException;
29  import org.apache.commons.math3.exception.util.LocalizedFormats;
30  import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer;
31  import org.apache.commons.math3.util.FastMath;
32  
33  /**
34   * Fits points to a {@link
35   * org.apache.commons.math3.analysis.function.Gaussian.Parametric Gaussian} function.
36   * <p>
37   * Usage example:
38   * <pre>
39   *   GaussianFitter fitter = new GaussianFitter(
40   *     new LevenbergMarquardtOptimizer());
41   *   fitter.addObservedPoint(4.0254623,  531026.0);
42   *   fitter.addObservedPoint(4.03128248, 984167.0);
43   *   fitter.addObservedPoint(4.03839603, 1887233.0);
44   *   fitter.addObservedPoint(4.04421621, 2687152.0);
45   *   fitter.addObservedPoint(4.05132976, 3461228.0);
46   *   fitter.addObservedPoint(4.05326982, 3580526.0);
47   *   fitter.addObservedPoint(4.05779662, 3439750.0);
48   *   fitter.addObservedPoint(4.0636168,  2877648.0);
49   *   fitter.addObservedPoint(4.06943698, 2175960.0);
50   *   fitter.addObservedPoint(4.07525716, 1447024.0);
51   *   fitter.addObservedPoint(4.08237071, 717104.0);
52   *   fitter.addObservedPoint(4.08366408, 620014.0);
53   *   double[] parameters = fitter.fit();
54   * </pre>
55   *
56   * @since 2.2
57   * @deprecated As of 3.1 (to be removed in 4.0).
58   */
59  @Deprecated
60  public class GaussianFitter extends CurveFitter<Gaussian.Parametric> {
61      /**
62       * Constructs an instance using the specified optimizer.
63       *
64       * @param optimizer Optimizer to use for the fitting.
65       */
66      public GaussianFitter(DifferentiableMultivariateVectorOptimizer optimizer) {
67          super(optimizer);
68      }
69  
70      /**
71       * Fits a Gaussian function to the observed points.
72       *
73       * @param initialGuess First guess values in the following order:
74       * <ul>
75       *  <li>Norm</li>
76       *  <li>Mean</li>
77       *  <li>Sigma</li>
78       * </ul>
79       * @return the parameters of the Gaussian function that best fits the
80       * observed points (in the same order as above).
81       * @since 3.0
82       */
83      public double[] fit(double[] initialGuess) {
84          final Gaussian.Parametric f = new Gaussian.Parametric() {
85                  /** {@inheritDoc} */
86                  @Override
87                  public double value(double x, double ... p) {
88                      double v = Double.POSITIVE_INFINITY;
89                      try {
90                          v = super.value(x, p);
91                      } catch (NotStrictlyPositiveException e) { // NOPMD
92                          // Do nothing.
93                      }
94                      return v;
95                  }
96  
97                  /** {@inheritDoc} */
98                  @Override
99                  public double[] gradient(double x, double ... p) {
100                     double[] v = { Double.POSITIVE_INFINITY,
101                                    Double.POSITIVE_INFINITY,
102                                    Double.POSITIVE_INFINITY };
103                     try {
104                         v = super.gradient(x, p);
105                     } catch (NotStrictlyPositiveException e) { // NOPMD
106                         // Do nothing.
107                     }
108                     return v;
109                 }
110             };
111 
112         return fit(f, initialGuess);
113     }
114 
115     /**
116      * Fits a Gaussian function to the observed points.
117      *
118      * @return the parameters of the Gaussian function that best fits the
119      * observed points (in the same order as above).
120      */
121     public double[] fit() {
122         final double[] guess = (new ParameterGuesser(getObservations())).guess();
123         return fit(guess);
124     }
125 
126     /**
127      * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma}
128      * of a {@link org.apache.commons.math3.analysis.function.Gaussian.Parametric}
129      * based on the specified observed points.
130      */
131     public static class ParameterGuesser {
132         /** Normalization factor. */
133         private final double norm;
134         /** Mean. */
135         private final double mean;
136         /** Standard deviation. */
137         private final double sigma;
138 
139         /**
140          * Constructs instance with the specified observed points.
141          *
142          * @param observations Observed points from which to guess the
143          * parameters of the Gaussian.
144          * @throws NullArgumentException if {@code observations} is
145          * {@code null}.
146          * @throws NumberIsTooSmallException if there are less than 3
147          * observations.
148          */
149         public ParameterGuesser(WeightedObservedPoint[] observations) {
150             if (observations == null) {
151                 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
152             }
153             if (observations.length < 3) {
154                 throw new NumberIsTooSmallException(observations.length, 3, true);
155             }
156 
157             final WeightedObservedPoint[] sorted = sortObservations(observations);
158             final double[] params = basicGuess(sorted);
159 
160             norm = params[0];
161             mean = params[1];
162             sigma = params[2];
163         }
164 
165         /**
166          * Gets an estimation of the parameters.
167          *
168          * @return the guessed parameters, in the following order:
169          * <ul>
170          *  <li>Normalization factor</li>
171          *  <li>Mean</li>
172          *  <li>Standard deviation</li>
173          * </ul>
174          */
175         public double[] guess() {
176             return new double[] { norm, mean, sigma };
177         }
178 
179         /**
180          * Sort the observations.
181          *
182          * @param unsorted Input observations.
183          * @return the input observations, sorted.
184          */
185         private WeightedObservedPoint[] sortObservations(WeightedObservedPoint[] unsorted) {
186             final WeightedObservedPoint[] observations = unsorted.clone();
187             final Comparator<WeightedObservedPoint> cmp
188                 = new Comparator<WeightedObservedPoint>() {
189                 /** {@inheritDoc} */
190                 public int compare(WeightedObservedPoint p1,
191                                    WeightedObservedPoint p2) {
192                     if (p1 == null && p2 == null) {
193                         return 0;
194                     }
195                     if (p1 == null) {
196                         return -1;
197                     }
198                     if (p2 == null) {
199                         return 1;
200                     }
201                     final int cmpX = Double.compare(p1.getX(), p2.getX());
202                     if (cmpX < 0) {
203                         return -1;
204                     }
205                     if (cmpX > 0) {
206                         return 1;
207                     }
208                     final int cmpY = Double.compare(p1.getY(), p2.getY());
209                     if (cmpY < 0) {
210                         return -1;
211                     }
212                     if (cmpY > 0) {
213                         return 1;
214                     }
215                     final int cmpW = Double.compare(p1.getWeight(), p2.getWeight());
216                     if (cmpW < 0) {
217                         return -1;
218                     }
219                     if (cmpW > 0) {
220                         return 1;
221                     }
222                     return 0;
223                 }
224             };
225 
226             Arrays.sort(observations, cmp);
227             return observations;
228         }
229 
230         /**
231          * Guesses the parameters based on the specified observed points.
232          *
233          * @param points Observed points, sorted.
234          * @return the guessed parameters (normalization factor, mean and
235          * sigma).
236          */
237         private double[] basicGuess(WeightedObservedPoint[] points) {
238             final int maxYIdx = findMaxY(points);
239             final double n = points[maxYIdx].getY();
240             final double m = points[maxYIdx].getX();
241 
242             double fwhmApprox;
243             try {
244                 final double halfY = n + ((m - n) / 2);
245                 final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
246                 final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY);
247                 fwhmApprox = fwhmX2 - fwhmX1;
248             } catch (OutOfRangeException e) {
249                 // TODO: Exceptions should not be used for flow control.
250                 fwhmApprox = points[points.length - 1].getX() - points[0].getX();
251             }
252             final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2)));
253 
254             return new double[] { n, m, s };
255         }
256 
257         /**
258          * Finds index of point in specified points with the largest Y.
259          *
260          * @param points Points to search.
261          * @return the index in specified points array.
262          */
263         private int findMaxY(WeightedObservedPoint[] points) {
264             int maxYIdx = 0;
265             for (int i = 1; i < points.length; i++) {
266                 if (points[i].getY() > points[maxYIdx].getY()) {
267                     maxYIdx = i;
268                 }
269             }
270             return maxYIdx;
271         }
272 
273         /**
274          * Interpolates using the specified points to determine X at the
275          * specified Y.
276          *
277          * @param points Points to use for interpolation.
278          * @param startIdx Index within points from which to start the search for
279          * interpolation bounds points.
280          * @param idxStep Index step for searching interpolation bounds points.
281          * @param y Y value for which X should be determined.
282          * @return the value of X for the specified Y.
283          * @throws ZeroException if {@code idxStep} is 0.
284          * @throws OutOfRangeException if specified {@code y} is not within the
285          * range of the specified {@code points}.
286          */
287         private double interpolateXAtY(WeightedObservedPoint[] points,
288                                        int startIdx,
289                                        int idxStep,
290                                        double y)
291             throws OutOfRangeException {
292             if (idxStep == 0) {
293                 throw new ZeroException();
294             }
295             final WeightedObservedPoint[] twoPoints
296                 = getInterpolationPointsForY(points, startIdx, idxStep, y);
297             final WeightedObservedPoint p1 = twoPoints[0];
298             final WeightedObservedPoint p2 = twoPoints[1];
299             if (p1.getY() == y) {
300                 return p1.getX();
301             }
302             if (p2.getY() == y) {
303                 return p2.getX();
304             }
305             return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
306                                 (p2.getY() - p1.getY()));
307         }
308 
309         /**
310          * Gets the two bounding interpolation points from the specified points
311          * suitable for determining X at the specified Y.
312          *
313          * @param points Points to use for interpolation.
314          * @param startIdx Index within points from which to start search for
315          * interpolation bounds points.
316          * @param idxStep Index step for search for interpolation bounds points.
317          * @param y Y value for which X should be determined.
318          * @return the array containing two points suitable for determining X at
319          * the specified Y.
320          * @throws ZeroException if {@code idxStep} is 0.
321          * @throws OutOfRangeException if specified {@code y} is not within the
322          * range of the specified {@code points}.
323          */
324         private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
325                                                                    int startIdx,
326                                                                    int idxStep,
327                                                                    double y)
328             throws OutOfRangeException {
329             if (idxStep == 0) {
330                 throw new ZeroException();
331             }
332             for (int i = startIdx;
333                  idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
334                  i += idxStep) {
335                 final WeightedObservedPoint p1 = points[i];
336                 final WeightedObservedPoint p2 = points[i + idxStep];
337                 if (isBetween(y, p1.getY(), p2.getY())) {
338                     if (idxStep < 0) {
339                         return new WeightedObservedPoint[] { p2, p1 };
340                     } else {
341                         return new WeightedObservedPoint[] { p1, p2 };
342                     }
343                 }
344             }
345 
346             // Boundaries are replaced by dummy values because the raised
347             // exception is caught and the message never displayed.
348             // TODO: Exceptions should not be used for flow control.
349             throw new OutOfRangeException(y,
350                                           Double.NEGATIVE_INFINITY,
351                                           Double.POSITIVE_INFINITY);
352         }
353 
354         /**
355          * Determines whether a value is between two other values.
356          *
357          * @param value Value to test whether it is between {@code boundary1}
358          * and {@code boundary2}.
359          * @param boundary1 One end of the range.
360          * @param boundary2 Other end of the range.
361          * @return {@code true} if {@code value} is between {@code boundary1} and
362          * {@code boundary2} (inclusive), {@code false} otherwise.
363          */
364         private boolean isBetween(double value,
365                                   double boundary1,
366                                   double boundary2) {
367             return (value >= boundary1 && value <= boundary2) ||
368                 (value >= boundary2 && value <= boundary1);
369         }
370     }
371 }