001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.filter;
018
019import org.apache.commons.math3.exception.DimensionMismatchException;
020import org.apache.commons.math3.exception.NullArgumentException;
021import org.apache.commons.math3.linear.Array2DRowRealMatrix;
022import org.apache.commons.math3.linear.ArrayRealVector;
023import org.apache.commons.math3.linear.CholeskyDecomposition;
024import org.apache.commons.math3.linear.DecompositionSolver;
025import org.apache.commons.math3.linear.MatrixDimensionMismatchException;
026import org.apache.commons.math3.linear.MatrixUtils;
027import org.apache.commons.math3.linear.NonSquareMatrixException;
028import org.apache.commons.math3.linear.RealMatrix;
029import org.apache.commons.math3.linear.RealVector;
030import org.apache.commons.math3.linear.SingularMatrixException;
031import org.apache.commons.math3.util.MathUtils;
032
033/**
034 * Implementation of a Kalman filter to estimate the state <i>x<sub>k</sub></i>
035 * of a discrete-time controlled process that is governed by the linear
036 * stochastic difference equation:
037 *
038 * <pre>
039 * <i>x<sub>k</sub></i> = <b>A</b><i>x<sub>k-1</sub></i> + <b>B</b><i>u<sub>k-1</sub></i> + <i>w<sub>k-1</sub></i>
040 * </pre>
041 *
042 * with a measurement <i>x<sub>k</sub></i> that is
043 *
044 * <pre>
045 * <i>z<sub>k</sub></i> = <b>H</b><i>x<sub>k</sub></i> + <i>v<sub>k</sub></i>.
046 * </pre>
047 *
048 * <p>
049 * The random variables <i>w<sub>k</sub></i> and <i>v<sub>k</sub></i> represent
050 * the process and measurement noise and are assumed to be independent of each
051 * other and distributed with normal probability (white noise).
052 * <p>
053 * The Kalman filter cycle involves the following steps:
054 * <ol>
055 * <li>predict: project the current state estimate ahead in time</li>
056 * <li>correct: adjust the projected estimate by an actual measurement</li>
057 * </ol>
058 * <p>
059 * The Kalman filter is initialized with a {@link ProcessModel} and a
060 * {@link MeasurementModel}, which contain the corresponding transformation and
061 * noise covariance matrices. The parameter names used in the respective models
062 * correspond to the following names commonly used in the mathematical
063 * literature:
064 * <ul>
065 * <li>A - state transition matrix</li>
066 * <li>B - control input matrix</li>
067 * <li>H - measurement matrix</li>
068 * <li>Q - process noise covariance matrix</li>
069 * <li>R - measurement noise covariance matrix</li>
070 * <li>P - error covariance matrix</li>
071 * </ul>
072 *
073 * @see <a href="http://www.cs.unc.edu/~welch/kalman/">Kalman filter
074 *      resources</a>
075 * @see <a href="http://www.cs.unc.edu/~welch/media/pdf/kalman_intro.pdf">An
076 *      introduction to the Kalman filter by Greg Welch and Gary Bishop</a>
077 * @see <a href="http://academic.csuohio.edu/simond/courses/eec644/kalman.pdf">
078 *      Kalman filter example by Dan Simon</a>
079 * @see ProcessModel
080 * @see MeasurementModel
081 * @since 3.0
082 * @version $Id: KalmanFilter.java 1531430 2013-10-11 21:39:09Z tn $
083 */
084public class KalmanFilter {
085    /** The process model used by this filter instance. */
086    private final ProcessModel processModel;
087    /** The measurement model used by this filter instance. */
088    private final MeasurementModel measurementModel;
089    /** The transition matrix, equivalent to A. */
090    private RealMatrix transitionMatrix;
091    /** The transposed transition matrix. */
092    private RealMatrix transitionMatrixT;
093    /** The control matrix, equivalent to B. */
094    private RealMatrix controlMatrix;
095    /** The measurement matrix, equivalent to H. */
096    private RealMatrix measurementMatrix;
097    /** The transposed measurement matrix. */
098    private RealMatrix measurementMatrixT;
099    /** The internal state estimation vector, equivalent to x hat. */
100    private RealVector stateEstimation;
101    /** The error covariance matrix, equivalent to P. */
102    private RealMatrix errorCovariance;
103
104    /**
105     * Creates a new Kalman filter with the given process and measurement models.
106     *
107     * @param process
108     *            the model defining the underlying process dynamics
109     * @param measurement
110     *            the model defining the given measurement characteristics
111     * @throws NullArgumentException
112     *             if any of the given inputs is null (except for the control matrix)
113     * @throws NonSquareMatrixException
114     *             if the transition matrix is non square
115     * @throws DimensionMismatchException
116     *             if the column dimension of the transition matrix does not match the dimension of the
117     *             initial state estimation vector
118     * @throws MatrixDimensionMismatchException
119     *             if the matrix dimensions do not fit together
120     */
121    public KalmanFilter(final ProcessModel process, final MeasurementModel measurement)
122            throws NullArgumentException, NonSquareMatrixException, DimensionMismatchException,
123                   MatrixDimensionMismatchException {
124
125        MathUtils.checkNotNull(process);
126        MathUtils.checkNotNull(measurement);
127
128        this.processModel = process;
129        this.measurementModel = measurement;
130
131        transitionMatrix = processModel.getStateTransitionMatrix();
132        MathUtils.checkNotNull(transitionMatrix);
133        transitionMatrixT = transitionMatrix.transpose();
134
135        // create an empty matrix if no control matrix was given
136        if (processModel.getControlMatrix() == null) {
137            controlMatrix = new Array2DRowRealMatrix();
138        } else {
139            controlMatrix = processModel.getControlMatrix();
140        }
141
142        measurementMatrix = measurementModel.getMeasurementMatrix();
143        MathUtils.checkNotNull(measurementMatrix);
144        measurementMatrixT = measurementMatrix.transpose();
145
146        // check that the process and measurement noise matrices are not null
147        // they will be directly accessed from the model as they may change
148        // over time
149        RealMatrix processNoise = processModel.getProcessNoise();
150        MathUtils.checkNotNull(processNoise);
151        RealMatrix measNoise = measurementModel.getMeasurementNoise();
152        MathUtils.checkNotNull(measNoise);
153
154        // set the initial state estimate to a zero vector if it is not
155        // available from the process model
156        if (processModel.getInitialStateEstimate() == null) {
157            stateEstimation = new ArrayRealVector(transitionMatrix.getColumnDimension());
158        } else {
159            stateEstimation = processModel.getInitialStateEstimate();
160        }
161
162        if (transitionMatrix.getColumnDimension() != stateEstimation.getDimension()) {
163            throw new DimensionMismatchException(transitionMatrix.getColumnDimension(),
164                                                 stateEstimation.getDimension());
165        }
166
167        // initialize the error covariance to the process noise if it is not
168        // available from the process model
169        if (processModel.getInitialErrorCovariance() == null) {
170            errorCovariance = processNoise.copy();
171        } else {
172            errorCovariance = processModel.getInitialErrorCovariance();
173        }
174
175        // sanity checks, the control matrix B may be null
176
177        // A must be a square matrix
178        if (!transitionMatrix.isSquare()) {
179            throw new NonSquareMatrixException(
180                    transitionMatrix.getRowDimension(),
181                    transitionMatrix.getColumnDimension());
182        }
183
184        // row dimension of B must be equal to A
185        // if no control matrix is available, the row and column dimension will be 0
186        if (controlMatrix != null &&
187            controlMatrix.getRowDimension() > 0 &&
188            controlMatrix.getColumnDimension() > 0 &&
189            controlMatrix.getRowDimension() != transitionMatrix.getRowDimension()) {
190            throw new MatrixDimensionMismatchException(controlMatrix.getRowDimension(),
191                                                       controlMatrix.getColumnDimension(),
192                                                       transitionMatrix.getRowDimension(),
193                                                       controlMatrix.getColumnDimension());
194        }
195
196        // Q must be equal to A
197        MatrixUtils.checkAdditionCompatible(transitionMatrix, processNoise);
198
199        // column dimension of H must be equal to row dimension of A
200        if (measurementMatrix.getColumnDimension() != transitionMatrix.getRowDimension()) {
201            throw new MatrixDimensionMismatchException(measurementMatrix.getRowDimension(),
202                                                       measurementMatrix.getColumnDimension(),
203                                                       measurementMatrix.getRowDimension(),
204                                                       transitionMatrix.getRowDimension());
205        }
206
207        // row dimension of R must be equal to row dimension of H
208        if (measNoise.getRowDimension() != measurementMatrix.getRowDimension()) {
209            throw new MatrixDimensionMismatchException(measNoise.getRowDimension(),
210                                                       measNoise.getColumnDimension(),
211                                                       measurementMatrix.getRowDimension(),
212                                                       measNoise.getColumnDimension());
213        }
214    }
215
216    /**
217     * Returns the dimension of the state estimation vector.
218     *
219     * @return the state dimension
220     */
221    public int getStateDimension() {
222        return stateEstimation.getDimension();
223    }
224
225    /**
226     * Returns the dimension of the measurement vector.
227     *
228     * @return the measurement vector dimension
229     */
230    public int getMeasurementDimension() {
231        return measurementMatrix.getRowDimension();
232    }
233
234    /**
235     * Returns the current state estimation vector.
236     *
237     * @return the state estimation vector
238     */
239    public double[] getStateEstimation() {
240        return stateEstimation.toArray();
241    }
242
243    /**
244     * Returns a copy of the current state estimation vector.
245     *
246     * @return the state estimation vector
247     */
248    public RealVector getStateEstimationVector() {
249        return stateEstimation.copy();
250    }
251
252    /**
253     * Returns the current error covariance matrix.
254     *
255     * @return the error covariance matrix
256     */
257    public double[][] getErrorCovariance() {
258        return errorCovariance.getData();
259    }
260
261    /**
262     * Returns a copy of the current error covariance matrix.
263     *
264     * @return the error covariance matrix
265     */
266    public RealMatrix getErrorCovarianceMatrix() {
267        return errorCovariance.copy();
268    }
269
270    /**
271     * Predict the internal state estimation one time step ahead.
272     */
273    public void predict() {
274        predict((RealVector) null);
275    }
276
277    /**
278     * Predict the internal state estimation one time step ahead.
279     *
280     * @param u
281     *            the control vector
282     * @throws DimensionMismatchException
283     *             if the dimension of the control vector does not fit
284     */
285    public void predict(final double[] u) throws DimensionMismatchException {
286        predict(new ArrayRealVector(u));
287    }
288
289    /**
290     * Predict the internal state estimation one time step ahead.
291     *
292     * @param u
293     *            the control vector
294     * @throws DimensionMismatchException
295     *             if the dimension of the control vector does not match
296     */
297    public void predict(final RealVector u) throws DimensionMismatchException {
298        // sanity checks
299        if (u != null &&
300            u.getDimension() != controlMatrix.getColumnDimension()) {
301            throw new DimensionMismatchException(u.getDimension(),
302                                                 controlMatrix.getColumnDimension());
303        }
304
305        // project the state estimation ahead (a priori state)
306        // xHat(k)- = A * xHat(k-1) + B * u(k-1)
307        stateEstimation = transitionMatrix.operate(stateEstimation);
308
309        // add control input if it is available
310        if (u != null) {
311            stateEstimation = stateEstimation.add(controlMatrix.operate(u));
312        }
313
314        // project the error covariance ahead
315        // P(k)- = A * P(k-1) * A' + Q
316        errorCovariance = transitionMatrix.multiply(errorCovariance)
317                .multiply(transitionMatrixT)
318                .add(processModel.getProcessNoise());
319    }
320
321    /**
322     * Correct the current state estimate with an actual measurement.
323     *
324     * @param z
325     *            the measurement vector
326     * @throws NullArgumentException
327     *             if the measurement vector is {@code null}
328     * @throws DimensionMismatchException
329     *             if the dimension of the measurement vector does not fit
330     * @throws SingularMatrixException
331     *             if the covariance matrix could not be inverted
332     */
333    public void correct(final double[] z)
334            throws NullArgumentException, DimensionMismatchException, SingularMatrixException {
335        correct(new ArrayRealVector(z));
336    }
337
338    /**
339     * Correct the current state estimate with an actual measurement.
340     *
341     * @param z
342     *            the measurement vector
343     * @throws NullArgumentException
344     *             if the measurement vector is {@code null}
345     * @throws DimensionMismatchException
346     *             if the dimension of the measurement vector does not fit
347     * @throws SingularMatrixException
348     *             if the covariance matrix could not be inverted
349     */
350    public void correct(final RealVector z)
351            throws NullArgumentException, DimensionMismatchException, SingularMatrixException {
352
353        // sanity checks
354        MathUtils.checkNotNull(z);
355        if (z.getDimension() != measurementMatrix.getRowDimension()) {
356            throw new DimensionMismatchException(z.getDimension(),
357                                                 measurementMatrix.getRowDimension());
358        }
359
360        // S = H * P(k) * H' + R
361        RealMatrix s = measurementMatrix.multiply(errorCovariance)
362            .multiply(measurementMatrixT)
363            .add(measurementModel.getMeasurementNoise());
364
365        // invert S
366        // as the error covariance matrix is a symmetric positive
367        // semi-definite matrix, we can use the cholesky decomposition
368        DecompositionSolver solver = new CholeskyDecomposition(s).getSolver();
369        RealMatrix invertedS = solver.getInverse();
370
371        // Inn = z(k) - H * xHat(k)-
372        RealVector innovation = z.subtract(measurementMatrix.operate(stateEstimation));
373
374        // calculate gain matrix
375        // K(k) = P(k)- * H' * (H * P(k)- * H' + R)^-1
376        // K(k) = P(k)- * H' * S^-1
377        RealMatrix kalmanGain = errorCovariance.multiply(measurementMatrixT).multiply(invertedS);
378
379        // update estimate with measurement z(k)
380        // xHat(k) = xHat(k)- + K * Inn
381        stateEstimation = stateEstimation.add(kalmanGain.operate(innovation));
382
383        // update covariance of prediction error
384        // P(k) = (I - K * H) * P(k)-
385        RealMatrix identity = MatrixUtils.createRealIdentityMatrix(kalmanGain.getRowDimension());
386        errorCovariance = identity.subtract(kalmanGain.multiply(measurementMatrix)).multiply(errorCovariance);
387    }
388}