001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math4.legacy.stat.regression; 018 019import org.apache.commons.math4.legacy.exception.DimensionMismatchException; 020import org.apache.commons.math4.legacy.exception.InsufficientDataException; 021import org.apache.commons.math4.legacy.exception.MathIllegalArgumentException; 022import org.apache.commons.math4.legacy.exception.NoDataException; 023import org.apache.commons.math4.legacy.exception.NullArgumentException; 024import org.apache.commons.math4.legacy.exception.util.LocalizedFormats; 025import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix; 026import org.apache.commons.math4.legacy.linear.ArrayRealVector; 027import org.apache.commons.math4.legacy.linear.NonSquareMatrixException; 028import org.apache.commons.math4.legacy.linear.RealMatrix; 029import org.apache.commons.math4.legacy.linear.RealVector; 030import org.apache.commons.math4.legacy.stat.descriptive.moment.Variance; 031import org.apache.commons.math4.core.jdkmath.JdkMath; 032 033/** 034 * Abstract base class for implementations of MultipleLinearRegression. 035 * @since 2.0 036 */ 037public abstract class AbstractMultipleLinearRegression implements 038 MultipleLinearRegression { 039 040 /** X sample data. */ 041 private RealMatrix xMatrix; 042 043 /** Y sample data. */ 044 private RealVector yVector; 045 046 /** Whether or not the regression model includes an intercept. True means no intercept. */ 047 private boolean noIntercept; 048 049 /** 050 * @return the X sample data. 051 */ 052 protected RealMatrix getX() { 053 return xMatrix; 054 } 055 056 /** 057 * @return the Y sample data. 058 */ 059 protected RealVector getY() { 060 return yVector; 061 } 062 063 /** 064 * @return true if the model has no intercept term; false otherwise 065 * @since 2.2 066 */ 067 public boolean isNoIntercept() { 068 return noIntercept; 069 } 070 071 /** 072 * @param noIntercept true means the model is to be estimated without an intercept term 073 * @since 2.2 074 */ 075 public void setNoIntercept(boolean noIntercept) { 076 this.noIntercept = noIntercept; 077 } 078 079 /** 080 * <p>Loads model x and y sample data from a flat input array, overriding any previous sample. 081 * </p> 082 * <p>Assumes that rows are concatenated with y values first in each row. For example, an input 083 * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with 084 * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two 085 * independent variables, as below: 086 * <pre> 087 * y x[0] x[1] 088 * -------------- 089 * 1 2 3 090 * 4 5 6 091 * 7 8 9 092 * </pre> 093 * 094 * <p>Note that there is no need to add an initial unitary column (column of 1's) when 095 * specifying a model including an intercept term. If {@link #isNoIntercept()} is <code>true</code>, 096 * the X matrix will be created without an initial column of "1"s; otherwise this column will 097 * be added. 098 * </p> 099 * <p>Throws IllegalArgumentException if any of the following preconditions fail: 100 * <ul><li><code>data</code> cannot be null</li> 101 * <li><code>data.length = nobs * (nvars + 1)</code></li> 102 * <li>{@code nobs > nvars}</li></ul> 103 * 104 * @param data input data array 105 * @param nobs number of observations (rows) 106 * @param nvars number of independent variables (columns, not counting y) 107 * @throws NullArgumentException if the data array is null 108 * @throws DimensionMismatchException if the length of the data array is not equal 109 * to <code>nobs * (nvars + 1)</code> 110 * @throws InsufficientDataException if <code>nobs</code> is less than 111 * <code>nvars + 1</code> 112 */ 113 public void newSampleData(double[] data, int nobs, int nvars) { 114 if (data == null) { 115 throw new NullArgumentException(); 116 } 117 if (data.length != nobs * (nvars + 1)) { 118 throw new DimensionMismatchException(data.length, nobs * (nvars + 1)); 119 } 120 if (nobs <= nvars) { 121 throw new InsufficientDataException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE, nobs, nvars + 1); 122 } 123 double[] y = new double[nobs]; 124 final int cols = noIntercept ? nvars: nvars + 1; 125 double[][] x = new double[nobs][cols]; 126 int pointer = 0; 127 for (int i = 0; i < nobs; i++) { 128 y[i] = data[pointer++]; 129 if (!noIntercept) { 130 x[i][0] = 1.0d; 131 } 132 for (int j = noIntercept ? 0 : 1; j < cols; j++) { 133 x[i][j] = data[pointer++]; 134 } 135 } 136 this.xMatrix = new Array2DRowRealMatrix(x); 137 this.yVector = new ArrayRealVector(y); 138 } 139 140 /** 141 * Loads new y sample data, overriding any previous data. 142 * 143 * @param y the array representing the y sample 144 * @throws NullArgumentException if y is null 145 * @throws NoDataException if y is empty 146 */ 147 protected void newYSampleData(double[] y) { 148 if (y == null) { 149 throw new NullArgumentException(); 150 } 151 if (y.length == 0) { 152 throw new NoDataException(); 153 } 154 this.yVector = new ArrayRealVector(y); 155 } 156 157 /** 158 * <p>Loads new x sample data, overriding any previous data. 159 * </p> 160 * The input <code>x</code> array should have one row for each sample 161 * observation, with columns corresponding to independent variables. 162 * For example, if <pre> 163 * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre> 164 * then <code>setXSampleData(x) </code> results in a model with two independent 165 * variables and 3 observations: 166 * <pre> 167 * x[0] x[1] 168 * ---------- 169 * 1 2 170 * 3 4 171 * 5 6 172 * </pre> 173 * <p>Note that there is no need to add an initial unitary column (column of 1's) when 174 * specifying a model including an intercept term. 175 * </p> 176 * @param x the rectangular array representing the x sample 177 * @throws NullArgumentException if x is null 178 * @throws NoDataException if x is empty 179 * @throws DimensionMismatchException if x is not rectangular 180 */ 181 protected void newXSampleData(double[][] x) { 182 if (x == null) { 183 throw new NullArgumentException(); 184 } 185 if (x.length == 0) { 186 throw new NoDataException(); 187 } 188 if (noIntercept) { 189 this.xMatrix = new Array2DRowRealMatrix(x, true); 190 } else { // Augment design matrix with initial unitary column 191 final int nVars = x[0].length; 192 final double[][] xAug = new double[x.length][nVars + 1]; 193 for (int i = 0; i < x.length; i++) { 194 if (x[i].length != nVars) { 195 throw new DimensionMismatchException(x[i].length, nVars); 196 } 197 xAug[i][0] = 1.0d; 198 System.arraycopy(x[i], 0, xAug[i], 1, nVars); 199 } 200 this.xMatrix = new Array2DRowRealMatrix(xAug, false); 201 } 202 } 203 204 /** 205 * Validates sample data. Checks that 206 * <ul><li>Neither x nor y is null or empty;</li> 207 * <li>The length (i.e. number of rows) of x equals the length of y</li> 208 * <li>x has at least one more row than it has columns (i.e. there is 209 * sufficient data to estimate regression coefficients for each of the 210 * columns in x plus an intercept.</li> 211 * </ul> 212 * 213 * @param x the [n,k] array representing the x data 214 * @param y the [n,1] array representing the y data 215 * @throws NullArgumentException if {@code x} or {@code y} is null 216 * @throws DimensionMismatchException if {@code x} and {@code y} do not 217 * have the same length 218 * @throws NoDataException if {@code x} or {@code y} are zero-length 219 * @throws MathIllegalArgumentException if the number of rows of {@code x} 220 * is not larger than the number of columns + 1 221 */ 222 protected void validateSampleData(double[][] x, double[] y) throws MathIllegalArgumentException { 223 if (x == null || y == null) { 224 throw new NullArgumentException(); 225 } 226 if (x.length != y.length) { 227 throw new DimensionMismatchException(y.length, x.length); 228 } 229 if (x.length == 0) { // Must be no y data either 230 throw new NoDataException(); 231 } 232 if (x[0].length + 1 > x.length) { 233 throw new MathIllegalArgumentException( 234 LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS, 235 x.length, x[0].length); 236 } 237 } 238 239 /** 240 * Validates that the x data and covariance matrix have the same 241 * number of rows and that the covariance matrix is square. 242 * 243 * @param x the [n,k] array representing the x sample 244 * @param covariance the [n,n] array representing the covariance matrix 245 * @throws DimensionMismatchException if the number of rows in x is not equal 246 * to the number of rows in covariance 247 * @throws NonSquareMatrixException if the covariance matrix is not square 248 */ 249 protected void validateCovarianceData(double[][] x, double[][] covariance) { 250 if (x.length != covariance.length) { 251 throw new DimensionMismatchException(x.length, covariance.length); 252 } 253 if (covariance.length > 0 && covariance.length != covariance[0].length) { 254 throw new NonSquareMatrixException(covariance.length, covariance[0].length); 255 } 256 } 257 258 /** 259 * {@inheritDoc} 260 */ 261 @Override 262 public double[] estimateRegressionParameters() { 263 RealVector b = calculateBeta(); 264 return b.toArray(); 265 } 266 267 /** 268 * {@inheritDoc} 269 */ 270 @Override 271 public double[] estimateResiduals() { 272 RealVector b = calculateBeta(); 273 RealVector e = yVector.subtract(xMatrix.operate(b)); 274 return e.toArray(); 275 } 276 277 /** 278 * {@inheritDoc} 279 */ 280 @Override 281 public double[][] estimateRegressionParametersVariance() { 282 return calculateBetaVariance().getData(); 283 } 284 285 /** 286 * {@inheritDoc} 287 */ 288 @Override 289 public double[] estimateRegressionParametersStandardErrors() { 290 double[][] betaVariance = estimateRegressionParametersVariance(); 291 double sigma = calculateErrorVariance(); 292 int length = betaVariance[0].length; 293 double[] result = new double[length]; 294 for (int i = 0; i < length; i++) { 295 result[i] = JdkMath.sqrt(sigma * betaVariance[i][i]); 296 } 297 return result; 298 } 299 300 /** 301 * {@inheritDoc} 302 */ 303 @Override 304 public double estimateRegressandVariance() { 305 return calculateYVariance(); 306 } 307 308 /** 309 * Estimates the variance of the error. 310 * 311 * @return estimate of the error variance 312 * @since 2.2 313 */ 314 public double estimateErrorVariance() { 315 return calculateErrorVariance(); 316 } 317 318 /** 319 * Estimates the standard error of the regression. 320 * 321 * @return regression standard error 322 * @since 2.2 323 */ 324 public double estimateRegressionStandardError() { 325 return JdkMath.sqrt(estimateErrorVariance()); 326 } 327 328 /** 329 * Calculates the beta of multiple linear regression in matrix notation. 330 * 331 * @return beta 332 */ 333 protected abstract RealVector calculateBeta(); 334 335 /** 336 * Calculates the beta variance of multiple linear regression in matrix 337 * notation. 338 * 339 * @return beta variance 340 */ 341 protected abstract RealMatrix calculateBetaVariance(); 342 343 344 /** 345 * Calculates the variance of the y values. 346 * 347 * @return Y variance 348 */ 349 protected double calculateYVariance() { 350 return new Variance().evaluate(yVector.toArray()); 351 } 352 353 /** 354 * <p>Calculates the variance of the error term.</p> 355 * Uses the formula <pre> 356 * var(u) = u · u / (n - k) 357 * </pre> 358 * where n and k are the row and column dimensions of the design 359 * matrix X. 360 * 361 * @return error variance estimate 362 * @since 2.2 363 */ 364 protected double calculateErrorVariance() { 365 RealVector residuals = calculateResiduals(); 366 return residuals.dotProduct(residuals) / 367 (xMatrix.getRowDimension() - xMatrix.getColumnDimension()); 368 } 369 370 /** 371 * Calculates the residuals of multiple linear regression in matrix 372 * notation. 373 * 374 * <pre> 375 * u = y - X * b 376 * </pre> 377 * 378 * @return The residuals [n,1] matrix 379 */ 380 protected RealVector calculateResiduals() { 381 RealVector b = calculateBeta(); 382 return yVector.subtract(xMatrix.operate(b)); 383 } 384}