View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.fitting.leastsquares;
18  
19  import java.io.BufferedReader;
20  import java.io.IOException;
21  import java.util.ArrayList;
22  import java.util.Arrays;
23  
24  import org.apache.commons.math4.legacy.analysis.MultivariateMatrixFunction;
25  import org.apache.commons.math4.legacy.analysis.MultivariateVectorFunction;
26  
27  /**
28   * This class gives access to the statistical reference datasets provided by the
29   * NIST (available
30   * <a href="http://www.itl.nist.gov/div898/strd/general/dataarchive.html">here</a>).
31   * Instances of this class can be created by invocation of the
32   * {@link StatisticalReferenceDatasetFactory}.
33   */
34  public abstract class StatisticalReferenceDataset {
35      /** The name of this dataset. */
36      private final String name;
37      /** The total number of observations (data points). */
38      private final int numObservations;
39      /** The total number of parameters. */
40      private final int numParameters;
41      /** The total number of starting points for the optimizations. */
42      private final int numStartingPoints;
43      /** The values of the predictor. */
44      private final double[] x;
45      /** The values of the response. */
46      private final double[] y;
47      /**
48       * The starting values. {@code startingValues[j][i]} is the value of the
49       * {@code i}-th parameter in the {@code j}-th set of starting values.
50       */
51      private final double[][] startingValues;
52      /** The certified values of the parameters. */
53      private final double[] a;
54      /** The certified values of the standard deviation of the parameters. */
55      private final double[] sigA;
56      /** The certified value of the residual sum of squares. */
57      private double residualSumOfSquares;
58      /** The least-squares problem. */
59      private final LeastSquaresProblem problem;
60  
61      /**
62       * Creates a new instance of this class from the specified data file. The
63       * file must follow the StRD format.
64       *
65       * @param in the data file
66       * @throws IOException if an I/O error occurs
67       */
68      public StatisticalReferenceDataset(final BufferedReader in)
69          throws IOException {
70  
71          final ArrayList<String> lines = new ArrayList<>();
72          for (String line = in.readLine(); line != null; line = in.readLine()) {
73              lines.add(line);
74          }
75          int[] index = findLineNumbers("Data", lines);
76          if (index == null) {
77              throw new AssertionError("could not find line indices for data");
78          }
79          this.numObservations = index[1] - index[0] + 1;
80          this.x = new double[this.numObservations];
81          this.y = new double[this.numObservations];
82          for (int i = 0; i < this.numObservations; i++) {
83              final String line = lines.get(index[0] + i - 1);
84              final String[] tokens = line.trim().split(" ++");
85              // Data columns are in reverse order!!!
86              this.y[i] = Double.parseDouble(tokens[0]);
87              this.x[i] = Double.parseDouble(tokens[1]);
88          }
89  
90          index = findLineNumbers("Starting Values", lines);
91          if (index == null) {
92              throw new AssertionError(
93                                       "could not find line indices for starting values");
94          }
95          this.numParameters = index[1] - index[0] + 1;
96  
97          double[][] start = null;
98          this.a = new double[numParameters];
99          this.sigA = new double[numParameters];
100         for (int i = 0; i < numParameters; i++) {
101             final String line = lines.get(index[0] + i - 1);
102             final String[] tokens = line.trim().split(" ++");
103             if (start == null) {
104                 start = new double[tokens.length - 4][numParameters];
105             }
106             for (int j = 2; j < tokens.length - 2; j++) {
107                 start[j - 2][i] = Double.parseDouble(tokens[j]);
108             }
109             this.a[i] = Double.parseDouble(tokens[tokens.length - 2]);
110             this.sigA[i] = Double.parseDouble(tokens[tokens.length - 1]);
111         }
112         if (start == null) {
113             throw new IOException("could not find starting values");
114         }
115         this.numStartingPoints = start.length;
116         this.startingValues = start;
117 
118         double dummyDouble = Double.NaN;
119         String dummyString = null;
120         for (String line : lines) {
121             if (line.contains("Dataset Name:")) {
122                 dummyString = line
123                     .substring(line.indexOf("Dataset Name:") + 13,
124                                line.indexOf("(")).trim();
125             }
126             if (line.contains("Residual Sum of Squares")) {
127                 final String[] tokens = line.split(" ++");
128                 dummyDouble = Double.parseDouble(tokens[4].trim());
129             }
130         }
131         if (Double.isNaN(dummyDouble)) {
132             throw new IOException(
133                                   "could not find certified value of residual sum of squares");
134         }
135         this.residualSumOfSquares = dummyDouble;
136 
137         if (dummyString == null) {
138             throw new IOException("could not find dataset name");
139         }
140         this.name = dummyString;
141 
142         this.problem = new LeastSquaresProblem();
143     }
144 
145     class LeastSquaresProblem {
146         public MultivariateVectorFunction getModelFunction() {
147             return new MultivariateVectorFunction() {
148                 @Override
149                 public double[] value(final double[] a) {
150                     final int n = getNumObservations();
151                     final double[] yhat = new double[n];
152                     for (int i = 0; i < n; i++) {
153                         yhat[i] = getModelValue(getX(i), a);
154                     }
155                     return yhat;
156                 }
157             };
158         }
159 
160         public MultivariateMatrixFunction getModelFunctionJacobian() {
161             return new MultivariateMatrixFunction() {
162                 @Override
163                 public double[][] value(final double[] a)
164                     throws IllegalArgumentException {
165                     final int n = getNumObservations();
166                     final double[][] j = new double[n][];
167                     for (int i = 0; i < n; i++) {
168                         j[i] = getModelDerivatives(getX(i), a);
169                     }
170                     return j;
171                 }
172             };
173         }
174     }
175 
176     /**
177      * Returns the name of this dataset.
178      *
179      * @return the name of the dataset
180      */
181     public String getName() {
182         return name;
183     }
184 
185     /**
186      * Returns the total number of observations (data points).
187      *
188      * @return the number of observations
189      */
190     public int getNumObservations() {
191         return numObservations;
192     }
193 
194     /**
195      * Returns a copy of the data arrays. The data is laid out as follows <li>
196      * {@code data[0][i] = x[i]},</li> <li>{@code data[1][i] = y[i]},</li>
197      *
198      * @return the array of data points.
199      */
200     public double[][] getData() {
201         return new double[][] {
202             Arrays.copyOf(x, x.length), Arrays.copyOf(y, y.length)
203         };
204     }
205 
206     /**
207      * Returns the x-value of the {@code i}-th data point.
208      *
209      * @param i the index of the data point
210      * @return the x-value
211      */
212     public double getX(final int i) {
213         return x[i];
214     }
215 
216     /**
217      * Returns the y-value of the {@code i}-th data point.
218      *
219      * @param i the index of the data point
220      * @return the y-value
221      */
222     public double getY(final int i) {
223         return y[i];
224     }
225 
226     /**
227      * Returns the total number of parameters.
228      *
229      * @return the number of parameters
230      */
231     public int getNumParameters() {
232         return numParameters;
233     }
234 
235     /**
236      * Returns the certified values of the paramters.
237      *
238      * @return the values of the parameters
239      */
240     public double[] getParameters() {
241         return Arrays.copyOf(a, a.length);
242     }
243 
244     /**
245      * Returns the certified value of the {@code i}-th parameter.
246      *
247      * @param i the index of the parameter
248      * @return the value of the parameter
249      */
250     public double getParameter(final int i) {
251         return a[i];
252     }
253 
254     /**
255      * Reurns the certified values of the standard deviations of the parameters.
256      *
257      * @return the standard deviations of the parameters
258      */
259     public double[] getParametersStandardDeviations() {
260         return Arrays.copyOf(sigA, sigA.length);
261     }
262 
263     /**
264      * Returns the certified value of the standard deviation of the {@code i}-th
265      * parameter.
266      *
267      * @param i the index of the parameter
268      * @return the standard deviation of the parameter
269      */
270     public double getParameterStandardDeviation(final int i) {
271         return sigA[i];
272     }
273 
274     /**
275      * Returns the certified value of the residual sum of squares.
276      *
277      * @return the residual sum of squares
278      */
279     public double getResidualSumOfSquares() {
280         return residualSumOfSquares;
281     }
282 
283     /**
284      * Returns the total number of starting points (initial guesses for the
285      * optimization process).
286      *
287      * @return the number of starting points
288      */
289     public int getNumStartingPoints() {
290         return numStartingPoints;
291     }
292 
293     /**
294      * Returns the {@code i}-th set of initial values of the parameters.
295      *
296      * @param i the index of the starting point
297      * @return the starting point
298      */
299     public double[] getStartingPoint(final int i) {
300         return Arrays.copyOf(startingValues[i], startingValues[i].length);
301     }
302 
303     /**
304      * Returns the least-squares problem corresponding to fitting the model to
305      * the specified data.
306      *
307      * @return the least-squares problem
308      */
309     public LeastSquaresProblem getLeastSquaresProblem() {
310         return problem;
311     }
312 
313     /**
314      * Returns the value of the model for the specified values of the predictor
315      * variable and the parameters.
316      *
317      * @param x the predictor variable
318      * @param a the parameters
319      * @return the value of the model
320      */
321     public abstract double getModelValue(double x, double[] a);
322 
323     /**
324      * Returns the values of the partial derivatives of the model with respect
325      * to the parameters.
326      *
327      * @param x the predictor variable
328      * @param a the parameters
329      * @return the partial derivatives
330      */
331     public abstract double[] getModelDerivatives(double x,
332                                                  double[] a);
333 
334     /**
335      * <p>
336      * Parses the specified text lines, and extracts the indices of the first
337      * and last lines of the data defined by the specified {@code key}. This key
338      * must be one of
339      * </p>
340      * <ul>
341      * <li>{@code "Starting Values"},</li>
342      * <li>{@code "Certified Values"},</li>
343      * <li>{@code "Data"}.</li>
344      * </ul>
345      * <p>
346      * In the NIST data files, the line indices are separated by the keywords
347      * {@code "lines"} and {@code "to"}.
348      * </p>
349      *
350      * @param lines the line of text to be parsed
351      * @return an array of two {@code int}s. First value is the index of the
352      *         first line, second value is the index of the last line.
353      *         {@code null} if the line could not be parsed.
354      */
355     private static int[] findLineNumbers(final String key,
356                                          final Iterable<String> lines) {
357         for (String text : lines) {
358             boolean flag = text.contains(key) && text.contains("lines") &&
359                            text.contains("to") && text.contains(")");
360             if (flag) {
361                 final int[] numbers = new int[2];
362                 final String from = text.substring(text.indexOf("lines") + 5,
363                                                    text.indexOf("to"));
364                 numbers[0] = Integer.parseInt(from.trim());
365                 final String to = text.substring(text.indexOf("to") + 2,
366                                                  text.indexOf(")"));
367                 numbers[1] = Integer.parseInt(to.trim());
368                 return numbers;
369             }
370         }
371         return null;
372     }
373 }