View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math3.linear;
19  
20  import org.apache.commons.math3.Field;
21  import org.apache.commons.math3.FieldElement;
22  import org.apache.commons.math3.exception.DimensionMismatchException;
23  import org.apache.commons.math3.util.MathArrays;
24  
25  /**
26   * Calculates the LUP-decomposition of a square matrix.
27   * <p>The LUP-decomposition of a matrix A consists of three matrices
28   * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
29   * upper triangular and P is a permutation matrix. All matrices are
30   * m&times;m.</p>
31   * <p>Since {@link FieldElement field elements} do not provide an ordering
32   * operator, the permutation matrix is computed here only in order to avoid
33   * a zero pivot element, no attempt is done to get the largest pivot
34   * element.</p>
35   * <p>This class is based on the class with similar name from the
36   * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library.</p>
37   * <ul>
38   *   <li>a {@link #getP() getP} method has been added,</li>
39   *   <li>the {@code det} method has been renamed as {@link #getDeterminant()
40   *   getDeterminant},</li>
41   *   <li>the {@code getDoublePivot} method has been removed (but the int based
42   *   {@link #getPivot() getPivot} method has been kept),</li>
43   *   <li>the {@code solve} and {@code isNonSingular} methods have been replaced
44   *   by a {@link #getSolver() getSolver} method and the equivalent methods
45   *   provided by the returned {@link DecompositionSolver}.</li>
46   * </ul>
47   *
48   * @param <T> the type of the field elements
49   * @see <a href="http://mathworld.wolfram.com/LUDecomposition.html">MathWorld</a>
50   * @see <a href="http://en.wikipedia.org/wiki/LU_decomposition">Wikipedia</a>
51   * @since 2.0 (changed to concrete class in 3.0)
52   */
53  public class FieldLUDecomposition<T extends FieldElement<T>> {
54  
55      /** Field to which the elements belong. */
56      private final Field<T> field;
57  
58      /** Entries of LU decomposition. */
59      private T[][] lu;
60  
61      /** Pivot permutation associated with LU decomposition. */
62      private int[] pivot;
63  
64      /** Parity of the permutation associated with the LU decomposition. */
65      private boolean even;
66  
67      /** Singularity indicator. */
68      private boolean singular;
69  
70      /** Cached value of L. */
71      private FieldMatrix<T> cachedL;
72  
73      /** Cached value of U. */
74      private FieldMatrix<T> cachedU;
75  
76      /** Cached value of P. */
77      private FieldMatrix<T> cachedP;
78  
79      /**
80       * Calculates the LU-decomposition of the given matrix.
81       * @param matrix The matrix to decompose.
82       * @throws NonSquareMatrixException if matrix is not square
83       */
84      public FieldLUDecomposition(FieldMatrix<T> matrix) {
85          if (!matrix.isSquare()) {
86              throw new NonSquareMatrixException(matrix.getRowDimension(),
87                                                 matrix.getColumnDimension());
88          }
89  
90          final int m = matrix.getColumnDimension();
91          field = matrix.getField();
92          lu = matrix.getData();
93          pivot = new int[m];
94          cachedL = null;
95          cachedU = null;
96          cachedP = null;
97  
98          // Initialize permutation array and parity
99          for (int row = 0; row < m; row++) {
100             pivot[row] = row;
101         }
102         even     = true;
103         singular = false;
104 
105         // Loop over columns
106         for (int col = 0; col < m; col++) {
107 
108             T sum = field.getZero();
109 
110             // upper
111             for (int row = 0; row < col; row++) {
112                 final T[] luRow = lu[row];
113                 sum = luRow[col];
114                 for (int i = 0; i < row; i++) {
115                     sum = sum.subtract(luRow[i].multiply(lu[i][col]));
116                 }
117                 luRow[col] = sum;
118             }
119 
120             // lower
121             int nonZero = col; // permutation row
122             for (int row = col; row < m; row++) {
123                 final T[] luRow = lu[row];
124                 sum = luRow[col];
125                 for (int i = 0; i < col; i++) {
126                     sum = sum.subtract(luRow[i].multiply(lu[i][col]));
127                 }
128                 luRow[col] = sum;
129 
130                 if (lu[nonZero][col].equals(field.getZero())) {
131                     // try to select a better permutation choice
132                     ++nonZero;
133                 }
134             }
135 
136             // Singularity check
137             if (nonZero >= m) {
138                 singular = true;
139                 return;
140             }
141 
142             // Pivot if necessary
143             if (nonZero != col) {
144                 T tmp = field.getZero();
145                 for (int i = 0; i < m; i++) {
146                     tmp = lu[nonZero][i];
147                     lu[nonZero][i] = lu[col][i];
148                     lu[col][i] = tmp;
149                 }
150                 int temp = pivot[nonZero];
151                 pivot[nonZero] = pivot[col];
152                 pivot[col] = temp;
153                 even = !even;
154             }
155 
156             // Divide the lower elements by the "winning" diagonal elt.
157             final T luDiag = lu[col][col];
158             for (int row = col + 1; row < m; row++) {
159                 final T[] luRow = lu[row];
160                 luRow[col] = luRow[col].divide(luDiag);
161             }
162         }
163 
164     }
165 
166     /**
167      * Returns the matrix L of the decomposition.
168      * <p>L is a lower-triangular matrix</p>
169      * @return the L matrix (or null if decomposed matrix is singular)
170      */
171     public FieldMatrix<T> getL() {
172         if ((cachedL == null) && !singular) {
173             final int m = pivot.length;
174             cachedL = new Array2DRowFieldMatrix<T>(field, m, m);
175             for (int i = 0; i < m; ++i) {
176                 final T[] luI = lu[i];
177                 for (int j = 0; j < i; ++j) {
178                     cachedL.setEntry(i, j, luI[j]);
179                 }
180                 cachedL.setEntry(i, i, field.getOne());
181             }
182         }
183         return cachedL;
184     }
185 
186     /**
187      * Returns the matrix U of the decomposition.
188      * <p>U is an upper-triangular matrix</p>
189      * @return the U matrix (or null if decomposed matrix is singular)
190      */
191     public FieldMatrix<T> getU() {
192         if ((cachedU == null) && !singular) {
193             final int m = pivot.length;
194             cachedU = new Array2DRowFieldMatrix<T>(field, m, m);
195             for (int i = 0; i < m; ++i) {
196                 final T[] luI = lu[i];
197                 for (int j = i; j < m; ++j) {
198                     cachedU.setEntry(i, j, luI[j]);
199                 }
200             }
201         }
202         return cachedU;
203     }
204 
205     /**
206      * Returns the P rows permutation matrix.
207      * <p>P is a sparse matrix with exactly one element set to 1.0 in
208      * each row and each column, all other elements being set to 0.0.</p>
209      * <p>The positions of the 1 elements are given by the {@link #getPivot()
210      * pivot permutation vector}.</p>
211      * @return the P rows permutation matrix (or null if decomposed matrix is singular)
212      * @see #getPivot()
213      */
214     public FieldMatrix<T> getP() {
215         if ((cachedP == null) && !singular) {
216             final int m = pivot.length;
217             cachedP = new Array2DRowFieldMatrix<T>(field, m, m);
218             for (int i = 0; i < m; ++i) {
219                 cachedP.setEntry(i, pivot[i], field.getOne());
220             }
221         }
222         return cachedP;
223     }
224 
225     /**
226      * Returns the pivot permutation vector.
227      * @return the pivot permutation vector
228      * @see #getP()
229      */
230     public int[] getPivot() {
231         return pivot.clone();
232     }
233 
234     /**
235      * Return the determinant of the matrix.
236      * @return determinant of the matrix
237      */
238     public T getDeterminant() {
239         if (singular) {
240             return field.getZero();
241         } else {
242             final int m = pivot.length;
243             T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
244             for (int i = 0; i < m; i++) {
245                 determinant = determinant.multiply(lu[i][i]);
246             }
247             return determinant;
248         }
249     }
250 
251     /**
252      * Get a solver for finding the A &times; X = B solution in exact linear sense.
253      * @return a solver
254      */
255     public FieldDecompositionSolver<T> getSolver() {
256         return new Solver<T>(field, lu, pivot, singular);
257     }
258 
259     /** Specialized solver. */
260     private static class Solver<T extends FieldElement<T>> implements FieldDecompositionSolver<T> {
261 
262         /** Field to which the elements belong. */
263         private final Field<T> field;
264 
265         /** Entries of LU decomposition. */
266         private final T[][] lu;
267 
268         /** Pivot permutation associated with LU decomposition. */
269         private final int[] pivot;
270 
271         /** Singularity indicator. */
272         private final boolean singular;
273 
274         /**
275          * Build a solver from decomposed matrix.
276          * @param field field to which the matrix elements belong
277          * @param lu entries of LU decomposition
278          * @param pivot pivot permutation associated with LU decomposition
279          * @param singular singularity indicator
280          */
281         private Solver(final Field<T> field, final T[][] lu,
282                        final int[] pivot, final boolean singular) {
283             this.field    = field;
284             this.lu       = lu;
285             this.pivot    = pivot;
286             this.singular = singular;
287         }
288 
289         /** {@inheritDoc} */
290         public boolean isNonSingular() {
291             return !singular;
292         }
293 
294         /** {@inheritDoc} */
295         public FieldVector<T> solve(FieldVector<T> b) {
296             try {
297                 return solve((ArrayFieldVector<T>) b);
298             } catch (ClassCastException cce) {
299 
300                 final int m = pivot.length;
301                 if (b.getDimension() != m) {
302                     throw new DimensionMismatchException(b.getDimension(), m);
303                 }
304                 if (singular) {
305                     throw new SingularMatrixException();
306                 }
307 
308                 // Apply permutations to b
309                 final T[] bp = MathArrays.buildArray(field, m);
310                 for (int row = 0; row < m; row++) {
311                     bp[row] = b.getEntry(pivot[row]);
312                 }
313 
314                 // Solve LY = b
315                 for (int col = 0; col < m; col++) {
316                     final T bpCol = bp[col];
317                     for (int i = col + 1; i < m; i++) {
318                         bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
319                     }
320                 }
321 
322                 // Solve UX = Y
323                 for (int col = m - 1; col >= 0; col--) {
324                     bp[col] = bp[col].divide(lu[col][col]);
325                     final T bpCol = bp[col];
326                     for (int i = 0; i < col; i++) {
327                         bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
328                     }
329                 }
330 
331                 return new ArrayFieldVector<T>(field, bp, false);
332 
333             }
334         }
335 
336         /** Solve the linear equation A &times; X = B.
337          * <p>The A matrix is implicit here. It is </p>
338          * @param b right-hand side of the equation A &times; X = B
339          * @return a vector X such that A &times; X = B
340          * @throws DimensionMismatchException if the matrices dimensions do not match.
341          * @throws SingularMatrixException if the decomposed matrix is singular.
342          */
343         public ArrayFieldVector<T> solve(ArrayFieldVector<T> b) {
344             final int m = pivot.length;
345             final int length = b.getDimension();
346             if (length != m) {
347                 throw new DimensionMismatchException(length, m);
348             }
349             if (singular) {
350                 throw new SingularMatrixException();
351             }
352 
353             // Apply permutations to b
354             final T[] bp = MathArrays.buildArray(field, m);
355             for (int row = 0; row < m; row++) {
356                 bp[row] = b.getEntry(pivot[row]);
357             }
358 
359             // Solve LY = b
360             for (int col = 0; col < m; col++) {
361                 final T bpCol = bp[col];
362                 for (int i = col + 1; i < m; i++) {
363                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
364                 }
365             }
366 
367             // Solve UX = Y
368             for (int col = m - 1; col >= 0; col--) {
369                 bp[col] = bp[col].divide(lu[col][col]);
370                 final T bpCol = bp[col];
371                 for (int i = 0; i < col; i++) {
372                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
373                 }
374             }
375 
376             return new ArrayFieldVector<T>(bp, false);
377         }
378 
379         /** {@inheritDoc} */
380         public FieldMatrix<T> solve(FieldMatrix<T> b) {
381             final int m = pivot.length;
382             if (b.getRowDimension() != m) {
383                 throw new DimensionMismatchException(b.getRowDimension(), m);
384             }
385             if (singular) {
386                 throw new SingularMatrixException();
387             }
388 
389             final int nColB = b.getColumnDimension();
390 
391             // Apply permutations to b
392             final T[][] bp = MathArrays.buildArray(field, m, nColB);
393             for (int row = 0; row < m; row++) {
394                 final T[] bpRow = bp[row];
395                 final int pRow = pivot[row];
396                 for (int col = 0; col < nColB; col++) {
397                     bpRow[col] = b.getEntry(pRow, col);
398                 }
399             }
400 
401             // Solve LY = b
402             for (int col = 0; col < m; col++) {
403                 final T[] bpCol = bp[col];
404                 for (int i = col + 1; i < m; i++) {
405                     final T[] bpI = bp[i];
406                     final T luICol = lu[i][col];
407                     for (int j = 0; j < nColB; j++) {
408                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
409                     }
410                 }
411             }
412 
413             // Solve UX = Y
414             for (int col = m - 1; col >= 0; col--) {
415                 final T[] bpCol = bp[col];
416                 final T luDiag = lu[col][col];
417                 for (int j = 0; j < nColB; j++) {
418                     bpCol[j] = bpCol[j].divide(luDiag);
419                 }
420                 for (int i = 0; i < col; i++) {
421                     final T[] bpI = bp[i];
422                     final T luICol = lu[i][col];
423                     for (int j = 0; j < nColB; j++) {
424                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
425                     }
426                 }
427             }
428 
429             return new Array2DRowFieldMatrix<T>(field, bp, false);
430 
431         }
432 
433         /** {@inheritDoc} */
434         public FieldMatrix<T> getInverse() {
435             final int m = pivot.length;
436             final T one = field.getOne();
437             FieldMatrix<T> identity = new Array2DRowFieldMatrix<T>(field, m, m);
438             for (int i = 0; i < m; ++i) {
439                 identity.setEntry(i, i, one);
440             }
441             return solve(identity);
442         }
443     }
444 }