001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018 package org.apache.commons.math3.linear; 019 020 import java.io.IOException; 021 import java.io.ObjectInputStream; 022 import java.io.ObjectOutputStream; 023 import java.lang.reflect.Array; 024 import java.util.Arrays; 025 026 import org.apache.commons.math3.Field; 027 import org.apache.commons.math3.FieldElement; 028 import org.apache.commons.math3.exception.MathArithmeticException; 029 import org.apache.commons.math3.exception.OutOfRangeException; 030 import org.apache.commons.math3.exception.NoDataException; 031 import org.apache.commons.math3.exception.NumberIsTooSmallException; 032 import org.apache.commons.math3.exception.NullArgumentException; 033 import org.apache.commons.math3.exception.DimensionMismatchException; 034 import org.apache.commons.math3.exception.ZeroException; 035 import org.apache.commons.math3.exception.util.LocalizedFormats; 036 import org.apache.commons.math3.fraction.BigFraction; 037 import org.apache.commons.math3.fraction.Fraction; 038 import org.apache.commons.math3.util.FastMath; 039 import org.apache.commons.math3.util.Precision; 040 041 /** 042 * A collection of static methods that operate on or return matrices. 043 * 044 * @version $Id: MatrixUtils.java 1422313 2012-12-15 18:53:41Z psteitz $ 045 */ 046 public class MatrixUtils { 047 048 /** 049 * The default format for {@link RealMatrix} objects. 050 * @since 3.1 051 */ 052 public static final RealMatrixFormat DEFAULT_FORMAT = RealMatrixFormat.getInstance(); 053 054 /** 055 * A format for {@link RealMatrix} objects compatible with octave. 056 * @since 3.1 057 */ 058 public static final RealMatrixFormat OCTAVE_FORMAT = new RealMatrixFormat("[", "]", "", "", "; ", ", "); 059 060 /** 061 * Private constructor. 062 */ 063 private MatrixUtils() { 064 super(); 065 } 066 067 /** 068 * Returns a {@link RealMatrix} with specified dimensions. 069 * <p>The type of matrix returned depends on the dimension. Below 070 * 2<sup>12</sup> elements (i.e. 4096 elements or 64×64 for a 071 * square matrix) which can be stored in a 32kB array, a {@link 072 * Array2DRowRealMatrix} instance is built. Above this threshold a {@link 073 * BlockRealMatrix} instance is built.</p> 074 * <p>The matrix elements are all set to 0.0.</p> 075 * @param rows number of rows of the matrix 076 * @param columns number of columns of the matrix 077 * @return RealMatrix with specified dimensions 078 * @see #createRealMatrix(double[][]) 079 */ 080 public static RealMatrix createRealMatrix(final int rows, final int columns) { 081 return (rows * columns <= 4096) ? 082 new Array2DRowRealMatrix(rows, columns) : new BlockRealMatrix(rows, columns); 083 } 084 085 /** 086 * Returns a {@link FieldMatrix} with specified dimensions. 087 * <p>The type of matrix returned depends on the dimension. Below 088 * 2<sup>12</sup> elements (i.e. 4096 elements or 64×64 for a 089 * square matrix), a {@link FieldMatrix} instance is built. Above 090 * this threshold a {@link BlockFieldMatrix} instance is built.</p> 091 * <p>The matrix elements are all set to field.getZero().</p> 092 * @param <T> the type of the field elements 093 * @param field field to which the matrix elements belong 094 * @param rows number of rows of the matrix 095 * @param columns number of columns of the matrix 096 * @return FieldMatrix with specified dimensions 097 * @see #createFieldMatrix(FieldElement[][]) 098 * @since 2.0 099 */ 100 public static <T extends FieldElement<T>> FieldMatrix<T> createFieldMatrix(final Field<T> field, 101 final int rows, 102 final int columns) { 103 return (rows * columns <= 4096) ? 104 new Array2DRowFieldMatrix<T>(field, rows, columns) : new BlockFieldMatrix<T>(field, rows, columns); 105 } 106 107 /** 108 * Returns a {@link RealMatrix} whose entries are the the values in the 109 * the input array. 110 * <p>The type of matrix returned depends on the dimension. Below 111 * 2<sup>12</sup> elements (i.e. 4096 elements or 64×64 for a 112 * square matrix) which can be stored in a 32kB array, a {@link 113 * Array2DRowRealMatrix} instance is built. Above this threshold a {@link 114 * BlockRealMatrix} instance is built.</p> 115 * <p>The input array is copied, not referenced.</p> 116 * 117 * @param data input array 118 * @return RealMatrix containing the values of the array 119 * @throws org.apache.commons.math3.exception.DimensionMismatchException 120 * if {@code data} is not rectangular (not all rows have the same length). 121 * @throws NoDataException if a row or column is empty. 122 * @throws NullArgumentException if either {@code data} or {@code data[0]} 123 * is {@code null}. 124 * @throws DimensionMismatchException if {@code data} is not rectangular. 125 * @see #createRealMatrix(int, int) 126 */ 127 public static RealMatrix createRealMatrix(double[][] data) 128 throws NullArgumentException, DimensionMismatchException, 129 NoDataException { 130 if (data == null || 131 data[0] == null) { 132 throw new NullArgumentException(); 133 } 134 return (data.length * data[0].length <= 4096) ? 135 new Array2DRowRealMatrix(data) : new BlockRealMatrix(data); 136 } 137 138 /** 139 * Returns a {@link FieldMatrix} whose entries are the the values in the 140 * the input array. 141 * <p>The type of matrix returned depends on the dimension. Below 142 * 2<sup>12</sup> elements (i.e. 4096 elements or 64×64 for a 143 * square matrix), a {@link FieldMatrix} instance is built. Above 144 * this threshold a {@link BlockFieldMatrix} instance is built.</p> 145 * <p>The input array is copied, not referenced.</p> 146 * @param <T> the type of the field elements 147 * @param data input array 148 * @return a matrix containing the values of the array. 149 * @throws org.apache.commons.math3.exception.DimensionMismatchException 150 * if {@code data} is not rectangular (not all rows have the same length). 151 * @throws NoDataException if a row or column is empty. 152 * @throws NullArgumentException if either {@code data} or {@code data[0]} 153 * is {@code null}. 154 * @see #createFieldMatrix(Field, int, int) 155 * @since 2.0 156 */ 157 public static <T extends FieldElement<T>> FieldMatrix<T> createFieldMatrix(T[][] data) 158 throws DimensionMismatchException, NoDataException, NullArgumentException { 159 if (data == null || 160 data[0] == null) { 161 throw new NullArgumentException(); 162 } 163 return (data.length * data[0].length <= 4096) ? 164 new Array2DRowFieldMatrix<T>(data) : new BlockFieldMatrix<T>(data); 165 } 166 167 /** 168 * Returns <code>dimension x dimension</code> identity matrix. 169 * 170 * @param dimension dimension of identity matrix to generate 171 * @return identity matrix 172 * @throws IllegalArgumentException if dimension is not positive 173 * @since 1.1 174 */ 175 public static RealMatrix createRealIdentityMatrix(int dimension) { 176 final RealMatrix m = createRealMatrix(dimension, dimension); 177 for (int i = 0; i < dimension; ++i) { 178 m.setEntry(i, i, 1.0); 179 } 180 return m; 181 } 182 183 /** 184 * Returns <code>dimension x dimension</code> identity matrix. 185 * 186 * @param <T> the type of the field elements 187 * @param field field to which the elements belong 188 * @param dimension dimension of identity matrix to generate 189 * @return identity matrix 190 * @throws IllegalArgumentException if dimension is not positive 191 * @since 2.0 192 */ 193 public static <T extends FieldElement<T>> FieldMatrix<T> 194 createFieldIdentityMatrix(final Field<T> field, final int dimension) { 195 final T zero = field.getZero(); 196 final T one = field.getOne(); 197 @SuppressWarnings("unchecked") 198 final T[][] d = (T[][]) Array.newInstance(field.getRuntimeClass(), new int[] { dimension, dimension }); 199 for (int row = 0; row < dimension; row++) { 200 final T[] dRow = d[row]; 201 Arrays.fill(dRow, zero); 202 dRow[row] = one; 203 } 204 return new Array2DRowFieldMatrix<T>(field, d, false); 205 } 206 207 /** 208 * Returns a diagonal matrix with specified elements. 209 * 210 * @param diagonal diagonal elements of the matrix (the array elements 211 * will be copied) 212 * @return diagonal matrix 213 * @since 2.0 214 */ 215 public static RealMatrix createRealDiagonalMatrix(final double[] diagonal) { 216 final RealMatrix m = createRealMatrix(diagonal.length, diagonal.length); 217 for (int i = 0; i < diagonal.length; ++i) { 218 m.setEntry(i, i, diagonal[i]); 219 } 220 return m; 221 } 222 223 /** 224 * Returns a diagonal matrix with specified elements. 225 * 226 * @param <T> the type of the field elements 227 * @param diagonal diagonal elements of the matrix (the array elements 228 * will be copied) 229 * @return diagonal matrix 230 * @since 2.0 231 */ 232 public static <T extends FieldElement<T>> FieldMatrix<T> 233 createFieldDiagonalMatrix(final T[] diagonal) { 234 final FieldMatrix<T> m = 235 createFieldMatrix(diagonal[0].getField(), diagonal.length, diagonal.length); 236 for (int i = 0; i < diagonal.length; ++i) { 237 m.setEntry(i, i, diagonal[i]); 238 } 239 return m; 240 } 241 242 /** 243 * Creates a {@link RealVector} using the data from the input array. 244 * 245 * @param data the input data 246 * @return a data.length RealVector 247 * @throws NoDataException if {@code data} is empty. 248 * @throws NullArgumentException if {@code data} is {@code null}. 249 */ 250 public static RealVector createRealVector(double[] data) 251 throws NoDataException, NullArgumentException { 252 if (data == null) { 253 throw new NullArgumentException(); 254 } 255 return new ArrayRealVector(data, true); 256 } 257 258 /** 259 * Creates a {@link FieldVector} using the data from the input array. 260 * 261 * @param <T> the type of the field elements 262 * @param data the input data 263 * @return a data.length FieldVector 264 * @throws NoDataException if {@code data} is empty. 265 * @throws NullArgumentException if {@code data} is {@code null}. 266 * @throws ZeroException if {@code data} has 0 elements 267 */ 268 public static <T extends FieldElement<T>> FieldVector<T> createFieldVector(final T[] data) 269 throws NoDataException, NullArgumentException, ZeroException { 270 if (data == null) { 271 throw new NullArgumentException(); 272 } 273 if (data.length == 0) { 274 throw new ZeroException(LocalizedFormats.VECTOR_MUST_HAVE_AT_LEAST_ONE_ELEMENT); 275 } 276 return new ArrayFieldVector<T>(data[0].getField(), data, true); 277 } 278 279 /** 280 * Create a row {@link RealMatrix} using the data from the input 281 * array. 282 * 283 * @param rowData the input row data 284 * @return a 1 x rowData.length RealMatrix 285 * @throws NoDataException if {@code rowData} is empty. 286 * @throws NullArgumentException if {@code rowData} is {@code null}. 287 */ 288 public static RealMatrix createRowRealMatrix(double[] rowData) 289 throws NoDataException, NullArgumentException { 290 if (rowData == null) { 291 throw new NullArgumentException(); 292 } 293 final int nCols = rowData.length; 294 final RealMatrix m = createRealMatrix(1, nCols); 295 for (int i = 0; i < nCols; ++i) { 296 m.setEntry(0, i, rowData[i]); 297 } 298 return m; 299 } 300 301 /** 302 * Create a row {@link FieldMatrix} using the data from the input 303 * array. 304 * 305 * @param <T> the type of the field elements 306 * @param rowData the input row data 307 * @return a 1 x rowData.length FieldMatrix 308 * @throws NoDataException if {@code rowData} is empty. 309 * @throws NullArgumentException if {@code rowData} is {@code null}. 310 */ 311 public static <T extends FieldElement<T>> FieldMatrix<T> 312 createRowFieldMatrix(final T[] rowData) 313 throws NoDataException, NullArgumentException { 314 if (rowData == null) { 315 throw new NullArgumentException(); 316 } 317 final int nCols = rowData.length; 318 if (nCols == 0) { 319 throw new NoDataException(LocalizedFormats.AT_LEAST_ONE_COLUMN); 320 } 321 final FieldMatrix<T> m = createFieldMatrix(rowData[0].getField(), 1, nCols); 322 for (int i = 0; i < nCols; ++i) { 323 m.setEntry(0, i, rowData[i]); 324 } 325 return m; 326 } 327 328 /** 329 * Creates a column {@link RealMatrix} using the data from the input 330 * array. 331 * 332 * @param columnData the input column data 333 * @return a columnData x 1 RealMatrix 334 * @throws NoDataException if {@code columnData} is empty. 335 * @throws NullArgumentException if {@code columnData} is {@code null}. 336 */ 337 public static RealMatrix createColumnRealMatrix(double[] columnData) 338 throws NoDataException, NullArgumentException { 339 if (columnData == null) { 340 throw new NullArgumentException(); 341 } 342 final int nRows = columnData.length; 343 final RealMatrix m = createRealMatrix(nRows, 1); 344 for (int i = 0; i < nRows; ++i) { 345 m.setEntry(i, 0, columnData[i]); 346 } 347 return m; 348 } 349 350 /** 351 * Creates a column {@link FieldMatrix} using the data from the input 352 * array. 353 * 354 * @param <T> the type of the field elements 355 * @param columnData the input column data 356 * @return a columnData x 1 FieldMatrix 357 * @throws NoDataException if {@code data} is empty. 358 * @throws NullArgumentException if {@code columnData} is {@code null}. 359 */ 360 public static <T extends FieldElement<T>> FieldMatrix<T> 361 createColumnFieldMatrix(final T[] columnData) 362 throws NoDataException, NullArgumentException { 363 if (columnData == null) { 364 throw new NullArgumentException(); 365 } 366 final int nRows = columnData.length; 367 if (nRows == 0) { 368 throw new NoDataException(LocalizedFormats.AT_LEAST_ONE_ROW); 369 } 370 final FieldMatrix<T> m = createFieldMatrix(columnData[0].getField(), nRows, 1); 371 for (int i = 0; i < nRows; ++i) { 372 m.setEntry(i, 0, columnData[i]); 373 } 374 return m; 375 } 376 377 /** 378 * Checks whether a matrix is symmetric, within a given relative tolerance. 379 * 380 * @param matrix Matrix to check. 381 * @param relativeTolerance Tolerance of the symmetry check. 382 * @param raiseException If {@code true}, an exception will be raised if 383 * the matrix is not symmetric. 384 * @return {@code true} if {@code matrix} is symmetric. 385 * @throws NonSquareMatrixException if the matrix is not square. 386 * @throws NonSymmetricMatrixException if the matrix is not symmetric. 387 */ 388 private static boolean isSymmetricInternal(RealMatrix matrix, 389 double relativeTolerance, 390 boolean raiseException) { 391 final int rows = matrix.getRowDimension(); 392 if (rows != matrix.getColumnDimension()) { 393 if (raiseException) { 394 throw new NonSquareMatrixException(rows, matrix.getColumnDimension()); 395 } else { 396 return false; 397 } 398 } 399 for (int i = 0; i < rows; i++) { 400 for (int j = i + 1; j < rows; j++) { 401 final double mij = matrix.getEntry(i, j); 402 final double mji = matrix.getEntry(j, i); 403 if (FastMath.abs(mij - mji) > 404 FastMath.max(FastMath.abs(mij), FastMath.abs(mji)) * relativeTolerance) { 405 if (raiseException) { 406 throw new NonSymmetricMatrixException(i, j, relativeTolerance); 407 } else { 408 return false; 409 } 410 } 411 } 412 } 413 return true; 414 } 415 416 /** 417 * Checks whether a matrix is symmetric. 418 * 419 * @param matrix Matrix to check. 420 * @param eps Relative tolerance. 421 * @throws NonSquareMatrixException if the matrix is not square. 422 * @throws NonSymmetricMatrixException if the matrix is not symmetric. 423 * @since 3.1 424 */ 425 public static void checkSymmetric(RealMatrix matrix, 426 double eps) { 427 isSymmetricInternal(matrix, eps, true); 428 } 429 430 /** 431 * Checks whether a matrix is symmetric. 432 * 433 * @param matrix Matrix to check. 434 * @param eps Relative tolerance. 435 * @return {@code true} if {@code matrix} is symmetric. 436 * @since 3.1 437 */ 438 public static boolean isSymmetric(RealMatrix matrix, 439 double eps) { 440 return isSymmetricInternal(matrix, eps, false); 441 } 442 443 /** 444 * Check if matrix indices are valid. 445 * 446 * @param m Matrix. 447 * @param row Row index to check. 448 * @param column Column index to check. 449 * @throws OutOfRangeException if {@code row} or {@code column} is not 450 * a valid index. 451 */ 452 public static void checkMatrixIndex(final AnyMatrix m, 453 final int row, final int column) 454 throws OutOfRangeException { 455 checkRowIndex(m, row); 456 checkColumnIndex(m, column); 457 } 458 459 /** 460 * Check if a row index is valid. 461 * 462 * @param m Matrix. 463 * @param row Row index to check. 464 * @throws OutOfRangeException if {@code row} is not a valid index. 465 */ 466 public static void checkRowIndex(final AnyMatrix m, final int row) 467 throws OutOfRangeException { 468 if (row < 0 || 469 row >= m.getRowDimension()) { 470 throw new OutOfRangeException(LocalizedFormats.ROW_INDEX, 471 row, 0, m.getRowDimension() - 1); 472 } 473 } 474 475 /** 476 * Check if a column index is valid. 477 * 478 * @param m Matrix. 479 * @param column Column index to check. 480 * @throws OutOfRangeException if {@code column} is not a valid index. 481 */ 482 public static void checkColumnIndex(final AnyMatrix m, final int column) 483 throws OutOfRangeException { 484 if (column < 0 || column >= m.getColumnDimension()) { 485 throw new OutOfRangeException(LocalizedFormats.COLUMN_INDEX, 486 column, 0, m.getColumnDimension() - 1); 487 } 488 } 489 490 /** 491 * Check if submatrix ranges indices are valid. 492 * Rows and columns are indicated counting from 0 to {@code n - 1}. 493 * 494 * @param m Matrix. 495 * @param startRow Initial row index. 496 * @param endRow Final row index. 497 * @param startColumn Initial column index. 498 * @param endColumn Final column index. 499 * @throws OutOfRangeException if the indices are invalid. 500 * @throws NumberIsTooSmallException if {@code endRow < startRow} or 501 * {@code endColumn < startColumn}. 502 */ 503 public static void checkSubMatrixIndex(final AnyMatrix m, 504 final int startRow, final int endRow, 505 final int startColumn, final int endColumn) 506 throws NumberIsTooSmallException, OutOfRangeException { 507 checkRowIndex(m, startRow); 508 checkRowIndex(m, endRow); 509 if (endRow < startRow) { 510 throw new NumberIsTooSmallException(LocalizedFormats.INITIAL_ROW_AFTER_FINAL_ROW, 511 endRow, startRow, false); 512 } 513 514 checkColumnIndex(m, startColumn); 515 checkColumnIndex(m, endColumn); 516 if (endColumn < startColumn) { 517 throw new NumberIsTooSmallException(LocalizedFormats.INITIAL_COLUMN_AFTER_FINAL_COLUMN, 518 endColumn, startColumn, false); 519 } 520 521 522 } 523 524 /** 525 * Check if submatrix ranges indices are valid. 526 * Rows and columns are indicated counting from 0 to n-1. 527 * 528 * @param m Matrix. 529 * @param selectedRows Array of row indices. 530 * @param selectedColumns Array of column indices. 531 * @throws NullArgumentException if {@code selectedRows} or 532 * {@code selectedColumns} are {@code null}. 533 * @throws NoDataException if the row or column selections are empty (zero 534 * length). 535 * @throws OutOfRangeException if row or column selections are not valid. 536 */ 537 public static void checkSubMatrixIndex(final AnyMatrix m, 538 final int[] selectedRows, 539 final int[] selectedColumns) 540 throws NoDataException, NullArgumentException, OutOfRangeException { 541 if (selectedRows == null) { 542 throw new NullArgumentException(); 543 } 544 if (selectedColumns == null) { 545 throw new NullArgumentException(); 546 } 547 if (selectedRows.length == 0) { 548 throw new NoDataException(LocalizedFormats.EMPTY_SELECTED_ROW_INDEX_ARRAY); 549 } 550 if (selectedColumns.length == 0) { 551 throw new NoDataException(LocalizedFormats.EMPTY_SELECTED_COLUMN_INDEX_ARRAY); 552 } 553 554 for (final int row : selectedRows) { 555 checkRowIndex(m, row); 556 } 557 for (final int column : selectedColumns) { 558 checkColumnIndex(m, column); 559 } 560 } 561 562 /** 563 * Check if matrices are addition compatible. 564 * 565 * @param left Left hand side matrix. 566 * @param right Right hand side matrix. 567 * @throws MatrixDimensionMismatchException if the matrices are not addition 568 * compatible. 569 */ 570 public static void checkAdditionCompatible(final AnyMatrix left, final AnyMatrix right) 571 throws MatrixDimensionMismatchException { 572 if ((left.getRowDimension() != right.getRowDimension()) || 573 (left.getColumnDimension() != right.getColumnDimension())) { 574 throw new MatrixDimensionMismatchException(left.getRowDimension(), left.getColumnDimension(), 575 right.getRowDimension(), right.getColumnDimension()); 576 } 577 } 578 579 /** 580 * Check if matrices are subtraction compatible 581 * 582 * @param left Left hand side matrix. 583 * @param right Right hand side matrix. 584 * @throws MatrixDimensionMismatchException if the matrices are not addition 585 * compatible. 586 */ 587 public static void checkSubtractionCompatible(final AnyMatrix left, final AnyMatrix right) 588 throws MatrixDimensionMismatchException { 589 if ((left.getRowDimension() != right.getRowDimension()) || 590 (left.getColumnDimension() != right.getColumnDimension())) { 591 throw new MatrixDimensionMismatchException(left.getRowDimension(), left.getColumnDimension(), 592 right.getRowDimension(), right.getColumnDimension()); 593 } 594 } 595 596 /** 597 * Check if matrices are multiplication compatible 598 * 599 * @param left Left hand side matrix. 600 * @param right Right hand side matrix. 601 * @throws DimensionMismatchException if matrices are not multiplication 602 * compatible. 603 */ 604 public static void checkMultiplicationCompatible(final AnyMatrix left, final AnyMatrix right) 605 throws DimensionMismatchException { 606 607 if (left.getColumnDimension() != right.getRowDimension()) { 608 throw new DimensionMismatchException(left.getColumnDimension(), 609 right.getRowDimension()); 610 } 611 } 612 613 /** 614 * Convert a {@link FieldMatrix}/{@link Fraction} matrix to a {@link RealMatrix}. 615 * @param m Matrix to convert. 616 * @return the converted matrix. 617 */ 618 public static Array2DRowRealMatrix fractionMatrixToRealMatrix(final FieldMatrix<Fraction> m) { 619 final FractionMatrixConverter converter = new FractionMatrixConverter(); 620 m.walkInOptimizedOrder(converter); 621 return converter.getConvertedMatrix(); 622 } 623 624 /** Converter for {@link FieldMatrix}/{@link Fraction}. */ 625 private static class FractionMatrixConverter extends DefaultFieldMatrixPreservingVisitor<Fraction> { 626 /** Converted array. */ 627 private double[][] data; 628 /** Simple constructor. */ 629 public FractionMatrixConverter() { 630 super(Fraction.ZERO); 631 } 632 633 /** {@inheritDoc} */ 634 @Override 635 public void start(int rows, int columns, 636 int startRow, int endRow, int startColumn, int endColumn) { 637 data = new double[rows][columns]; 638 } 639 640 /** {@inheritDoc} */ 641 @Override 642 public void visit(int row, int column, Fraction value) { 643 data[row][column] = value.doubleValue(); 644 } 645 646 /** 647 * Get the converted matrix. 648 * 649 * @return the converted matrix. 650 */ 651 Array2DRowRealMatrix getConvertedMatrix() { 652 return new Array2DRowRealMatrix(data, false); 653 } 654 655 } 656 657 /** 658 * Convert a {@link FieldMatrix}/{@link BigFraction} matrix to a {@link RealMatrix}. 659 * 660 * @param m Matrix to convert. 661 * @return the converted matrix. 662 */ 663 public static Array2DRowRealMatrix bigFractionMatrixToRealMatrix(final FieldMatrix<BigFraction> m) { 664 final BigFractionMatrixConverter converter = new BigFractionMatrixConverter(); 665 m.walkInOptimizedOrder(converter); 666 return converter.getConvertedMatrix(); 667 } 668 669 /** Converter for {@link FieldMatrix}/{@link BigFraction}. */ 670 private static class BigFractionMatrixConverter extends DefaultFieldMatrixPreservingVisitor<BigFraction> { 671 /** Converted array. */ 672 private double[][] data; 673 /** Simple constructor. */ 674 public BigFractionMatrixConverter() { 675 super(BigFraction.ZERO); 676 } 677 678 /** {@inheritDoc} */ 679 @Override 680 public void start(int rows, int columns, 681 int startRow, int endRow, int startColumn, int endColumn) { 682 data = new double[rows][columns]; 683 } 684 685 /** {@inheritDoc} */ 686 @Override 687 public void visit(int row, int column, BigFraction value) { 688 data[row][column] = value.doubleValue(); 689 } 690 691 /** 692 * Get the converted matrix. 693 * 694 * @return the converted matrix. 695 */ 696 Array2DRowRealMatrix getConvertedMatrix() { 697 return new Array2DRowRealMatrix(data, false); 698 } 699 } 700 701 /** Serialize a {@link RealVector}. 702 * <p> 703 * This method is intended to be called from within a private 704 * <code>writeObject</code> method (after a call to 705 * <code>oos.defaultWriteObject()</code>) in a class that has a 706 * {@link RealVector} field, which should be declared <code>transient</code>. 707 * This way, the default handling does not serialize the vector (the {@link 708 * RealVector} interface is not serializable by default) but this method does 709 * serialize it specifically. 710 * </p> 711 * <p> 712 * The following example shows how a simple class with a name and a real vector 713 * should be written: 714 * <pre><code> 715 * public class NamedVector implements Serializable { 716 * 717 * private final String name; 718 * private final transient RealVector coefficients; 719 * 720 * // omitted constructors, getters ... 721 * 722 * private void writeObject(ObjectOutputStream oos) throws IOException { 723 * oos.defaultWriteObject(); // takes care of name field 724 * MatrixUtils.serializeRealVector(coefficients, oos); 725 * } 726 * 727 * private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { 728 * ois.defaultReadObject(); // takes care of name field 729 * MatrixUtils.deserializeRealVector(this, "coefficients", ois); 730 * } 731 * 732 * } 733 * </code></pre> 734 * </p> 735 * 736 * @param vector real vector to serialize 737 * @param oos stream where the real vector should be written 738 * @exception IOException if object cannot be written to stream 739 * @see #deserializeRealVector(Object, String, ObjectInputStream) 740 */ 741 public static void serializeRealVector(final RealVector vector, 742 final ObjectOutputStream oos) 743 throws IOException { 744 final int n = vector.getDimension(); 745 oos.writeInt(n); 746 for (int i = 0; i < n; ++i) { 747 oos.writeDouble(vector.getEntry(i)); 748 } 749 } 750 751 /** Deserialize a {@link RealVector} field in a class. 752 * <p> 753 * This method is intended to be called from within a private 754 * <code>readObject</code> method (after a call to 755 * <code>ois.defaultReadObject()</code>) in a class that has a 756 * {@link RealVector} field, which should be declared <code>transient</code>. 757 * This way, the default handling does not deserialize the vector (the {@link 758 * RealVector} interface is not serializable by default) but this method does 759 * deserialize it specifically. 760 * </p> 761 * @param instance instance in which the field must be set up 762 * @param fieldName name of the field within the class (may be private and final) 763 * @param ois stream from which the real vector should be read 764 * @exception ClassNotFoundException if a class in the stream cannot be found 765 * @exception IOException if object cannot be read from the stream 766 * @see #serializeRealVector(RealVector, ObjectOutputStream) 767 */ 768 public static void deserializeRealVector(final Object instance, 769 final String fieldName, 770 final ObjectInputStream ois) 771 throws ClassNotFoundException, IOException { 772 try { 773 774 // read the vector data 775 final int n = ois.readInt(); 776 final double[] data = new double[n]; 777 for (int i = 0; i < n; ++i) { 778 data[i] = ois.readDouble(); 779 } 780 781 // create the instance 782 final RealVector vector = new ArrayRealVector(data, false); 783 784 // set up the field 785 final java.lang.reflect.Field f = 786 instance.getClass().getDeclaredField(fieldName); 787 f.setAccessible(true); 788 f.set(instance, vector); 789 790 } catch (NoSuchFieldException nsfe) { 791 IOException ioe = new IOException(); 792 ioe.initCause(nsfe); 793 throw ioe; 794 } catch (IllegalAccessException iae) { 795 IOException ioe = new IOException(); 796 ioe.initCause(iae); 797 throw ioe; 798 } 799 800 } 801 802 /** Serialize a {@link RealMatrix}. 803 * <p> 804 * This method is intended to be called from within a private 805 * <code>writeObject</code> method (after a call to 806 * <code>oos.defaultWriteObject()</code>) in a class that has a 807 * {@link RealMatrix} field, which should be declared <code>transient</code>. 808 * This way, the default handling does not serialize the matrix (the {@link 809 * RealMatrix} interface is not serializable by default) but this method does 810 * serialize it specifically. 811 * </p> 812 * <p> 813 * The following example shows how a simple class with a name and a real matrix 814 * should be written: 815 * <pre><code> 816 * public class NamedMatrix implements Serializable { 817 * 818 * private final String name; 819 * private final transient RealMatrix coefficients; 820 * 821 * // omitted constructors, getters ... 822 * 823 * private void writeObject(ObjectOutputStream oos) throws IOException { 824 * oos.defaultWriteObject(); // takes care of name field 825 * MatrixUtils.serializeRealMatrix(coefficients, oos); 826 * } 827 * 828 * private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { 829 * ois.defaultReadObject(); // takes care of name field 830 * MatrixUtils.deserializeRealMatrix(this, "coefficients", ois); 831 * } 832 * 833 * } 834 * </code></pre> 835 * </p> 836 * 837 * @param matrix real matrix to serialize 838 * @param oos stream where the real matrix should be written 839 * @exception IOException if object cannot be written to stream 840 * @see #deserializeRealMatrix(Object, String, ObjectInputStream) 841 */ 842 public static void serializeRealMatrix(final RealMatrix matrix, 843 final ObjectOutputStream oos) 844 throws IOException { 845 final int n = matrix.getRowDimension(); 846 final int m = matrix.getColumnDimension(); 847 oos.writeInt(n); 848 oos.writeInt(m); 849 for (int i = 0; i < n; ++i) { 850 for (int j = 0; j < m; ++j) { 851 oos.writeDouble(matrix.getEntry(i, j)); 852 } 853 } 854 } 855 856 /** Deserialize a {@link RealMatrix} field in a class. 857 * <p> 858 * This method is intended to be called from within a private 859 * <code>readObject</code> method (after a call to 860 * <code>ois.defaultReadObject()</code>) in a class that has a 861 * {@link RealMatrix} field, which should be declared <code>transient</code>. 862 * This way, the default handling does not deserialize the matrix (the {@link 863 * RealMatrix} interface is not serializable by default) but this method does 864 * deserialize it specifically. 865 * </p> 866 * @param instance instance in which the field must be set up 867 * @param fieldName name of the field within the class (may be private and final) 868 * @param ois stream from which the real matrix should be read 869 * @exception ClassNotFoundException if a class in the stream cannot be found 870 * @exception IOException if object cannot be read from the stream 871 * @see #serializeRealMatrix(RealMatrix, ObjectOutputStream) 872 */ 873 public static void deserializeRealMatrix(final Object instance, 874 final String fieldName, 875 final ObjectInputStream ois) 876 throws ClassNotFoundException, IOException { 877 try { 878 879 // read the matrix data 880 final int n = ois.readInt(); 881 final int m = ois.readInt(); 882 final double[][] data = new double[n][m]; 883 for (int i = 0; i < n; ++i) { 884 final double[] dataI = data[i]; 885 for (int j = 0; j < m; ++j) { 886 dataI[j] = ois.readDouble(); 887 } 888 } 889 890 // create the instance 891 final RealMatrix matrix = new Array2DRowRealMatrix(data, false); 892 893 // set up the field 894 final java.lang.reflect.Field f = 895 instance.getClass().getDeclaredField(fieldName); 896 f.setAccessible(true); 897 f.set(instance, matrix); 898 899 } catch (NoSuchFieldException nsfe) { 900 IOException ioe = new IOException(); 901 ioe.initCause(nsfe); 902 throw ioe; 903 } catch (IllegalAccessException iae) { 904 IOException ioe = new IOException(); 905 ioe.initCause(iae); 906 throw ioe; 907 } 908 } 909 910 /**Solve a system of composed of a Lower Triangular Matrix 911 * {@link RealMatrix}. 912 * <p> 913 * This method is called to solve systems of equations which are 914 * of the lower triangular form. The matrix {@link RealMatrix} 915 * is assumed, though not checked, to be in lower triangular form. 916 * The vector {@link RealVector} is overwritten with the solution. 917 * The matrix is checked that it is square and its dimensions match 918 * the length of the vector. 919 * </p> 920 * @param rm RealMatrix which is lower triangular 921 * @param b RealVector this is overwritten 922 * @throws DimensionMismatchException if the matrix and vector are not 923 * conformable 924 * @throws NonSquareMatrixException if the matrix {@code rm} is not square 925 * @throws MathArithmeticException if the absolute value of one of the diagonal 926 * coefficient of {@code rm} is lower than {@link Precision#SAFE_MIN} 927 */ 928 public static void solveLowerTriangularSystem(RealMatrix rm, RealVector b) 929 throws DimensionMismatchException, MathArithmeticException, 930 NonSquareMatrixException { 931 if ((rm == null) || (b == null) || ( rm.getRowDimension() != b.getDimension())) { 932 throw new DimensionMismatchException( 933 (rm == null) ? 0 : rm.getRowDimension(), 934 (b == null) ? 0 : b.getDimension()); 935 } 936 if( rm.getColumnDimension() != rm.getRowDimension() ){ 937 throw new NonSquareMatrixException(rm.getRowDimension(), 938 rm.getColumnDimension()); 939 } 940 int rows = rm.getRowDimension(); 941 for( int i = 0 ; i < rows ; i++ ){ 942 double diag = rm.getEntry(i, i); 943 if( FastMath.abs(diag) < Precision.SAFE_MIN ){ 944 throw new MathArithmeticException(LocalizedFormats.ZERO_DENOMINATOR); 945 } 946 double bi = b.getEntry(i)/diag; 947 b.setEntry(i, bi ); 948 for( int j = i+1; j< rows; j++ ){ 949 b.setEntry(j, b.getEntry(j)-bi*rm.getEntry(j,i) ); 950 } 951 } 952 } 953 954 /** Solver a system composed of an Upper Triangular Matrix 955 * {@link RealMatrix}. 956 * <p> 957 * This method is called to solve systems of equations which are 958 * of the lower triangular form. The matrix {@link RealMatrix} 959 * is assumed, though not checked, to be in upper triangular form. 960 * The vector {@link RealVector} is overwritten with the solution. 961 * The matrix is checked that it is square and its dimensions match 962 * the length of the vector. 963 * </p> 964 * @param rm RealMatrix which is upper triangular 965 * @param b RealVector this is overwritten 966 * @throws DimensionMismatchException if the matrix and vector are not 967 * conformable 968 * @throws NonSquareMatrixException if the matrix {@code rm} is not 969 * square 970 * @throws MathArithmeticException if the absolute value of one of the diagonal 971 * coefficient of {@code rm} is lower than {@link Precision#SAFE_MIN} 972 */ 973 public static void solveUpperTriangularSystem(RealMatrix rm, RealVector b) 974 throws DimensionMismatchException, MathArithmeticException, 975 NonSquareMatrixException { 976 if ((rm == null) || (b == null) || ( rm.getRowDimension() != b.getDimension())) { 977 throw new DimensionMismatchException( 978 (rm == null) ? 0 : rm.getRowDimension(), 979 (b == null) ? 0 : b.getDimension()); 980 } 981 if( rm.getColumnDimension() != rm.getRowDimension() ){ 982 throw new NonSquareMatrixException(rm.getRowDimension(), 983 rm.getColumnDimension()); 984 } 985 int rows = rm.getRowDimension(); 986 for( int i = rows-1 ; i >-1 ; i-- ){ 987 double diag = rm.getEntry(i, i); 988 if( FastMath.abs(diag) < Precision.SAFE_MIN ){ 989 throw new MathArithmeticException(LocalizedFormats.ZERO_DENOMINATOR); 990 } 991 double bi = b.getEntry(i)/diag; 992 b.setEntry(i, bi ); 993 for( int j = i-1; j>-1; j-- ){ 994 b.setEntry(j, b.getEntry(j)-bi*rm.getEntry(j,i) ); 995 } 996 } 997 } 998 999 /** 1000 * Computes the inverse of the given matrix by splitting it into 1001 * 4 sub-matrices. 1002 * 1003 * @param m Matrix whose inverse must be computed. 1004 * @param splitIndex Index that determines the "split" line and 1005 * column. 1006 * The element corresponding to this index will part of the 1007 * upper-left sub-matrix. 1008 * @return the inverse of {@code m}. 1009 * @throws NonSquareMatrixException if {@code m} is not square. 1010 */ 1011 public static RealMatrix blockInverse(RealMatrix m, 1012 int splitIndex) { 1013 final int n = m.getRowDimension(); 1014 if (m.getColumnDimension() != n) { 1015 throw new NonSquareMatrixException(m.getRowDimension(), 1016 m.getColumnDimension()); 1017 } 1018 1019 final int splitIndex1 = splitIndex + 1; 1020 1021 final RealMatrix a = m.getSubMatrix(0, splitIndex, 0, splitIndex); 1022 final RealMatrix b = m.getSubMatrix(0, splitIndex, splitIndex1, n - 1); 1023 final RealMatrix c = m.getSubMatrix(splitIndex1, n - 1, 0, splitIndex); 1024 final RealMatrix d = m.getSubMatrix(splitIndex1, n - 1, splitIndex1, n - 1); 1025 1026 final SingularValueDecomposition aDec = new SingularValueDecomposition(a); 1027 final RealMatrix aInv = aDec.getSolver().getInverse(); 1028 1029 final SingularValueDecomposition dDec = new SingularValueDecomposition(d); 1030 final RealMatrix dInv = dDec.getSolver().getInverse(); 1031 1032 final RealMatrix tmp1 = a.subtract(b.multiply(dInv).multiply(c)); 1033 final SingularValueDecomposition tmp1Dec = new SingularValueDecomposition(tmp1); 1034 final RealMatrix result00 = tmp1Dec.getSolver().getInverse(); 1035 1036 final RealMatrix tmp2 = d.subtract(c.multiply(aInv).multiply(b)); 1037 final SingularValueDecomposition tmp2Dec = new SingularValueDecomposition(tmp2); 1038 final RealMatrix result11 = tmp2Dec.getSolver().getInverse(); 1039 1040 final RealMatrix result01 = aInv.multiply(b).multiply(result11).scalarMultiply(-1); 1041 final RealMatrix result10 = dInv.multiply(c).multiply(result00).scalarMultiply(-1); 1042 1043 final RealMatrix result = new Array2DRowRealMatrix(n, n); 1044 result.setSubMatrix(result00.getData(), 0, 0); 1045 result.setSubMatrix(result01.getData(), 0, splitIndex1); 1046 result.setSubMatrix(result10.getData(), splitIndex1, 0); 1047 result.setSubMatrix(result11.getData(), splitIndex1, splitIndex1); 1048 1049 return result; 1050 } 1051 }