001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.fitting;
018
019import java.util.ArrayList;
020import java.util.Collection;
021import java.util.Collections;
022import java.util.Comparator;
023import java.util.List;
024
025import org.apache.commons.math3.analysis.function.Gaussian;
026import org.apache.commons.math3.exception.NotStrictlyPositiveException;
027import org.apache.commons.math3.exception.NullArgumentException;
028import org.apache.commons.math3.exception.NumberIsTooSmallException;
029import org.apache.commons.math3.exception.OutOfRangeException;
030import org.apache.commons.math3.exception.ZeroException;
031import org.apache.commons.math3.exception.util.LocalizedFormats;
032import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
033import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
034import org.apache.commons.math3.linear.DiagonalMatrix;
035import org.apache.commons.math3.util.FastMath;
036
037/**
038 * Fits points to a {@link
039 * org.apache.commons.math3.analysis.function.Gaussian.Parametric Gaussian}
040 * function.
041 * <br/>
042 * The {@link #withStartPoint(double[]) initial guess values} must be passed
043 * in the following order:
044 * <ul>
045 *  <li>Normalization</li>
046 *  <li>Mean</li>
047 *  <li>Sigma</li>
048 * </ul>
049 * The optimal values will be returned in the same order.
050 *
051 * <p>
052 * Usage example:
053 * <pre>
054 *   WeightedObservedPoints obs = new WeightedObservedPoints();
055 *   obs.add(4.0254623,  531026.0);
056 *   obs.add(4.03128248, 984167.0);
057 *   obs.add(4.03839603, 1887233.0);
058 *   obs.add(4.04421621, 2687152.0);
059 *   obs.add(4.05132976, 3461228.0);
060 *   obs.add(4.05326982, 3580526.0);
061 *   obs.add(4.05779662, 3439750.0);
062 *   obs.add(4.0636168,  2877648.0);
063 *   obs.add(4.06943698, 2175960.0);
064 *   obs.add(4.07525716, 1447024.0);
065 *   obs.add(4.08237071, 717104.0);
066 *   obs.add(4.08366408, 620014.0);
067 *   double[] parameters = GaussianCurveFitter.create().fit(obs);
068 * </pre>
069 *
070 * @version $Id: GaussianCurveFitter.ParameterGuesser.html 908881 2014-05-15 07:10:28Z luc $
071 * @since 3.3
072 */
073public class GaussianCurveFitter extends AbstractCurveFitter {
074    /** Parametric function to be fitted. */
075    private static final Gaussian.Parametric FUNCTION = new Gaussian.Parametric() {
076            @Override
077            public double value(double x, double ... p) {
078                double v = Double.POSITIVE_INFINITY;
079                try {
080                    v = super.value(x, p);
081                } catch (NotStrictlyPositiveException e) { // NOPMD
082                    // Do nothing.
083                }
084                return v;
085            }
086
087            @Override
088            public double[] gradient(double x, double ... p) {
089                double[] v = { Double.POSITIVE_INFINITY,
090                               Double.POSITIVE_INFINITY,
091                               Double.POSITIVE_INFINITY };
092                try {
093                    v = super.gradient(x, p);
094                } catch (NotStrictlyPositiveException e) { // NOPMD
095                    // Do nothing.
096                }
097                return v;
098            }
099        };
100    /** Initial guess. */
101    private final double[] initialGuess;
102    /** Maximum number of iterations of the optimization algorithm. */
103    private final int maxIter;
104
105    /**
106     * Contructor used by the factory methods.
107     *
108     * @param initialGuess Initial guess. If set to {@code null}, the initial guess
109     * will be estimated using the {@link ParameterGuesser}.
110     * @param maxIter Maximum number of iterations of the optimization algorithm.
111     */
112    private GaussianCurveFitter(double[] initialGuess,
113                                int maxIter) {
114        this.initialGuess = initialGuess;
115        this.maxIter = maxIter;
116    }
117
118    /**
119     * Creates a default curve fitter.
120     * The initial guess for the parameters will be {@link ParameterGuesser}
121     * computed automatically, and the maximum number of iterations of the
122     * optimization algorithm is set to {@link Integer#MAX_VALUE}.
123     *
124     * @return a curve fitter.
125     *
126     * @see #withStartPoint(double[])
127     * @see #withMaxIterations(int)
128     */
129    public static GaussianCurveFitter create() {
130        return new GaussianCurveFitter(null, Integer.MAX_VALUE);
131    }
132
133    /**
134     * Configure the start point (initial guess).
135     * @param newStart new start point (initial guess)
136     * @return a new instance.
137     */
138    public GaussianCurveFitter withStartPoint(double[] newStart) {
139        return new GaussianCurveFitter(newStart.clone(),
140                                       maxIter);
141    }
142
143    /**
144     * Configure the maximum number of iterations.
145     * @param newMaxIter maximum number of iterations
146     * @return a new instance.
147     */
148    public GaussianCurveFitter withMaxIterations(int newMaxIter) {
149        return new GaussianCurveFitter(initialGuess,
150                                       newMaxIter);
151    }
152
153    /** {@inheritDoc} */
154    @Override
155    protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
156
157        // Prepare least-squares problem.
158        final int len = observations.size();
159        final double[] target  = new double[len];
160        final double[] weights = new double[len];
161
162        int i = 0;
163        for (WeightedObservedPoint obs : observations) {
164            target[i]  = obs.getY();
165            weights[i] = obs.getWeight();
166            ++i;
167        }
168
169        final AbstractCurveFitter.TheoreticalValuesFunction model =
170                new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations);
171
172        final double[] startPoint = initialGuess != null ?
173            initialGuess :
174            // Compute estimation.
175            new ParameterGuesser(observations).guess();
176
177        // Return a new least squares problem set up to fit a Gaussian curve to the
178        // observed points.
179        return new LeastSquaresBuilder().
180                maxEvaluations(Integer.MAX_VALUE).
181                maxIterations(maxIter).
182                start(startPoint).
183                target(target).
184                weight(new DiagonalMatrix(weights)).
185                model(model.getModelFunction(), model.getModelFunctionJacobian()).
186                build();
187
188    }
189
190    /**
191     * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma}
192     * of a {@link org.apache.commons.math3.analysis.function.Gaussian.Parametric}
193     * based on the specified observed points.
194     */
195    public static class ParameterGuesser {
196        /** Normalization factor. */
197        private final double norm;
198        /** Mean. */
199        private final double mean;
200        /** Standard deviation. */
201        private final double sigma;
202
203        /**
204         * Constructs instance with the specified observed points.
205         *
206         * @param observations Observed points from which to guess the
207         * parameters of the Gaussian.
208         * @throws NullArgumentException if {@code observations} is
209         * {@code null}.
210         * @throws NumberIsTooSmallException if there are less than 3
211         * observations.
212         */
213        public ParameterGuesser(Collection<WeightedObservedPoint> observations) {
214            if (observations == null) {
215                throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
216            }
217            if (observations.size() < 3) {
218                throw new NumberIsTooSmallException(observations.size(), 3, true);
219            }
220
221            final List<WeightedObservedPoint> sorted = sortObservations(observations);
222            final double[] params = basicGuess(sorted.toArray(new WeightedObservedPoint[0]));
223
224            norm = params[0];
225            mean = params[1];
226            sigma = params[2];
227        }
228
229        /**
230         * Gets an estimation of the parameters.
231         *
232         * @return the guessed parameters, in the following order:
233         * <ul>
234         *  <li>Normalization factor</li>
235         *  <li>Mean</li>
236         *  <li>Standard deviation</li>
237         * </ul>
238         */
239        public double[] guess() {
240            return new double[] { norm, mean, sigma };
241        }
242
243        /**
244         * Sort the observations.
245         *
246         * @param unsorted Input observations.
247         * @return the input observations, sorted.
248         */
249        private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
250            final List<WeightedObservedPoint> observations = new ArrayList<WeightedObservedPoint>(unsorted);
251
252            final Comparator<WeightedObservedPoint> cmp = new Comparator<WeightedObservedPoint>() {
253                public int compare(WeightedObservedPoint p1,
254                                   WeightedObservedPoint p2) {
255                    if (p1 == null && p2 == null) {
256                        return 0;
257                    }
258                    if (p1 == null) {
259                        return -1;
260                    }
261                    if (p2 == null) {
262                        return 1;
263                    }
264                    if (p1.getX() < p2.getX()) {
265                        return -1;
266                    }
267                    if (p1.getX() > p2.getX()) {
268                        return 1;
269                    }
270                    if (p1.getY() < p2.getY()) {
271                        return -1;
272                    }
273                    if (p1.getY() > p2.getY()) {
274                        return 1;
275                    }
276                    if (p1.getWeight() < p2.getWeight()) {
277                        return -1;
278                    }
279                    if (p1.getWeight() > p2.getWeight()) {
280                        return 1;
281                    }
282                    return 0;
283                }
284            };
285
286            Collections.sort(observations, cmp);
287            return observations;
288        }
289
290        /**
291         * Guesses the parameters based on the specified observed points.
292         *
293         * @param points Observed points, sorted.
294         * @return the guessed parameters (normalization factor, mean and
295         * sigma).
296         */
297        private double[] basicGuess(WeightedObservedPoint[] points) {
298            final int maxYIdx = findMaxY(points);
299            final double n = points[maxYIdx].getY();
300            final double m = points[maxYIdx].getX();
301
302            double fwhmApprox;
303            try {
304                final double halfY = n + ((m - n) / 2);
305                final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
306                final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY);
307                fwhmApprox = fwhmX2 - fwhmX1;
308            } catch (OutOfRangeException e) {
309                // TODO: Exceptions should not be used for flow control.
310                fwhmApprox = points[points.length - 1].getX() - points[0].getX();
311            }
312            final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2)));
313
314            return new double[] { n, m, s };
315        }
316
317        /**
318         * Finds index of point in specified points with the largest Y.
319         *
320         * @param points Points to search.
321         * @return the index in specified points array.
322         */
323        private int findMaxY(WeightedObservedPoint[] points) {
324            int maxYIdx = 0;
325            for (int i = 1; i < points.length; i++) {
326                if (points[i].getY() > points[maxYIdx].getY()) {
327                    maxYIdx = i;
328                }
329            }
330            return maxYIdx;
331        }
332
333        /**
334         * Interpolates using the specified points to determine X at the
335         * specified Y.
336         *
337         * @param points Points to use for interpolation.
338         * @param startIdx Index within points from which to start the search for
339         * interpolation bounds points.
340         * @param idxStep Index step for searching interpolation bounds points.
341         * @param y Y value for which X should be determined.
342         * @return the value of X for the specified Y.
343         * @throws ZeroException if {@code idxStep} is 0.
344         * @throws OutOfRangeException if specified {@code y} is not within the
345         * range of the specified {@code points}.
346         */
347        private double interpolateXAtY(WeightedObservedPoint[] points,
348                                       int startIdx,
349                                       int idxStep,
350                                       double y)
351            throws OutOfRangeException {
352            if (idxStep == 0) {
353                throw new ZeroException();
354            }
355            final WeightedObservedPoint[] twoPoints
356                = getInterpolationPointsForY(points, startIdx, idxStep, y);
357            final WeightedObservedPoint p1 = twoPoints[0];
358            final WeightedObservedPoint p2 = twoPoints[1];
359            if (p1.getY() == y) {
360                return p1.getX();
361            }
362            if (p2.getY() == y) {
363                return p2.getX();
364            }
365            return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
366                                (p2.getY() - p1.getY()));
367        }
368
369        /**
370         * Gets the two bounding interpolation points from the specified points
371         * suitable for determining X at the specified Y.
372         *
373         * @param points Points to use for interpolation.
374         * @param startIdx Index within points from which to start search for
375         * interpolation bounds points.
376         * @param idxStep Index step for search for interpolation bounds points.
377         * @param y Y value for which X should be determined.
378         * @return the array containing two points suitable for determining X at
379         * the specified Y.
380         * @throws ZeroException if {@code idxStep} is 0.
381         * @throws OutOfRangeException if specified {@code y} is not within the
382         * range of the specified {@code points}.
383         */
384        private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
385                                                                   int startIdx,
386                                                                   int idxStep,
387                                                                   double y)
388            throws OutOfRangeException {
389            if (idxStep == 0) {
390                throw new ZeroException();
391            }
392            for (int i = startIdx;
393                 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
394                 i += idxStep) {
395                final WeightedObservedPoint p1 = points[i];
396                final WeightedObservedPoint p2 = points[i + idxStep];
397                if (isBetween(y, p1.getY(), p2.getY())) {
398                    if (idxStep < 0) {
399                        return new WeightedObservedPoint[] { p2, p1 };
400                    } else {
401                        return new WeightedObservedPoint[] { p1, p2 };
402                    }
403                }
404            }
405
406            // Boundaries are replaced by dummy values because the raised
407            // exception is caught and the message never displayed.
408            // TODO: Exceptions should not be used for flow control.
409            throw new OutOfRangeException(y,
410                                          Double.NEGATIVE_INFINITY,
411                                          Double.POSITIVE_INFINITY);
412        }
413
414        /**
415         * Determines whether a value is between two other values.
416         *
417         * @param value Value to test whether it is between {@code boundary1}
418         * and {@code boundary2}.
419         * @param boundary1 One end of the range.
420         * @param boundary2 Other end of the range.
421         * @return {@code true} if {@code value} is between {@code boundary1} and
422         * {@code boundary2} (inclusive), {@code false} otherwise.
423         */
424        private boolean isBetween(double value,
425                                  double boundary1,
426                                  double boundary2) {
427            return (value >= boundary1 && value <= boundary2) ||
428                (value >= boundary2 && value <= boundary1);
429        }
430    }
431}