001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.stat.regression;
018
019import org.apache.commons.math3.exception.MathIllegalArgumentException;
020import org.apache.commons.math3.linear.Array2DRowRealMatrix;
021import org.apache.commons.math3.linear.LUDecomposition;
022import org.apache.commons.math3.linear.QRDecomposition;
023import org.apache.commons.math3.linear.RealMatrix;
024import org.apache.commons.math3.linear.RealVector;
025import org.apache.commons.math3.stat.StatUtils;
026import org.apache.commons.math3.stat.descriptive.moment.SecondMoment;
027
028/**
029 * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
030 * multiple linear regression model.</p>
031 *
032 * <p>The regression coefficients, <code>b</code>, satisfy the normal equations:
033 * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre></p>
034 *
035 * <p>To solve the normal equations, this implementation uses QR decomposition
036 * of the <code>X</code> matrix. (See {@link QRDecomposition} for details on the
037 * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i>
038 * has rows corresponding to sample observations and columns corresponding to independent
039 * variables.  When the model is estimated using an intercept term (i.e. when
040 * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code>
041 * matrix includes an initial column identically equal to 1.  We solve the normal equations
042 * as follows:
043 * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y
044 * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y
045 * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y
046 * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y
047 * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y
048 * R b = Q<sup>T</sup> y </code></pre></p>
049 *
050 * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p>
051 *
052 * @since 2.0
053 */
054public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
055
056    /** Cached QR decomposition of X matrix */
057    private QRDecomposition qr = null;
058
059    /** Singularity threshold for QR decomposition */
060    private final double threshold;
061
062    /**
063     * Create an empty OLSMultipleLinearRegression instance.
064     */
065    public OLSMultipleLinearRegression() {
066        this(0d);
067    }
068
069    /**
070     * Create an empty OLSMultipleLinearRegression instance, using the given
071     * singularity threshold for the QR decomposition.
072     *
073     * @param threshold the singularity threshold
074     * @since 3.3
075     */
076    public OLSMultipleLinearRegression(final double threshold) {
077        this.threshold = threshold;
078    }
079
080    /**
081     * Loads model x and y sample data, overriding any previous sample.
082     *
083     * Computes and caches QR decomposition of the X matrix.
084     * @param y the [n,1] array representing the y sample
085     * @param x the [n,k] array representing the x sample
086     * @throws MathIllegalArgumentException if the x and y array data are not
087     *             compatible for the regression
088     */
089    public void newSampleData(double[] y, double[][] x) throws MathIllegalArgumentException {
090        validateSampleData(x, y);
091        newYSampleData(y);
092        newXSampleData(x);
093    }
094
095    /**
096     * {@inheritDoc}
097     * <p>This implementation computes and caches the QR decomposition of the X matrix.</p>
098     */
099    @Override
100    public void newSampleData(double[] data, int nobs, int nvars) {
101        super.newSampleData(data, nobs, nvars);
102        qr = new QRDecomposition(getX(), threshold);
103    }
104
105    /**
106     * <p>Compute the "hat" matrix.
107     * </p>
108     * <p>The hat matrix is defined in terms of the design matrix X
109     *  by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
110     * </p>
111     * <p>The implementation here uses the QR decomposition to compute the
112     * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
113     * p-dimensional identity matrix augmented by 0's.  This computational
114     * formula is from "The Hat Matrix in Regression and ANOVA",
115     * David C. Hoaglin and Roy E. Welsch,
116     * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
117     * </p>
118     * <p>Data for the model must have been successfully loaded using one of
119     * the {@code newSampleData} methods before invoking this method; otherwise
120     * a {@code NullPointerException} will be thrown.</p>
121     *
122     * @return the hat matrix
123     * @throws NullPointerException unless method {@code newSampleData} has been
124     * called beforehand.
125     */
126    public RealMatrix calculateHat() {
127        // Create augmented identity matrix
128        RealMatrix Q = qr.getQ();
129        final int p = qr.getR().getColumnDimension();
130        final int n = Q.getColumnDimension();
131        // No try-catch or advertised NotStrictlyPositiveException - NPE above if n < 3
132        Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
133        double[][] augIData = augI.getDataRef();
134        for (int i = 0; i < n; i++) {
135            for (int j =0; j < n; j++) {
136                if (i == j && i < p) {
137                    augIData[i][j] = 1d;
138                } else {
139                    augIData[i][j] = 0d;
140                }
141            }
142        }
143
144        // Compute and return Hat matrix
145        // No DME advertised - args valid if we get here
146        return Q.multiply(augI).multiply(Q.transpose());
147    }
148
149    /**
150     * <p>Returns the sum of squared deviations of Y from its mean.</p>
151     *
152     * <p>If the model has no intercept term, <code>0</code> is used for the
153     * mean of Y - i.e., what is returned is the sum of the squared Y values.</p>
154     *
155     * <p>The value returned by this method is the SSTO value used in
156     * the {@link #calculateRSquared() R-squared} computation.</p>
157     *
158     * @return SSTO - the total sum of squares
159     * @throws NullPointerException if the sample has not been set
160     * @see #isNoIntercept()
161     * @since 2.2
162     */
163    public double calculateTotalSumOfSquares() {
164        if (isNoIntercept()) {
165            return StatUtils.sumSq(getY().toArray());
166        } else {
167            return new SecondMoment().evaluate(getY().toArray());
168        }
169    }
170
171    /**
172     * Returns the sum of squared residuals.
173     *
174     * @return residual sum of squares
175     * @since 2.2
176     * @throws org.apache.commons.math3.linear.SingularMatrixException if the design matrix is singular
177     * @throws NullPointerException if the data for the model have not been loaded
178     */
179    public double calculateResidualSumOfSquares() {
180        final RealVector residuals = calculateResiduals();
181        // No advertised DME, args are valid
182        return residuals.dotProduct(residuals);
183    }
184
185    /**
186     * Returns the R-Squared statistic, defined by the formula <pre>
187     * R<sup>2</sup> = 1 - SSR / SSTO
188     * </pre>
189     * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
190     * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
191     *
192     * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
193     *
194     * @return R-square statistic
195     * @throws NullPointerException if the sample has not been set
196     * @throws org.apache.commons.math3.linear.SingularMatrixException if the design matrix is singular
197     * @since 2.2
198     */
199    public double calculateRSquared() {
200        return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
201    }
202
203    /**
204     * <p>Returns the adjusted R-squared statistic, defined by the formula <pre>
205     * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)]
206     * </pre>
207     * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
208     * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
209     * of observations and p is the number of parameters estimated (including the intercept).</p>
210     *
211     * <p>If the regression is estimated without an intercept term, what is returned is <pre>
212     * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code>
213     * </pre></p>
214     *
215     * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
216     *
217     * @return adjusted R-Squared statistic
218     * @throws NullPointerException if the sample has not been set
219     * @throws org.apache.commons.math3.linear.SingularMatrixException if the design matrix is singular
220     * @see #isNoIntercept()
221     * @since 2.2
222     */
223    public double calculateAdjustedRSquared() {
224        final double n = getX().getRowDimension();
225        if (isNoIntercept()) {
226            return 1 - (1 - calculateRSquared()) * (n / (n - getX().getColumnDimension()));
227        } else {
228            return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
229                (calculateTotalSumOfSquares() * (n - getX().getColumnDimension()));
230        }
231    }
232
233    /**
234     * {@inheritDoc}
235     * <p>This implementation computes and caches the QR decomposition of the X matrix
236     * once it is successfully loaded.</p>
237     */
238    @Override
239    protected void newXSampleData(double[][] x) {
240        super.newXSampleData(x);
241        qr = new QRDecomposition(getX(), threshold);
242    }
243
244    /**
245     * Calculates the regression coefficients using OLS.
246     *
247     * <p>Data for the model must have been successfully loaded using one of
248     * the {@code newSampleData} methods before invoking this method; otherwise
249     * a {@code NullPointerException} will be thrown.</p>
250     *
251     * @return beta
252     * @throws org.apache.commons.math3.linear.SingularMatrixException if the design matrix is singular
253     * @throws NullPointerException if the data for the model have not been loaded
254     */
255    @Override
256    protected RealVector calculateBeta() {
257        return qr.getSolver().solve(getY());
258    }
259
260    /**
261     * <p>Calculates the variance-covariance matrix of the regression parameters.
262     * </p>
263     * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
264     * </p>
265     * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
266     * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
267     * R included, where p = the length of the beta vector.</p>
268     *
269     * <p>Data for the model must have been successfully loaded using one of
270     * the {@code newSampleData} methods before invoking this method; otherwise
271     * a {@code NullPointerException} will be thrown.</p>
272     *
273     * @return The beta variance-covariance matrix
274     * @throws org.apache.commons.math3.linear.SingularMatrixException if the design matrix is singular
275     * @throws NullPointerException if the data for the model have not been loaded
276     */
277    @Override
278    protected RealMatrix calculateBetaVariance() {
279        int p = getX().getColumnDimension();
280        RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
281        RealMatrix Rinv = new LUDecomposition(Raug).getSolver().getInverse();
282        return Rinv.multiply(Rinv.transpose());
283    }
284
285}