1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  
18  package org.apache.commons.math4.legacy.fitting.leastsquares;
19  
20  import java.util.ArrayList;
21  
22  import org.apache.commons.math4.legacy.analysis.MultivariateMatrixFunction;
23  import org.apache.commons.math4.legacy.analysis.MultivariateVectorFunction;
24  import org.apache.commons.math4.legacy.analysis.UnivariateFunction;
25  import org.apache.commons.math4.legacy.stat.regression.SimpleRegression;
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  class StraightLineProblem {
40      
41      private final ArrayList<double[]> points;
42      
43      private final double sigma;
44  
45      
46  
47  
48      StraightLineProblem(double error) {
49          points = new ArrayList<>();
50          sigma = error;
51      }
52  
53      public void addPoint(double px, double py) {
54          points.add(new double[] { px, py });
55      }
56  
57      
58  
59  
60      public double[] x() {
61          final double[] v = new double[points.size()];
62          for (int i = 0; i < points.size(); i++) {
63              final double[] p = points.get(i);
64              v[i] = p[0]; 
65          }
66  
67          return v;
68      }
69  
70      
71  
72  
73      public double[] y() {
74          final double[] v = new double[points.size()];
75          for (int i = 0; i < points.size(); i++) {
76              final double[] p = points.get(i);
77              v[i] = p[1]; 
78          }
79  
80          return v;
81      }
82  
83      public double[] target() {
84          return y();
85      }
86  
87      public double[] weight() {
88          final double weight = 1 / (sigma * sigma);
89          final double[] w = new double[points.size()];
90          for (int i = 0; i < points.size(); i++) {
91              w[i] = weight;
92          }
93  
94          return w;
95      }
96  
97      public MultivariateVectorFunction getModelFunction() {
98          return new MultivariateVectorFunction() {
99              @Override
100             public double[] value(double[] params) {
101                 final Model line = new Model(params[0], params[1]);
102 
103                 final double[] model = new double[points.size()];
104                 for (int i = 0; i < points.size(); i++) {
105                     final double[] p = points.get(i);
106                     model[i] = line.value(p[0]);
107                 }
108 
109                 return model;
110             }
111         };
112     }
113 
114     public MultivariateMatrixFunction getModelFunctionJacobian() {
115         return new MultivariateMatrixFunction() {
116             @Override
117             public double[][] value(double[] point) {
118                 return jacobian(point);
119             }
120         };
121     }
122 
123     
124 
125 
126 
127     public double[] solve() {
128         final SimpleRegression regress = new SimpleRegression(true);
129         for (double[] d : points) {
130             regress.addData(d[0], d[1]);
131         }
132 
133         final double[] result = { regress.getSlope(), regress.getIntercept() };
134         return result;
135     }
136 
137     private double[][] jacobian(double[] params) {
138         final double[][] jacobian = new double[points.size()][2];
139 
140         for (int i = 0; i < points.size(); i++) {
141             final double[] p = points.get(i);
142             
143             jacobian[i][0] = p[0];
144             
145             jacobian[i][1] = 1;
146         }
147 
148         return jacobian;
149     }
150 
151     
152 
153 
154     public static class Model implements UnivariateFunction {
155         private final double a;
156         private final double b;
157 
158         Model(double a,
159               double b) {
160             this.a = a;
161             this.b = b;
162         }
163 
164         @Override
165         public double value(double x) {
166             return a * x + b;
167         }
168     }
169 }