001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.stat.regression;
018
019import org.apache.commons.math3.exception.DimensionMismatchException;
020import org.apache.commons.math3.exception.InsufficientDataException;
021import org.apache.commons.math3.exception.MathIllegalArgumentException;
022import org.apache.commons.math3.exception.NoDataException;
023import org.apache.commons.math3.exception.NullArgumentException;
024import org.apache.commons.math3.exception.util.LocalizedFormats;
025import org.apache.commons.math3.linear.NonSquareMatrixException;
026import org.apache.commons.math3.linear.RealMatrix;
027import org.apache.commons.math3.linear.Array2DRowRealMatrix;
028import org.apache.commons.math3.linear.RealVector;
029import org.apache.commons.math3.linear.ArrayRealVector;
030import org.apache.commons.math3.stat.descriptive.moment.Variance;
031import org.apache.commons.math3.util.FastMath;
032
033/**
034 * Abstract base class for implementations of MultipleLinearRegression.
035 * @since 2.0
036 */
037public abstract class AbstractMultipleLinearRegression implements
038        MultipleLinearRegression {
039
040    /** X sample data. */
041    private RealMatrix xMatrix;
042
043    /** Y sample data. */
044    private RealVector yVector;
045
046    /** Whether or not the regression model includes an intercept.  True means no intercept. */
047    private boolean noIntercept = false;
048
049    /**
050     * @return the X sample data.
051     */
052    protected RealMatrix getX() {
053        return xMatrix;
054    }
055
056    /**
057     * @return the Y sample data.
058     */
059    protected RealVector getY() {
060        return yVector;
061    }
062
063    /**
064     * @return true if the model has no intercept term; false otherwise
065     * @since 2.2
066     */
067    public boolean isNoIntercept() {
068        return noIntercept;
069    }
070
071    /**
072     * @param noIntercept true means the model is to be estimated without an intercept term
073     * @since 2.2
074     */
075    public void setNoIntercept(boolean noIntercept) {
076        this.noIntercept = noIntercept;
077    }
078
079    /**
080     * <p>Loads model x and y sample data from a flat input array, overriding any previous sample.
081     * </p>
082     * <p>Assumes that rows are concatenated with y values first in each row.  For example, an input
083     * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with
084     * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two
085     * independent variables, as below:
086     * <pre>
087     *   y   x[0]  x[1]
088     *   --------------
089     *   1     2     3
090     *   4     5     6
091     *   7     8     9
092     * </pre>
093     * </p>
094     * <p>Note that there is no need to add an initial unitary column (column of 1's) when
095     * specifying a model including an intercept term.  If {@link #isNoIntercept()} is <code>true</code>,
096     * the X matrix will be created without an initial column of "1"s; otherwise this column will
097     * be added.
098     * </p>
099     * <p>Throws IllegalArgumentException if any of the following preconditions fail:
100     * <ul><li><code>data</code> cannot be null</li>
101     * <li><code>data.length = nobs * (nvars + 1)</li>
102     * <li><code>nobs > nvars</code></li></ul>
103     * </p>
104     *
105     * @param data input data array
106     * @param nobs number of observations (rows)
107     * @param nvars number of independent variables (columns, not counting y)
108     * @throws NullArgumentException if the data array is null
109     * @throws DimensionMismatchException if the length of the data array is not equal
110     * to <code>nobs * (nvars + 1)</code>
111     * @throws InsufficientDataException if <code>nobs</code> is less than
112     * <code>nvars + 1</code>
113     */
114    public void newSampleData(double[] data, int nobs, int nvars) {
115        if (data == null) {
116            throw new NullArgumentException();
117        }
118        if (data.length != nobs * (nvars + 1)) {
119            throw new DimensionMismatchException(data.length, nobs * (nvars + 1));
120        }
121        if (nobs <= nvars) {
122            throw new InsufficientDataException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE, nobs, nvars + 1);
123        }
124        double[] y = new double[nobs];
125        final int cols = noIntercept ? nvars: nvars + 1;
126        double[][] x = new double[nobs][cols];
127        int pointer = 0;
128        for (int i = 0; i < nobs; i++) {
129            y[i] = data[pointer++];
130            if (!noIntercept) {
131                x[i][0] = 1.0d;
132            }
133            for (int j = noIntercept ? 0 : 1; j < cols; j++) {
134                x[i][j] = data[pointer++];
135            }
136        }
137        this.xMatrix = new Array2DRowRealMatrix(x);
138        this.yVector = new ArrayRealVector(y);
139    }
140
141    /**
142     * Loads new y sample data, overriding any previous data.
143     *
144     * @param y the array representing the y sample
145     * @throws NullArgumentException if y is null
146     * @throws NoDataException if y is empty
147     */
148    protected void newYSampleData(double[] y) {
149        if (y == null) {
150            throw new NullArgumentException();
151        }
152        if (y.length == 0) {
153            throw new NoDataException();
154        }
155        this.yVector = new ArrayRealVector(y);
156    }
157
158    /**
159     * <p>Loads new x sample data, overriding any previous data.
160     * </p>
161     * The input <code>x</code> array should have one row for each sample
162     * observation, with columns corresponding to independent variables.
163     * For example, if <pre>
164     * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre>
165     * then <code>setXSampleData(x) </code> results in a model with two independent
166     * variables and 3 observations:
167     * <pre>
168     *   x[0]  x[1]
169     *   ----------
170     *     1    2
171     *     3    4
172     *     5    6
173     * </pre>
174     * </p>
175     * <p>Note that there is no need to add an initial unitary column (column of 1's) when
176     * specifying a model including an intercept term.
177     * </p>
178     * @param x the rectangular array representing the x sample
179     * @throws NullArgumentException if x is null
180     * @throws NoDataException if x is empty
181     * @throws DimensionMismatchException if x is not rectangular
182     */
183    protected void newXSampleData(double[][] x) {
184        if (x == null) {
185            throw new NullArgumentException();
186        }
187        if (x.length == 0) {
188            throw new NoDataException();
189        }
190        if (noIntercept) {
191            this.xMatrix = new Array2DRowRealMatrix(x, true);
192        } else { // Augment design matrix with initial unitary column
193            final int nVars = x[0].length;
194            final double[][] xAug = new double[x.length][nVars + 1];
195            for (int i = 0; i < x.length; i++) {
196                if (x[i].length != nVars) {
197                    throw new DimensionMismatchException(x[i].length, nVars);
198                }
199                xAug[i][0] = 1.0d;
200                System.arraycopy(x[i], 0, xAug[i], 1, nVars);
201            }
202            this.xMatrix = new Array2DRowRealMatrix(xAug, false);
203        }
204    }
205
206    /**
207     * Validates sample data.  Checks that
208     * <ul><li>Neither x nor y is null or empty;</li>
209     * <li>The length (i.e. number of rows) of x equals the length of y</li>
210     * <li>x has at least one more row than it has columns (i.e. there is
211     * sufficient data to estimate regression coefficients for each of the
212     * columns in x plus an intercept.</li>
213     * </ul>
214     *
215     * @param x the [n,k] array representing the x data
216     * @param y the [n,1] array representing the y data
217     * @throws NullArgumentException if {@code x} or {@code y} is null
218     * @throws DimensionMismatchException if {@code x} and {@code y} do not
219     * have the same length
220     * @throws NoDataException if {@code x} or {@code y} are zero-length
221     * @throws MathIllegalArgumentException if the number of rows of {@code x}
222     * is not larger than the number of columns + 1
223     */
224    protected void validateSampleData(double[][] x, double[] y) throws MathIllegalArgumentException {
225        if ((x == null) || (y == null)) {
226            throw new NullArgumentException();
227        }
228        if (x.length != y.length) {
229            throw new DimensionMismatchException(y.length, x.length);
230        }
231        if (x.length == 0) {  // Must be no y data either
232            throw new NoDataException();
233        }
234        if (x[0].length + 1 > x.length) {
235            throw new MathIllegalArgumentException(
236                    LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS,
237                    x.length, x[0].length);
238        }
239    }
240
241    /**
242     * Validates that the x data and covariance matrix have the same
243     * number of rows and that the covariance matrix is square.
244     *
245     * @param x the [n,k] array representing the x sample
246     * @param covariance the [n,n] array representing the covariance matrix
247     * @throws DimensionMismatchException if the number of rows in x is not equal
248     * to the number of rows in covariance
249     * @throws NonSquareMatrixException if the covariance matrix is not square
250     */
251    protected void validateCovarianceData(double[][] x, double[][] covariance) {
252        if (x.length != covariance.length) {
253            throw new DimensionMismatchException(x.length, covariance.length);
254        }
255        if (covariance.length > 0 && covariance.length != covariance[0].length) {
256            throw new NonSquareMatrixException(covariance.length, covariance[0].length);
257        }
258    }
259
260    /**
261     * {@inheritDoc}
262     */
263    public double[] estimateRegressionParameters() {
264        RealVector b = calculateBeta();
265        return b.toArray();
266    }
267
268    /**
269     * {@inheritDoc}
270     */
271    public double[] estimateResiduals() {
272        RealVector b = calculateBeta();
273        RealVector e = yVector.subtract(xMatrix.operate(b));
274        return e.toArray();
275    }
276
277    /**
278     * {@inheritDoc}
279     */
280    public double[][] estimateRegressionParametersVariance() {
281        return calculateBetaVariance().getData();
282    }
283
284    /**
285     * {@inheritDoc}
286     */
287    public double[] estimateRegressionParametersStandardErrors() {
288        double[][] betaVariance = estimateRegressionParametersVariance();
289        double sigma = calculateErrorVariance();
290        int length = betaVariance[0].length;
291        double[] result = new double[length];
292        for (int i = 0; i < length; i++) {
293            result[i] = FastMath.sqrt(sigma * betaVariance[i][i]);
294        }
295        return result;
296    }
297
298    /**
299     * {@inheritDoc}
300     */
301    public double estimateRegressandVariance() {
302        return calculateYVariance();
303    }
304
305    /**
306     * Estimates the variance of the error.
307     *
308     * @return estimate of the error variance
309     * @since 2.2
310     */
311    public double estimateErrorVariance() {
312        return calculateErrorVariance();
313
314    }
315
316    /**
317     * Estimates the standard error of the regression.
318     *
319     * @return regression standard error
320     * @since 2.2
321     */
322    public double estimateRegressionStandardError() {
323        return FastMath.sqrt(estimateErrorVariance());
324    }
325
326    /**
327     * Calculates the beta of multiple linear regression in matrix notation.
328     *
329     * @return beta
330     */
331    protected abstract RealVector calculateBeta();
332
333    /**
334     * Calculates the beta variance of multiple linear regression in matrix
335     * notation.
336     *
337     * @return beta variance
338     */
339    protected abstract RealMatrix calculateBetaVariance();
340
341
342    /**
343     * Calculates the variance of the y values.
344     *
345     * @return Y variance
346     */
347    protected double calculateYVariance() {
348        return new Variance().evaluate(yVector.toArray());
349    }
350
351    /**
352     * <p>Calculates the variance of the error term.</p>
353     * Uses the formula <pre>
354     * var(u) = u &middot; u / (n - k)
355     * </pre>
356     * where n and k are the row and column dimensions of the design
357     * matrix X.
358     *
359     * @return error variance estimate
360     * @since 2.2
361     */
362    protected double calculateErrorVariance() {
363        RealVector residuals = calculateResiduals();
364        return residuals.dotProduct(residuals) /
365               (xMatrix.getRowDimension() - xMatrix.getColumnDimension());
366    }
367
368    /**
369     * Calculates the residuals of multiple linear regression in matrix
370     * notation.
371     *
372     * <pre>
373     * u = y - X * b
374     * </pre>
375     *
376     * @return The residuals [n,1] matrix
377     */
378    protected RealVector calculateResiduals() {
379        RealVector b = calculateBeta();
380        return yVector.subtract(xMatrix.operate(b));
381    }
382
383}