View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.stat.regression;
18  
19  import org.apache.commons.math4.legacy.exception.MathIllegalArgumentException;
20  import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
21  import org.apache.commons.math4.legacy.linear.LUDecomposition;
22  import org.apache.commons.math4.legacy.linear.QRDecomposition;
23  import org.apache.commons.math4.legacy.linear.RealMatrix;
24  import org.apache.commons.math4.legacy.linear.RealVector;
25  import org.apache.commons.math4.legacy.stat.StatUtils;
26  import org.apache.commons.math4.legacy.stat.descriptive.moment.SecondMoment;
27  
28  /**
29   * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
30   * multiple linear regression model.</p>
31   *
32   * <p>The regression coefficients, <code>b</code>, satisfy the normal equations:
33   * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre>
34   *
35   * <p>To solve the normal equations, this implementation uses QR decomposition
36   * of the <code>X</code> matrix. (See {@link QRDecomposition} for details on the
37   * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i>
38   * has rows corresponding to sample observations and columns corresponding to independent
39   * variables.  When the model is estimated using an intercept term (i.e. when
40   * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code>
41   * matrix includes an initial column identically equal to 1.  We solve the normal equations
42   * as follows:
43   * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y
44   * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y
45   * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y
46   * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y
47   * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y
48   * R b = Q<sup>T</sup> y </code></pre>
49   *
50   * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p>
51   *
52   * @since 2.0
53   */
54  public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
55  
56      /** Cached QR decomposition of X matrix. */
57      private QRDecomposition qr;
58  
59      /** Singularity threshold for QR decomposition. */
60      private final double threshold;
61  
62      /**
63       * Create an empty OLSMultipleLinearRegression instance.
64       */
65      public OLSMultipleLinearRegression() {
66          this(0d);
67      }
68  
69      /**
70       * Create an empty OLSMultipleLinearRegression instance, using the given
71       * singularity threshold for the QR decomposition.
72       *
73       * @param threshold the singularity threshold
74       * @since 3.3
75       */
76      public OLSMultipleLinearRegression(final double threshold) {
77          this.threshold = threshold;
78      }
79  
80      /**
81       * Loads model x and y sample data, overriding any previous sample.
82       *
83       * Computes and caches QR decomposition of the X matrix.
84       * @param y the [n,1] array representing the y sample
85       * @param x the [n,k] array representing the x sample
86       * @throws MathIllegalArgumentException if the x and y array data are not
87       *             compatible for the regression
88       */
89      public void newSampleData(double[] y, double[][] x) throws MathIllegalArgumentException {
90          validateSampleData(x, y);
91          newYSampleData(y);
92          newXSampleData(x);
93      }
94  
95      /**
96       * {@inheritDoc}
97       * <p>This implementation computes and caches the QR decomposition of the X matrix.</p>
98       */
99      @Override
100     public void newSampleData(double[] data, int nobs, int nvars) {
101         super.newSampleData(data, nobs, nvars);
102         qr = new QRDecomposition(getX(), threshold);
103     }
104 
105     /**
106      * <p>Compute the "hat" matrix.
107      * </p>
108      * <p>The hat matrix is defined in terms of the design matrix X
109      *  by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
110      * </p>
111      * <p>The implementation here uses the QR decomposition to compute the
112      * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
113      * p-dimensional identity matrix augmented by 0's.  This computational
114      * formula is from "The Hat Matrix in Regression and ANOVA",
115      * David C. Hoaglin and Roy E. Welsch,
116      * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
117      * </p>
118      * <p>Data for the model must have been successfully loaded using one of
119      * the {@code newSampleData} methods before invoking this method; otherwise
120      * a {@code NullPointerException} will be thrown.</p>
121      *
122      * @return the hat matrix
123      * @throws NullPointerException unless method {@code newSampleData} has been
124      * called beforehand.
125      */
126     public RealMatrix calculateHat() {
127         // Create augmented identity matrix
128         RealMatrix q = qr.getQ();
129         final int p = qr.getR().getColumnDimension();
130         final int n = q.getColumnDimension();
131         // No try-catch or advertised NotStrictlyPositiveException - NPE above if n < 3
132         Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
133         double[][] augIData = augI.getDataRef();
134         for (int i = 0; i < n; i++) {
135             for (int j =0; j < n; j++) {
136                 if (i == j && i < p) {
137                     augIData[i][j] = 1d;
138                 } else {
139                     augIData[i][j] = 0d;
140                 }
141             }
142         }
143 
144         // Compute and return Hat matrix
145         // No DME advertised - args valid if we get here
146         return q.multiply(augI).multiply(q.transpose());
147     }
148 
149     /**
150      * <p>Returns the sum of squared deviations of Y from its mean.</p>
151      *
152      * <p>If the model has no intercept term, <code>0</code> is used for the
153      * mean of Y - i.e., what is returned is the sum of the squared Y values.</p>
154      *
155      * <p>The value returned by this method is the SSTO value used in
156      * the {@link #calculateRSquared() R-squared} computation.</p>
157      *
158      * @return SSTO - the total sum of squares
159      * @throws NullPointerException if the sample has not been set
160      * @see #isNoIntercept()
161      * @since 2.2
162      */
163     public double calculateTotalSumOfSquares() {
164         if (isNoIntercept()) {
165             return StatUtils.sumSq(getY().toArray());
166         } else {
167             return new SecondMoment().evaluate(getY().toArray());
168         }
169     }
170 
171     /**
172      * Returns the sum of squared residuals.
173      *
174      * @return residual sum of squares
175      * @since 2.2
176      * @throws org.apache.commons.math4.legacy.linear.SingularMatrixException if the design matrix is singular
177      * @throws NullPointerException if the data for the model have not been loaded
178      */
179     public double calculateResidualSumOfSquares() {
180         final RealVector residuals = calculateResiduals();
181         // No advertised DME, args are valid
182         return residuals.dotProduct(residuals);
183     }
184 
185     /**
186      * Returns the R-Squared statistic, defined by the formula <div style="white-space: pre"><code>
187      * R<sup>2</sup> = 1 - SSR / SSTO
188      * </code></div>
189      * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
190      * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
191      *
192      * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
193      *
194      * @return R-square statistic
195      * @throws NullPointerException if the sample has not been set
196      * @throws org.apache.commons.math4.legacy.linear.SingularMatrixException if the design matrix is singular
197      * @since 2.2
198      */
199     public double calculateRSquared() {
200         return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
201     }
202 
203     /**
204      * <p>Returns the adjusted R-squared statistic, defined by the formula <div style="white-space: pre"><code>
205      * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)]
206      * </code></div>
207      * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
208      * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
209      * of observations and p is the number of parameters estimated (including the intercept).
210      *
211      * <p>If the regression is estimated without an intercept term, what is returned is <pre>
212      * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code>
213      * </pre>
214      *
215      * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
216      *
217      * @return adjusted R-Squared statistic
218      * @throws NullPointerException if the sample has not been set
219      * @throws org.apache.commons.math4.legacy.linear.SingularMatrixException if the design matrix is singular
220      * @see #isNoIntercept()
221      * @since 2.2
222      */
223     public double calculateAdjustedRSquared() {
224         final double n = getX().getRowDimension();
225         if (isNoIntercept()) {
226             return 1 - (1 - calculateRSquared()) * (n / (n - getX().getColumnDimension()));
227         } else {
228             return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
229                 (calculateTotalSumOfSquares() * (n - getX().getColumnDimension()));
230         }
231     }
232 
233     /**
234      * {@inheritDoc}
235      * <p>This implementation computes and caches the QR decomposition of the X matrix
236      * once it is successfully loaded.</p>
237      */
238     @Override
239     protected void newXSampleData(double[][] x) {
240         super.newXSampleData(x);
241         qr = new QRDecomposition(getX(), threshold);
242     }
243 
244     /**
245      * Calculates the regression coefficients using OLS.
246      *
247      * <p>Data for the model must have been successfully loaded using one of
248      * the {@code newSampleData} methods before invoking this method; otherwise
249      * a {@code NullPointerException} will be thrown.</p>
250      *
251      * @return beta
252      * @throws org.apache.commons.math4.legacy.linear.SingularMatrixException if the design matrix is singular
253      * @throws NullPointerException if the data for the model have not been loaded
254      */
255     @Override
256     protected RealVector calculateBeta() {
257         return qr.getSolver().solve(getY());
258     }
259 
260     /**
261      * <p>Calculates the variance-covariance matrix of the regression parameters.
262      * </p>
263      * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
264      * </p>
265      * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
266      * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
267      * R included, where p = the length of the beta vector.</p>
268      *
269      * <p>Data for the model must have been successfully loaded using one of
270      * the {@code newSampleData} methods before invoking this method; otherwise
271      * a {@code NullPointerException} will be thrown.</p>
272      *
273      * @return The beta variance-covariance matrix
274      * @throws org.apache.commons.math4.legacy.linear.SingularMatrixException if the design matrix is singular
275      * @throws NullPointerException if the data for the model have not been loaded
276      */
277     @Override
278     protected RealMatrix calculateBetaVariance() {
279         int p = getX().getColumnDimension();
280         RealMatrix rAug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
281         RealMatrix rInv = new LUDecomposition(rAug).getSolver().getInverse();
282         return rInv.multiply(rInv.transpose());
283     }
284 }