1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.math4.legacy.stat.regression;
18
19 import org.apache.commons.math4.legacy.exception.MathIllegalArgumentException;
20 import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
21 import org.apache.commons.math4.legacy.linear.LUDecomposition;
22 import org.apache.commons.math4.legacy.linear.QRDecomposition;
23 import org.apache.commons.math4.legacy.linear.RealMatrix;
24 import org.apache.commons.math4.legacy.linear.RealVector;
25 import org.apache.commons.math4.legacy.stat.StatUtils;
26 import org.apache.commons.math4.legacy.stat.descriptive.moment.SecondMoment;
27
28 /**
29 * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
30 * multiple linear regression model.</p>
31 *
32 * <p>The regression coefficients, <code>b</code>, satisfy the normal equations:
33 * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre>
34 *
35 * <p>To solve the normal equations, this implementation uses QR decomposition
36 * of the <code>X</code> matrix. (See {@link QRDecomposition} for details on the
37 * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i>
38 * has rows corresponding to sample observations and columns corresponding to independent
39 * variables. When the model is estimated using an intercept term (i.e. when
40 * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code>
41 * matrix includes an initial column identically equal to 1. We solve the normal equations
42 * as follows:
43 * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y
44 * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y
45 * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y
46 * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y
47 * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y
48 * R b = Q<sup>T</sup> y </code></pre>
49 *
50 * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p>
51 *
52 * @since 2.0
53 */
54 public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
55
56 /** Cached QR decomposition of X matrix. */
57 private QRDecomposition qr;
58
59 /** Singularity threshold for QR decomposition. */
60 private final double threshold;
61
62 /**
63 * Create an empty OLSMultipleLinearRegression instance.
64 */
65 public OLSMultipleLinearRegression() {
66 this(0d);
67 }
68
69 /**
70 * Create an empty OLSMultipleLinearRegression instance, using the given
71 * singularity threshold for the QR decomposition.
72 *
73 * @param threshold the singularity threshold
74 * @since 3.3
75 */
76 public OLSMultipleLinearRegression(final double threshold) {
77 this.threshold = threshold;
78 }
79
80 /**
81 * Loads model x and y sample data, overriding any previous sample.
82 *
83 * Computes and caches QR decomposition of the X matrix.
84 * @param y the [n,1] array representing the y sample
85 * @param x the [n,k] array representing the x sample
86 * @throws MathIllegalArgumentException if the x and y array data are not
87 * compatible for the regression
88 */
89 public void newSampleData(double[] y, double[][] x) throws MathIllegalArgumentException {
90 validateSampleData(x, y);
91 newYSampleData(y);
92 newXSampleData(x);
93 }
94
95 /**
96 * {@inheritDoc}
97 * <p>This implementation computes and caches the QR decomposition of the X matrix.</p>
98 */
99 @Override
100 public void newSampleData(double[] data, int nobs, int nvars) {
101 super.newSampleData(data, nobs, nvars);
102 qr = new QRDecomposition(getX(), threshold);
103 }
104
105 /**
106 * <p>Compute the "hat" matrix.
107 * </p>
108 * <p>The hat matrix is defined in terms of the design matrix X
109 * by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
110 * </p>
111 * <p>The implementation here uses the QR decomposition to compute the
112 * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
113 * p-dimensional identity matrix augmented by 0's. This computational
114 * formula is from "The Hat Matrix in Regression and ANOVA",
115 * David C. Hoaglin and Roy E. Welsch,
116 * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
117 * </p>
118 * <p>Data for the model must have been successfully loaded using one of
119 * the {@code newSampleData} methods before invoking this method; otherwise
120 * a {@code NullPointerException} will be thrown.</p>
121 *
122 * @return the hat matrix
123 * @throws NullPointerException unless method {@code newSampleData} has been
124 * called beforehand.
125 */
126 public RealMatrix calculateHat() {
127 // Create augmented identity matrix
128 RealMatrix q = qr.getQ();
129 final int p = qr.getR().getColumnDimension();
130 final int n = q.getColumnDimension();
131 // No try-catch or advertised NotStrictlyPositiveException - NPE above if n < 3
132 Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
133 double[][] augIData = augI.getDataRef();
134 for (int i = 0; i < n; i++) {
135 for (int j =0; j < n; j++) {
136 if (i == j && i < p) {
137 augIData[i][j] = 1d;
138 } else {
139 augIData[i][j] = 0d;
140 }
141 }
142 }
143
144 // Compute and return Hat matrix
145 // No DME advertised - args valid if we get here
146 return q.multiply(augI).multiply(q.transpose());
147 }
148
149 /**
150 * <p>Returns the sum of squared deviations of Y from its mean.</p>
151 *
152 * <p>If the model has no intercept term, <code>0</code> is used for the
153 * mean of Y - i.e., what is returned is the sum of the squared Y values.</p>
154 *
155 * <p>The value returned by this method is the SSTO value used in
156 * the {@link #calculateRSquared() R-squared} computation.</p>
157 *
158 * @return SSTO - the total sum of squares
159 * @throws NullPointerException if the sample has not been set
160 * @see #isNoIntercept()
161 * @since 2.2
162 */
163 public double calculateTotalSumOfSquares() {
164 if (isNoIntercept()) {
165 return StatUtils.sumSq(getY().toArray());
166 } else {
167 return new SecondMoment().evaluate(getY().toArray());
168 }
169 }
170
171 /**
172 * Returns the sum of squared residuals.
173 *
174 * @return residual sum of squares
175 * @since 2.2
176 * @throws org.apache.commons.math4.legacy.linear.SingularMatrixException if the design matrix is singular
177 * @throws NullPointerException if the data for the model have not been loaded
178 */
179 public double calculateResidualSumOfSquares() {
180 final RealVector residuals = calculateResiduals();
181 // No advertised DME, args are valid
182 return residuals.dotProduct(residuals);
183 }
184
185 /**
186 * Returns the R-Squared statistic, defined by the formula <div style="white-space: pre"><code>
187 * R<sup>2</sup> = 1 - SSR / SSTO
188 * </code></div>
189 * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
190 * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
191 *
192 * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
193 *
194 * @return R-square statistic
195 * @throws NullPointerException if the sample has not been set
196 * @throws org.apache.commons.math4.legacy.linear.SingularMatrixException if the design matrix is singular
197 * @since 2.2
198 */
199 public double calculateRSquared() {
200 return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
201 }
202
203 /**
204 * <p>Returns the adjusted R-squared statistic, defined by the formula <div style="white-space: pre"><code>
205 * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)]
206 * </code></div>
207 * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
208 * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
209 * of observations and p is the number of parameters estimated (including the intercept).
210 *
211 * <p>If the regression is estimated without an intercept term, what is returned is <pre>
212 * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code>
213 * </pre>
214 *
215 * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
216 *
217 * @return adjusted R-Squared statistic
218 * @throws NullPointerException if the sample has not been set
219 * @throws org.apache.commons.math4.legacy.linear.SingularMatrixException if the design matrix is singular
220 * @see #isNoIntercept()
221 * @since 2.2
222 */
223 public double calculateAdjustedRSquared() {
224 final double n = getX().getRowDimension();
225 if (isNoIntercept()) {
226 return 1 - (1 - calculateRSquared()) * (n / (n - getX().getColumnDimension()));
227 } else {
228 return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
229 (calculateTotalSumOfSquares() * (n - getX().getColumnDimension()));
230 }
231 }
232
233 /**
234 * {@inheritDoc}
235 * <p>This implementation computes and caches the QR decomposition of the X matrix
236 * once it is successfully loaded.</p>
237 */
238 @Override
239 protected void newXSampleData(double[][] x) {
240 super.newXSampleData(x);
241 qr = new QRDecomposition(getX(), threshold);
242 }
243
244 /**
245 * Calculates the regression coefficients using OLS.
246 *
247 * <p>Data for the model must have been successfully loaded using one of
248 * the {@code newSampleData} methods before invoking this method; otherwise
249 * a {@code NullPointerException} will be thrown.</p>
250 *
251 * @return beta
252 * @throws org.apache.commons.math4.legacy.linear.SingularMatrixException if the design matrix is singular
253 * @throws NullPointerException if the data for the model have not been loaded
254 */
255 @Override
256 protected RealVector calculateBeta() {
257 return qr.getSolver().solve(getY());
258 }
259
260 /**
261 * <p>Calculates the variance-covariance matrix of the regression parameters.
262 * </p>
263 * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
264 * </p>
265 * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
266 * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
267 * R included, where p = the length of the beta vector.</p>
268 *
269 * <p>Data for the model must have been successfully loaded using one of
270 * the {@code newSampleData} methods before invoking this method; otherwise
271 * a {@code NullPointerException} will be thrown.</p>
272 *
273 * @return The beta variance-covariance matrix
274 * @throws org.apache.commons.math4.legacy.linear.SingularMatrixException if the design matrix is singular
275 * @throws NullPointerException if the data for the model have not been loaded
276 */
277 @Override
278 protected RealMatrix calculateBetaVariance() {
279 int p = getX().getColumnDimension();
280 RealMatrix rAug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
281 RealMatrix rInv = new LUDecomposition(rAug).getSolver().getInverse();
282 return rInv.multiply(rInv.transpose());
283 }
284 }