1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.fitting.leastsquares;
18
19 import java.io.BufferedReader;
20 import java.io.IOException;
21 import java.util.ArrayList;
22 import java.util.Arrays;
23
24 import org.apache.commons.math4.legacy.analysis.MultivariateMatrixFunction;
25 import org.apache.commons.math4.legacy.analysis.MultivariateVectorFunction;
26
27
28
29
30
31
32
33
34 public abstract class StatisticalReferenceDataset {
35
36 private final String name;
37
38 private final int numObservations;
39
40 private final int numParameters;
41
42 private final int numStartingPoints;
43
44 private final double[] x;
45
46 private final double[] y;
47
48
49
50
51 private final double[][] startingValues;
52
53 private final double[] a;
54
55 private final double[] sigA;
56
57 private double residualSumOfSquares;
58
59 private final LeastSquaresProblem problem;
60
61
62
63
64
65
66
67
68 public StatisticalReferenceDataset(final BufferedReader in)
69 throws IOException {
70
71 final ArrayList<String> lines = new ArrayList<>();
72 for (String line = in.readLine(); line != null; line = in.readLine()) {
73 lines.add(line);
74 }
75 int[] index = findLineNumbers("Data", lines);
76 if (index == null) {
77 throw new AssertionError("could not find line indices for data");
78 }
79 this.numObservations = index[1] - index[0] + 1;
80 this.x = new double[this.numObservations];
81 this.y = new double[this.numObservations];
82 for (int i = 0; i < this.numObservations; i++) {
83 final String line = lines.get(index[0] + i - 1);
84 final String[] tokens = line.trim().split(" ++");
85
86 this.y[i] = Double.parseDouble(tokens[0]);
87 this.x[i] = Double.parseDouble(tokens[1]);
88 }
89
90 index = findLineNumbers("Starting Values", lines);
91 if (index == null) {
92 throw new AssertionError(
93 "could not find line indices for starting values");
94 }
95 this.numParameters = index[1] - index[0] + 1;
96
97 double[][] start = null;
98 this.a = new double[numParameters];
99 this.sigA = new double[numParameters];
100 for (int i = 0; i < numParameters; i++) {
101 final String line = lines.get(index[0] + i - 1);
102 final String[] tokens = line.trim().split(" ++");
103 if (start == null) {
104 start = new double[tokens.length - 4][numParameters];
105 }
106 for (int j = 2; j < tokens.length - 2; j++) {
107 start[j - 2][i] = Double.parseDouble(tokens[j]);
108 }
109 this.a[i] = Double.parseDouble(tokens[tokens.length - 2]);
110 this.sigA[i] = Double.parseDouble(tokens[tokens.length - 1]);
111 }
112 if (start == null) {
113 throw new IOException("could not find starting values");
114 }
115 this.numStartingPoints = start.length;
116 this.startingValues = start;
117
118 double dummyDouble = Double.NaN;
119 String dummyString = null;
120 for (String line : lines) {
121 if (line.contains("Dataset Name:")) {
122 dummyString = line
123 .substring(line.indexOf("Dataset Name:") + 13,
124 line.indexOf("(")).trim();
125 }
126 if (line.contains("Residual Sum of Squares")) {
127 final String[] tokens = line.split(" ++");
128 dummyDouble = Double.parseDouble(tokens[4].trim());
129 }
130 }
131 if (Double.isNaN(dummyDouble)) {
132 throw new IOException(
133 "could not find certified value of residual sum of squares");
134 }
135 this.residualSumOfSquares = dummyDouble;
136
137 if (dummyString == null) {
138 throw new IOException("could not find dataset name");
139 }
140 this.name = dummyString;
141
142 this.problem = new LeastSquaresProblem();
143 }
144
145 class LeastSquaresProblem {
146 public MultivariateVectorFunction getModelFunction() {
147 return new MultivariateVectorFunction() {
148 @Override
149 public double[] value(final double[] a) {
150 final int n = getNumObservations();
151 final double[] yhat = new double[n];
152 for (int i = 0; i < n; i++) {
153 yhat[i] = getModelValue(getX(i), a);
154 }
155 return yhat;
156 }
157 };
158 }
159
160 public MultivariateMatrixFunction getModelFunctionJacobian() {
161 return new MultivariateMatrixFunction() {
162 @Override
163 public double[][] value(final double[] a)
164 throws IllegalArgumentException {
165 final int n = getNumObservations();
166 final double[][] j = new double[n][];
167 for (int i = 0; i < n; i++) {
168 j[i] = getModelDerivatives(getX(i), a);
169 }
170 return j;
171 }
172 };
173 }
174 }
175
176
177
178
179
180
181 public String getName() {
182 return name;
183 }
184
185
186
187
188
189
190 public int getNumObservations() {
191 return numObservations;
192 }
193
194
195
196
197
198
199
200 public double[][] getData() {
201 return new double[][] {
202 Arrays.copyOf(x, x.length), Arrays.copyOf(y, y.length)
203 };
204 }
205
206
207
208
209
210
211
212 public double getX(final int i) {
213 return x[i];
214 }
215
216
217
218
219
220
221
222 public double getY(final int i) {
223 return y[i];
224 }
225
226
227
228
229
230
231 public int getNumParameters() {
232 return numParameters;
233 }
234
235
236
237
238
239
240 public double[] getParameters() {
241 return Arrays.copyOf(a, a.length);
242 }
243
244
245
246
247
248
249
250 public double getParameter(final int i) {
251 return a[i];
252 }
253
254
255
256
257
258
259 public double[] getParametersStandardDeviations() {
260 return Arrays.copyOf(sigA, sigA.length);
261 }
262
263
264
265
266
267
268
269
270 public double getParameterStandardDeviation(final int i) {
271 return sigA[i];
272 }
273
274
275
276
277
278
279 public double getResidualSumOfSquares() {
280 return residualSumOfSquares;
281 }
282
283
284
285
286
287
288
289 public int getNumStartingPoints() {
290 return numStartingPoints;
291 }
292
293
294
295
296
297
298
299 public double[] getStartingPoint(final int i) {
300 return Arrays.copyOf(startingValues[i], startingValues[i].length);
301 }
302
303
304
305
306
307
308
309 public LeastSquaresProblem getLeastSquaresProblem() {
310 return problem;
311 }
312
313
314
315
316
317
318
319
320
321 public abstract double getModelValue(double x, double[] a);
322
323
324
325
326
327
328
329
330
331 public abstract double[] getModelDerivatives(double x,
332 double[] a);
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355 private static int[] findLineNumbers(final String key,
356 final Iterable<String> lines) {
357 for (String text : lines) {
358 boolean flag = text.contains(key) && text.contains("lines") &&
359 text.contains("to") && text.contains(")");
360 if (flag) {
361 final int[] numbers = new int[2];
362 final String from = text.substring(text.indexOf("lines") + 5,
363 text.indexOf("to"));
364 numbers[0] = Integer.parseInt(from.trim());
365 final String to = text.substring(text.indexOf("to") + 2,
366 text.indexOf(")"));
367 numbers[1] = Integer.parseInt(to.trim());
368 return numbers;
369 }
370 }
371 return null;
372 }
373 }