1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.filter;
18
19 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
20 import org.apache.commons.math4.legacy.exception.NullArgumentException;
21 import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
22 import org.apache.commons.math4.legacy.linear.ArrayRealVector;
23 import org.apache.commons.math4.legacy.linear.CholeskyDecomposition;
24 import org.apache.commons.math4.legacy.linear.MatrixDimensionMismatchException;
25 import org.apache.commons.math4.legacy.linear.MatrixUtils;
26 import org.apache.commons.math4.legacy.linear.NonSquareMatrixException;
27 import org.apache.commons.math4.legacy.linear.RealMatrix;
28 import org.apache.commons.math4.legacy.linear.RealVector;
29 import org.apache.commons.math4.legacy.linear.SingularMatrixException;
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81 public class KalmanFilter {
82
83 private final ProcessModel processModel;
84
85 private final MeasurementModel measurementModel;
86
87 private RealMatrix transitionMatrix;
88
89 private RealMatrix transitionMatrixT;
90
91 private RealMatrix controlMatrix;
92
93 private RealMatrix measurementMatrix;
94
95 private RealMatrix measurementMatrixT;
96
97 private RealVector stateEstimation;
98
99 private RealMatrix errorCovariance;
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118 public KalmanFilter(final ProcessModel process, final MeasurementModel measurement)
119 throws NullArgumentException, NonSquareMatrixException, DimensionMismatchException,
120 MatrixDimensionMismatchException {
121
122 NullArgumentException.check(process);
123 NullArgumentException.check(measurement);
124
125 this.processModel = process;
126 this.measurementModel = measurement;
127
128 transitionMatrix = processModel.getStateTransitionMatrix();
129 NullArgumentException.check(transitionMatrix);
130 transitionMatrixT = transitionMatrix.transpose();
131
132
133 if (processModel.getControlMatrix() == null) {
134 controlMatrix = new Array2DRowRealMatrix();
135 } else {
136 controlMatrix = processModel.getControlMatrix();
137 }
138
139 measurementMatrix = measurementModel.getMeasurementMatrix();
140 NullArgumentException.check(measurementMatrix);
141 measurementMatrixT = measurementMatrix.transpose();
142
143
144
145
146 RealMatrix processNoise = processModel.getProcessNoise();
147 NullArgumentException.check(processNoise);
148 RealMatrix measNoise = measurementModel.getMeasurementNoise();
149 NullArgumentException.check(measNoise);
150
151
152
153 if (processModel.getInitialStateEstimate() == null) {
154 stateEstimation = new ArrayRealVector(transitionMatrix.getColumnDimension());
155 } else {
156 stateEstimation = processModel.getInitialStateEstimate();
157 }
158
159 if (transitionMatrix.getColumnDimension() != stateEstimation.getDimension()) {
160 throw new DimensionMismatchException(transitionMatrix.getColumnDimension(),
161 stateEstimation.getDimension());
162 }
163
164
165
166 if (processModel.getInitialErrorCovariance() == null) {
167 errorCovariance = processNoise.copy();
168 } else {
169 errorCovariance = processModel.getInitialErrorCovariance();
170 }
171
172
173
174
175 if (!transitionMatrix.isSquare()) {
176 throw new NonSquareMatrixException(
177 transitionMatrix.getRowDimension(),
178 transitionMatrix.getColumnDimension());
179 }
180
181
182
183 if (controlMatrix != null &&
184 controlMatrix.getRowDimension() > 0 &&
185 controlMatrix.getColumnDimension() > 0 &&
186 controlMatrix.getRowDimension() != transitionMatrix.getRowDimension()) {
187 throw new MatrixDimensionMismatchException(controlMatrix.getRowDimension(),
188 controlMatrix.getColumnDimension(),
189 transitionMatrix.getRowDimension(),
190 controlMatrix.getColumnDimension());
191 }
192
193
194 MatrixUtils.checkAdditionCompatible(transitionMatrix, processNoise);
195
196
197 if (measurementMatrix.getColumnDimension() != transitionMatrix.getRowDimension()) {
198 throw new MatrixDimensionMismatchException(measurementMatrix.getRowDimension(),
199 measurementMatrix.getColumnDimension(),
200 measurementMatrix.getRowDimension(),
201 transitionMatrix.getRowDimension());
202 }
203
204
205 if (measNoise.getRowDimension() != measurementMatrix.getRowDimension()) {
206 throw new MatrixDimensionMismatchException(measNoise.getRowDimension(),
207 measNoise.getColumnDimension(),
208 measurementMatrix.getRowDimension(),
209 measNoise.getColumnDimension());
210 }
211 }
212
213
214
215
216
217
218 public int getStateDimension() {
219 return stateEstimation.getDimension();
220 }
221
222
223
224
225
226
227 public int getMeasurementDimension() {
228 return measurementMatrix.getRowDimension();
229 }
230
231
232
233
234
235
236 public double[] getStateEstimation() {
237 return stateEstimation.toArray();
238 }
239
240
241
242
243
244
245 public RealVector getStateEstimationVector() {
246 return stateEstimation.copy();
247 }
248
249
250
251
252
253
254 public double[][] getErrorCovariance() {
255 return errorCovariance.getData();
256 }
257
258
259
260
261
262
263 public RealMatrix getErrorCovarianceMatrix() {
264 return errorCovariance.copy();
265 }
266
267
268
269
270 public void predict() {
271 predict((RealVector) null);
272 }
273
274
275
276
277
278
279
280
281
282 public void predict(final double[] u) throws DimensionMismatchException {
283 predict(new ArrayRealVector(u, false));
284 }
285
286
287
288
289
290
291
292
293
294 public void predict(final RealVector u) throws DimensionMismatchException {
295
296 if (u != null &&
297 u.getDimension() != controlMatrix.getColumnDimension()) {
298 throw new DimensionMismatchException(u.getDimension(),
299 controlMatrix.getColumnDimension());
300 }
301
302
303
304 stateEstimation = transitionMatrix.operate(stateEstimation);
305
306
307 if (u != null) {
308 stateEstimation = stateEstimation.add(controlMatrix.operate(u));
309 }
310
311
312
313 errorCovariance = transitionMatrix.multiply(errorCovariance)
314 .multiply(transitionMatrixT)
315 .add(processModel.getProcessNoise());
316 }
317
318
319
320
321
322
323
324
325
326
327
328
329
330 public void correct(final double[] z)
331 throws NullArgumentException, DimensionMismatchException, SingularMatrixException {
332 correct(new ArrayRealVector(z, false));
333 }
334
335
336
337
338
339
340
341
342
343
344
345
346
347 public void correct(final RealVector z)
348 throws NullArgumentException, DimensionMismatchException, SingularMatrixException {
349
350
351 NullArgumentException.check(z);
352 if (z.getDimension() != measurementMatrix.getRowDimension()) {
353 throw new DimensionMismatchException(z.getDimension(),
354 measurementMatrix.getRowDimension());
355 }
356
357
358 RealMatrix s = measurementMatrix.multiply(errorCovariance)
359 .multiply(measurementMatrixT)
360 .add(measurementModel.getMeasurementNoise());
361
362
363 RealVector innovation = z.subtract(measurementMatrix.operate(stateEstimation));
364
365
366
367
368
369
370
371
372
373
374 RealMatrix kalmanGain = new CholeskyDecomposition(s).getSolver()
375 .solve(measurementMatrix.multiply(errorCovariance.transpose()))
376 .transpose();
377
378
379
380 stateEstimation = stateEstimation.add(kalmanGain.operate(innovation));
381
382
383
384 RealMatrix identity = MatrixUtils.createRealIdentityMatrix(kalmanGain.getRowDimension());
385 errorCovariance = identity.subtract(kalmanGain.multiply(measurementMatrix)).multiply(errorCovariance);
386 }
387 }