OLSMultipleLinearRegression.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.stat.regression;

  18. import org.apache.commons.math4.legacy.exception.MathIllegalArgumentException;
  19. import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
  20. import org.apache.commons.math4.legacy.linear.LUDecomposition;
  21. import org.apache.commons.math4.legacy.linear.QRDecomposition;
  22. import org.apache.commons.math4.legacy.linear.RealMatrix;
  23. import org.apache.commons.math4.legacy.linear.RealVector;
  24. import org.apache.commons.math4.legacy.stat.StatUtils;
  25. import org.apache.commons.math4.legacy.stat.descriptive.moment.SecondMoment;

  26. /**
  27.  * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
  28.  * multiple linear regression model.</p>
  29.  *
  30.  * <p>The regression coefficients, <code>b</code>, satisfy the normal equations:
  31.  * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre>
  32.  *
  33.  * <p>To solve the normal equations, this implementation uses QR decomposition
  34.  * of the <code>X</code> matrix. (See {@link QRDecomposition} for details on the
  35.  * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i>
  36.  * has rows corresponding to sample observations and columns corresponding to independent
  37.  * variables.  When the model is estimated using an intercept term (i.e. when
  38.  * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code>
  39.  * matrix includes an initial column identically equal to 1.  We solve the normal equations
  40.  * as follows:
  41.  * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y
  42.  * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y
  43.  * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y
  44.  * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y
  45.  * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y
  46.  * R b = Q<sup>T</sup> y </code></pre>
  47.  *
  48.  * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p>
  49.  *
  50.  * @since 2.0
  51.  */
  52. public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {

  53.     /** Cached QR decomposition of X matrix. */
  54.     private QRDecomposition qr;

  55.     /** Singularity threshold for QR decomposition. */
  56.     private final double threshold;

  57.     /**
  58.      * Create an empty OLSMultipleLinearRegression instance.
  59.      */
  60.     public OLSMultipleLinearRegression() {
  61.         this(0d);
  62.     }

  63.     /**
  64.      * Create an empty OLSMultipleLinearRegression instance, using the given
  65.      * singularity threshold for the QR decomposition.
  66.      *
  67.      * @param threshold the singularity threshold
  68.      * @since 3.3
  69.      */
  70.     public OLSMultipleLinearRegression(final double threshold) {
  71.         this.threshold = threshold;
  72.     }

  73.     /**
  74.      * Loads model x and y sample data, overriding any previous sample.
  75.      *
  76.      * Computes and caches QR decomposition of the X matrix.
  77.      * @param y the [n,1] array representing the y sample
  78.      * @param x the [n,k] array representing the x sample
  79.      * @throws MathIllegalArgumentException if the x and y array data are not
  80.      *             compatible for the regression
  81.      */
  82.     public void newSampleData(double[] y, double[][] x) throws MathIllegalArgumentException {
  83.         validateSampleData(x, y);
  84.         newYSampleData(y);
  85.         newXSampleData(x);
  86.     }

  87.     /**
  88.      * {@inheritDoc}
  89.      * <p>This implementation computes and caches the QR decomposition of the X matrix.</p>
  90.      */
  91.     @Override
  92.     public void newSampleData(double[] data, int nobs, int nvars) {
  93.         super.newSampleData(data, nobs, nvars);
  94.         qr = new QRDecomposition(getX(), threshold);
  95.     }

  96.     /**
  97.      * <p>Compute the "hat" matrix.
  98.      * </p>
  99.      * <p>The hat matrix is defined in terms of the design matrix X
  100.      *  by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
  101.      * </p>
  102.      * <p>The implementation here uses the QR decomposition to compute the
  103.      * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
  104.      * p-dimensional identity matrix augmented by 0's.  This computational
  105.      * formula is from "The Hat Matrix in Regression and ANOVA",
  106.      * David C. Hoaglin and Roy E. Welsch,
  107.      * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
  108.      * </p>
  109.      * <p>Data for the model must have been successfully loaded using one of
  110.      * the {@code newSampleData} methods before invoking this method; otherwise
  111.      * a {@code NullPointerException} will be thrown.</p>
  112.      *
  113.      * @return the hat matrix
  114.      * @throws NullPointerException unless method {@code newSampleData} has been
  115.      * called beforehand.
  116.      */
  117.     public RealMatrix calculateHat() {
  118.         // Create augmented identity matrix
  119.         RealMatrix q = qr.getQ();
  120.         final int p = qr.getR().getColumnDimension();
  121.         final int n = q.getColumnDimension();
  122.         // No try-catch or advertised NotStrictlyPositiveException - NPE above if n < 3
  123.         Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
  124.         double[][] augIData = augI.getDataRef();
  125.         for (int i = 0; i < n; i++) {
  126.             for (int j =0; j < n; j++) {
  127.                 if (i == j && i < p) {
  128.                     augIData[i][j] = 1d;
  129.                 } else {
  130.                     augIData[i][j] = 0d;
  131.                 }
  132.             }
  133.         }

  134.         // Compute and return Hat matrix
  135.         // No DME advertised - args valid if we get here
  136.         return q.multiply(augI).multiply(q.transpose());
  137.     }

  138.     /**
  139.      * <p>Returns the sum of squared deviations of Y from its mean.</p>
  140.      *
  141.      * <p>If the model has no intercept term, <code>0</code> is used for the
  142.      * mean of Y - i.e., what is returned is the sum of the squared Y values.</p>
  143.      *
  144.      * <p>The value returned by this method is the SSTO value used in
  145.      * the {@link #calculateRSquared() R-squared} computation.</p>
  146.      *
  147.      * @return SSTO - the total sum of squares
  148.      * @throws NullPointerException if the sample has not been set
  149.      * @see #isNoIntercept()
  150.      * @since 2.2
  151.      */
  152.     public double calculateTotalSumOfSquares() {
  153.         if (isNoIntercept()) {
  154.             return StatUtils.sumSq(getY().toArray());
  155.         } else {
  156.             return new SecondMoment().evaluate(getY().toArray());
  157.         }
  158.     }

  159.     /**
  160.      * Returns the sum of squared residuals.
  161.      *
  162.      * @return residual sum of squares
  163.      * @since 2.2
  164.      * @throws org.apache.commons.math4.legacy.linear.SingularMatrixException if the design matrix is singular
  165.      * @throws NullPointerException if the data for the model have not been loaded
  166.      */
  167.     public double calculateResidualSumOfSquares() {
  168.         final RealVector residuals = calculateResiduals();
  169.         // No advertised DME, args are valid
  170.         return residuals.dotProduct(residuals);
  171.     }

  172.     /**
  173.      * Returns the R-Squared statistic, defined by the formula <div style="white-space: pre"><code>
  174.      * R<sup>2</sup> = 1 - SSR / SSTO
  175.      * </code></div>
  176.      * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
  177.      * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
  178.      *
  179.      * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
  180.      *
  181.      * @return R-square statistic
  182.      * @throws NullPointerException if the sample has not been set
  183.      * @throws org.apache.commons.math4.legacy.linear.SingularMatrixException if the design matrix is singular
  184.      * @since 2.2
  185.      */
  186.     public double calculateRSquared() {
  187.         return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
  188.     }

  189.     /**
  190.      * <p>Returns the adjusted R-squared statistic, defined by the formula <div style="white-space: pre"><code>
  191.      * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)]
  192.      * </code></div>
  193.      * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
  194.      * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
  195.      * of observations and p is the number of parameters estimated (including the intercept).
  196.      *
  197.      * <p>If the regression is estimated without an intercept term, what is returned is <pre>
  198.      * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code>
  199.      * </pre>
  200.      *
  201.      * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
  202.      *
  203.      * @return adjusted R-Squared statistic
  204.      * @throws NullPointerException if the sample has not been set
  205.      * @throws org.apache.commons.math4.legacy.linear.SingularMatrixException if the design matrix is singular
  206.      * @see #isNoIntercept()
  207.      * @since 2.2
  208.      */
  209.     public double calculateAdjustedRSquared() {
  210.         final double n = getX().getRowDimension();
  211.         if (isNoIntercept()) {
  212.             return 1 - (1 - calculateRSquared()) * (n / (n - getX().getColumnDimension()));
  213.         } else {
  214.             return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
  215.                 (calculateTotalSumOfSquares() * (n - getX().getColumnDimension()));
  216.         }
  217.     }

  218.     /**
  219.      * {@inheritDoc}
  220.      * <p>This implementation computes and caches the QR decomposition of the X matrix
  221.      * once it is successfully loaded.</p>
  222.      */
  223.     @Override
  224.     protected void newXSampleData(double[][] x) {
  225.         super.newXSampleData(x);
  226.         qr = new QRDecomposition(getX(), threshold);
  227.     }

  228.     /**
  229.      * Calculates the regression coefficients using OLS.
  230.      *
  231.      * <p>Data for the model must have been successfully loaded using one of
  232.      * the {@code newSampleData} methods before invoking this method; otherwise
  233.      * a {@code NullPointerException} will be thrown.</p>
  234.      *
  235.      * @return beta
  236.      * @throws org.apache.commons.math4.legacy.linear.SingularMatrixException if the design matrix is singular
  237.      * @throws NullPointerException if the data for the model have not been loaded
  238.      */
  239.     @Override
  240.     protected RealVector calculateBeta() {
  241.         return qr.getSolver().solve(getY());
  242.     }

  243.     /**
  244.      * <p>Calculates the variance-covariance matrix of the regression parameters.
  245.      * </p>
  246.      * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
  247.      * </p>
  248.      * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
  249.      * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
  250.      * R included, where p = the length of the beta vector.</p>
  251.      *
  252.      * <p>Data for the model must have been successfully loaded using one of
  253.      * the {@code newSampleData} methods before invoking this method; otherwise
  254.      * a {@code NullPointerException} will be thrown.</p>
  255.      *
  256.      * @return The beta variance-covariance matrix
  257.      * @throws org.apache.commons.math4.legacy.linear.SingularMatrixException if the design matrix is singular
  258.      * @throws NullPointerException if the data for the model have not been loaded
  259.      */
  260.     @Override
  261.     protected RealMatrix calculateBetaVariance() {
  262.         int p = getX().getColumnDimension();
  263.         RealMatrix rAug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
  264.         RealMatrix rInv = new LUDecomposition(rAug).getSolver().getInverse();
  265.         return rInv.multiply(rInv.transpose());
  266.     }
  267. }