001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math3.filter;
018    
019    import org.apache.commons.math3.exception.DimensionMismatchException;
020    import org.apache.commons.math3.exception.NullArgumentException;
021    import org.apache.commons.math3.linear.Array2DRowRealMatrix;
022    import org.apache.commons.math3.linear.ArrayRealVector;
023    import org.apache.commons.math3.linear.CholeskyDecomposition;
024    import org.apache.commons.math3.linear.DecompositionSolver;
025    import org.apache.commons.math3.linear.MatrixDimensionMismatchException;
026    import org.apache.commons.math3.linear.MatrixUtils;
027    import org.apache.commons.math3.linear.NonSquareMatrixException;
028    import org.apache.commons.math3.linear.RealMatrix;
029    import org.apache.commons.math3.linear.RealVector;
030    import org.apache.commons.math3.linear.SingularMatrixException;
031    import org.apache.commons.math3.util.MathUtils;
032    
033    /**
034     * Implementation of a Kalman filter to estimate the state <i>x<sub>k</sub></i>
035     * of a discrete-time controlled process that is governed by the linear
036     * stochastic difference equation:
037     *
038     * <pre>
039     * <i>x<sub>k</sub></i> = <b>A</b><i>x<sub>k-1</sub></i> + <b>B</b><i>u<sub>k-1</sub></i> + <i>w<sub>k-1</sub></i>
040     * </pre>
041     *
042     * with a measurement <i>x<sub>k</sub></i> that is
043     *
044     * <pre>
045     * <i>z<sub>k</sub></i> = <b>H</b><i>x<sub>k</sub></i> + <i>v<sub>k</sub></i>.
046     * </pre>
047     *
048     * <p>
049     * The random variables <i>w<sub>k</sub></i> and <i>v<sub>k</sub></i> represent
050     * the process and measurement noise and are assumed to be independent of each
051     * other and distributed with normal probability (white noise).
052     * <p>
053     * The Kalman filter cycle involves the following steps:
054     * <ol>
055     * <li>predict: project the current state estimate ahead in time</li>
056     * <li>correct: adjust the projected estimate by an actual measurement</li>
057     * </ol>
058     * <p>
059     * The Kalman filter is initialized with a {@link ProcessModel} and a
060     * {@link MeasurementModel}, which contain the corresponding transformation and
061     * noise covariance matrices. The parameter names used in the respective models
062     * correspond to the following names commonly used in the mathematical
063     * literature:
064     * <ul>
065     * <li>A - state transition matrix</li>
066     * <li>B - control input matrix</li>
067     * <li>H - measurement matrix</li>
068     * <li>Q - process noise covariance matrix</li>
069     * <li>R - measurement noise covariance matrix</li>
070     * <li>P - error covariance matrix</li>
071     * </ul>
072     *
073     * @see <a href="http://www.cs.unc.edu/~welch/kalman/">Kalman filter
074     *      resources</a>
075     * @see <a href="http://www.cs.unc.edu/~welch/media/pdf/kalman_intro.pdf">An
076     *      introduction to the Kalman filter by Greg Welch and Gary Bishop</a>
077     * @see <a href="http://academic.csuohio.edu/simond/courses/eec644/kalman.pdf">
078     *      Kalman filter example by Dan Simon</a>
079     * @see ProcessModel
080     * @see MeasurementModel
081     * @since 3.0
082     * @version $Id: KalmanFilter.java 1416643 2012-12-03 19:37:14Z tn $
083     */
084    public class KalmanFilter {
085        /** The process model used by this filter instance. */
086        private final ProcessModel processModel;
087        /** The measurement model used by this filter instance. */
088        private final MeasurementModel measurementModel;
089        /** The transition matrix, equivalent to A. */
090        private RealMatrix transitionMatrix;
091        /** The transposed transition matrix. */
092        private RealMatrix transitionMatrixT;
093        /** The control matrix, equivalent to B. */
094        private RealMatrix controlMatrix;
095        /** The measurement matrix, equivalent to H. */
096        private RealMatrix measurementMatrix;
097        /** The transposed measurement matrix. */
098        private RealMatrix measurementMatrixT;
099        /** The internal state estimation vector, equivalent to x hat. */
100        private RealVector stateEstimation;
101        /** The error covariance matrix, equivalent to P. */
102        private RealMatrix errorCovariance;
103    
104        /**
105         * Creates a new Kalman filter with the given process and measurement models.
106         *
107         * @param process
108         *            the model defining the underlying process dynamics
109         * @param measurement
110         *            the model defining the given measurement characteristics
111         * @throws NullArgumentException
112         *             if any of the given inputs is null (except for the control matrix)
113         * @throws NonSquareMatrixException
114         *             if the transition matrix is non square
115         * @throws DimensionMismatchException
116         *             if the column dimension of the transition matrix does not match the dimension of the
117         *             initial state estimation vector
118         * @throws MatrixDimensionMismatchException
119         *             if the matrix dimensions do not fit together
120         */
121        public KalmanFilter(final ProcessModel process, final MeasurementModel measurement)
122                throws NullArgumentException, NonSquareMatrixException, DimensionMismatchException,
123                       MatrixDimensionMismatchException {
124    
125            MathUtils.checkNotNull(process);
126            MathUtils.checkNotNull(measurement);
127    
128            this.processModel = process;
129            this.measurementModel = measurement;
130    
131            transitionMatrix = processModel.getStateTransitionMatrix();
132            MathUtils.checkNotNull(transitionMatrix);
133            transitionMatrixT = transitionMatrix.transpose();
134    
135            // create an empty matrix if no control matrix was given
136            if (processModel.getControlMatrix() == null) {
137                controlMatrix = new Array2DRowRealMatrix();
138            } else {
139                controlMatrix = processModel.getControlMatrix();
140            }
141    
142            measurementMatrix = measurementModel.getMeasurementMatrix();
143            MathUtils.checkNotNull(measurementMatrix);
144            measurementMatrixT = measurementMatrix.transpose();
145    
146            // check that the process and measurement noise matrices are not null
147            // they will be directly accessed from the model as they may change
148            // over time
149            RealMatrix processNoise = processModel.getProcessNoise();
150            MathUtils.checkNotNull(processNoise);
151            RealMatrix measNoise = measurementModel.getMeasurementNoise();
152            MathUtils.checkNotNull(measNoise);
153    
154            // set the initial state estimate to a zero vector if it is not
155            // available from the process model
156            if (processModel.getInitialStateEstimate() == null) {
157                stateEstimation = new ArrayRealVector(transitionMatrix.getColumnDimension());
158            } else {
159                stateEstimation = processModel.getInitialStateEstimate();
160            }
161    
162            if (transitionMatrix.getColumnDimension() != stateEstimation.getDimension()) {
163                throw new DimensionMismatchException(transitionMatrix.getColumnDimension(),
164                                                     stateEstimation.getDimension());
165            }
166    
167            // initialize the error covariance to the process noise if it is not
168            // available from the process model
169            if (processModel.getInitialErrorCovariance() == null) {
170                errorCovariance = processNoise.copy();
171            } else {
172                errorCovariance = processModel.getInitialErrorCovariance();
173            }
174    
175            // sanity checks, the control matrix B may be null
176    
177            // A must be a square matrix
178            if (!transitionMatrix.isSquare()) {
179                throw new NonSquareMatrixException(
180                        transitionMatrix.getRowDimension(),
181                        transitionMatrix.getColumnDimension());
182            }
183    
184            // row dimension of B must be equal to A
185            if (controlMatrix != null &&
186                controlMatrix.getRowDimension() > 0 &&
187                controlMatrix.getColumnDimension() > 0 &&
188                (controlMatrix.getRowDimension() != transitionMatrix.getRowDimension() ||
189                 controlMatrix.getColumnDimension() != 1)) {
190                throw new MatrixDimensionMismatchException(controlMatrix.getRowDimension(),
191                                                           controlMatrix.getColumnDimension(),
192                                                           transitionMatrix.getRowDimension(), 1);
193            }
194    
195            // Q must be equal to A
196            MatrixUtils.checkAdditionCompatible(transitionMatrix, processNoise);
197    
198            // column dimension of H must be equal to row dimension of A
199            if (measurementMatrix.getColumnDimension() != transitionMatrix.getRowDimension()) {
200                throw new MatrixDimensionMismatchException(measurementMatrix.getRowDimension(),
201                                                           measurementMatrix.getColumnDimension(),
202                                                           measurementMatrix.getRowDimension(),
203                                                           transitionMatrix.getRowDimension());
204            }
205    
206            // row dimension of R must be equal to row dimension of H
207            if (measNoise.getRowDimension() != measurementMatrix.getRowDimension() ||
208                measNoise.getColumnDimension() != 1) {
209                throw new MatrixDimensionMismatchException(measNoise.getRowDimension(),
210                                                           measNoise.getColumnDimension(),
211                                                           measurementMatrix.getRowDimension(), 1);
212            }
213        }
214    
215        /**
216         * Returns the dimension of the state estimation vector.
217         *
218         * @return the state dimension
219         */
220        public int getStateDimension() {
221            return stateEstimation.getDimension();
222        }
223    
224        /**
225         * Returns the dimension of the measurement vector.
226         *
227         * @return the measurement vector dimension
228         */
229        public int getMeasurementDimension() {
230            return measurementMatrix.getRowDimension();
231        }
232    
233        /**
234         * Returns the current state estimation vector.
235         *
236         * @return the state estimation vector
237         */
238        public double[] getStateEstimation() {
239            return stateEstimation.toArray();
240        }
241    
242        /**
243         * Returns a copy of the current state estimation vector.
244         *
245         * @return the state estimation vector
246         */
247        public RealVector getStateEstimationVector() {
248            return stateEstimation.copy();
249        }
250    
251        /**
252         * Returns the current error covariance matrix.
253         *
254         * @return the error covariance matrix
255         */
256        public double[][] getErrorCovariance() {
257            return errorCovariance.getData();
258        }
259    
260        /**
261         * Returns a copy of the current error covariance matrix.
262         *
263         * @return the error covariance matrix
264         */
265        public RealMatrix getErrorCovarianceMatrix() {
266            return errorCovariance.copy();
267        }
268    
269        /**
270         * Predict the internal state estimation one time step ahead.
271         */
272        public void predict() {
273            predict((RealVector) null);
274        }
275    
276        /**
277         * Predict the internal state estimation one time step ahead.
278         *
279         * @param u
280         *            the control vector
281         * @throws DimensionMismatchException
282         *             if the dimension of the control vector does not fit
283         */
284        public void predict(final double[] u) throws DimensionMismatchException {
285            predict(new ArrayRealVector(u));
286        }
287    
288        /**
289         * Predict the internal state estimation one time step ahead.
290         *
291         * @param u
292         *            the control vector
293         * @throws DimensionMismatchException
294         *             if the dimension of the control vector does not match
295         */
296        public void predict(final RealVector u) throws DimensionMismatchException {
297            // sanity checks
298            if (u != null &&
299                u.getDimension() != controlMatrix.getColumnDimension()) {
300                throw new DimensionMismatchException(u.getDimension(),
301                                                     controlMatrix.getColumnDimension());
302            }
303    
304            // project the state estimation ahead (a priori state)
305            // xHat(k)- = A * xHat(k-1) + B * u(k-1)
306            stateEstimation = transitionMatrix.operate(stateEstimation);
307    
308            // add control input if it is available
309            if (u != null) {
310                stateEstimation = stateEstimation.add(controlMatrix.operate(u));
311            }
312    
313            // project the error covariance ahead
314            // P(k)- = A * P(k-1) * A' + Q
315            errorCovariance = transitionMatrix.multiply(errorCovariance)
316                    .multiply(transitionMatrixT)
317                    .add(processModel.getProcessNoise());
318        }
319    
320        /**
321         * Correct the current state estimate with an actual measurement.
322         *
323         * @param z
324         *            the measurement vector
325         * @throws NullArgumentException
326         *             if the measurement vector is {@code null}
327         * @throws DimensionMismatchException
328         *             if the dimension of the measurement vector does not fit
329         * @throws SingularMatrixException
330         *             if the covariance matrix could not be inverted
331         */
332        public void correct(final double[] z)
333                throws NullArgumentException, DimensionMismatchException, SingularMatrixException {
334            correct(new ArrayRealVector(z));
335        }
336    
337        /**
338         * Correct the current state estimate with an actual measurement.
339         *
340         * @param z
341         *            the measurement vector
342         * @throws NullArgumentException
343         *             if the measurement vector is {@code null}
344         * @throws DimensionMismatchException
345         *             if the dimension of the measurement vector does not fit
346         * @throws SingularMatrixException
347         *             if the covariance matrix could not be inverted
348         */
349        public void correct(final RealVector z)
350                throws NullArgumentException, DimensionMismatchException, SingularMatrixException {
351    
352            // sanity checks
353            MathUtils.checkNotNull(z);
354            if (z.getDimension() != measurementMatrix.getRowDimension()) {
355                throw new DimensionMismatchException(z.getDimension(),
356                                                     measurementMatrix.getRowDimension());
357            }
358    
359            // S = H * P(k) - * H' + R
360            RealMatrix s = measurementMatrix.multiply(errorCovariance)
361                .multiply(measurementMatrixT)
362                .add(measurementModel.getMeasurementNoise());
363    
364            // invert S
365            // as the error covariance matrix is a symmetric positive
366            // semi-definite matrix, we can use the cholesky decomposition
367            DecompositionSolver solver = new CholeskyDecomposition(s).getSolver();
368            RealMatrix invertedS = solver.getInverse();
369    
370            // Inn = z(k) - H * xHat(k)-
371            RealVector innovation = z.subtract(measurementMatrix.operate(stateEstimation));
372    
373            // calculate gain matrix
374            // K(k) = P(k)- * H' * (H * P(k)- * H' + R)^-1
375            // K(k) = P(k)- * H' * S^-1
376            RealMatrix kalmanGain = errorCovariance.multiply(measurementMatrixT).multiply(invertedS);
377    
378            // update estimate with measurement z(k)
379            // xHat(k) = xHat(k)- + K * Inn
380            stateEstimation = stateEstimation.add(kalmanGain.operate(innovation));
381    
382            // update covariance of prediction error
383            // P(k) = (I - K * H) * P(k)-
384            RealMatrix identity = MatrixUtils.createRealIdentityMatrix(kalmanGain.getRowDimension());
385            errorCovariance = identity.subtract(kalmanGain.multiply(measurementMatrix)).multiply(errorCovariance);
386        }
387    }