001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.filter;
018
019import org.apache.commons.math3.exception.DimensionMismatchException;
020import org.apache.commons.math3.exception.NullArgumentException;
021import org.apache.commons.math3.linear.Array2DRowRealMatrix;
022import org.apache.commons.math3.linear.ArrayRealVector;
023import org.apache.commons.math3.linear.CholeskyDecomposition;
024import org.apache.commons.math3.linear.MatrixDimensionMismatchException;
025import org.apache.commons.math3.linear.MatrixUtils;
026import org.apache.commons.math3.linear.NonSquareMatrixException;
027import org.apache.commons.math3.linear.RealMatrix;
028import org.apache.commons.math3.linear.RealVector;
029import org.apache.commons.math3.linear.SingularMatrixException;
030import org.apache.commons.math3.util.MathUtils;
031
032/**
033 * Implementation of a Kalman filter to estimate the state <i>x<sub>k</sub></i>
034 * of a discrete-time controlled process that is governed by the linear
035 * stochastic difference equation:
036 *
037 * <pre>
038 * <i>x<sub>k</sub></i> = <b>A</b><i>x<sub>k-1</sub></i> + <b>B</b><i>u<sub>k-1</sub></i> + <i>w<sub>k-1</sub></i>
039 * </pre>
040 *
041 * with a measurement <i>x<sub>k</sub></i> that is
042 *
043 * <pre>
044 * <i>z<sub>k</sub></i> = <b>H</b><i>x<sub>k</sub></i> + <i>v<sub>k</sub></i>.
045 * </pre>
046 *
047 * <p>
048 * The random variables <i>w<sub>k</sub></i> and <i>v<sub>k</sub></i> represent
049 * the process and measurement noise and are assumed to be independent of each
050 * other and distributed with normal probability (white noise).
051 * <p>
052 * The Kalman filter cycle involves the following steps:
053 * <ol>
054 * <li>predict: project the current state estimate ahead in time</li>
055 * <li>correct: adjust the projected estimate by an actual measurement</li>
056 * </ol>
057 * <p>
058 * The Kalman filter is initialized with a {@link ProcessModel} and a
059 * {@link MeasurementModel}, which contain the corresponding transformation and
060 * noise covariance matrices. The parameter names used in the respective models
061 * correspond to the following names commonly used in the mathematical
062 * literature:
063 * <ul>
064 * <li>A - state transition matrix</li>
065 * <li>B - control input matrix</li>
066 * <li>H - measurement matrix</li>
067 * <li>Q - process noise covariance matrix</li>
068 * <li>R - measurement noise covariance matrix</li>
069 * <li>P - error covariance matrix</li>
070 * </ul>
071 *
072 * @see <a href="http://www.cs.unc.edu/~welch/kalman/">Kalman filter
073 *      resources</a>
074 * @see <a href="http://www.cs.unc.edu/~welch/media/pdf/kalman_intro.pdf">An
075 *      introduction to the Kalman filter by Greg Welch and Gary Bishop</a>
076 * @see <a href="http://academic.csuohio.edu/simond/courses/eec644/kalman.pdf">
077 *      Kalman filter example by Dan Simon</a>
078 * @see ProcessModel
079 * @see MeasurementModel
080 * @since 3.0
081 */
082public class KalmanFilter {
083    /** The process model used by this filter instance. */
084    private final ProcessModel processModel;
085    /** The measurement model used by this filter instance. */
086    private final MeasurementModel measurementModel;
087    /** The transition matrix, equivalent to A. */
088    private RealMatrix transitionMatrix;
089    /** The transposed transition matrix. */
090    private RealMatrix transitionMatrixT;
091    /** The control matrix, equivalent to B. */
092    private RealMatrix controlMatrix;
093    /** The measurement matrix, equivalent to H. */
094    private RealMatrix measurementMatrix;
095    /** The transposed measurement matrix. */
096    private RealMatrix measurementMatrixT;
097    /** The internal state estimation vector, equivalent to x hat. */
098    private RealVector stateEstimation;
099    /** The error covariance matrix, equivalent to P. */
100    private RealMatrix errorCovariance;
101
102    /**
103     * Creates a new Kalman filter with the given process and measurement models.
104     *
105     * @param process
106     *            the model defining the underlying process dynamics
107     * @param measurement
108     *            the model defining the given measurement characteristics
109     * @throws NullArgumentException
110     *             if any of the given inputs is null (except for the control matrix)
111     * @throws NonSquareMatrixException
112     *             if the transition matrix is non square
113     * @throws DimensionMismatchException
114     *             if the column dimension of the transition matrix does not match the dimension of the
115     *             initial state estimation vector
116     * @throws MatrixDimensionMismatchException
117     *             if the matrix dimensions do not fit together
118     */
119    public KalmanFilter(final ProcessModel process, final MeasurementModel measurement)
120            throws NullArgumentException, NonSquareMatrixException, DimensionMismatchException,
121                   MatrixDimensionMismatchException {
122
123        MathUtils.checkNotNull(process);
124        MathUtils.checkNotNull(measurement);
125
126        this.processModel = process;
127        this.measurementModel = measurement;
128
129        transitionMatrix = processModel.getStateTransitionMatrix();
130        MathUtils.checkNotNull(transitionMatrix);
131        transitionMatrixT = transitionMatrix.transpose();
132
133        // create an empty matrix if no control matrix was given
134        if (processModel.getControlMatrix() == null) {
135            controlMatrix = new Array2DRowRealMatrix();
136        } else {
137            controlMatrix = processModel.getControlMatrix();
138        }
139
140        measurementMatrix = measurementModel.getMeasurementMatrix();
141        MathUtils.checkNotNull(measurementMatrix);
142        measurementMatrixT = measurementMatrix.transpose();
143
144        // check that the process and measurement noise matrices are not null
145        // they will be directly accessed from the model as they may change
146        // over time
147        RealMatrix processNoise = processModel.getProcessNoise();
148        MathUtils.checkNotNull(processNoise);
149        RealMatrix measNoise = measurementModel.getMeasurementNoise();
150        MathUtils.checkNotNull(measNoise);
151
152        // set the initial state estimate to a zero vector if it is not
153        // available from the process model
154        if (processModel.getInitialStateEstimate() == null) {
155            stateEstimation = new ArrayRealVector(transitionMatrix.getColumnDimension());
156        } else {
157            stateEstimation = processModel.getInitialStateEstimate();
158        }
159
160        if (transitionMatrix.getColumnDimension() != stateEstimation.getDimension()) {
161            throw new DimensionMismatchException(transitionMatrix.getColumnDimension(),
162                                                 stateEstimation.getDimension());
163        }
164
165        // initialize the error covariance to the process noise if it is not
166        // available from the process model
167        if (processModel.getInitialErrorCovariance() == null) {
168            errorCovariance = processNoise.copy();
169        } else {
170            errorCovariance = processModel.getInitialErrorCovariance();
171        }
172
173        // sanity checks, the control matrix B may be null
174
175        // A must be a square matrix
176        if (!transitionMatrix.isSquare()) {
177            throw new NonSquareMatrixException(
178                    transitionMatrix.getRowDimension(),
179                    transitionMatrix.getColumnDimension());
180        }
181
182        // row dimension of B must be equal to A
183        // if no control matrix is available, the row and column dimension will be 0
184        if (controlMatrix != null &&
185            controlMatrix.getRowDimension() > 0 &&
186            controlMatrix.getColumnDimension() > 0 &&
187            controlMatrix.getRowDimension() != transitionMatrix.getRowDimension()) {
188            throw new MatrixDimensionMismatchException(controlMatrix.getRowDimension(),
189                                                       controlMatrix.getColumnDimension(),
190                                                       transitionMatrix.getRowDimension(),
191                                                       controlMatrix.getColumnDimension());
192        }
193
194        // Q must be equal to A
195        MatrixUtils.checkAdditionCompatible(transitionMatrix, processNoise);
196
197        // column dimension of H must be equal to row dimension of A
198        if (measurementMatrix.getColumnDimension() != transitionMatrix.getRowDimension()) {
199            throw new MatrixDimensionMismatchException(measurementMatrix.getRowDimension(),
200                                                       measurementMatrix.getColumnDimension(),
201                                                       measurementMatrix.getRowDimension(),
202                                                       transitionMatrix.getRowDimension());
203        }
204
205        // row dimension of R must be equal to row dimension of H
206        if (measNoise.getRowDimension() != measurementMatrix.getRowDimension()) {
207            throw new MatrixDimensionMismatchException(measNoise.getRowDimension(),
208                                                       measNoise.getColumnDimension(),
209                                                       measurementMatrix.getRowDimension(),
210                                                       measNoise.getColumnDimension());
211        }
212    }
213
214    /**
215     * Returns the dimension of the state estimation vector.
216     *
217     * @return the state dimension
218     */
219    public int getStateDimension() {
220        return stateEstimation.getDimension();
221    }
222
223    /**
224     * Returns the dimension of the measurement vector.
225     *
226     * @return the measurement vector dimension
227     */
228    public int getMeasurementDimension() {
229        return measurementMatrix.getRowDimension();
230    }
231
232    /**
233     * Returns the current state estimation vector.
234     *
235     * @return the state estimation vector
236     */
237    public double[] getStateEstimation() {
238        return stateEstimation.toArray();
239    }
240
241    /**
242     * Returns a copy of the current state estimation vector.
243     *
244     * @return the state estimation vector
245     */
246    public RealVector getStateEstimationVector() {
247        return stateEstimation.copy();
248    }
249
250    /**
251     * Returns the current error covariance matrix.
252     *
253     * @return the error covariance matrix
254     */
255    public double[][] getErrorCovariance() {
256        return errorCovariance.getData();
257    }
258
259    /**
260     * Returns a copy of the current error covariance matrix.
261     *
262     * @return the error covariance matrix
263     */
264    public RealMatrix getErrorCovarianceMatrix() {
265        return errorCovariance.copy();
266    }
267
268    /**
269     * Predict the internal state estimation one time step ahead.
270     */
271    public void predict() {
272        predict((RealVector) null);
273    }
274
275    /**
276     * Predict the internal state estimation one time step ahead.
277     *
278     * @param u
279     *            the control vector
280     * @throws DimensionMismatchException
281     *             if the dimension of the control vector does not fit
282     */
283    public void predict(final double[] u) throws DimensionMismatchException {
284        predict(new ArrayRealVector(u, false));
285    }
286
287    /**
288     * Predict the internal state estimation one time step ahead.
289     *
290     * @param u
291     *            the control vector
292     * @throws DimensionMismatchException
293     *             if the dimension of the control vector does not match
294     */
295    public void predict(final RealVector u) throws DimensionMismatchException {
296        // sanity checks
297        if (u != null &&
298            u.getDimension() != controlMatrix.getColumnDimension()) {
299            throw new DimensionMismatchException(u.getDimension(),
300                                                 controlMatrix.getColumnDimension());
301        }
302
303        // project the state estimation ahead (a priori state)
304        // xHat(k)- = A * xHat(k-1) + B * u(k-1)
305        stateEstimation = transitionMatrix.operate(stateEstimation);
306
307        // add control input if it is available
308        if (u != null) {
309            stateEstimation = stateEstimation.add(controlMatrix.operate(u));
310        }
311
312        // project the error covariance ahead
313        // P(k)- = A * P(k-1) * A' + Q
314        errorCovariance = transitionMatrix.multiply(errorCovariance)
315                .multiply(transitionMatrixT)
316                .add(processModel.getProcessNoise());
317    }
318
319    /**
320     * Correct the current state estimate with an actual measurement.
321     *
322     * @param z
323     *            the measurement vector
324     * @throws NullArgumentException
325     *             if the measurement vector is {@code null}
326     * @throws DimensionMismatchException
327     *             if the dimension of the measurement vector does not fit
328     * @throws SingularMatrixException
329     *             if the covariance matrix could not be inverted
330     */
331    public void correct(final double[] z)
332            throws NullArgumentException, DimensionMismatchException, SingularMatrixException {
333        correct(new ArrayRealVector(z, false));
334    }
335
336    /**
337     * Correct the current state estimate with an actual measurement.
338     *
339     * @param z
340     *            the measurement vector
341     * @throws NullArgumentException
342     *             if the measurement vector is {@code null}
343     * @throws DimensionMismatchException
344     *             if the dimension of the measurement vector does not fit
345     * @throws SingularMatrixException
346     *             if the covariance matrix could not be inverted
347     */
348    public void correct(final RealVector z)
349            throws NullArgumentException, DimensionMismatchException, SingularMatrixException {
350
351        // sanity checks
352        MathUtils.checkNotNull(z);
353        if (z.getDimension() != measurementMatrix.getRowDimension()) {
354            throw new DimensionMismatchException(z.getDimension(),
355                                                 measurementMatrix.getRowDimension());
356        }
357
358        // S = H * P(k) * H' + R
359        RealMatrix s = measurementMatrix.multiply(errorCovariance)
360            .multiply(measurementMatrixT)
361            .add(measurementModel.getMeasurementNoise());
362
363        // Inn = z(k) - H * xHat(k)-
364        RealVector innovation = z.subtract(measurementMatrix.operate(stateEstimation));
365
366        // calculate gain matrix
367        // K(k) = P(k)- * H' * (H * P(k)- * H' + R)^-1
368        // K(k) = P(k)- * H' * S^-1
369
370        // instead of calculating the inverse of S we can rearrange the formula,
371        // and then solve the linear equation A x X = B with A = S', X = K' and B = (H * P)'
372
373        // K(k) * S = P(k)- * H'
374        // S' * K(k)' = H * P(k)-'
375        RealMatrix kalmanGain = new CholeskyDecomposition(s).getSolver()
376                .solve(measurementMatrix.multiply(errorCovariance.transpose()))
377                .transpose();
378
379        // update estimate with measurement z(k)
380        // xHat(k) = xHat(k)- + K * Inn
381        stateEstimation = stateEstimation.add(kalmanGain.operate(innovation));
382
383        // update covariance of prediction error
384        // P(k) = (I - K * H) * P(k)-
385        RealMatrix identity = MatrixUtils.createRealIdentityMatrix(kalmanGain.getRowDimension());
386        errorCovariance = identity.subtract(kalmanGain.multiply(measurementMatrix)).multiply(errorCovariance);
387    }
388}