001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.legacy.linear;
019
020import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
021import org.apache.commons.math4.core.jdkmath.JdkMath;
022
023
024/**
025 * Calculates the Cholesky decomposition of a matrix.
026 * <p>The Cholesky decomposition of a real symmetric positive-definite
027 * matrix A consists of a lower triangular matrix L with same size such
028 * that: A = LL<sup>T</sup>. In a sense, this is the square root of A.</p>
029 * <p>This class is based on the class with similar name from the
030 * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library, with the
031 * following changes:</p>
032 * <ul>
033 *   <li>a {@link #getLT() getLT} method has been added,</li>
034 *   <li>the {@code isspd} method has been removed, since the constructor of
035 *   this class throws a {@link NonPositiveDefiniteMatrixException} when a
036 *   matrix cannot be decomposed,</li>
037 *   <li>a {@link #getDeterminant() getDeterminant} method has been added,</li>
038 *   <li>the {@code solve} method has been replaced by a {@link #getSolver()
039 *   getSolver} method and the equivalent method provided by the returned
040 *   {@link DecompositionSolver}.</li>
041 * </ul>
042 *
043 * @see <a href="http://mathworld.wolfram.com/CholeskyDecomposition.html">MathWorld</a>
044 * @see <a href="http://en.wikipedia.org/wiki/Cholesky_decomposition">Wikipedia</a>
045 * @since 2.0 (changed to concrete class in 3.0)
046 */
047public class CholeskyDecomposition {
048    /**
049     * Default threshold above which off-diagonal elements are considered too different
050     * and matrix not symmetric.
051     */
052    public static final double DEFAULT_RELATIVE_SYMMETRY_THRESHOLD = 1.0e-15;
053    /**
054     * Default threshold below which diagonal elements are considered null
055     * and matrix not positive definite.
056     */
057    public static final double DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD = 1.0e-10;
058    /** Row-oriented storage for L<sup>T</sup> matrix data. */
059    private final double[][] lTData;
060    /** Cached value of L. */
061    private RealMatrix cachedL;
062    /** Cached value of LT. */
063    private RealMatrix cachedLT;
064
065    /**
066     * Calculates the Cholesky decomposition of the given matrix.
067     * <p>
068     * Calling this constructor is equivalent to call {@link
069     * #CholeskyDecomposition(RealMatrix, double, double)} with the
070     * thresholds set to the default values {@link
071     * #DEFAULT_RELATIVE_SYMMETRY_THRESHOLD} and {@link
072     * #DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD}
073     * </p>
074     * @param matrix the matrix to decompose
075     * @throws NonSquareMatrixException if the matrix is not square.
076     * @throws NonSymmetricMatrixException if the matrix is not symmetric.
077     * @throws NonPositiveDefiniteMatrixException if the matrix is not
078     * strictly positive definite.
079     * @see #CholeskyDecomposition(RealMatrix, double, double)
080     * @see #DEFAULT_RELATIVE_SYMMETRY_THRESHOLD
081     * @see #DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD
082     */
083    public CholeskyDecomposition(final RealMatrix matrix) {
084        this(matrix, DEFAULT_RELATIVE_SYMMETRY_THRESHOLD,
085             DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD);
086    }
087
088    /**
089     * Calculates the Cholesky decomposition of the given matrix.
090     * @param matrix the matrix to decompose
091     * @param relativeSymmetryThreshold threshold above which off-diagonal
092     * elements are considered too different and matrix not symmetric
093     * @param absolutePositivityThreshold threshold below which diagonal
094     * elements are considered null and matrix not positive definite
095     * @throws NonSquareMatrixException if the matrix is not square.
096     * @throws NonSymmetricMatrixException if the matrix is not symmetric.
097     * @throws NonPositiveDefiniteMatrixException if the matrix is not
098     * strictly positive definite.
099     * @see #CholeskyDecomposition(RealMatrix)
100     * @see #DEFAULT_RELATIVE_SYMMETRY_THRESHOLD
101     * @see #DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD
102     */
103    public CholeskyDecomposition(final RealMatrix matrix,
104                                     final double relativeSymmetryThreshold,
105                                     final double absolutePositivityThreshold) {
106        if (!matrix.isSquare()) {
107            throw new NonSquareMatrixException(matrix.getRowDimension(),
108                                               matrix.getColumnDimension());
109        }
110
111        final int order = matrix.getRowDimension();
112        lTData   = matrix.getData();
113        cachedL  = null;
114        cachedLT = null;
115
116        // check the matrix before transformation
117        for (int i = 0; i < order; ++i) {
118            final double[] lI = lTData[i];
119
120            // check off-diagonal elements (and reset them to 0)
121            for (int j = i + 1; j < order; ++j) {
122                final double[] lJ = lTData[j];
123                final double lIJ = lI[j];
124                final double lJI = lJ[i];
125                final double maxDelta =
126                    relativeSymmetryThreshold * JdkMath.max(JdkMath.abs(lIJ), JdkMath.abs(lJI));
127                if (JdkMath.abs(lIJ - lJI) > maxDelta) {
128                    throw new NonSymmetricMatrixException(i, j, relativeSymmetryThreshold);
129                }
130                lJ[i] = 0;
131           }
132        }
133
134        // transform the matrix
135        for (int i = 0; i < order; ++i) {
136
137            final double[] ltI = lTData[i];
138
139            // check diagonal element
140            if (ltI[i] <= absolutePositivityThreshold) {
141                throw new NonPositiveDefiniteMatrixException(ltI[i], i, absolutePositivityThreshold);
142            }
143
144            ltI[i] = JdkMath.sqrt(ltI[i]);
145            final double inverse = 1.0 / ltI[i];
146
147            for (int q = order - 1; q > i; --q) {
148                ltI[q] *= inverse;
149                final double[] ltQ = lTData[q];
150                for (int p = q; p < order; ++p) {
151                    ltQ[p] -= ltI[q] * ltI[p];
152                }
153            }
154        }
155    }
156
157    /**
158     * Returns the matrix L of the decomposition.
159     * <p>L is an lower-triangular matrix</p>
160     * @return the L matrix
161     */
162    public RealMatrix getL() {
163        if (cachedL == null) {
164            cachedL = getLT().transpose();
165        }
166        return cachedL;
167    }
168
169    /**
170     * Returns the transpose of the matrix L of the decomposition.
171     * <p>L<sup>T</sup> is an upper-triangular matrix</p>
172     * @return the transpose of the matrix L of the decomposition
173     */
174    public RealMatrix getLT() {
175
176        if (cachedLT == null) {
177            cachedLT = MatrixUtils.createRealMatrix(lTData);
178        }
179
180        // return the cached matrix
181        return cachedLT;
182    }
183
184    /**
185     * Return the determinant of the matrix.
186     * @return determinant of the matrix
187     */
188    public double getDeterminant() {
189        double determinant = 1.0;
190        for (int i = 0; i < lTData.length; ++i) {
191            double lTii = lTData[i][i];
192            determinant *= lTii * lTii;
193        }
194        return determinant;
195    }
196
197    /**
198     * Get a solver for finding the A &times; X = B solution in least square sense.
199     * @return a solver
200     */
201    public DecompositionSolver getSolver() {
202        return new Solver(lTData);
203    }
204
205    /** Specialized solver. */
206    private static final class Solver implements DecompositionSolver {
207        /** Row-oriented storage for L<sup>T</sup> matrix data. */
208        private final double[][] lTData;
209
210        /**
211         * Build a solver from decomposed matrix.
212         * @param lTData row-oriented storage for L<sup>T</sup> matrix data
213         */
214        private Solver(final double[][] lTData) {
215            this.lTData = lTData;
216        }
217
218        /** {@inheritDoc} */
219        @Override
220        public boolean isNonSingular() {
221            // if we get this far, the matrix was positive definite, hence non-singular
222            return true;
223        }
224
225        /** {@inheritDoc} */
226        @Override
227        public RealVector solve(final RealVector b) {
228            final int m = lTData.length;
229            if (b.getDimension() != m) {
230                throw new DimensionMismatchException(b.getDimension(), m);
231            }
232
233            final double[] x = b.toArray();
234
235            // Solve LY = b
236            for (int j = 0; j < m; j++) {
237                final double[] lJ = lTData[j];
238                x[j] /= lJ[j];
239                final double xJ = x[j];
240                for (int i = j + 1; i < m; i++) {
241                    x[i] -= xJ * lJ[i];
242                }
243            }
244
245            // Solve LTX = Y
246            for (int j = m - 1; j >= 0; j--) {
247                x[j] /= lTData[j][j];
248                final double xJ = x[j];
249                for (int i = 0; i < j; i++) {
250                    x[i] -= xJ * lTData[i][j];
251                }
252            }
253
254            return new ArrayRealVector(x, false);
255        }
256
257        /** {@inheritDoc} */
258        @Override
259        public RealMatrix solve(RealMatrix b) {
260            final int m = lTData.length;
261            if (b.getRowDimension() != m) {
262                throw new DimensionMismatchException(b.getRowDimension(), m);
263            }
264
265            final int nColB = b.getColumnDimension();
266            final double[][] x = b.getData();
267
268            // Solve LY = b
269            for (int j = 0; j < m; j++) {
270                final double[] lJ = lTData[j];
271                final double lJJ = lJ[j];
272                final double[] xJ = x[j];
273                for (int k = 0; k < nColB; ++k) {
274                    xJ[k] /= lJJ;
275                }
276                for (int i = j + 1; i < m; i++) {
277                    final double[] xI = x[i];
278                    final double lJI = lJ[i];
279                    for (int k = 0; k < nColB; ++k) {
280                        xI[k] -= xJ[k] * lJI;
281                    }
282                }
283            }
284
285            // Solve LTX = Y
286            for (int j = m - 1; j >= 0; j--) {
287                final double lJJ = lTData[j][j];
288                final double[] xJ = x[j];
289                for (int k = 0; k < nColB; ++k) {
290                    xJ[k] /= lJJ;
291                }
292                for (int i = 0; i < j; i++) {
293                    final double[] xI = x[i];
294                    final double lIJ = lTData[i][j];
295                    for (int k = 0; k < nColB; ++k) {
296                        xI[k] -= xJ[k] * lIJ;
297                    }
298                }
299            }
300
301            return new Array2DRowRealMatrix(x);
302        }
303
304        /**
305         * Get the inverse of the decomposed matrix.
306         *
307         * @return the inverse matrix.
308         */
309        @Override
310        public RealMatrix getInverse() {
311            return solve(MatrixUtils.createRealIdentityMatrix(lTData.length));
312        }
313    }
314}