001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math.stat.regression;
018    
019    import org.apache.commons.math.linear.Array2DRowRealMatrix;
020    import org.apache.commons.math.linear.LUDecomposition;
021    import org.apache.commons.math.linear.QRDecomposition;
022    import org.apache.commons.math.linear.RealMatrix;
023    import org.apache.commons.math.linear.RealVector;
024    import org.apache.commons.math.stat.StatUtils;
025    import org.apache.commons.math.stat.descriptive.moment.SecondMoment;
026    
027    /**
028     * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
029     * multiple linear regression model.</p>
030     *
031     * <p>The regression coefficients, <code>b</code>, satisfy the normal equations:
032     * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre></p>
033     *
034     * <p>To solve the normal equations, this implementation uses QR decomposition
035     * of the <code>X</code> matrix. (See {@link QRDecomposition} for details on the
036     * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i>
037     * has rows corresponding to sample observations and columns corresponding to independent
038     * variables.  When the model is estimated using an intercept term (i.e. when
039     * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code>
040     * matrix includes an initial column identically equal to 1.  We solve the normal equations
041     * as follows:
042     * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y
043     * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y
044     * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y
045     * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y
046     * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y
047     * R b = Q<sup>T</sup> y </code></pre></p>
048     *
049     * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p>
050     *
051     * @version $Id: OLSMultipleLinearRegression.java 1175100 2011-09-24 04:47:38Z celestin $
052     * @since 2.0
053     */
054    public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
055    
056        /** Cached QR decomposition of X matrix */
057        private QRDecomposition qr = null;
058    
059        /**
060         * Loads model x and y sample data, overriding any previous sample.
061         *
062         * Computes and caches QR decomposition of the X matrix.
063         * @param y the [n,1] array representing the y sample
064         * @param x the [n,k] array representing the x sample
065         * @throws IllegalArgumentException if the x and y array data are not
066         *             compatible for the regression
067         */
068        public void newSampleData(double[] y, double[][] x) {
069            validateSampleData(x, y);
070            newYSampleData(y);
071            newXSampleData(x);
072        }
073    
074        /**
075         * {@inheritDoc}
076         * <p>This implementation computes and caches the QR decomposition of the X matrix.</p>
077         */
078        @Override
079        public void newSampleData(double[] data, int nobs, int nvars) {
080            super.newSampleData(data, nobs, nvars);
081            qr = new QRDecomposition(X);
082        }
083    
084        /**
085         * <p>Compute the "hat" matrix.
086         * </p>
087         * <p>The hat matrix is defined in terms of the design matrix X
088         *  by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
089         * </p>
090         * <p>The implementation here uses the QR decomposition to compute the
091         * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
092         * p-dimensional identity matrix augmented by 0's.  This computational
093         * formula is from "The Hat Matrix in Regression and ANOVA",
094         * David C. Hoaglin and Roy E. Welsch,
095         * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
096         *
097         * @return the hat matrix
098         */
099        public RealMatrix calculateHat() {
100            // Create augmented identity matrix
101            RealMatrix Q = qr.getQ();
102            final int p = qr.getR().getColumnDimension();
103            final int n = Q.getColumnDimension();
104            Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
105            double[][] augIData = augI.getDataRef();
106            for (int i = 0; i < n; i++) {
107                for (int j =0; j < n; j++) {
108                    if (i == j && i < p) {
109                        augIData[i][j] = 1d;
110                    } else {
111                        augIData[i][j] = 0d;
112                    }
113                }
114            }
115    
116            // Compute and return Hat matrix
117            return Q.multiply(augI).multiply(Q.transpose());
118        }
119    
120        /**
121         * <p>Returns the sum of squared deviations of Y from its mean.</p>
122         *
123         * <p>If the model has no intercept term, <code>0</code> is used for the
124         * mean of Y - i.e., what is returned is the sum of the squared Y values.</p>
125         *
126         * <p>The value returned by this method is the SSTO value used in
127         * the {@link #calculateRSquared() R-squared} computation.</p>
128         *
129         * @return SSTO - the total sum of squares
130         * @see #isNoIntercept()
131         * @since 2.2
132         */
133        public double calculateTotalSumOfSquares() {
134            if (isNoIntercept()) {
135                return StatUtils.sumSq(Y.toArray());
136            } else {
137                return new SecondMoment().evaluate(Y.toArray());
138            }
139        }
140    
141        /**
142         * Returns the sum of squared residuals.
143         *
144         * @return residual sum of squares
145         * @since 2.2
146         */
147        public double calculateResidualSumOfSquares() {
148            final RealVector residuals = calculateResiduals();
149            return residuals.dotProduct(residuals);
150        }
151    
152        /**
153         * Returns the R-Squared statistic, defined by the formula <pre>
154         * R<sup>2</sup> = 1 - SSR / SSTO
155         * </pre>
156         * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
157         * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
158         *
159         * @return R-square statistic
160         * @since 2.2
161         */
162        public double calculateRSquared() {
163            return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
164        }
165    
166        /**
167         * <p>Returns the adjusted R-squared statistic, defined by the formula <pre>
168         * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)]
169         * </pre>
170         * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
171         * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
172         * of observations and p is the number of parameters estimated (including the intercept).</p>
173         *
174         * <p>If the regression is estimated without an intercept term, what is returned is <pre>
175         * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code>
176         * </pre></p>
177         *
178         * @return adjusted R-Squared statistic
179         * @see #isNoIntercept()
180         * @since 2.2
181         */
182        public double calculateAdjustedRSquared() {
183            final double n = X.getRowDimension();
184            if (isNoIntercept()) {
185                return 1 - (1 - calculateRSquared()) * (n / (n - X.getColumnDimension()));
186            } else {
187                return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
188                    (calculateTotalSumOfSquares() * (n - X.getColumnDimension()));
189            }
190        }
191    
192        /**
193         * {@inheritDoc}
194         * <p>This implementation computes and caches the QR decomposition of the X matrix
195         * once it is successfully loaded.</p>
196         */
197        @Override
198        protected void newXSampleData(double[][] x) {
199            super.newXSampleData(x);
200            qr = new QRDecomposition(X);
201        }
202    
203        /**
204         * Calculates the regression coefficients using OLS.
205         *
206         * @return beta
207         */
208        @Override
209        protected RealVector calculateBeta() {
210            return qr.getSolver().solve(Y);
211        }
212    
213        /**
214         * <p>Calculates the variance-covariance matrix of the regression parameters.
215         * </p>
216         * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
217         * </p>
218         * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
219         * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
220         * R included, where p = the length of the beta vector.</p>
221         *
222         * @return The beta variance-covariance matrix
223         */
224        @Override
225        protected RealMatrix calculateBetaVariance() {
226            int p = X.getColumnDimension();
227            RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
228            RealMatrix Rinv = new LUDecomposition(Raug).getSolver().getInverse();
229            return Rinv.multiply(Rinv.transpose());
230        }
231    
232    }