001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.stat.regression;
018
019import org.apache.commons.math3.exception.DimensionMismatchException;
020import org.apache.commons.math3.exception.InsufficientDataException;
021import org.apache.commons.math3.exception.MathIllegalArgumentException;
022import org.apache.commons.math3.exception.NoDataException;
023import org.apache.commons.math3.exception.NullArgumentException;
024import org.apache.commons.math3.exception.util.LocalizedFormats;
025import org.apache.commons.math3.linear.NonSquareMatrixException;
026import org.apache.commons.math3.linear.RealMatrix;
027import org.apache.commons.math3.linear.Array2DRowRealMatrix;
028import org.apache.commons.math3.linear.RealVector;
029import org.apache.commons.math3.linear.ArrayRealVector;
030import org.apache.commons.math3.stat.descriptive.moment.Variance;
031import org.apache.commons.math3.util.FastMath;
032
033/**
034 * Abstract base class for implementations of MultipleLinearRegression.
035 * @version $Id: AbstractMultipleLinearRegression.java 1547633 2013-12-03 23:03:06Z tn $
036 * @since 2.0
037 */
038public abstract class AbstractMultipleLinearRegression implements
039        MultipleLinearRegression {
040
041    /** X sample data. */
042    private RealMatrix xMatrix;
043
044    /** Y sample data. */
045    private RealVector yVector;
046
047    /** Whether or not the regression model includes an intercept.  True means no intercept. */
048    private boolean noIntercept = false;
049
050    /**
051     * @return the X sample data.
052     */
053    protected RealMatrix getX() {
054        return xMatrix;
055    }
056
057    /**
058     * @return the Y sample data.
059     */
060    protected RealVector getY() {
061        return yVector;
062    }
063
064    /**
065     * @return true if the model has no intercept term; false otherwise
066     * @since 2.2
067     */
068    public boolean isNoIntercept() {
069        return noIntercept;
070    }
071
072    /**
073     * @param noIntercept true means the model is to be estimated without an intercept term
074     * @since 2.2
075     */
076    public void setNoIntercept(boolean noIntercept) {
077        this.noIntercept = noIntercept;
078    }
079
080    /**
081     * <p>Loads model x and y sample data from a flat input array, overriding any previous sample.
082     * </p>
083     * <p>Assumes that rows are concatenated with y values first in each row.  For example, an input
084     * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with
085     * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two
086     * independent variables, as below:
087     * <pre>
088     *   y   x[0]  x[1]
089     *   --------------
090     *   1     2     3
091     *   4     5     6
092     *   7     8     9
093     * </pre>
094     * </p>
095     * <p>Note that there is no need to add an initial unitary column (column of 1's) when
096     * specifying a model including an intercept term.  If {@link #isNoIntercept()} is <code>true</code>,
097     * the X matrix will be created without an initial column of "1"s; otherwise this column will
098     * be added.
099     * </p>
100     * <p>Throws IllegalArgumentException if any of the following preconditions fail:
101     * <ul><li><code>data</code> cannot be null</li>
102     * <li><code>data.length = nobs * (nvars + 1)</li>
103     * <li><code>nobs > nvars</code></li></ul>
104     * </p>
105     *
106     * @param data input data array
107     * @param nobs number of observations (rows)
108     * @param nvars number of independent variables (columns, not counting y)
109     * @throws NullArgumentException if the data array is null
110     * @throws DimensionMismatchException if the length of the data array is not equal
111     * to <code>nobs * (nvars + 1)</code>
112     * @throws InsufficientDataException if <code>nobs</code> is less than
113     * <code>nvars + 1</code>
114     */
115    public void newSampleData(double[] data, int nobs, int nvars) {
116        if (data == null) {
117            throw new NullArgumentException();
118        }
119        if (data.length != nobs * (nvars + 1)) {
120            throw new DimensionMismatchException(data.length, nobs * (nvars + 1));
121        }
122        if (nobs <= nvars) {
123            throw new InsufficientDataException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE, nobs, nvars + 1);
124        }
125        double[] y = new double[nobs];
126        final int cols = noIntercept ? nvars: nvars + 1;
127        double[][] x = new double[nobs][cols];
128        int pointer = 0;
129        for (int i = 0; i < nobs; i++) {
130            y[i] = data[pointer++];
131            if (!noIntercept) {
132                x[i][0] = 1.0d;
133            }
134            for (int j = noIntercept ? 0 : 1; j < cols; j++) {
135                x[i][j] = data[pointer++];
136            }
137        }
138        this.xMatrix = new Array2DRowRealMatrix(x);
139        this.yVector = new ArrayRealVector(y);
140    }
141
142    /**
143     * Loads new y sample data, overriding any previous data.
144     *
145     * @param y the array representing the y sample
146     * @throws NullArgumentException if y is null
147     * @throws NoDataException if y is empty
148     */
149    protected void newYSampleData(double[] y) {
150        if (y == null) {
151            throw new NullArgumentException();
152        }
153        if (y.length == 0) {
154            throw new NoDataException();
155        }
156        this.yVector = new ArrayRealVector(y);
157    }
158
159    /**
160     * <p>Loads new x sample data, overriding any previous data.
161     * </p>
162     * The input <code>x</code> array should have one row for each sample
163     * observation, with columns corresponding to independent variables.
164     * For example, if <pre>
165     * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre>
166     * then <code>setXSampleData(x) </code> results in a model with two independent
167     * variables and 3 observations:
168     * <pre>
169     *   x[0]  x[1]
170     *   ----------
171     *     1    2
172     *     3    4
173     *     5    6
174     * </pre>
175     * </p>
176     * <p>Note that there is no need to add an initial unitary column (column of 1's) when
177     * specifying a model including an intercept term.
178     * </p>
179     * @param x the rectangular array representing the x sample
180     * @throws NullArgumentException if x is null
181     * @throws NoDataException if x is empty
182     * @throws DimensionMismatchException if x is not rectangular
183     */
184    protected void newXSampleData(double[][] x) {
185        if (x == null) {
186            throw new NullArgumentException();
187        }
188        if (x.length == 0) {
189            throw new NoDataException();
190        }
191        if (noIntercept) {
192            this.xMatrix = new Array2DRowRealMatrix(x, true);
193        } else { // Augment design matrix with initial unitary column
194            final int nVars = x[0].length;
195            final double[][] xAug = new double[x.length][nVars + 1];
196            for (int i = 0; i < x.length; i++) {
197                if (x[i].length != nVars) {
198                    throw new DimensionMismatchException(x[i].length, nVars);
199                }
200                xAug[i][0] = 1.0d;
201                System.arraycopy(x[i], 0, xAug[i], 1, nVars);
202            }
203            this.xMatrix = new Array2DRowRealMatrix(xAug, false);
204        }
205    }
206
207    /**
208     * Validates sample data.  Checks that
209     * <ul><li>Neither x nor y is null or empty;</li>
210     * <li>The length (i.e. number of rows) of x equals the length of y</li>
211     * <li>x has at least one more row than it has columns (i.e. there is
212     * sufficient data to estimate regression coefficients for each of the
213     * columns in x plus an intercept.</li>
214     * </ul>
215     *
216     * @param x the [n,k] array representing the x data
217     * @param y the [n,1] array representing the y data
218     * @throws NullArgumentException if {@code x} or {@code y} is null
219     * @throws DimensionMismatchException if {@code x} and {@code y} do not
220     * have the same length
221     * @throws NoDataException if {@code x} or {@code y} are zero-length
222     * @throws MathIllegalArgumentException if the number of rows of {@code x}
223     * is not larger than the number of columns + 1
224     */
225    protected void validateSampleData(double[][] x, double[] y) throws MathIllegalArgumentException {
226        if ((x == null) || (y == null)) {
227            throw new NullArgumentException();
228        }
229        if (x.length != y.length) {
230            throw new DimensionMismatchException(y.length, x.length);
231        }
232        if (x.length == 0) {  // Must be no y data either
233            throw new NoDataException();
234        }
235        if (x[0].length + 1 > x.length) {
236            throw new MathIllegalArgumentException(
237                    LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS,
238                    x.length, x[0].length);
239        }
240    }
241
242    /**
243     * Validates that the x data and covariance matrix have the same
244     * number of rows and that the covariance matrix is square.
245     *
246     * @param x the [n,k] array representing the x sample
247     * @param covariance the [n,n] array representing the covariance matrix
248     * @throws DimensionMismatchException if the number of rows in x is not equal
249     * to the number of rows in covariance
250     * @throws NonSquareMatrixException if the covariance matrix is not square
251     */
252    protected void validateCovarianceData(double[][] x, double[][] covariance) {
253        if (x.length != covariance.length) {
254            throw new DimensionMismatchException(x.length, covariance.length);
255        }
256        if (covariance.length > 0 && covariance.length != covariance[0].length) {
257            throw new NonSquareMatrixException(covariance.length, covariance[0].length);
258        }
259    }
260
261    /**
262     * {@inheritDoc}
263     */
264    public double[] estimateRegressionParameters() {
265        RealVector b = calculateBeta();
266        return b.toArray();
267    }
268
269    /**
270     * {@inheritDoc}
271     */
272    public double[] estimateResiduals() {
273        RealVector b = calculateBeta();
274        RealVector e = yVector.subtract(xMatrix.operate(b));
275        return e.toArray();
276    }
277
278    /**
279     * {@inheritDoc}
280     */
281    public double[][] estimateRegressionParametersVariance() {
282        return calculateBetaVariance().getData();
283    }
284
285    /**
286     * {@inheritDoc}
287     */
288    public double[] estimateRegressionParametersStandardErrors() {
289        double[][] betaVariance = estimateRegressionParametersVariance();
290        double sigma = calculateErrorVariance();
291        int length = betaVariance[0].length;
292        double[] result = new double[length];
293        for (int i = 0; i < length; i++) {
294            result[i] = FastMath.sqrt(sigma * betaVariance[i][i]);
295        }
296        return result;
297    }
298
299    /**
300     * {@inheritDoc}
301     */
302    public double estimateRegressandVariance() {
303        return calculateYVariance();
304    }
305
306    /**
307     * Estimates the variance of the error.
308     *
309     * @return estimate of the error variance
310     * @since 2.2
311     */
312    public double estimateErrorVariance() {
313        return calculateErrorVariance();
314
315    }
316
317    /**
318     * Estimates the standard error of the regression.
319     *
320     * @return regression standard error
321     * @since 2.2
322     */
323    public double estimateRegressionStandardError() {
324        return FastMath.sqrt(estimateErrorVariance());
325    }
326
327    /**
328     * Calculates the beta of multiple linear regression in matrix notation.
329     *
330     * @return beta
331     */
332    protected abstract RealVector calculateBeta();
333
334    /**
335     * Calculates the beta variance of multiple linear regression in matrix
336     * notation.
337     *
338     * @return beta variance
339     */
340    protected abstract RealMatrix calculateBetaVariance();
341
342
343    /**
344     * Calculates the variance of the y values.
345     *
346     * @return Y variance
347     */
348    protected double calculateYVariance() {
349        return new Variance().evaluate(yVector.toArray());
350    }
351
352    /**
353     * <p>Calculates the variance of the error term.</p>
354     * Uses the formula <pre>
355     * var(u) = u &middot; u / (n - k)
356     * </pre>
357     * where n and k are the row and column dimensions of the design
358     * matrix X.
359     *
360     * @return error variance estimate
361     * @since 2.2
362     */
363    protected double calculateErrorVariance() {
364        RealVector residuals = calculateResiduals();
365        return residuals.dotProduct(residuals) /
366               (xMatrix.getRowDimension() - xMatrix.getColumnDimension());
367    }
368
369    /**
370     * Calculates the residuals of multiple linear regression in matrix
371     * notation.
372     *
373     * <pre>
374     * u = y - X * b
375     * </pre>
376     *
377     * @return The residuals [n,1] matrix
378     */
379    protected RealVector calculateResiduals() {
380        RealVector b = calculateBeta();
381        return yVector.subtract(xMatrix.operate(b));
382    }
383
384}