KalmanFilter.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.filter;

  18. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  19. import org.apache.commons.math4.legacy.exception.NullArgumentException;
  20. import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
  21. import org.apache.commons.math4.legacy.linear.ArrayRealVector;
  22. import org.apache.commons.math4.legacy.linear.CholeskyDecomposition;
  23. import org.apache.commons.math4.legacy.linear.MatrixDimensionMismatchException;
  24. import org.apache.commons.math4.legacy.linear.MatrixUtils;
  25. import org.apache.commons.math4.legacy.linear.NonSquareMatrixException;
  26. import org.apache.commons.math4.legacy.linear.RealMatrix;
  27. import org.apache.commons.math4.legacy.linear.RealVector;
  28. import org.apache.commons.math4.legacy.linear.SingularMatrixException;

  29. /**
  30.  * Implementation of a Kalman filter to estimate the state <i>x<sub>k</sub></i>
  31.  * of a discrete-time controlled process that is governed by the linear
  32.  * stochastic difference equation:
  33.  *
  34.  * <pre>
  35.  * <i>x<sub>k</sub></i> = <b>A</b><i>x<sub>k-1</sub></i> + <b>B</b><i>u<sub>k-1</sub></i> + <i>w<sub>k-1</sub></i>
  36.  * </pre>
  37.  *
  38.  * with a measurement <i>x<sub>k</sub></i> that is
  39.  *
  40.  * <pre>
  41.  * <i>z<sub>k</sub></i> = <b>H</b><i>x<sub>k</sub></i> + <i>v<sub>k</sub></i>.
  42.  * </pre>
  43.  *
  44.  * <p>
  45.  * The random variables <i>w<sub>k</sub></i> and <i>v<sub>k</sub></i> represent
  46.  * the process and measurement noise and are assumed to be independent of each
  47.  * other and distributed with normal probability (white noise).
  48.  * <p>
  49.  * The Kalman filter cycle involves the following steps:
  50.  * <ol>
  51.  * <li>predict: project the current state estimate ahead in time</li>
  52.  * <li>correct: adjust the projected estimate by an actual measurement</li>
  53.  * </ol>
  54.  * <p>
  55.  * The Kalman filter is initialized with a {@link ProcessModel} and a
  56.  * {@link MeasurementModel}, which contain the corresponding transformation and
  57.  * noise covariance matrices. The parameter names used in the respective models
  58.  * correspond to the following names commonly used in the mathematical
  59.  * literature:
  60.  * <ul>
  61.  * <li>A - state transition matrix</li>
  62.  * <li>B - control input matrix</li>
  63.  * <li>H - measurement matrix</li>
  64.  * <li>Q - process noise covariance matrix</li>
  65.  * <li>R - measurement noise covariance matrix</li>
  66.  * <li>P - error covariance matrix</li>
  67.  * </ul>
  68.  *
  69.  * @see <a href="http://www.cs.unc.edu/~welch/kalman/">Kalman filter
  70.  *      resources</a>
  71.  * @see <a href="http://www.cs.unc.edu/~welch/media/pdf/kalman_intro.pdf">An
  72.  *      introduction to the Kalman filter by Greg Welch and Gary Bishop</a>
  73.  * @see <a href="http://academic.csuohio.edu/simond/courses/eec644/kalman.pdf">
  74.  *      Kalman filter example by Dan Simon</a>
  75.  * @see ProcessModel
  76.  * @see MeasurementModel
  77.  * @since 3.0
  78.  */
  79. public class KalmanFilter {
  80.     /** The process model used by this filter instance. */
  81.     private final ProcessModel processModel;
  82.     /** The measurement model used by this filter instance. */
  83.     private final MeasurementModel measurementModel;
  84.     /** The transition matrix, equivalent to A. */
  85.     private RealMatrix transitionMatrix;
  86.     /** The transposed transition matrix. */
  87.     private RealMatrix transitionMatrixT;
  88.     /** The control matrix, equivalent to B. */
  89.     private RealMatrix controlMatrix;
  90.     /** The measurement matrix, equivalent to H. */
  91.     private RealMatrix measurementMatrix;
  92.     /** The transposed measurement matrix. */
  93.     private RealMatrix measurementMatrixT;
  94.     /** The internal state estimation vector, equivalent to x hat. */
  95.     private RealVector stateEstimation;
  96.     /** The error covariance matrix, equivalent to P. */
  97.     private RealMatrix errorCovariance;

  98.     /**
  99.      * Creates a new Kalman filter with the given process and measurement models.
  100.      *
  101.      * @param process
  102.      *            the model defining the underlying process dynamics
  103.      * @param measurement
  104.      *            the model defining the given measurement characteristics
  105.      * @throws NullArgumentException
  106.      *             if any of the given inputs is null (except for the control matrix)
  107.      * @throws NonSquareMatrixException
  108.      *             if the transition matrix is non square
  109.      * @throws DimensionMismatchException
  110.      *             if the column dimension of the transition matrix does not match the dimension of the
  111.      *             initial state estimation vector
  112.      * @throws MatrixDimensionMismatchException
  113.      *             if the matrix dimensions do not fit together
  114.      */
  115.     public KalmanFilter(final ProcessModel process, final MeasurementModel measurement)
  116.             throws NullArgumentException, NonSquareMatrixException, DimensionMismatchException,
  117.                    MatrixDimensionMismatchException {

  118.         NullArgumentException.check(process);
  119.         NullArgumentException.check(measurement);

  120.         this.processModel = process;
  121.         this.measurementModel = measurement;

  122.         transitionMatrix = processModel.getStateTransitionMatrix();
  123.         NullArgumentException.check(transitionMatrix);
  124.         transitionMatrixT = transitionMatrix.transpose();

  125.         // create an empty matrix if no control matrix was given
  126.         if (processModel.getControlMatrix() == null) {
  127.             controlMatrix = new Array2DRowRealMatrix();
  128.         } else {
  129.             controlMatrix = processModel.getControlMatrix();
  130.         }

  131.         measurementMatrix = measurementModel.getMeasurementMatrix();
  132.         NullArgumentException.check(measurementMatrix);
  133.         measurementMatrixT = measurementMatrix.transpose();

  134.         // check that the process and measurement noise matrices are not null
  135.         // they will be directly accessed from the model as they may change
  136.         // over time
  137.         RealMatrix processNoise = processModel.getProcessNoise();
  138.         NullArgumentException.check(processNoise);
  139.         RealMatrix measNoise = measurementModel.getMeasurementNoise();
  140.         NullArgumentException.check(measNoise);

  141.         // set the initial state estimate to a zero vector if it is not
  142.         // available from the process model
  143.         if (processModel.getInitialStateEstimate() == null) {
  144.             stateEstimation = new ArrayRealVector(transitionMatrix.getColumnDimension());
  145.         } else {
  146.             stateEstimation = processModel.getInitialStateEstimate();
  147.         }

  148.         if (transitionMatrix.getColumnDimension() != stateEstimation.getDimension()) {
  149.             throw new DimensionMismatchException(transitionMatrix.getColumnDimension(),
  150.                                                  stateEstimation.getDimension());
  151.         }

  152.         // initialize the error covariance to the process noise if it is not
  153.         // available from the process model
  154.         if (processModel.getInitialErrorCovariance() == null) {
  155.             errorCovariance = processNoise.copy();
  156.         } else {
  157.             errorCovariance = processModel.getInitialErrorCovariance();
  158.         }

  159.         // sanity checks, the control matrix B may be null

  160.         // A must be a square matrix
  161.         if (!transitionMatrix.isSquare()) {
  162.             throw new NonSquareMatrixException(
  163.                     transitionMatrix.getRowDimension(),
  164.                     transitionMatrix.getColumnDimension());
  165.         }

  166.         // row dimension of B must be equal to A
  167.         // if no control matrix is available, the row and column dimension will be 0
  168.         if (controlMatrix != null &&
  169.             controlMatrix.getRowDimension() > 0 &&
  170.             controlMatrix.getColumnDimension() > 0 &&
  171.             controlMatrix.getRowDimension() != transitionMatrix.getRowDimension()) {
  172.             throw new MatrixDimensionMismatchException(controlMatrix.getRowDimension(),
  173.                                                        controlMatrix.getColumnDimension(),
  174.                                                        transitionMatrix.getRowDimension(),
  175.                                                        controlMatrix.getColumnDimension());
  176.         }

  177.         // Q must be equal to A
  178.         MatrixUtils.checkAdditionCompatible(transitionMatrix, processNoise);

  179.         // column dimension of H must be equal to row dimension of A
  180.         if (measurementMatrix.getColumnDimension() != transitionMatrix.getRowDimension()) {
  181.             throw new MatrixDimensionMismatchException(measurementMatrix.getRowDimension(),
  182.                                                        measurementMatrix.getColumnDimension(),
  183.                                                        measurementMatrix.getRowDimension(),
  184.                                                        transitionMatrix.getRowDimension());
  185.         }

  186.         // row dimension of R must be equal to row dimension of H
  187.         if (measNoise.getRowDimension() != measurementMatrix.getRowDimension()) {
  188.             throw new MatrixDimensionMismatchException(measNoise.getRowDimension(),
  189.                                                        measNoise.getColumnDimension(),
  190.                                                        measurementMatrix.getRowDimension(),
  191.                                                        measNoise.getColumnDimension());
  192.         }
  193.     }

  194.     /**
  195.      * Returns the dimension of the state estimation vector.
  196.      *
  197.      * @return the state dimension
  198.      */
  199.     public int getStateDimension() {
  200.         return stateEstimation.getDimension();
  201.     }

  202.     /**
  203.      * Returns the dimension of the measurement vector.
  204.      *
  205.      * @return the measurement vector dimension
  206.      */
  207.     public int getMeasurementDimension() {
  208.         return measurementMatrix.getRowDimension();
  209.     }

  210.     /**
  211.      * Returns the current state estimation vector.
  212.      *
  213.      * @return the state estimation vector
  214.      */
  215.     public double[] getStateEstimation() {
  216.         return stateEstimation.toArray();
  217.     }

  218.     /**
  219.      * Returns a copy of the current state estimation vector.
  220.      *
  221.      * @return the state estimation vector
  222.      */
  223.     public RealVector getStateEstimationVector() {
  224.         return stateEstimation.copy();
  225.     }

  226.     /**
  227.      * Returns the current error covariance matrix.
  228.      *
  229.      * @return the error covariance matrix
  230.      */
  231.     public double[][] getErrorCovariance() {
  232.         return errorCovariance.getData();
  233.     }

  234.     /**
  235.      * Returns a copy of the current error covariance matrix.
  236.      *
  237.      * @return the error covariance matrix
  238.      */
  239.     public RealMatrix getErrorCovarianceMatrix() {
  240.         return errorCovariance.copy();
  241.     }

  242.     /**
  243.      * Predict the internal state estimation one time step ahead.
  244.      */
  245.     public void predict() {
  246.         predict((RealVector) null);
  247.     }

  248.     /**
  249.      * Predict the internal state estimation one time step ahead.
  250.      *
  251.      * @param u
  252.      *            the control vector
  253.      * @throws DimensionMismatchException
  254.      *             if the dimension of the control vector does not fit
  255.      */
  256.     public void predict(final double[] u) throws DimensionMismatchException {
  257.         predict(new ArrayRealVector(u, false));
  258.     }

  259.     /**
  260.      * Predict the internal state estimation one time step ahead.
  261.      *
  262.      * @param u
  263.      *            the control vector
  264.      * @throws DimensionMismatchException
  265.      *             if the dimension of the control vector does not match
  266.      */
  267.     public void predict(final RealVector u) throws DimensionMismatchException {
  268.         // sanity checks
  269.         if (u != null &&
  270.             u.getDimension() != controlMatrix.getColumnDimension()) {
  271.             throw new DimensionMismatchException(u.getDimension(),
  272.                                                  controlMatrix.getColumnDimension());
  273.         }

  274.         // project the state estimation ahead (a priori state)
  275.         // xHat(k)- = A * xHat(k-1) + B * u(k-1)
  276.         stateEstimation = transitionMatrix.operate(stateEstimation);

  277.         // add control input if it is available
  278.         if (u != null) {
  279.             stateEstimation = stateEstimation.add(controlMatrix.operate(u));
  280.         }

  281.         // project the error covariance ahead
  282.         // P(k)- = A * P(k-1) * A' + Q
  283.         errorCovariance = transitionMatrix.multiply(errorCovariance)
  284.                 .multiply(transitionMatrixT)
  285.                 .add(processModel.getProcessNoise());
  286.     }

  287.     /**
  288.      * Correct the current state estimate with an actual measurement.
  289.      *
  290.      * @param z
  291.      *            the measurement vector
  292.      * @throws NullArgumentException
  293.      *             if the measurement vector is {@code null}
  294.      * @throws DimensionMismatchException
  295.      *             if the dimension of the measurement vector does not fit
  296.      * @throws SingularMatrixException
  297.      *             if the covariance matrix could not be inverted
  298.      */
  299.     public void correct(final double[] z)
  300.             throws NullArgumentException, DimensionMismatchException, SingularMatrixException {
  301.         correct(new ArrayRealVector(z, false));
  302.     }

  303.     /**
  304.      * Correct the current state estimate with an actual measurement.
  305.      *
  306.      * @param z
  307.      *            the measurement vector
  308.      * @throws NullArgumentException
  309.      *             if the measurement vector is {@code null}
  310.      * @throws DimensionMismatchException
  311.      *             if the dimension of the measurement vector does not fit
  312.      * @throws SingularMatrixException
  313.      *             if the covariance matrix could not be inverted
  314.      */
  315.     public void correct(final RealVector z)
  316.             throws NullArgumentException, DimensionMismatchException, SingularMatrixException {

  317.         // sanity checks
  318.         NullArgumentException.check(z);
  319.         if (z.getDimension() != measurementMatrix.getRowDimension()) {
  320.             throw new DimensionMismatchException(z.getDimension(),
  321.                                                  measurementMatrix.getRowDimension());
  322.         }

  323.         // S = H * P(k) * H' + R
  324.         RealMatrix s = measurementMatrix.multiply(errorCovariance)
  325.             .multiply(measurementMatrixT)
  326.             .add(measurementModel.getMeasurementNoise());

  327.         // Inn = z(k) - H * xHat(k)-
  328.         RealVector innovation = z.subtract(measurementMatrix.operate(stateEstimation));

  329.         // calculate gain matrix
  330.         // K(k) = P(k)- * H' * (H * P(k)- * H' + R)^-1
  331.         // K(k) = P(k)- * H' * S^-1

  332.         // instead of calculating the inverse of S we can rearrange the formula,
  333.         // and then solve the linear equation A x X = B with A = S', X = K' and B = (H * P)'

  334.         // K(k) * S = P(k)- * H'
  335.         // S' * K(k)' = H * P(k)-'
  336.         RealMatrix kalmanGain = new CholeskyDecomposition(s).getSolver()
  337.                 .solve(measurementMatrix.multiply(errorCovariance.transpose()))
  338.                 .transpose();

  339.         // update estimate with measurement z(k)
  340.         // xHat(k) = xHat(k)- + K * Inn
  341.         stateEstimation = stateEstimation.add(kalmanGain.operate(innovation));

  342.         // update covariance of prediction error
  343.         // P(k) = (I - K * H) * P(k)-
  344.         RealMatrix identity = MatrixUtils.createRealIdentityMatrix(kalmanGain.getRowDimension());
  345.         errorCovariance = identity.subtract(kalmanGain.multiply(measurementMatrix)).multiply(errorCovariance);
  346.     }
  347. }