View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.fitting;
18  
19  import java.util.Collections;
20  import java.util.Collection;
21  import java.util.Comparator;
22  import java.util.List;
23  import java.util.ArrayList;
24  
25  import org.apache.commons.math4.legacy.exception.ZeroException;
26  import org.apache.commons.math4.legacy.exception.OutOfRangeException;
27  import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
28  import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder;
29  import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
30  import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
31  
32  /**
33   * Fits points to a user-defined {@link ParametricUnivariateFunction function}.
34   *
35   * @since 3.4
36   */
37  public class SimpleCurveFitter extends AbstractCurveFitter {
38      /** Function to fit. */
39      private final ParametricUnivariateFunction function;
40      /** Initial guess for the parameters. */
41      private final double[] initialGuess;
42      /** Parameter guesser. */
43      private final ParameterGuesser guesser;
44      /** Maximum number of iterations of the optimization algorithm. */
45      private final int maxIter;
46  
47      /**
48       * Constructor used by the factory methods.
49       *
50       * @param function Function to fit.
51       * @param initialGuess Initial guess. Cannot be {@code null}. Its length must
52       * be consistent with the number of parameters of the {@code function} to fit.
53       * @param guesser Method for providing an initial guess (if {@code initialGuess}
54       * is {@code null}).
55       * @param maxIter Maximum number of iterations of the optimization algorithm.
56       */
57      protected SimpleCurveFitter(ParametricUnivariateFunction function,
58                                  double[] initialGuess,
59                                  ParameterGuesser guesser,
60                                  int maxIter) {
61          this.function = function;
62          this.initialGuess = initialGuess;
63          this.guesser = guesser;
64          this.maxIter = maxIter;
65      }
66  
67      /**
68       * Creates a curve fitter.
69       * The maximum number of iterations of the optimization algorithm is set
70       * to {@link Integer#MAX_VALUE}.
71       *
72       * @param f Function to fit.
73       * @param start Initial guess for the parameters.  Cannot be {@code null}.
74       * Its length must be consistent with the number of parameters of the
75       * function to fit.
76       * @return a curve fitter.
77       *
78       * @see #withStartPoint(double[])
79       * @see #withMaxIterations(int)
80       */
81      public static SimpleCurveFitter create(ParametricUnivariateFunction f,
82                                             double[] start) {
83          return new SimpleCurveFitter(f, start, null, Integer.MAX_VALUE);
84      }
85  
86      /**
87       * Creates a curve fitter.
88       * The maximum number of iterations of the optimization algorithm is set
89       * to {@link Integer#MAX_VALUE}.
90       *
91       * @param f Function to fit.
92       * @param guesser Method for providing an initial guess.
93       * @return a curve fitter.
94       *
95       * @see #withStartPoint(double[])
96       * @see #withMaxIterations(int)
97       */
98      public static SimpleCurveFitter create(ParametricUnivariateFunction f,
99                                             ParameterGuesser guesser) {
100         return new SimpleCurveFitter(f, null, guesser, Integer.MAX_VALUE);
101     }
102 
103     /**
104      * Configure the start point (initial guess).
105      * @param newStart new start point (initial guess)
106      * @return a new instance.
107      */
108     public SimpleCurveFitter withStartPoint(double[] newStart) {
109         return new SimpleCurveFitter(function,
110                                      newStart.clone(),
111                                      null,
112                                      maxIter);
113     }
114 
115     /**
116      * Configure the maximum number of iterations.
117      * @param newMaxIter maximum number of iterations
118      * @return a new instance.
119      */
120     public SimpleCurveFitter withMaxIterations(int newMaxIter) {
121         return new SimpleCurveFitter(function,
122                                      initialGuess,
123                                      guesser,
124                                      newMaxIter);
125     }
126 
127     /** {@inheritDoc} */
128     @Override
129     protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
130         // Prepare least-squares problem.
131         final int len = observations.size();
132         final double[] target  = new double[len];
133         final double[] weights = new double[len];
134 
135         int count = 0;
136         for (WeightedObservedPoint obs : observations) {
137             target[count]  = obs.getY();
138             weights[count] = obs.getWeight();
139             ++count;
140         }
141 
142         final AbstractCurveFitter.TheoreticalValuesFunction model
143             = new AbstractCurveFitter.TheoreticalValuesFunction(function,
144                                                                 observations);
145 
146         final double[] startPoint = initialGuess != null ?
147             initialGuess :
148             // Compute estimation.
149             guesser.guess(observations);
150 
151         // Create an optimizer for fitting the curve to the observed points.
152         return new LeastSquaresBuilder().
153                 maxEvaluations(Integer.MAX_VALUE).
154                 maxIterations(maxIter).
155                 start(startPoint).
156                 target(target).
157                 weight(new DiagonalMatrix(weights)).
158                 model(model.getModelFunction(), model.getModelFunctionJacobian()).
159                 build();
160     }
161 
162     /**
163      * Guesses the parameters.
164      */
165     public abstract static class ParameterGuesser {
166         /** Comparator. */
167         private static final Comparator<WeightedObservedPoint> CMP = new Comparator<WeightedObservedPoint>() {
168                 /** {@inheritDoc} */
169                 @Override
170                 public int compare(WeightedObservedPoint p1,
171                                    WeightedObservedPoint p2) {
172                     if (p1 == null && p2 == null) {
173                         return 0;
174                     }
175                     if (p1 == null) {
176                         return -1;
177                     }
178                     if (p2 == null) {
179                         return 1;
180                     }
181                     int comp = Double.compare(p1.getX(), p2.getX());
182                     if (comp != 0) {
183                         return comp;
184                     }
185                     comp = Double.compare(p1.getY(), p2.getY());
186                     if (comp != 0) {
187                         return comp;
188                     }
189                     return Double.compare(p1.getWeight(), p2.getWeight());
190                 }
191             };
192 
193         /**
194          * Computes an estimation of the parameters.
195          *
196          * @param obs Observations.
197          * @return the guessed parameters.
198          */
199         public abstract double[] guess(Collection<WeightedObservedPoint> obs);
200 
201         /**
202          * Sort the observations.
203          *
204          * @param unsorted Input observations.
205          * @return the input observations, sorted.
206          */
207         protected List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
208             final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);
209             Collections.sort(observations, CMP);
210             return observations;
211         }
212 
213         /**
214          * Finds index of point in specified points with the largest Y.
215          *
216          * @param points Points to search.
217          * @return the index in specified points array.
218          */
219         protected int findMaxY(WeightedObservedPoint[] points) {
220             int maxYIdx = 0;
221             for (int i = 1; i < points.length; i++) {
222                 if (points[i].getY() > points[maxYIdx].getY()) {
223                     maxYIdx = i;
224                 }
225             }
226             return maxYIdx;
227         }
228 
229         /**
230          * Interpolates using the specified points to determine X at the
231          * specified Y.
232          *
233          * @param points Points to use for interpolation.
234          * @param startIdx Index within points from which to start the search for
235          * interpolation bounds points.
236          * @param idxStep Index step for searching interpolation bounds points.
237          * @param y Y value for which X should be determined.
238          * @return the value of X for the specified Y.
239          * @throws ZeroException if {@code idxStep} is 0.
240          * @throws OutOfRangeException if specified {@code y} is not within the
241          * range of the specified {@code points}.
242          */
243         protected double interpolateXAtY(WeightedObservedPoint[] points,
244                                          int startIdx,
245                                          int idxStep,
246                                          double y) {
247             if (idxStep == 0) {
248                 throw new ZeroException();
249             }
250             final WeightedObservedPoint[] twoPoints
251                 = getInterpolationPointsForY(points, startIdx, idxStep, y);
252             final WeightedObservedPoint p1 = twoPoints[0];
253             final WeightedObservedPoint p2 = twoPoints[1];
254             if (p1.getY() == y) {
255                 return p1.getX();
256             }
257             if (p2.getY() == y) {
258                 return p2.getX();
259             }
260             return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
261                                 (p2.getY() - p1.getY()));
262         }
263 
264         /**
265          * Gets the two bounding interpolation points from the specified points
266          * suitable for determining X at the specified Y.
267          *
268          * @param points Points to use for interpolation.
269          * @param startIdx Index within points from which to start search for
270          * interpolation bounds points.
271          * @param idxStep Index step for search for interpolation bounds points.
272          * @param y Y value for which X should be determined.
273          * @return the array containing two points suitable for determining X at
274          * the specified Y.
275          * @throws ZeroException if {@code idxStep} is 0.
276          * @throws OutOfRangeException if specified {@code y} is not within the
277          * range of the specified {@code points}.
278          */
279         private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
280                                                                    int startIdx,
281                                                                    int idxStep,
282                                                                    double y) {
283             if (idxStep == 0) {
284                 throw new ZeroException();
285             }
286             for (int i = startIdx;
287                  idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
288                  i += idxStep) {
289                 final WeightedObservedPoint p1 = points[i];
290                 final WeightedObservedPoint p2 = points[i + idxStep];
291                 if (isBetween(y, p1.getY(), p2.getY())) {
292                     if (idxStep < 0) {
293                         return new WeightedObservedPoint[] { p2, p1 };
294                     } else {
295                         return new WeightedObservedPoint[] { p1, p2 };
296                     }
297                 }
298             }
299 
300             // Boundaries are replaced by dummy values because the raised
301             // exception is caught and the message never displayed.
302             // TODO: Exceptions should not be used for flow control.
303             throw new OutOfRangeException(y,
304                                           Double.NEGATIVE_INFINITY,
305                                           Double.POSITIVE_INFINITY);
306         }
307 
308         /**
309          * Determines whether a value is between two other values.
310          *
311          * @param value Value to test whether it is between {@code boundary1}
312          * and {@code boundary2}.
313          * @param boundary1 One end of the range.
314          * @param boundary2 Other end of the range.
315          * @return {@code true} if {@code value} is between {@code boundary1} and
316          * {@code boundary2} (inclusive), {@code false} otherwise.
317          */
318         private boolean isBetween(double value,
319                                   double boundary1,
320                                   double boundary2) {
321             return (value >= boundary1 && value <= boundary2) ||
322                 (value >= boundary2 && value <= boundary1);
323         }
324     }
325 }