Array2DRowRealMatrix.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.linear;

  18. import java.io.Serializable;

  19. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  20. import org.apache.commons.math4.legacy.exception.MathIllegalStateException;
  21. import org.apache.commons.math4.legacy.exception.NoDataException;
  22. import org.apache.commons.math4.legacy.exception.NullArgumentException;
  23. import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;

  24. /**
  25.  * Implementation of {@link RealMatrix} using a {@code double[][]} array to
  26.  * store entries.
  27.  */
  28. public class Array2DRowRealMatrix extends AbstractRealMatrix implements Serializable {
  29.     /** Serializable version identifier. */
  30.     private static final long serialVersionUID = -1067294169172445528L;

  31.     /** Entries of the matrix. */
  32.     private double[][] data;

  33.     /**
  34.      * Creates a matrix with no data.
  35.      */
  36.     public Array2DRowRealMatrix() {}

  37.     /**
  38.      * Create a new RealMatrix with the supplied row and column dimensions.
  39.      *
  40.      * @param rowDimension Number of rows in the new matrix.
  41.      * @param columnDimension Number of columns in the new matrix.
  42.      * @throws org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException
  43.      * if the row or column dimension is not positive.
  44.      */
  45.     public Array2DRowRealMatrix(final int rowDimension,
  46.                                 final int columnDimension) {
  47.         super(rowDimension, columnDimension);
  48.         data = new double[rowDimension][columnDimension];
  49.     }

  50.     /**
  51.      * Create a new {@code RealMatrix} using the input array as the underlying
  52.      * data array.
  53.      * <p>The input array is copied, not referenced. This constructor has
  54.      * the same effect as calling {@link #Array2DRowRealMatrix(double[][], boolean)}
  55.      * with the second argument set to {@code true}.</p>
  56.      *
  57.      * @param d Data for the new matrix.
  58.      * @throws DimensionMismatchException if {@code d} is not rectangular.
  59.      * @throws NoDataException if {@code d} row or column dimension is zero.
  60.      * @throws NullArgumentException if {@code d} is {@code null}.
  61.      * @see #Array2DRowRealMatrix(double[][], boolean)
  62.      */
  63.     public Array2DRowRealMatrix(final double[][] d) {
  64.         copyIn(d);
  65.     }

  66.     /**
  67.      * Create a new RealMatrix using the input array as the underlying
  68.      * data array.
  69.      * If an array is built specially in order to be embedded in a
  70.      * RealMatrix and not used directly, the {@code copyArray} may be
  71.      * set to {@code false}. This will prevent the copying and improve
  72.      * performance as no new array will be built and no data will be copied.
  73.      *
  74.      * @param d Data for new matrix.
  75.      * @param copyArray if {@code true}, the input array will be copied,
  76.      * otherwise it will be referenced.
  77.      * @throws DimensionMismatchException if {@code d} is not rectangular.
  78.      * @throws NoDataException if {@code d} row or column dimension is zero.
  79.      * @throws NullArgumentException if {@code d} is {@code null}.
  80.      * @see #Array2DRowRealMatrix(double[][])
  81.      */
  82.     public Array2DRowRealMatrix(final double[][] d, final boolean copyArray) {
  83.         if (copyArray) {
  84.             copyIn(d);
  85.         } else {
  86.             if (d == null) {
  87.                 throw new NullArgumentException();
  88.             }
  89.             final int nRows = d.length;
  90.             if (nRows == 0) {
  91.                 throw new NoDataException(LocalizedFormats.AT_LEAST_ONE_ROW);
  92.             }
  93.             final int nCols = d[0].length;
  94.             if (nCols == 0) {
  95.                 throw new NoDataException(LocalizedFormats.AT_LEAST_ONE_COLUMN);
  96.             }
  97.             for (int r = 1; r < nRows; r++) {
  98.                 if (d[r].length != nCols) {
  99.                     throw new DimensionMismatchException(d[r].length, nCols);
  100.                 }
  101.             }
  102.             data = d;
  103.         }
  104.     }

  105.     /**
  106.      * Create a new (column) RealMatrix using {@code v} as the
  107.      * data for the unique column of the created matrix.
  108.      * The input array is copied.
  109.      *
  110.      * @param v Column vector holding data for new matrix.
  111.      */
  112.     public Array2DRowRealMatrix(final double[] v) {
  113.         final int nRows = v.length;
  114.         data = new double[nRows][1];
  115.         for (int row = 0; row < nRows; row++) {
  116.             data[row][0] = v[row];
  117.         }
  118.     }

  119.     /** {@inheritDoc} */
  120.     @Override
  121.     public RealMatrix createMatrix(final int rowDimension,
  122.                                    final int columnDimension) {
  123.         return new Array2DRowRealMatrix(rowDimension, columnDimension);
  124.     }

  125.     /** {@inheritDoc} */
  126.     @Override
  127.     public RealMatrix copy() {
  128.         return new Array2DRowRealMatrix(copyOut(), false);
  129.     }

  130.     /**
  131.      * Compute the sum of {@code this} and {@code m}.
  132.      *
  133.      * @param m Matrix to be added.
  134.      * @return {@code this + m}.
  135.      * @throws MatrixDimensionMismatchException if {@code m} is not the same
  136.      * size as {@code this}.
  137.      */
  138.     public Array2DRowRealMatrix add(final Array2DRowRealMatrix m) {
  139.         // Safety check.
  140.         checkAdd(m);

  141.         final int rowCount    = getRowDimension();
  142.         final int columnCount = getColumnDimension();
  143.         final double[][] outData = new double[rowCount][columnCount];
  144.         for (int row = 0; row < rowCount; row++) {
  145.             final double[] dataRow    = data[row];
  146.             final double[] mRow       = m.data[row];
  147.             final double[] outDataRow = outData[row];
  148.             for (int col = 0; col < columnCount; col++) {
  149.                 outDataRow[col] = dataRow[col] + mRow[col];
  150.             }
  151.         }

  152.         return new Array2DRowRealMatrix(outData, false);
  153.     }

  154.     /**
  155.      * Returns {@code this} minus {@code m}.
  156.      *
  157.      * @param m Matrix to be subtracted.
  158.      * @return {@code this - m}
  159.      * @throws MatrixDimensionMismatchException if {@code m} is not the same
  160.      * size as {@code this}.
  161.      */
  162.     public Array2DRowRealMatrix subtract(final Array2DRowRealMatrix m) {
  163.         checkAdd(m);

  164.         final int rowCount    = getRowDimension();
  165.         final int columnCount = getColumnDimension();
  166.         final double[][] outData = new double[rowCount][columnCount];
  167.         for (int row = 0; row < rowCount; row++) {
  168.             final double[] dataRow    = data[row];
  169.             final double[] mRow       = m.data[row];
  170.             final double[] outDataRow = outData[row];
  171.             for (int col = 0; col < columnCount; col++) {
  172.                 outDataRow[col] = dataRow[col] - mRow[col];
  173.             }
  174.         }

  175.         return new Array2DRowRealMatrix(outData, false);
  176.     }

  177.     /**
  178.      * Returns the result of postmultiplying {@code this} by {@code m}.
  179.      *
  180.      * @param m matrix to postmultiply by
  181.      * @return {@code this * m}
  182.      * @throws DimensionMismatchException if
  183.      * {@code columnDimension(this) != rowDimension(m)}
  184.      */
  185.     public Array2DRowRealMatrix multiply(final Array2DRowRealMatrix m) {
  186.         checkMultiply(m);

  187.         final int nRows = this.getRowDimension();
  188.         final int nCols = m.getColumnDimension();
  189.         final int nSum = this.getColumnDimension();

  190.         final double[][] outData = new double[nRows][nCols];
  191.         // Will hold a column of "m".
  192.         final double[] mCol = new double[nSum];
  193.         final double[][] mData = m.data;

  194.         // Multiply.
  195.         for (int col = 0; col < nCols; col++) {
  196.             // Copy all elements of column "col" of "m" so that
  197.             // will be in contiguous memory.
  198.             for (int mRow = 0; mRow < nSum; mRow++) {
  199.                 mCol[mRow] = mData[mRow][col];
  200.             }

  201.             for (int row = 0; row < nRows; row++) {
  202.                 final double[] dataRow = data[row];
  203.                 double sum = 0;
  204.                 for (int i = 0; i < nSum; i++) {
  205.                     sum += dataRow[i] * mCol[i];
  206.                 }
  207.                 outData[row][col] = sum;
  208.             }
  209.         }

  210.         return new Array2DRowRealMatrix(outData, false);
  211.     }

  212.     /** {@inheritDoc} */
  213.     @Override
  214.     public double[][] getData() {
  215.         return copyOut();
  216.     }

  217.     /**
  218.      * Get a reference to the underlying data array.
  219.      *
  220.      * @return 2-dimensional array of entries.
  221.      */
  222.     public double[][] getDataRef() {
  223.         return data;
  224.     }

  225.     /** {@inheritDoc} */
  226.     @Override
  227.     public void setSubMatrix(final double[][] subMatrix, final int row,
  228.                              final int column) {
  229.         if (data == null) {
  230.             if (row > 0) {
  231.                 throw new MathIllegalStateException(LocalizedFormats.FIRST_ROWS_NOT_INITIALIZED_YET, row);
  232.             }
  233.             if (column > 0) {
  234.                 throw new MathIllegalStateException(LocalizedFormats.FIRST_COLUMNS_NOT_INITIALIZED_YET, column);
  235.             }
  236.             NullArgumentException.check(subMatrix);
  237.             final int nRows = subMatrix.length;
  238.             if (nRows == 0) {
  239.                 throw new NoDataException(LocalizedFormats.AT_LEAST_ONE_ROW);
  240.             }

  241.             final int nCols = subMatrix[0].length;
  242.             if (nCols == 0) {
  243.                 throw new NoDataException(LocalizedFormats.AT_LEAST_ONE_COLUMN);
  244.             }
  245.             data = new double[subMatrix.length][nCols];
  246.             for (int i = 0; i < data.length; ++i) {
  247.                 if (subMatrix[i].length != nCols) {
  248.                     throw new DimensionMismatchException(subMatrix[i].length, nCols);
  249.                 }
  250.                 System.arraycopy(subMatrix[i], 0, data[i + row], column, nCols);
  251.             }
  252.         } else {
  253.             super.setSubMatrix(subMatrix, row, column);
  254.         }
  255.     }

  256.     /** {@inheritDoc} */
  257.     @Override
  258.     public double getEntry(final int row, final int column) {
  259.         try {
  260.             return data[row][column];
  261.         } catch (IndexOutOfBoundsException e) {
  262.             // throw the exact cause of the exception
  263.             MatrixUtils.checkMatrixIndex(this, row, column);
  264.             // should never happen
  265.             throw e;
  266.         }
  267.     }

  268.     /** {@inheritDoc} */
  269.     @Override
  270.     public void setEntry(final int row, final int column, final double value) {
  271.         MatrixUtils.checkMatrixIndex(this, row, column);
  272.         data[row][column] = value;
  273.     }

  274.     /** {@inheritDoc} */
  275.     @Override
  276.     public void addToEntry(final int row, final int column,
  277.                            final double increment) {
  278.         MatrixUtils.checkMatrixIndex(this, row, column);
  279.         data[row][column] += increment;
  280.     }

  281.     /** {@inheritDoc} */
  282.     @Override
  283.     public void multiplyEntry(final int row, final int column,
  284.                               final double factor) {
  285.         MatrixUtils.checkMatrixIndex(this, row, column);
  286.         data[row][column] *= factor;
  287.     }

  288.     /** {@inheritDoc} */
  289.     @Override
  290.     public int getRowDimension() {
  291.         return (data == null) ? 0 : data.length;
  292.     }

  293.     /** {@inheritDoc} */
  294.     @Override
  295.     public int getColumnDimension() {
  296.         return (data == null || data[0] == null) ? 0 : data[0].length;
  297.     }

  298.     /** {@inheritDoc} */
  299.     @Override
  300.     public double[] operate(final double[] v) {
  301.         final int nRows = this.getRowDimension();
  302.         final int nCols = this.getColumnDimension();
  303.         if (v.length != nCols) {
  304.             throw new DimensionMismatchException(v.length, nCols);
  305.         }
  306.         final double[] out = new double[nRows];
  307.         for (int row = 0; row < nRows; row++) {
  308.             final double[] dataRow = data[row];
  309.             double sum = 0;
  310.             for (int i = 0; i < nCols; i++) {
  311.                 sum += dataRow[i] * v[i];
  312.             }
  313.             out[row] = sum;
  314.         }
  315.         return out;
  316.     }

  317.     /** {@inheritDoc} */
  318.     @Override
  319.     public double[] preMultiply(final double[] v) {
  320.         final int nRows = getRowDimension();
  321.         final int nCols = getColumnDimension();
  322.         if (v.length != nRows) {
  323.             throw new DimensionMismatchException(v.length, nRows);
  324.         }

  325.         final double[] out = new double[nCols];
  326.         for (int col = 0; col < nCols; ++col) {
  327.             double sum = 0;
  328.             for (int i = 0; i < nRows; ++i) {
  329.                 sum += data[i][col] * v[i];
  330.             }
  331.             out[col] = sum;
  332.         }

  333.         return out;
  334.     }

  335.     /** {@inheritDoc} */
  336.     @Override
  337.     public RealMatrix getSubMatrix(final int startRow, final int endRow,
  338.                                    final int startColumn, final int endColumn) {
  339.         MatrixUtils.checkSubMatrixIndex(this, startRow, endRow, startColumn, endColumn);
  340.         final int rowCount = endRow - startRow + 1;
  341.         final int columnCount = endColumn - startColumn + 1;
  342.         final double[][] outData = new double[rowCount][columnCount];
  343.         for (int i = 0; i < rowCount; ++i) {
  344.             System.arraycopy(data[startRow + i], startColumn, outData[i], 0, columnCount);
  345.         }

  346.         Array2DRowRealMatrix subMatrix = new Array2DRowRealMatrix();
  347.         subMatrix.data = outData;
  348.         return subMatrix;
  349.     }

  350.     /** {@inheritDoc} */
  351.     @Override
  352.     public double walkInRowOrder(final RealMatrixChangingVisitor visitor) {
  353.         final int rows    = getRowDimension();
  354.         final int columns = getColumnDimension();
  355.         visitor.start(rows, columns, 0, rows - 1, 0, columns - 1);
  356.         for (int i = 0; i < rows; ++i) {
  357.             final double[] rowI = data[i];
  358.             for (int j = 0; j < columns; ++j) {
  359.                 rowI[j] = visitor.visit(i, j, rowI[j]);
  360.             }
  361.         }
  362.         return visitor.end();
  363.     }

  364.     /** {@inheritDoc} */
  365.     @Override
  366.     public double walkInRowOrder(final RealMatrixPreservingVisitor visitor) {
  367.         final int rows    = getRowDimension();
  368.         final int columns = getColumnDimension();
  369.         visitor.start(rows, columns, 0, rows - 1, 0, columns - 1);
  370.         for (int i = 0; i < rows; ++i) {
  371.             final double[] rowI = data[i];
  372.             for (int j = 0; j < columns; ++j) {
  373.                 visitor.visit(i, j, rowI[j]);
  374.             }
  375.         }
  376.         return visitor.end();
  377.     }

  378.     /** {@inheritDoc} */
  379.     @Override
  380.     public double walkInRowOrder(final RealMatrixChangingVisitor visitor,
  381.                                  final int startRow, final int endRow,
  382.                                  final int startColumn, final int endColumn) {
  383.         MatrixUtils.checkSubMatrixIndex(this, startRow, endRow, startColumn, endColumn);
  384.         visitor.start(getRowDimension(), getColumnDimension(),
  385.                       startRow, endRow, startColumn, endColumn);
  386.         for (int i = startRow; i <= endRow; ++i) {
  387.             final double[] rowI = data[i];
  388.             for (int j = startColumn; j <= endColumn; ++j) {
  389.                 rowI[j] = visitor.visit(i, j, rowI[j]);
  390.             }
  391.         }
  392.         return visitor.end();
  393.     }

  394.     /** {@inheritDoc} */
  395.     @Override
  396.     public double walkInRowOrder(final RealMatrixPreservingVisitor visitor,
  397.                                  final int startRow, final int endRow,
  398.                                  final int startColumn, final int endColumn) {
  399.         MatrixUtils.checkSubMatrixIndex(this, startRow, endRow, startColumn, endColumn);
  400.         visitor.start(getRowDimension(), getColumnDimension(),
  401.                       startRow, endRow, startColumn, endColumn);
  402.         for (int i = startRow; i <= endRow; ++i) {
  403.             final double[] rowI = data[i];
  404.             for (int j = startColumn; j <= endColumn; ++j) {
  405.                 visitor.visit(i, j, rowI[j]);
  406.             }
  407.         }
  408.         return visitor.end();
  409.     }

  410.     /** {@inheritDoc} */
  411.     @Override
  412.     public double walkInColumnOrder(final RealMatrixChangingVisitor visitor) {
  413.         final int rows    = getRowDimension();
  414.         final int columns = getColumnDimension();
  415.         visitor.start(rows, columns, 0, rows - 1, 0, columns - 1);
  416.         for (int j = 0; j < columns; ++j) {
  417.             for (int i = 0; i < rows; ++i) {
  418.                 final double[] rowI = data[i];
  419.                 rowI[j] = visitor.visit(i, j, rowI[j]);
  420.             }
  421.         }
  422.         return visitor.end();
  423.     }

  424.     /** {@inheritDoc} */
  425.     @Override
  426.     public double walkInColumnOrder(final RealMatrixPreservingVisitor visitor) {
  427.         final int rows    = getRowDimension();
  428.         final int columns = getColumnDimension();
  429.         visitor.start(rows, columns, 0, rows - 1, 0, columns - 1);
  430.         for (int j = 0; j < columns; ++j) {
  431.             for (int i = 0; i < rows; ++i) {
  432.                 visitor.visit(i, j, data[i][j]);
  433.             }
  434.         }
  435.         return visitor.end();
  436.     }

  437.     /** {@inheritDoc} */
  438.     @Override
  439.     public double walkInColumnOrder(final RealMatrixChangingVisitor visitor,
  440.                                     final int startRow, final int endRow,
  441.                                     final int startColumn, final int endColumn) {
  442.         MatrixUtils.checkSubMatrixIndex(this, startRow, endRow, startColumn, endColumn);
  443.         visitor.start(getRowDimension(), getColumnDimension(),
  444.                       startRow, endRow, startColumn, endColumn);
  445.         for (int j = startColumn; j <= endColumn; ++j) {
  446.             for (int i = startRow; i <= endRow; ++i) {
  447.                 final double[] rowI = data[i];
  448.                 rowI[j] = visitor.visit(i, j, rowI[j]);
  449.             }
  450.         }
  451.         return visitor.end();
  452.     }

  453.     /** {@inheritDoc} */
  454.     @Override
  455.     public double walkInColumnOrder(final RealMatrixPreservingVisitor visitor,
  456.                                     final int startRow, final int endRow,
  457.                                     final int startColumn, final int endColumn) {
  458.         MatrixUtils.checkSubMatrixIndex(this, startRow, endRow, startColumn, endColumn);
  459.         visitor.start(getRowDimension(), getColumnDimension(),
  460.                       startRow, endRow, startColumn, endColumn);
  461.         for (int j = startColumn; j <= endColumn; ++j) {
  462.             for (int i = startRow; i <= endRow; ++i) {
  463.                 visitor.visit(i, j, data[i][j]);
  464.             }
  465.         }
  466.         return visitor.end();
  467.     }

  468.     /**
  469.      * Get a fresh copy of the underlying data array.
  470.      *
  471.      * @return a copy of the underlying data array.
  472.      */
  473.     private double[][] copyOut() {
  474.         final int nRows = this.getRowDimension();
  475.         final double[][] out = new double[nRows][this.getColumnDimension()];
  476.         // can't copy 2-d array in one shot, otherwise get row references
  477.         for (int i = 0; i < nRows; i++) {
  478.             System.arraycopy(data[i], 0, out[i], 0, data[i].length);
  479.         }
  480.         return out;
  481.     }

  482.     /**
  483.      * Replace data with a fresh copy of the input array.
  484.      *
  485.      * @param in Data to copy.
  486.      * @throws NoDataException if the input array is empty.
  487.      * @throws DimensionMismatchException if the input array is not rectangular.
  488.      * @throws NullArgumentException if the input array is {@code null}.
  489.      */
  490.     private void copyIn(final double[][] in) {
  491.         setSubMatrix(in, 0, 0);
  492.     }

  493.     /** {@inheritDoc} */
  494.     @Override
  495.     public double[] getRow(final int row) {
  496.         MatrixUtils.checkRowIndex(this, row);
  497.         final int nCols = getColumnDimension();
  498.         final double[] out = new double[nCols];
  499.         System.arraycopy(data[row], 0, out, 0, nCols);
  500.         return out;
  501.     }

  502.     /** {@inheritDoc} */
  503.     @Override
  504.     public void setRow(final int row,
  505.                        final double[] array) {
  506.         MatrixUtils.checkRowIndex(this, row);
  507.         final int nCols = getColumnDimension();
  508.         if (array.length != nCols) {
  509.                 throw new MatrixDimensionMismatchException(1, array.length, 1, nCols);
  510.         }
  511.         System.arraycopy(array, 0, data[row], 0, nCols);
  512.     }
  513. }