001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.filter;
018
019import org.apache.commons.math3.exception.DimensionMismatchException;
020import org.apache.commons.math3.exception.NullArgumentException;
021import org.apache.commons.math3.linear.Array2DRowRealMatrix;
022import org.apache.commons.math3.linear.ArrayRealVector;
023import org.apache.commons.math3.linear.MatrixDimensionMismatchException;
024import org.apache.commons.math3.linear.MatrixUtils;
025import org.apache.commons.math3.linear.NonSquareMatrixException;
026import org.apache.commons.math3.linear.RealMatrix;
027import org.apache.commons.math3.linear.RealVector;
028import org.apache.commons.math3.linear.SingularMatrixException;
029import org.apache.commons.math3.util.MathUtils;
030
031/**
032 * Implementation of a Kalman filter to estimate the state <i>x<sub>k</sub></i>
033 * of a discrete-time controlled process that is governed by the linear
034 * stochastic difference equation:
035 *
036 * <pre>
037 * <i>x<sub>k</sub></i> = <b>A</b><i>x<sub>k-1</sub></i> + <b>B</b><i>u<sub>k-1</sub></i> + <i>w<sub>k-1</sub></i>
038 * </pre>
039 *
040 * with a measurement <i>x<sub>k</sub></i> that is
041 *
042 * <pre>
043 * <i>z<sub>k</sub></i> = <b>H</b><i>x<sub>k</sub></i> + <i>v<sub>k</sub></i>.
044 * </pre>
045 *
046 * <p>
047 * The random variables <i>w<sub>k</sub></i> and <i>v<sub>k</sub></i> represent
048 * the process and measurement noise and are assumed to be independent of each
049 * other and distributed with normal probability (white noise).
050 * <p>
051 * The Kalman filter cycle involves the following steps:
052 * <ol>
053 * <li>predict: project the current state estimate ahead in time</li>
054 * <li>correct: adjust the projected estimate by an actual measurement</li>
055 * </ol>
056 * <p>
057 * The Kalman filter is initialized with a {@link ProcessModel} and a
058 * {@link MeasurementModel}, which contain the corresponding transformation and
059 * noise covariance matrices. The parameter names used in the respective models
060 * correspond to the following names commonly used in the mathematical
061 * literature:
062 * <ul>
063 * <li>A - state transition matrix</li>
064 * <li>B - control input matrix</li>
065 * <li>H - measurement matrix</li>
066 * <li>Q - process noise covariance matrix</li>
067 * <li>R - measurement noise covariance matrix</li>
068 * <li>P - error covariance matrix</li>
069 * </ul>
070 *
071 * @see <a href="http://www.cs.unc.edu/~welch/kalman/">Kalman filter
072 *      resources</a>
073 * @see <a href="http://www.cs.unc.edu/~welch/media/pdf/kalman_intro.pdf">An
074 *      introduction to the Kalman filter by Greg Welch and Gary Bishop</a>
075 * @see <a href="http://academic.csuohio.edu/simond/courses/eec644/kalman.pdf">
076 *      Kalman filter example by Dan Simon</a>
077 * @see ProcessModel
078 * @see MeasurementModel
079 * @since 3.0
080 */
081public class KalmanFilter {
082    /** The process model used by this filter instance. */
083    private final ProcessModel processModel;
084    /** The measurement model used by this filter instance. */
085    private final MeasurementModel measurementModel;
086    /** The transition matrix, equivalent to A. */
087    private RealMatrix transitionMatrix;
088    /** The transposed transition matrix. */
089    private RealMatrix transitionMatrixT;
090    /** The control matrix, equivalent to B. */
091    private RealMatrix controlMatrix;
092    /** The measurement matrix, equivalent to H. */
093    private RealMatrix measurementMatrix;
094    /** The transposed measurement matrix. */
095    private RealMatrix measurementMatrixT;
096    /** The internal state estimation vector, equivalent to x hat. */
097    private RealVector stateEstimation;
098    /** The error covariance matrix, equivalent to P. */
099    private RealMatrix errorCovariance;
100
101    /**
102     * Creates a new Kalman filter with the given process and measurement models.
103     *
104     * @param process
105     *            the model defining the underlying process dynamics
106     * @param measurement
107     *            the model defining the given measurement characteristics
108     * @throws NullArgumentException
109     *             if any of the given inputs is null (except for the control matrix)
110     * @throws NonSquareMatrixException
111     *             if the transition matrix is non square
112     * @throws DimensionMismatchException
113     *             if the column dimension of the transition matrix does not match the dimension of the
114     *             initial state estimation vector
115     * @throws MatrixDimensionMismatchException
116     *             if the matrix dimensions do not fit together
117     */
118    public KalmanFilter(final ProcessModel process, final MeasurementModel measurement)
119            throws NullArgumentException, NonSquareMatrixException, DimensionMismatchException,
120                   MatrixDimensionMismatchException {
121
122        MathUtils.checkNotNull(process);
123        MathUtils.checkNotNull(measurement);
124
125        this.processModel = process;
126        this.measurementModel = measurement;
127
128        transitionMatrix = processModel.getStateTransitionMatrix();
129        MathUtils.checkNotNull(transitionMatrix);
130        transitionMatrixT = transitionMatrix.transpose();
131
132        // create an empty matrix if no control matrix was given
133        if (processModel.getControlMatrix() == null) {
134            controlMatrix = new Array2DRowRealMatrix();
135        } else {
136            controlMatrix = processModel.getControlMatrix();
137        }
138
139        measurementMatrix = measurementModel.getMeasurementMatrix();
140        MathUtils.checkNotNull(measurementMatrix);
141        measurementMatrixT = measurementMatrix.transpose();
142
143        // check that the process and measurement noise matrices are not null
144        // they will be directly accessed from the model as they may change
145        // over time
146        RealMatrix processNoise = processModel.getProcessNoise();
147        MathUtils.checkNotNull(processNoise);
148        RealMatrix measNoise = measurementModel.getMeasurementNoise();
149        MathUtils.checkNotNull(measNoise);
150
151        // set the initial state estimate to a zero vector if it is not
152        // available from the process model
153        if (processModel.getInitialStateEstimate() == null) {
154            stateEstimation = new ArrayRealVector(transitionMatrix.getColumnDimension());
155        } else {
156            stateEstimation = processModel.getInitialStateEstimate();
157        }
158
159        if (transitionMatrix.getColumnDimension() != stateEstimation.getDimension()) {
160            throw new DimensionMismatchException(transitionMatrix.getColumnDimension(),
161                                                 stateEstimation.getDimension());
162        }
163
164        // initialize the error covariance to the process noise if it is not
165        // available from the process model
166        if (processModel.getInitialErrorCovariance() == null) {
167            errorCovariance = processNoise.copy();
168        } else {
169            errorCovariance = processModel.getInitialErrorCovariance();
170        }
171
172        // sanity checks, the control matrix B may be null
173
174        // A must be a square matrix
175        if (!transitionMatrix.isSquare()) {
176            throw new NonSquareMatrixException(
177                    transitionMatrix.getRowDimension(),
178                    transitionMatrix.getColumnDimension());
179        }
180
181        // row dimension of B must be equal to A
182        // if no control matrix is available, the row and column dimension will be 0
183        if (controlMatrix != null &&
184            controlMatrix.getRowDimension() > 0 &&
185            controlMatrix.getColumnDimension() > 0 &&
186            controlMatrix.getRowDimension() != transitionMatrix.getRowDimension()) {
187            throw new MatrixDimensionMismatchException(controlMatrix.getRowDimension(),
188                                                       controlMatrix.getColumnDimension(),
189                                                       transitionMatrix.getRowDimension(),
190                                                       controlMatrix.getColumnDimension());
191        }
192
193        // Q must be equal to A
194        MatrixUtils.checkAdditionCompatible(transitionMatrix, processNoise);
195
196        // column dimension of H must be equal to row dimension of A
197        if (measurementMatrix.getColumnDimension() != transitionMatrix.getRowDimension()) {
198            throw new MatrixDimensionMismatchException(measurementMatrix.getRowDimension(),
199                                                       measurementMatrix.getColumnDimension(),
200                                                       measurementMatrix.getRowDimension(),
201                                                       transitionMatrix.getRowDimension());
202        }
203
204        // row dimension of R must be equal to row dimension of H
205        if (measNoise.getRowDimension() != measurementMatrix.getRowDimension()) {
206            throw new MatrixDimensionMismatchException(measNoise.getRowDimension(),
207                                                       measNoise.getColumnDimension(),
208                                                       measurementMatrix.getRowDimension(),
209                                                       measNoise.getColumnDimension());
210        }
211    }
212
213    /**
214     * Returns the dimension of the state estimation vector.
215     *
216     * @return the state dimension
217     */
218    public int getStateDimension() {
219        return stateEstimation.getDimension();
220    }
221
222    /**
223     * Returns the dimension of the measurement vector.
224     *
225     * @return the measurement vector dimension
226     */
227    public int getMeasurementDimension() {
228        return measurementMatrix.getRowDimension();
229    }
230
231    /**
232     * Returns the current state estimation vector.
233     *
234     * @return the state estimation vector
235     */
236    public double[] getStateEstimation() {
237        return stateEstimation.toArray();
238    }
239
240    /**
241     * Returns a copy of the current state estimation vector.
242     *
243     * @return the state estimation vector
244     */
245    public RealVector getStateEstimationVector() {
246        return stateEstimation.copy();
247    }
248
249    /**
250     * Returns the current error covariance matrix.
251     *
252     * @return the error covariance matrix
253     */
254    public double[][] getErrorCovariance() {
255        return errorCovariance.getData();
256    }
257
258    /**
259     * Returns a copy of the current error covariance matrix.
260     *
261     * @return the error covariance matrix
262     */
263    public RealMatrix getErrorCovarianceMatrix() {
264        return errorCovariance.copy();
265    }
266
267    /**
268     * Predict the internal state estimation one time step ahead.
269     */
270    public void predict() {
271        predict((RealVector) null);
272    }
273
274    /**
275     * Predict the internal state estimation one time step ahead.
276     *
277     * @param u
278     *            the control vector
279     * @throws DimensionMismatchException
280     *             if the dimension of the control vector does not fit
281     */
282    public void predict(final double[] u) throws DimensionMismatchException {
283        predict(new ArrayRealVector(u, false));
284    }
285
286    /**
287     * Predict the internal state estimation one time step ahead.
288     *
289     * @param u
290     *            the control vector
291     * @throws DimensionMismatchException
292     *             if the dimension of the control vector does not match
293     */
294    public void predict(final RealVector u) throws DimensionMismatchException {
295        // sanity checks
296        if (u != null &&
297            u.getDimension() != controlMatrix.getColumnDimension()) {
298            throw new DimensionMismatchException(u.getDimension(),
299                                                 controlMatrix.getColumnDimension());
300        }
301
302        // project the state estimation ahead (a priori state)
303        // xHat(k)- = A * xHat(k-1) + B * u(k-1)
304        stateEstimation = transitionMatrix.operate(stateEstimation);
305
306        // add control input if it is available
307        if (u != null) {
308            stateEstimation = stateEstimation.add(controlMatrix.operate(u));
309        }
310
311        // project the error covariance ahead
312        // P(k)- = A * P(k-1) * A' + Q
313        errorCovariance = transitionMatrix.multiply(errorCovariance)
314                .multiply(transitionMatrixT)
315                .add(processModel.getProcessNoise());
316    }
317
318    /**
319     * Correct the current state estimate with an actual measurement.
320     *
321     * @param z
322     *            the measurement vector
323     * @throws NullArgumentException
324     *             if the measurement vector is {@code null}
325     * @throws DimensionMismatchException
326     *             if the dimension of the measurement vector does not fit
327     * @throws SingularMatrixException
328     *             if the covariance matrix could not be inverted
329     */
330    public void correct(final double[] z)
331            throws NullArgumentException, DimensionMismatchException, SingularMatrixException {
332        correct(new ArrayRealVector(z, false));
333    }
334
335    /**
336     * Correct the current state estimate with an actual measurement.
337     *
338     * @param z
339     *            the measurement vector
340     * @throws NullArgumentException
341     *             if the measurement vector is {@code null}
342     * @throws DimensionMismatchException
343     *             if the dimension of the measurement vector does not fit
344     * @throws SingularMatrixException
345     *             if the covariance matrix could not be inverted
346     */
347    public void correct(final RealVector z)
348            throws NullArgumentException, DimensionMismatchException, SingularMatrixException {
349
350        // sanity checks
351        MathUtils.checkNotNull(z);
352        if (z.getDimension() != measurementMatrix.getRowDimension()) {
353            throw new DimensionMismatchException(z.getDimension(),
354                                                 measurementMatrix.getRowDimension());
355        }
356
357        // S = H * P(k) * H' + R
358        RealMatrix s = measurementMatrix.multiply(errorCovariance)
359            .multiply(measurementMatrixT)
360            .add(measurementModel.getMeasurementNoise());
361
362        // invert S
363        RealMatrix invertedS = MatrixUtils.inverse(s);
364
365        // Inn = z(k) - H * xHat(k)-
366        RealVector innovation = z.subtract(measurementMatrix.operate(stateEstimation));
367
368        // calculate gain matrix
369        // K(k) = P(k)- * H' * (H * P(k)- * H' + R)^-1
370        // K(k) = P(k)- * H' * S^-1
371        RealMatrix kalmanGain = errorCovariance.multiply(measurementMatrixT).multiply(invertedS);
372
373        // update estimate with measurement z(k)
374        // xHat(k) = xHat(k)- + K * Inn
375        stateEstimation = stateEstimation.add(kalmanGain.operate(innovation));
376
377        // update covariance of prediction error
378        // P(k) = (I - K * H) * P(k)-
379        RealMatrix identity = MatrixUtils.createRealIdentityMatrix(kalmanGain.getRowDimension());
380        errorCovariance = identity.subtract(kalmanGain.multiply(measurementMatrix)).multiply(errorCovariance);
381    }
382}