001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.linear;
019    
020    import org.apache.commons.math.exception.DimensionMismatchException;
021    import org.apache.commons.math.util.FastMath;
022    
023    /**
024     * Calculates the LUP-decomposition of a square matrix.
025     * <p>The LUP-decomposition of a matrix A consists of three matrices
026     * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
027     * upper triangular and P is a permutation matrix. All matrices are
028     * m&times;m.</p>
029     * <p>As shown by the presence of the P matrix, this decomposition is
030     * implemented using partial pivoting.</p>
031     *
032     * @version $Id: LUDecompositionImpl.java 1131229 2011-06-03 20:49:25Z luc $
033     * @since 2.0
034     */
035    public class LUDecompositionImpl implements LUDecomposition {
036        /** Default bound to determine effective singularity in LU decomposition */
037        private static final double DEFAULT_TOO_SMALL = 10E-12;
038        /** Entries of LU decomposition. */
039        private double lu[][];
040        /** Pivot permutation associated with LU decomposition */
041        private int[] pivot;
042        /** Parity of the permutation associated with the LU decomposition */
043        private boolean even;
044        /** Singularity indicator. */
045        private boolean singular;
046        /** Cached value of L. */
047        private RealMatrix cachedL;
048        /** Cached value of U. */
049        private RealMatrix cachedU;
050        /** Cached value of P. */
051        private RealMatrix cachedP;
052    
053        /**
054         * Calculates the LU-decomposition of the given matrix.
055         * @param matrix Matrix to decompose.
056         * @throws NonSquareMatrixException if matrix is not square.
057         */
058        public LUDecompositionImpl(RealMatrix matrix) {
059            this(matrix, DEFAULT_TOO_SMALL);
060        }
061    
062        /**
063         * Calculates the LU-decomposition of the given matrix.
064         * @param matrix The matrix to decompose.
065         * @param singularityThreshold threshold (based on partial row norm)
066         * under which a matrix is considered singular
067         * @throws NonSquareMatrixException if matrix is not square
068         */
069        public LUDecompositionImpl(RealMatrix matrix, double singularityThreshold) {
070            if (!matrix.isSquare()) {
071                throw new NonSquareMatrixException(matrix.getRowDimension(),
072                                                   matrix.getColumnDimension());
073            }
074    
075            final int m = matrix.getColumnDimension();
076            lu = matrix.getData();
077            pivot = new int[m];
078            cachedL = null;
079            cachedU = null;
080            cachedP = null;
081    
082            // Initialize permutation array and parity
083            for (int row = 0; row < m; row++) {
084                pivot[row] = row;
085            }
086            even     = true;
087            singular = false;
088    
089            // Loop over columns
090            for (int col = 0; col < m; col++) {
091    
092                double sum = 0;
093    
094                // upper
095                for (int row = 0; row < col; row++) {
096                    final double[] luRow = lu[row];
097                    sum = luRow[col];
098                    for (int i = 0; i < row; i++) {
099                        sum -= luRow[i] * lu[i][col];
100                    }
101                    luRow[col] = sum;
102                }
103    
104                // lower
105                int max = col; // permutation row
106                double largest = Double.NEGATIVE_INFINITY;
107                for (int row = col; row < m; row++) {
108                    final double[] luRow = lu[row];
109                    sum = luRow[col];
110                    for (int i = 0; i < col; i++) {
111                        sum -= luRow[i] * lu[i][col];
112                    }
113                    luRow[col] = sum;
114    
115                    // maintain best permutation choice
116                    if (FastMath.abs(sum) > largest) {
117                        largest = FastMath.abs(sum);
118                        max = row;
119                    }
120                }
121    
122                // Singularity check
123                if (FastMath.abs(lu[max][col]) < singularityThreshold) {
124                    singular = true;
125                    return;
126                }
127    
128                // Pivot if necessary
129                if (max != col) {
130                    double tmp = 0;
131                    final double[] luMax = lu[max];
132                    final double[] luCol = lu[col];
133                    for (int i = 0; i < m; i++) {
134                        tmp = luMax[i];
135                        luMax[i] = luCol[i];
136                        luCol[i] = tmp;
137                    }
138                    int temp = pivot[max];
139                    pivot[max] = pivot[col];
140                    pivot[col] = temp;
141                    even = !even;
142                }
143    
144                // Divide the lower elements by the "winning" diagonal elt.
145                final double luDiag = lu[col][col];
146                for (int row = col + 1; row < m; row++) {
147                    lu[row][col] /= luDiag;
148                }
149            }
150        }
151    
152        /** {@inheritDoc} */
153        public RealMatrix getL() {
154            if ((cachedL == null) && !singular) {
155                final int m = pivot.length;
156                cachedL = MatrixUtils.createRealMatrix(m, m);
157                for (int i = 0; i < m; ++i) {
158                    final double[] luI = lu[i];
159                    for (int j = 0; j < i; ++j) {
160                        cachedL.setEntry(i, j, luI[j]);
161                    }
162                    cachedL.setEntry(i, i, 1.0);
163                }
164            }
165            return cachedL;
166        }
167    
168        /** {@inheritDoc} */
169        public RealMatrix getU() {
170            if ((cachedU == null) && !singular) {
171                final int m = pivot.length;
172                cachedU = MatrixUtils.createRealMatrix(m, m);
173                for (int i = 0; i < m; ++i) {
174                    final double[] luI = lu[i];
175                    for (int j = i; j < m; ++j) {
176                        cachedU.setEntry(i, j, luI[j]);
177                    }
178                }
179            }
180            return cachedU;
181        }
182    
183        /** {@inheritDoc} */
184        public RealMatrix getP() {
185            if ((cachedP == null) && !singular) {
186                final int m = pivot.length;
187                cachedP = MatrixUtils.createRealMatrix(m, m);
188                for (int i = 0; i < m; ++i) {
189                    cachedP.setEntry(i, pivot[i], 1.0);
190                }
191            }
192            return cachedP;
193        }
194    
195        /** {@inheritDoc} */
196        public int[] getPivot() {
197            return pivot.clone();
198        }
199    
200        /** {@inheritDoc} */
201        public double getDeterminant() {
202            if (singular) {
203                return 0;
204            } else {
205                final int m = pivot.length;
206                double determinant = even ? 1 : -1;
207                for (int i = 0; i < m; i++) {
208                    determinant *= lu[i][i];
209                }
210                return determinant;
211            }
212        }
213    
214        /** {@inheritDoc} */
215        public DecompositionSolver getSolver() {
216            return new Solver(lu, pivot, singular);
217        }
218    
219        /** Specialized solver. */
220        private static class Solver implements DecompositionSolver {
221    
222            /** Entries of LU decomposition. */
223            private final double lu[][];
224    
225            /** Pivot permutation associated with LU decomposition. */
226            private final int[] pivot;
227    
228            /** Singularity indicator. */
229            private final boolean singular;
230    
231            /**
232             * Build a solver from decomposed matrix.
233             * @param lu entries of LU decomposition
234             * @param pivot pivot permutation associated with LU decomposition
235             * @param singular singularity indicator
236             */
237            private Solver(final double[][] lu, final int[] pivot, final boolean singular) {
238                this.lu       = lu;
239                this.pivot    = pivot;
240                this.singular = singular;
241            }
242    
243            /** {@inheritDoc} */
244            public boolean isNonSingular() {
245                return !singular;
246            }
247    
248            /** {@inheritDoc} */
249            public double[] solve(double[] b) {
250                final int m = pivot.length;
251                if (b.length != m) {
252                    throw new DimensionMismatchException(b.length, m);
253                }
254                if (singular) {
255                    throw new SingularMatrixException();
256                }
257    
258                final double[] bp = new double[m];
259    
260                // Apply permutations to b
261                for (int row = 0; row < m; row++) {
262                    bp[row] = b[pivot[row]];
263                }
264    
265                // Solve LY = b
266                for (int col = 0; col < m; col++) {
267                    final double bpCol = bp[col];
268                    for (int i = col + 1; i < m; i++) {
269                        bp[i] -= bpCol * lu[i][col];
270                    }
271                }
272    
273                // Solve UX = Y
274                for (int col = m - 1; col >= 0; col--) {
275                    bp[col] /= lu[col][col];
276                    final double bpCol = bp[col];
277                    for (int i = 0; i < col; i++) {
278                        bp[i] -= bpCol * lu[i][col];
279                    }
280                }
281    
282                return bp;
283            }
284    
285            /** {@inheritDoc} */
286            public RealVector solve(RealVector b) {
287                try {
288                    return solve((ArrayRealVector) b);
289                } catch (ClassCastException cce) {
290    
291                    final int m = pivot.length;
292                    if (b.getDimension() != m) {
293                        throw new DimensionMismatchException(b.getDimension(), m);
294                    }
295                    if (singular) {
296                        throw new SingularMatrixException();
297                    }
298    
299                    final double[] bp = new double[m];
300    
301                    // Apply permutations to b
302                    for (int row = 0; row < m; row++) {
303                        bp[row] = b.getEntry(pivot[row]);
304                    }
305    
306                    // Solve LY = b
307                    for (int col = 0; col < m; col++) {
308                        final double bpCol = bp[col];
309                        for (int i = col + 1; i < m; i++) {
310                            bp[i] -= bpCol * lu[i][col];
311                        }
312                    }
313    
314                    // Solve UX = Y
315                    for (int col = m - 1; col >= 0; col--) {
316                        bp[col] /= lu[col][col];
317                        final double bpCol = bp[col];
318                        for (int i = 0; i < col; i++) {
319                            bp[i] -= bpCol * lu[i][col];
320                        }
321                    }
322    
323                    return new ArrayRealVector(bp, false);
324                }
325            }
326    
327            /** Solve the linear equation A &times; X = B.
328             * <p>The A matrix is implicit here. It is </p>
329             * @param b right-hand side of the equation A &times; X = B
330             * @return a vector X such that A &times; X = B
331             * @throws DimensionMismatchException if the matrices dimensions
332             * do not match.
333             * @throws SingularMatrixException if decomposed matrix is singular.
334             */
335            public ArrayRealVector solve(ArrayRealVector b) {
336                return new ArrayRealVector(solve(b.getDataRef()), false);
337            }
338    
339            /** {@inheritDoc} */
340            public double[][] solve(double[][] b) {
341    
342                final int m = pivot.length;
343                if (b.length != m) {
344                    throw new DimensionMismatchException(b.length, m);
345                }
346                if (singular) {
347                    throw new SingularMatrixException();
348                }
349    
350                final int nColB = b[0].length;
351    
352                // Apply permutations to b
353                final double[][] bp = new double[m][nColB];
354                for (int row = 0; row < m; row++) {
355                    final double[] bpRow = bp[row];
356                    final int pRow = pivot[row];
357                    for (int col = 0; col < nColB; col++) {
358                        bpRow[col] = b[pRow][col];
359                    }
360                }
361    
362                // Solve LY = b
363                for (int col = 0; col < m; col++) {
364                    final double[] bpCol = bp[col];
365                    for (int i = col + 1; i < m; i++) {
366                        final double[] bpI = bp[i];
367                        final double luICol = lu[i][col];
368                        for (int j = 0; j < nColB; j++) {
369                            bpI[j] -= bpCol[j] * luICol;
370                        }
371                    }
372                }
373    
374                // Solve UX = Y
375                for (int col = m - 1; col >= 0; col--) {
376                    final double[] bpCol = bp[col];
377                    final double luDiag = lu[col][col];
378                    for (int j = 0; j < nColB; j++) {
379                        bpCol[j] /= luDiag;
380                    }
381                    for (int i = 0; i < col; i++) {
382                        final double[] bpI = bp[i];
383                        final double luICol = lu[i][col];
384                        for (int j = 0; j < nColB; j++) {
385                            bpI[j] -= bpCol[j] * luICol;
386                        }
387                    }
388                }
389    
390                return bp;
391    
392            }
393    
394            /** {@inheritDoc} */
395            public RealMatrix solve(RealMatrix b) {
396                return new Array2DRowRealMatrix(solve(b.getData()), false);
397            }
398    
399            /** {@inheritDoc} */
400            public RealMatrix getInverse() {
401                return solve(MatrixUtils.createRealIdentityMatrix(pivot.length));
402            }
403        }
404    }