View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math3.linear;
19  
20  import org.apache.commons.math3.Field;
21  import org.apache.commons.math3.FieldElement;
22  import org.apache.commons.math3.exception.DimensionMismatchException;
23  import org.apache.commons.math3.util.MathArrays;
24  
25  /**
26   * Calculates the LUP-decomposition of a square matrix.
27   * <p>The LUP-decomposition of a matrix A consists of three matrices
28   * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
29   * upper triangular and P is a permutation matrix. All matrices are
30   * m&times;m.</p>
31   * <p>Since {@link FieldElement field elements} do not provide an ordering
32   * operator, the permutation matrix is computed here only in order to avoid
33   * a zero pivot element, no attempt is done to get the largest pivot
34   * element.</p>
35   * <p>This class is based on the class with similar name from the
36   * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library.</p>
37   * <ul>
38   *   <li>a {@link #getP() getP} method has been added,</li>
39   *   <li>the {@code det} method has been renamed as {@link #getDeterminant()
40   *   getDeterminant},</li>
41   *   <li>the {@code getDoublePivot} method has been removed (but the int based
42   *   {@link #getPivot() getPivot} method has been kept),</li>
43   *   <li>the {@code solve} and {@code isNonSingular} methods have been replaced
44   *   by a {@link #getSolver() getSolver} method and the equivalent methods
45   *   provided by the returned {@link DecompositionSolver}.</li>
46   * </ul>
47   *
48   * @param <T> the type of the field elements
49   * @see <a href="http://mathworld.wolfram.com/LUDecomposition.html">MathWorld</a>
50   * @see <a href="http://en.wikipedia.org/wiki/LU_decomposition">Wikipedia</a>
51   * @since 2.0 (changed to concrete class in 3.0)
52   */
53  public class FieldLUDecomposition<T extends FieldElement<T>> {
54  
55      /** Field to which the elements belong. */
56      private final Field<T> field;
57  
58      /** Entries of LU decomposition. */
59      private T[][] lu;
60  
61      /** Pivot permutation associated with LU decomposition. */
62      private int[] pivot;
63  
64      /** Parity of the permutation associated with the LU decomposition. */
65      private boolean even;
66  
67      /** Singularity indicator. */
68      private boolean singular;
69  
70      /** Cached value of L. */
71      private FieldMatrix<T> cachedL;
72  
73      /** Cached value of U. */
74      private FieldMatrix<T> cachedU;
75  
76      /** Cached value of P. */
77      private FieldMatrix<T> cachedP;
78  
79      /**
80       * Calculates the LU-decomposition of the given matrix.
81       * @param matrix The matrix to decompose.
82       * @throws NonSquareMatrixException if matrix is not square
83       */
84      public FieldLUDecomposition(FieldMatrix<T> matrix) {
85          if (!matrix.isSquare()) {
86              throw new NonSquareMatrixException(matrix.getRowDimension(),
87                                                 matrix.getColumnDimension());
88          }
89  
90          final int m = matrix.getColumnDimension();
91          field = matrix.getField();
92          lu = matrix.getData();
93          pivot = new int[m];
94          cachedL = null;
95          cachedU = null;
96          cachedP = null;
97  
98          // Initialize permutation array and parity
99          for (int row = 0; row < m; row++) {
100             pivot[row] = row;
101         }
102         even     = true;
103         singular = false;
104 
105         // Loop over columns
106         for (int col = 0; col < m; col++) {
107 
108             T sum = field.getZero();
109 
110             // upper
111             for (int row = 0; row < col; row++) {
112                 final T[] luRow = lu[row];
113                 sum = luRow[col];
114                 for (int i = 0; i < row; i++) {
115                     sum = sum.subtract(luRow[i].multiply(lu[i][col]));
116                 }
117                 luRow[col] = sum;
118             }
119 
120             // lower
121             int nonZero = col; // permutation row
122             for (int row = col; row < m; row++) {
123                 final T[] luRow = lu[row];
124                 sum = luRow[col];
125                 for (int i = 0; i < col; i++) {
126                     sum = sum.subtract(luRow[i].multiply(lu[i][col]));
127                 }
128                 luRow[col] = sum;
129 
130                 if (lu[nonZero][col].equals(field.getZero())) {
131                     // try to select a better permutation choice
132                     ++nonZero;
133                 }
134             }
135 
136             // Singularity check
137             if (nonZero >= m) {
138                 singular = true;
139                 return;
140             }
141 
142             // Pivot if necessary
143             if (nonZero != col) {
144                 T tmp = field.getZero();
145                 for (int i = 0; i < m; i++) {
146                     tmp = lu[nonZero][i];
147                     lu[nonZero][i] = lu[col][i];
148                     lu[col][i] = tmp;
149                 }
150                 int temp = pivot[nonZero];
151                 pivot[nonZero] = pivot[col];
152                 pivot[col] = temp;
153                 even = !even;
154             }
155 
156             // Divide the lower elements by the "winning" diagonal elt.
157             final T luDiag = lu[col][col];
158             for (int row = col + 1; row < m; row++) {
159                 final T[] luRow = lu[row];
160                 luRow[col] = luRow[col].divide(luDiag);
161             }
162         }
163 
164     }
165 
166     /**
167      * Returns the matrix L of the decomposition.
168      * <p>L is a lower-triangular matrix</p>
169      * @return the L matrix (or null if decomposed matrix is singular)
170      */
171     public FieldMatrix<T> getL() {
172         if ((cachedL == null) && !singular) {
173             final int m = pivot.length;
174             cachedL = new Array2DRowFieldMatrix<T>(field, m, m);
175             for (int i = 0; i < m; ++i) {
176                 final T[] luI = lu[i];
177                 for (int j = 0; j < i; ++j) {
178                     cachedL.setEntry(i, j, luI[j]);
179                 }
180                 cachedL.setEntry(i, i, field.getOne());
181             }
182         }
183         return cachedL;
184     }
185 
186     /**
187      * Returns the matrix U of the decomposition.
188      * <p>U is an upper-triangular matrix</p>
189      * @return the U matrix (or null if decomposed matrix is singular)
190      */
191     public FieldMatrix<T> getU() {
192         if ((cachedU == null) && !singular) {
193             final int m = pivot.length;
194             cachedU = new Array2DRowFieldMatrix<T>(field, m, m);
195             for (int i = 0; i < m; ++i) {
196                 final T[] luI = lu[i];
197                 for (int j = i; j < m; ++j) {
198                     cachedU.setEntry(i, j, luI[j]);
199                 }
200             }
201         }
202         return cachedU;
203     }
204 
205     /**
206      * Returns the P rows permutation matrix.
207      * <p>P is a sparse matrix with exactly one element set to 1.0 in
208      * each row and each column, all other elements being set to 0.0.</p>
209      * <p>The positions of the 1 elements are given by the {@link #getPivot()
210      * pivot permutation vector}.</p>
211      * @return the P rows permutation matrix (or null if decomposed matrix is singular)
212      * @see #getPivot()
213      */
214     public FieldMatrix<T> getP() {
215         if ((cachedP == null) && !singular) {
216             final int m = pivot.length;
217             cachedP = new Array2DRowFieldMatrix<T>(field, m, m);
218             for (int i = 0; i < m; ++i) {
219                 cachedP.setEntry(i, pivot[i], field.getOne());
220             }
221         }
222         return cachedP;
223     }
224 
225     /**
226      * Returns the pivot permutation vector.
227      * @return the pivot permutation vector
228      * @see #getP()
229      */
230     public int[] getPivot() {
231         return pivot.clone();
232     }
233 
234     /**
235      * Return the determinant of the matrix.
236      * @return determinant of the matrix
237      */
238     public T getDeterminant() {
239         if (singular) {
240             return field.getZero();
241         } else {
242             final int m = pivot.length;
243             T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
244             for (int i = 0; i < m; i++) {
245                 determinant = determinant.multiply(lu[i][i]);
246             }
247             return determinant;
248         }
249     }
250 
251     /**
252      * Get a solver for finding the A &times; X = B solution in exact linear sense.
253      * @return a solver
254      */
255     public FieldDecompositionSolver<T> getSolver() {
256         return new Solver<T>(field, lu, pivot, singular);
257     }
258 
259     /** Specialized solver.
260      * @param <T> the type of the field elements
261      */
262     private static class Solver<T extends FieldElement<T>> implements FieldDecompositionSolver<T> {
263 
264         /** Field to which the elements belong. */
265         private final Field<T> field;
266 
267         /** Entries of LU decomposition. */
268         private final T[][] lu;
269 
270         /** Pivot permutation associated with LU decomposition. */
271         private final int[] pivot;
272 
273         /** Singularity indicator. */
274         private final boolean singular;
275 
276         /**
277          * Build a solver from decomposed matrix.
278          * @param field field to which the matrix elements belong
279          * @param lu entries of LU decomposition
280          * @param pivot pivot permutation associated with LU decomposition
281          * @param singular singularity indicator
282          */
283         private Solver(final Field<T> field, final T[][] lu,
284                        final int[] pivot, final boolean singular) {
285             this.field    = field;
286             this.lu       = lu;
287             this.pivot    = pivot;
288             this.singular = singular;
289         }
290 
291         /** {@inheritDoc} */
292         public boolean isNonSingular() {
293             return !singular;
294         }
295 
296         /** {@inheritDoc} */
297         public FieldVector<T> solve(FieldVector<T> b) {
298             try {
299                 return solve((ArrayFieldVector<T>) b);
300             } catch (ClassCastException cce) {
301 
302                 final int m = pivot.length;
303                 if (b.getDimension() != m) {
304                     throw new DimensionMismatchException(b.getDimension(), m);
305                 }
306                 if (singular) {
307                     throw new SingularMatrixException();
308                 }
309 
310                 // Apply permutations to b
311                 final T[] bp = MathArrays.buildArray(field, m);
312                 for (int row = 0; row < m; row++) {
313                     bp[row] = b.getEntry(pivot[row]);
314                 }
315 
316                 // Solve LY = b
317                 for (int col = 0; col < m; col++) {
318                     final T bpCol = bp[col];
319                     for (int i = col + 1; i < m; i++) {
320                         bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
321                     }
322                 }
323 
324                 // Solve UX = Y
325                 for (int col = m - 1; col >= 0; col--) {
326                     bp[col] = bp[col].divide(lu[col][col]);
327                     final T bpCol = bp[col];
328                     for (int i = 0; i < col; i++) {
329                         bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
330                     }
331                 }
332 
333                 return new ArrayFieldVector<T>(field, bp, false);
334 
335             }
336         }
337 
338         /** Solve the linear equation A &times; X = B.
339          * <p>The A matrix is implicit here. It is </p>
340          * @param b right-hand side of the equation A &times; X = B
341          * @return a vector X such that A &times; X = B
342          * @throws DimensionMismatchException if the matrices dimensions do not match.
343          * @throws SingularMatrixException if the decomposed matrix is singular.
344          */
345         public ArrayFieldVector<T> solve(ArrayFieldVector<T> b) {
346             final int m = pivot.length;
347             final int length = b.getDimension();
348             if (length != m) {
349                 throw new DimensionMismatchException(length, m);
350             }
351             if (singular) {
352                 throw new SingularMatrixException();
353             }
354 
355             // Apply permutations to b
356             final T[] bp = MathArrays.buildArray(field, m);
357             for (int row = 0; row < m; row++) {
358                 bp[row] = b.getEntry(pivot[row]);
359             }
360 
361             // Solve LY = b
362             for (int col = 0; col < m; col++) {
363                 final T bpCol = bp[col];
364                 for (int i = col + 1; i < m; i++) {
365                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
366                 }
367             }
368 
369             // Solve UX = Y
370             for (int col = m - 1; col >= 0; col--) {
371                 bp[col] = bp[col].divide(lu[col][col]);
372                 final T bpCol = bp[col];
373                 for (int i = 0; i < col; i++) {
374                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
375                 }
376             }
377 
378             return new ArrayFieldVector<T>(bp, false);
379         }
380 
381         /** {@inheritDoc} */
382         public FieldMatrix<T> solve(FieldMatrix<T> b) {
383             final int m = pivot.length;
384             if (b.getRowDimension() != m) {
385                 throw new DimensionMismatchException(b.getRowDimension(), m);
386             }
387             if (singular) {
388                 throw new SingularMatrixException();
389             }
390 
391             final int nColB = b.getColumnDimension();
392 
393             // Apply permutations to b
394             final T[][] bp = MathArrays.buildArray(field, m, nColB);
395             for (int row = 0; row < m; row++) {
396                 final T[] bpRow = bp[row];
397                 final int pRow = pivot[row];
398                 for (int col = 0; col < nColB; col++) {
399                     bpRow[col] = b.getEntry(pRow, col);
400                 }
401             }
402 
403             // Solve LY = b
404             for (int col = 0; col < m; col++) {
405                 final T[] bpCol = bp[col];
406                 for (int i = col + 1; i < m; i++) {
407                     final T[] bpI = bp[i];
408                     final T luICol = lu[i][col];
409                     for (int j = 0; j < nColB; j++) {
410                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
411                     }
412                 }
413             }
414 
415             // Solve UX = Y
416             for (int col = m - 1; col >= 0; col--) {
417                 final T[] bpCol = bp[col];
418                 final T luDiag = lu[col][col];
419                 for (int j = 0; j < nColB; j++) {
420                     bpCol[j] = bpCol[j].divide(luDiag);
421                 }
422                 for (int i = 0; i < col; i++) {
423                     final T[] bpI = bp[i];
424                     final T luICol = lu[i][col];
425                     for (int j = 0; j < nColB; j++) {
426                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
427                     }
428                 }
429             }
430 
431             return new Array2DRowFieldMatrix<T>(field, bp, false);
432 
433         }
434 
435         /** {@inheritDoc} */
436         public FieldMatrix<T> getInverse() {
437             final int m = pivot.length;
438             final T one = field.getOne();
439             FieldMatrix<T> identity = new Array2DRowFieldMatrix<T>(field, m, m);
440             for (int i = 0; i < m; ++i) {
441                 identity.setEntry(i, i, one);
442             }
443             return solve(identity);
444         }
445     }
446 }