CorrelatedVectorFactory.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.random;

  18. import java.util.function.Supplier;

  19. import org.apache.commons.rng.UniformRandomProvider;
  20. import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
  21. import org.apache.commons.rng.sampling.distribution.ContinuousUniformSampler;
  22. import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
  23. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  24. import org.apache.commons.math4.core.jdkmath.JdkMath;
  25. import org.apache.commons.math4.legacy.linear.RealMatrix;
  26. import org.apache.commons.math4.legacy.linear.RectangularCholeskyDecomposition;

  27. /**
  28.  * Generates vectors with with correlated components.
  29.  *
  30.  * <p>Random vectors with correlated components are built by combining
  31.  * the uncorrelated components of another random vector in such a way
  32.  * that the resulting correlations are the ones specified by a positive
  33.  * definite covariance matrix.</p>
  34.  *
  35.  * <p>The main use of correlated random vector generation is for Monte-Carlo
  36.  * simulation of physical problems with several variables (for example to
  37.  * generate error vectors to be added to a nominal vector). A particularly
  38.  * common case is when the generated vector should be drawn from a
  39.  * <a href="http://en.wikipedia.org/wiki/Multivariate_normal_distribution">
  40.  * Multivariate Normal Distribution</a>, usually using Cholesky decomposition.
  41.  * Other distributions are possible as long as the underlying sampler provides
  42.  * normalized values (unit standard deviation).</p>
  43.  *
  44.  * <p>Sometimes, the covariance matrix for a given simulation is not
  45.  * strictly positive definite. This means that the correlations are
  46.  * not all independent from each other. In this case, however, the non
  47.  * strictly positive elements found during the Cholesky decomposition
  48.  * of the covariance matrix should not be negative either, they
  49.  * should be null. Another non-conventional extension handling this case
  50.  * is used here. Rather than computing <code>C = U<sup>T</sup> U</code>
  51.  * where <code>C</code> is the covariance matrix and <code>U</code>
  52.  * is an upper-triangular matrix, we compute <code>C = B B<sup>T</sup></code>
  53.  * where <code>B</code> is a rectangular matrix having more rows than
  54.  * columns. The number of columns of <code>B</code> is the rank of the
  55.  * covariance matrix, and it is the dimension of the uncorrelated
  56.  * random vector that is needed to compute the component of the
  57.  * correlated vector. This class handles this situation automatically.</p>
  58.  */
  59. public class CorrelatedVectorFactory {
  60.     /** Square root of three. */
  61.     private static final double SQRT3 = JdkMath.sqrt(3);
  62.     /** Mean vector. */
  63.     private final double[] mean;
  64.     /** Root of the covariance matrix. */
  65.     private final RealMatrix root;
  66.     /** Size of uncorrelated vector. */
  67.     private final int lengthUncorrelated;
  68.     /** Size of correlated vector. */
  69.     private final int lengthCorrelated;

  70.     /**
  71.      * Correlated vector factory.
  72.      *
  73.      * @param mean Expected mean values of the components.
  74.      * @param covariance Covariance matrix.
  75.      * @param small Diagonal elements threshold under which columns are
  76.      * considered to be dependent on previous ones and are discarded.
  77.      * @throws org.apache.commons.math4.legacy.linear.NonPositiveDefiniteMatrixException
  78.      * if the covariance matrix is not strictly positive definite.
  79.      * @throws DimensionMismatchException if the mean and covariance
  80.      * arrays dimensions do not match.
  81.      */
  82.     public CorrelatedVectorFactory(double[] mean,
  83.                                    RealMatrix covariance,
  84.                                    double small) {
  85.         lengthCorrelated = covariance.getRowDimension();
  86.         if (mean.length != lengthCorrelated) {
  87.             throw new DimensionMismatchException(mean.length, lengthCorrelated);
  88.         }
  89.         this.mean = mean.clone();

  90.         final RectangularCholeskyDecomposition decomposition
  91.             = new RectangularCholeskyDecomposition(covariance, small);
  92.         root = decomposition.getRootMatrix();

  93.         lengthUncorrelated = decomposition.getRank();
  94.     }

  95.     /**
  96.      * Null mean correlated vector factory.
  97.      *
  98.      * @param covariance Covariance matrix.
  99.      * @param small Diagonal elements threshold under which columns are
  100.      * considered to be dependent on previous ones and are discarded.
  101.      * @throws org.apache.commons.math4.legacy.linear.NonPositiveDefiniteMatrixException
  102.      * if the covariance matrix is not strictly positive definite.
  103.      */
  104.     public CorrelatedVectorFactory(RealMatrix covariance,
  105.                                    double small) {
  106.         this(new double[covariance.getRowDimension()],
  107.              covariance,
  108.              small);
  109.     }

  110.     /**
  111.      * @param rng RNG.
  112.      * @return a generator of vectors with correlated components sampled
  113.      * from a uniform distribution.
  114.      */
  115.     public Supplier<double[]> uniform(UniformRandomProvider rng) {
  116.         return with(new ContinuousUniformSampler(rng, -SQRT3, SQRT3));
  117.     }

  118.     /**
  119.      * @param rng RNG.
  120.      * @return a generator of vectors with correlated components sampled
  121.      * from a normal distribution.
  122.      */
  123.     public Supplier<double[]> gaussian(UniformRandomProvider rng) {
  124.         return with(new ZigguratNormalizedGaussianSampler(rng));
  125.     }

  126.     /**
  127.      * @param sampler Generator of samples from a normalized distribution.
  128.      * @return a generator of vectors with correlated components.
  129.      */
  130.     private Supplier<double[]> with(final ContinuousSampler sampler) {
  131.         return new Supplier<double[]>() {
  132.             @Override
  133.             public double[] get() {
  134.                 // Uncorrelated vector.
  135.                 final double[] uncorrelated = new double[lengthUncorrelated];
  136.                 for (int i = 0; i < lengthUncorrelated; i++) {
  137.                     uncorrelated[i] = sampler.sample();
  138.                 }

  139.                 // Correlated vector.
  140.                 final double[] correlated = mean.clone();
  141.                 for (int i = 0; i < correlated.length; i++) {
  142.                     for (int j = 0; j < lengthUncorrelated; j++) {
  143.                         correlated[i] += root.getEntry(i, j) * uncorrelated[j];
  144.                     }
  145.                 }

  146.                 return correlated;
  147.             }
  148.         };
  149.     }
  150. }