001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.fitting;
018
019import java.util.ArrayList;
020import java.util.Collection;
021import java.util.Collections;
022import java.util.Comparator;
023import java.util.List;
024
025import org.apache.commons.math3.analysis.function.Gaussian;
026import org.apache.commons.math3.exception.NotStrictlyPositiveException;
027import org.apache.commons.math3.exception.NullArgumentException;
028import org.apache.commons.math3.exception.NumberIsTooSmallException;
029import org.apache.commons.math3.exception.OutOfRangeException;
030import org.apache.commons.math3.exception.ZeroException;
031import org.apache.commons.math3.exception.util.LocalizedFormats;
032import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
033import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
034import org.apache.commons.math3.linear.DiagonalMatrix;
035import org.apache.commons.math3.util.FastMath;
036
037/**
038 * Fits points to a {@link
039 * org.apache.commons.math3.analysis.function.Gaussian.Parametric Gaussian}
040 * function.
041 * <br/>
042 * The {@link #withStartPoint(double[]) initial guess values} must be passed
043 * in the following order:
044 * <ul>
045 *  <li>Normalization</li>
046 *  <li>Mean</li>
047 *  <li>Sigma</li>
048 * </ul>
049 * The optimal values will be returned in the same order.
050 *
051 * <p>
052 * Usage example:
053 * <pre>
054 *   WeightedObservedPoints obs = new WeightedObservedPoints();
055 *   obs.add(4.0254623,  531026.0);
056 *   obs.add(4.03128248, 984167.0);
057 *   obs.add(4.03839603, 1887233.0);
058 *   obs.add(4.04421621, 2687152.0);
059 *   obs.add(4.05132976, 3461228.0);
060 *   obs.add(4.05326982, 3580526.0);
061 *   obs.add(4.05779662, 3439750.0);
062 *   obs.add(4.0636168,  2877648.0);
063 *   obs.add(4.06943698, 2175960.0);
064 *   obs.add(4.07525716, 1447024.0);
065 *   obs.add(4.08237071, 717104.0);
066 *   obs.add(4.08366408, 620014.0);
067 *   double[] parameters = GaussianCurveFitter.create().fit(obs.toList());
068 * </pre>
069 *
070 * @since 3.3
071 */
072public class GaussianCurveFitter extends AbstractCurveFitter {
073    /** Parametric function to be fitted. */
074    private static final Gaussian.Parametric FUNCTION = new Gaussian.Parametric() {
075            /** {@inheritDoc} */
076            @Override
077            public double value(double x, double ... p) {
078                double v = Double.POSITIVE_INFINITY;
079                try {
080                    v = super.value(x, p);
081                } catch (NotStrictlyPositiveException e) { // NOPMD
082                    // Do nothing.
083                }
084                return v;
085            }
086
087            /** {@inheritDoc} */
088            @Override
089            public double[] gradient(double x, double ... p) {
090                double[] v = { Double.POSITIVE_INFINITY,
091                               Double.POSITIVE_INFINITY,
092                               Double.POSITIVE_INFINITY };
093                try {
094                    v = super.gradient(x, p);
095                } catch (NotStrictlyPositiveException e) { // NOPMD
096                    // Do nothing.
097                }
098                return v;
099            }
100        };
101    /** Initial guess. */
102    private final double[] initialGuess;
103    /** Maximum number of iterations of the optimization algorithm. */
104    private final int maxIter;
105
106    /**
107     * Contructor used by the factory methods.
108     *
109     * @param initialGuess Initial guess. If set to {@code null}, the initial guess
110     * will be estimated using the {@link ParameterGuesser}.
111     * @param maxIter Maximum number of iterations of the optimization algorithm.
112     */
113    private GaussianCurveFitter(double[] initialGuess,
114                                int maxIter) {
115        this.initialGuess = initialGuess;
116        this.maxIter = maxIter;
117    }
118
119    /**
120     * Creates a default curve fitter.
121     * The initial guess for the parameters will be {@link ParameterGuesser}
122     * computed automatically, and the maximum number of iterations of the
123     * optimization algorithm is set to {@link Integer#MAX_VALUE}.
124     *
125     * @return a curve fitter.
126     *
127     * @see #withStartPoint(double[])
128     * @see #withMaxIterations(int)
129     */
130    public static GaussianCurveFitter create() {
131        return new GaussianCurveFitter(null, Integer.MAX_VALUE);
132    }
133
134    /**
135     * Configure the start point (initial guess).
136     * @param newStart new start point (initial guess)
137     * @return a new instance.
138     */
139    public GaussianCurveFitter withStartPoint(double[] newStart) {
140        return new GaussianCurveFitter(newStart.clone(),
141                                       maxIter);
142    }
143
144    /**
145     * Configure the maximum number of iterations.
146     * @param newMaxIter maximum number of iterations
147     * @return a new instance.
148     */
149    public GaussianCurveFitter withMaxIterations(int newMaxIter) {
150        return new GaussianCurveFitter(initialGuess,
151                                       newMaxIter);
152    }
153
154    /** {@inheritDoc} */
155    @Override
156    protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
157
158        // Prepare least-squares problem.
159        final int len = observations.size();
160        final double[] target  = new double[len];
161        final double[] weights = new double[len];
162
163        int i = 0;
164        for (WeightedObservedPoint obs : observations) {
165            target[i]  = obs.getY();
166            weights[i] = obs.getWeight();
167            ++i;
168        }
169
170        final AbstractCurveFitter.TheoreticalValuesFunction model =
171                new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations);
172
173        final double[] startPoint = initialGuess != null ?
174            initialGuess :
175            // Compute estimation.
176            new ParameterGuesser(observations).guess();
177
178        // Return a new least squares problem set up to fit a Gaussian curve to the
179        // observed points.
180        return new LeastSquaresBuilder().
181                maxEvaluations(Integer.MAX_VALUE).
182                maxIterations(maxIter).
183                start(startPoint).
184                target(target).
185                weight(new DiagonalMatrix(weights)).
186                model(model.getModelFunction(), model.getModelFunctionJacobian()).
187                build();
188
189    }
190
191    /**
192     * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma}
193     * of a {@link org.apache.commons.math3.analysis.function.Gaussian.Parametric}
194     * based on the specified observed points.
195     */
196    public static class ParameterGuesser {
197        /** Normalization factor. */
198        private final double norm;
199        /** Mean. */
200        private final double mean;
201        /** Standard deviation. */
202        private final double sigma;
203
204        /**
205         * Constructs instance with the specified observed points.
206         *
207         * @param observations Observed points from which to guess the
208         * parameters of the Gaussian.
209         * @throws NullArgumentException if {@code observations} is
210         * {@code null}.
211         * @throws NumberIsTooSmallException if there are less than 3
212         * observations.
213         */
214        public ParameterGuesser(Collection<WeightedObservedPoint> observations) {
215            if (observations == null) {
216                throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
217            }
218            if (observations.size() < 3) {
219                throw new NumberIsTooSmallException(observations.size(), 3, true);
220            }
221
222            final List<WeightedObservedPoint> sorted = sortObservations(observations);
223            final double[] params = basicGuess(sorted.toArray(new WeightedObservedPoint[0]));
224
225            norm = params[0];
226            mean = params[1];
227            sigma = params[2];
228        }
229
230        /**
231         * Gets an estimation of the parameters.
232         *
233         * @return the guessed parameters, in the following order:
234         * <ul>
235         *  <li>Normalization factor</li>
236         *  <li>Mean</li>
237         *  <li>Standard deviation</li>
238         * </ul>
239         */
240        public double[] guess() {
241            return new double[] { norm, mean, sigma };
242        }
243
244        /**
245         * Sort the observations.
246         *
247         * @param unsorted Input observations.
248         * @return the input observations, sorted.
249         */
250        private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
251            final List<WeightedObservedPoint> observations = new ArrayList<WeightedObservedPoint>(unsorted);
252
253            final Comparator<WeightedObservedPoint> cmp = new Comparator<WeightedObservedPoint>() {
254                /** {@inheritDoc} */
255                public int compare(WeightedObservedPoint p1,
256                                   WeightedObservedPoint p2) {
257                    if (p1 == null && p2 == null) {
258                        return 0;
259                    }
260                    if (p1 == null) {
261                        return -1;
262                    }
263                    if (p2 == null) {
264                        return 1;
265                    }
266                    final int cmpX = Double.compare(p1.getX(), p2.getX());
267                    if (cmpX < 0) {
268                        return -1;
269                    }
270                    if (cmpX > 0) {
271                        return 1;
272                    }
273                    final int cmpY = Double.compare(p1.getY(), p2.getY());
274                    if (cmpY < 0) {
275                        return -1;
276                    }
277                    if (cmpY > 0) {
278                        return 1;
279                    }
280                    final int cmpW = Double.compare(p1.getWeight(), p2.getWeight());
281                    if (cmpW < 0) {
282                        return -1;
283                    }
284                    if (cmpW > 0) {
285                        return 1;
286                    }
287                    return 0;
288                }
289            };
290
291            Collections.sort(observations, cmp);
292            return observations;
293        }
294
295        /**
296         * Guesses the parameters based on the specified observed points.
297         *
298         * @param points Observed points, sorted.
299         * @return the guessed parameters (normalization factor, mean and
300         * sigma).
301         */
302        private double[] basicGuess(WeightedObservedPoint[] points) {
303            final int maxYIdx = findMaxY(points);
304            final double n = points[maxYIdx].getY();
305            final double m = points[maxYIdx].getX();
306
307            double fwhmApprox;
308            try {
309                final double halfY = n + ((m - n) / 2);
310                final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
311                final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY);
312                fwhmApprox = fwhmX2 - fwhmX1;
313            } catch (OutOfRangeException e) {
314                // TODO: Exceptions should not be used for flow control.
315                fwhmApprox = points[points.length - 1].getX() - points[0].getX();
316            }
317            final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2)));
318
319            return new double[] { n, m, s };
320        }
321
322        /**
323         * Finds index of point in specified points with the largest Y.
324         *
325         * @param points Points to search.
326         * @return the index in specified points array.
327         */
328        private int findMaxY(WeightedObservedPoint[] points) {
329            int maxYIdx = 0;
330            for (int i = 1; i < points.length; i++) {
331                if (points[i].getY() > points[maxYIdx].getY()) {
332                    maxYIdx = i;
333                }
334            }
335            return maxYIdx;
336        }
337
338        /**
339         * Interpolates using the specified points to determine X at the
340         * specified Y.
341         *
342         * @param points Points to use for interpolation.
343         * @param startIdx Index within points from which to start the search for
344         * interpolation bounds points.
345         * @param idxStep Index step for searching interpolation bounds points.
346         * @param y Y value for which X should be determined.
347         * @return the value of X for the specified Y.
348         * @throws ZeroException if {@code idxStep} is 0.
349         * @throws OutOfRangeException if specified {@code y} is not within the
350         * range of the specified {@code points}.
351         */
352        private double interpolateXAtY(WeightedObservedPoint[] points,
353                                       int startIdx,
354                                       int idxStep,
355                                       double y)
356            throws OutOfRangeException {
357            if (idxStep == 0) {
358                throw new ZeroException();
359            }
360            final WeightedObservedPoint[] twoPoints
361                = getInterpolationPointsForY(points, startIdx, idxStep, y);
362            final WeightedObservedPoint p1 = twoPoints[0];
363            final WeightedObservedPoint p2 = twoPoints[1];
364            if (p1.getY() == y) {
365                return p1.getX();
366            }
367            if (p2.getY() == y) {
368                return p2.getX();
369            }
370            return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
371                                (p2.getY() - p1.getY()));
372        }
373
374        /**
375         * Gets the two bounding interpolation points from the specified points
376         * suitable for determining X at the specified Y.
377         *
378         * @param points Points to use for interpolation.
379         * @param startIdx Index within points from which to start search for
380         * interpolation bounds points.
381         * @param idxStep Index step for search for interpolation bounds points.
382         * @param y Y value for which X should be determined.
383         * @return the array containing two points suitable for determining X at
384         * the specified Y.
385         * @throws ZeroException if {@code idxStep} is 0.
386         * @throws OutOfRangeException if specified {@code y} is not within the
387         * range of the specified {@code points}.
388         */
389        private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
390                                                                   int startIdx,
391                                                                   int idxStep,
392                                                                   double y)
393            throws OutOfRangeException {
394            if (idxStep == 0) {
395                throw new ZeroException();
396            }
397            for (int i = startIdx;
398                 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
399                 i += idxStep) {
400                final WeightedObservedPoint p1 = points[i];
401                final WeightedObservedPoint p2 = points[i + idxStep];
402                if (isBetween(y, p1.getY(), p2.getY())) {
403                    if (idxStep < 0) {
404                        return new WeightedObservedPoint[] { p2, p1 };
405                    } else {
406                        return new WeightedObservedPoint[] { p1, p2 };
407                    }
408                }
409            }
410
411            // Boundaries are replaced by dummy values because the raised
412            // exception is caught and the message never displayed.
413            // TODO: Exceptions should not be used for flow control.
414            throw new OutOfRangeException(y,
415                                          Double.NEGATIVE_INFINITY,
416                                          Double.POSITIVE_INFINITY);
417        }
418
419        /**
420         * Determines whether a value is between two other values.
421         *
422         * @param value Value to test whether it is between {@code boundary1}
423         * and {@code boundary2}.
424         * @param boundary1 One end of the range.
425         * @param boundary2 Other end of the range.
426         * @return {@code true} if {@code value} is between {@code boundary1} and
427         * {@code boundary2} (inclusive), {@code false} otherwise.
428         */
429        private boolean isBetween(double value,
430                                  double boundary1,
431                                  double boundary2) {
432            return (value >= boundary1 && value <= boundary2) ||
433                (value >= boundary2 && value <= boundary1);
434        }
435    }
436}