001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math3.stat.regression; 018 019import org.apache.commons.math3.exception.MathIllegalArgumentException; 020import org.apache.commons.math3.linear.Array2DRowRealMatrix; 021import org.apache.commons.math3.linear.LUDecomposition; 022import org.apache.commons.math3.linear.QRDecomposition; 023import org.apache.commons.math3.linear.RealMatrix; 024import org.apache.commons.math3.linear.RealVector; 025import org.apache.commons.math3.stat.StatUtils; 026import org.apache.commons.math3.stat.descriptive.moment.SecondMoment; 027 028/** 029 * <p>Implements ordinary least squares (OLS) to estimate the parameters of a 030 * multiple linear regression model.</p> 031 * 032 * <p>The regression coefficients, <code>b</code>, satisfy the normal equations: 033 * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre></p> 034 * 035 * <p>To solve the normal equations, this implementation uses QR decomposition 036 * of the <code>X</code> matrix. (See {@link QRDecomposition} for details on the 037 * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i> 038 * has rows corresponding to sample observations and columns corresponding to independent 039 * variables. When the model is estimated using an intercept term (i.e. when 040 * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code> 041 * matrix includes an initial column identically equal to 1. We solve the normal equations 042 * as follows: 043 * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y 044 * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y 045 * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y 046 * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y 047 * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y 048 * R b = Q<sup>T</sup> y </code></pre></p> 049 * 050 * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p> 051 * 052 * @since 2.0 053 */ 054public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression { 055 056 /** Cached QR decomposition of X matrix */ 057 private QRDecomposition qr = null; 058 059 /** Singularity threshold for QR decomposition */ 060 private final double threshold; 061 062 /** 063 * Create an empty OLSMultipleLinearRegression instance. 064 */ 065 public OLSMultipleLinearRegression() { 066 this(0d); 067 } 068 069 /** 070 * Create an empty OLSMultipleLinearRegression instance, using the given 071 * singularity threshold for the QR decomposition. 072 * 073 * @param threshold the singularity threshold 074 * @since 3.3 075 */ 076 public OLSMultipleLinearRegression(final double threshold) { 077 this.threshold = threshold; 078 } 079 080 /** 081 * Loads model x and y sample data, overriding any previous sample. 082 * 083 * Computes and caches QR decomposition of the X matrix. 084 * @param y the [n,1] array representing the y sample 085 * @param x the [n,k] array representing the x sample 086 * @throws MathIllegalArgumentException if the x and y array data are not 087 * compatible for the regression 088 */ 089 public void newSampleData(double[] y, double[][] x) throws MathIllegalArgumentException { 090 validateSampleData(x, y); 091 newYSampleData(y); 092 newXSampleData(x); 093 } 094 095 /** 096 * {@inheritDoc} 097 * <p>This implementation computes and caches the QR decomposition of the X matrix.</p> 098 */ 099 @Override 100 public void newSampleData(double[] data, int nobs, int nvars) { 101 super.newSampleData(data, nobs, nvars); 102 qr = new QRDecomposition(getX(), threshold); 103 } 104 105 /** 106 * <p>Compute the "hat" matrix. 107 * </p> 108 * <p>The hat matrix is defined in terms of the design matrix X 109 * by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup> 110 * </p> 111 * <p>The implementation here uses the QR decomposition to compute the 112 * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the 113 * p-dimensional identity matrix augmented by 0's. This computational 114 * formula is from "The Hat Matrix in Regression and ANOVA", 115 * David C. Hoaglin and Roy E. Welsch, 116 * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22. 117 * </p> 118 * <p>Data for the model must have been successfully loaded using one of 119 * the {@code newSampleData} methods before invoking this method; otherwise 120 * a {@code NullPointerException} will be thrown.</p> 121 * 122 * @return the hat matrix 123 * @throws NullPointerException unless method {@code newSampleData} has been 124 * called beforehand. 125 */ 126 public RealMatrix calculateHat() { 127 // Create augmented identity matrix 128 RealMatrix Q = qr.getQ(); 129 final int p = qr.getR().getColumnDimension(); 130 final int n = Q.getColumnDimension(); 131 // No try-catch or advertised NotStrictlyPositiveException - NPE above if n < 3 132 Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n); 133 double[][] augIData = augI.getDataRef(); 134 for (int i = 0; i < n; i++) { 135 for (int j =0; j < n; j++) { 136 if (i == j && i < p) { 137 augIData[i][j] = 1d; 138 } else { 139 augIData[i][j] = 0d; 140 } 141 } 142 } 143 144 // Compute and return Hat matrix 145 // No DME advertised - args valid if we get here 146 return Q.multiply(augI).multiply(Q.transpose()); 147 } 148 149 /** 150 * <p>Returns the sum of squared deviations of Y from its mean.</p> 151 * 152 * <p>If the model has no intercept term, <code>0</code> is used for the 153 * mean of Y - i.e., what is returned is the sum of the squared Y values.</p> 154 * 155 * <p>The value returned by this method is the SSTO value used in 156 * the {@link #calculateRSquared() R-squared} computation.</p> 157 * 158 * @return SSTO - the total sum of squares 159 * @throws NullPointerException if the sample has not been set 160 * @see #isNoIntercept() 161 * @since 2.2 162 */ 163 public double calculateTotalSumOfSquares() { 164 if (isNoIntercept()) { 165 return StatUtils.sumSq(getY().toArray()); 166 } else { 167 return new SecondMoment().evaluate(getY().toArray()); 168 } 169 } 170 171 /** 172 * Returns the sum of squared residuals. 173 * 174 * @return residual sum of squares 175 * @since 2.2 176 * @throws org.apache.commons.math3.linear.SingularMatrixException if the design matrix is singular 177 * @throws NullPointerException if the data for the model have not been loaded 178 */ 179 public double calculateResidualSumOfSquares() { 180 final RealVector residuals = calculateResiduals(); 181 // No advertised DME, args are valid 182 return residuals.dotProduct(residuals); 183 } 184 185 /** 186 * Returns the R-Squared statistic, defined by the formula <pre> 187 * R<sup>2</sup> = 1 - SSR / SSTO 188 * </pre> 189 * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals} 190 * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares} 191 * 192 * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p> 193 * 194 * @return R-square statistic 195 * @throws NullPointerException if the sample has not been set 196 * @throws org.apache.commons.math3.linear.SingularMatrixException if the design matrix is singular 197 * @since 2.2 198 */ 199 public double calculateRSquared() { 200 return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares(); 201 } 202 203 /** 204 * <p>Returns the adjusted R-squared statistic, defined by the formula <pre> 205 * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)] 206 * </pre> 207 * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}, 208 * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number 209 * of observations and p is the number of parameters estimated (including the intercept).</p> 210 * 211 * <p>If the regression is estimated without an intercept term, what is returned is <pre> 212 * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code> 213 * </pre></p> 214 * 215 * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p> 216 * 217 * @return adjusted R-Squared statistic 218 * @throws NullPointerException if the sample has not been set 219 * @throws org.apache.commons.math3.linear.SingularMatrixException if the design matrix is singular 220 * @see #isNoIntercept() 221 * @since 2.2 222 */ 223 public double calculateAdjustedRSquared() { 224 final double n = getX().getRowDimension(); 225 if (isNoIntercept()) { 226 return 1 - (1 - calculateRSquared()) * (n / (n - getX().getColumnDimension())); 227 } else { 228 return 1 - (calculateResidualSumOfSquares() * (n - 1)) / 229 (calculateTotalSumOfSquares() * (n - getX().getColumnDimension())); 230 } 231 } 232 233 /** 234 * {@inheritDoc} 235 * <p>This implementation computes and caches the QR decomposition of the X matrix 236 * once it is successfully loaded.</p> 237 */ 238 @Override 239 protected void newXSampleData(double[][] x) { 240 super.newXSampleData(x); 241 qr = new QRDecomposition(getX(), threshold); 242 } 243 244 /** 245 * Calculates the regression coefficients using OLS. 246 * 247 * <p>Data for the model must have been successfully loaded using one of 248 * the {@code newSampleData} methods before invoking this method; otherwise 249 * a {@code NullPointerException} will be thrown.</p> 250 * 251 * @return beta 252 * @throws org.apache.commons.math3.linear.SingularMatrixException if the design matrix is singular 253 * @throws NullPointerException if the data for the model have not been loaded 254 */ 255 @Override 256 protected RealVector calculateBeta() { 257 return qr.getSolver().solve(getY()); 258 } 259 260 /** 261 * <p>Calculates the variance-covariance matrix of the regression parameters. 262 * </p> 263 * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup> 264 * </p> 265 * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup> 266 * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of 267 * R included, where p = the length of the beta vector.</p> 268 * 269 * <p>Data for the model must have been successfully loaded using one of 270 * the {@code newSampleData} methods before invoking this method; otherwise 271 * a {@code NullPointerException} will be thrown.</p> 272 * 273 * @return The beta variance-covariance matrix 274 * @throws org.apache.commons.math3.linear.SingularMatrixException if the design matrix is singular 275 * @throws NullPointerException if the data for the model have not been loaded 276 */ 277 @Override 278 protected RealMatrix calculateBetaVariance() { 279 int p = getX().getColumnDimension(); 280 RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1); 281 RealMatrix Rinv = new LUDecomposition(Raug).getSolver().getInverse(); 282 return Rinv.multiply(Rinv.transpose()); 283 } 284 285}