View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.linear;
19  
20  import org.apache.commons.math4.legacy.core.Field;
21  import org.apache.commons.math4.legacy.core.FieldElement;
22  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
23  import org.apache.commons.math4.legacy.core.MathArrays;
24  
25  /**
26   * Calculates the LUP-decomposition of a square matrix.
27   * <p>The LUP-decomposition of a matrix A consists of three matrices
28   * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
29   * upper triangular and P is a permutation matrix. All matrices are
30   * m&times;m.</p>
31   * <p>Since {@link FieldElement field elements} do not provide an ordering
32   * operator, the permutation matrix is computed here only in order to avoid
33   * a zero pivot element, no attempt is done to get the largest pivot
34   * element.</p>
35   * <p>This class is based on the class with similar name from the
36   * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library.</p>
37   * <ul>
38   *   <li>a {@link #getP() getP} method has been added,</li>
39   *   <li>the {@code det} method has been renamed as {@link #getDeterminant()
40   *   getDeterminant},</li>
41   *   <li>the {@code getDoublePivot} method has been removed (but the int based
42   *   {@link #getPivot() getPivot} method has been kept),</li>
43   *   <li>the {@code solve} and {@code isNonSingular} methods have been replaced
44   *   by a {@link #getSolver() getSolver} method and the equivalent methods
45   *   provided by the returned {@link DecompositionSolver}.</li>
46   * </ul>
47   *
48   * @param <T> the type of the field elements
49   * @see <a href="http://mathworld.wolfram.com/LUDecomposition.html">MathWorld</a>
50   * @see <a href="http://en.wikipedia.org/wiki/LU_decomposition">Wikipedia</a>
51   * @since 2.0 (changed to concrete class in 3.0)
52   */
53  public class FieldLUDecomposition<T extends FieldElement<T>> {
54  
55      /** Field to which the elements belong. */
56      private final Field<T> field;
57  
58      /** Entries of LU decomposition. */
59      private T[][] lu;
60  
61      /** Pivot permutation associated with LU decomposition. */
62      private int[] pivot;
63  
64      /** Parity of the permutation associated with the LU decomposition. */
65      private boolean even;
66  
67      /** Singularity indicator. */
68      private boolean singular;
69  
70      /** Cached value of L. */
71      private FieldMatrix<T> cachedL;
72  
73      /** Cached value of U. */
74      private FieldMatrix<T> cachedU;
75  
76      /** Cached value of P. */
77      private FieldMatrix<T> cachedP;
78  
79      /**
80       * Calculates the LU-decomposition of the given matrix.
81       * @param matrix The matrix to decompose.
82       * @throws NonSquareMatrixException if matrix is not square
83       */
84      public FieldLUDecomposition(FieldMatrix<T> matrix) {
85          if (!matrix.isSquare()) {
86              throw new NonSquareMatrixException(matrix.getRowDimension(),
87                                                 matrix.getColumnDimension());
88          }
89  
90          final int m = matrix.getColumnDimension();
91          field = matrix.getField();
92          lu = matrix.getData();
93          pivot = new int[m];
94          cachedL = null;
95          cachedU = null;
96          cachedP = null;
97  
98          // Initialize permutation array and parity
99          for (int row = 0; row < m; row++) {
100             pivot[row] = row;
101         }
102         even     = true;
103         singular = false;
104 
105         // Loop over columns
106         for (int col = 0; col < m; col++) {
107 
108             T sum = field.getZero();
109 
110             // upper
111             for (int row = 0; row < col; row++) {
112                 final T[] luRow = lu[row];
113                 sum = luRow[col];
114                 for (int i = 0; i < row; i++) {
115                     sum = sum.subtract(luRow[i].multiply(lu[i][col]));
116                 }
117                 luRow[col] = sum;
118             }
119 
120             // lower
121             int nonZero = col; // permutation row
122             for (int row = col; row < m; row++) {
123                 final T[] luRow = lu[row];
124                 sum = luRow[col];
125                 for (int i = 0; i < col; i++) {
126                     sum = sum.subtract(luRow[i].multiply(lu[i][col]));
127                 }
128                 luRow[col] = sum;
129 
130                 if (lu[nonZero][col].equals(field.getZero())) {
131                     // try to select a better permutation choice
132                     ++nonZero;
133                 }
134             }
135 
136             // Singularity check
137             if (nonZero >= m) {
138                 singular = true;
139                 return;
140             }
141 
142             // Pivot if necessary
143             if (nonZero != col) {
144                 T tmp = field.getZero();
145                 for (int i = 0; i < m; i++) {
146                     tmp = lu[nonZero][i];
147                     lu[nonZero][i] = lu[col][i];
148                     lu[col][i] = tmp;
149                 }
150                 int temp = pivot[nonZero];
151                 pivot[nonZero] = pivot[col];
152                 pivot[col] = temp;
153                 even = !even;
154             }
155 
156             // Divide the lower elements by the "winning" diagonal elt.
157             final T luDiag = lu[col][col];
158             for (int row = col + 1; row < m; row++) {
159                 final T[] luRow = lu[row];
160                 luRow[col] = luRow[col].divide(luDiag);
161             }
162         }
163     }
164 
165     /**
166      * Returns the matrix L of the decomposition.
167      * <p>L is a lower-triangular matrix</p>
168      * @return the L matrix (or null if decomposed matrix is singular)
169      */
170     public FieldMatrix<T> getL() {
171         if (cachedL == null && !singular) {
172             final int m = pivot.length;
173             cachedL = new Array2DRowFieldMatrix<>(field, m, m);
174             for (int i = 0; i < m; ++i) {
175                 final T[] luI = lu[i];
176                 for (int j = 0; j < i; ++j) {
177                     cachedL.setEntry(i, j, luI[j]);
178                 }
179                 cachedL.setEntry(i, i, field.getOne());
180             }
181         }
182         return cachedL;
183     }
184 
185     /**
186      * Returns the matrix U of the decomposition.
187      * <p>U is an upper-triangular matrix</p>
188      * @return the U matrix (or null if decomposed matrix is singular)
189      */
190     public FieldMatrix<T> getU() {
191         if (cachedU == null && !singular) {
192             final int m = pivot.length;
193             cachedU = new Array2DRowFieldMatrix<>(field, m, m);
194             for (int i = 0; i < m; ++i) {
195                 final T[] luI = lu[i];
196                 for (int j = i; j < m; ++j) {
197                     cachedU.setEntry(i, j, luI[j]);
198                 }
199             }
200         }
201         return cachedU;
202     }
203 
204     /**
205      * Returns the P rows permutation matrix.
206      * <p>P is a sparse matrix with exactly one element set to 1.0 in
207      * each row and each column, all other elements being set to 0.0.</p>
208      * <p>The positions of the 1 elements are given by the {@link #getPivot()
209      * pivot permutation vector}.</p>
210      * @return the P rows permutation matrix (or null if decomposed matrix is singular)
211      * @see #getPivot()
212      */
213     public FieldMatrix<T> getP() {
214         if (cachedP == null && !singular) {
215             final int m = pivot.length;
216             cachedP = new Array2DRowFieldMatrix<>(field, m, m);
217             for (int i = 0; i < m; ++i) {
218                 cachedP.setEntry(i, pivot[i], field.getOne());
219             }
220         }
221         return cachedP;
222     }
223 
224     /**
225      * Returns the pivot permutation vector.
226      * @return the pivot permutation vector
227      * @see #getP()
228      */
229     public int[] getPivot() {
230         return pivot.clone();
231     }
232 
233     /**
234      * Return the determinant of the matrix.
235      * @return determinant of the matrix
236      */
237     public T getDeterminant() {
238         if (singular) {
239             return field.getZero();
240         } else {
241             final int m = pivot.length;
242             T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
243             for (int i = 0; i < m; i++) {
244                 determinant = determinant.multiply(lu[i][i]);
245             }
246             return determinant;
247         }
248     }
249 
250     /**
251      * Get a solver for finding the A &times; X = B solution in exact linear sense.
252      * @return a solver
253      */
254     public FieldDecompositionSolver<T> getSolver() {
255         return new Solver<>(field, lu, pivot, singular);
256     }
257 
258     /** Specialized solver.
259      * @param <T> the type of the field elements
260      */
261     private static final class Solver<T extends FieldElement<T>> implements FieldDecompositionSolver<T> {
262 
263         /** Field to which the elements belong. */
264         private final Field<T> field;
265 
266         /** Entries of LU decomposition. */
267         private final T[][] lu;
268 
269         /** Pivot permutation associated with LU decomposition. */
270         private final int[] pivot;
271 
272         /** Singularity indicator. */
273         private final boolean singular;
274 
275         /**
276          * Build a solver from decomposed matrix.
277          * @param field field to which the matrix elements belong
278          * @param lu entries of LU decomposition
279          * @param pivot pivot permutation associated with LU decomposition
280          * @param singular singularity indicator
281          */
282         private Solver(final Field<T> field, final T[][] lu,
283                        final int[] pivot, final boolean singular) {
284             this.field    = field;
285             this.lu       = lu;
286             this.pivot    = pivot;
287             this.singular = singular;
288         }
289 
290         /** {@inheritDoc} */
291         @Override
292         public boolean isNonSingular() {
293             return !singular;
294         }
295 
296         /** {@inheritDoc} */
297         @Override
298         public FieldVector<T> solve(FieldVector<T> b) {
299             if (b instanceof ArrayFieldVector) {
300                 return solve((ArrayFieldVector<T>) b);
301             }
302 
303             final int m = pivot.length;
304             if (b.getDimension() != m) {
305                 throw new DimensionMismatchException(b.getDimension(), m);
306             }
307             if (singular) {
308                 throw new SingularMatrixException();
309             }
310 
311             // Apply permutations to b
312             final T[] bp = MathArrays.buildArray(field, m);
313             for (int row = 0; row < m; row++) {
314                 bp[row] = b.getEntry(pivot[row]);
315             }
316 
317             // Solve LY = b
318             for (int col = 0; col < m; col++) {
319                 final T bpCol = bp[col];
320                 for (int i = col + 1; i < m; i++) {
321                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
322                 }
323             }
324 
325             // Solve UX = Y
326             for (int col = m - 1; col >= 0; col--) {
327                 bp[col] = bp[col].divide(lu[col][col]);
328                 final T bpCol = bp[col];
329                 for (int i = 0; i < col; i++) {
330                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
331                 }
332             }
333 
334             return new ArrayFieldVector<>(field, bp, false);
335         }
336 
337         /** Solve the linear equation A &times; X = B.
338          * <p>The A matrix is implicit here. It is </p>
339          * @param b right-hand side of the equation A &times; X = B
340          * @return a vector X such that A &times; X = B
341          * @throws DimensionMismatchException if the matrices dimensions do not match.
342          * @throws SingularMatrixException if the decomposed matrix is singular.
343          */
344         public ArrayFieldVector<T> solve(ArrayFieldVector<T> b) {
345             final int m = pivot.length;
346             final int length = b.getDimension();
347             if (length != m) {
348                 throw new DimensionMismatchException(length, m);
349             }
350             if (singular) {
351                 throw new SingularMatrixException();
352             }
353 
354             // Apply permutations to b
355             final T[] bp = MathArrays.buildArray(field, m);
356             for (int row = 0; row < m; row++) {
357                 bp[row] = b.getEntry(pivot[row]);
358             }
359 
360             // Solve LY = b
361             for (int col = 0; col < m; col++) {
362                 final T bpCol = bp[col];
363                 for (int i = col + 1; i < m; i++) {
364                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
365                 }
366             }
367 
368             // Solve UX = Y
369             for (int col = m - 1; col >= 0; col--) {
370                 bp[col] = bp[col].divide(lu[col][col]);
371                 final T bpCol = bp[col];
372                 for (int i = 0; i < col; i++) {
373                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
374                 }
375             }
376 
377             return new ArrayFieldVector<>(bp, false);
378         }
379 
380         /** {@inheritDoc} */
381         @Override
382         public FieldMatrix<T> solve(FieldMatrix<T> b) {
383             final int m = pivot.length;
384             if (b.getRowDimension() != m) {
385                 throw new DimensionMismatchException(b.getRowDimension(), m);
386             }
387             if (singular) {
388                 throw new SingularMatrixException();
389             }
390 
391             final int nColB = b.getColumnDimension();
392 
393             // Apply permutations to b
394             final T[][] bp = MathArrays.buildArray(field, m, nColB);
395             for (int row = 0; row < m; row++) {
396                 final T[] bpRow = bp[row];
397                 final int pRow = pivot[row];
398                 for (int col = 0; col < nColB; col++) {
399                     bpRow[col] = b.getEntry(pRow, col);
400                 }
401             }
402 
403             // Solve LY = b
404             for (int col = 0; col < m; col++) {
405                 final T[] bpCol = bp[col];
406                 for (int i = col + 1; i < m; i++) {
407                     final T[] bpI = bp[i];
408                     final T luICol = lu[i][col];
409                     for (int j = 0; j < nColB; j++) {
410                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
411                     }
412                 }
413             }
414 
415             // Solve UX = Y
416             for (int col = m - 1; col >= 0; col--) {
417                 final T[] bpCol = bp[col];
418                 final T luDiag = lu[col][col];
419                 for (int j = 0; j < nColB; j++) {
420                     bpCol[j] = bpCol[j].divide(luDiag);
421                 }
422                 for (int i = 0; i < col; i++) {
423                     final T[] bpI = bp[i];
424                     final T luICol = lu[i][col];
425                     for (int j = 0; j < nColB; j++) {
426                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
427                     }
428                 }
429             }
430 
431             return new Array2DRowFieldMatrix<>(field, bp, false);
432         }
433 
434         /** {@inheritDoc} */
435         @Override
436         public FieldMatrix<T> getInverse() {
437             final int m = pivot.length;
438             final T one = field.getOne();
439             FieldMatrix<T> identity = new Array2DRowFieldMatrix<>(field, m, m);
440             for (int i = 0; i < m; ++i) {
441                 identity.setEntry(i, i, one);
442             }
443             return solve(identity);
444         }
445     }
446 }