View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.fitting;
18  
19  import java.util.List;
20  import java.util.Collection;
21  
22  import org.apache.commons.math4.legacy.analysis.function.Gaussian;
23  import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
24  import org.apache.commons.math4.legacy.exception.NullArgumentException;
25  import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
26  import org.apache.commons.math4.legacy.exception.OutOfRangeException;
27  import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
28  import org.apache.commons.math4.core.jdkmath.JdkMath;
29  
30  /**
31   * Fits points to a {@link
32   * org.apache.commons.math4.legacy.analysis.function.Gaussian.Parametric Gaussian}
33   * function.
34   * <br>
35   * The {@link #withStartPoint(double[]) initial guess values} must be passed
36   * in the following order:
37   * <ul>
38   *  <li>Normalization</li>
39   *  <li>Mean</li>
40   *  <li>Sigma</li>
41   * </ul>
42   * The optimal values will be returned in the same order.
43   *
44   * <p>
45   * Usage example:
46   * <pre>
47   *   WeightedObservedPoints obs = new WeightedObservedPoints();
48   *   obs.add(4.0254623,  531026.0);
49   *   obs.add(4.03128248, 984167.0);
50   *   obs.add(4.03839603, 1887233.0);
51   *   obs.add(4.04421621, 2687152.0);
52   *   obs.add(4.05132976, 3461228.0);
53   *   obs.add(4.05326982, 3580526.0);
54   *   obs.add(4.05779662, 3439750.0);
55   *   obs.add(4.0636168,  2877648.0);
56   *   obs.add(4.06943698, 2175960.0);
57   *   obs.add(4.07525716, 1447024.0);
58   *   obs.add(4.08237071, 717104.0);
59   *   obs.add(4.08366408, 620014.0);
60   *   double[] parameters = GaussianCurveFitter.create().fit(obs.toList());
61   * </pre>
62   *
63   * @since 3.3
64   */
65  public final class GaussianCurveFitter extends SimpleCurveFitter {
66      /** Parametric function to be fitted. */
67      private static final Gaussian.Parametric FUNCTION = new Gaussian.Parametric() {
68              /** {@inheritDoc} */
69              @Override
70              public double value(double x, double ... p) {
71                  double v = Double.POSITIVE_INFINITY;
72                  try {
73                      v = super.value(x, p);
74                  } catch (NotStrictlyPositiveException e) { // NOPMD
75                      // Do nothing.
76                  }
77                  return v;
78              }
79  
80              /** {@inheritDoc} */
81              @Override
82              public double[] gradient(double x, double ... p) {
83                  double[] v = { Double.POSITIVE_INFINITY,
84                                 Double.POSITIVE_INFINITY,
85                                 Double.POSITIVE_INFINITY };
86                  try {
87                      v = super.gradient(x, p);
88                  } catch (NotStrictlyPositiveException e) { // NOPMD
89                      // Do nothing.
90                  }
91                  return v;
92              }
93          };
94  
95      /**
96       * Constructor used by the factory methods.
97       *
98       * @param initialGuess Initial guess. If set to {@code null}, the initial guess
99       * will be estimated using the {@link ParameterGuesser}.
100      * @param maxIter Maximum number of iterations of the optimization algorithm.
101      */
102     private GaussianCurveFitter(double[] initialGuess,
103                                 int maxIter) {
104         super(FUNCTION, initialGuess, new ParameterGuesser(), maxIter);
105     }
106 
107     /**
108      * Creates a default curve fitter.
109      * The initial guess for the parameters will be {@link ParameterGuesser}
110      * computed automatically, and the maximum number of iterations of the
111      * optimization algorithm is set to {@link Integer#MAX_VALUE}.
112      *
113      * @return a curve fitter.
114      *
115      * @see #withStartPoint(double[])
116      * @see #withMaxIterations(int)
117      */
118     public static GaussianCurveFitter create() {
119         return new GaussianCurveFitter(null, Integer.MAX_VALUE);
120     }
121 
122     /**
123      * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma}
124      * of a {@link org.apache.commons.math4.legacy.analysis.function.Gaussian.Parametric}
125      * based on the specified observed points.
126      */
127     public static class ParameterGuesser extends SimpleCurveFitter.ParameterGuesser {
128         /**
129          * {@inheritDoc}
130          *
131          * @return the guessed parameters, in the following order:
132          * <ul>
133          *  <li>Normalization factor</li>
134          *  <li>Mean</li>
135          *  <li>Standard deviation</li>
136          * </ul>
137          * @throws NullArgumentException if {@code observations} is
138          * {@code null}.
139          * @throws NumberIsTooSmallException if there are less than 3
140          * observations.
141          */
142         @Override
143         public double[] guess(Collection<WeightedObservedPoint> observations) {
144             if (observations == null) {
145                 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
146             }
147             if (observations.size() < 3) {
148                 throw new NumberIsTooSmallException(observations.size(), 3, true);
149             }
150 
151             final List<WeightedObservedPoint> sorted = sortObservations(observations);
152             return basicGuess(sorted.toArray(new WeightedObservedPoint[0]));
153         }
154 
155         /**
156          * Guesses the parameters based on the specified observed points.
157          *
158          * @param points Observed points, sorted.
159          * @return the guessed parameters (normalization factor, mean and
160          * sigma).
161          */
162         private double[] basicGuess(WeightedObservedPoint[] points) {
163             final int maxYIdx = findMaxY(points);
164             final double n = points[maxYIdx].getY();
165 
166             double fwhmApprox;
167             try {
168                 final double halfY = 0.5 * n;
169                 final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
170                 final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY);
171                 fwhmApprox = fwhmX2 - fwhmX1;
172             } catch (OutOfRangeException e) {
173                 // TODO: Exceptions should not be used for flow control.
174                 fwhmApprox = points[points.length - 1].getX() - points[0].getX();
175             }
176             final double s = fwhmApprox / (2 * JdkMath.sqrt(2 * JdkMath.log(2)));
177 
178             return new double[] { n, points[maxYIdx].getX(), s };
179         }
180     }
181 }