AbstractMultipleLinearRegression.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.stat.regression;

  18. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  19. import org.apache.commons.math4.legacy.exception.InsufficientDataException;
  20. import org.apache.commons.math4.legacy.exception.MathIllegalArgumentException;
  21. import org.apache.commons.math4.legacy.exception.NoDataException;
  22. import org.apache.commons.math4.legacy.exception.NullArgumentException;
  23. import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
  24. import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
  25. import org.apache.commons.math4.legacy.linear.ArrayRealVector;
  26. import org.apache.commons.math4.legacy.linear.NonSquareMatrixException;
  27. import org.apache.commons.math4.legacy.linear.RealMatrix;
  28. import org.apache.commons.math4.legacy.linear.RealVector;
  29. import org.apache.commons.math4.legacy.stat.descriptive.moment.Variance;
  30. import org.apache.commons.math4.core.jdkmath.JdkMath;

  31. /**
  32.  * Abstract base class for implementations of MultipleLinearRegression.
  33.  * @since 2.0
  34.  */
  35. public abstract class AbstractMultipleLinearRegression implements
  36.         MultipleLinearRegression {

  37.     /** X sample data. */
  38.     private RealMatrix xMatrix;

  39.     /** Y sample data. */
  40.     private RealVector yVector;

  41.     /** Whether or not the regression model includes an intercept.  True means no intercept. */
  42.     private boolean noIntercept;

  43.     /**
  44.      * @return the X sample data.
  45.      */
  46.     protected RealMatrix getX() {
  47.         return xMatrix;
  48.     }

  49.     /**
  50.      * @return the Y sample data.
  51.      */
  52.     protected RealVector getY() {
  53.         return yVector;
  54.     }

  55.     /**
  56.      * @return true if the model has no intercept term; false otherwise
  57.      * @since 2.2
  58.      */
  59.     public boolean isNoIntercept() {
  60.         return noIntercept;
  61.     }

  62.     /**
  63.      * @param noIntercept true means the model is to be estimated without an intercept term
  64.      * @since 2.2
  65.      */
  66.     public void setNoIntercept(boolean noIntercept) {
  67.         this.noIntercept = noIntercept;
  68.     }

  69.     /**
  70.      * <p>Loads model x and y sample data from a flat input array, overriding any previous sample.
  71.      * </p>
  72.      * <p>Assumes that rows are concatenated with y values first in each row.  For example, an input
  73.      * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with
  74.      * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two
  75.      * independent variables, as below:
  76.      * <pre>
  77.      *   y   x[0]  x[1]
  78.      *   --------------
  79.      *   1     2     3
  80.      *   4     5     6
  81.      *   7     8     9
  82.      * </pre>
  83.      *
  84.      * <p>Note that there is no need to add an initial unitary column (column of 1's) when
  85.      * specifying a model including an intercept term.  If {@link #isNoIntercept()} is <code>true</code>,
  86.      * the X matrix will be created without an initial column of "1"s; otherwise this column will
  87.      * be added.
  88.      * </p>
  89.      * <p>Throws IllegalArgumentException if any of the following preconditions fail:
  90.      * <ul><li><code>data</code> cannot be null</li>
  91.      * <li><code>data.length = nobs * (nvars + 1)</code></li>
  92.      * <li>{@code nobs > nvars}</li></ul>
  93.      *
  94.      * @param data input data array
  95.      * @param nobs number of observations (rows)
  96.      * @param nvars number of independent variables (columns, not counting y)
  97.      * @throws NullArgumentException if the data array is null
  98.      * @throws DimensionMismatchException if the length of the data array is not equal
  99.      * to <code>nobs * (nvars + 1)</code>
  100.      * @throws InsufficientDataException if <code>nobs</code> is less than
  101.      * <code>nvars + 1</code>
  102.      */
  103.     public void newSampleData(double[] data, int nobs, int nvars) {
  104.         if (data == null) {
  105.             throw new NullArgumentException();
  106.         }
  107.         if (data.length != nobs * (nvars + 1)) {
  108.             throw new DimensionMismatchException(data.length, nobs * (nvars + 1));
  109.         }
  110.         if (nobs <= nvars) {
  111.             throw new InsufficientDataException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE, nobs, nvars + 1);
  112.         }
  113.         double[] y = new double[nobs];
  114.         final int cols = noIntercept ? nvars: nvars + 1;
  115.         double[][] x = new double[nobs][cols];
  116.         int pointer = 0;
  117.         for (int i = 0; i < nobs; i++) {
  118.             y[i] = data[pointer++];
  119.             if (!noIntercept) {
  120.                 x[i][0] = 1.0d;
  121.             }
  122.             for (int j = noIntercept ? 0 : 1; j < cols; j++) {
  123.                 x[i][j] = data[pointer++];
  124.             }
  125.         }
  126.         this.xMatrix = new Array2DRowRealMatrix(x);
  127.         this.yVector = new ArrayRealVector(y);
  128.     }

  129.     /**
  130.      * Loads new y sample data, overriding any previous data.
  131.      *
  132.      * @param y the array representing the y sample
  133.      * @throws NullArgumentException if y is null
  134.      * @throws NoDataException if y is empty
  135.      */
  136.     protected void newYSampleData(double[] y) {
  137.         if (y == null) {
  138.             throw new NullArgumentException();
  139.         }
  140.         if (y.length == 0) {
  141.             throw new NoDataException();
  142.         }
  143.         this.yVector = new ArrayRealVector(y);
  144.     }

  145.     /**
  146.      * <p>Loads new x sample data, overriding any previous data.
  147.      * </p>
  148.      * The input <code>x</code> array should have one row for each sample
  149.      * observation, with columns corresponding to independent variables.
  150.      * For example, if <pre>
  151.      * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre>
  152.      * then <code>setXSampleData(x) </code> results in a model with two independent
  153.      * variables and 3 observations:
  154.      * <pre>
  155.      *   x[0]  x[1]
  156.      *   ----------
  157.      *     1    2
  158.      *     3    4
  159.      *     5    6
  160.      * </pre>
  161.      * <p>Note that there is no need to add an initial unitary column (column of 1's) when
  162.      * specifying a model including an intercept term.
  163.      * </p>
  164.      * @param x the rectangular array representing the x sample
  165.      * @throws NullArgumentException if x is null
  166.      * @throws NoDataException if x is empty
  167.      * @throws DimensionMismatchException if x is not rectangular
  168.      */
  169.     protected void newXSampleData(double[][] x) {
  170.         if (x == null) {
  171.             throw new NullArgumentException();
  172.         }
  173.         if (x.length == 0) {
  174.             throw new NoDataException();
  175.         }
  176.         if (noIntercept) {
  177.             this.xMatrix = new Array2DRowRealMatrix(x, true);
  178.         } else { // Augment design matrix with initial unitary column
  179.             final int nVars = x[0].length;
  180.             final double[][] xAug = new double[x.length][nVars + 1];
  181.             for (int i = 0; i < x.length; i++) {
  182.                 if (x[i].length != nVars) {
  183.                     throw new DimensionMismatchException(x[i].length, nVars);
  184.                 }
  185.                 xAug[i][0] = 1.0d;
  186.                 System.arraycopy(x[i], 0, xAug[i], 1, nVars);
  187.             }
  188.             this.xMatrix = new Array2DRowRealMatrix(xAug, false);
  189.         }
  190.     }

  191.     /**
  192.      * Validates sample data.  Checks that
  193.      * <ul><li>Neither x nor y is null or empty;</li>
  194.      * <li>The length (i.e. number of rows) of x equals the length of y</li>
  195.      * <li>x has at least one more row than it has columns (i.e. there is
  196.      * sufficient data to estimate regression coefficients for each of the
  197.      * columns in x plus an intercept.</li>
  198.      * </ul>
  199.      *
  200.      * @param x the [n,k] array representing the x data
  201.      * @param y the [n,1] array representing the y data
  202.      * @throws NullArgumentException if {@code x} or {@code y} is null
  203.      * @throws DimensionMismatchException if {@code x} and {@code y} do not
  204.      * have the same length
  205.      * @throws NoDataException if {@code x} or {@code y} are zero-length
  206.      * @throws MathIllegalArgumentException if the number of rows of {@code x}
  207.      * is not larger than the number of columns + 1
  208.      */
  209.     protected void validateSampleData(double[][] x, double[] y) throws MathIllegalArgumentException {
  210.         if (x == null || y == null) {
  211.             throw new NullArgumentException();
  212.         }
  213.         if (x.length != y.length) {
  214.             throw new DimensionMismatchException(y.length, x.length);
  215.         }
  216.         if (x.length == 0) {  // Must be no y data either
  217.             throw new NoDataException();
  218.         }
  219.         if (x[0].length + 1 > x.length) {
  220.             throw new MathIllegalArgumentException(
  221.                     LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS,
  222.                     x.length, x[0].length);
  223.         }
  224.     }

  225.     /**
  226.      * Validates that the x data and covariance matrix have the same
  227.      * number of rows and that the covariance matrix is square.
  228.      *
  229.      * @param x the [n,k] array representing the x sample
  230.      * @param covariance the [n,n] array representing the covariance matrix
  231.      * @throws DimensionMismatchException if the number of rows in x is not equal
  232.      * to the number of rows in covariance
  233.      * @throws NonSquareMatrixException if the covariance matrix is not square
  234.      */
  235.     protected void validateCovarianceData(double[][] x, double[][] covariance) {
  236.         if (x.length != covariance.length) {
  237.             throw new DimensionMismatchException(x.length, covariance.length);
  238.         }
  239.         if (covariance.length > 0 && covariance.length != covariance[0].length) {
  240.             throw new NonSquareMatrixException(covariance.length, covariance[0].length);
  241.         }
  242.     }

  243.     /**
  244.      * {@inheritDoc}
  245.      */
  246.     @Override
  247.     public double[] estimateRegressionParameters() {
  248.         RealVector b = calculateBeta();
  249.         return b.toArray();
  250.     }

  251.     /**
  252.      * {@inheritDoc}
  253.      */
  254.     @Override
  255.     public double[] estimateResiduals() {
  256.         RealVector b = calculateBeta();
  257.         RealVector e = yVector.subtract(xMatrix.operate(b));
  258.         return e.toArray();
  259.     }

  260.     /**
  261.      * {@inheritDoc}
  262.      */
  263.     @Override
  264.     public double[][] estimateRegressionParametersVariance() {
  265.         return calculateBetaVariance().getData();
  266.     }

  267.     /**
  268.      * {@inheritDoc}
  269.      */
  270.     @Override
  271.     public double[] estimateRegressionParametersStandardErrors() {
  272.         double[][] betaVariance = estimateRegressionParametersVariance();
  273.         double sigma = calculateErrorVariance();
  274.         int length = betaVariance[0].length;
  275.         double[] result = new double[length];
  276.         for (int i = 0; i < length; i++) {
  277.             result[i] = JdkMath.sqrt(sigma * betaVariance[i][i]);
  278.         }
  279.         return result;
  280.     }

  281.     /**
  282.      * {@inheritDoc}
  283.      */
  284.     @Override
  285.     public double estimateRegressandVariance() {
  286.         return calculateYVariance();
  287.     }

  288.     /**
  289.      * Estimates the variance of the error.
  290.      *
  291.      * @return estimate of the error variance
  292.      * @since 2.2
  293.      */
  294.     public double estimateErrorVariance() {
  295.         return calculateErrorVariance();
  296.     }

  297.     /**
  298.      * Estimates the standard error of the regression.
  299.      *
  300.      * @return regression standard error
  301.      * @since 2.2
  302.      */
  303.     public double estimateRegressionStandardError() {
  304.         return JdkMath.sqrt(estimateErrorVariance());
  305.     }

  306.     /**
  307.      * Calculates the beta of multiple linear regression in matrix notation.
  308.      *
  309.      * @return beta
  310.      */
  311.     protected abstract RealVector calculateBeta();

  312.     /**
  313.      * Calculates the beta variance of multiple linear regression in matrix
  314.      * notation.
  315.      *
  316.      * @return beta variance
  317.      */
  318.     protected abstract RealMatrix calculateBetaVariance();


  319.     /**
  320.      * Calculates the variance of the y values.
  321.      *
  322.      * @return Y variance
  323.      */
  324.     protected double calculateYVariance() {
  325.         return new Variance().evaluate(yVector.toArray());
  326.     }

  327.     /**
  328.      * <p>Calculates the variance of the error term.</p>
  329.      * Uses the formula <pre>
  330.      * var(u) = u &middot; u / (n - k)
  331.      * </pre>
  332.      * where n and k are the row and column dimensions of the design
  333.      * matrix X.
  334.      *
  335.      * @return error variance estimate
  336.      * @since 2.2
  337.      */
  338.     protected double calculateErrorVariance() {
  339.         RealVector residuals = calculateResiduals();
  340.         return residuals.dotProduct(residuals) /
  341.                (xMatrix.getRowDimension() - xMatrix.getColumnDimension());
  342.     }

  343.     /**
  344.      * Calculates the residuals of multiple linear regression in matrix
  345.      * notation.
  346.      *
  347.      * <pre>
  348.      * u = y - X * b
  349.      * </pre>
  350.      *
  351.      * @return The residuals [n,1] matrix
  352.      */
  353.     protected RealVector calculateResiduals() {
  354.         RealVector b = calculateBeta();
  355.         return yVector.subtract(xMatrix.operate(b));
  356.     }
  357. }