View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.filter;
18  
19  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
20  import org.apache.commons.math4.legacy.exception.NullArgumentException;
21  import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
22  import org.apache.commons.math4.legacy.linear.ArrayRealVector;
23  import org.apache.commons.math4.legacy.linear.CholeskyDecomposition;
24  import org.apache.commons.math4.legacy.linear.MatrixDimensionMismatchException;
25  import org.apache.commons.math4.legacy.linear.MatrixUtils;
26  import org.apache.commons.math4.legacy.linear.NonSquareMatrixException;
27  import org.apache.commons.math4.legacy.linear.RealMatrix;
28  import org.apache.commons.math4.legacy.linear.RealVector;
29  import org.apache.commons.math4.legacy.linear.SingularMatrixException;
30  
31  /**
32   * Implementation of a Kalman filter to estimate the state <i>x<sub>k</sub></i>
33   * of a discrete-time controlled process that is governed by the linear
34   * stochastic difference equation:
35   *
36   * <pre>
37   * <i>x<sub>k</sub></i> = <b>A</b><i>x<sub>k-1</sub></i> + <b>B</b><i>u<sub>k-1</sub></i> + <i>w<sub>k-1</sub></i>
38   * </pre>
39   *
40   * with a measurement <i>x<sub>k</sub></i> that is
41   *
42   * <pre>
43   * <i>z<sub>k</sub></i> = <b>H</b><i>x<sub>k</sub></i> + <i>v<sub>k</sub></i>.
44   * </pre>
45   *
46   * <p>
47   * The random variables <i>w<sub>k</sub></i> and <i>v<sub>k</sub></i> represent
48   * the process and measurement noise and are assumed to be independent of each
49   * other and distributed with normal probability (white noise).
50   * <p>
51   * The Kalman filter cycle involves the following steps:
52   * <ol>
53   * <li>predict: project the current state estimate ahead in time</li>
54   * <li>correct: adjust the projected estimate by an actual measurement</li>
55   * </ol>
56   * <p>
57   * The Kalman filter is initialized with a {@link ProcessModel} and a
58   * {@link MeasurementModel}, which contain the corresponding transformation and
59   * noise covariance matrices. The parameter names used in the respective models
60   * correspond to the following names commonly used in the mathematical
61   * literature:
62   * <ul>
63   * <li>A - state transition matrix</li>
64   * <li>B - control input matrix</li>
65   * <li>H - measurement matrix</li>
66   * <li>Q - process noise covariance matrix</li>
67   * <li>R - measurement noise covariance matrix</li>
68   * <li>P - error covariance matrix</li>
69   * </ul>
70   *
71   * @see <a href="http://www.cs.unc.edu/~welch/kalman/">Kalman filter
72   *      resources</a>
73   * @see <a href="http://www.cs.unc.edu/~welch/media/pdf/kalman_intro.pdf">An
74   *      introduction to the Kalman filter by Greg Welch and Gary Bishop</a>
75   * @see <a href="http://academic.csuohio.edu/simond/courses/eec644/kalman.pdf">
76   *      Kalman filter example by Dan Simon</a>
77   * @see ProcessModel
78   * @see MeasurementModel
79   * @since 3.0
80   */
81  public class KalmanFilter {
82      /** The process model used by this filter instance. */
83      private final ProcessModel processModel;
84      /** The measurement model used by this filter instance. */
85      private final MeasurementModel measurementModel;
86      /** The transition matrix, equivalent to A. */
87      private RealMatrix transitionMatrix;
88      /** The transposed transition matrix. */
89      private RealMatrix transitionMatrixT;
90      /** The control matrix, equivalent to B. */
91      private RealMatrix controlMatrix;
92      /** The measurement matrix, equivalent to H. */
93      private RealMatrix measurementMatrix;
94      /** The transposed measurement matrix. */
95      private RealMatrix measurementMatrixT;
96      /** The internal state estimation vector, equivalent to x hat. */
97      private RealVector stateEstimation;
98      /** The error covariance matrix, equivalent to P. */
99      private RealMatrix errorCovariance;
100 
101     /**
102      * Creates a new Kalman filter with the given process and measurement models.
103      *
104      * @param process
105      *            the model defining the underlying process dynamics
106      * @param measurement
107      *            the model defining the given measurement characteristics
108      * @throws NullArgumentException
109      *             if any of the given inputs is null (except for the control matrix)
110      * @throws NonSquareMatrixException
111      *             if the transition matrix is non square
112      * @throws DimensionMismatchException
113      *             if the column dimension of the transition matrix does not match the dimension of the
114      *             initial state estimation vector
115      * @throws MatrixDimensionMismatchException
116      *             if the matrix dimensions do not fit together
117      */
118     public KalmanFilter(final ProcessModel process, final MeasurementModel measurement)
119             throws NullArgumentException, NonSquareMatrixException, DimensionMismatchException,
120                    MatrixDimensionMismatchException {
121 
122         NullArgumentException.check(process);
123         NullArgumentException.check(measurement);
124 
125         this.processModel = process;
126         this.measurementModel = measurement;
127 
128         transitionMatrix = processModel.getStateTransitionMatrix();
129         NullArgumentException.check(transitionMatrix);
130         transitionMatrixT = transitionMatrix.transpose();
131 
132         // create an empty matrix if no control matrix was given
133         if (processModel.getControlMatrix() == null) {
134             controlMatrix = new Array2DRowRealMatrix();
135         } else {
136             controlMatrix = processModel.getControlMatrix();
137         }
138 
139         measurementMatrix = measurementModel.getMeasurementMatrix();
140         NullArgumentException.check(measurementMatrix);
141         measurementMatrixT = measurementMatrix.transpose();
142 
143         // check that the process and measurement noise matrices are not null
144         // they will be directly accessed from the model as they may change
145         // over time
146         RealMatrix processNoise = processModel.getProcessNoise();
147         NullArgumentException.check(processNoise);
148         RealMatrix measNoise = measurementModel.getMeasurementNoise();
149         NullArgumentException.check(measNoise);
150 
151         // set the initial state estimate to a zero vector if it is not
152         // available from the process model
153         if (processModel.getInitialStateEstimate() == null) {
154             stateEstimation = new ArrayRealVector(transitionMatrix.getColumnDimension());
155         } else {
156             stateEstimation = processModel.getInitialStateEstimate();
157         }
158 
159         if (transitionMatrix.getColumnDimension() != stateEstimation.getDimension()) {
160             throw new DimensionMismatchException(transitionMatrix.getColumnDimension(),
161                                                  stateEstimation.getDimension());
162         }
163 
164         // initialize the error covariance to the process noise if it is not
165         // available from the process model
166         if (processModel.getInitialErrorCovariance() == null) {
167             errorCovariance = processNoise.copy();
168         } else {
169             errorCovariance = processModel.getInitialErrorCovariance();
170         }
171 
172         // sanity checks, the control matrix B may be null
173 
174         // A must be a square matrix
175         if (!transitionMatrix.isSquare()) {
176             throw new NonSquareMatrixException(
177                     transitionMatrix.getRowDimension(),
178                     transitionMatrix.getColumnDimension());
179         }
180 
181         // row dimension of B must be equal to A
182         // if no control matrix is available, the row and column dimension will be 0
183         if (controlMatrix != null &&
184             controlMatrix.getRowDimension() > 0 &&
185             controlMatrix.getColumnDimension() > 0 &&
186             controlMatrix.getRowDimension() != transitionMatrix.getRowDimension()) {
187             throw new MatrixDimensionMismatchException(controlMatrix.getRowDimension(),
188                                                        controlMatrix.getColumnDimension(),
189                                                        transitionMatrix.getRowDimension(),
190                                                        controlMatrix.getColumnDimension());
191         }
192 
193         // Q must be equal to A
194         MatrixUtils.checkAdditionCompatible(transitionMatrix, processNoise);
195 
196         // column dimension of H must be equal to row dimension of A
197         if (measurementMatrix.getColumnDimension() != transitionMatrix.getRowDimension()) {
198             throw new MatrixDimensionMismatchException(measurementMatrix.getRowDimension(),
199                                                        measurementMatrix.getColumnDimension(),
200                                                        measurementMatrix.getRowDimension(),
201                                                        transitionMatrix.getRowDimension());
202         }
203 
204         // row dimension of R must be equal to row dimension of H
205         if (measNoise.getRowDimension() != measurementMatrix.getRowDimension()) {
206             throw new MatrixDimensionMismatchException(measNoise.getRowDimension(),
207                                                        measNoise.getColumnDimension(),
208                                                        measurementMatrix.getRowDimension(),
209                                                        measNoise.getColumnDimension());
210         }
211     }
212 
213     /**
214      * Returns the dimension of the state estimation vector.
215      *
216      * @return the state dimension
217      */
218     public int getStateDimension() {
219         return stateEstimation.getDimension();
220     }
221 
222     /**
223      * Returns the dimension of the measurement vector.
224      *
225      * @return the measurement vector dimension
226      */
227     public int getMeasurementDimension() {
228         return measurementMatrix.getRowDimension();
229     }
230 
231     /**
232      * Returns the current state estimation vector.
233      *
234      * @return the state estimation vector
235      */
236     public double[] getStateEstimation() {
237         return stateEstimation.toArray();
238     }
239 
240     /**
241      * Returns a copy of the current state estimation vector.
242      *
243      * @return the state estimation vector
244      */
245     public RealVector getStateEstimationVector() {
246         return stateEstimation.copy();
247     }
248 
249     /**
250      * Returns the current error covariance matrix.
251      *
252      * @return the error covariance matrix
253      */
254     public double[][] getErrorCovariance() {
255         return errorCovariance.getData();
256     }
257 
258     /**
259      * Returns a copy of the current error covariance matrix.
260      *
261      * @return the error covariance matrix
262      */
263     public RealMatrix getErrorCovarianceMatrix() {
264         return errorCovariance.copy();
265     }
266 
267     /**
268      * Predict the internal state estimation one time step ahead.
269      */
270     public void predict() {
271         predict((RealVector) null);
272     }
273 
274     /**
275      * Predict the internal state estimation one time step ahead.
276      *
277      * @param u
278      *            the control vector
279      * @throws DimensionMismatchException
280      *             if the dimension of the control vector does not fit
281      */
282     public void predict(final double[] u) throws DimensionMismatchException {
283         predict(new ArrayRealVector(u, false));
284     }
285 
286     /**
287      * Predict the internal state estimation one time step ahead.
288      *
289      * @param u
290      *            the control vector
291      * @throws DimensionMismatchException
292      *             if the dimension of the control vector does not match
293      */
294     public void predict(final RealVector u) throws DimensionMismatchException {
295         // sanity checks
296         if (u != null &&
297             u.getDimension() != controlMatrix.getColumnDimension()) {
298             throw new DimensionMismatchException(u.getDimension(),
299                                                  controlMatrix.getColumnDimension());
300         }
301 
302         // project the state estimation ahead (a priori state)
303         // xHat(k)- = A * xHat(k-1) + B * u(k-1)
304         stateEstimation = transitionMatrix.operate(stateEstimation);
305 
306         // add control input if it is available
307         if (u != null) {
308             stateEstimation = stateEstimation.add(controlMatrix.operate(u));
309         }
310 
311         // project the error covariance ahead
312         // P(k)- = A * P(k-1) * A' + Q
313         errorCovariance = transitionMatrix.multiply(errorCovariance)
314                 .multiply(transitionMatrixT)
315                 .add(processModel.getProcessNoise());
316     }
317 
318     /**
319      * Correct the current state estimate with an actual measurement.
320      *
321      * @param z
322      *            the measurement vector
323      * @throws NullArgumentException
324      *             if the measurement vector is {@code null}
325      * @throws DimensionMismatchException
326      *             if the dimension of the measurement vector does not fit
327      * @throws SingularMatrixException
328      *             if the covariance matrix could not be inverted
329      */
330     public void correct(final double[] z)
331             throws NullArgumentException, DimensionMismatchException, SingularMatrixException {
332         correct(new ArrayRealVector(z, false));
333     }
334 
335     /**
336      * Correct the current state estimate with an actual measurement.
337      *
338      * @param z
339      *            the measurement vector
340      * @throws NullArgumentException
341      *             if the measurement vector is {@code null}
342      * @throws DimensionMismatchException
343      *             if the dimension of the measurement vector does not fit
344      * @throws SingularMatrixException
345      *             if the covariance matrix could not be inverted
346      */
347     public void correct(final RealVector z)
348             throws NullArgumentException, DimensionMismatchException, SingularMatrixException {
349 
350         // sanity checks
351         NullArgumentException.check(z);
352         if (z.getDimension() != measurementMatrix.getRowDimension()) {
353             throw new DimensionMismatchException(z.getDimension(),
354                                                  measurementMatrix.getRowDimension());
355         }
356 
357         // S = H * P(k) * H' + R
358         RealMatrix s = measurementMatrix.multiply(errorCovariance)
359             .multiply(measurementMatrixT)
360             .add(measurementModel.getMeasurementNoise());
361 
362         // Inn = z(k) - H * xHat(k)-
363         RealVector innovation = z.subtract(measurementMatrix.operate(stateEstimation));
364 
365         // calculate gain matrix
366         // K(k) = P(k)- * H' * (H * P(k)- * H' + R)^-1
367         // K(k) = P(k)- * H' * S^-1
368 
369         // instead of calculating the inverse of S we can rearrange the formula,
370         // and then solve the linear equation A x X = B with A = S', X = K' and B = (H * P)'
371 
372         // K(k) * S = P(k)- * H'
373         // S' * K(k)' = H * P(k)-'
374         RealMatrix kalmanGain = new CholeskyDecomposition(s).getSolver()
375                 .solve(measurementMatrix.multiply(errorCovariance.transpose()))
376                 .transpose();
377 
378         // update estimate with measurement z(k)
379         // xHat(k) = xHat(k)- + K * Inn
380         stateEstimation = stateEstimation.add(kalmanGain.operate(innovation));
381 
382         // update covariance of prediction error
383         // P(k) = (I - K * H) * P(k)-
384         RealMatrix identity = MatrixUtils.createRealIdentityMatrix(kalmanGain.getRowDimension());
385         errorCovariance = identity.subtract(kalmanGain.multiply(measurementMatrix)).multiply(errorCovariance);
386     }
387 }