View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math3.linear;
19  
20  import org.apache.commons.math3.Field;
21  import org.apache.commons.math3.FieldElement;
22  import org.apache.commons.math3.exception.DimensionMismatchException;
23  import org.apache.commons.math3.util.MathArrays;
24  
25  /**
26   * Calculates the LUP-decomposition of a square matrix.
27   * <p>The LUP-decomposition of a matrix A consists of three matrices
28   * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
29   * upper triangular and P is a permutation matrix. All matrices are
30   * m&times;m.</p>
31   * <p>Since {@link FieldElement field elements} do not provide an ordering
32   * operator, the permutation matrix is computed here only in order to avoid
33   * a zero pivot element, no attempt is done to get the largest pivot
34   * element.</p>
35   * <p>This class is based on the class with similar name from the
36   * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library.</p>
37   * <ul>
38   *   <li>a {@link #getP() getP} method has been added,</li>
39   *   <li>the {@code det} method has been renamed as {@link #getDeterminant()
40   *   getDeterminant},</li>
41   *   <li>the {@code getDoublePivot} method has been removed (but the int based
42   *   {@link #getPivot() getPivot} method has been kept),</li>
43   *   <li>the {@code solve} and {@code isNonSingular} methods have been replaced
44   *   by a {@link #getSolver() getSolver} method and the equivalent methods
45   *   provided by the returned {@link DecompositionSolver}.</li>
46   * </ul>
47   *
48   * @param <T> the type of the field elements
49   * @see <a href="http://mathworld.wolfram.com/LUDecomposition.html">MathWorld</a>
50   * @see <a href="http://en.wikipedia.org/wiki/LU_decomposition">Wikipedia</a>
51   * @version $Id: FieldLUDecomposition.java 1449528 2013-02-24 19:06:20Z luc $
52   * @since 2.0 (changed to concrete class in 3.0)
53   */
54  public class FieldLUDecomposition<T extends FieldElement<T>> {
55  
56      /** Field to which the elements belong. */
57      private final Field<T> field;
58  
59      /** Entries of LU decomposition. */
60      private T[][] lu;
61  
62      /** Pivot permutation associated with LU decomposition. */
63      private int[] pivot;
64  
65      /** Parity of the permutation associated with the LU decomposition. */
66      private boolean even;
67  
68      /** Singularity indicator. */
69      private boolean singular;
70  
71      /** Cached value of L. */
72      private FieldMatrix<T> cachedL;
73  
74      /** Cached value of U. */
75      private FieldMatrix<T> cachedU;
76  
77      /** Cached value of P. */
78      private FieldMatrix<T> cachedP;
79  
80      /**
81       * Calculates the LU-decomposition of the given matrix.
82       * @param matrix The matrix to decompose.
83       * @throws NonSquareMatrixException if matrix is not square
84       */
85      public FieldLUDecomposition(FieldMatrix<T> matrix) {
86          if (!matrix.isSquare()) {
87              throw new NonSquareMatrixException(matrix.getRowDimension(),
88                                                 matrix.getColumnDimension());
89          }
90  
91          final int m = matrix.getColumnDimension();
92          field = matrix.getField();
93          lu = matrix.getData();
94          pivot = new int[m];
95          cachedL = null;
96          cachedU = null;
97          cachedP = null;
98  
99          // Initialize permutation array and parity
100         for (int row = 0; row < m; row++) {
101             pivot[row] = row;
102         }
103         even     = true;
104         singular = false;
105 
106         // Loop over columns
107         for (int col = 0; col < m; col++) {
108 
109             T sum = field.getZero();
110 
111             // upper
112             for (int row = 0; row < col; row++) {
113                 final T[] luRow = lu[row];
114                 sum = luRow[col];
115                 for (int i = 0; i < row; i++) {
116                     sum = sum.subtract(luRow[i].multiply(lu[i][col]));
117                 }
118                 luRow[col] = sum;
119             }
120 
121             // lower
122             int nonZero = col; // permutation row
123             for (int row = col; row < m; row++) {
124                 final T[] luRow = lu[row];
125                 sum = luRow[col];
126                 for (int i = 0; i < col; i++) {
127                     sum = sum.subtract(luRow[i].multiply(lu[i][col]));
128                 }
129                 luRow[col] = sum;
130 
131                 if (lu[nonZero][col].equals(field.getZero())) {
132                     // try to select a better permutation choice
133                     ++nonZero;
134                 }
135             }
136 
137             // Singularity check
138             if (nonZero >= m) {
139                 singular = true;
140                 return;
141             }
142 
143             // Pivot if necessary
144             if (nonZero != col) {
145                 T tmp = field.getZero();
146                 for (int i = 0; i < m; i++) {
147                     tmp = lu[nonZero][i];
148                     lu[nonZero][i] = lu[col][i];
149                     lu[col][i] = tmp;
150                 }
151                 int temp = pivot[nonZero];
152                 pivot[nonZero] = pivot[col];
153                 pivot[col] = temp;
154                 even = !even;
155             }
156 
157             // Divide the lower elements by the "winning" diagonal elt.
158             final T luDiag = lu[col][col];
159             for (int row = col + 1; row < m; row++) {
160                 final T[] luRow = lu[row];
161                 luRow[col] = luRow[col].divide(luDiag);
162             }
163         }
164 
165     }
166 
167     /**
168      * Returns the matrix L of the decomposition.
169      * <p>L is a lower-triangular matrix</p>
170      * @return the L matrix (or null if decomposed matrix is singular)
171      */
172     public FieldMatrix<T> getL() {
173         if ((cachedL == null) && !singular) {
174             final int m = pivot.length;
175             cachedL = new Array2DRowFieldMatrix<T>(field, m, m);
176             for (int i = 0; i < m; ++i) {
177                 final T[] luI = lu[i];
178                 for (int j = 0; j < i; ++j) {
179                     cachedL.setEntry(i, j, luI[j]);
180                 }
181                 cachedL.setEntry(i, i, field.getOne());
182             }
183         }
184         return cachedL;
185     }
186 
187     /**
188      * Returns the matrix U of the decomposition.
189      * <p>U is an upper-triangular matrix</p>
190      * @return the U matrix (or null if decomposed matrix is singular)
191      */
192     public FieldMatrix<T> getU() {
193         if ((cachedU == null) && !singular) {
194             final int m = pivot.length;
195             cachedU = new Array2DRowFieldMatrix<T>(field, m, m);
196             for (int i = 0; i < m; ++i) {
197                 final T[] luI = lu[i];
198                 for (int j = i; j < m; ++j) {
199                     cachedU.setEntry(i, j, luI[j]);
200                 }
201             }
202         }
203         return cachedU;
204     }
205 
206     /**
207      * Returns the P rows permutation matrix.
208      * <p>P is a sparse matrix with exactly one element set to 1.0 in
209      * each row and each column, all other elements being set to 0.0.</p>
210      * <p>The positions of the 1 elements are given by the {@link #getPivot()
211      * pivot permutation vector}.</p>
212      * @return the P rows permutation matrix (or null if decomposed matrix is singular)
213      * @see #getPivot()
214      */
215     public FieldMatrix<T> getP() {
216         if ((cachedP == null) && !singular) {
217             final int m = pivot.length;
218             cachedP = new Array2DRowFieldMatrix<T>(field, m, m);
219             for (int i = 0; i < m; ++i) {
220                 cachedP.setEntry(i, pivot[i], field.getOne());
221             }
222         }
223         return cachedP;
224     }
225 
226     /**
227      * Returns the pivot permutation vector.
228      * @return the pivot permutation vector
229      * @see #getP()
230      */
231     public int[] getPivot() {
232         return pivot.clone();
233     }
234 
235     /**
236      * Return the determinant of the matrix.
237      * @return determinant of the matrix
238      */
239     public T getDeterminant() {
240         if (singular) {
241             return field.getZero();
242         } else {
243             final int m = pivot.length;
244             T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
245             for (int i = 0; i < m; i++) {
246                 determinant = determinant.multiply(lu[i][i]);
247             }
248             return determinant;
249         }
250     }
251 
252     /**
253      * Get a solver for finding the A &times; X = B solution in exact linear sense.
254      * @return a solver
255      */
256     public FieldDecompositionSolver<T> getSolver() {
257         return new Solver<T>(field, lu, pivot, singular);
258     }
259 
260     /** Specialized solver. */
261     private static class Solver<T extends FieldElement<T>> implements FieldDecompositionSolver<T> {
262 
263         /** Field to which the elements belong. */
264         private final Field<T> field;
265 
266         /** Entries of LU decomposition. */
267         private final T[][] lu;
268 
269         /** Pivot permutation associated with LU decomposition. */
270         private final int[] pivot;
271 
272         /** Singularity indicator. */
273         private final boolean singular;
274 
275         /**
276          * Build a solver from decomposed matrix.
277          * @param field field to which the matrix elements belong
278          * @param lu entries of LU decomposition
279          * @param pivot pivot permutation associated with LU decomposition
280          * @param singular singularity indicator
281          */
282         private Solver(final Field<T> field, final T[][] lu,
283                        final int[] pivot, final boolean singular) {
284             this.field    = field;
285             this.lu       = lu;
286             this.pivot    = pivot;
287             this.singular = singular;
288         }
289 
290         /** {@inheritDoc} */
291         public boolean isNonSingular() {
292             return !singular;
293         }
294 
295         /** {@inheritDoc} */
296         public FieldVector<T> solve(FieldVector<T> b) {
297             try {
298                 return solve((ArrayFieldVector<T>) b);
299             } catch (ClassCastException cce) {
300 
301                 final int m = pivot.length;
302                 if (b.getDimension() != m) {
303                     throw new DimensionMismatchException(b.getDimension(), m);
304                 }
305                 if (singular) {
306                     throw new SingularMatrixException();
307                 }
308 
309                 // Apply permutations to b
310                 final T[] bp = MathArrays.buildArray(field, m);
311                 for (int row = 0; row < m; row++) {
312                     bp[row] = b.getEntry(pivot[row]);
313                 }
314 
315                 // Solve LY = b
316                 for (int col = 0; col < m; col++) {
317                     final T bpCol = bp[col];
318                     for (int i = col + 1; i < m; i++) {
319                         bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
320                     }
321                 }
322 
323                 // Solve UX = Y
324                 for (int col = m - 1; col >= 0; col--) {
325                     bp[col] = bp[col].divide(lu[col][col]);
326                     final T bpCol = bp[col];
327                     for (int i = 0; i < col; i++) {
328                         bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
329                     }
330                 }
331 
332                 return new ArrayFieldVector<T>(field, bp, false);
333 
334             }
335         }
336 
337         /** Solve the linear equation A &times; X = B.
338          * <p>The A matrix is implicit here. It is </p>
339          * @param b right-hand side of the equation A &times; X = B
340          * @return a vector X such that A &times; X = B
341          * @throws DimensionMismatchException if the matrices dimensions do not match.
342          * @throws SingularMatrixException if the decomposed matrix is singular.
343          */
344         public ArrayFieldVector<T> solve(ArrayFieldVector<T> b) {
345             final int m = pivot.length;
346             final int length = b.getDimension();
347             if (length != m) {
348                 throw new DimensionMismatchException(length, m);
349             }
350             if (singular) {
351                 throw new SingularMatrixException();
352             }
353 
354             // Apply permutations to b
355             final T[] bp = MathArrays.buildArray(field, m);
356             for (int row = 0; row < m; row++) {
357                 bp[row] = b.getEntry(pivot[row]);
358             }
359 
360             // Solve LY = b
361             for (int col = 0; col < m; col++) {
362                 final T bpCol = bp[col];
363                 for (int i = col + 1; i < m; i++) {
364                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
365                 }
366             }
367 
368             // Solve UX = Y
369             for (int col = m - 1; col >= 0; col--) {
370                 bp[col] = bp[col].divide(lu[col][col]);
371                 final T bpCol = bp[col];
372                 for (int i = 0; i < col; i++) {
373                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
374                 }
375             }
376 
377             return new ArrayFieldVector<T>(bp, false);
378         }
379 
380         /** {@inheritDoc} */
381         public FieldMatrix<T> solve(FieldMatrix<T> b) {
382             final int m = pivot.length;
383             if (b.getRowDimension() != m) {
384                 throw new DimensionMismatchException(b.getRowDimension(), m);
385             }
386             if (singular) {
387                 throw new SingularMatrixException();
388             }
389 
390             final int nColB = b.getColumnDimension();
391 
392             // Apply permutations to b
393             final T[][] bp = MathArrays.buildArray(field, m, nColB);
394             for (int row = 0; row < m; row++) {
395                 final T[] bpRow = bp[row];
396                 final int pRow = pivot[row];
397                 for (int col = 0; col < nColB; col++) {
398                     bpRow[col] = b.getEntry(pRow, col);
399                 }
400             }
401 
402             // Solve LY = b
403             for (int col = 0; col < m; col++) {
404                 final T[] bpCol = bp[col];
405                 for (int i = col + 1; i < m; i++) {
406                     final T[] bpI = bp[i];
407                     final T luICol = lu[i][col];
408                     for (int j = 0; j < nColB; j++) {
409                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
410                     }
411                 }
412             }
413 
414             // Solve UX = Y
415             for (int col = m - 1; col >= 0; col--) {
416                 final T[] bpCol = bp[col];
417                 final T luDiag = lu[col][col];
418                 for (int j = 0; j < nColB; j++) {
419                     bpCol[j] = bpCol[j].divide(luDiag);
420                 }
421                 for (int i = 0; i < col; i++) {
422                     final T[] bpI = bp[i];
423                     final T luICol = lu[i][col];
424                     for (int j = 0; j < nColB; j++) {
425                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
426                     }
427                 }
428             }
429 
430             return new Array2DRowFieldMatrix<T>(field, bp, false);
431 
432         }
433 
434         /** {@inheritDoc} */
435         public FieldMatrix<T> getInverse() {
436             final int m = pivot.length;
437             final T one = field.getOne();
438             FieldMatrix<T> identity = new Array2DRowFieldMatrix<T>(field, m, m);
439             for (int i = 0; i < m; ++i) {
440                 identity.setEntry(i, i, one);
441             }
442             return solve(identity);
443         }
444     }
445 }