1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.math4.legacy.fitting.leastsquares;
18
19 import java.io.BufferedReader;
20 import java.io.IOException;
21 import java.util.ArrayList;
22 import java.util.Arrays;
23
24 import org.apache.commons.math4.legacy.analysis.MultivariateMatrixFunction;
25 import org.apache.commons.math4.legacy.analysis.MultivariateVectorFunction;
26
27 /**
28 * This class gives access to the statistical reference datasets provided by the
29 * NIST (available
30 * <a href="http://www.itl.nist.gov/div898/strd/general/dataarchive.html">here</a>).
31 * Instances of this class can be created by invocation of the
32 * {@link StatisticalReferenceDatasetFactory}.
33 */
34 public abstract class StatisticalReferenceDataset {
35 /** The name of this dataset. */
36 private final String name;
37 /** The total number of observations (data points). */
38 private final int numObservations;
39 /** The total number of parameters. */
40 private final int numParameters;
41 /** The total number of starting points for the optimizations. */
42 private final int numStartingPoints;
43 /** The values of the predictor. */
44 private final double[] x;
45 /** The values of the response. */
46 private final double[] y;
47 /**
48 * The starting values. {@code startingValues[j][i]} is the value of the
49 * {@code i}-th parameter in the {@code j}-th set of starting values.
50 */
51 private final double[][] startingValues;
52 /** The certified values of the parameters. */
53 private final double[] a;
54 /** The certified values of the standard deviation of the parameters. */
55 private final double[] sigA;
56 /** The certified value of the residual sum of squares. */
57 private double residualSumOfSquares;
58 /** The least-squares problem. */
59 private final LeastSquaresProblem problem;
60
61 /**
62 * Creates a new instance of this class from the specified data file. The
63 * file must follow the StRD format.
64 *
65 * @param in the data file
66 * @throws IOException if an I/O error occurs
67 */
68 public StatisticalReferenceDataset(final BufferedReader in)
69 throws IOException {
70
71 final ArrayList<String> lines = new ArrayList<>();
72 for (String line = in.readLine(); line != null; line = in.readLine()) {
73 lines.add(line);
74 }
75 int[] index = findLineNumbers("Data", lines);
76 if (index == null) {
77 throw new AssertionError("could not find line indices for data");
78 }
79 this.numObservations = index[1] - index[0] + 1;
80 this.x = new double[this.numObservations];
81 this.y = new double[this.numObservations];
82 for (int i = 0; i < this.numObservations; i++) {
83 final String line = lines.get(index[0] + i - 1);
84 final String[] tokens = line.trim().split(" ++");
85 // Data columns are in reverse order!!!
86 this.y[i] = Double.parseDouble(tokens[0]);
87 this.x[i] = Double.parseDouble(tokens[1]);
88 }
89
90 index = findLineNumbers("Starting Values", lines);
91 if (index == null) {
92 throw new AssertionError(
93 "could not find line indices for starting values");
94 }
95 this.numParameters = index[1] - index[0] + 1;
96
97 double[][] start = null;
98 this.a = new double[numParameters];
99 this.sigA = new double[numParameters];
100 for (int i = 0; i < numParameters; i++) {
101 final String line = lines.get(index[0] + i - 1);
102 final String[] tokens = line.trim().split(" ++");
103 if (start == null) {
104 start = new double[tokens.length - 4][numParameters];
105 }
106 for (int j = 2; j < tokens.length - 2; j++) {
107 start[j - 2][i] = Double.parseDouble(tokens[j]);
108 }
109 this.a[i] = Double.parseDouble(tokens[tokens.length - 2]);
110 this.sigA[i] = Double.parseDouble(tokens[tokens.length - 1]);
111 }
112 if (start == null) {
113 throw new IOException("could not find starting values");
114 }
115 this.numStartingPoints = start.length;
116 this.startingValues = start;
117
118 double dummyDouble = Double.NaN;
119 String dummyString = null;
120 for (String line : lines) {
121 if (line.contains("Dataset Name:")) {
122 dummyString = line
123 .substring(line.indexOf("Dataset Name:") + 13,
124 line.indexOf("(")).trim();
125 }
126 if (line.contains("Residual Sum of Squares")) {
127 final String[] tokens = line.split(" ++");
128 dummyDouble = Double.parseDouble(tokens[4].trim());
129 }
130 }
131 if (Double.isNaN(dummyDouble)) {
132 throw new IOException(
133 "could not find certified value of residual sum of squares");
134 }
135 this.residualSumOfSquares = dummyDouble;
136
137 if (dummyString == null) {
138 throw new IOException("could not find dataset name");
139 }
140 this.name = dummyString;
141
142 this.problem = new LeastSquaresProblem();
143 }
144
145 class LeastSquaresProblem {
146 public MultivariateVectorFunction getModelFunction() {
147 return new MultivariateVectorFunction() {
148 @Override
149 public double[] value(final double[] a) {
150 final int n = getNumObservations();
151 final double[] yhat = new double[n];
152 for (int i = 0; i < n; i++) {
153 yhat[i] = getModelValue(getX(i), a);
154 }
155 return yhat;
156 }
157 };
158 }
159
160 public MultivariateMatrixFunction getModelFunctionJacobian() {
161 return new MultivariateMatrixFunction() {
162 @Override
163 public double[][] value(final double[] a)
164 throws IllegalArgumentException {
165 final int n = getNumObservations();
166 final double[][] j = new double[n][];
167 for (int i = 0; i < n; i++) {
168 j[i] = getModelDerivatives(getX(i), a);
169 }
170 return j;
171 }
172 };
173 }
174 }
175
176 /**
177 * Returns the name of this dataset.
178 *
179 * @return the name of the dataset
180 */
181 public String getName() {
182 return name;
183 }
184
185 /**
186 * Returns the total number of observations (data points).
187 *
188 * @return the number of observations
189 */
190 public int getNumObservations() {
191 return numObservations;
192 }
193
194 /**
195 * Returns a copy of the data arrays. The data is laid out as follows <li>
196 * {@code data[0][i] = x[i]},</li> <li>{@code data[1][i] = y[i]},</li>
197 *
198 * @return the array of data points.
199 */
200 public double[][] getData() {
201 return new double[][] {
202 Arrays.copyOf(x, x.length), Arrays.copyOf(y, y.length)
203 };
204 }
205
206 /**
207 * Returns the x-value of the {@code i}-th data point.
208 *
209 * @param i the index of the data point
210 * @return the x-value
211 */
212 public double getX(final int i) {
213 return x[i];
214 }
215
216 /**
217 * Returns the y-value of the {@code i}-th data point.
218 *
219 * @param i the index of the data point
220 * @return the y-value
221 */
222 public double getY(final int i) {
223 return y[i];
224 }
225
226 /**
227 * Returns the total number of parameters.
228 *
229 * @return the number of parameters
230 */
231 public int getNumParameters() {
232 return numParameters;
233 }
234
235 /**
236 * Returns the certified values of the paramters.
237 *
238 * @return the values of the parameters
239 */
240 public double[] getParameters() {
241 return Arrays.copyOf(a, a.length);
242 }
243
244 /**
245 * Returns the certified value of the {@code i}-th parameter.
246 *
247 * @param i the index of the parameter
248 * @return the value of the parameter
249 */
250 public double getParameter(final int i) {
251 return a[i];
252 }
253
254 /**
255 * Reurns the certified values of the standard deviations of the parameters.
256 *
257 * @return the standard deviations of the parameters
258 */
259 public double[] getParametersStandardDeviations() {
260 return Arrays.copyOf(sigA, sigA.length);
261 }
262
263 /**
264 * Returns the certified value of the standard deviation of the {@code i}-th
265 * parameter.
266 *
267 * @param i the index of the parameter
268 * @return the standard deviation of the parameter
269 */
270 public double getParameterStandardDeviation(final int i) {
271 return sigA[i];
272 }
273
274 /**
275 * Returns the certified value of the residual sum of squares.
276 *
277 * @return the residual sum of squares
278 */
279 public double getResidualSumOfSquares() {
280 return residualSumOfSquares;
281 }
282
283 /**
284 * Returns the total number of starting points (initial guesses for the
285 * optimization process).
286 *
287 * @return the number of starting points
288 */
289 public int getNumStartingPoints() {
290 return numStartingPoints;
291 }
292
293 /**
294 * Returns the {@code i}-th set of initial values of the parameters.
295 *
296 * @param i the index of the starting point
297 * @return the starting point
298 */
299 public double[] getStartingPoint(final int i) {
300 return Arrays.copyOf(startingValues[i], startingValues[i].length);
301 }
302
303 /**
304 * Returns the least-squares problem corresponding to fitting the model to
305 * the specified data.
306 *
307 * @return the least-squares problem
308 */
309 public LeastSquaresProblem getLeastSquaresProblem() {
310 return problem;
311 }
312
313 /**
314 * Returns the value of the model for the specified values of the predictor
315 * variable and the parameters.
316 *
317 * @param x the predictor variable
318 * @param a the parameters
319 * @return the value of the model
320 */
321 public abstract double getModelValue(double x, double[] a);
322
323 /**
324 * Returns the values of the partial derivatives of the model with respect
325 * to the parameters.
326 *
327 * @param x the predictor variable
328 * @param a the parameters
329 * @return the partial derivatives
330 */
331 public abstract double[] getModelDerivatives(double x,
332 double[] a);
333
334 /**
335 * <p>
336 * Parses the specified text lines, and extracts the indices of the first
337 * and last lines of the data defined by the specified {@code key}. This key
338 * must be one of
339 * </p>
340 * <ul>
341 * <li>{@code "Starting Values"},</li>
342 * <li>{@code "Certified Values"},</li>
343 * <li>{@code "Data"}.</li>
344 * </ul>
345 * <p>
346 * In the NIST data files, the line indices are separated by the keywords
347 * {@code "lines"} and {@code "to"}.
348 * </p>
349 *
350 * @param lines the line of text to be parsed
351 * @return an array of two {@code int}s. First value is the index of the
352 * first line, second value is the index of the last line.
353 * {@code null} if the line could not be parsed.
354 */
355 private static int[] findLineNumbers(final String key,
356 final Iterable<String> lines) {
357 for (String text : lines) {
358 boolean flag = text.contains(key) && text.contains("lines") &&
359 text.contains("to") && text.contains(")");
360 if (flag) {
361 final int[] numbers = new int[2];
362 final String from = text.substring(text.indexOf("lines") + 5,
363 text.indexOf("to"));
364 numbers[0] = Integer.parseInt(from.trim());
365 final String to = text.substring(text.indexOf("to") + 2,
366 text.indexOf(")"));
367 numbers[1] = Integer.parseInt(to.trim());
368 return numbers;
369 }
370 }
371 return null;
372 }
373 }