001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math4.legacy.fitting; 018 019import java.util.Collections; 020import java.util.Collection; 021import java.util.Comparator; 022import java.util.List; 023import java.util.ArrayList; 024 025import org.apache.commons.math4.legacy.exception.ZeroException; 026import org.apache.commons.math4.legacy.exception.OutOfRangeException; 027import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction; 028import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder; 029import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem; 030import org.apache.commons.math4.legacy.linear.DiagonalMatrix; 031 032/** 033 * Fits points to a user-defined {@link ParametricUnivariateFunction function}. 034 * 035 * @since 3.4 036 */ 037public class SimpleCurveFitter extends AbstractCurveFitter { 038 /** Function to fit. */ 039 private final ParametricUnivariateFunction function; 040 /** Initial guess for the parameters. */ 041 private final double[] initialGuess; 042 /** Parameter guesser. */ 043 private final ParameterGuesser guesser; 044 /** Maximum number of iterations of the optimization algorithm. */ 045 private final int maxIter; 046 047 /** 048 * Constructor used by the factory methods. 049 * 050 * @param function Function to fit. 051 * @param initialGuess Initial guess. Cannot be {@code null}. Its length must 052 * be consistent with the number of parameters of the {@code function} to fit. 053 * @param guesser Method for providing an initial guess (if {@code initialGuess} 054 * is {@code null}). 055 * @param maxIter Maximum number of iterations of the optimization algorithm. 056 */ 057 protected SimpleCurveFitter(ParametricUnivariateFunction function, 058 double[] initialGuess, 059 ParameterGuesser guesser, 060 int maxIter) { 061 this.function = function; 062 this.initialGuess = initialGuess; 063 this.guesser = guesser; 064 this.maxIter = maxIter; 065 } 066 067 /** 068 * Creates a curve fitter. 069 * The maximum number of iterations of the optimization algorithm is set 070 * to {@link Integer#MAX_VALUE}. 071 * 072 * @param f Function to fit. 073 * @param start Initial guess for the parameters. Cannot be {@code null}. 074 * Its length must be consistent with the number of parameters of the 075 * function to fit. 076 * @return a curve fitter. 077 * 078 * @see #withStartPoint(double[]) 079 * @see #withMaxIterations(int) 080 */ 081 public static SimpleCurveFitter create(ParametricUnivariateFunction f, 082 double[] start) { 083 return new SimpleCurveFitter(f, start, null, Integer.MAX_VALUE); 084 } 085 086 /** 087 * Creates a curve fitter. 088 * The maximum number of iterations of the optimization algorithm is set 089 * to {@link Integer#MAX_VALUE}. 090 * 091 * @param f Function to fit. 092 * @param guesser Method for providing an initial guess. 093 * @return a curve fitter. 094 * 095 * @see #withStartPoint(double[]) 096 * @see #withMaxIterations(int) 097 */ 098 public static SimpleCurveFitter create(ParametricUnivariateFunction f, 099 ParameterGuesser guesser) { 100 return new SimpleCurveFitter(f, null, guesser, Integer.MAX_VALUE); 101 } 102 103 /** 104 * Configure the start point (initial guess). 105 * @param newStart new start point (initial guess) 106 * @return a new instance. 107 */ 108 public SimpleCurveFitter withStartPoint(double[] newStart) { 109 return new SimpleCurveFitter(function, 110 newStart.clone(), 111 null, 112 maxIter); 113 } 114 115 /** 116 * Configure the maximum number of iterations. 117 * @param newMaxIter maximum number of iterations 118 * @return a new instance. 119 */ 120 public SimpleCurveFitter withMaxIterations(int newMaxIter) { 121 return new SimpleCurveFitter(function, 122 initialGuess, 123 guesser, 124 newMaxIter); 125 } 126 127 /** {@inheritDoc} */ 128 @Override 129 protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) { 130 // Prepare least-squares problem. 131 final int len = observations.size(); 132 final double[] target = new double[len]; 133 final double[] weights = new double[len]; 134 135 int count = 0; 136 for (WeightedObservedPoint obs : observations) { 137 target[count] = obs.getY(); 138 weights[count] = obs.getWeight(); 139 ++count; 140 } 141 142 final AbstractCurveFitter.TheoreticalValuesFunction model 143 = new AbstractCurveFitter.TheoreticalValuesFunction(function, 144 observations); 145 146 final double[] startPoint = initialGuess != null ? 147 initialGuess : 148 // Compute estimation. 149 guesser.guess(observations); 150 151 // Create an optimizer for fitting the curve to the observed points. 152 return new LeastSquaresBuilder(). 153 maxEvaluations(Integer.MAX_VALUE). 154 maxIterations(maxIter). 155 start(startPoint). 156 target(target). 157 weight(new DiagonalMatrix(weights)). 158 model(model.getModelFunction(), model.getModelFunctionJacobian()). 159 build(); 160 } 161 162 /** 163 * Guesses the parameters. 164 */ 165 public abstract static class ParameterGuesser { 166 /** Comparator. */ 167 private static final Comparator<WeightedObservedPoint> CMP = new Comparator<WeightedObservedPoint>() { 168 /** {@inheritDoc} */ 169 @Override 170 public int compare(WeightedObservedPoint p1, 171 WeightedObservedPoint p2) { 172 if (p1 == null && p2 == null) { 173 return 0; 174 } 175 if (p1 == null) { 176 return -1; 177 } 178 if (p2 == null) { 179 return 1; 180 } 181 int comp = Double.compare(p1.getX(), p2.getX()); 182 if (comp != 0) { 183 return comp; 184 } 185 comp = Double.compare(p1.getY(), p2.getY()); 186 if (comp != 0) { 187 return comp; 188 } 189 return Double.compare(p1.getWeight(), p2.getWeight()); 190 } 191 }; 192 193 /** 194 * Computes an estimation of the parameters. 195 * 196 * @param obs Observations. 197 * @return the guessed parameters. 198 */ 199 public abstract double[] guess(Collection<WeightedObservedPoint> obs); 200 201 /** 202 * Sort the observations. 203 * 204 * @param unsorted Input observations. 205 * @return the input observations, sorted. 206 */ 207 protected List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) { 208 final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted); 209 Collections.sort(observations, CMP); 210 return observations; 211 } 212 213 /** 214 * Finds index of point in specified points with the largest Y. 215 * 216 * @param points Points to search. 217 * @return the index in specified points array. 218 */ 219 protected int findMaxY(WeightedObservedPoint[] points) { 220 int maxYIdx = 0; 221 for (int i = 1; i < points.length; i++) { 222 if (points[i].getY() > points[maxYIdx].getY()) { 223 maxYIdx = i; 224 } 225 } 226 return maxYIdx; 227 } 228 229 /** 230 * Interpolates using the specified points to determine X at the 231 * specified Y. 232 * 233 * @param points Points to use for interpolation. 234 * @param startIdx Index within points from which to start the search for 235 * interpolation bounds points. 236 * @param idxStep Index step for searching interpolation bounds points. 237 * @param y Y value for which X should be determined. 238 * @return the value of X for the specified Y. 239 * @throws ZeroException if {@code idxStep} is 0. 240 * @throws OutOfRangeException if specified {@code y} is not within the 241 * range of the specified {@code points}. 242 */ 243 protected double interpolateXAtY(WeightedObservedPoint[] points, 244 int startIdx, 245 int idxStep, 246 double y) { 247 if (idxStep == 0) { 248 throw new ZeroException(); 249 } 250 final WeightedObservedPoint[] twoPoints 251 = getInterpolationPointsForY(points, startIdx, idxStep, y); 252 final WeightedObservedPoint p1 = twoPoints[0]; 253 final WeightedObservedPoint p2 = twoPoints[1]; 254 if (p1.getY() == y) { 255 return p1.getX(); 256 } 257 if (p2.getY() == y) { 258 return p2.getX(); 259 } 260 return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) / 261 (p2.getY() - p1.getY())); 262 } 263 264 /** 265 * Gets the two bounding interpolation points from the specified points 266 * suitable for determining X at the specified Y. 267 * 268 * @param points Points to use for interpolation. 269 * @param startIdx Index within points from which to start search for 270 * interpolation bounds points. 271 * @param idxStep Index step for search for interpolation bounds points. 272 * @param y Y value for which X should be determined. 273 * @return the array containing two points suitable for determining X at 274 * the specified Y. 275 * @throws ZeroException if {@code idxStep} is 0. 276 * @throws OutOfRangeException if specified {@code y} is not within the 277 * range of the specified {@code points}. 278 */ 279 private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points, 280 int startIdx, 281 int idxStep, 282 double y) { 283 if (idxStep == 0) { 284 throw new ZeroException(); 285 } 286 for (int i = startIdx; 287 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length; 288 i += idxStep) { 289 final WeightedObservedPoint p1 = points[i]; 290 final WeightedObservedPoint p2 = points[i + idxStep]; 291 if (isBetween(y, p1.getY(), p2.getY())) { 292 if (idxStep < 0) { 293 return new WeightedObservedPoint[] { p2, p1 }; 294 } else { 295 return new WeightedObservedPoint[] { p1, p2 }; 296 } 297 } 298 } 299 300 // Boundaries are replaced by dummy values because the raised 301 // exception is caught and the message never displayed. 302 // TODO: Exceptions should not be used for flow control. 303 throw new OutOfRangeException(y, 304 Double.NEGATIVE_INFINITY, 305 Double.POSITIVE_INFINITY); 306 } 307 308 /** 309 * Determines whether a value is between two other values. 310 * 311 * @param value Value to test whether it is between {@code boundary1} 312 * and {@code boundary2}. 313 * @param boundary1 One end of the range. 314 * @param boundary2 Other end of the range. 315 * @return {@code true} if {@code value} is between {@code boundary1} and 316 * {@code boundary2} (inclusive), {@code false} otherwise. 317 */ 318 private boolean isBetween(double value, 319 double boundary1, 320 double boundary2) { 321 return (value >= boundary1 && value <= boundary2) || 322 (value >= boundary2 && value <= boundary1); 323 } 324 } 325}