001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.legacy.linear;
018
019import java.io.Serializable;
020
021import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
022import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
023import org.apache.commons.math4.legacy.exception.NullArgumentException;
024import org.apache.commons.math4.legacy.exception.NumberIsTooLargeException;
025import org.apache.commons.math4.legacy.exception.OutOfRangeException;
026import org.apache.commons.math4.core.jdkmath.JdkMath;
027import org.apache.commons.numbers.core.Precision;
028
029/**
030 * Implementation of a diagonal matrix.
031 *
032 * @since 3.1.1
033 */
034public class DiagonalMatrix extends AbstractRealMatrix
035    implements Serializable {
036    /** Serializable version identifier. */
037    private static final long serialVersionUID = 20121229L;
038    /** Entries of the diagonal. */
039    private final double[] data;
040
041    /**
042     * Creates a matrix with the supplied dimension.
043     *
044     * @param dimension Number of rows and columns in the new matrix.
045     * @throws NotStrictlyPositiveException if the dimension is
046     * not positive.
047     */
048    public DiagonalMatrix(final int dimension)
049        throws NotStrictlyPositiveException {
050        super(dimension, dimension);
051        data = new double[dimension];
052    }
053
054    /**
055     * Creates a matrix using the input array as the underlying data.
056     * <br>
057     * The input array is copied, not referenced.
058     *
059     * @param d Data for the new matrix.
060     */
061    public DiagonalMatrix(final double[] d) {
062        this(d, true);
063    }
064
065    /**
066     * Creates a matrix using the input array as the underlying data.
067     * <br>
068     * If an array is created specially in order to be embedded in a
069     * this instance and not used directly, the {@code copyArray} may be
070     * set to {@code false}.
071     * This will prevent the copying and improve performance as no new
072     * array will be built and no data will be copied.
073     *
074     * @param d Data for new matrix.
075     * @param copyArray if {@code true}, the input array will be copied,
076     * otherwise it will be referenced.
077     * @exception NullArgumentException if d is null
078     */
079    public DiagonalMatrix(final double[] d, final boolean copyArray)
080        throws NullArgumentException {
081        NullArgumentException.check(d);
082        data = copyArray ? d.clone() : d;
083    }
084
085    /**
086     * {@inheritDoc}
087     *
088     * @throws DimensionMismatchException if the requested dimensions are not equal.
089     */
090    @Override
091    public RealMatrix createMatrix(final int rowDimension,
092                                   final int columnDimension)
093        throws NotStrictlyPositiveException,
094               DimensionMismatchException {
095        if (rowDimension != columnDimension) {
096            throw new DimensionMismatchException(rowDimension, columnDimension);
097        }
098
099        return new DiagonalMatrix(rowDimension);
100    }
101
102    /** {@inheritDoc} */
103    @Override
104    public RealMatrix copy() {
105        return new DiagonalMatrix(data);
106    }
107
108    /**
109     * Compute the sum of {@code this} and {@code m}.
110     *
111     * @param m Matrix to be added.
112     * @return {@code this + m}.
113     * @throws MatrixDimensionMismatchException if {@code m} is not the same
114     * size as {@code this}.
115     */
116    public DiagonalMatrix add(final DiagonalMatrix m)
117        throws MatrixDimensionMismatchException {
118        // Safety check.
119        MatrixUtils.checkAdditionCompatible(this, m);
120
121        final int dim = getRowDimension();
122        final double[] outData = new double[dim];
123        for (int i = 0; i < dim; i++) {
124            outData[i] = data[i] + m.data[i];
125        }
126
127        return new DiagonalMatrix(outData, false);
128    }
129
130    /**
131     * Returns {@code this} minus {@code m}.
132     *
133     * @param m Matrix to be subtracted.
134     * @return {@code this - m}
135     * @throws MatrixDimensionMismatchException if {@code m} is not the same
136     * size as {@code this}.
137     */
138    public DiagonalMatrix subtract(final DiagonalMatrix m)
139        throws MatrixDimensionMismatchException {
140        MatrixUtils.checkSubtractionCompatible(this, m);
141
142        final int dim = getRowDimension();
143        final double[] outData = new double[dim];
144        for (int i = 0; i < dim; i++) {
145            outData[i] = data[i] - m.data[i];
146        }
147
148        return new DiagonalMatrix(outData, false);
149    }
150
151    /**
152     * Returns the result of postmultiplying {@code this} by {@code m}.
153     *
154     * @param m matrix to postmultiply by
155     * @return {@code this * m}
156     * @throws DimensionMismatchException if
157     * {@code columnDimension(this) != rowDimension(m)}
158     */
159    public DiagonalMatrix multiply(final DiagonalMatrix m)
160        throws DimensionMismatchException {
161        MatrixUtils.checkMultiplicationCompatible(this, m);
162
163        final int dim = getRowDimension();
164        final double[] outData = new double[dim];
165        for (int i = 0; i < dim; i++) {
166            outData[i] = data[i] * m.data[i];
167        }
168
169        return new DiagonalMatrix(outData, false);
170    }
171
172    /**
173     * Returns the result of postmultiplying {@code this} by {@code m}.
174     *
175     * @param m matrix to postmultiply by
176     * @return {@code this * m}
177     * @throws DimensionMismatchException if
178     * {@code columnDimension(this) != rowDimension(m)}
179     */
180    @Override
181    public RealMatrix multiply(final RealMatrix m)
182        throws DimensionMismatchException {
183        if (m instanceof DiagonalMatrix) {
184            return multiply((DiagonalMatrix) m);
185        } else {
186            MatrixUtils.checkMultiplicationCompatible(this, m);
187            final int nRows = m.getRowDimension();
188            final int nCols = m.getColumnDimension();
189            final double[][] product = new double[nRows][nCols];
190            for (int r = 0; r < nRows; r++) {
191                for (int c = 0; c < nCols; c++) {
192                    product[r][c] = data[r] * m.getEntry(r, c);
193                }
194            }
195            return new Array2DRowRealMatrix(product, false);
196        }
197    }
198
199    /** {@inheritDoc} */
200    @Override
201    public double[][] getData() {
202        final int dim = getRowDimension();
203        final double[][] out = new double[dim][dim];
204
205        for (int i = 0; i < dim; i++) {
206            out[i][i] = data[i];
207        }
208
209        return out;
210    }
211
212    /**
213     * Gets a reference to the underlying data array.
214     *
215     * @return 1-dimensional array of entries.
216     */
217    public double[] getDataRef() {
218        return data;
219    }
220
221    /** {@inheritDoc} */
222    @Override
223    public double getEntry(final int row, final int column)
224        throws OutOfRangeException {
225        MatrixUtils.checkMatrixIndex(this, row, column);
226        return row == column ? data[row] : 0;
227    }
228
229    /** {@inheritDoc}
230     * @throws NumberIsTooLargeException if {@code row != column} and value is non-zero.
231     */
232    @Override
233    public void setEntry(final int row, final int column, final double value)
234        throws OutOfRangeException, NumberIsTooLargeException {
235        if (row == column) {
236            MatrixUtils.checkRowIndex(this, row);
237            data[row] = value;
238        } else {
239            ensureZero(value);
240        }
241    }
242
243    /** {@inheritDoc}
244     * @throws NumberIsTooLargeException if {@code row != column} and increment is non-zero.
245     */
246    @Override
247    public void addToEntry(final int row,
248                           final int column,
249                           final double increment)
250        throws OutOfRangeException, NumberIsTooLargeException {
251        if (row == column) {
252            MatrixUtils.checkRowIndex(this, row);
253            data[row] += increment;
254        } else {
255            ensureZero(increment);
256        }
257    }
258
259    /** {@inheritDoc} */
260    @Override
261    public void multiplyEntry(final int row,
262                              final int column,
263                              final double factor)
264        throws OutOfRangeException {
265        // we don't care about non-diagonal elements for multiplication
266        if (row == column) {
267            MatrixUtils.checkRowIndex(this, row);
268            data[row] *= factor;
269        }
270    }
271
272    /** {@inheritDoc} */
273    @Override
274    public int getRowDimension() {
275        return data.length;
276    }
277
278    /** {@inheritDoc} */
279    @Override
280    public int getColumnDimension() {
281        return data.length;
282    }
283
284    /** {@inheritDoc} */
285    @Override
286    public double[] operate(final double[] v)
287        throws DimensionMismatchException {
288        return multiply(new DiagonalMatrix(v, false)).getDataRef();
289    }
290
291    /** {@inheritDoc} */
292    @Override
293    public double[] preMultiply(final double[] v)
294        throws DimensionMismatchException {
295        return operate(v);
296    }
297
298    /** {@inheritDoc} */
299    @Override
300    public RealVector preMultiply(final RealVector v) throws DimensionMismatchException {
301        final double[] vectorData;
302        if (v instanceof ArrayRealVector) {
303            vectorData = ((ArrayRealVector) v).getDataRef();
304        } else {
305            vectorData = v.toArray();
306        }
307        return MatrixUtils.createRealVector(preMultiply(vectorData));
308    }
309
310    /** Ensure a value is zero.
311     * @param value value to check
312     * @exception NumberIsTooLargeException if value is not zero
313     */
314    private void ensureZero(final double value) throws NumberIsTooLargeException {
315        if (!Precision.equals(0.0, value, 1)) {
316            throw new NumberIsTooLargeException(JdkMath.abs(value), 0, true);
317        }
318    }
319
320    /**
321     * Computes the inverse of this diagonal matrix.
322     * <p>
323     * Note: this method will use a singularity threshold of 0,
324     * use {@link #inverse(double)} if a different threshold is needed.
325     *
326     * @return the inverse of {@code m}
327     * @throws SingularMatrixException if the matrix is singular
328     * @since 3.3
329     */
330    public DiagonalMatrix inverse() throws SingularMatrixException {
331        return inverse(0);
332    }
333
334    /**
335     * Computes the inverse of this diagonal matrix.
336     *
337     * @param threshold Singularity threshold.
338     * @return the inverse of {@code m}
339     * @throws SingularMatrixException if the matrix is singular
340     * @since 3.3
341     */
342    public DiagonalMatrix inverse(double threshold) throws SingularMatrixException {
343        if (isSingular(threshold)) {
344            throw new SingularMatrixException();
345        }
346
347        final double[] result = new double[data.length];
348        for (int i = 0; i < data.length; i++) {
349            result[i] = 1.0 / data[i];
350        }
351        return new DiagonalMatrix(result, false);
352    }
353
354    /** Returns whether this diagonal matrix is singular, i.e. any diagonal entry
355     * is equal to {@code 0} within the given threshold.
356     *
357     * @param threshold Singularity threshold.
358     * @return {@code true} if the matrix is singular, {@code false} otherwise
359     * @since 3.3
360     */
361    public boolean isSingular(double threshold) {
362        for (int i = 0; i < data.length; i++) {
363            if (Precision.equals(data[i], 0.0, threshold)) {
364                return true;
365            }
366        }
367        return false;
368    }
369}