1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.math4.legacy.stat.regression;
18
19 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
20 import org.apache.commons.math4.legacy.exception.InsufficientDataException;
21 import org.apache.commons.math4.legacy.exception.MathIllegalArgumentException;
22 import org.apache.commons.math4.legacy.exception.NoDataException;
23 import org.apache.commons.math4.legacy.exception.NullArgumentException;
24 import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
25 import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
26 import org.apache.commons.math4.legacy.linear.ArrayRealVector;
27 import org.apache.commons.math4.legacy.linear.NonSquareMatrixException;
28 import org.apache.commons.math4.legacy.linear.RealMatrix;
29 import org.apache.commons.math4.legacy.linear.RealVector;
30 import org.apache.commons.math4.legacy.stat.descriptive.moment.Variance;
31 import org.apache.commons.math4.core.jdkmath.JdkMath;
32
33 /**
34 * Abstract base class for implementations of MultipleLinearRegression.
35 * @since 2.0
36 */
37 public abstract class AbstractMultipleLinearRegression implements
38 MultipleLinearRegression {
39
40 /** X sample data. */
41 private RealMatrix xMatrix;
42
43 /** Y sample data. */
44 private RealVector yVector;
45
46 /** Whether or not the regression model includes an intercept. True means no intercept. */
47 private boolean noIntercept;
48
49 /**
50 * @return the X sample data.
51 */
52 protected RealMatrix getX() {
53 return xMatrix;
54 }
55
56 /**
57 * @return the Y sample data.
58 */
59 protected RealVector getY() {
60 return yVector;
61 }
62
63 /**
64 * @return true if the model has no intercept term; false otherwise
65 * @since 2.2
66 */
67 public boolean isNoIntercept() {
68 return noIntercept;
69 }
70
71 /**
72 * @param noIntercept true means the model is to be estimated without an intercept term
73 * @since 2.2
74 */
75 public void setNoIntercept(boolean noIntercept) {
76 this.noIntercept = noIntercept;
77 }
78
79 /**
80 * <p>Loads model x and y sample data from a flat input array, overriding any previous sample.
81 * </p>
82 * <p>Assumes that rows are concatenated with y values first in each row. For example, an input
83 * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with
84 * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two
85 * independent variables, as below:
86 * <pre>
87 * y x[0] x[1]
88 * --------------
89 * 1 2 3
90 * 4 5 6
91 * 7 8 9
92 * </pre>
93 *
94 * <p>Note that there is no need to add an initial unitary column (column of 1's) when
95 * specifying a model including an intercept term. If {@link #isNoIntercept()} is <code>true</code>,
96 * the X matrix will be created without an initial column of "1"s; otherwise this column will
97 * be added.
98 * </p>
99 * <p>Throws IllegalArgumentException if any of the following preconditions fail:
100 * <ul><li><code>data</code> cannot be null</li>
101 * <li><code>data.length = nobs * (nvars + 1)</code></li>
102 * <li>{@code nobs > nvars}</li></ul>
103 *
104 * @param data input data array
105 * @param nobs number of observations (rows)
106 * @param nvars number of independent variables (columns, not counting y)
107 * @throws NullArgumentException if the data array is null
108 * @throws DimensionMismatchException if the length of the data array is not equal
109 * to <code>nobs * (nvars + 1)</code>
110 * @throws InsufficientDataException if <code>nobs</code> is less than
111 * <code>nvars + 1</code>
112 */
113 public void newSampleData(double[] data, int nobs, int nvars) {
114 if (data == null) {
115 throw new NullArgumentException();
116 }
117 if (data.length != nobs * (nvars + 1)) {
118 throw new DimensionMismatchException(data.length, nobs * (nvars + 1));
119 }
120 if (nobs <= nvars) {
121 throw new InsufficientDataException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE, nobs, nvars + 1);
122 }
123 double[] y = new double[nobs];
124 final int cols = noIntercept ? nvars: nvars + 1;
125 double[][] x = new double[nobs][cols];
126 int pointer = 0;
127 for (int i = 0; i < nobs; i++) {
128 y[i] = data[pointer++];
129 if (!noIntercept) {
130 x[i][0] = 1.0d;
131 }
132 for (int j = noIntercept ? 0 : 1; j < cols; j++) {
133 x[i][j] = data[pointer++];
134 }
135 }
136 this.xMatrix = new Array2DRowRealMatrix(x);
137 this.yVector = new ArrayRealVector(y);
138 }
139
140 /**
141 * Loads new y sample data, overriding any previous data.
142 *
143 * @param y the array representing the y sample
144 * @throws NullArgumentException if y is null
145 * @throws NoDataException if y is empty
146 */
147 protected void newYSampleData(double[] y) {
148 if (y == null) {
149 throw new NullArgumentException();
150 }
151 if (y.length == 0) {
152 throw new NoDataException();
153 }
154 this.yVector = new ArrayRealVector(y);
155 }
156
157 /**
158 * <p>Loads new x sample data, overriding any previous data.
159 * </p>
160 * The input <code>x</code> array should have one row for each sample
161 * observation, with columns corresponding to independent variables.
162 * For example, if <pre>
163 * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre>
164 * then <code>setXSampleData(x) </code> results in a model with two independent
165 * variables and 3 observations:
166 * <pre>
167 * x[0] x[1]
168 * ----------
169 * 1 2
170 * 3 4
171 * 5 6
172 * </pre>
173 * <p>Note that there is no need to add an initial unitary column (column of 1's) when
174 * specifying a model including an intercept term.
175 * </p>
176 * @param x the rectangular array representing the x sample
177 * @throws NullArgumentException if x is null
178 * @throws NoDataException if x is empty
179 * @throws DimensionMismatchException if x is not rectangular
180 */
181 protected void newXSampleData(double[][] x) {
182 if (x == null) {
183 throw new NullArgumentException();
184 }
185 if (x.length == 0) {
186 throw new NoDataException();
187 }
188 if (noIntercept) {
189 this.xMatrix = new Array2DRowRealMatrix(x, true);
190 } else { // Augment design matrix with initial unitary column
191 final int nVars = x[0].length;
192 final double[][] xAug = new double[x.length][nVars + 1];
193 for (int i = 0; i < x.length; i++) {
194 if (x[i].length != nVars) {
195 throw new DimensionMismatchException(x[i].length, nVars);
196 }
197 xAug[i][0] = 1.0d;
198 System.arraycopy(x[i], 0, xAug[i], 1, nVars);
199 }
200 this.xMatrix = new Array2DRowRealMatrix(xAug, false);
201 }
202 }
203
204 /**
205 * Validates sample data. Checks that
206 * <ul><li>Neither x nor y is null or empty;</li>
207 * <li>The length (i.e. number of rows) of x equals the length of y</li>
208 * <li>x has at least one more row than it has columns (i.e. there is
209 * sufficient data to estimate regression coefficients for each of the
210 * columns in x plus an intercept.</li>
211 * </ul>
212 *
213 * @param x the [n,k] array representing the x data
214 * @param y the [n,1] array representing the y data
215 * @throws NullArgumentException if {@code x} or {@code y} is null
216 * @throws DimensionMismatchException if {@code x} and {@code y} do not
217 * have the same length
218 * @throws NoDataException if {@code x} or {@code y} are zero-length
219 * @throws MathIllegalArgumentException if the number of rows of {@code x}
220 * is not larger than the number of columns + 1
221 */
222 protected void validateSampleData(double[][] x, double[] y) throws MathIllegalArgumentException {
223 if (x == null || y == null) {
224 throw new NullArgumentException();
225 }
226 if (x.length != y.length) {
227 throw new DimensionMismatchException(y.length, x.length);
228 }
229 if (x.length == 0) { // Must be no y data either
230 throw new NoDataException();
231 }
232 if (x[0].length + 1 > x.length) {
233 throw new MathIllegalArgumentException(
234 LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS,
235 x.length, x[0].length);
236 }
237 }
238
239 /**
240 * Validates that the x data and covariance matrix have the same
241 * number of rows and that the covariance matrix is square.
242 *
243 * @param x the [n,k] array representing the x sample
244 * @param covariance the [n,n] array representing the covariance matrix
245 * @throws DimensionMismatchException if the number of rows in x is not equal
246 * to the number of rows in covariance
247 * @throws NonSquareMatrixException if the covariance matrix is not square
248 */
249 protected void validateCovarianceData(double[][] x, double[][] covariance) {
250 if (x.length != covariance.length) {
251 throw new DimensionMismatchException(x.length, covariance.length);
252 }
253 if (covariance.length > 0 && covariance.length != covariance[0].length) {
254 throw new NonSquareMatrixException(covariance.length, covariance[0].length);
255 }
256 }
257
258 /**
259 * {@inheritDoc}
260 */
261 @Override
262 public double[] estimateRegressionParameters() {
263 RealVector b = calculateBeta();
264 return b.toArray();
265 }
266
267 /**
268 * {@inheritDoc}
269 */
270 @Override
271 public double[] estimateResiduals() {
272 RealVector b = calculateBeta();
273 RealVector e = yVector.subtract(xMatrix.operate(b));
274 return e.toArray();
275 }
276
277 /**
278 * {@inheritDoc}
279 */
280 @Override
281 public double[][] estimateRegressionParametersVariance() {
282 return calculateBetaVariance().getData();
283 }
284
285 /**
286 * {@inheritDoc}
287 */
288 @Override
289 public double[] estimateRegressionParametersStandardErrors() {
290 double[][] betaVariance = estimateRegressionParametersVariance();
291 double sigma = calculateErrorVariance();
292 int length = betaVariance[0].length;
293 double[] result = new double[length];
294 for (int i = 0; i < length; i++) {
295 result[i] = JdkMath.sqrt(sigma * betaVariance[i][i]);
296 }
297 return result;
298 }
299
300 /**
301 * {@inheritDoc}
302 */
303 @Override
304 public double estimateRegressandVariance() {
305 return calculateYVariance();
306 }
307
308 /**
309 * Estimates the variance of the error.
310 *
311 * @return estimate of the error variance
312 * @since 2.2
313 */
314 public double estimateErrorVariance() {
315 return calculateErrorVariance();
316 }
317
318 /**
319 * Estimates the standard error of the regression.
320 *
321 * @return regression standard error
322 * @since 2.2
323 */
324 public double estimateRegressionStandardError() {
325 return JdkMath.sqrt(estimateErrorVariance());
326 }
327
328 /**
329 * Calculates the beta of multiple linear regression in matrix notation.
330 *
331 * @return beta
332 */
333 protected abstract RealVector calculateBeta();
334
335 /**
336 * Calculates the beta variance of multiple linear regression in matrix
337 * notation.
338 *
339 * @return beta variance
340 */
341 protected abstract RealMatrix calculateBetaVariance();
342
343
344 /**
345 * Calculates the variance of the y values.
346 *
347 * @return Y variance
348 */
349 protected double calculateYVariance() {
350 return new Variance().evaluate(yVector.toArray());
351 }
352
353 /**
354 * <p>Calculates the variance of the error term.</p>
355 * Uses the formula <pre>
356 * var(u) = u · u / (n - k)
357 * </pre>
358 * where n and k are the row and column dimensions of the design
359 * matrix X.
360 *
361 * @return error variance estimate
362 * @since 2.2
363 */
364 protected double calculateErrorVariance() {
365 RealVector residuals = calculateResiduals();
366 return residuals.dotProduct(residuals) /
367 (xMatrix.getRowDimension() - xMatrix.getColumnDimension());
368 }
369
370 /**
371 * Calculates the residuals of multiple linear regression in matrix
372 * notation.
373 *
374 * <pre>
375 * u = y - X * b
376 * </pre>
377 *
378 * @return The residuals [n,1] matrix
379 */
380 protected RealVector calculateResiduals() {
381 RealVector b = calculateBeta();
382 return yVector.subtract(xMatrix.operate(b));
383 }
384 }