001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.fitting;
018
019import java.util.ArrayList;
020import java.util.Collection;
021import java.util.Collections;
022import java.util.Comparator;
023import java.util.List;
024
025import org.apache.commons.math3.analysis.function.Gaussian;
026import org.apache.commons.math3.exception.NotStrictlyPositiveException;
027import org.apache.commons.math3.exception.NullArgumentException;
028import org.apache.commons.math3.exception.NumberIsTooSmallException;
029import org.apache.commons.math3.exception.OutOfRangeException;
030import org.apache.commons.math3.exception.ZeroException;
031import org.apache.commons.math3.exception.util.LocalizedFormats;
032import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
033import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
034import org.apache.commons.math3.linear.DiagonalMatrix;
035import org.apache.commons.math3.util.FastMath;
036
037/**
038 * Fits points to a {@link
039 * org.apache.commons.math3.analysis.function.Gaussian.Parametric Gaussian}
040 * function.
041 * <br/>
042 * The {@link #withStartPoint(double[]) initial guess values} must be passed
043 * in the following order:
044 * <ul>
045 *  <li>Normalization</li>
046 *  <li>Mean</li>
047 *  <li>Sigma</li>
048 * </ul>
049 * The optimal values will be returned in the same order.
050 *
051 * <p>
052 * Usage example:
053 * <pre>
054 *   WeightedObservedPoints obs = new WeightedObservedPoints();
055 *   obs.add(4.0254623,  531026.0);
056 *   obs.add(4.03128248, 984167.0);
057 *   obs.add(4.03839603, 1887233.0);
058 *   obs.add(4.04421621, 2687152.0);
059 *   obs.add(4.05132976, 3461228.0);
060 *   obs.add(4.05326982, 3580526.0);
061 *   obs.add(4.05779662, 3439750.0);
062 *   obs.add(4.0636168,  2877648.0);
063 *   obs.add(4.06943698, 2175960.0);
064 *   obs.add(4.07525716, 1447024.0);
065 *   obs.add(4.08237071, 717104.0);
066 *   obs.add(4.08366408, 620014.0);
067 *   double[] parameters = GaussianCurveFitter.create().fit(obs.toList());
068 * </pre>
069 *
070 * @since 3.3
071 */
072public class GaussianCurveFitter extends AbstractCurveFitter {
073    /** Parametric function to be fitted. */
074    private static final Gaussian.Parametric FUNCTION = new Gaussian.Parametric() {
075            @Override
076            public double value(double x, double ... p) {
077                double v = Double.POSITIVE_INFINITY;
078                try {
079                    v = super.value(x, p);
080                } catch (NotStrictlyPositiveException e) { // NOPMD
081                    // Do nothing.
082                }
083                return v;
084            }
085
086            @Override
087            public double[] gradient(double x, double ... p) {
088                double[] v = { Double.POSITIVE_INFINITY,
089                               Double.POSITIVE_INFINITY,
090                               Double.POSITIVE_INFINITY };
091                try {
092                    v = super.gradient(x, p);
093                } catch (NotStrictlyPositiveException e) { // NOPMD
094                    // Do nothing.
095                }
096                return v;
097            }
098        };
099    /** Initial guess. */
100    private final double[] initialGuess;
101    /** Maximum number of iterations of the optimization algorithm. */
102    private final int maxIter;
103
104    /**
105     * Contructor used by the factory methods.
106     *
107     * @param initialGuess Initial guess. If set to {@code null}, the initial guess
108     * will be estimated using the {@link ParameterGuesser}.
109     * @param maxIter Maximum number of iterations of the optimization algorithm.
110     */
111    private GaussianCurveFitter(double[] initialGuess,
112                                int maxIter) {
113        this.initialGuess = initialGuess;
114        this.maxIter = maxIter;
115    }
116
117    /**
118     * Creates a default curve fitter.
119     * The initial guess for the parameters will be {@link ParameterGuesser}
120     * computed automatically, and the maximum number of iterations of the
121     * optimization algorithm is set to {@link Integer#MAX_VALUE}.
122     *
123     * @return a curve fitter.
124     *
125     * @see #withStartPoint(double[])
126     * @see #withMaxIterations(int)
127     */
128    public static GaussianCurveFitter create() {
129        return new GaussianCurveFitter(null, Integer.MAX_VALUE);
130    }
131
132    /**
133     * Configure the start point (initial guess).
134     * @param newStart new start point (initial guess)
135     * @return a new instance.
136     */
137    public GaussianCurveFitter withStartPoint(double[] newStart) {
138        return new GaussianCurveFitter(newStart.clone(),
139                                       maxIter);
140    }
141
142    /**
143     * Configure the maximum number of iterations.
144     * @param newMaxIter maximum number of iterations
145     * @return a new instance.
146     */
147    public GaussianCurveFitter withMaxIterations(int newMaxIter) {
148        return new GaussianCurveFitter(initialGuess,
149                                       newMaxIter);
150    }
151
152    /** {@inheritDoc} */
153    @Override
154    protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
155
156        // Prepare least-squares problem.
157        final int len = observations.size();
158        final double[] target  = new double[len];
159        final double[] weights = new double[len];
160
161        int i = 0;
162        for (WeightedObservedPoint obs : observations) {
163            target[i]  = obs.getY();
164            weights[i] = obs.getWeight();
165            ++i;
166        }
167
168        final AbstractCurveFitter.TheoreticalValuesFunction model =
169                new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations);
170
171        final double[] startPoint = initialGuess != null ?
172            initialGuess :
173            // Compute estimation.
174            new ParameterGuesser(observations).guess();
175
176        // Return a new least squares problem set up to fit a Gaussian curve to the
177        // observed points.
178        return new LeastSquaresBuilder().
179                maxEvaluations(Integer.MAX_VALUE).
180                maxIterations(maxIter).
181                start(startPoint).
182                target(target).
183                weight(new DiagonalMatrix(weights)).
184                model(model.getModelFunction(), model.getModelFunctionJacobian()).
185                build();
186
187    }
188
189    /**
190     * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma}
191     * of a {@link org.apache.commons.math3.analysis.function.Gaussian.Parametric}
192     * based on the specified observed points.
193     */
194    public static class ParameterGuesser {
195        /** Normalization factor. */
196        private final double norm;
197        /** Mean. */
198        private final double mean;
199        /** Standard deviation. */
200        private final double sigma;
201
202        /**
203         * Constructs instance with the specified observed points.
204         *
205         * @param observations Observed points from which to guess the
206         * parameters of the Gaussian.
207         * @throws NullArgumentException if {@code observations} is
208         * {@code null}.
209         * @throws NumberIsTooSmallException if there are less than 3
210         * observations.
211         */
212        public ParameterGuesser(Collection<WeightedObservedPoint> observations) {
213            if (observations == null) {
214                throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
215            }
216            if (observations.size() < 3) {
217                throw new NumberIsTooSmallException(observations.size(), 3, true);
218            }
219
220            final List<WeightedObservedPoint> sorted = sortObservations(observations);
221            final double[] params = basicGuess(sorted.toArray(new WeightedObservedPoint[0]));
222
223            norm = params[0];
224            mean = params[1];
225            sigma = params[2];
226        }
227
228        /**
229         * Gets an estimation of the parameters.
230         *
231         * @return the guessed parameters, in the following order:
232         * <ul>
233         *  <li>Normalization factor</li>
234         *  <li>Mean</li>
235         *  <li>Standard deviation</li>
236         * </ul>
237         */
238        public double[] guess() {
239            return new double[] { norm, mean, sigma };
240        }
241
242        /**
243         * Sort the observations.
244         *
245         * @param unsorted Input observations.
246         * @return the input observations, sorted.
247         */
248        private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
249            final List<WeightedObservedPoint> observations = new ArrayList<WeightedObservedPoint>(unsorted);
250
251            final Comparator<WeightedObservedPoint> cmp = new Comparator<WeightedObservedPoint>() {
252                public int compare(WeightedObservedPoint p1,
253                                   WeightedObservedPoint p2) {
254                    if (p1 == null && p2 == null) {
255                        return 0;
256                    }
257                    if (p1 == null) {
258                        return -1;
259                    }
260                    if (p2 == null) {
261                        return 1;
262                    }
263                    if (p1.getX() < p2.getX()) {
264                        return -1;
265                    }
266                    if (p1.getX() > p2.getX()) {
267                        return 1;
268                    }
269                    if (p1.getY() < p2.getY()) {
270                        return -1;
271                    }
272                    if (p1.getY() > p2.getY()) {
273                        return 1;
274                    }
275                    if (p1.getWeight() < p2.getWeight()) {
276                        return -1;
277                    }
278                    if (p1.getWeight() > p2.getWeight()) {
279                        return 1;
280                    }
281                    return 0;
282                }
283            };
284
285            Collections.sort(observations, cmp);
286            return observations;
287        }
288
289        /**
290         * Guesses the parameters based on the specified observed points.
291         *
292         * @param points Observed points, sorted.
293         * @return the guessed parameters (normalization factor, mean and
294         * sigma).
295         */
296        private double[] basicGuess(WeightedObservedPoint[] points) {
297            final int maxYIdx = findMaxY(points);
298            final double n = points[maxYIdx].getY();
299            final double m = points[maxYIdx].getX();
300
301            double fwhmApprox;
302            try {
303                final double halfY = n + ((m - n) / 2);
304                final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
305                final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY);
306                fwhmApprox = fwhmX2 - fwhmX1;
307            } catch (OutOfRangeException e) {
308                // TODO: Exceptions should not be used for flow control.
309                fwhmApprox = points[points.length - 1].getX() - points[0].getX();
310            }
311            final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2)));
312
313            return new double[] { n, m, s };
314        }
315
316        /**
317         * Finds index of point in specified points with the largest Y.
318         *
319         * @param points Points to search.
320         * @return the index in specified points array.
321         */
322        private int findMaxY(WeightedObservedPoint[] points) {
323            int maxYIdx = 0;
324            for (int i = 1; i < points.length; i++) {
325                if (points[i].getY() > points[maxYIdx].getY()) {
326                    maxYIdx = i;
327                }
328            }
329            return maxYIdx;
330        }
331
332        /**
333         * Interpolates using the specified points to determine X at the
334         * specified Y.
335         *
336         * @param points Points to use for interpolation.
337         * @param startIdx Index within points from which to start the search for
338         * interpolation bounds points.
339         * @param idxStep Index step for searching interpolation bounds points.
340         * @param y Y value for which X should be determined.
341         * @return the value of X for the specified Y.
342         * @throws ZeroException if {@code idxStep} is 0.
343         * @throws OutOfRangeException if specified {@code y} is not within the
344         * range of the specified {@code points}.
345         */
346        private double interpolateXAtY(WeightedObservedPoint[] points,
347                                       int startIdx,
348                                       int idxStep,
349                                       double y)
350            throws OutOfRangeException {
351            if (idxStep == 0) {
352                throw new ZeroException();
353            }
354            final WeightedObservedPoint[] twoPoints
355                = getInterpolationPointsForY(points, startIdx, idxStep, y);
356            final WeightedObservedPoint p1 = twoPoints[0];
357            final WeightedObservedPoint p2 = twoPoints[1];
358            if (p1.getY() == y) {
359                return p1.getX();
360            }
361            if (p2.getY() == y) {
362                return p2.getX();
363            }
364            return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
365                                (p2.getY() - p1.getY()));
366        }
367
368        /**
369         * Gets the two bounding interpolation points from the specified points
370         * suitable for determining X at the specified Y.
371         *
372         * @param points Points to use for interpolation.
373         * @param startIdx Index within points from which to start search for
374         * interpolation bounds points.
375         * @param idxStep Index step for search for interpolation bounds points.
376         * @param y Y value for which X should be determined.
377         * @return the array containing two points suitable for determining X at
378         * the specified Y.
379         * @throws ZeroException if {@code idxStep} is 0.
380         * @throws OutOfRangeException if specified {@code y} is not within the
381         * range of the specified {@code points}.
382         */
383        private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
384                                                                   int startIdx,
385                                                                   int idxStep,
386                                                                   double y)
387            throws OutOfRangeException {
388            if (idxStep == 0) {
389                throw new ZeroException();
390            }
391            for (int i = startIdx;
392                 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
393                 i += idxStep) {
394                final WeightedObservedPoint p1 = points[i];
395                final WeightedObservedPoint p2 = points[i + idxStep];
396                if (isBetween(y, p1.getY(), p2.getY())) {
397                    if (idxStep < 0) {
398                        return new WeightedObservedPoint[] { p2, p1 };
399                    } else {
400                        return new WeightedObservedPoint[] { p1, p2 };
401                    }
402                }
403            }
404
405            // Boundaries are replaced by dummy values because the raised
406            // exception is caught and the message never displayed.
407            // TODO: Exceptions should not be used for flow control.
408            throw new OutOfRangeException(y,
409                                          Double.NEGATIVE_INFINITY,
410                                          Double.POSITIVE_INFINITY);
411        }
412
413        /**
414         * Determines whether a value is between two other values.
415         *
416         * @param value Value to test whether it is between {@code boundary1}
417         * and {@code boundary2}.
418         * @param boundary1 One end of the range.
419         * @param boundary2 Other end of the range.
420         * @return {@code true} if {@code value} is between {@code boundary1} and
421         * {@code boundary2} (inclusive), {@code false} otherwise.
422         */
423        private boolean isBetween(double value,
424                                  double boundary1,
425                                  double boundary2) {
426            return (value >= boundary1 && value <= boundary2) ||
427                (value >= boundary2 && value <= boundary1);
428        }
429    }
430}