001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.linear;
019
020import org.apache.commons.math3.exception.DimensionMismatchException;
021import org.apache.commons.math3.util.FastMath;
022
023
024/**
025 * Calculates the Cholesky decomposition of a matrix.
026 * <p>The Cholesky decomposition of a real symmetric positive-definite
027 * matrix A consists of a lower triangular matrix L with same size such
028 * that: A = LL<sup>T</sup>. In a sense, this is the square root of A.</p>
029 * <p>This class is based on the class with similar name from the
030 * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library, with the
031 * following changes:</p>
032 * <ul>
033 *   <li>a {@link #getLT() getLT} method has been added,</li>
034 *   <li>the {@code isspd} method has been removed, since the constructor of
035 *   this class throws a {@link NonPositiveDefiniteMatrixException} when a
036 *   matrix cannot be decomposed,</li>
037 *   <li>a {@link #getDeterminant() getDeterminant} method has been added,</li>
038 *   <li>the {@code solve} method has been replaced by a {@link #getSolver()
039 *   getSolver} method and the equivalent method provided by the returned
040 *   {@link DecompositionSolver}.</li>
041 * </ul>
042 *
043 * @see <a href="http://mathworld.wolfram.com/CholeskyDecomposition.html">MathWorld</a>
044 * @see <a href="http://en.wikipedia.org/wiki/Cholesky_decomposition">Wikipedia</a>
045 * @version $Id: CholeskyDecomposition.java 1566017 2014-02-08 14:13:34Z tn $
046 * @since 2.0 (changed to concrete class in 3.0)
047 */
048public class CholeskyDecomposition {
049    /**
050     * Default threshold above which off-diagonal elements are considered too different
051     * and matrix not symmetric.
052     */
053    public static final double DEFAULT_RELATIVE_SYMMETRY_THRESHOLD = 1.0e-15;
054    /**
055     * Default threshold below which diagonal elements are considered null
056     * and matrix not positive definite.
057     */
058    public static final double DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD = 1.0e-10;
059    /** Row-oriented storage for L<sup>T</sup> matrix data. */
060    private double[][] lTData;
061    /** Cached value of L. */
062    private RealMatrix cachedL;
063    /** Cached value of LT. */
064    private RealMatrix cachedLT;
065
066    /**
067     * Calculates the Cholesky decomposition of the given matrix.
068     * <p>
069     * Calling this constructor is equivalent to call {@link
070     * #CholeskyDecomposition(RealMatrix, double, double)} with the
071     * thresholds set to the default values {@link
072     * #DEFAULT_RELATIVE_SYMMETRY_THRESHOLD} and {@link
073     * #DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD}
074     * </p>
075     * @param matrix the matrix to decompose
076     * @throws NonSquareMatrixException if the matrix is not square.
077     * @throws NonSymmetricMatrixException if the matrix is not symmetric.
078     * @throws NonPositiveDefiniteMatrixException if the matrix is not
079     * strictly positive definite.
080     * @see #CholeskyDecomposition(RealMatrix, double, double)
081     * @see #DEFAULT_RELATIVE_SYMMETRY_THRESHOLD
082     * @see #DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD
083     */
084    public CholeskyDecomposition(final RealMatrix matrix) {
085        this(matrix, DEFAULT_RELATIVE_SYMMETRY_THRESHOLD,
086             DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD);
087    }
088
089    /**
090     * Calculates the Cholesky decomposition of the given matrix.
091     * @param matrix the matrix to decompose
092     * @param relativeSymmetryThreshold threshold above which off-diagonal
093     * elements are considered too different and matrix not symmetric
094     * @param absolutePositivityThreshold threshold below which diagonal
095     * elements are considered null and matrix not positive definite
096     * @throws NonSquareMatrixException if the matrix is not square.
097     * @throws NonSymmetricMatrixException if the matrix is not symmetric.
098     * @throws NonPositiveDefiniteMatrixException if the matrix is not
099     * strictly positive definite.
100     * @see #CholeskyDecomposition(RealMatrix)
101     * @see #DEFAULT_RELATIVE_SYMMETRY_THRESHOLD
102     * @see #DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD
103     */
104    public CholeskyDecomposition(final RealMatrix matrix,
105                                     final double relativeSymmetryThreshold,
106                                     final double absolutePositivityThreshold) {
107        if (!matrix.isSquare()) {
108            throw new NonSquareMatrixException(matrix.getRowDimension(),
109                                               matrix.getColumnDimension());
110        }
111
112        final int order = matrix.getRowDimension();
113        lTData   = matrix.getData();
114        cachedL  = null;
115        cachedLT = null;
116
117        // check the matrix before transformation
118        for (int i = 0; i < order; ++i) {
119            final double[] lI = lTData[i];
120
121            // check off-diagonal elements (and reset them to 0)
122            for (int j = i + 1; j < order; ++j) {
123                final double[] lJ = lTData[j];
124                final double lIJ = lI[j];
125                final double lJI = lJ[i];
126                final double maxDelta =
127                    relativeSymmetryThreshold * FastMath.max(FastMath.abs(lIJ), FastMath.abs(lJI));
128                if (FastMath.abs(lIJ - lJI) > maxDelta) {
129                    throw new NonSymmetricMatrixException(i, j, relativeSymmetryThreshold);
130                }
131                lJ[i] = 0;
132           }
133        }
134
135        // transform the matrix
136        for (int i = 0; i < order; ++i) {
137
138            final double[] ltI = lTData[i];
139
140            // check diagonal element
141            if (ltI[i] <= absolutePositivityThreshold) {
142                throw new NonPositiveDefiniteMatrixException(ltI[i], i, absolutePositivityThreshold);
143            }
144
145            ltI[i] = FastMath.sqrt(ltI[i]);
146            final double inverse = 1.0 / ltI[i];
147
148            for (int q = order - 1; q > i; --q) {
149                ltI[q] *= inverse;
150                final double[] ltQ = lTData[q];
151                for (int p = q; p < order; ++p) {
152                    ltQ[p] -= ltI[q] * ltI[p];
153                }
154            }
155        }
156    }
157
158    /**
159     * Returns the matrix L of the decomposition.
160     * <p>L is an lower-triangular matrix</p>
161     * @return the L matrix
162     */
163    public RealMatrix getL() {
164        if (cachedL == null) {
165            cachedL = getLT().transpose();
166        }
167        return cachedL;
168    }
169
170    /**
171     * Returns the transpose of the matrix L of the decomposition.
172     * <p>L<sup>T</sup> is an upper-triangular matrix</p>
173     * @return the transpose of the matrix L of the decomposition
174     */
175    public RealMatrix getLT() {
176
177        if (cachedLT == null) {
178            cachedLT = MatrixUtils.createRealMatrix(lTData);
179        }
180
181        // return the cached matrix
182        return cachedLT;
183    }
184
185    /**
186     * Return the determinant of the matrix
187     * @return determinant of the matrix
188     */
189    public double getDeterminant() {
190        double determinant = 1.0;
191        for (int i = 0; i < lTData.length; ++i) {
192            double lTii = lTData[i][i];
193            determinant *= lTii * lTii;
194        }
195        return determinant;
196    }
197
198    /**
199     * Get a solver for finding the A &times; X = B solution in least square sense.
200     * @return a solver
201     */
202    public DecompositionSolver getSolver() {
203        return new Solver(lTData);
204    }
205
206    /** Specialized solver. */
207    private static class Solver implements DecompositionSolver {
208        /** Row-oriented storage for L<sup>T</sup> matrix data. */
209        private final double[][] lTData;
210
211        /**
212         * Build a solver from decomposed matrix.
213         * @param lTData row-oriented storage for L<sup>T</sup> matrix data
214         */
215        private Solver(final double[][] lTData) {
216            this.lTData = lTData;
217        }
218
219        /** {@inheritDoc} */
220        public boolean isNonSingular() {
221            // if we get this far, the matrix was positive definite, hence non-singular
222            return true;
223        }
224
225        /** {@inheritDoc} */
226        public RealVector solve(final RealVector b) {
227            final int m = lTData.length;
228            if (b.getDimension() != m) {
229                throw new DimensionMismatchException(b.getDimension(), m);
230            }
231
232            final double[] x = b.toArray();
233
234            // Solve LY = b
235            for (int j = 0; j < m; j++) {
236                final double[] lJ = lTData[j];
237                x[j] /= lJ[j];
238                final double xJ = x[j];
239                for (int i = j + 1; i < m; i++) {
240                    x[i] -= xJ * lJ[i];
241                }
242            }
243
244            // Solve LTX = Y
245            for (int j = m - 1; j >= 0; j--) {
246                x[j] /= lTData[j][j];
247                final double xJ = x[j];
248                for (int i = 0; i < j; i++) {
249                    x[i] -= xJ * lTData[i][j];
250                }
251            }
252
253            return new ArrayRealVector(x, false);
254        }
255
256        /** {@inheritDoc} */
257        public RealMatrix solve(RealMatrix b) {
258            final int m = lTData.length;
259            if (b.getRowDimension() != m) {
260                throw new DimensionMismatchException(b.getRowDimension(), m);
261            }
262
263            final int nColB = b.getColumnDimension();
264            final double[][] x = b.getData();
265
266            // Solve LY = b
267            for (int j = 0; j < m; j++) {
268                final double[] lJ = lTData[j];
269                final double lJJ = lJ[j];
270                final double[] xJ = x[j];
271                for (int k = 0; k < nColB; ++k) {
272                    xJ[k] /= lJJ;
273                }
274                for (int i = j + 1; i < m; i++) {
275                    final double[] xI = x[i];
276                    final double lJI = lJ[i];
277                    for (int k = 0; k < nColB; ++k) {
278                        xI[k] -= xJ[k] * lJI;
279                    }
280                }
281            }
282
283            // Solve LTX = Y
284            for (int j = m - 1; j >= 0; j--) {
285                final double lJJ = lTData[j][j];
286                final double[] xJ = x[j];
287                for (int k = 0; k < nColB; ++k) {
288                    xJ[k] /= lJJ;
289                }
290                for (int i = 0; i < j; i++) {
291                    final double[] xI = x[i];
292                    final double lIJ = lTData[i][j];
293                    for (int k = 0; k < nColB; ++k) {
294                        xI[k] -= xJ[k] * lIJ;
295                    }
296                }
297            }
298
299            return new Array2DRowRealMatrix(x);
300        }
301
302        /**
303         * Get the inverse of the decomposed matrix.
304         *
305         * @return the inverse matrix.
306         */
307        public RealMatrix getInverse() {
308            return solve(MatrixUtils.createRealIdentityMatrix(lTData.length));
309        }
310    }
311}