001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.linear;
019    
020    import java.lang.reflect.Array;
021    
022    import org.apache.commons.math.Field;
023    import org.apache.commons.math.FieldElement;
024    import org.apache.commons.math.exception.DimensionMismatchException;
025    
026    /**
027     * Calculates the LUP-decomposition of a square matrix.
028     * <p>The LUP-decomposition of a matrix A consists of three matrices
029     * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
030     * upper triangular and P is a permutation matrix. All matrices are
031     * m&times;m.</p>
032     * <p>Since {@link FieldElement field elements} do not provide an ordering
033     * operator, the permutation matrix is computed here only in order to avoid
034     * a zero pivot element, no attempt is done to get the largest pivot element.</p>
035     *
036     * @param <T> the type of the field elements
037     * @version $Id: FieldLUDecompositionImpl.java 1141885 2011-07-01 09:16:41Z luc $
038     * @since 2.0
039     */
040    public class FieldLUDecompositionImpl<T extends FieldElement<T>> implements FieldLUDecomposition<T> {
041    
042        /** Field to which the elements belong. */
043        private final Field<T> field;
044    
045        /** Entries of LU decomposition. */
046        private T lu[][];
047    
048        /** Pivot permutation associated with LU decomposition */
049        private int[] pivot;
050    
051        /** Parity of the permutation associated with the LU decomposition */
052        private boolean even;
053    
054        /** Singularity indicator. */
055        private boolean singular;
056    
057        /** Cached value of L. */
058        private FieldMatrix<T> cachedL;
059    
060        /** Cached value of U. */
061        private FieldMatrix<T> cachedU;
062    
063        /** Cached value of P. */
064        private FieldMatrix<T> cachedP;
065    
066        /**
067         * Calculates the LU-decomposition of the given matrix.
068         * @param matrix The matrix to decompose.
069         * @throws NonSquareMatrixException if matrix is not square
070         */
071        public FieldLUDecompositionImpl(FieldMatrix<T> matrix) {
072            if (!matrix.isSquare()) {
073                throw new NonSquareMatrixException(matrix.getRowDimension(),
074                                                   matrix.getColumnDimension());
075            }
076    
077            final int m = matrix.getColumnDimension();
078            field = matrix.getField();
079            lu = matrix.getData();
080            pivot = new int[m];
081            cachedL = null;
082            cachedU = null;
083            cachedP = null;
084    
085            // Initialize permutation array and parity
086            for (int row = 0; row < m; row++) {
087                pivot[row] = row;
088            }
089            even     = true;
090            singular = false;
091    
092            // Loop over columns
093            for (int col = 0; col < m; col++) {
094    
095                T sum = field.getZero();
096    
097                // upper
098                for (int row = 0; row < col; row++) {
099                    final T[] luRow = lu[row];
100                    sum = luRow[col];
101                    for (int i = 0; i < row; i++) {
102                        sum = sum.subtract(luRow[i].multiply(lu[i][col]));
103                    }
104                    luRow[col] = sum;
105                }
106    
107                // lower
108                int nonZero = col; // permutation row
109                for (int row = col; row < m; row++) {
110                    final T[] luRow = lu[row];
111                    sum = luRow[col];
112                    for (int i = 0; i < col; i++) {
113                        sum = sum.subtract(luRow[i].multiply(lu[i][col]));
114                    }
115                    luRow[col] = sum;
116    
117                    if (lu[nonZero][col].equals(field.getZero())) {
118                        // try to select a better permutation choice
119                        ++nonZero;
120                    }
121                }
122    
123                // Singularity check
124                if (nonZero >= m) {
125                    singular = true;
126                    return;
127                }
128    
129                // Pivot if necessary
130                if (nonZero != col) {
131                    T tmp = field.getZero();
132                    for (int i = 0; i < m; i++) {
133                        tmp = lu[nonZero][i];
134                        lu[nonZero][i] = lu[col][i];
135                        lu[col][i] = tmp;
136                    }
137                    int temp = pivot[nonZero];
138                    pivot[nonZero] = pivot[col];
139                    pivot[col] = temp;
140                    even = !even;
141                }
142    
143                // Divide the lower elements by the "winning" diagonal elt.
144                final T luDiag = lu[col][col];
145                for (int row = col + 1; row < m; row++) {
146                    final T[] luRow = lu[row];
147                    luRow[col] = luRow[col].divide(luDiag);
148                }
149            }
150    
151        }
152    
153        /** {@inheritDoc} */
154        public FieldMatrix<T> getL() {
155            if ((cachedL == null) && !singular) {
156                final int m = pivot.length;
157                cachedL = new Array2DRowFieldMatrix<T>(field, m, m);
158                for (int i = 0; i < m; ++i) {
159                    final T[] luI = lu[i];
160                    for (int j = 0; j < i; ++j) {
161                        cachedL.setEntry(i, j, luI[j]);
162                    }
163                    cachedL.setEntry(i, i, field.getOne());
164                }
165            }
166            return cachedL;
167        }
168    
169        /** {@inheritDoc} */
170        public FieldMatrix<T> getU() {
171            if ((cachedU == null) && !singular) {
172                final int m = pivot.length;
173                cachedU = new Array2DRowFieldMatrix<T>(field, m, m);
174                for (int i = 0; i < m; ++i) {
175                    final T[] luI = lu[i];
176                    for (int j = i; j < m; ++j) {
177                        cachedU.setEntry(i, j, luI[j]);
178                    }
179                }
180            }
181            return cachedU;
182        }
183    
184        /** {@inheritDoc} */
185        public FieldMatrix<T> getP() {
186            if ((cachedP == null) && !singular) {
187                final int m = pivot.length;
188                cachedP = new Array2DRowFieldMatrix<T>(field, m, m);
189                for (int i = 0; i < m; ++i) {
190                    cachedP.setEntry(i, pivot[i], field.getOne());
191                }
192            }
193            return cachedP;
194        }
195    
196        /** {@inheritDoc} */
197        public int[] getPivot() {
198            return pivot.clone();
199        }
200    
201        /** {@inheritDoc} */
202        public T getDeterminant() {
203            if (singular) {
204                return field.getZero();
205            } else {
206                final int m = pivot.length;
207                T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
208                for (int i = 0; i < m; i++) {
209                    determinant = determinant.multiply(lu[i][i]);
210                }
211                return determinant;
212            }
213        }
214    
215        /** {@inheritDoc} */
216        public FieldDecompositionSolver<T> getSolver() {
217            return new Solver<T>(field, lu, pivot, singular);
218        }
219    
220        /** Specialized solver. */
221        private static class Solver<T extends FieldElement<T>> implements FieldDecompositionSolver<T> {
222    
223            /** Field to which the elements belong. */
224            private final Field<T> field;
225    
226            /** Entries of LU decomposition. */
227            private final T lu[][];
228    
229            /** Pivot permutation associated with LU decomposition. */
230            private final int[] pivot;
231    
232            /** Singularity indicator. */
233            private final boolean singular;
234    
235            /**
236             * Build a solver from decomposed matrix.
237             * @param field field to which the matrix elements belong
238             * @param lu entries of LU decomposition
239             * @param pivot pivot permutation associated with LU decomposition
240             * @param singular singularity indicator
241             */
242            private Solver(final Field<T> field, final T[][] lu,
243                           final int[] pivot, final boolean singular) {
244                this.field    = field;
245                this.lu       = lu;
246                this.pivot    = pivot;
247                this.singular = singular;
248            }
249    
250            /** {@inheritDoc} */
251            public boolean isNonSingular() {
252                return !singular;
253            }
254    
255            /** {@inheritDoc} */
256            public T[] solve(T[] b) {
257                final int m = pivot.length;
258                if (b.length != m) {
259                    throw new DimensionMismatchException(b.length, m);
260                }
261                if (singular) {
262                    throw new SingularMatrixException();
263                }
264    
265                @SuppressWarnings("unchecked") // field is of type T
266                final T[] bp = (T[]) Array.newInstance(field.getZero().getClass(), m);
267    
268                // Apply permutations to b
269                for (int row = 0; row < m; row++) {
270                    bp[row] = b[pivot[row]];
271                }
272    
273                // Solve LY = b
274                for (int col = 0; col < m; col++) {
275                    final T bpCol = bp[col];
276                    for (int i = col + 1; i < m; i++) {
277                        bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
278                    }
279                }
280    
281                // Solve UX = Y
282                for (int col = m - 1; col >= 0; col--) {
283                    bp[col] = bp[col].divide(lu[col][col]);
284                    final T bpCol = bp[col];
285                    for (int i = 0; i < col; i++) {
286                        bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
287                    }
288                }
289    
290                return bp;
291    
292            }
293    
294            /** {@inheritDoc} */
295            public FieldVector<T> solve(FieldVector<T> b) {
296                try {
297                    return solve((ArrayFieldVector<T>) b);
298                } catch (ClassCastException cce) {
299    
300                    final int m = pivot.length;
301                    if (b.getDimension() != m) {
302                        throw new DimensionMismatchException(b.getDimension(), m);
303                    }
304                    if (singular) {
305                        throw new SingularMatrixException();
306                    }
307    
308                    @SuppressWarnings("unchecked") // field is of type T
309                    final T[] bp = (T[]) Array.newInstance(field.getZero().getClass(), m);
310    
311                    // Apply permutations to b
312                    for (int row = 0; row < m; row++) {
313                        bp[row] = b.getEntry(pivot[row]);
314                    }
315    
316                    // Solve LY = b
317                    for (int col = 0; col < m; col++) {
318                        final T bpCol = bp[col];
319                        for (int i = col + 1; i < m; i++) {
320                            bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
321                        }
322                    }
323    
324                    // Solve UX = Y
325                    for (int col = m - 1; col >= 0; col--) {
326                        bp[col] = bp[col].divide(lu[col][col]);
327                        final T bpCol = bp[col];
328                        for (int i = 0; i < col; i++) {
329                            bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
330                        }
331                    }
332    
333                    return new ArrayFieldVector<T>(field, bp, false);
334    
335                }
336            }
337    
338            /** Solve the linear equation A &times; X = B.
339             * <p>The A matrix is implicit here. It is </p>
340             * @param b right-hand side of the equation A &times; X = B
341             * @return a vector X such that A &times; X = B
342             * @throws DimensionMismatchException if the matrices dimensions do not match.
343             * @throws SingularMatrixException if the decomposed matrix is singular.
344             */
345            public ArrayFieldVector<T> solve(ArrayFieldVector<T> b) {
346                return new ArrayFieldVector<T>(field, solve(b.getDataRef()), false);
347            }
348    
349            /** {@inheritDoc} */
350            public FieldMatrix<T> solve(FieldMatrix<T> b) {
351                final int m = pivot.length;
352                if (b.getRowDimension() != m) {
353                    throw new DimensionMismatchException(b.getRowDimension(), m);
354                }
355                if (singular) {
356                    throw new SingularMatrixException();
357                }
358    
359                final int nColB = b.getColumnDimension();
360    
361                // Apply permutations to b
362                @SuppressWarnings("unchecked") // field is of type T
363                final T[][] bp = (T[][]) Array.newInstance(field.getZero().getClass(), new int[] { m, nColB });
364                for (int row = 0; row < m; row++) {
365                    final T[] bpRow = bp[row];
366                    final int pRow = pivot[row];
367                    for (int col = 0; col < nColB; col++) {
368                        bpRow[col] = b.getEntry(pRow, col);
369                    }
370                }
371    
372                // Solve LY = b
373                for (int col = 0; col < m; col++) {
374                    final T[] bpCol = bp[col];
375                    for (int i = col + 1; i < m; i++) {
376                        final T[] bpI = bp[i];
377                        final T luICol = lu[i][col];
378                        for (int j = 0; j < nColB; j++) {
379                            bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
380                        }
381                    }
382                }
383    
384                // Solve UX = Y
385                for (int col = m - 1; col >= 0; col--) {
386                    final T[] bpCol = bp[col];
387                    final T luDiag = lu[col][col];
388                    for (int j = 0; j < nColB; j++) {
389                        bpCol[j] = bpCol[j].divide(luDiag);
390                    }
391                    for (int i = 0; i < col; i++) {
392                        final T[] bpI = bp[i];
393                        final T luICol = lu[i][col];
394                        for (int j = 0; j < nColB; j++) {
395                            bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
396                        }
397                    }
398                }
399    
400                return new Array2DRowFieldMatrix<T>(field, bp, false);
401    
402            }
403    
404            /** {@inheritDoc} */
405            public FieldMatrix<T> getInverse() {
406                final int m = pivot.length;
407                final T one = field.getOne();
408                FieldMatrix<T> identity = new Array2DRowFieldMatrix<T>(field, m, m);
409                for (int i = 0; i < m; ++i) {
410                    identity.setEntry(i, i, one);
411                }
412                return solve(identity);
413            }
414        }
415    }