1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.filter;
19
20 import org.apache.commons.rng.simple.RandomSource;
21 import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
22 import org.apache.commons.rng.sampling.distribution.GaussianSampler;
23 import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
24 import org.apache.commons.numbers.core.Precision;
25 import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
26 import org.apache.commons.math4.legacy.linear.ArrayRealVector;
27 import org.apache.commons.math4.legacy.linear.MatrixDimensionMismatchException;
28 import org.apache.commons.math4.legacy.linear.MatrixUtils;
29 import org.apache.commons.math4.legacy.linear.RealMatrix;
30 import org.apache.commons.math4.legacy.linear.RealVector;
31 import org.apache.commons.math4.core.jdkmath.JdkMath;
32 import org.junit.Assert;
33 import org.junit.Test;
34
35
36
37
38
39 public class KalmanFilterTest {
40
41 @Test(expected=MatrixDimensionMismatchException.class)
42 public void testTransitionMeasurementMatrixMismatch() {
43
44
45
46
47 RealMatrix A = new Array2DRowRealMatrix(new double[] { 1d });
48
49 RealMatrix B = null;
50
51 RealMatrix H = new Array2DRowRealMatrix(new double[] { 1d, 1d });
52
53 RealMatrix Q = new Array2DRowRealMatrix(new double[] { 0 });
54
55 RealMatrix R = new Array2DRowRealMatrix(new double[] { 0 });
56
57 ProcessModel pm
58 = new DefaultProcessModel(A, B, Q,
59 new ArrayRealVector(new double[] { 0 }), null);
60 MeasurementModel mm = new DefaultMeasurementModel(H, R);
61 new KalmanFilter(pm, mm);
62 Assert.fail("transition and measurement matrix should not be compatible");
63 }
64
65 @Test(expected=MatrixDimensionMismatchException.class)
66 public void testTransitionControlMatrixMismatch() {
67
68
69
70
71 RealMatrix A = new Array2DRowRealMatrix(new double[] { 1d });
72
73 RealMatrix B = new Array2DRowRealMatrix(new double[] { 1d, 1d });
74
75 RealMatrix H = new Array2DRowRealMatrix(new double[] { 1d });
76
77 RealMatrix Q = new Array2DRowRealMatrix(new double[] { 0 });
78
79 RealMatrix R = new Array2DRowRealMatrix(new double[] { 0 });
80
81 ProcessModel pm
82 = new DefaultProcessModel(A, B, Q,
83 new ArrayRealVector(new double[] { 0 }), null);
84 MeasurementModel mm = new DefaultMeasurementModel(H, R);
85 new KalmanFilter(pm, mm);
86 Assert.fail("transition and control matrix should not be compatible");
87 }
88
89 @Test
90 public void testConstant() {
91
92
93 double constantValue = 10d;
94 double measurementNoise = 0.1d;
95 double processNoise = 1e-5d;
96
97
98 RealMatrix A = new Array2DRowRealMatrix(new double[] { 1d });
99
100 RealMatrix B = null;
101
102 RealMatrix H = new Array2DRowRealMatrix(new double[] { 1d });
103
104 RealVector x = new ArrayRealVector(new double[] { constantValue });
105
106 RealMatrix Q = new Array2DRowRealMatrix(new double[] { processNoise });
107
108 RealMatrix R = new Array2DRowRealMatrix(new double[] { measurementNoise });
109
110 ProcessModel pm
111 = new DefaultProcessModel(A, B, Q,
112 new ArrayRealVector(new double[] { constantValue }), null);
113 MeasurementModel mm = new DefaultMeasurementModel(H, R);
114 KalmanFilter filter = new KalmanFilter(pm, mm);
115
116 Assert.assertEquals(1, filter.getMeasurementDimension());
117 Assert.assertEquals(1, filter.getStateDimension());
118
119 assertMatrixEquals(Q.getData(), filter.getErrorCovariance());
120
121
122 double[] expectedInitialState = new double[] { constantValue };
123 assertVectorEquals(expectedInitialState, filter.getStateEstimation());
124
125 RealVector pNoise = new ArrayRealVector(1);
126 RealVector mNoise = new ArrayRealVector(1);
127
128 final ContinuousSampler rand = createGaussianSampler(0, 1);
129
130
131 for (int i = 0; i < 60; i++) {
132 filter.predict();
133
134
135 pNoise.setEntry(0, processNoise * rand.sample());
136
137
138 x = A.operate(x).add(pNoise);
139
140
141 mNoise.setEntry(0, measurementNoise * rand.sample());
142
143
144 RealVector z = H.operate(x).add(mNoise);
145
146 filter.correct(z);
147
148
149 double diff = JdkMath.abs(constantValue - filter.getStateEstimation()[0]);
150
151 Assert.assertTrue(Precision.compareTo(diff, measurementNoise, 1e-6) < 0);
152 }
153
154
155 Assert.assertTrue(Precision.compareTo(filter.getErrorCovariance()[0][0],
156 0.02d, 1e-6) < 0);
157 }
158
159 @Test
160 public void testConstantAcceleration() {
161
162
163
164 double dt = 0.1d;
165
166 double measurementNoise = 10d;
167
168 double accelNoise = 0.2d;
169
170
171
172 RealMatrix A = new Array2DRowRealMatrix(new double[][] { { 1, dt }, { 0, 1 } });
173
174
175
176 RealMatrix B = new Array2DRowRealMatrix(
177 new double[][] { { JdkMath.pow(dt, 2d) / 2d }, { dt } });
178
179
180 RealMatrix H = new Array2DRowRealMatrix(new double[][] { { 1d, 0d } });
181
182
183 RealVector x = new ArrayRealVector(new double[] { 0, 0 });
184
185 RealMatrix tmp = new Array2DRowRealMatrix(
186 new double[][] { { JdkMath.pow(dt, 4d) / 4d, JdkMath.pow(dt, 3d) / 2d },
187 { JdkMath.pow(dt, 3d) / 2d, JdkMath.pow(dt, 2d) } });
188
189
190
191 RealMatrix Q = tmp.scalarMultiply(JdkMath.pow(accelNoise, 2));
192
193
194
195 RealMatrix P0 = new Array2DRowRealMatrix(new double[][] { { 1, 1 }, { 1, 1 } });
196
197
198 RealMatrix R = new Array2DRowRealMatrix(
199 new double[] { JdkMath.pow(measurementNoise, 2) });
200
201
202 RealVector u = new ArrayRealVector(new double[] { 0.1d });
203
204 ProcessModel pm = new DefaultProcessModel(A, B, Q, x, P0);
205 MeasurementModel mm = new DefaultMeasurementModel(H, R);
206 KalmanFilter filter = new KalmanFilter(pm, mm);
207
208 Assert.assertEquals(1, filter.getMeasurementDimension());
209 Assert.assertEquals(2, filter.getStateDimension());
210
211 assertMatrixEquals(P0.getData(), filter.getErrorCovariance());
212
213
214 double[] expectedInitialState = new double[] { 0.0, 0.0 };
215 assertVectorEquals(expectedInitialState, filter.getStateEstimation());
216
217 final ContinuousSampler rand = createGaussianSampler(0, 1);
218
219 RealVector tmpPNoise = new ArrayRealVector(
220 new double[] { JdkMath.pow(dt, 2d) / 2d, dt });
221
222
223 for (int i = 0; i < 60; i++) {
224 filter.predict(u);
225
226
227 RealVector pNoise = tmpPNoise.mapMultiply(accelNoise * rand.sample());
228
229
230 x = A.operate(x).add(B.operate(u)).add(pNoise);
231
232
233 double mNoise = measurementNoise * rand.sample();
234
235
236 RealVector z = H.operate(x).mapAdd(mNoise);
237
238 filter.correct(z);
239
240
241 double diff = JdkMath.abs(x.getEntry(0) - filter.getStateEstimation()[0]);
242 Assert.assertTrue(Precision.compareTo(diff, measurementNoise, 1e-6) < 0);
243 }
244
245
246 Assert.assertTrue(Precision.compareTo(filter.getErrorCovariance()[1][1],
247 0.1d, 1e-6) < 0);
248 }
249
250
251
252
253 public static class Cannonball {
254
255 private final double[] gravity = { 0, -9.81 };
256
257 private final double[] velocity;
258 private final double[] location;
259
260 private double timeslice;
261
262 public Cannonball(double timeslice, double angle, double initialVelocity) {
263 this.timeslice = timeslice;
264
265 final double angleInRadians = JdkMath.toRadians(angle);
266 this.velocity = new double[] {
267 initialVelocity * JdkMath.cos(angleInRadians),
268 initialVelocity * JdkMath.sin(angleInRadians)
269 };
270
271 this.location = new double[] { 0, 0 };
272 }
273
274 public double getX() {
275 return location[0];
276 }
277
278 public double getY() {
279 return location[1];
280 }
281
282 public double getXVelocity() {
283 return velocity[0];
284 }
285
286 public double getYVelocity() {
287 return velocity[1];
288 }
289
290 public void step() {
291
292 double[] slicedGravity = gravity.clone();
293 for ( int i = 0; i < slicedGravity.length; i++ ) {
294 slicedGravity[i] *= timeslice;
295 }
296
297
298 double[] slicedVelocity = velocity.clone();
299 for ( int i = 0; i < velocity.length; i++ ) {
300 velocity[i] += slicedGravity[i];
301 slicedVelocity[i] = velocity[i] * timeslice;
302 location[i] += slicedVelocity[i];
303 }
304
305
306 if ( location[1] < 0 ) {
307 location[1] = 0;
308 }
309 }
310 }
311
312 @Test
313 public void testCannonball() {
314
315
316
317 final int iterations = 144;
318
319 final double dt = 0.1d;
320
321 final double measurementNoise = 30d;
322
323 final double initialVelocity = 100;
324
325 final double angle = 45;
326
327 final Cannonball cannonball = new Cannonball(dt, angle, initialVelocity);
328
329 final double speedX = cannonball.getXVelocity();
330 final double speedY = cannonball.getYVelocity();
331
332
333
334
335
336 final RealMatrix A = MatrixUtils.createRealMatrix(new double[][] {
337 { 1, dt, 0, 0 },
338 { 0, 1, 0, 0 },
339 { 0, 0, 1, dt },
340 { 0, 0, 0, 1 }
341 });
342
343
344
345
346
347
348 final RealVector controlVector =
349 MatrixUtils.createRealVector(new double[] { 0, 0, 0.5 * -9.81 * dt * dt, -9.81 * dt } );
350
351
352 final RealMatrix B = MatrixUtils.createRealMatrix(new double[][] {
353 { 0, 0, 0, 0 },
354 { 0, 0, 0, 0 },
355 { 0, 0, 1, 0 },
356 { 0, 0, 0, 1 }
357 });
358
359
360 final RealMatrix H = MatrixUtils.createRealMatrix(new double[][] {
361 { 1, 0, 0, 0 },
362 { 0, 0, 0, 0 },
363 { 0, 0, 1, 0 },
364 { 0, 0, 0, 0 }
365 });
366
367
368 final RealVector initialState = MatrixUtils.createRealVector(new double[] { 0, speedX, 0, speedY } );
369
370
371 final double var = measurementNoise * measurementNoise;
372 final RealMatrix initialErrorCovariance = MatrixUtils.createRealMatrix(new double[][] {
373 { var, 0, 0, 0 },
374 { 0, 1e-3, 0, 0 },
375 { 0, 0, var, 0 },
376 { 0, 0, 0, 1e-3 }
377 });
378
379
380 final RealMatrix Q = MatrixUtils.createRealMatrix(4, 4);
381
382
383 final RealMatrix R = MatrixUtils.createRealMatrix(new double[][] {
384 { var, 0, 0, 0 },
385 { 0, 1e-3, 0, 0 },
386 { 0, 0, var, 0 },
387 { 0, 0, 0, 1e-3 }
388 });
389
390 final ProcessModel pm = new DefaultProcessModel(A, B, Q, initialState, initialErrorCovariance);
391 final MeasurementModel mm = new DefaultMeasurementModel(H, R);
392 final KalmanFilter filter = new KalmanFilter(pm, mm);
393
394 final ContinuousSampler rand = createGaussianSampler(0, measurementNoise);
395
396 for (int i = 0; i < iterations; i++) {
397
398 double x = cannonball.getX();
399 double y = cannonball.getY();
400
401
402 double nx = x + rand.sample();
403 double ny = y + rand.sample();
404
405 cannonball.step();
406
407 filter.predict(controlVector);
408
409 filter.correct(new double[] { nx, 0, ny, 0 } );
410
411
412 double diff = JdkMath.abs(cannonball.getY() - filter.getStateEstimation()[2]);
413 Assert.assertTrue(Precision.compareTo(diff, measurementNoise, 1e-6) < 0);
414 }
415
416
417
418 Assert.assertTrue(Precision.compareTo(filter.getErrorCovariance()[0][0],
419 9, 1e-6) < 0);
420
421 Assert.assertTrue(Precision.compareTo(filter.getErrorCovariance()[2][2],
422 9, 1e-6) < 0);
423 }
424
425 private void assertVectorEquals(double[] expected, double[] result) {
426 Assert.assertEquals("Wrong number of rows.", expected.length,
427 result.length);
428 for (int i = 0; i < expected.length; i++) {
429 Assert.assertEquals("Wrong value at position [" + i + "]",
430 expected[i], result[i], 1.0e-6);
431 }
432 }
433
434 private void assertMatrixEquals(double[][] expected, double[][] result) {
435 Assert.assertEquals("Wrong number of rows.", expected.length,
436 result.length);
437 for (int i = 0; i < expected.length; i++) {
438 Assert.assertEquals("Wrong number of columns.", expected[i].length,
439 result[i].length);
440 for (int j = 0; j < expected[i].length; j++) {
441 Assert.assertEquals("Wrong value at position [" + i + "," + j
442 + "]", expected[i][j], result[i][j], 1.0e-6);
443 }
444 }
445 }
446
447
448
449
450
451
452 private ContinuousSampler createGaussianSampler(double mu,
453 double sigma) {
454 return GaussianSampler.of(ZigguratNormalizedGaussianSampler.of(RandomSource.JSF_64.create()),
455 mu, sigma);
456 }
457 }