001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.legacy.linear;
019
020import org.apache.commons.math4.legacy.core.Field;
021import org.apache.commons.math4.legacy.core.FieldElement;
022import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
023import org.apache.commons.math4.legacy.core.MathArrays;
024
025/**
026 * Calculates the LUP-decomposition of a square matrix.
027 * <p>The LUP-decomposition of a matrix A consists of three matrices
028 * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
029 * upper triangular and P is a permutation matrix. All matrices are
030 * m&times;m.</p>
031 * <p>Since {@link FieldElement field elements} do not provide an ordering
032 * operator, the permutation matrix is computed here only in order to avoid
033 * a zero pivot element, no attempt is done to get the largest pivot
034 * element.</p>
035 * <p>This class is based on the class with similar name from the
036 * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library.</p>
037 * <ul>
038 *   <li>a {@link #getP() getP} method has been added,</li>
039 *   <li>the {@code det} method has been renamed as {@link #getDeterminant()
040 *   getDeterminant},</li>
041 *   <li>the {@code getDoublePivot} method has been removed (but the int based
042 *   {@link #getPivot() getPivot} method has been kept),</li>
043 *   <li>the {@code solve} and {@code isNonSingular} methods have been replaced
044 *   by a {@link #getSolver() getSolver} method and the equivalent methods
045 *   provided by the returned {@link DecompositionSolver}.</li>
046 * </ul>
047 *
048 * @param <T> the type of the field elements
049 * @see <a href="http://mathworld.wolfram.com/LUDecomposition.html">MathWorld</a>
050 * @see <a href="http://en.wikipedia.org/wiki/LU_decomposition">Wikipedia</a>
051 * @since 2.0 (changed to concrete class in 3.0)
052 */
053public class FieldLUDecomposition<T extends FieldElement<T>> {
054
055    /** Field to which the elements belong. */
056    private final Field<T> field;
057
058    /** Entries of LU decomposition. */
059    private T[][] lu;
060
061    /** Pivot permutation associated with LU decomposition. */
062    private int[] pivot;
063
064    /** Parity of the permutation associated with the LU decomposition. */
065    private boolean even;
066
067    /** Singularity indicator. */
068    private boolean singular;
069
070    /** Cached value of L. */
071    private FieldMatrix<T> cachedL;
072
073    /** Cached value of U. */
074    private FieldMatrix<T> cachedU;
075
076    /** Cached value of P. */
077    private FieldMatrix<T> cachedP;
078
079    /**
080     * Calculates the LU-decomposition of the given matrix.
081     * @param matrix The matrix to decompose.
082     * @throws NonSquareMatrixException if matrix is not square
083     */
084    public FieldLUDecomposition(FieldMatrix<T> matrix) {
085        if (!matrix.isSquare()) {
086            throw new NonSquareMatrixException(matrix.getRowDimension(),
087                                               matrix.getColumnDimension());
088        }
089
090        final int m = matrix.getColumnDimension();
091        field = matrix.getField();
092        lu = matrix.getData();
093        pivot = new int[m];
094        cachedL = null;
095        cachedU = null;
096        cachedP = null;
097
098        // Initialize permutation array and parity
099        for (int row = 0; row < m; row++) {
100            pivot[row] = row;
101        }
102        even     = true;
103        singular = false;
104
105        // Loop over columns
106        for (int col = 0; col < m; col++) {
107
108            T sum = field.getZero();
109
110            // upper
111            for (int row = 0; row < col; row++) {
112                final T[] luRow = lu[row];
113                sum = luRow[col];
114                for (int i = 0; i < row; i++) {
115                    sum = sum.subtract(luRow[i].multiply(lu[i][col]));
116                }
117                luRow[col] = sum;
118            }
119
120            // lower
121            int nonZero = col; // permutation row
122            for (int row = col; row < m; row++) {
123                final T[] luRow = lu[row];
124                sum = luRow[col];
125                for (int i = 0; i < col; i++) {
126                    sum = sum.subtract(luRow[i].multiply(lu[i][col]));
127                }
128                luRow[col] = sum;
129
130                if (lu[nonZero][col].equals(field.getZero())) {
131                    // try to select a better permutation choice
132                    ++nonZero;
133                }
134            }
135
136            // Singularity check
137            if (nonZero >= m) {
138                singular = true;
139                return;
140            }
141
142            // Pivot if necessary
143            if (nonZero != col) {
144                T tmp = field.getZero();
145                for (int i = 0; i < m; i++) {
146                    tmp = lu[nonZero][i];
147                    lu[nonZero][i] = lu[col][i];
148                    lu[col][i] = tmp;
149                }
150                int temp = pivot[nonZero];
151                pivot[nonZero] = pivot[col];
152                pivot[col] = temp;
153                even = !even;
154            }
155
156            // Divide the lower elements by the "winning" diagonal elt.
157            final T luDiag = lu[col][col];
158            for (int row = col + 1; row < m; row++) {
159                final T[] luRow = lu[row];
160                luRow[col] = luRow[col].divide(luDiag);
161            }
162        }
163    }
164
165    /**
166     * Returns the matrix L of the decomposition.
167     * <p>L is a lower-triangular matrix</p>
168     * @return the L matrix (or null if decomposed matrix is singular)
169     */
170    public FieldMatrix<T> getL() {
171        if (cachedL == null && !singular) {
172            final int m = pivot.length;
173            cachedL = new Array2DRowFieldMatrix<>(field, m, m);
174            for (int i = 0; i < m; ++i) {
175                final T[] luI = lu[i];
176                for (int j = 0; j < i; ++j) {
177                    cachedL.setEntry(i, j, luI[j]);
178                }
179                cachedL.setEntry(i, i, field.getOne());
180            }
181        }
182        return cachedL;
183    }
184
185    /**
186     * Returns the matrix U of the decomposition.
187     * <p>U is an upper-triangular matrix</p>
188     * @return the U matrix (or null if decomposed matrix is singular)
189     */
190    public FieldMatrix<T> getU() {
191        if (cachedU == null && !singular) {
192            final int m = pivot.length;
193            cachedU = new Array2DRowFieldMatrix<>(field, m, m);
194            for (int i = 0; i < m; ++i) {
195                final T[] luI = lu[i];
196                for (int j = i; j < m; ++j) {
197                    cachedU.setEntry(i, j, luI[j]);
198                }
199            }
200        }
201        return cachedU;
202    }
203
204    /**
205     * Returns the P rows permutation matrix.
206     * <p>P is a sparse matrix with exactly one element set to 1.0 in
207     * each row and each column, all other elements being set to 0.0.</p>
208     * <p>The positions of the 1 elements are given by the {@link #getPivot()
209     * pivot permutation vector}.</p>
210     * @return the P rows permutation matrix (or null if decomposed matrix is singular)
211     * @see #getPivot()
212     */
213    public FieldMatrix<T> getP() {
214        if (cachedP == null && !singular) {
215            final int m = pivot.length;
216            cachedP = new Array2DRowFieldMatrix<>(field, m, m);
217            for (int i = 0; i < m; ++i) {
218                cachedP.setEntry(i, pivot[i], field.getOne());
219            }
220        }
221        return cachedP;
222    }
223
224    /**
225     * Returns the pivot permutation vector.
226     * @return the pivot permutation vector
227     * @see #getP()
228     */
229    public int[] getPivot() {
230        return pivot.clone();
231    }
232
233    /**
234     * Return the determinant of the matrix.
235     * @return determinant of the matrix
236     */
237    public T getDeterminant() {
238        if (singular) {
239            return field.getZero();
240        } else {
241            final int m = pivot.length;
242            T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
243            for (int i = 0; i < m; i++) {
244                determinant = determinant.multiply(lu[i][i]);
245            }
246            return determinant;
247        }
248    }
249
250    /**
251     * Get a solver for finding the A &times; X = B solution in exact linear sense.
252     * @return a solver
253     */
254    public FieldDecompositionSolver<T> getSolver() {
255        return new Solver<>(field, lu, pivot, singular);
256    }
257
258    /** Specialized solver.
259     * @param <T> the type of the field elements
260     */
261    private static final class Solver<T extends FieldElement<T>> implements FieldDecompositionSolver<T> {
262
263        /** Field to which the elements belong. */
264        private final Field<T> field;
265
266        /** Entries of LU decomposition. */
267        private final T[][] lu;
268
269        /** Pivot permutation associated with LU decomposition. */
270        private final int[] pivot;
271
272        /** Singularity indicator. */
273        private final boolean singular;
274
275        /**
276         * Build a solver from decomposed matrix.
277         * @param field field to which the matrix elements belong
278         * @param lu entries of LU decomposition
279         * @param pivot pivot permutation associated with LU decomposition
280         * @param singular singularity indicator
281         */
282        private Solver(final Field<T> field, final T[][] lu,
283                       final int[] pivot, final boolean singular) {
284            this.field    = field;
285            this.lu       = lu;
286            this.pivot    = pivot;
287            this.singular = singular;
288        }
289
290        /** {@inheritDoc} */
291        @Override
292        public boolean isNonSingular() {
293            return !singular;
294        }
295
296        /** {@inheritDoc} */
297        @Override
298        public FieldVector<T> solve(FieldVector<T> b) {
299            if (b instanceof ArrayFieldVector) {
300                return solve((ArrayFieldVector<T>) b);
301            }
302
303            final int m = pivot.length;
304            if (b.getDimension() != m) {
305                throw new DimensionMismatchException(b.getDimension(), m);
306            }
307            if (singular) {
308                throw new SingularMatrixException();
309            }
310
311            // Apply permutations to b
312            final T[] bp = MathArrays.buildArray(field, m);
313            for (int row = 0; row < m; row++) {
314                bp[row] = b.getEntry(pivot[row]);
315            }
316
317            // Solve LY = b
318            for (int col = 0; col < m; col++) {
319                final T bpCol = bp[col];
320                for (int i = col + 1; i < m; i++) {
321                    bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
322                }
323            }
324
325            // Solve UX = Y
326            for (int col = m - 1; col >= 0; col--) {
327                bp[col] = bp[col].divide(lu[col][col]);
328                final T bpCol = bp[col];
329                for (int i = 0; i < col; i++) {
330                    bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
331                }
332            }
333
334            return new ArrayFieldVector<>(field, bp, false);
335        }
336
337        /** Solve the linear equation A &times; X = B.
338         * <p>The A matrix is implicit here. It is </p>
339         * @param b right-hand side of the equation A &times; X = B
340         * @return a vector X such that A &times; X = B
341         * @throws DimensionMismatchException if the matrices dimensions do not match.
342         * @throws SingularMatrixException if the decomposed matrix is singular.
343         */
344        public ArrayFieldVector<T> solve(ArrayFieldVector<T> b) {
345            final int m = pivot.length;
346            final int length = b.getDimension();
347            if (length != m) {
348                throw new DimensionMismatchException(length, m);
349            }
350            if (singular) {
351                throw new SingularMatrixException();
352            }
353
354            // Apply permutations to b
355            final T[] bp = MathArrays.buildArray(field, m);
356            for (int row = 0; row < m; row++) {
357                bp[row] = b.getEntry(pivot[row]);
358            }
359
360            // Solve LY = b
361            for (int col = 0; col < m; col++) {
362                final T bpCol = bp[col];
363                for (int i = col + 1; i < m; i++) {
364                    bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
365                }
366            }
367
368            // Solve UX = Y
369            for (int col = m - 1; col >= 0; col--) {
370                bp[col] = bp[col].divide(lu[col][col]);
371                final T bpCol = bp[col];
372                for (int i = 0; i < col; i++) {
373                    bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
374                }
375            }
376
377            return new ArrayFieldVector<>(bp, false);
378        }
379
380        /** {@inheritDoc} */
381        @Override
382        public FieldMatrix<T> solve(FieldMatrix<T> b) {
383            final int m = pivot.length;
384            if (b.getRowDimension() != m) {
385                throw new DimensionMismatchException(b.getRowDimension(), m);
386            }
387            if (singular) {
388                throw new SingularMatrixException();
389            }
390
391            final int nColB = b.getColumnDimension();
392
393            // Apply permutations to b
394            final T[][] bp = MathArrays.buildArray(field, m, nColB);
395            for (int row = 0; row < m; row++) {
396                final T[] bpRow = bp[row];
397                final int pRow = pivot[row];
398                for (int col = 0; col < nColB; col++) {
399                    bpRow[col] = b.getEntry(pRow, col);
400                }
401            }
402
403            // Solve LY = b
404            for (int col = 0; col < m; col++) {
405                final T[] bpCol = bp[col];
406                for (int i = col + 1; i < m; i++) {
407                    final T[] bpI = bp[i];
408                    final T luICol = lu[i][col];
409                    for (int j = 0; j < nColB; j++) {
410                        bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
411                    }
412                }
413            }
414
415            // Solve UX = Y
416            for (int col = m - 1; col >= 0; col--) {
417                final T[] bpCol = bp[col];
418                final T luDiag = lu[col][col];
419                for (int j = 0; j < nColB; j++) {
420                    bpCol[j] = bpCol[j].divide(luDiag);
421                }
422                for (int i = 0; i < col; i++) {
423                    final T[] bpI = bp[i];
424                    final T luICol = lu[i][col];
425                    for (int j = 0; j < nColB; j++) {
426                        bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
427                    }
428                }
429            }
430
431            return new Array2DRowFieldMatrix<>(field, bp, false);
432        }
433
434        /** {@inheritDoc} */
435        @Override
436        public FieldMatrix<T> getInverse() {
437            final int m = pivot.length;
438            final T one = field.getOne();
439            FieldMatrix<T> identity = new Array2DRowFieldMatrix<>(field, m, m);
440            for (int i = 0; i < m; ++i) {
441                identity.setEntry(i, i, one);
442            }
443            return solve(identity);
444        }
445    }
446}