View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.stat.regression;
18  
19  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
20  import org.apache.commons.math4.legacy.exception.InsufficientDataException;
21  import org.apache.commons.math4.legacy.exception.MathIllegalArgumentException;
22  import org.apache.commons.math4.legacy.exception.NoDataException;
23  import org.apache.commons.math4.legacy.exception.NullArgumentException;
24  import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
25  import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
26  import org.apache.commons.math4.legacy.linear.ArrayRealVector;
27  import org.apache.commons.math4.legacy.linear.NonSquareMatrixException;
28  import org.apache.commons.math4.legacy.linear.RealMatrix;
29  import org.apache.commons.math4.legacy.linear.RealVector;
30  import org.apache.commons.math4.legacy.stat.descriptive.moment.Variance;
31  import org.apache.commons.math4.core.jdkmath.JdkMath;
32  
33  /**
34   * Abstract base class for implementations of MultipleLinearRegression.
35   * @since 2.0
36   */
37  public abstract class AbstractMultipleLinearRegression implements
38          MultipleLinearRegression {
39  
40      /** X sample data. */
41      private RealMatrix xMatrix;
42  
43      /** Y sample data. */
44      private RealVector yVector;
45  
46      /** Whether or not the regression model includes an intercept.  True means no intercept. */
47      private boolean noIntercept;
48  
49      /**
50       * @return the X sample data.
51       */
52      protected RealMatrix getX() {
53          return xMatrix;
54      }
55  
56      /**
57       * @return the Y sample data.
58       */
59      protected RealVector getY() {
60          return yVector;
61      }
62  
63      /**
64       * @return true if the model has no intercept term; false otherwise
65       * @since 2.2
66       */
67      public boolean isNoIntercept() {
68          return noIntercept;
69      }
70  
71      /**
72       * @param noIntercept true means the model is to be estimated without an intercept term
73       * @since 2.2
74       */
75      public void setNoIntercept(boolean noIntercept) {
76          this.noIntercept = noIntercept;
77      }
78  
79      /**
80       * <p>Loads model x and y sample data from a flat input array, overriding any previous sample.
81       * </p>
82       * <p>Assumes that rows are concatenated with y values first in each row.  For example, an input
83       * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with
84       * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two
85       * independent variables, as below:
86       * <pre>
87       *   y   x[0]  x[1]
88       *   --------------
89       *   1     2     3
90       *   4     5     6
91       *   7     8     9
92       * </pre>
93       *
94       * <p>Note that there is no need to add an initial unitary column (column of 1's) when
95       * specifying a model including an intercept term.  If {@link #isNoIntercept()} is <code>true</code>,
96       * the X matrix will be created without an initial column of "1"s; otherwise this column will
97       * be added.
98       * </p>
99       * <p>Throws IllegalArgumentException if any of the following preconditions fail:
100      * <ul><li><code>data</code> cannot be null</li>
101      * <li><code>data.length = nobs * (nvars + 1)</code></li>
102      * <li>{@code nobs > nvars}</li></ul>
103      *
104      * @param data input data array
105      * @param nobs number of observations (rows)
106      * @param nvars number of independent variables (columns, not counting y)
107      * @throws NullArgumentException if the data array is null
108      * @throws DimensionMismatchException if the length of the data array is not equal
109      * to <code>nobs * (nvars + 1)</code>
110      * @throws InsufficientDataException if <code>nobs</code> is less than
111      * <code>nvars + 1</code>
112      */
113     public void newSampleData(double[] data, int nobs, int nvars) {
114         if (data == null) {
115             throw new NullArgumentException();
116         }
117         if (data.length != nobs * (nvars + 1)) {
118             throw new DimensionMismatchException(data.length, nobs * (nvars + 1));
119         }
120         if (nobs <= nvars) {
121             throw new InsufficientDataException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE, nobs, nvars + 1);
122         }
123         double[] y = new double[nobs];
124         final int cols = noIntercept ? nvars: nvars + 1;
125         double[][] x = new double[nobs][cols];
126         int pointer = 0;
127         for (int i = 0; i < nobs; i++) {
128             y[i] = data[pointer++];
129             if (!noIntercept) {
130                 x[i][0] = 1.0d;
131             }
132             for (int j = noIntercept ? 0 : 1; j < cols; j++) {
133                 x[i][j] = data[pointer++];
134             }
135         }
136         this.xMatrix = new Array2DRowRealMatrix(x);
137         this.yVector = new ArrayRealVector(y);
138     }
139 
140     /**
141      * Loads new y sample data, overriding any previous data.
142      *
143      * @param y the array representing the y sample
144      * @throws NullArgumentException if y is null
145      * @throws NoDataException if y is empty
146      */
147     protected void newYSampleData(double[] y) {
148         if (y == null) {
149             throw new NullArgumentException();
150         }
151         if (y.length == 0) {
152             throw new NoDataException();
153         }
154         this.yVector = new ArrayRealVector(y);
155     }
156 
157     /**
158      * <p>Loads new x sample data, overriding any previous data.
159      * </p>
160      * The input <code>x</code> array should have one row for each sample
161      * observation, with columns corresponding to independent variables.
162      * For example, if <pre>
163      * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre>
164      * then <code>setXSampleData(x) </code> results in a model with two independent
165      * variables and 3 observations:
166      * <pre>
167      *   x[0]  x[1]
168      *   ----------
169      *     1    2
170      *     3    4
171      *     5    6
172      * </pre>
173      * <p>Note that there is no need to add an initial unitary column (column of 1's) when
174      * specifying a model including an intercept term.
175      * </p>
176      * @param x the rectangular array representing the x sample
177      * @throws NullArgumentException if x is null
178      * @throws NoDataException if x is empty
179      * @throws DimensionMismatchException if x is not rectangular
180      */
181     protected void newXSampleData(double[][] x) {
182         if (x == null) {
183             throw new NullArgumentException();
184         }
185         if (x.length == 0) {
186             throw new NoDataException();
187         }
188         if (noIntercept) {
189             this.xMatrix = new Array2DRowRealMatrix(x, true);
190         } else { // Augment design matrix with initial unitary column
191             final int nVars = x[0].length;
192             final double[][] xAug = new double[x.length][nVars + 1];
193             for (int i = 0; i < x.length; i++) {
194                 if (x[i].length != nVars) {
195                     throw new DimensionMismatchException(x[i].length, nVars);
196                 }
197                 xAug[i][0] = 1.0d;
198                 System.arraycopy(x[i], 0, xAug[i], 1, nVars);
199             }
200             this.xMatrix = new Array2DRowRealMatrix(xAug, false);
201         }
202     }
203 
204     /**
205      * Validates sample data.  Checks that
206      * <ul><li>Neither x nor y is null or empty;</li>
207      * <li>The length (i.e. number of rows) of x equals the length of y</li>
208      * <li>x has at least one more row than it has columns (i.e. there is
209      * sufficient data to estimate regression coefficients for each of the
210      * columns in x plus an intercept.</li>
211      * </ul>
212      *
213      * @param x the [n,k] array representing the x data
214      * @param y the [n,1] array representing the y data
215      * @throws NullArgumentException if {@code x} or {@code y} is null
216      * @throws DimensionMismatchException if {@code x} and {@code y} do not
217      * have the same length
218      * @throws NoDataException if {@code x} or {@code y} are zero-length
219      * @throws MathIllegalArgumentException if the number of rows of {@code x}
220      * is not larger than the number of columns + 1
221      */
222     protected void validateSampleData(double[][] x, double[] y) throws MathIllegalArgumentException {
223         if (x == null || y == null) {
224             throw new NullArgumentException();
225         }
226         if (x.length != y.length) {
227             throw new DimensionMismatchException(y.length, x.length);
228         }
229         if (x.length == 0) {  // Must be no y data either
230             throw new NoDataException();
231         }
232         if (x[0].length + 1 > x.length) {
233             throw new MathIllegalArgumentException(
234                     LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS,
235                     x.length, x[0].length);
236         }
237     }
238 
239     /**
240      * Validates that the x data and covariance matrix have the same
241      * number of rows and that the covariance matrix is square.
242      *
243      * @param x the [n,k] array representing the x sample
244      * @param covariance the [n,n] array representing the covariance matrix
245      * @throws DimensionMismatchException if the number of rows in x is not equal
246      * to the number of rows in covariance
247      * @throws NonSquareMatrixException if the covariance matrix is not square
248      */
249     protected void validateCovarianceData(double[][] x, double[][] covariance) {
250         if (x.length != covariance.length) {
251             throw new DimensionMismatchException(x.length, covariance.length);
252         }
253         if (covariance.length > 0 && covariance.length != covariance[0].length) {
254             throw new NonSquareMatrixException(covariance.length, covariance[0].length);
255         }
256     }
257 
258     /**
259      * {@inheritDoc}
260      */
261     @Override
262     public double[] estimateRegressionParameters() {
263         RealVector b = calculateBeta();
264         return b.toArray();
265     }
266 
267     /**
268      * {@inheritDoc}
269      */
270     @Override
271     public double[] estimateResiduals() {
272         RealVector b = calculateBeta();
273         RealVector e = yVector.subtract(xMatrix.operate(b));
274         return e.toArray();
275     }
276 
277     /**
278      * {@inheritDoc}
279      */
280     @Override
281     public double[][] estimateRegressionParametersVariance() {
282         return calculateBetaVariance().getData();
283     }
284 
285     /**
286      * {@inheritDoc}
287      */
288     @Override
289     public double[] estimateRegressionParametersStandardErrors() {
290         double[][] betaVariance = estimateRegressionParametersVariance();
291         double sigma = calculateErrorVariance();
292         int length = betaVariance[0].length;
293         double[] result = new double[length];
294         for (int i = 0; i < length; i++) {
295             result[i] = JdkMath.sqrt(sigma * betaVariance[i][i]);
296         }
297         return result;
298     }
299 
300     /**
301      * {@inheritDoc}
302      */
303     @Override
304     public double estimateRegressandVariance() {
305         return calculateYVariance();
306     }
307 
308     /**
309      * Estimates the variance of the error.
310      *
311      * @return estimate of the error variance
312      * @since 2.2
313      */
314     public double estimateErrorVariance() {
315         return calculateErrorVariance();
316     }
317 
318     /**
319      * Estimates the standard error of the regression.
320      *
321      * @return regression standard error
322      * @since 2.2
323      */
324     public double estimateRegressionStandardError() {
325         return JdkMath.sqrt(estimateErrorVariance());
326     }
327 
328     /**
329      * Calculates the beta of multiple linear regression in matrix notation.
330      *
331      * @return beta
332      */
333     protected abstract RealVector calculateBeta();
334 
335     /**
336      * Calculates the beta variance of multiple linear regression in matrix
337      * notation.
338      *
339      * @return beta variance
340      */
341     protected abstract RealMatrix calculateBetaVariance();
342 
343 
344     /**
345      * Calculates the variance of the y values.
346      *
347      * @return Y variance
348      */
349     protected double calculateYVariance() {
350         return new Variance().evaluate(yVector.toArray());
351     }
352 
353     /**
354      * <p>Calculates the variance of the error term.</p>
355      * Uses the formula <pre>
356      * var(u) = u &middot; u / (n - k)
357      * </pre>
358      * where n and k are the row and column dimensions of the design
359      * matrix X.
360      *
361      * @return error variance estimate
362      * @since 2.2
363      */
364     protected double calculateErrorVariance() {
365         RealVector residuals = calculateResiduals();
366         return residuals.dotProduct(residuals) /
367                (xMatrix.getRowDimension() - xMatrix.getColumnDimension());
368     }
369 
370     /**
371      * Calculates the residuals of multiple linear regression in matrix
372      * notation.
373      *
374      * <pre>
375      * u = y - X * b
376      * </pre>
377      *
378      * @return The residuals [n,1] matrix
379      */
380     protected RealVector calculateResiduals() {
381         RealVector b = calculateBeta();
382         return yVector.subtract(xMatrix.operate(b));
383     }
384 }