001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.linear;
018
019import java.io.Serializable;
020
021import org.apache.commons.math3.exception.DimensionMismatchException;
022import org.apache.commons.math3.exception.NotStrictlyPositiveException;
023import org.apache.commons.math3.exception.NullArgumentException;
024import org.apache.commons.math3.exception.NumberIsTooLargeException;
025import org.apache.commons.math3.exception.OutOfRangeException;
026import org.apache.commons.math3.util.FastMath;
027import org.apache.commons.math3.util.MathUtils;
028import org.apache.commons.math3.util.Precision;
029
030/**
031 * Implementation of a diagonal matrix.
032 *
033 * @since 3.1.1
034 */
035public class DiagonalMatrix extends AbstractRealMatrix
036    implements Serializable {
037    /** Serializable version identifier. */
038    private static final long serialVersionUID = 20121229L;
039    /** Entries of the diagonal. */
040    private final double[] data;
041
042    /**
043     * Creates a matrix with the supplied dimension.
044     *
045     * @param dimension Number of rows and columns in the new matrix.
046     * @throws NotStrictlyPositiveException if the dimension is
047     * not positive.
048     */
049    public DiagonalMatrix(final int dimension)
050        throws NotStrictlyPositiveException {
051        super(dimension, dimension);
052        data = new double[dimension];
053    }
054
055    /**
056     * Creates a matrix using the input array as the underlying data.
057     * <br/>
058     * The input array is copied, not referenced.
059     *
060     * @param d Data for the new matrix.
061     */
062    public DiagonalMatrix(final double[] d) {
063        this(d, true);
064    }
065
066    /**
067     * Creates a matrix using the input array as the underlying data.
068     * <br/>
069     * If an array is created specially in order to be embedded in a
070     * this instance and not used directly, the {@code copyArray} may be
071     * set to {@code false}.
072     * This will prevent the copying and improve performance as no new
073     * array will be built and no data will be copied.
074     *
075     * @param d Data for new matrix.
076     * @param copyArray if {@code true}, the input array will be copied,
077     * otherwise it will be referenced.
078     * @exception NullArgumentException if d is null
079     */
080    public DiagonalMatrix(final double[] d, final boolean copyArray)
081        throws NullArgumentException {
082        MathUtils.checkNotNull(d);
083        data = copyArray ? d.clone() : d;
084    }
085
086    /**
087     * {@inheritDoc}
088     *
089     * @throws DimensionMismatchException if the requested dimensions are not equal.
090     */
091    @Override
092    public RealMatrix createMatrix(final int rowDimension,
093                                   final int columnDimension)
094        throws NotStrictlyPositiveException,
095               DimensionMismatchException {
096        if (rowDimension != columnDimension) {
097            throw new DimensionMismatchException(rowDimension, columnDimension);
098        }
099
100        return new DiagonalMatrix(rowDimension);
101    }
102
103    /** {@inheritDoc} */
104    @Override
105    public RealMatrix copy() {
106        return new DiagonalMatrix(data);
107    }
108
109    /**
110     * Compute the sum of {@code this} and {@code m}.
111     *
112     * @param m Matrix to be added.
113     * @return {@code this + m}.
114     * @throws MatrixDimensionMismatchException if {@code m} is not the same
115     * size as {@code this}.
116     */
117    public DiagonalMatrix add(final DiagonalMatrix m)
118        throws MatrixDimensionMismatchException {
119        // Safety check.
120        MatrixUtils.checkAdditionCompatible(this, m);
121
122        final int dim = getRowDimension();
123        final double[] outData = new double[dim];
124        for (int i = 0; i < dim; i++) {
125            outData[i] = data[i] + m.data[i];
126        }
127
128        return new DiagonalMatrix(outData, false);
129    }
130
131    /**
132     * Returns {@code this} minus {@code m}.
133     *
134     * @param m Matrix to be subtracted.
135     * @return {@code this - m}
136     * @throws MatrixDimensionMismatchException if {@code m} is not the same
137     * size as {@code this}.
138     */
139    public DiagonalMatrix subtract(final DiagonalMatrix m)
140        throws MatrixDimensionMismatchException {
141        MatrixUtils.checkSubtractionCompatible(this, m);
142
143        final int dim = getRowDimension();
144        final double[] outData = new double[dim];
145        for (int i = 0; i < dim; i++) {
146            outData[i] = data[i] - m.data[i];
147        }
148
149        return new DiagonalMatrix(outData, false);
150    }
151
152    /**
153     * Returns the result of postmultiplying {@code this} by {@code m}.
154     *
155     * @param m matrix to postmultiply by
156     * @return {@code this * m}
157     * @throws DimensionMismatchException if
158     * {@code columnDimension(this) != rowDimension(m)}
159     */
160    public DiagonalMatrix multiply(final DiagonalMatrix m)
161        throws DimensionMismatchException {
162        MatrixUtils.checkMultiplicationCompatible(this, m);
163
164        final int dim = getRowDimension();
165        final double[] outData = new double[dim];
166        for (int i = 0; i < dim; i++) {
167            outData[i] = data[i] * m.data[i];
168        }
169
170        return new DiagonalMatrix(outData, false);
171    }
172
173    /**
174     * Returns the result of postmultiplying {@code this} by {@code m}.
175     *
176     * @param m matrix to postmultiply by
177     * @return {@code this * m}
178     * @throws DimensionMismatchException if
179     * {@code columnDimension(this) != rowDimension(m)}
180     */
181    @Override
182    public RealMatrix multiply(final RealMatrix m)
183        throws DimensionMismatchException {
184        if (m instanceof DiagonalMatrix) {
185            return multiply((DiagonalMatrix) m);
186        } else {
187            MatrixUtils.checkMultiplicationCompatible(this, m);
188            final int nRows = m.getRowDimension();
189            final int nCols = m.getColumnDimension();
190            final double[][] product = new double[nRows][nCols];
191            for (int r = 0; r < nRows; r++) {
192                for (int c = 0; c < nCols; c++) {
193                    product[r][c] = data[r] * m.getEntry(r, c);
194                }
195            }
196            return new Array2DRowRealMatrix(product, false);
197        }
198    }
199
200    /** {@inheritDoc} */
201    @Override
202    public double[][] getData() {
203        final int dim = getRowDimension();
204        final double[][] out = new double[dim][dim];
205
206        for (int i = 0; i < dim; i++) {
207            out[i][i] = data[i];
208        }
209
210        return out;
211    }
212
213    /**
214     * Gets a reference to the underlying data array.
215     *
216     * @return 1-dimensional array of entries.
217     */
218    public double[] getDataRef() {
219        return data;
220    }
221
222    /** {@inheritDoc} */
223    @Override
224    public double getEntry(final int row, final int column)
225        throws OutOfRangeException {
226        MatrixUtils.checkMatrixIndex(this, row, column);
227        return row == column ? data[row] : 0;
228    }
229
230    /** {@inheritDoc}
231     * @throws NumberIsTooLargeException if {@code row != column} and value is non-zero.
232     */
233    @Override
234    public void setEntry(final int row, final int column, final double value)
235        throws OutOfRangeException, NumberIsTooLargeException {
236        if (row == column) {
237            MatrixUtils.checkRowIndex(this, row);
238            data[row] = value;
239        } else {
240            ensureZero(value);
241        }
242    }
243
244    /** {@inheritDoc}
245     * @throws NumberIsTooLargeException if {@code row != column} and increment is non-zero.
246     */
247    @Override
248    public void addToEntry(final int row,
249                           final int column,
250                           final double increment)
251        throws OutOfRangeException, NumberIsTooLargeException {
252        if (row == column) {
253            MatrixUtils.checkRowIndex(this, row);
254            data[row] += increment;
255        } else {
256            ensureZero(increment);
257        }
258    }
259
260    /** {@inheritDoc} */
261    @Override
262    public void multiplyEntry(final int row,
263                              final int column,
264                              final double factor)
265        throws OutOfRangeException {
266        // we don't care about non-diagonal elements for multiplication
267        if (row == column) {
268            MatrixUtils.checkRowIndex(this, row);
269            data[row] *= factor;
270        }
271    }
272
273    /** {@inheritDoc} */
274    @Override
275    public int getRowDimension() {
276        return data.length;
277    }
278
279    /** {@inheritDoc} */
280    @Override
281    public int getColumnDimension() {
282        return data.length;
283    }
284
285    /** {@inheritDoc} */
286    @Override
287    public double[] operate(final double[] v)
288        throws DimensionMismatchException {
289        return multiply(new DiagonalMatrix(v, false)).getDataRef();
290    }
291
292    /** {@inheritDoc} */
293    @Override
294    public double[] preMultiply(final double[] v)
295        throws DimensionMismatchException {
296        return operate(v);
297    }
298
299    /** {@inheritDoc} */
300    @Override
301    public RealVector preMultiply(final RealVector v) throws DimensionMismatchException {
302        final double[] vectorData;
303        if (v instanceof ArrayRealVector) {
304            vectorData = ((ArrayRealVector) v).getDataRef();
305        } else {
306            vectorData = v.toArray();
307        }
308        return MatrixUtils.createRealVector(preMultiply(vectorData));
309    }
310
311    /** Ensure a value is zero.
312     * @param value value to check
313     * @exception NumberIsTooLargeException if value is not zero
314     */
315    private void ensureZero(final double value) throws NumberIsTooLargeException {
316        if (!Precision.equals(0.0, value, 1)) {
317            throw new NumberIsTooLargeException(FastMath.abs(value), 0, true);
318        }
319    }
320
321    /**
322     * Computes the inverse of this diagonal matrix.
323     * <p>
324     * Note: this method will use a singularity threshold of 0,
325     * use {@link #inverse(double)} if a different threshold is needed.
326     *
327     * @return the inverse of {@code m}
328     * @throws SingularMatrixException if the matrix is singular
329     * @since 3.3
330     */
331    public DiagonalMatrix inverse() throws SingularMatrixException {
332        return inverse(0);
333    }
334
335    /**
336     * Computes the inverse of this diagonal matrix.
337     *
338     * @param threshold Singularity threshold.
339     * @return the inverse of {@code m}
340     * @throws SingularMatrixException if the matrix is singular
341     * @since 3.3
342     */
343    public DiagonalMatrix inverse(double threshold) throws SingularMatrixException {
344        if (isSingular(threshold)) {
345            throw new SingularMatrixException();
346        }
347
348        final double[] result = new double[data.length];
349        for (int i = 0; i < data.length; i++) {
350            result[i] = 1.0 / data[i];
351        }
352        return new DiagonalMatrix(result, false);
353    }
354
355    /** Returns whether this diagonal matrix is singular, i.e. any diagonal entry
356     * is equal to {@code 0} within the given threshold.
357     *
358     * @param threshold Singularity threshold.
359     * @return {@code true} if the matrix is singular, {@code false} otherwise
360     * @since 3.3
361     */
362    public boolean isSingular(double threshold) {
363        for (int i = 0; i < data.length; i++) {
364            if (Precision.equals(data[i], 0.0, threshold)) {
365                return true;
366            }
367        }
368        return false;
369    }
370}