1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.linear;
18
19 import org.apache.commons.math4.legacy.TestUtils;
20 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
21 import org.apache.commons.math4.legacy.exception.NullArgumentException;
22 import org.apache.commons.math4.legacy.exception.NumberIsTooLargeException;
23 import org.apache.commons.math4.legacy.exception.OutOfRangeException;
24 import org.apache.commons.numbers.core.Precision;
25 import org.junit.Assert;
26 import org.junit.Test;
27
28
29
30
31 public class DiagonalMatrixTest {
32 @Test
33 public void testConstructor1() {
34 final int dim = 3;
35 final DiagonalMatrix m = new DiagonalMatrix(dim);
36 Assert.assertEquals(dim, m.getRowDimension());
37 Assert.assertEquals(dim, m.getColumnDimension());
38 }
39
40 @Test
41 public void testConstructor2() {
42 final double[] d = { -1.2, 3.4, 5 };
43 final DiagonalMatrix m = new DiagonalMatrix(d);
44 for (int i = 0; i < m.getRowDimension(); i++) {
45 for (int j = 0; j < m.getRowDimension(); j++) {
46 if (i == j) {
47 Assert.assertEquals(d[i], m.getEntry(i, j), 0d);
48 } else {
49 Assert.assertEquals(0d, m.getEntry(i, j), 0d);
50 }
51 }
52 }
53
54
55 d[0] = 0;
56 Assert.assertNotEquals(d[0], m.getEntry(0, 0), 0.0);
57 }
58
59 @Test
60 public void testConstructor3() {
61 final double[] d = { -1.2, 3.4, 5 };
62 final DiagonalMatrix m = new DiagonalMatrix(d, false);
63 for (int i = 0; i < m.getRowDimension(); i++) {
64 for (int j = 0; j < m.getRowDimension(); j++) {
65 if (i == j) {
66 Assert.assertEquals(d[i], m.getEntry(i, j), 0d);
67 } else {
68 Assert.assertEquals(0d, m.getEntry(i, j), 0d);
69 }
70 }
71 }
72
73
74 d[0] = 0;
75 Assert.assertEquals(d[0], m.getEntry(0, 0), 0.0);
76 }
77
78 @Test(expected=DimensionMismatchException.class)
79 public void testCreateError() {
80 final double[] d = { -1.2, 3.4, 5 };
81 final DiagonalMatrix m = new DiagonalMatrix(d, false);
82 m.createMatrix(5, 3);
83 }
84
85 @Test
86 public void testCreate() {
87 final double[] d = { -1.2, 3.4, 5 };
88 final DiagonalMatrix m = new DiagonalMatrix(d, false);
89 final RealMatrix p = m.createMatrix(5, 5);
90 Assert.assertTrue(p instanceof DiagonalMatrix);
91 Assert.assertEquals(5, p.getRowDimension());
92 Assert.assertEquals(5, p.getColumnDimension());
93 }
94
95 @Test
96 public void testCopy() {
97 final double[] d = { -1.2, 3.4, 5 };
98 final DiagonalMatrix m = new DiagonalMatrix(d, false);
99 final DiagonalMatrix p = (DiagonalMatrix) m.copy();
100 for (int i = 0; i < m.getRowDimension(); ++i) {
101 Assert.assertEquals(m.getEntry(i, i), p.getEntry(i, i), 1.0e-20);
102 }
103 }
104
105 @Test
106 public void testGetData() {
107 final double[] data = { -1.2, 3.4, 5 };
108 final int dim = 3;
109 final DiagonalMatrix m = new DiagonalMatrix(dim);
110 for (int i = 0; i < dim; i++) {
111 m.setEntry(i, i, data[i]);
112 }
113
114 final double[][] out = m.getData();
115 Assert.assertEquals(dim, out.length);
116 for (int i = 0; i < m.getRowDimension(); i++) {
117 Assert.assertEquals(dim, out[i].length);
118 for (int j = 0; j < m.getRowDimension(); j++) {
119 if (i == j) {
120 Assert.assertEquals(data[i], out[i][j], 0d);
121 } else {
122 Assert.assertEquals(0d, out[i][j], 0d);
123 }
124 }
125 }
126 }
127
128 @Test
129 public void testAdd() {
130 final double[] data1 = { -1.2, 3.4, 5 };
131 final DiagonalMatrix m1 = new DiagonalMatrix(data1);
132
133 final double[] data2 = { 10.1, 2.3, 45 };
134 final DiagonalMatrix m2 = new DiagonalMatrix(data2);
135
136 final DiagonalMatrix result = m1.add(m2);
137 Assert.assertEquals(m1.getRowDimension(), result.getRowDimension());
138 for (int i = 0; i < result.getRowDimension(); i++) {
139 for (int j = 0; j < result.getRowDimension(); j++) {
140 if (i == j) {
141 Assert.assertEquals(data1[i] + data2[i], result.getEntry(i, j), 0d);
142 } else {
143 Assert.assertEquals(0d, result.getEntry(i, j), 0d);
144 }
145 }
146 }
147 }
148
149 @Test
150 public void testSubtract() {
151 final double[] data1 = { -1.2, 3.4, 5 };
152 final DiagonalMatrix m1 = new DiagonalMatrix(data1);
153
154 final double[] data2 = { 10.1, 2.3, 45 };
155 final DiagonalMatrix m2 = new DiagonalMatrix(data2);
156
157 final DiagonalMatrix result = m1.subtract(m2);
158 Assert.assertEquals(m1.getRowDimension(), result.getRowDimension());
159 for (int i = 0; i < result.getRowDimension(); i++) {
160 for (int j = 0; j < result.getRowDimension(); j++) {
161 if (i == j) {
162 Assert.assertEquals(data1[i] - data2[i], result.getEntry(i, j), 0d);
163 } else {
164 Assert.assertEquals(0d, result.getEntry(i, j), 0d);
165 }
166 }
167 }
168 }
169
170 @Test
171 public void testAddToEntry() {
172 final double[] data = { -1.2, 3.4, 5 };
173 final DiagonalMatrix m = new DiagonalMatrix(data);
174
175 for (int i = 0; i < m.getRowDimension(); i++) {
176 m.addToEntry(i, i, i);
177 Assert.assertEquals(data[i] + i, m.getEntry(i, i), 0d);
178 }
179 }
180
181 @Test
182 public void testMultiplyEntry() {
183 final double[] data = { -1.2, 3.4, 5 };
184 final DiagonalMatrix m = new DiagonalMatrix(data);
185
186 for (int i = 0; i < m.getRowDimension(); i++) {
187 m.multiplyEntry(i, i, i);
188 Assert.assertEquals(data[i] * i, m.getEntry(i, i), 0d);
189 }
190 }
191
192 @Test
193 public void testMultiply1() {
194 final double[] data1 = { -1.2, 3.4, 5 };
195 final DiagonalMatrix m1 = new DiagonalMatrix(data1);
196 final double[] data2 = { 10.1, 2.3, 45 };
197 final DiagonalMatrix m2 = new DiagonalMatrix(data2);
198
199 final DiagonalMatrix result = (DiagonalMatrix) m1.multiply((RealMatrix) m2);
200 Assert.assertEquals(m1.getRowDimension(), result.getRowDimension());
201 for (int i = 0; i < result.getRowDimension(); i++) {
202 for (int j = 0; j < result.getRowDimension(); j++) {
203 if (i == j) {
204 Assert.assertEquals(data1[i] * data2[i], result.getEntry(i, j), 0d);
205 } else {
206 Assert.assertEquals(0d, result.getEntry(i, j), 0d);
207 }
208 }
209 }
210 }
211
212 @Test
213 public void testMultiply2() {
214 final double[] data1 = { -1.2, 3.4, 5 };
215 final DiagonalMatrix diag1 = new DiagonalMatrix(data1);
216
217 final double[][] data2 = { { -1.2, 3.4 },
218 { -5.6, 7.8 },
219 { 9.1, 2.3 } };
220 final RealMatrix dense2 = new Array2DRowRealMatrix(data2);
221 final RealMatrix dense1 = new Array2DRowRealMatrix(diag1.getData());
222
223 final RealMatrix diagResult = diag1.multiply(dense2);
224 final RealMatrix denseResult = dense1.multiply(dense2);
225
226 for (int i = 0; i < dense1.getRowDimension(); i++) {
227 for (int j = 0; j < dense2.getColumnDimension(); j++) {
228 Assert.assertEquals(denseResult.getEntry(i, j),
229 diagResult.getEntry(i, j), 0d);
230 }
231 }
232 }
233
234 @Test
235 public void testOperate() {
236 final double[] data = { -1.2, 3.4, 5 };
237 final DiagonalMatrix diag = new DiagonalMatrix(data);
238 final RealMatrix dense = new Array2DRowRealMatrix(diag.getData());
239
240 final double[] v = { 6.7, 890.1, 23.4 };
241 final double[] diagResult = diag.operate(v);
242 final double[] denseResult = dense.operate(v);
243
244 TestUtils.assertEquals(diagResult, denseResult, 0d);
245 }
246
247 @Test
248 public void testPreMultiply() {
249 final double[] data = { -1.2, 3.4, 5 };
250 final DiagonalMatrix diag = new DiagonalMatrix(data);
251 final RealMatrix dense = new Array2DRowRealMatrix(diag.getData());
252
253 final double[] v = { 6.7, 890.1, 23.4 };
254 final double[] diagResult = diag.preMultiply(v);
255 final double[] denseResult = dense.preMultiply(v);
256
257 TestUtils.assertEquals(diagResult, denseResult, 0d);
258 }
259
260 @Test
261 public void testPreMultiplyVector() {
262 final double[] data = { -1.2, 3.4, 5 };
263 final DiagonalMatrix diag = new DiagonalMatrix(data);
264 final RealMatrix dense = new Array2DRowRealMatrix(diag.getData());
265
266 final double[] v = { 6.7, 890.1, 23.4 };
267 final RealVector vector = MatrixUtils.createRealVector(v);
268 final RealVector diagResult = diag.preMultiply(vector);
269 final RealVector denseResult = dense.preMultiply(vector);
270
271 TestUtils.assertEquals("preMultiply(Vector) returns wrong result", diagResult, denseResult, 0d);
272 }
273
274 @Test(expected=NumberIsTooLargeException.class)
275 public void testSetNonDiagonalEntry() {
276 final DiagonalMatrix diag = new DiagonalMatrix(3);
277 diag.setEntry(1, 2, 3.4);
278 }
279
280 @Test
281 public void testSetNonDiagonalZero() {
282 final DiagonalMatrix diag = new DiagonalMatrix(3);
283 diag.setEntry(1, 2, 0.0);
284 Assert.assertEquals(0.0, diag.getEntry(1, 2), Precision.SAFE_MIN);
285 }
286
287 @Test(expected=NumberIsTooLargeException.class)
288 public void testAddNonDiagonalEntry() {
289 final DiagonalMatrix diag = new DiagonalMatrix(3);
290 diag.addToEntry(1, 2, 3.4);
291 }
292
293 @Test
294 public void testAddNonDiagonalZero() {
295 final DiagonalMatrix diag = new DiagonalMatrix(3);
296 diag.addToEntry(1, 2, 0.0);
297 Assert.assertEquals(0.0, diag.getEntry(1, 2), Precision.SAFE_MIN);
298 }
299
300 @Test
301 public void testMultiplyNonDiagonalEntry() {
302 final DiagonalMatrix diag = new DiagonalMatrix(3);
303 diag.multiplyEntry(1, 2, 3.4);
304 Assert.assertEquals(0.0, diag.getEntry(1, 2), Precision.SAFE_MIN);
305 }
306
307 @Test
308 public void testMultiplyNonDiagonalZero() {
309 final DiagonalMatrix diag = new DiagonalMatrix(3);
310 diag.multiplyEntry(1, 2, 0.0);
311 Assert.assertEquals(0.0, diag.getEntry(1, 2), Precision.SAFE_MIN);
312 }
313
314 @Test(expected=OutOfRangeException.class)
315 public void testSetEntryOutOfRange() {
316 final DiagonalMatrix diag = new DiagonalMatrix(3);
317 diag.setEntry(3, 3, 3.4);
318 }
319
320 @Test(expected=NullArgumentException.class)
321 public void testNull() {
322 new DiagonalMatrix(null, false);
323 }
324
325 @Test(expected=NumberIsTooLargeException.class)
326 public void testSetSubMatrixError() {
327 final double[] data = { -1.2, 3.4, 5 };
328 final DiagonalMatrix diag = new DiagonalMatrix(data);
329 diag.setSubMatrix(new double[][] { {1.0, 1.0}, {1.0, 1.0}}, 1, 1);
330 }
331
332 @Test
333 public void testSetSubMatrix() {
334 final double[] data = { -1.2, 3.4, 5 };
335 final DiagonalMatrix diag = new DiagonalMatrix(data);
336 diag.setSubMatrix(new double[][] { {0.0, 5.0, 0.0}, {0.0, 0.0, 6.0}}, 1, 0);
337 Assert.assertEquals(-1.2, diag.getEntry(0, 0), 1.0e-20);
338 Assert.assertEquals( 5.0, diag.getEntry(1, 1), 1.0e-20);
339 Assert.assertEquals( 6.0, diag.getEntry(2, 2), 1.0e-20);
340 }
341
342 @Test(expected=SingularMatrixException.class)
343 public void testInverseError() {
344 final double[] data = { 1, 2, 0 };
345 final DiagonalMatrix diag = new DiagonalMatrix(data);
346 diag.inverse();
347 }
348
349 @Test(expected=SingularMatrixException.class)
350 public void testInverseError2() {
351 final double[] data = { 1, 2, 1e-6 };
352 final DiagonalMatrix diag = new DiagonalMatrix(data);
353 diag.inverse(1e-5);
354 }
355
356 @Test
357 public void testInverse() {
358 final double[] data = { 1, 2, 3 };
359 final DiagonalMatrix m = new DiagonalMatrix(data);
360 final DiagonalMatrix inverse = m.inverse();
361
362 final DiagonalMatrix result = m.multiply(inverse);
363 TestUtils.assertEquals("DiagonalMatrix.inverse() returns wrong result",
364 MatrixUtils.createRealIdentityMatrix(data.length), result, Math.ulp(1d));
365 }
366 }