001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.stat.regression;
018
019import org.apache.commons.math3.linear.LUDecomposition;
020import org.apache.commons.math3.linear.RealMatrix;
021import org.apache.commons.math3.linear.Array2DRowRealMatrix;
022import org.apache.commons.math3.linear.RealVector;
023
024/**
025 * The GLS implementation of multiple linear regression.
026 *
027 * GLS assumes a general covariance matrix Omega of the error
028 * <pre>
029 * u ~ N(0, Omega)
030 * </pre>
031 *
032 * Estimated by GLS,
033 * <pre>
034 * b=(X' Omega^-1 X)^-1X'Omega^-1 y
035 * </pre>
036 * whose variance is
037 * <pre>
038 * Var(b)=(X' Omega^-1 X)^-1
039 * </pre>
040 * @since 2.0
041 */
042public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
043
044    /** Covariance matrix. */
045    private RealMatrix Omega;
046
047    /** Inverse of covariance matrix. */
048    private RealMatrix OmegaInverse;
049
050    /** Replace sample data, overriding any previous sample.
051     * @param y y values of the sample
052     * @param x x values of the sample
053     * @param covariance array representing the covariance matrix
054     */
055    public void newSampleData(double[] y, double[][] x, double[][] covariance) {
056        validateSampleData(x, y);
057        newYSampleData(y);
058        newXSampleData(x);
059        validateCovarianceData(x, covariance);
060        newCovarianceData(covariance);
061    }
062
063    /**
064     * Add the covariance data.
065     *
066     * @param omega the [n,n] array representing the covariance
067     */
068    protected void newCovarianceData(double[][] omega){
069        this.Omega = new Array2DRowRealMatrix(omega);
070        this.OmegaInverse = null;
071    }
072
073    /**
074     * Get the inverse of the covariance.
075     * <p>The inverse of the covariance matrix is lazily evaluated and cached.</p>
076     * @return inverse of the covariance
077     */
078    protected RealMatrix getOmegaInverse() {
079        if (OmegaInverse == null) {
080            OmegaInverse = new LUDecomposition(Omega).getSolver().getInverse();
081        }
082        return OmegaInverse;
083    }
084
085    /**
086     * Calculates beta by GLS.
087     * <pre>
088     *  b=(X' Omega^-1 X)^-1X'Omega^-1 y
089     * </pre>
090     * @return beta
091     */
092    @Override
093    protected RealVector calculateBeta() {
094        RealMatrix OI = getOmegaInverse();
095        RealMatrix XT = getX().transpose();
096        RealMatrix XTOIX = XT.multiply(OI).multiply(getX());
097        RealMatrix inverse = new LUDecomposition(XTOIX).getSolver().getInverse();
098        return inverse.multiply(XT).multiply(OI).operate(getY());
099    }
100
101    /**
102     * Calculates the variance on the beta.
103     * <pre>
104     *  Var(b)=(X' Omega^-1 X)^-1
105     * </pre>
106     * @return The beta variance matrix
107     */
108    @Override
109    protected RealMatrix calculateBetaVariance() {
110        RealMatrix OI = getOmegaInverse();
111        RealMatrix XTOIX = getX().transpose().multiply(OI).multiply(getX());
112        return new LUDecomposition(XTOIX).getSolver().getInverse();
113    }
114
115
116    /**
117     * Calculates the estimated variance of the error term using the formula
118     * <pre>
119     *  Var(u) = Tr(u' Omega^-1 u)/(n-k)
120     * </pre>
121     * where n and k are the row and column dimensions of the design
122     * matrix X.
123     *
124     * @return error variance
125     * @since 2.2
126     */
127    @Override
128    protected double calculateErrorVariance() {
129        RealVector residuals = calculateResiduals();
130        double t = residuals.dotProduct(getOmegaInverse().operate(residuals));
131        return t / (getX().getRowDimension() - getX().getColumnDimension());
132
133    }
134
135}