View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.fitting.leastsquares;
19  
20  import java.util.ArrayList;
21  
22  import org.apache.commons.math4.legacy.analysis.MultivariateMatrixFunction;
23  import org.apache.commons.math4.legacy.analysis.MultivariateVectorFunction;
24  import org.apache.commons.math4.legacy.analysis.UnivariateFunction;
25  import org.apache.commons.math4.legacy.stat.regression.SimpleRegression;
26  
27  /**
28   * Class that models a straight line defined as {@code y = a x + b}.
29   * The parameters of problem are:
30   * <ul>
31   *  <li>{@code a}</li>
32   *  <li>{@code b}</li>
33   * </ul>
34   * The model functions are:
35   * <ul>
36   *  <li>for each pair (a, b), the y-coordinate of the line.</li>
37   * </ul>
38   */
39  class StraightLineProblem {
40      /** Cloud of points assumed to be fitted by a straight line. */
41      private final ArrayList<double[]> points;
42      /** Error (on the y-coordinate of the points). */
43      private final double sigma;
44  
45      /**
46       * @param error Assumed error for the y-coordinate.
47       */
48      StraightLineProblem(double error) {
49          points = new ArrayList<>();
50          sigma = error;
51      }
52  
53      public void addPoint(double px, double py) {
54          points.add(new double[] { px, py });
55      }
56  
57      /**
58       * @return the list of x-coordinates.
59       */
60      public double[] x() {
61          final double[] v = new double[points.size()];
62          for (int i = 0; i < points.size(); i++) {
63              final double[] p = points.get(i);
64              v[i] = p[0]; // x-coordinate.
65          }
66  
67          return v;
68      }
69  
70      /**
71       * @return the list of y-coordinates.
72       */
73      public double[] y() {
74          final double[] v = new double[points.size()];
75          for (int i = 0; i < points.size(); i++) {
76              final double[] p = points.get(i);
77              v[i] = p[1]; // y-coordinate.
78          }
79  
80          return v;
81      }
82  
83      public double[] target() {
84          return y();
85      }
86  
87      public double[] weight() {
88          final double weight = 1 / (sigma * sigma);
89          final double[] w = new double[points.size()];
90          for (int i = 0; i < points.size(); i++) {
91              w[i] = weight;
92          }
93  
94          return w;
95      }
96  
97      public MultivariateVectorFunction getModelFunction() {
98          return new MultivariateVectorFunction() {
99              @Override
100             public double[] value(double[] params) {
101                 final Model line = new Model(params[0], params[1]);
102 
103                 final double[] model = new double[points.size()];
104                 for (int i = 0; i < points.size(); i++) {
105                     final double[] p = points.get(i);
106                     model[i] = line.value(p[0]);
107                 }
108 
109                 return model;
110             }
111         };
112     }
113 
114     public MultivariateMatrixFunction getModelFunctionJacobian() {
115         return new MultivariateMatrixFunction() {
116             @Override
117             public double[][] value(double[] point) {
118                 return jacobian(point);
119             }
120         };
121     }
122 
123     /**
124      * Directly solve the linear problem, using the {@link SimpleRegression}
125      * class.
126      */
127     public double[] solve() {
128         final SimpleRegression regress = new SimpleRegression(true);
129         for (double[] d : points) {
130             regress.addData(d[0], d[1]);
131         }
132 
133         final double[] result = { regress.getSlope(), regress.getIntercept() };
134         return result;
135     }
136 
137     private double[][] jacobian(double[] params) {
138         final double[][] jacobian = new double[points.size()][2];
139 
140         for (int i = 0; i < points.size(); i++) {
141             final double[] p = points.get(i);
142             // Partial derivative wrt "a".
143             jacobian[i][0] = p[0];
144             // Partial derivative wrt "b".
145             jacobian[i][1] = 1;
146         }
147 
148         return jacobian;
149     }
150 
151     /**
152      * Linear function.
153      */
154     public static class Model implements UnivariateFunction {
155         private final double a;
156         private final double b;
157 
158         Model(double a,
159               double b) {
160             this.a = a;
161             this.b = b;
162         }
163 
164         @Override
165         public double value(double x) {
166             return a * x + b;
167         }
168     }
169 }