1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.stat.regression;
18
19 import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
20 import org.apache.commons.math4.legacy.linear.LUDecomposition;
21 import org.apache.commons.math4.legacy.linear.RealMatrix;
22 import org.apache.commons.math4.legacy.linear.RealVector;
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42 public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
43
44
45 private RealMatrix omega;
46
47
48 private RealMatrix omegaInverse;
49
50
51
52
53
54
55 public void newSampleData(double[] y, double[][] x, double[][] covariance) {
56 validateSampleData(x, y);
57 newYSampleData(y);
58 newXSampleData(x);
59 validateCovarianceData(x, covariance);
60 newCovarianceData(covariance);
61 }
62
63
64
65
66
67
68 protected void newCovarianceData(double[][] omega){
69 this.omega = new Array2DRowRealMatrix(omega);
70 this.omegaInverse = null;
71 }
72
73
74
75
76
77
78 protected RealMatrix getOmegaInverse() {
79 if (omegaInverse == null) {
80 omegaInverse = new LUDecomposition(omega).getSolver().getInverse();
81 }
82 return omegaInverse;
83 }
84
85
86
87
88
89
90
91
92 @Override
93 protected RealVector calculateBeta() {
94 RealMatrix oi = getOmegaInverse();
95 RealMatrix xt = getX().transpose();
96 RealMatrix xtoix = xt.multiply(oi).multiply(getX());
97 RealMatrix inverse = new LUDecomposition(xtoix).getSolver().getInverse();
98 return inverse.multiply(xt).multiply(oi).operate(getY());
99 }
100
101
102
103
104
105
106
107
108 @Override
109 protected RealMatrix calculateBetaVariance() {
110 RealMatrix oi = getOmegaInverse();
111 RealMatrix xtoix = getX().transpose().multiply(oi).multiply(getX());
112 return new LUDecomposition(xtoix).getSolver().getInverse();
113 }
114
115
116
117
118
119
120
121
122
123
124
125
126
127 @Override
128 protected double calculateErrorVariance() {
129 RealVector residuals = calculateResiduals();
130 double t = residuals.dotProduct(getOmegaInverse().operate(residuals));
131 return t / (getX().getRowDimension() - getX().getColumnDimension());
132 }
133 }