1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.stat.regression;
18
19 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
20 import org.apache.commons.math4.legacy.exception.InsufficientDataException;
21 import org.apache.commons.math4.legacy.exception.MathIllegalArgumentException;
22 import org.apache.commons.math4.legacy.exception.NoDataException;
23 import org.apache.commons.math4.legacy.exception.NullArgumentException;
24 import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
25 import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
26 import org.apache.commons.math4.legacy.linear.ArrayRealVector;
27 import org.apache.commons.math4.legacy.linear.NonSquareMatrixException;
28 import org.apache.commons.math4.legacy.linear.RealMatrix;
29 import org.apache.commons.math4.legacy.linear.RealVector;
30 import org.apache.commons.math4.legacy.stat.descriptive.moment.Variance;
31 import org.apache.commons.math4.core.jdkmath.JdkMath;
32
33
34
35
36
37 public abstract class AbstractMultipleLinearRegression implements
38 MultipleLinearRegression {
39
40
41 private RealMatrix xMatrix;
42
43
44 private RealVector yVector;
45
46
47 private boolean noIntercept;
48
49
50
51
52 protected RealMatrix getX() {
53 return xMatrix;
54 }
55
56
57
58
59 protected RealVector getY() {
60 return yVector;
61 }
62
63
64
65
66
67 public boolean isNoIntercept() {
68 return noIntercept;
69 }
70
71
72
73
74
75 public void setNoIntercept(boolean noIntercept) {
76 this.noIntercept = noIntercept;
77 }
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113 public void newSampleData(double[] data, int nobs, int nvars) {
114 if (data == null) {
115 throw new NullArgumentException();
116 }
117 if (data.length != nobs * (nvars + 1)) {
118 throw new DimensionMismatchException(data.length, nobs * (nvars + 1));
119 }
120 if (nobs <= nvars) {
121 throw new InsufficientDataException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE, nobs, nvars + 1);
122 }
123 double[] y = new double[nobs];
124 final int cols = noIntercept ? nvars: nvars + 1;
125 double[][] x = new double[nobs][cols];
126 int pointer = 0;
127 for (int i = 0; i < nobs; i++) {
128 y[i] = data[pointer++];
129 if (!noIntercept) {
130 x[i][0] = 1.0d;
131 }
132 for (int j = noIntercept ? 0 : 1; j < cols; j++) {
133 x[i][j] = data[pointer++];
134 }
135 }
136 this.xMatrix = new Array2DRowRealMatrix(x);
137 this.yVector = new ArrayRealVector(y);
138 }
139
140
141
142
143
144
145
146
147 protected void newYSampleData(double[] y) {
148 if (y == null) {
149 throw new NullArgumentException();
150 }
151 if (y.length == 0) {
152 throw new NoDataException();
153 }
154 this.yVector = new ArrayRealVector(y);
155 }
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181 protected void newXSampleData(double[][] x) {
182 if (x == null) {
183 throw new NullArgumentException();
184 }
185 if (x.length == 0) {
186 throw new NoDataException();
187 }
188 if (noIntercept) {
189 this.xMatrix = new Array2DRowRealMatrix(x, true);
190 } else {
191 final int nVars = x[0].length;
192 final double[][] xAug = new double[x.length][nVars + 1];
193 for (int i = 0; i < x.length; i++) {
194 if (x[i].length != nVars) {
195 throw new DimensionMismatchException(x[i].length, nVars);
196 }
197 xAug[i][0] = 1.0d;
198 System.arraycopy(x[i], 0, xAug[i], 1, nVars);
199 }
200 this.xMatrix = new Array2DRowRealMatrix(xAug, false);
201 }
202 }
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222 protected void validateSampleData(double[][] x, double[] y) throws MathIllegalArgumentException {
223 if (x == null || y == null) {
224 throw new NullArgumentException();
225 }
226 if (x.length != y.length) {
227 throw new DimensionMismatchException(y.length, x.length);
228 }
229 if (x.length == 0) {
230 throw new NoDataException();
231 }
232 if (x[0].length + 1 > x.length) {
233 throw new MathIllegalArgumentException(
234 LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS,
235 x.length, x[0].length);
236 }
237 }
238
239
240
241
242
243
244
245
246
247
248
249 protected void validateCovarianceData(double[][] x, double[][] covariance) {
250 if (x.length != covariance.length) {
251 throw new DimensionMismatchException(x.length, covariance.length);
252 }
253 if (covariance.length > 0 && covariance.length != covariance[0].length) {
254 throw new NonSquareMatrixException(covariance.length, covariance[0].length);
255 }
256 }
257
258
259
260
261 @Override
262 public double[] estimateRegressionParameters() {
263 RealVector b = calculateBeta();
264 return b.toArray();
265 }
266
267
268
269
270 @Override
271 public double[] estimateResiduals() {
272 RealVector b = calculateBeta();
273 RealVector e = yVector.subtract(xMatrix.operate(b));
274 return e.toArray();
275 }
276
277
278
279
280 @Override
281 public double[][] estimateRegressionParametersVariance() {
282 return calculateBetaVariance().getData();
283 }
284
285
286
287
288 @Override
289 public double[] estimateRegressionParametersStandardErrors() {
290 double[][] betaVariance = estimateRegressionParametersVariance();
291 double sigma = calculateErrorVariance();
292 int length = betaVariance[0].length;
293 double[] result = new double[length];
294 for (int i = 0; i < length; i++) {
295 result[i] = JdkMath.sqrt(sigma * betaVariance[i][i]);
296 }
297 return result;
298 }
299
300
301
302
303 @Override
304 public double estimateRegressandVariance() {
305 return calculateYVariance();
306 }
307
308
309
310
311
312
313
314 public double estimateErrorVariance() {
315 return calculateErrorVariance();
316 }
317
318
319
320
321
322
323
324 public double estimateRegressionStandardError() {
325 return JdkMath.sqrt(estimateErrorVariance());
326 }
327
328
329
330
331
332
333 protected abstract RealVector calculateBeta();
334
335
336
337
338
339
340
341 protected abstract RealMatrix calculateBetaVariance();
342
343
344
345
346
347
348
349 protected double calculateYVariance() {
350 return new Variance().evaluate(yVector.toArray());
351 }
352
353
354
355
356
357
358
359
360
361
362
363
364 protected double calculateErrorVariance() {
365 RealVector residuals = calculateResiduals();
366 return residuals.dotProduct(residuals) /
367 (xMatrix.getRowDimension() - xMatrix.getColumnDimension());
368 }
369
370
371
372
373
374
375
376
377
378
379
380 protected RealVector calculateResiduals() {
381 RealVector b = calculateBeta();
382 return yVector.subtract(xMatrix.operate(b));
383 }
384 }