001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.linear;
019
020import org.apache.commons.math3.exception.DimensionMismatchException;
021import org.apache.commons.math3.util.FastMath;
022
023
024/**
025 * Calculates the Cholesky decomposition of a matrix.
026 * <p>The Cholesky decomposition of a real symmetric positive-definite
027 * matrix A consists of a lower triangular matrix L with same size such
028 * that: A = LL<sup>T</sup>. In a sense, this is the square root of A.</p>
029 * <p>This class is based on the class with similar name from the
030 * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library, with the
031 * following changes:</p>
032 * <ul>
033 *   <li>a {@link #getLT() getLT} method has been added,</li>
034 *   <li>the {@code isspd} method has been removed, since the constructor of
035 *   this class throws a {@link NonPositiveDefiniteMatrixException} when a
036 *   matrix cannot be decomposed,</li>
037 *   <li>a {@link #getDeterminant() getDeterminant} method has been added,</li>
038 *   <li>the {@code solve} method has been replaced by a {@link #getSolver()
039 *   getSolver} method and the equivalent method provided by the returned
040 *   {@link DecompositionSolver}.</li>
041 * </ul>
042 *
043 * @see <a href="http://mathworld.wolfram.com/CholeskyDecomposition.html">MathWorld</a>
044 * @see <a href="http://en.wikipedia.org/wiki/Cholesky_decomposition">Wikipedia</a>
045 * @since 2.0 (changed to concrete class in 3.0)
046 */
047public class CholeskyDecomposition {
048    /**
049     * Default threshold above which off-diagonal elements are considered too different
050     * and matrix not symmetric.
051     */
052    public static final double DEFAULT_RELATIVE_SYMMETRY_THRESHOLD = 1.0e-15;
053    /**
054     * Default threshold below which diagonal elements are considered null
055     * and matrix not positive definite.
056     */
057    public static final double DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD = 1.0e-10;
058    /** Row-oriented storage for L<sup>T</sup> matrix data. */
059    private double[][] lTData;
060    /** Cached value of L. */
061    private RealMatrix cachedL;
062    /** Cached value of LT. */
063    private RealMatrix cachedLT;
064
065    /**
066     * Calculates the Cholesky decomposition of the given matrix.
067     * <p>
068     * Calling this constructor is equivalent to call {@link
069     * #CholeskyDecomposition(RealMatrix, double, double)} with the
070     * thresholds set to the default values {@link
071     * #DEFAULT_RELATIVE_SYMMETRY_THRESHOLD} and {@link
072     * #DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD}
073     * </p>
074     * @param matrix the matrix to decompose
075     * @throws NonSquareMatrixException if the matrix is not square.
076     * @throws NonSymmetricMatrixException if the matrix is not symmetric.
077     * @throws NonPositiveDefiniteMatrixException if the matrix is not
078     * strictly positive definite.
079     * @see #CholeskyDecomposition(RealMatrix, double, double)
080     * @see #DEFAULT_RELATIVE_SYMMETRY_THRESHOLD
081     * @see #DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD
082     */
083    public CholeskyDecomposition(final RealMatrix matrix) {
084        this(matrix, DEFAULT_RELATIVE_SYMMETRY_THRESHOLD,
085             DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD);
086    }
087
088    /**
089     * Calculates the Cholesky decomposition of the given matrix.
090     * @param matrix the matrix to decompose
091     * @param relativeSymmetryThreshold threshold above which off-diagonal
092     * elements are considered too different and matrix not symmetric
093     * @param absolutePositivityThreshold threshold below which diagonal
094     * elements are considered null and matrix not positive definite
095     * @throws NonSquareMatrixException if the matrix is not square.
096     * @throws NonSymmetricMatrixException if the matrix is not symmetric.
097     * @throws NonPositiveDefiniteMatrixException if the matrix is not
098     * strictly positive definite.
099     * @see #CholeskyDecomposition(RealMatrix)
100     * @see #DEFAULT_RELATIVE_SYMMETRY_THRESHOLD
101     * @see #DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD
102     */
103    public CholeskyDecomposition(final RealMatrix matrix,
104                                     final double relativeSymmetryThreshold,
105                                     final double absolutePositivityThreshold) {
106        if (!matrix.isSquare()) {
107            throw new NonSquareMatrixException(matrix.getRowDimension(),
108                                               matrix.getColumnDimension());
109        }
110
111        final int order = matrix.getRowDimension();
112        lTData   = matrix.getData();
113        cachedL  = null;
114        cachedLT = null;
115
116        // check the matrix before transformation
117        for (int i = 0; i < order; ++i) {
118            final double[] lI = lTData[i];
119
120            // check off-diagonal elements (and reset them to 0)
121            for (int j = i + 1; j < order; ++j) {
122                final double[] lJ = lTData[j];
123                final double lIJ = lI[j];
124                final double lJI = lJ[i];
125                final double maxDelta =
126                    relativeSymmetryThreshold * FastMath.max(FastMath.abs(lIJ), FastMath.abs(lJI));
127                if (FastMath.abs(lIJ - lJI) > maxDelta) {
128                    throw new NonSymmetricMatrixException(i, j, relativeSymmetryThreshold);
129                }
130                lJ[i] = 0;
131           }
132        }
133
134        // transform the matrix
135        for (int i = 0; i < order; ++i) {
136
137            final double[] ltI = lTData[i];
138
139            // check diagonal element
140            if (ltI[i] <= absolutePositivityThreshold) {
141                throw new NonPositiveDefiniteMatrixException(ltI[i], i, absolutePositivityThreshold);
142            }
143
144            ltI[i] = FastMath.sqrt(ltI[i]);
145            final double inverse = 1.0 / ltI[i];
146
147            for (int q = order - 1; q > i; --q) {
148                ltI[q] *= inverse;
149                final double[] ltQ = lTData[q];
150                for (int p = q; p < order; ++p) {
151                    ltQ[p] -= ltI[q] * ltI[p];
152                }
153            }
154        }
155    }
156
157    /**
158     * Returns the matrix L of the decomposition.
159     * <p>L is an lower-triangular matrix</p>
160     * @return the L matrix
161     */
162    public RealMatrix getL() {
163        if (cachedL == null) {
164            cachedL = getLT().transpose();
165        }
166        return cachedL;
167    }
168
169    /**
170     * Returns the transpose of the matrix L of the decomposition.
171     * <p>L<sup>T</sup> is an upper-triangular matrix</p>
172     * @return the transpose of the matrix L of the decomposition
173     */
174    public RealMatrix getLT() {
175
176        if (cachedLT == null) {
177            cachedLT = MatrixUtils.createRealMatrix(lTData);
178        }
179
180        // return the cached matrix
181        return cachedLT;
182    }
183
184    /**
185     * Return the determinant of the matrix
186     * @return determinant of the matrix
187     */
188    public double getDeterminant() {
189        double determinant = 1.0;
190        for (int i = 0; i < lTData.length; ++i) {
191            double lTii = lTData[i][i];
192            determinant *= lTii * lTii;
193        }
194        return determinant;
195    }
196
197    /**
198     * Get a solver for finding the A &times; X = B solution in least square sense.
199     * @return a solver
200     */
201    public DecompositionSolver getSolver() {
202        return new Solver(lTData);
203    }
204
205    /** Specialized solver. */
206    private static class Solver implements DecompositionSolver {
207        /** Row-oriented storage for L<sup>T</sup> matrix data. */
208        private final double[][] lTData;
209
210        /**
211         * Build a solver from decomposed matrix.
212         * @param lTData row-oriented storage for L<sup>T</sup> matrix data
213         */
214        private Solver(final double[][] lTData) {
215            this.lTData = lTData;
216        }
217
218        /** {@inheritDoc} */
219        public boolean isNonSingular() {
220            // if we get this far, the matrix was positive definite, hence non-singular
221            return true;
222        }
223
224        /** {@inheritDoc} */
225        public RealVector solve(final RealVector b) {
226            final int m = lTData.length;
227            if (b.getDimension() != m) {
228                throw new DimensionMismatchException(b.getDimension(), m);
229            }
230
231            final double[] x = b.toArray();
232
233            // Solve LY = b
234            for (int j = 0; j < m; j++) {
235                final double[] lJ = lTData[j];
236                x[j] /= lJ[j];
237                final double xJ = x[j];
238                for (int i = j + 1; i < m; i++) {
239                    x[i] -= xJ * lJ[i];
240                }
241            }
242
243            // Solve LTX = Y
244            for (int j = m - 1; j >= 0; j--) {
245                x[j] /= lTData[j][j];
246                final double xJ = x[j];
247                for (int i = 0; i < j; i++) {
248                    x[i] -= xJ * lTData[i][j];
249                }
250            }
251
252            return new ArrayRealVector(x, false);
253        }
254
255        /** {@inheritDoc} */
256        public RealMatrix solve(RealMatrix b) {
257            final int m = lTData.length;
258            if (b.getRowDimension() != m) {
259                throw new DimensionMismatchException(b.getRowDimension(), m);
260            }
261
262            final int nColB = b.getColumnDimension();
263            final double[][] x = b.getData();
264
265            // Solve LY = b
266            for (int j = 0; j < m; j++) {
267                final double[] lJ = lTData[j];
268                final double lJJ = lJ[j];
269                final double[] xJ = x[j];
270                for (int k = 0; k < nColB; ++k) {
271                    xJ[k] /= lJJ;
272                }
273                for (int i = j + 1; i < m; i++) {
274                    final double[] xI = x[i];
275                    final double lJI = lJ[i];
276                    for (int k = 0; k < nColB; ++k) {
277                        xI[k] -= xJ[k] * lJI;
278                    }
279                }
280            }
281
282            // Solve LTX = Y
283            for (int j = m - 1; j >= 0; j--) {
284                final double lJJ = lTData[j][j];
285                final double[] xJ = x[j];
286                for (int k = 0; k < nColB; ++k) {
287                    xJ[k] /= lJJ;
288                }
289                for (int i = 0; i < j; i++) {
290                    final double[] xI = x[i];
291                    final double lIJ = lTData[i][j];
292                    for (int k = 0; k < nColB; ++k) {
293                        xI[k] -= xJ[k] * lIJ;
294                    }
295                }
296            }
297
298            return new Array2DRowRealMatrix(x);
299        }
300
301        /**
302         * Get the inverse of the decomposed matrix.
303         *
304         * @return the inverse matrix.
305         */
306        public RealMatrix getInverse() {
307            return solve(MatrixUtils.createRealIdentityMatrix(lTData.length));
308        }
309    }
310}