View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.filter;
19  
20  import org.apache.commons.rng.simple.RandomSource;
21  import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
22  import org.apache.commons.rng.sampling.distribution.GaussianSampler;
23  import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
24  import org.apache.commons.numbers.core.Precision;
25  import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
26  import org.apache.commons.math4.legacy.linear.ArrayRealVector;
27  import org.apache.commons.math4.legacy.linear.MatrixDimensionMismatchException;
28  import org.apache.commons.math4.legacy.linear.MatrixUtils;
29  import org.apache.commons.math4.legacy.linear.RealMatrix;
30  import org.apache.commons.math4.legacy.linear.RealVector;
31  import org.apache.commons.math4.core.jdkmath.JdkMath;
32  import org.junit.Assert;
33  import org.junit.Test;
34  
35  /**
36   * Tests for {@link KalmanFilter}.
37   *
38   */
39  public class KalmanFilterTest {
40  
41      @Test(expected=MatrixDimensionMismatchException.class)
42      public void testTransitionMeasurementMatrixMismatch() {
43  
44          // A and H matrix do not match in dimensions
45  
46          // A = [ 1 ]
47          RealMatrix A = new Array2DRowRealMatrix(new double[] { 1d });
48          // no control input
49          RealMatrix B = null;
50          // H = [ 1 1 ]
51          RealMatrix H = new Array2DRowRealMatrix(new double[] { 1d, 1d });
52          // Q = [ 0 ]
53          RealMatrix Q = new Array2DRowRealMatrix(new double[] { 0 });
54          // R = [ 0 ]
55          RealMatrix R = new Array2DRowRealMatrix(new double[] { 0 });
56  
57          ProcessModel pm
58              = new DefaultProcessModel(A, B, Q,
59                                        new ArrayRealVector(new double[] { 0 }), null);
60          MeasurementModel mm = new DefaultMeasurementModel(H, R);
61          new KalmanFilter(pm, mm);
62          Assert.fail("transition and measurement matrix should not be compatible");
63      }
64  
65      @Test(expected=MatrixDimensionMismatchException.class)
66      public void testTransitionControlMatrixMismatch() {
67  
68          // A and B matrix do not match in dimensions
69  
70          // A = [ 1 ]
71          RealMatrix A = new Array2DRowRealMatrix(new double[] { 1d });
72          // B = [ 1 1 ]
73          RealMatrix B = new Array2DRowRealMatrix(new double[] { 1d, 1d });
74          // H = [ 1 ]
75          RealMatrix H = new Array2DRowRealMatrix(new double[] { 1d });
76          // Q = [ 0 ]
77          RealMatrix Q = new Array2DRowRealMatrix(new double[] { 0 });
78          // R = [ 0 ]
79          RealMatrix R = new Array2DRowRealMatrix(new double[] { 0 });
80  
81          ProcessModel pm
82              = new DefaultProcessModel(A, B, Q,
83                                        new ArrayRealVector(new double[] { 0 }), null);
84          MeasurementModel mm = new DefaultMeasurementModel(H, R);
85          new KalmanFilter(pm, mm);
86          Assert.fail("transition and control matrix should not be compatible");
87      }
88  
89      @Test
90      public void testConstant() {
91          // simulates a simple process with a constant state and no control input
92  
93          double constantValue = 10d;
94          double measurementNoise = 0.1d;
95          double processNoise = 1e-5d;
96  
97          // A = [ 1 ]
98          RealMatrix A = new Array2DRowRealMatrix(new double[] { 1d });
99          // no control input
100         RealMatrix B = null;
101         // H = [ 1 ]
102         RealMatrix H = new Array2DRowRealMatrix(new double[] { 1d });
103         // x = [ 10 ]
104         RealVector x = new ArrayRealVector(new double[] { constantValue });
105         // Q = [ 1e-5 ]
106         RealMatrix Q = new Array2DRowRealMatrix(new double[] { processNoise });
107         // R = [ 0.1 ]
108         RealMatrix R = new Array2DRowRealMatrix(new double[] { measurementNoise });
109 
110         ProcessModel pm
111             = new DefaultProcessModel(A, B, Q,
112                                       new ArrayRealVector(new double[] { constantValue }), null);
113         MeasurementModel mm = new DefaultMeasurementModel(H, R);
114         KalmanFilter filter = new KalmanFilter(pm, mm);
115 
116         Assert.assertEquals(1, filter.getMeasurementDimension());
117         Assert.assertEquals(1, filter.getStateDimension());
118 
119         assertMatrixEquals(Q.getData(), filter.getErrorCovariance());
120 
121         // check the initial state
122         double[] expectedInitialState = new double[] { constantValue };
123         assertVectorEquals(expectedInitialState, filter.getStateEstimation());
124 
125         RealVector pNoise = new ArrayRealVector(1);
126         RealVector mNoise = new ArrayRealVector(1);
127 
128         final ContinuousSampler rand = createGaussianSampler(0, 1);
129 
130         // iterate 60 steps
131         for (int i = 0; i < 60; i++) {
132             filter.predict();
133 
134             // Simulate the process
135             pNoise.setEntry(0, processNoise * rand.sample());
136 
137             // x = A * x + p_noise
138             x = A.operate(x).add(pNoise);
139 
140             // Simulate the measurement
141             mNoise.setEntry(0, measurementNoise * rand.sample());
142 
143             // z = H * x + m_noise
144             RealVector z = H.operate(x).add(mNoise);
145 
146             filter.correct(z);
147 
148             // state estimate shouldn't be larger than measurement noise
149             double diff = JdkMath.abs(constantValue - filter.getStateEstimation()[0]);
150             // System.out.println(diff);
151             Assert.assertTrue(Precision.compareTo(diff, measurementNoise, 1e-6) < 0);
152         }
153 
154         // error covariance should be already very low (< 0.02)
155         Assert.assertTrue(Precision.compareTo(filter.getErrorCovariance()[0][0],
156                                               0.02d, 1e-6) < 0);
157     }
158 
159     @Test
160     public void testConstantAcceleration() {
161         // simulates a vehicle, accelerating at a constant rate (0.1 m/s)
162 
163         // discrete time interval
164         double dt = 0.1d;
165         // position measurement noise (meter)
166         double measurementNoise = 10d;
167         // acceleration noise (meter/sec^2)
168         double accelNoise = 0.2d;
169 
170         // A = [ 1 dt ]
171         //     [ 0  1 ]
172         RealMatrix A = new Array2DRowRealMatrix(new double[][] { { 1, dt }, { 0, 1 } });
173 
174         // B = [ dt^2/2 ]
175         //     [ dt     ]
176         RealMatrix B = new Array2DRowRealMatrix(
177                 new double[][] { { JdkMath.pow(dt, 2d) / 2d }, { dt } });
178 
179         // H = [ 1 0 ]
180         RealMatrix H = new Array2DRowRealMatrix(new double[][] { { 1d, 0d } });
181 
182         // x = [ 0 0 ]
183         RealVector x = new ArrayRealVector(new double[] { 0, 0 });
184 
185         RealMatrix tmp = new Array2DRowRealMatrix(
186                 new double[][] { { JdkMath.pow(dt, 4d) / 4d, JdkMath.pow(dt, 3d) / 2d },
187                                  { JdkMath.pow(dt, 3d) / 2d, JdkMath.pow(dt, 2d) } });
188 
189         // Q = [ dt^4/4 dt^3/2 ]
190         //     [ dt^3/2 dt^2   ]
191         RealMatrix Q = tmp.scalarMultiply(JdkMath.pow(accelNoise, 2));
192 
193         // P0 = [ 1 1 ]
194         //      [ 1 1 ]
195         RealMatrix P0 = new Array2DRowRealMatrix(new double[][] { { 1, 1 }, { 1, 1 } });
196 
197         // R = [ measurementNoise^2 ]
198         RealMatrix R = new Array2DRowRealMatrix(
199                 new double[] { JdkMath.pow(measurementNoise, 2) });
200 
201         // constant control input, increase velocity by 0.1 m/s per cycle
202         RealVector u = new ArrayRealVector(new double[] { 0.1d });
203 
204         ProcessModel pm = new DefaultProcessModel(A, B, Q, x, P0);
205         MeasurementModel mm = new DefaultMeasurementModel(H, R);
206         KalmanFilter filter = new KalmanFilter(pm, mm);
207 
208         Assert.assertEquals(1, filter.getMeasurementDimension());
209         Assert.assertEquals(2, filter.getStateDimension());
210 
211         assertMatrixEquals(P0.getData(), filter.getErrorCovariance());
212 
213         // check the initial state
214         double[] expectedInitialState = new double[] { 0.0, 0.0 };
215         assertVectorEquals(expectedInitialState, filter.getStateEstimation());
216 
217         final ContinuousSampler rand = createGaussianSampler(0, 1);
218 
219         RealVector tmpPNoise = new ArrayRealVector(
220                 new double[] { JdkMath.pow(dt, 2d) / 2d, dt });
221 
222         // iterate 60 steps
223         for (int i = 0; i < 60; i++) {
224             filter.predict(u);
225 
226             // Simulate the process
227             RealVector pNoise = tmpPNoise.mapMultiply(accelNoise * rand.sample());
228 
229             // x = A * x + B * u + pNoise
230             x = A.operate(x).add(B.operate(u)).add(pNoise);
231 
232             // Simulate the measurement
233             double mNoise = measurementNoise * rand.sample();
234 
235             // z = H * x + m_noise
236             RealVector z = H.operate(x).mapAdd(mNoise);
237 
238             filter.correct(z);
239 
240             // state estimate shouldn't be larger than the measurement noise
241             double diff = JdkMath.abs(x.getEntry(0) - filter.getStateEstimation()[0]);
242             Assert.assertTrue(Precision.compareTo(diff, measurementNoise, 1e-6) < 0);
243         }
244 
245         // error covariance of the velocity should be already very low (< 0.1)
246         Assert.assertTrue(Precision.compareTo(filter.getErrorCovariance()[1][1],
247                                               0.1d, 1e-6) < 0);
248     }
249 
250     /**
251      * Represents an idealized Cannonball only taking into account gravity.
252      */
253     public static class Cannonball {
254 
255         private final double[] gravity = { 0, -9.81 };
256 
257         private final double[] velocity;
258         private final double[] location;
259 
260         private double timeslice;
261 
262         public Cannonball(double timeslice, double angle, double initialVelocity) {
263             this.timeslice = timeslice;
264 
265             final double angleInRadians = JdkMath.toRadians(angle);
266             this.velocity = new double[] {
267                     initialVelocity * JdkMath.cos(angleInRadians),
268                     initialVelocity * JdkMath.sin(angleInRadians)
269             };
270 
271             this.location = new double[] { 0, 0 };
272         }
273 
274         public double getX() {
275             return location[0];
276         }
277 
278         public double getY() {
279             return location[1];
280         }
281 
282         public double getXVelocity() {
283             return velocity[0];
284         }
285 
286         public double getYVelocity() {
287             return velocity[1];
288         }
289 
290         public void step() {
291             // break gravitational force into a smaller time slice.
292             double[] slicedGravity = gravity.clone();
293             for ( int i = 0; i < slicedGravity.length; i++ ) {
294                 slicedGravity[i] *= timeslice;
295             }
296 
297             // apply the acceleration to velocity.
298             double[] slicedVelocity = velocity.clone();
299             for ( int i = 0; i < velocity.length; i++ ) {
300                 velocity[i] += slicedGravity[i];
301                 slicedVelocity[i] = velocity[i] * timeslice;
302                 location[i] += slicedVelocity[i];
303             }
304 
305             // cannonballs shouldn't go into the ground.
306             if ( location[1] < 0 ) {
307                 location[1] = 0;
308             }
309         }
310     }
311 
312     @Test
313     public void testCannonball() {
314         // simulates the flight of a cannonball (only taking gravity and initial thrust into account)
315 
316         // number of iterations
317         final int iterations = 144;
318         // discrete time interval
319         final double dt = 0.1d;
320         // position measurement noise (meter)
321         final double measurementNoise = 30d;
322         // the initial velocity of the cannonball
323         final double initialVelocity = 100;
324         // shooting angle
325         final double angle = 45;
326 
327         final Cannonball cannonball = new Cannonball(dt, angle, initialVelocity);
328 
329         final double speedX = cannonball.getXVelocity();
330         final double speedY = cannonball.getYVelocity();
331 
332         // A = [ 1, dt, 0,  0 ]  =>  x(n+1) = x(n) + vx(n)
333         //     [ 0,  1, 0,  0 ]  => vx(n+1) =        vx(n)
334         //     [ 0,  0, 1, dt ]  =>  y(n+1) =              y(n) + vy(n)
335         //     [ 0,  0, 0,  1 ]  => vy(n+1) =                     vy(n)
336         final RealMatrix A = MatrixUtils.createRealMatrix(new double[][] {
337                 { 1, dt, 0,  0 },
338                 { 0,  1, 0,  0 },
339                 { 0,  0, 1, dt },
340                 { 0,  0, 0,  1 }
341         });
342 
343         // The control vector, which adds acceleration to the kinematic equations.
344         // 0          =>  x(n+1) =  x(n+1)
345         // 0          => vx(n+1) = vx(n+1)
346         // -9.81*dt^2 =>  y(n+1) =  y(n+1) - 1/2 * 9.81 * dt^2
347         // -9.81*dt   => vy(n+1) = vy(n+1) - 9.81 * dt
348         final RealVector controlVector =
349                 MatrixUtils.createRealVector(new double[] { 0, 0, 0.5 * -9.81 * dt * dt, -9.81 * dt } );
350 
351         // The control matrix B only expects y and vy, see control vector
352         final RealMatrix B = MatrixUtils.createRealMatrix(new double[][] {
353                 { 0, 0, 0, 0 },
354                 { 0, 0, 0, 0 },
355                 { 0, 0, 1, 0 },
356                 { 0, 0, 0, 1 }
357         });
358 
359         // We only observe the x/y position of the cannonball
360         final RealMatrix H = MatrixUtils.createRealMatrix(new double[][] {
361                 { 1, 0, 0, 0 },
362                 { 0, 0, 0, 0 },
363                 { 0, 0, 1, 0 },
364                 { 0, 0, 0, 0 }
365         });
366 
367         // our guess of the initial state.
368         final RealVector initialState = MatrixUtils.createRealVector(new double[] { 0, speedX, 0, speedY } );
369 
370         // the initial error covariance matrix, the variance = noise^2
371         final double var = measurementNoise * measurementNoise;
372         final RealMatrix initialErrorCovariance = MatrixUtils.createRealMatrix(new double[][] {
373                 { var,    0,   0,    0 },
374                 {   0, 1e-3,   0,    0 },
375                 {   0,    0, var,    0 },
376                 {   0,    0,   0, 1e-3 }
377         });
378 
379         // we assume no process noise -> zero matrix
380         final RealMatrix Q = MatrixUtils.createRealMatrix(4, 4);
381 
382         // the measurement covariance matrix
383         final RealMatrix R = MatrixUtils.createRealMatrix(new double[][] {
384                 { var,    0,   0,    0 },
385                 {   0, 1e-3,   0,    0 },
386                 {   0,    0, var,    0 },
387                 {   0,    0,   0, 1e-3 }
388         });
389 
390         final ProcessModel pm = new DefaultProcessModel(A, B, Q, initialState, initialErrorCovariance);
391         final MeasurementModel mm = new DefaultMeasurementModel(H, R);
392         final KalmanFilter filter = new KalmanFilter(pm, mm);
393 
394         final ContinuousSampler rand = createGaussianSampler(0, measurementNoise);
395 
396         for (int i = 0; i < iterations; i++) {
397             // get the "real" cannonball position
398             double x = cannonball.getX();
399             double y = cannonball.getY();
400 
401             // apply measurement noise to current cannonball position
402             double nx = x + rand.sample();
403             double ny = y + rand.sample();
404 
405             cannonball.step();
406 
407             filter.predict(controlVector);
408             // correct the filter with our measurements
409             filter.correct(new double[] { nx, 0, ny, 0 } );
410 
411             // state estimate shouldn't be larger than the measurement noise
412             double diff = JdkMath.abs(cannonball.getY() - filter.getStateEstimation()[2]);
413             Assert.assertTrue(Precision.compareTo(diff, measurementNoise, 1e-6) < 0);
414         }
415 
416         // error covariance of the x/y-position should be already very low (< 3m std dev = 9 variance)
417 
418         Assert.assertTrue(Precision.compareTo(filter.getErrorCovariance()[0][0],
419                                               9, 1e-6) < 0);
420 
421         Assert.assertTrue(Precision.compareTo(filter.getErrorCovariance()[2][2],
422                                               9, 1e-6) < 0);
423     }
424 
425     private void assertVectorEquals(double[] expected, double[] result) {
426         Assert.assertEquals("Wrong number of rows.", expected.length,
427                             result.length);
428         for (int i = 0; i < expected.length; i++) {
429             Assert.assertEquals("Wrong value at position [" + i + "]",
430                                 expected[i], result[i], 1.0e-6);
431         }
432     }
433 
434     private void assertMatrixEquals(double[][] expected, double[][] result) {
435         Assert.assertEquals("Wrong number of rows.", expected.length,
436                             result.length);
437         for (int i = 0; i < expected.length; i++) {
438             Assert.assertEquals("Wrong number of columns.", expected[i].length,
439                                 result[i].length);
440             for (int j = 0; j < expected[i].length; j++) {
441                 Assert.assertEquals("Wrong value at position [" + i + "," + j
442                                     + "]", expected[i][j], result[i][j], 1.0e-6);
443             }
444         }
445     }
446 
447     /**
448      * @param mu Mean
449      * @param sigma Standard deviation.
450      * @return a sampler that follows the N(mu,sigma) distribution.
451      */
452     private ContinuousSampler createGaussianSampler(double mu,
453                                                     double sigma) {
454         return GaussianSampler.of(ZigguratNormalizedGaussianSampler.of(RandomSource.JSF_64.create()),
455                                   mu, sigma);
456     }
457 }