001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math3.fitting; 018 019import java.util.Collection; 020 021import org.apache.commons.math3.analysis.ParametricUnivariateFunction; 022import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder; 023import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem; 024import org.apache.commons.math3.linear.DiagonalMatrix; 025 026/** 027 * Fits points to a user-defined {@link ParametricUnivariateFunction function}. 028 * 029 * @since 3.4 030 */ 031public class SimpleCurveFitter extends AbstractCurveFitter { 032 /** Function to fit. */ 033 private final ParametricUnivariateFunction function; 034 /** Initial guess for the parameters. */ 035 private final double[] initialGuess; 036 /** Maximum number of iterations of the optimization algorithm. */ 037 private final int maxIter; 038 039 /** 040 * Contructor used by the factory methods. 041 * 042 * @param function Function to fit. 043 * @param initialGuess Initial guess. Cannot be {@code null}. Its length must 044 * be consistent with the number of parameters of the {@code function} to fit. 045 * @param maxIter Maximum number of iterations of the optimization algorithm. 046 */ 047 private SimpleCurveFitter(ParametricUnivariateFunction function, 048 double[] initialGuess, 049 int maxIter) { 050 this.function = function; 051 this.initialGuess = initialGuess; 052 this.maxIter = maxIter; 053 } 054 055 /** 056 * Creates a curve fitter. 057 * The maximum number of iterations of the optimization algorithm is set 058 * to {@link Integer#MAX_VALUE}. 059 * 060 * @param f Function to fit. 061 * @param start Initial guess for the parameters. Cannot be {@code null}. 062 * Its length must be consistent with the number of parameters of the 063 * function to fit. 064 * @return a curve fitter. 065 * 066 * @see #withStartPoint(double[]) 067 * @see #withMaxIterations(int) 068 */ 069 public static SimpleCurveFitter create(ParametricUnivariateFunction f, 070 double[] start) { 071 return new SimpleCurveFitter(f, start, Integer.MAX_VALUE); 072 } 073 074 /** 075 * Configure the start point (initial guess). 076 * @param newStart new start point (initial guess) 077 * @return a new instance. 078 */ 079 public SimpleCurveFitter withStartPoint(double[] newStart) { 080 return new SimpleCurveFitter(function, 081 newStart.clone(), 082 maxIter); 083 } 084 085 /** 086 * Configure the maximum number of iterations. 087 * @param newMaxIter maximum number of iterations 088 * @return a new instance. 089 */ 090 public SimpleCurveFitter withMaxIterations(int newMaxIter) { 091 return new SimpleCurveFitter(function, 092 initialGuess, 093 newMaxIter); 094 } 095 096 /** {@inheritDoc} */ 097 @Override 098 protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) { 099 // Prepare least-squares problem. 100 final int len = observations.size(); 101 final double[] target = new double[len]; 102 final double[] weights = new double[len]; 103 104 int count = 0; 105 for (WeightedObservedPoint obs : observations) { 106 target[count] = obs.getY(); 107 weights[count] = obs.getWeight(); 108 ++count; 109 } 110 111 final AbstractCurveFitter.TheoreticalValuesFunction model 112 = new AbstractCurveFitter.TheoreticalValuesFunction(function, 113 observations); 114 115 // Create an optimizer for fitting the curve to the observed points. 116 return new LeastSquaresBuilder(). 117 maxEvaluations(Integer.MAX_VALUE). 118 maxIterations(maxIter). 119 start(initialGuess). 120 target(target). 121 weight(new DiagonalMatrix(weights)). 122 model(model.getModelFunction(), model.getModelFunctionJacobian()). 123 build(); 124 } 125}