DiagonalMatrix.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.linear;

  18. import java.io.Serializable;

  19. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  20. import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
  21. import org.apache.commons.math4.legacy.exception.NullArgumentException;
  22. import org.apache.commons.math4.legacy.exception.NumberIsTooLargeException;
  23. import org.apache.commons.math4.legacy.exception.OutOfRangeException;
  24. import org.apache.commons.math4.core.jdkmath.JdkMath;
  25. import org.apache.commons.numbers.core.Precision;

  26. /**
  27.  * Implementation of a diagonal matrix.
  28.  *
  29.  * @since 3.1.1
  30.  */
  31. public class DiagonalMatrix extends AbstractRealMatrix
  32.     implements Serializable {
  33.     /** Serializable version identifier. */
  34.     private static final long serialVersionUID = 20121229L;
  35.     /** Entries of the diagonal. */
  36.     private final double[] data;

  37.     /**
  38.      * Creates a matrix with the supplied dimension.
  39.      *
  40.      * @param dimension Number of rows and columns in the new matrix.
  41.      * @throws NotStrictlyPositiveException if the dimension is
  42.      * not positive.
  43.      */
  44.     public DiagonalMatrix(final int dimension)
  45.         throws NotStrictlyPositiveException {
  46.         super(dimension, dimension);
  47.         data = new double[dimension];
  48.     }

  49.     /**
  50.      * Creates a matrix using the input array as the underlying data.
  51.      * <br>
  52.      * The input array is copied, not referenced.
  53.      *
  54.      * @param d Data for the new matrix.
  55.      */
  56.     public DiagonalMatrix(final double[] d) {
  57.         this(d, true);
  58.     }

  59.     /**
  60.      * Creates a matrix using the input array as the underlying data.
  61.      * <br>
  62.      * If an array is created specially in order to be embedded in a
  63.      * this instance and not used directly, the {@code copyArray} may be
  64.      * set to {@code false}.
  65.      * This will prevent the copying and improve performance as no new
  66.      * array will be built and no data will be copied.
  67.      *
  68.      * @param d Data for new matrix.
  69.      * @param copyArray if {@code true}, the input array will be copied,
  70.      * otherwise it will be referenced.
  71.      * @exception NullArgumentException if d is null
  72.      */
  73.     public DiagonalMatrix(final double[] d, final boolean copyArray)
  74.         throws NullArgumentException {
  75.         NullArgumentException.check(d);
  76.         data = copyArray ? d.clone() : d;
  77.     }

  78.     /**
  79.      * {@inheritDoc}
  80.      *
  81.      * @throws DimensionMismatchException if the requested dimensions are not equal.
  82.      */
  83.     @Override
  84.     public RealMatrix createMatrix(final int rowDimension,
  85.                                    final int columnDimension)
  86.         throws NotStrictlyPositiveException,
  87.                DimensionMismatchException {
  88.         if (rowDimension != columnDimension) {
  89.             throw new DimensionMismatchException(rowDimension, columnDimension);
  90.         }

  91.         return new DiagonalMatrix(rowDimension);
  92.     }

  93.     /** {@inheritDoc} */
  94.     @Override
  95.     public RealMatrix copy() {
  96.         return new DiagonalMatrix(data);
  97.     }

  98.     /**
  99.      * Compute the sum of {@code this} and {@code m}.
  100.      *
  101.      * @param m Matrix to be added.
  102.      * @return {@code this + m}.
  103.      * @throws MatrixDimensionMismatchException if {@code m} is not the same
  104.      * size as {@code this}.
  105.      */
  106.     public DiagonalMatrix add(final DiagonalMatrix m)
  107.         throws MatrixDimensionMismatchException {
  108.         // Safety check.
  109.         MatrixUtils.checkAdditionCompatible(this, m);

  110.         final int dim = getRowDimension();
  111.         final double[] outData = new double[dim];
  112.         for (int i = 0; i < dim; i++) {
  113.             outData[i] = data[i] + m.data[i];
  114.         }

  115.         return new DiagonalMatrix(outData, false);
  116.     }

  117.     /**
  118.      * Returns {@code this} minus {@code m}.
  119.      *
  120.      * @param m Matrix to be subtracted.
  121.      * @return {@code this - m}
  122.      * @throws MatrixDimensionMismatchException if {@code m} is not the same
  123.      * size as {@code this}.
  124.      */
  125.     public DiagonalMatrix subtract(final DiagonalMatrix m)
  126.         throws MatrixDimensionMismatchException {
  127.         MatrixUtils.checkSubtractionCompatible(this, m);

  128.         final int dim = getRowDimension();
  129.         final double[] outData = new double[dim];
  130.         for (int i = 0; i < dim; i++) {
  131.             outData[i] = data[i] - m.data[i];
  132.         }

  133.         return new DiagonalMatrix(outData, false);
  134.     }

  135.     /**
  136.      * Returns the result of postmultiplying {@code this} by {@code m}.
  137.      *
  138.      * @param m matrix to postmultiply by
  139.      * @return {@code this * m}
  140.      * @throws DimensionMismatchException if
  141.      * {@code columnDimension(this) != rowDimension(m)}
  142.      */
  143.     public DiagonalMatrix multiply(final DiagonalMatrix m)
  144.         throws DimensionMismatchException {
  145.         MatrixUtils.checkMultiplicationCompatible(this, m);

  146.         final int dim = getRowDimension();
  147.         final double[] outData = new double[dim];
  148.         for (int i = 0; i < dim; i++) {
  149.             outData[i] = data[i] * m.data[i];
  150.         }

  151.         return new DiagonalMatrix(outData, false);
  152.     }

  153.     /**
  154.      * Returns the result of postmultiplying {@code this} by {@code m}.
  155.      *
  156.      * @param m matrix to postmultiply by
  157.      * @return {@code this * m}
  158.      * @throws DimensionMismatchException if
  159.      * {@code columnDimension(this) != rowDimension(m)}
  160.      */
  161.     @Override
  162.     public RealMatrix multiply(final RealMatrix m)
  163.         throws DimensionMismatchException {
  164.         if (m instanceof DiagonalMatrix) {
  165.             return multiply((DiagonalMatrix) m);
  166.         } else {
  167.             MatrixUtils.checkMultiplicationCompatible(this, m);
  168.             final int nRows = m.getRowDimension();
  169.             final int nCols = m.getColumnDimension();
  170.             final double[][] product = new double[nRows][nCols];
  171.             for (int r = 0; r < nRows; r++) {
  172.                 for (int c = 0; c < nCols; c++) {
  173.                     product[r][c] = data[r] * m.getEntry(r, c);
  174.                 }
  175.             }
  176.             return new Array2DRowRealMatrix(product, false);
  177.         }
  178.     }

  179.     /** {@inheritDoc} */
  180.     @Override
  181.     public double[][] getData() {
  182.         final int dim = getRowDimension();
  183.         final double[][] out = new double[dim][dim];

  184.         for (int i = 0; i < dim; i++) {
  185.             out[i][i] = data[i];
  186.         }

  187.         return out;
  188.     }

  189.     /**
  190.      * Gets a reference to the underlying data array.
  191.      *
  192.      * @return 1-dimensional array of entries.
  193.      */
  194.     public double[] getDataRef() {
  195.         return data;
  196.     }

  197.     /** {@inheritDoc} */
  198.     @Override
  199.     public double getEntry(final int row, final int column)
  200.         throws OutOfRangeException {
  201.         MatrixUtils.checkMatrixIndex(this, row, column);
  202.         return row == column ? data[row] : 0;
  203.     }

  204.     /** {@inheritDoc}
  205.      * @throws NumberIsTooLargeException if {@code row != column} and value is non-zero.
  206.      */
  207.     @Override
  208.     public void setEntry(final int row, final int column, final double value)
  209.         throws OutOfRangeException, NumberIsTooLargeException {
  210.         if (row == column) {
  211.             MatrixUtils.checkRowIndex(this, row);
  212.             data[row] = value;
  213.         } else {
  214.             ensureZero(value);
  215.         }
  216.     }

  217.     /** {@inheritDoc}
  218.      * @throws NumberIsTooLargeException if {@code row != column} and increment is non-zero.
  219.      */
  220.     @Override
  221.     public void addToEntry(final int row,
  222.                            final int column,
  223.                            final double increment)
  224.         throws OutOfRangeException, NumberIsTooLargeException {
  225.         if (row == column) {
  226.             MatrixUtils.checkRowIndex(this, row);
  227.             data[row] += increment;
  228.         } else {
  229.             ensureZero(increment);
  230.         }
  231.     }

  232.     /** {@inheritDoc} */
  233.     @Override
  234.     public void multiplyEntry(final int row,
  235.                               final int column,
  236.                               final double factor)
  237.         throws OutOfRangeException {
  238.         // we don't care about non-diagonal elements for multiplication
  239.         if (row == column) {
  240.             MatrixUtils.checkRowIndex(this, row);
  241.             data[row] *= factor;
  242.         }
  243.     }

  244.     /** {@inheritDoc} */
  245.     @Override
  246.     public int getRowDimension() {
  247.         return data.length;
  248.     }

  249.     /** {@inheritDoc} */
  250.     @Override
  251.     public int getColumnDimension() {
  252.         return data.length;
  253.     }

  254.     /** {@inheritDoc} */
  255.     @Override
  256.     public double[] operate(final double[] v)
  257.         throws DimensionMismatchException {
  258.         return multiply(new DiagonalMatrix(v, false)).getDataRef();
  259.     }

  260.     /** {@inheritDoc} */
  261.     @Override
  262.     public double[] preMultiply(final double[] v)
  263.         throws DimensionMismatchException {
  264.         return operate(v);
  265.     }

  266.     /** {@inheritDoc} */
  267.     @Override
  268.     public RealVector preMultiply(final RealVector v) throws DimensionMismatchException {
  269.         final double[] vectorData;
  270.         if (v instanceof ArrayRealVector) {
  271.             vectorData = ((ArrayRealVector) v).getDataRef();
  272.         } else {
  273.             vectorData = v.toArray();
  274.         }
  275.         return MatrixUtils.createRealVector(preMultiply(vectorData));
  276.     }

  277.     /** Ensure a value is zero.
  278.      * @param value value to check
  279.      * @exception NumberIsTooLargeException if value is not zero
  280.      */
  281.     private void ensureZero(final double value) throws NumberIsTooLargeException {
  282.         if (!Precision.equals(0.0, value, 1)) {
  283.             throw new NumberIsTooLargeException(JdkMath.abs(value), 0, true);
  284.         }
  285.     }

  286.     /**
  287.      * Computes the inverse of this diagonal matrix.
  288.      * <p>
  289.      * Note: this method will use a singularity threshold of 0,
  290.      * use {@link #inverse(double)} if a different threshold is needed.
  291.      *
  292.      * @return the inverse of {@code m}
  293.      * @throws SingularMatrixException if the matrix is singular
  294.      * @since 3.3
  295.      */
  296.     public DiagonalMatrix inverse() throws SingularMatrixException {
  297.         return inverse(0);
  298.     }

  299.     /**
  300.      * Computes the inverse of this diagonal matrix.
  301.      *
  302.      * @param threshold Singularity threshold.
  303.      * @return the inverse of {@code m}
  304.      * @throws SingularMatrixException if the matrix is singular
  305.      * @since 3.3
  306.      */
  307.     public DiagonalMatrix inverse(double threshold) throws SingularMatrixException {
  308.         if (isSingular(threshold)) {
  309.             throw new SingularMatrixException();
  310.         }

  311.         final double[] result = new double[data.length];
  312.         for (int i = 0; i < data.length; i++) {
  313.             result[i] = 1.0 / data[i];
  314.         }
  315.         return new DiagonalMatrix(result, false);
  316.     }

  317.     /** Returns whether this diagonal matrix is singular, i.e. any diagonal entry
  318.      * is equal to {@code 0} within the given threshold.
  319.      *
  320.      * @param threshold Singularity threshold.
  321.      * @return {@code true} if the matrix is singular, {@code false} otherwise
  322.      * @since 3.3
  323.      */
  324.     public boolean isSingular(double threshold) {
  325.         for (int i = 0; i < data.length; i++) {
  326.             if (Precision.equals(data[i], 0.0, threshold)) {
  327.                 return true;
  328.             }
  329.         }
  330.         return false;
  331.     }
  332. }