001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.stat.regression;
018
019import org.apache.commons.math4.exception.DimensionMismatchException;
020import org.apache.commons.math4.exception.InsufficientDataException;
021import org.apache.commons.math4.exception.MathIllegalArgumentException;
022import org.apache.commons.math4.exception.NoDataException;
023import org.apache.commons.math4.exception.NullArgumentException;
024import org.apache.commons.math4.exception.util.LocalizedFormats;
025import org.apache.commons.math4.linear.Array2DRowRealMatrix;
026import org.apache.commons.math4.linear.ArrayRealVector;
027import org.apache.commons.math4.linear.NonSquareMatrixException;
028import org.apache.commons.math4.linear.RealMatrix;
029import org.apache.commons.math4.linear.RealVector;
030import org.apache.commons.math4.stat.descriptive.moment.Variance;
031import org.apache.commons.math4.util.FastMath;
032
033/**
034 * Abstract base class for implementations of MultipleLinearRegression.
035 * @since 2.0
036 */
037public abstract class AbstractMultipleLinearRegression implements
038        MultipleLinearRegression {
039
040    /** X sample data. */
041    private RealMatrix xMatrix;
042
043    /** Y sample data. */
044    private RealVector yVector;
045
046    /** Whether or not the regression model includes an intercept.  True means no intercept. */
047    private boolean noIntercept;
048
049    /**
050     * @return the X sample data.
051     */
052    protected RealMatrix getX() {
053        return xMatrix;
054    }
055
056    /**
057     * @return the Y sample data.
058     */
059    protected RealVector getY() {
060        return yVector;
061    }
062
063    /**
064     * @return true if the model has no intercept term; false otherwise
065     * @since 2.2
066     */
067    public boolean isNoIntercept() {
068        return noIntercept;
069    }
070
071    /**
072     * @param noIntercept true means the model is to be estimated without an intercept term
073     * @since 2.2
074     */
075    public void setNoIntercept(boolean noIntercept) {
076        this.noIntercept = noIntercept;
077    }
078
079    /**
080     * <p>Loads model x and y sample data from a flat input array, overriding any previous sample.
081     * </p>
082     * <p>Assumes that rows are concatenated with y values first in each row.  For example, an input
083     * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with
084     * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two
085     * independent variables, as below:
086     * <pre>
087     *   y   x[0]  x[1]
088     *   --------------
089     *   1     2     3
090     *   4     5     6
091     *   7     8     9
092     * </pre>
093     *
094     * <p>Note that there is no need to add an initial unitary column (column of 1's) when
095     * specifying a model including an intercept term.  If {@link #isNoIntercept()} is <code>true</code>,
096     * the X matrix will be created without an initial column of "1"s; otherwise this column will
097     * be added.
098     * </p>
099     * <p>Throws IllegalArgumentException if any of the following preconditions fail:
100     * <ul><li><code>data</code> cannot be null</li>
101     * <li><code>data.length = nobs * (nvars + 1)</code></li>
102     * <li>{@code nobs > nvars}</li></ul>
103     *
104     * @param data input data array
105     * @param nobs number of observations (rows)
106     * @param nvars number of independent variables (columns, not counting y)
107     * @throws NullArgumentException if the data array is null
108     * @throws DimensionMismatchException if the length of the data array is not equal
109     * to <code>nobs * (nvars + 1)</code>
110     * @throws InsufficientDataException if <code>nobs</code> is less than
111     * <code>nvars + 1</code>
112     */
113    public void newSampleData(double[] data, int nobs, int nvars) {
114        if (data == null) {
115            throw new NullArgumentException();
116        }
117        if (data.length != nobs * (nvars + 1)) {
118            throw new DimensionMismatchException(data.length, nobs * (nvars + 1));
119        }
120        if (nobs <= nvars) {
121            throw new InsufficientDataException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE, nobs, nvars + 1);
122        }
123        double[] y = new double[nobs];
124        final int cols = noIntercept ? nvars: nvars + 1;
125        double[][] x = new double[nobs][cols];
126        int pointer = 0;
127        for (int i = 0; i < nobs; i++) {
128            y[i] = data[pointer++];
129            if (!noIntercept) {
130                x[i][0] = 1.0d;
131            }
132            for (int j = noIntercept ? 0 : 1; j < cols; j++) {
133                x[i][j] = data[pointer++];
134            }
135        }
136        this.xMatrix = new Array2DRowRealMatrix(x);
137        this.yVector = new ArrayRealVector(y);
138    }
139
140    /**
141     * Loads new y sample data, overriding any previous data.
142     *
143     * @param y the array representing the y sample
144     * @throws NullArgumentException if y is null
145     * @throws NoDataException if y is empty
146     */
147    protected void newYSampleData(double[] y) {
148        if (y == null) {
149            throw new NullArgumentException();
150        }
151        if (y.length == 0) {
152            throw new NoDataException();
153        }
154        this.yVector = new ArrayRealVector(y);
155    }
156
157    /**
158     * <p>Loads new x sample data, overriding any previous data.
159     * </p>
160     * The input <code>x</code> array should have one row for each sample
161     * observation, with columns corresponding to independent variables.
162     * For example, if <pre>
163     * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre>
164     * then <code>setXSampleData(x) </code> results in a model with two independent
165     * variables and 3 observations:
166     * <pre>
167     *   x[0]  x[1]
168     *   ----------
169     *     1    2
170     *     3    4
171     *     5    6
172     * </pre>
173     * <p>Note that there is no need to add an initial unitary column (column of 1's) when
174     * specifying a model including an intercept term.
175     * </p>
176     * @param x the rectangular array representing the x sample
177     * @throws NullArgumentException if x is null
178     * @throws NoDataException if x is empty
179     * @throws DimensionMismatchException if x is not rectangular
180     */
181    protected void newXSampleData(double[][] x) {
182        if (x == null) {
183            throw new NullArgumentException();
184        }
185        if (x.length == 0) {
186            throw new NoDataException();
187        }
188        if (noIntercept) {
189            this.xMatrix = new Array2DRowRealMatrix(x, true);
190        } else { // Augment design matrix with initial unitary column
191            final int nVars = x[0].length;
192            final double[][] xAug = new double[x.length][nVars + 1];
193            for (int i = 0; i < x.length; i++) {
194                if (x[i].length != nVars) {
195                    throw new DimensionMismatchException(x[i].length, nVars);
196                }
197                xAug[i][0] = 1.0d;
198                System.arraycopy(x[i], 0, xAug[i], 1, nVars);
199            }
200            this.xMatrix = new Array2DRowRealMatrix(xAug, false);
201        }
202    }
203
204    /**
205     * Validates sample data.  Checks that
206     * <ul><li>Neither x nor y is null or empty;</li>
207     * <li>The length (i.e. number of rows) of x equals the length of y</li>
208     * <li>x has at least one more row than it has columns (i.e. there is
209     * sufficient data to estimate regression coefficients for each of the
210     * columns in x plus an intercept.</li>
211     * </ul>
212     *
213     * @param x the [n,k] array representing the x data
214     * @param y the [n,1] array representing the y data
215     * @throws NullArgumentException if {@code x} or {@code y} is null
216     * @throws DimensionMismatchException if {@code x} and {@code y} do not
217     * have the same length
218     * @throws NoDataException if {@code x} or {@code y} are zero-length
219     * @throws MathIllegalArgumentException if the number of rows of {@code x}
220     * is not larger than the number of columns + 1
221     */
222    protected void validateSampleData(double[][] x, double[] y) throws MathIllegalArgumentException {
223        if ((x == null) || (y == null)) {
224            throw new NullArgumentException();
225        }
226        if (x.length != y.length) {
227            throw new DimensionMismatchException(y.length, x.length);
228        }
229        if (x.length == 0) {  // Must be no y data either
230            throw new NoDataException();
231        }
232        if (x[0].length + 1 > x.length) {
233            throw new MathIllegalArgumentException(
234                    LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS,
235                    x.length, x[0].length);
236        }
237    }
238
239    /**
240     * Validates that the x data and covariance matrix have the same
241     * number of rows and that the covariance matrix is square.
242     *
243     * @param x the [n,k] array representing the x sample
244     * @param covariance the [n,n] array representing the covariance matrix
245     * @throws DimensionMismatchException if the number of rows in x is not equal
246     * to the number of rows in covariance
247     * @throws NonSquareMatrixException if the covariance matrix is not square
248     */
249    protected void validateCovarianceData(double[][] x, double[][] covariance) {
250        if (x.length != covariance.length) {
251            throw new DimensionMismatchException(x.length, covariance.length);
252        }
253        if (covariance.length > 0 && covariance.length != covariance[0].length) {
254            throw new NonSquareMatrixException(covariance.length, covariance[0].length);
255        }
256    }
257
258    /**
259     * {@inheritDoc}
260     */
261    @Override
262    public double[] estimateRegressionParameters() {
263        RealVector b = calculateBeta();
264        return b.toArray();
265    }
266
267    /**
268     * {@inheritDoc}
269     */
270    @Override
271    public double[] estimateResiduals() {
272        RealVector b = calculateBeta();
273        RealVector e = yVector.subtract(xMatrix.operate(b));
274        return e.toArray();
275    }
276
277    /**
278     * {@inheritDoc}
279     */
280    @Override
281    public double[][] estimateRegressionParametersVariance() {
282        return calculateBetaVariance().getData();
283    }
284
285    /**
286     * {@inheritDoc}
287     */
288    @Override
289    public double[] estimateRegressionParametersStandardErrors() {
290        double[][] betaVariance = estimateRegressionParametersVariance();
291        double sigma = calculateErrorVariance();
292        int length = betaVariance[0].length;
293        double[] result = new double[length];
294        for (int i = 0; i < length; i++) {
295            result[i] = FastMath.sqrt(sigma * betaVariance[i][i]);
296        }
297        return result;
298    }
299
300    /**
301     * {@inheritDoc}
302     */
303    @Override
304    public double estimateRegressandVariance() {
305        return calculateYVariance();
306    }
307
308    /**
309     * Estimates the variance of the error.
310     *
311     * @return estimate of the error variance
312     * @since 2.2
313     */
314    public double estimateErrorVariance() {
315        return calculateErrorVariance();
316
317    }
318
319    /**
320     * Estimates the standard error of the regression.
321     *
322     * @return regression standard error
323     * @since 2.2
324     */
325    public double estimateRegressionStandardError() {
326        return FastMath.sqrt(estimateErrorVariance());
327    }
328
329    /**
330     * Calculates the beta of multiple linear regression in matrix notation.
331     *
332     * @return beta
333     */
334    protected abstract RealVector calculateBeta();
335
336    /**
337     * Calculates the beta variance of multiple linear regression in matrix
338     * notation.
339     *
340     * @return beta variance
341     */
342    protected abstract RealMatrix calculateBetaVariance();
343
344
345    /**
346     * Calculates the variance of the y values.
347     *
348     * @return Y variance
349     */
350    protected double calculateYVariance() {
351        return new Variance().evaluate(yVector.toArray());
352    }
353
354    /**
355     * <p>Calculates the variance of the error term.</p>
356     * Uses the formula <pre>
357     * var(u) = u &middot; u / (n - k)
358     * </pre>
359     * where n and k are the row and column dimensions of the design
360     * matrix X.
361     *
362     * @return error variance estimate
363     * @since 2.2
364     */
365    protected double calculateErrorVariance() {
366        RealVector residuals = calculateResiduals();
367        return residuals.dotProduct(residuals) /
368               (xMatrix.getRowDimension() - xMatrix.getColumnDimension());
369    }
370
371    /**
372     * Calculates the residuals of multiple linear regression in matrix
373     * notation.
374     *
375     * <pre>
376     * u = y - X * b
377     * </pre>
378     *
379     * @return The residuals [n,1] matrix
380     */
381    protected RealVector calculateResiduals() {
382        RealVector b = calculateBeta();
383        return yVector.subtract(xMatrix.operate(b));
384    }
385
386}