001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math3.linear;
019    
020    import java.lang.reflect.Array;
021    
022    import org.apache.commons.math3.Field;
023    import org.apache.commons.math3.FieldElement;
024    import org.apache.commons.math3.exception.DimensionMismatchException;
025    
026    /**
027     * Calculates the LUP-decomposition of a square matrix.
028     * <p>The LUP-decomposition of a matrix A consists of three matrices
029     * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
030     * upper triangular and P is a permutation matrix. All matrices are
031     * m&times;m.</p>
032     * <p>Since {@link FieldElement field elements} do not provide an ordering
033     * operator, the permutation matrix is computed here only in order to avoid
034     * a zero pivot element, no attempt is done to get the largest pivot
035     * element.</p>
036     * <p>This class is based on the class with similar name from the
037     * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library.</p>
038     * <ul>
039     *   <li>a {@link #getP() getP} method has been added,</li>
040     *   <li>the {@code det} method has been renamed as {@link #getDeterminant()
041     *   getDeterminant},</li>
042     *   <li>the {@code getDoublePivot} method has been removed (but the int based
043     *   {@link #getPivot() getPivot} method has been kept),</li>
044     *   <li>the {@code solve} and {@code isNonSingular} methods have been replaced
045     *   by a {@link #getSolver() getSolver} method and the equivalent methods
046     *   provided by the returned {@link DecompositionSolver}.</li>
047     * </ul>
048     *
049     * @param <T> the type of the field elements
050     * @see <a href="http://mathworld.wolfram.com/LUDecomposition.html">MathWorld</a>
051     * @see <a href="http://en.wikipedia.org/wiki/LU_decomposition">Wikipedia</a>
052     * @version $Id: FieldLUDecomposition.java 1416643 2012-12-03 19:37:14Z tn $
053     * @since 2.0 (changed to concrete class in 3.0)
054     */
055    public class FieldLUDecomposition<T extends FieldElement<T>> {
056    
057        /** Field to which the elements belong. */
058        private final Field<T> field;
059    
060        /** Entries of LU decomposition. */
061        private T[][] lu;
062    
063        /** Pivot permutation associated with LU decomposition. */
064        private int[] pivot;
065    
066        /** Parity of the permutation associated with the LU decomposition. */
067        private boolean even;
068    
069        /** Singularity indicator. */
070        private boolean singular;
071    
072        /** Cached value of L. */
073        private FieldMatrix<T> cachedL;
074    
075        /** Cached value of U. */
076        private FieldMatrix<T> cachedU;
077    
078        /** Cached value of P. */
079        private FieldMatrix<T> cachedP;
080    
081        /**
082         * Calculates the LU-decomposition of the given matrix.
083         * @param matrix The matrix to decompose.
084         * @throws NonSquareMatrixException if matrix is not square
085         */
086        public FieldLUDecomposition(FieldMatrix<T> matrix) {
087            if (!matrix.isSquare()) {
088                throw new NonSquareMatrixException(matrix.getRowDimension(),
089                                                   matrix.getColumnDimension());
090            }
091    
092            final int m = matrix.getColumnDimension();
093            field = matrix.getField();
094            lu = matrix.getData();
095            pivot = new int[m];
096            cachedL = null;
097            cachedU = null;
098            cachedP = null;
099    
100            // Initialize permutation array and parity
101            for (int row = 0; row < m; row++) {
102                pivot[row] = row;
103            }
104            even     = true;
105            singular = false;
106    
107            // Loop over columns
108            for (int col = 0; col < m; col++) {
109    
110                T sum = field.getZero();
111    
112                // upper
113                for (int row = 0; row < col; row++) {
114                    final T[] luRow = lu[row];
115                    sum = luRow[col];
116                    for (int i = 0; i < row; i++) {
117                        sum = sum.subtract(luRow[i].multiply(lu[i][col]));
118                    }
119                    luRow[col] = sum;
120                }
121    
122                // lower
123                int nonZero = col; // permutation row
124                for (int row = col; row < m; row++) {
125                    final T[] luRow = lu[row];
126                    sum = luRow[col];
127                    for (int i = 0; i < col; i++) {
128                        sum = sum.subtract(luRow[i].multiply(lu[i][col]));
129                    }
130                    luRow[col] = sum;
131    
132                    if (lu[nonZero][col].equals(field.getZero())) {
133                        // try to select a better permutation choice
134                        ++nonZero;
135                    }
136                }
137    
138                // Singularity check
139                if (nonZero >= m) {
140                    singular = true;
141                    return;
142                }
143    
144                // Pivot if necessary
145                if (nonZero != col) {
146                    T tmp = field.getZero();
147                    for (int i = 0; i < m; i++) {
148                        tmp = lu[nonZero][i];
149                        lu[nonZero][i] = lu[col][i];
150                        lu[col][i] = tmp;
151                    }
152                    int temp = pivot[nonZero];
153                    pivot[nonZero] = pivot[col];
154                    pivot[col] = temp;
155                    even = !even;
156                }
157    
158                // Divide the lower elements by the "winning" diagonal elt.
159                final T luDiag = lu[col][col];
160                for (int row = col + 1; row < m; row++) {
161                    final T[] luRow = lu[row];
162                    luRow[col] = luRow[col].divide(luDiag);
163                }
164            }
165    
166        }
167    
168        /**
169         * Returns the matrix L of the decomposition.
170         * <p>L is a lower-triangular matrix</p>
171         * @return the L matrix (or null if decomposed matrix is singular)
172         */
173        public FieldMatrix<T> getL() {
174            if ((cachedL == null) && !singular) {
175                final int m = pivot.length;
176                cachedL = new Array2DRowFieldMatrix<T>(field, m, m);
177                for (int i = 0; i < m; ++i) {
178                    final T[] luI = lu[i];
179                    for (int j = 0; j < i; ++j) {
180                        cachedL.setEntry(i, j, luI[j]);
181                    }
182                    cachedL.setEntry(i, i, field.getOne());
183                }
184            }
185            return cachedL;
186        }
187    
188        /**
189         * Returns the matrix U of the decomposition.
190         * <p>U is an upper-triangular matrix</p>
191         * @return the U matrix (or null if decomposed matrix is singular)
192         */
193        public FieldMatrix<T> getU() {
194            if ((cachedU == null) && !singular) {
195                final int m = pivot.length;
196                cachedU = new Array2DRowFieldMatrix<T>(field, m, m);
197                for (int i = 0; i < m; ++i) {
198                    final T[] luI = lu[i];
199                    for (int j = i; j < m; ++j) {
200                        cachedU.setEntry(i, j, luI[j]);
201                    }
202                }
203            }
204            return cachedU;
205        }
206    
207        /**
208         * Returns the P rows permutation matrix.
209         * <p>P is a sparse matrix with exactly one element set to 1.0 in
210         * each row and each column, all other elements being set to 0.0.</p>
211         * <p>The positions of the 1 elements are given by the {@link #getPivot()
212         * pivot permutation vector}.</p>
213         * @return the P rows permutation matrix (or null if decomposed matrix is singular)
214         * @see #getPivot()
215         */
216        public FieldMatrix<T> getP() {
217            if ((cachedP == null) && !singular) {
218                final int m = pivot.length;
219                cachedP = new Array2DRowFieldMatrix<T>(field, m, m);
220                for (int i = 0; i < m; ++i) {
221                    cachedP.setEntry(i, pivot[i], field.getOne());
222                }
223            }
224            return cachedP;
225        }
226    
227        /**
228         * Returns the pivot permutation vector.
229         * @return the pivot permutation vector
230         * @see #getP()
231         */
232        public int[] getPivot() {
233            return pivot.clone();
234        }
235    
236        /**
237         * Return the determinant of the matrix.
238         * @return determinant of the matrix
239         */
240        public T getDeterminant() {
241            if (singular) {
242                return field.getZero();
243            } else {
244                final int m = pivot.length;
245                T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
246                for (int i = 0; i < m; i++) {
247                    determinant = determinant.multiply(lu[i][i]);
248                }
249                return determinant;
250            }
251        }
252    
253        /**
254         * Get a solver for finding the A &times; X = B solution in exact linear sense.
255         * @return a solver
256         */
257        public FieldDecompositionSolver<T> getSolver() {
258            return new Solver<T>(field, lu, pivot, singular);
259        }
260    
261        /** Specialized solver. */
262        private static class Solver<T extends FieldElement<T>> implements FieldDecompositionSolver<T> {
263    
264            /** Field to which the elements belong. */
265            private final Field<T> field;
266    
267            /** Entries of LU decomposition. */
268            private final T[][] lu;
269    
270            /** Pivot permutation associated with LU decomposition. */
271            private final int[] pivot;
272    
273            /** Singularity indicator. */
274            private final boolean singular;
275    
276            /**
277             * Build a solver from decomposed matrix.
278             * @param field field to which the matrix elements belong
279             * @param lu entries of LU decomposition
280             * @param pivot pivot permutation associated with LU decomposition
281             * @param singular singularity indicator
282             */
283            private Solver(final Field<T> field, final T[][] lu,
284                           final int[] pivot, final boolean singular) {
285                this.field    = field;
286                this.lu       = lu;
287                this.pivot    = pivot;
288                this.singular = singular;
289            }
290    
291            /** {@inheritDoc} */
292            public boolean isNonSingular() {
293                return !singular;
294            }
295    
296            /** {@inheritDoc} */
297            public FieldVector<T> solve(FieldVector<T> b) {
298                try {
299                    return solve((ArrayFieldVector<T>) b);
300                } catch (ClassCastException cce) {
301    
302                    final int m = pivot.length;
303                    if (b.getDimension() != m) {
304                        throw new DimensionMismatchException(b.getDimension(), m);
305                    }
306                    if (singular) {
307                        throw new SingularMatrixException();
308                    }
309    
310                    @SuppressWarnings("unchecked") // field is of type T
311                    final T[] bp = (T[]) Array.newInstance(field.getRuntimeClass(), m);
312    
313                    // Apply permutations to b
314                    for (int row = 0; row < m; row++) {
315                        bp[row] = b.getEntry(pivot[row]);
316                    }
317    
318                    // Solve LY = b
319                    for (int col = 0; col < m; col++) {
320                        final T bpCol = bp[col];
321                        for (int i = col + 1; i < m; i++) {
322                            bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
323                        }
324                    }
325    
326                    // Solve UX = Y
327                    for (int col = m - 1; col >= 0; col--) {
328                        bp[col] = bp[col].divide(lu[col][col]);
329                        final T bpCol = bp[col];
330                        for (int i = 0; i < col; i++) {
331                            bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
332                        }
333                    }
334    
335                    return new ArrayFieldVector<T>(field, bp, false);
336    
337                }
338            }
339    
340            /** Solve the linear equation A &times; X = B.
341             * <p>The A matrix is implicit here. It is </p>
342             * @param b right-hand side of the equation A &times; X = B
343             * @return a vector X such that A &times; X = B
344             * @throws DimensionMismatchException if the matrices dimensions do not match.
345             * @throws SingularMatrixException if the decomposed matrix is singular.
346             */
347            public ArrayFieldVector<T> solve(ArrayFieldVector<T> b) {
348                final int m = pivot.length;
349                final int length = b.getDimension();
350                if (length != m) {
351                    throw new DimensionMismatchException(length, m);
352                }
353                if (singular) {
354                    throw new SingularMatrixException();
355                }
356    
357                @SuppressWarnings("unchecked")
358                // field is of type T
359                final T[] bp = (T[]) Array.newInstance(field.getRuntimeClass(),
360                                                       m);
361    
362                // Apply permutations to b
363                for (int row = 0; row < m; row++) {
364                    bp[row] = b.getEntry(pivot[row]);
365                }
366    
367                // Solve LY = b
368                for (int col = 0; col < m; col++) {
369                    final T bpCol = bp[col];
370                    for (int i = col + 1; i < m; i++) {
371                        bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
372                    }
373                }
374    
375                // Solve UX = Y
376                for (int col = m - 1; col >= 0; col--) {
377                    bp[col] = bp[col].divide(lu[col][col]);
378                    final T bpCol = bp[col];
379                    for (int i = 0; i < col; i++) {
380                        bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
381                    }
382                }
383    
384                return new ArrayFieldVector<T>(bp, false);
385            }
386    
387            /** {@inheritDoc} */
388            public FieldMatrix<T> solve(FieldMatrix<T> b) {
389                final int m = pivot.length;
390                if (b.getRowDimension() != m) {
391                    throw new DimensionMismatchException(b.getRowDimension(), m);
392                }
393                if (singular) {
394                    throw new SingularMatrixException();
395                }
396    
397                final int nColB = b.getColumnDimension();
398    
399                // Apply permutations to b
400                @SuppressWarnings("unchecked") // field is of type T
401                final T[][] bp = (T[][]) Array.newInstance(field.getRuntimeClass(), new int[] { m, nColB });
402                for (int row = 0; row < m; row++) {
403                    final T[] bpRow = bp[row];
404                    final int pRow = pivot[row];
405                    for (int col = 0; col < nColB; col++) {
406                        bpRow[col] = b.getEntry(pRow, col);
407                    }
408                }
409    
410                // Solve LY = b
411                for (int col = 0; col < m; col++) {
412                    final T[] bpCol = bp[col];
413                    for (int i = col + 1; i < m; i++) {
414                        final T[] bpI = bp[i];
415                        final T luICol = lu[i][col];
416                        for (int j = 0; j < nColB; j++) {
417                            bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
418                        }
419                    }
420                }
421    
422                // Solve UX = Y
423                for (int col = m - 1; col >= 0; col--) {
424                    final T[] bpCol = bp[col];
425                    final T luDiag = lu[col][col];
426                    for (int j = 0; j < nColB; j++) {
427                        bpCol[j] = bpCol[j].divide(luDiag);
428                    }
429                    for (int i = 0; i < col; i++) {
430                        final T[] bpI = bp[i];
431                        final T luICol = lu[i][col];
432                        for (int j = 0; j < nColB; j++) {
433                            bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
434                        }
435                    }
436                }
437    
438                return new Array2DRowFieldMatrix<T>(field, bp, false);
439    
440            }
441    
442            /** {@inheritDoc} */
443            public FieldMatrix<T> getInverse() {
444                final int m = pivot.length;
445                final T one = field.getOne();
446                FieldMatrix<T> identity = new Array2DRowFieldMatrix<T>(field, m, m);
447                for (int i = 0; i < m; ++i) {
448                    identity.setEntry(i, i, one);
449                }
450                return solve(identity);
451            }
452        }
453    }