001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.legacy.fitting;
018
019import java.util.Collections;
020import java.util.Collection;
021import java.util.Comparator;
022import java.util.List;
023import java.util.ArrayList;
024
025import org.apache.commons.math4.legacy.exception.ZeroException;
026import org.apache.commons.math4.legacy.exception.OutOfRangeException;
027import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
028import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder;
029import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
030import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
031
032/**
033 * Fits points to a user-defined {@link ParametricUnivariateFunction function}.
034 *
035 * @since 3.4
036 */
037public class SimpleCurveFitter extends AbstractCurveFitter {
038    /** Function to fit. */
039    private final ParametricUnivariateFunction function;
040    /** Initial guess for the parameters. */
041    private final double[] initialGuess;
042    /** Parameter guesser. */
043    private final ParameterGuesser guesser;
044    /** Maximum number of iterations of the optimization algorithm. */
045    private final int maxIter;
046
047    /**
048     * Constructor used by the factory methods.
049     *
050     * @param function Function to fit.
051     * @param initialGuess Initial guess. Cannot be {@code null}. Its length must
052     * be consistent with the number of parameters of the {@code function} to fit.
053     * @param guesser Method for providing an initial guess (if {@code initialGuess}
054     * is {@code null}).
055     * @param maxIter Maximum number of iterations of the optimization algorithm.
056     */
057    protected SimpleCurveFitter(ParametricUnivariateFunction function,
058                                double[] initialGuess,
059                                ParameterGuesser guesser,
060                                int maxIter) {
061        this.function = function;
062        this.initialGuess = initialGuess;
063        this.guesser = guesser;
064        this.maxIter = maxIter;
065    }
066
067    /**
068     * Creates a curve fitter.
069     * The maximum number of iterations of the optimization algorithm is set
070     * to {@link Integer#MAX_VALUE}.
071     *
072     * @param f Function to fit.
073     * @param start Initial guess for the parameters.  Cannot be {@code null}.
074     * Its length must be consistent with the number of parameters of the
075     * function to fit.
076     * @return a curve fitter.
077     *
078     * @see #withStartPoint(double[])
079     * @see #withMaxIterations(int)
080     */
081    public static SimpleCurveFitter create(ParametricUnivariateFunction f,
082                                           double[] start) {
083        return new SimpleCurveFitter(f, start, null, Integer.MAX_VALUE);
084    }
085
086    /**
087     * Creates a curve fitter.
088     * The maximum number of iterations of the optimization algorithm is set
089     * to {@link Integer#MAX_VALUE}.
090     *
091     * @param f Function to fit.
092     * @param guesser Method for providing an initial guess.
093     * @return a curve fitter.
094     *
095     * @see #withStartPoint(double[])
096     * @see #withMaxIterations(int)
097     */
098    public static SimpleCurveFitter create(ParametricUnivariateFunction f,
099                                           ParameterGuesser guesser) {
100        return new SimpleCurveFitter(f, null, guesser, Integer.MAX_VALUE);
101    }
102
103    /**
104     * Configure the start point (initial guess).
105     * @param newStart new start point (initial guess)
106     * @return a new instance.
107     */
108    public SimpleCurveFitter withStartPoint(double[] newStart) {
109        return new SimpleCurveFitter(function,
110                                     newStart.clone(),
111                                     null,
112                                     maxIter);
113    }
114
115    /**
116     * Configure the maximum number of iterations.
117     * @param newMaxIter maximum number of iterations
118     * @return a new instance.
119     */
120    public SimpleCurveFitter withMaxIterations(int newMaxIter) {
121        return new SimpleCurveFitter(function,
122                                     initialGuess,
123                                     guesser,
124                                     newMaxIter);
125    }
126
127    /** {@inheritDoc} */
128    @Override
129    protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
130        // Prepare least-squares problem.
131        final int len = observations.size();
132        final double[] target  = new double[len];
133        final double[] weights = new double[len];
134
135        int count = 0;
136        for (WeightedObservedPoint obs : observations) {
137            target[count]  = obs.getY();
138            weights[count] = obs.getWeight();
139            ++count;
140        }
141
142        final AbstractCurveFitter.TheoreticalValuesFunction model
143            = new AbstractCurveFitter.TheoreticalValuesFunction(function,
144                                                                observations);
145
146        final double[] startPoint = initialGuess != null ?
147            initialGuess :
148            // Compute estimation.
149            guesser.guess(observations);
150
151        // Create an optimizer for fitting the curve to the observed points.
152        return new LeastSquaresBuilder().
153                maxEvaluations(Integer.MAX_VALUE).
154                maxIterations(maxIter).
155                start(startPoint).
156                target(target).
157                weight(new DiagonalMatrix(weights)).
158                model(model.getModelFunction(), model.getModelFunctionJacobian()).
159                build();
160    }
161
162    /**
163     * Guesses the parameters.
164     */
165    public abstract static class ParameterGuesser {
166        /** Comparator. */
167        private static final Comparator<WeightedObservedPoint> CMP = new Comparator<WeightedObservedPoint>() {
168                /** {@inheritDoc} */
169                @Override
170                public int compare(WeightedObservedPoint p1,
171                                   WeightedObservedPoint p2) {
172                    if (p1 == null && p2 == null) {
173                        return 0;
174                    }
175                    if (p1 == null) {
176                        return -1;
177                    }
178                    if (p2 == null) {
179                        return 1;
180                    }
181                    int comp = Double.compare(p1.getX(), p2.getX());
182                    if (comp != 0) {
183                        return comp;
184                    }
185                    comp = Double.compare(p1.getY(), p2.getY());
186                    if (comp != 0) {
187                        return comp;
188                    }
189                    return Double.compare(p1.getWeight(), p2.getWeight());
190                }
191            };
192
193        /**
194         * Computes an estimation of the parameters.
195         *
196         * @param obs Observations.
197         * @return the guessed parameters.
198         */
199        public abstract double[] guess(Collection<WeightedObservedPoint> obs);
200
201        /**
202         * Sort the observations.
203         *
204         * @param unsorted Input observations.
205         * @return the input observations, sorted.
206         */
207        protected List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
208            final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);
209            Collections.sort(observations, CMP);
210            return observations;
211        }
212
213        /**
214         * Finds index of point in specified points with the largest Y.
215         *
216         * @param points Points to search.
217         * @return the index in specified points array.
218         */
219        protected int findMaxY(WeightedObservedPoint[] points) {
220            int maxYIdx = 0;
221            for (int i = 1; i < points.length; i++) {
222                if (points[i].getY() > points[maxYIdx].getY()) {
223                    maxYIdx = i;
224                }
225            }
226            return maxYIdx;
227        }
228
229        /**
230         * Interpolates using the specified points to determine X at the
231         * specified Y.
232         *
233         * @param points Points to use for interpolation.
234         * @param startIdx Index within points from which to start the search for
235         * interpolation bounds points.
236         * @param idxStep Index step for searching interpolation bounds points.
237         * @param y Y value for which X should be determined.
238         * @return the value of X for the specified Y.
239         * @throws ZeroException if {@code idxStep} is 0.
240         * @throws OutOfRangeException if specified {@code y} is not within the
241         * range of the specified {@code points}.
242         */
243        protected double interpolateXAtY(WeightedObservedPoint[] points,
244                                         int startIdx,
245                                         int idxStep,
246                                         double y) {
247            if (idxStep == 0) {
248                throw new ZeroException();
249            }
250            final WeightedObservedPoint[] twoPoints
251                = getInterpolationPointsForY(points, startIdx, idxStep, y);
252            final WeightedObservedPoint p1 = twoPoints[0];
253            final WeightedObservedPoint p2 = twoPoints[1];
254            if (p1.getY() == y) {
255                return p1.getX();
256            }
257            if (p2.getY() == y) {
258                return p2.getX();
259            }
260            return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
261                                (p2.getY() - p1.getY()));
262        }
263
264        /**
265         * Gets the two bounding interpolation points from the specified points
266         * suitable for determining X at the specified Y.
267         *
268         * @param points Points to use for interpolation.
269         * @param startIdx Index within points from which to start search for
270         * interpolation bounds points.
271         * @param idxStep Index step for search for interpolation bounds points.
272         * @param y Y value for which X should be determined.
273         * @return the array containing two points suitable for determining X at
274         * the specified Y.
275         * @throws ZeroException if {@code idxStep} is 0.
276         * @throws OutOfRangeException if specified {@code y} is not within the
277         * range of the specified {@code points}.
278         */
279        private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
280                                                                   int startIdx,
281                                                                   int idxStep,
282                                                                   double y) {
283            if (idxStep == 0) {
284                throw new ZeroException();
285            }
286            for (int i = startIdx;
287                 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
288                 i += idxStep) {
289                final WeightedObservedPoint p1 = points[i];
290                final WeightedObservedPoint p2 = points[i + idxStep];
291                if (isBetween(y, p1.getY(), p2.getY())) {
292                    if (idxStep < 0) {
293                        return new WeightedObservedPoint[] { p2, p1 };
294                    } else {
295                        return new WeightedObservedPoint[] { p1, p2 };
296                    }
297                }
298            }
299
300            // Boundaries are replaced by dummy values because the raised
301            // exception is caught and the message never displayed.
302            // TODO: Exceptions should not be used for flow control.
303            throw new OutOfRangeException(y,
304                                          Double.NEGATIVE_INFINITY,
305                                          Double.POSITIVE_INFINITY);
306        }
307
308        /**
309         * Determines whether a value is between two other values.
310         *
311         * @param value Value to test whether it is between {@code boundary1}
312         * and {@code boundary2}.
313         * @param boundary1 One end of the range.
314         * @param boundary2 Other end of the range.
315         * @return {@code true} if {@code value} is between {@code boundary1} and
316         * {@code boundary2} (inclusive), {@code false} otherwise.
317         */
318        private boolean isBetween(double value,
319                                  double boundary1,
320                                  double boundary2) {
321            return (value >= boundary1 && value <= boundary2) ||
322                (value >= boundary2 && value <= boundary1);
323        }
324    }
325}