1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.linear;
18
19 import java.io.Serializable;
20
21 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
22 import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
23 import org.apache.commons.math4.legacy.exception.NullArgumentException;
24 import org.apache.commons.math4.legacy.exception.NumberIsTooLargeException;
25 import org.apache.commons.math4.legacy.exception.OutOfRangeException;
26 import org.apache.commons.math4.core.jdkmath.JdkMath;
27 import org.apache.commons.numbers.core.Precision;
28
29
30
31
32
33
34 public class DiagonalMatrix extends AbstractRealMatrix
35 implements Serializable {
36
37 private static final long serialVersionUID = 20121229L;
38
39 private final double[] data;
40
41
42
43
44
45
46
47
48 public DiagonalMatrix(final int dimension)
49 throws NotStrictlyPositiveException {
50 super(dimension, dimension);
51 data = new double[dimension];
52 }
53
54
55
56
57
58
59
60
61 public DiagonalMatrix(final double[] d) {
62 this(d, true);
63 }
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79 public DiagonalMatrix(final double[] d, final boolean copyArray)
80 throws NullArgumentException {
81 NullArgumentException.check(d);
82 data = copyArray ? d.clone() : d;
83 }
84
85
86
87
88
89
90 @Override
91 public RealMatrix createMatrix(final int rowDimension,
92 final int columnDimension)
93 throws NotStrictlyPositiveException,
94 DimensionMismatchException {
95 if (rowDimension != columnDimension) {
96 throw new DimensionMismatchException(rowDimension, columnDimension);
97 }
98
99 return new DiagonalMatrix(rowDimension);
100 }
101
102
103 @Override
104 public RealMatrix copy() {
105 return new DiagonalMatrix(data);
106 }
107
108
109
110
111
112
113
114
115
116 public DiagonalMatrix add(final DiagonalMatrix m)
117 throws MatrixDimensionMismatchException {
118
119 MatrixUtils.checkAdditionCompatible(this, m);
120
121 final int dim = getRowDimension();
122 final double[] outData = new double[dim];
123 for (int i = 0; i < dim; i++) {
124 outData[i] = data[i] + m.data[i];
125 }
126
127 return new DiagonalMatrix(outData, false);
128 }
129
130
131
132
133
134
135
136
137
138 public DiagonalMatrix subtract(final DiagonalMatrix m)
139 throws MatrixDimensionMismatchException {
140 MatrixUtils.checkSubtractionCompatible(this, m);
141
142 final int dim = getRowDimension();
143 final double[] outData = new double[dim];
144 for (int i = 0; i < dim; i++) {
145 outData[i] = data[i] - m.data[i];
146 }
147
148 return new DiagonalMatrix(outData, false);
149 }
150
151
152
153
154
155
156
157
158
159 public DiagonalMatrix multiply(final DiagonalMatrix m)
160 throws DimensionMismatchException {
161 MatrixUtils.checkMultiplicationCompatible(this, m);
162
163 final int dim = getRowDimension();
164 final double[] outData = new double[dim];
165 for (int i = 0; i < dim; i++) {
166 outData[i] = data[i] * m.data[i];
167 }
168
169 return new DiagonalMatrix(outData, false);
170 }
171
172
173
174
175
176
177
178
179
180 @Override
181 public RealMatrix multiply(final RealMatrix m)
182 throws DimensionMismatchException {
183 if (m instanceof DiagonalMatrix) {
184 return multiply((DiagonalMatrix) m);
185 } else {
186 MatrixUtils.checkMultiplicationCompatible(this, m);
187 final int nRows = m.getRowDimension();
188 final int nCols = m.getColumnDimension();
189 final double[][] product = new double[nRows][nCols];
190 for (int r = 0; r < nRows; r++) {
191 for (int c = 0; c < nCols; c++) {
192 product[r][c] = data[r] * m.getEntry(r, c);
193 }
194 }
195 return new Array2DRowRealMatrix(product, false);
196 }
197 }
198
199
200 @Override
201 public double[][] getData() {
202 final int dim = getRowDimension();
203 final double[][] out = new double[dim][dim];
204
205 for (int i = 0; i < dim; i++) {
206 out[i][i] = data[i];
207 }
208
209 return out;
210 }
211
212
213
214
215
216
217 public double[] getDataRef() {
218 return data;
219 }
220
221
222 @Override
223 public double getEntry(final int row, final int column)
224 throws OutOfRangeException {
225 MatrixUtils.checkMatrixIndex(this, row, column);
226 return row == column ? data[row] : 0;
227 }
228
229
230
231
232 @Override
233 public void setEntry(final int row, final int column, final double value)
234 throws OutOfRangeException, NumberIsTooLargeException {
235 if (row == column) {
236 MatrixUtils.checkRowIndex(this, row);
237 data[row] = value;
238 } else {
239 ensureZero(value);
240 }
241 }
242
243
244
245
246 @Override
247 public void addToEntry(final int row,
248 final int column,
249 final double increment)
250 throws OutOfRangeException, NumberIsTooLargeException {
251 if (row == column) {
252 MatrixUtils.checkRowIndex(this, row);
253 data[row] += increment;
254 } else {
255 ensureZero(increment);
256 }
257 }
258
259
260 @Override
261 public void multiplyEntry(final int row,
262 final int column,
263 final double factor)
264 throws OutOfRangeException {
265
266 if (row == column) {
267 MatrixUtils.checkRowIndex(this, row);
268 data[row] *= factor;
269 }
270 }
271
272
273 @Override
274 public int getRowDimension() {
275 return data.length;
276 }
277
278
279 @Override
280 public int getColumnDimension() {
281 return data.length;
282 }
283
284
285 @Override
286 public double[] operate(final double[] v)
287 throws DimensionMismatchException {
288 return multiply(new DiagonalMatrix(v, false)).getDataRef();
289 }
290
291
292 @Override
293 public double[] preMultiply(final double[] v)
294 throws DimensionMismatchException {
295 return operate(v);
296 }
297
298
299 @Override
300 public RealVector preMultiply(final RealVector v) throws DimensionMismatchException {
301 final double[] vectorData;
302 if (v instanceof ArrayRealVector) {
303 vectorData = ((ArrayRealVector) v).getDataRef();
304 } else {
305 vectorData = v.toArray();
306 }
307 return MatrixUtils.createRealVector(preMultiply(vectorData));
308 }
309
310
311
312
313
314 private void ensureZero(final double value) throws NumberIsTooLargeException {
315 if (!Precision.equals(0.0, value, 1)) {
316 throw new NumberIsTooLargeException(JdkMath.abs(value), 0, true);
317 }
318 }
319
320
321
322
323
324
325
326
327
328
329
330 public DiagonalMatrix inverse() throws SingularMatrixException {
331 return inverse(0);
332 }
333
334
335
336
337
338
339
340
341
342 public DiagonalMatrix inverse(double threshold) throws SingularMatrixException {
343 if (isSingular(threshold)) {
344 throw new SingularMatrixException();
345 }
346
347 final double[] result = new double[data.length];
348 for (int i = 0; i < data.length; i++) {
349 result[i] = 1.0 / data[i];
350 }
351 return new DiagonalMatrix(result, false);
352 }
353
354
355
356
357
358
359
360
361 public boolean isSingular(double threshold) {
362 for (int i = 0; i < data.length; i++) {
363 if (Precision.equals(data[i], 0.0, threshold)) {
364 return true;
365 }
366 }
367 return false;
368 }
369 }