001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math3.stat.regression; 018 019import org.apache.commons.math3.exception.DimensionMismatchException; 020import org.apache.commons.math3.exception.InsufficientDataException; 021import org.apache.commons.math3.exception.MathIllegalArgumentException; 022import org.apache.commons.math3.exception.NoDataException; 023import org.apache.commons.math3.exception.NullArgumentException; 024import org.apache.commons.math3.exception.util.LocalizedFormats; 025import org.apache.commons.math3.linear.NonSquareMatrixException; 026import org.apache.commons.math3.linear.RealMatrix; 027import org.apache.commons.math3.linear.Array2DRowRealMatrix; 028import org.apache.commons.math3.linear.RealVector; 029import org.apache.commons.math3.linear.ArrayRealVector; 030import org.apache.commons.math3.stat.descriptive.moment.Variance; 031import org.apache.commons.math3.util.FastMath; 032 033/** 034 * Abstract base class for implementations of MultipleLinearRegression. 035 * @since 2.0 036 */ 037public abstract class AbstractMultipleLinearRegression implements 038 MultipleLinearRegression { 039 040 /** X sample data. */ 041 private RealMatrix xMatrix; 042 043 /** Y sample data. */ 044 private RealVector yVector; 045 046 /** Whether or not the regression model includes an intercept. True means no intercept. */ 047 private boolean noIntercept = false; 048 049 /** 050 * @return the X sample data. 051 */ 052 protected RealMatrix getX() { 053 return xMatrix; 054 } 055 056 /** 057 * @return the Y sample data. 058 */ 059 protected RealVector getY() { 060 return yVector; 061 } 062 063 /** 064 * @return true if the model has no intercept term; false otherwise 065 * @since 2.2 066 */ 067 public boolean isNoIntercept() { 068 return noIntercept; 069 } 070 071 /** 072 * @param noIntercept true means the model is to be estimated without an intercept term 073 * @since 2.2 074 */ 075 public void setNoIntercept(boolean noIntercept) { 076 this.noIntercept = noIntercept; 077 } 078 079 /** 080 * <p>Loads model x and y sample data from a flat input array, overriding any previous sample. 081 * </p> 082 * <p>Assumes that rows are concatenated with y values first in each row. For example, an input 083 * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with 084 * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two 085 * independent variables, as below: 086 * <pre> 087 * y x[0] x[1] 088 * -------------- 089 * 1 2 3 090 * 4 5 6 091 * 7 8 9 092 * </pre> 093 * </p> 094 * <p>Note that there is no need to add an initial unitary column (column of 1's) when 095 * specifying a model including an intercept term. If {@link #isNoIntercept()} is <code>true</code>, 096 * the X matrix will be created without an initial column of "1"s; otherwise this column will 097 * be added. 098 * </p> 099 * <p>Throws IllegalArgumentException if any of the following preconditions fail: 100 * <ul><li><code>data</code> cannot be null</li> 101 * <li><code>data.length = nobs * (nvars + 1)</li> 102 * <li><code>nobs > nvars</code></li></ul> 103 * </p> 104 * 105 * @param data input data array 106 * @param nobs number of observations (rows) 107 * @param nvars number of independent variables (columns, not counting y) 108 * @throws NullArgumentException if the data array is null 109 * @throws DimensionMismatchException if the length of the data array is not equal 110 * to <code>nobs * (nvars + 1)</code> 111 * @throws InsufficientDataException if <code>nobs</code> is less than 112 * <code>nvars + 1</code> 113 */ 114 public void newSampleData(double[] data, int nobs, int nvars) { 115 if (data == null) { 116 throw new NullArgumentException(); 117 } 118 if (data.length != nobs * (nvars + 1)) { 119 throw new DimensionMismatchException(data.length, nobs * (nvars + 1)); 120 } 121 if (nobs <= nvars) { 122 throw new InsufficientDataException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE, nobs, nvars + 1); 123 } 124 double[] y = new double[nobs]; 125 final int cols = noIntercept ? nvars: nvars + 1; 126 double[][] x = new double[nobs][cols]; 127 int pointer = 0; 128 for (int i = 0; i < nobs; i++) { 129 y[i] = data[pointer++]; 130 if (!noIntercept) { 131 x[i][0] = 1.0d; 132 } 133 for (int j = noIntercept ? 0 : 1; j < cols; j++) { 134 x[i][j] = data[pointer++]; 135 } 136 } 137 this.xMatrix = new Array2DRowRealMatrix(x); 138 this.yVector = new ArrayRealVector(y); 139 } 140 141 /** 142 * Loads new y sample data, overriding any previous data. 143 * 144 * @param y the array representing the y sample 145 * @throws NullArgumentException if y is null 146 * @throws NoDataException if y is empty 147 */ 148 protected void newYSampleData(double[] y) { 149 if (y == null) { 150 throw new NullArgumentException(); 151 } 152 if (y.length == 0) { 153 throw new NoDataException(); 154 } 155 this.yVector = new ArrayRealVector(y); 156 } 157 158 /** 159 * <p>Loads new x sample data, overriding any previous data. 160 * </p> 161 * The input <code>x</code> array should have one row for each sample 162 * observation, with columns corresponding to independent variables. 163 * For example, if <pre> 164 * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre> 165 * then <code>setXSampleData(x) </code> results in a model with two independent 166 * variables and 3 observations: 167 * <pre> 168 * x[0] x[1] 169 * ---------- 170 * 1 2 171 * 3 4 172 * 5 6 173 * </pre> 174 * </p> 175 * <p>Note that there is no need to add an initial unitary column (column of 1's) when 176 * specifying a model including an intercept term. 177 * </p> 178 * @param x the rectangular array representing the x sample 179 * @throws NullArgumentException if x is null 180 * @throws NoDataException if x is empty 181 * @throws DimensionMismatchException if x is not rectangular 182 */ 183 protected void newXSampleData(double[][] x) { 184 if (x == null) { 185 throw new NullArgumentException(); 186 } 187 if (x.length == 0) { 188 throw new NoDataException(); 189 } 190 if (noIntercept) { 191 this.xMatrix = new Array2DRowRealMatrix(x, true); 192 } else { // Augment design matrix with initial unitary column 193 final int nVars = x[0].length; 194 final double[][] xAug = new double[x.length][nVars + 1]; 195 for (int i = 0; i < x.length; i++) { 196 if (x[i].length != nVars) { 197 throw new DimensionMismatchException(x[i].length, nVars); 198 } 199 xAug[i][0] = 1.0d; 200 System.arraycopy(x[i], 0, xAug[i], 1, nVars); 201 } 202 this.xMatrix = new Array2DRowRealMatrix(xAug, false); 203 } 204 } 205 206 /** 207 * Validates sample data. Checks that 208 * <ul><li>Neither x nor y is null or empty;</li> 209 * <li>The length (i.e. number of rows) of x equals the length of y</li> 210 * <li>x has at least one more row than it has columns (i.e. there is 211 * sufficient data to estimate regression coefficients for each of the 212 * columns in x plus an intercept.</li> 213 * </ul> 214 * 215 * @param x the [n,k] array representing the x data 216 * @param y the [n,1] array representing the y data 217 * @throws NullArgumentException if {@code x} or {@code y} is null 218 * @throws DimensionMismatchException if {@code x} and {@code y} do not 219 * have the same length 220 * @throws NoDataException if {@code x} or {@code y} are zero-length 221 * @throws MathIllegalArgumentException if the number of rows of {@code x} 222 * is not larger than the number of columns + 1 223 */ 224 protected void validateSampleData(double[][] x, double[] y) throws MathIllegalArgumentException { 225 if ((x == null) || (y == null)) { 226 throw new NullArgumentException(); 227 } 228 if (x.length != y.length) { 229 throw new DimensionMismatchException(y.length, x.length); 230 } 231 if (x.length == 0) { // Must be no y data either 232 throw new NoDataException(); 233 } 234 if (x[0].length + 1 > x.length) { 235 throw new MathIllegalArgumentException( 236 LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS, 237 x.length, x[0].length); 238 } 239 } 240 241 /** 242 * Validates that the x data and covariance matrix have the same 243 * number of rows and that the covariance matrix is square. 244 * 245 * @param x the [n,k] array representing the x sample 246 * @param covariance the [n,n] array representing the covariance matrix 247 * @throws DimensionMismatchException if the number of rows in x is not equal 248 * to the number of rows in covariance 249 * @throws NonSquareMatrixException if the covariance matrix is not square 250 */ 251 protected void validateCovarianceData(double[][] x, double[][] covariance) { 252 if (x.length != covariance.length) { 253 throw new DimensionMismatchException(x.length, covariance.length); 254 } 255 if (covariance.length > 0 && covariance.length != covariance[0].length) { 256 throw new NonSquareMatrixException(covariance.length, covariance[0].length); 257 } 258 } 259 260 /** 261 * {@inheritDoc} 262 */ 263 public double[] estimateRegressionParameters() { 264 RealVector b = calculateBeta(); 265 return b.toArray(); 266 } 267 268 /** 269 * {@inheritDoc} 270 */ 271 public double[] estimateResiduals() { 272 RealVector b = calculateBeta(); 273 RealVector e = yVector.subtract(xMatrix.operate(b)); 274 return e.toArray(); 275 } 276 277 /** 278 * {@inheritDoc} 279 */ 280 public double[][] estimateRegressionParametersVariance() { 281 return calculateBetaVariance().getData(); 282 } 283 284 /** 285 * {@inheritDoc} 286 */ 287 public double[] estimateRegressionParametersStandardErrors() { 288 double[][] betaVariance = estimateRegressionParametersVariance(); 289 double sigma = calculateErrorVariance(); 290 int length = betaVariance[0].length; 291 double[] result = new double[length]; 292 for (int i = 0; i < length; i++) { 293 result[i] = FastMath.sqrt(sigma * betaVariance[i][i]); 294 } 295 return result; 296 } 297 298 /** 299 * {@inheritDoc} 300 */ 301 public double estimateRegressandVariance() { 302 return calculateYVariance(); 303 } 304 305 /** 306 * Estimates the variance of the error. 307 * 308 * @return estimate of the error variance 309 * @since 2.2 310 */ 311 public double estimateErrorVariance() { 312 return calculateErrorVariance(); 313 314 } 315 316 /** 317 * Estimates the standard error of the regression. 318 * 319 * @return regression standard error 320 * @since 2.2 321 */ 322 public double estimateRegressionStandardError() { 323 return FastMath.sqrt(estimateErrorVariance()); 324 } 325 326 /** 327 * Calculates the beta of multiple linear regression in matrix notation. 328 * 329 * @return beta 330 */ 331 protected abstract RealVector calculateBeta(); 332 333 /** 334 * Calculates the beta variance of multiple linear regression in matrix 335 * notation. 336 * 337 * @return beta variance 338 */ 339 protected abstract RealMatrix calculateBetaVariance(); 340 341 342 /** 343 * Calculates the variance of the y values. 344 * 345 * @return Y variance 346 */ 347 protected double calculateYVariance() { 348 return new Variance().evaluate(yVector.toArray()); 349 } 350 351 /** 352 * <p>Calculates the variance of the error term.</p> 353 * Uses the formula <pre> 354 * var(u) = u · u / (n - k) 355 * </pre> 356 * where n and k are the row and column dimensions of the design 357 * matrix X. 358 * 359 * @return error variance estimate 360 * @since 2.2 361 */ 362 protected double calculateErrorVariance() { 363 RealVector residuals = calculateResiduals(); 364 return residuals.dotProduct(residuals) / 365 (xMatrix.getRowDimension() - xMatrix.getColumnDimension()); 366 } 367 368 /** 369 * Calculates the residuals of multiple linear regression in matrix 370 * notation. 371 * 372 * <pre> 373 * u = y - X * b 374 * </pre> 375 * 376 * @return The residuals [n,1] matrix 377 */ 378 protected RealVector calculateResiduals() { 379 RealVector b = calculateBeta(); 380 return yVector.subtract(xMatrix.operate(b)); 381 } 382 383}