SquareMatrixSupport.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.statistics.inference;

  18. /**
  19.  * Provide support for square matrix basic algebraic operations.
  20.  *
  21.  * <p>Matrix element indexing is 0-based e.g. {@code get(0, 0)}
  22.  * returns the element in the first row, first column of the matrix.
  23.  *
  24.  * <p>This class supports computations in the {@link KolmogorovSmirnovTest}.
  25.  *
  26.  * @since 1.1
  27.  */
  28. final class SquareMatrixSupport {
  29.     /**
  30.      * Define a real-valued square matrix.
  31.      *
  32.      * <p>This matrix supports a scale to protect against overflow. The true value
  33.      * of any matrix value is multiplied by {@code 2^scale}. This is readily performed
  34.      * using {@link Math#scalb(double, int)}.
  35.      */
  36.     interface RealSquareMatrix {
  37.         /**
  38.          * Gets the dimension for the rows and columns.
  39.          *
  40.          * @return the dimension
  41.          */
  42.         int dimension();

  43.         /**
  44.          * Gets the scale of the matrix values.
  45.          * The true value is the value returned from {@link #get(int, int)} multiplied by
  46.          * {@code 2^scale}.
  47.          *
  48.          * @return the scale
  49.          */
  50.         int scale();

  51.         /**
  52.          * Gets the value. This is a scaled value. The true value is the value returned
  53.          * multiplied by {@code 2^scale}.
  54.          *
  55.          * @param i Row
  56.          * @param j Column
  57.          * @return the value
  58.          * @see #scale
  59.          */
  60.         double get(int i, int j);

  61.         /**
  62.          * Returns the result of multiplying {@code this} with itself {@code n} times.
  63.          *
  64.          * @param n raise {@code this} to power {@code n}
  65.          * @return {@code this^n}
  66.          * @throws IllegalArgumentException if {@code n < 0}
  67.          */
  68.         RealSquareMatrix power(int n);
  69.     }

  70.     /**
  71.      * Implementation of {@link RealSquareMatrix} using a {@code double[]} array to
  72.      * store entries. Values are addressed using {@code i*dim + j} where {@code dim} is
  73.      * the square dimension.
  74.      *
  75.      * <p>Scaling is supported using the central element {@code [m][m]} where
  76.      * {@code m = dimension/2}. Scaling is only implemented post-multiplication
  77.      * to protect against overflow during repeat multiplication operations.
  78.      *
  79.      * <p>Note: The scaling is implemented to support computation of Kolmogorov's
  80.      * distribution as described in:
  81.      * <ul>
  82.      * <li>
  83.      * Marsaglia, G., Tsang, W. W., &amp; Wang, J. (2003).
  84.      * <a href="https://doi.org/10.18637/jss.v008.i18">Evaluating Kolmogorov's Distribution.</a>
  85.      * Journal of Statistical Software, 8(18), 1–4.
  86.      * </ul>
  87.      */
  88.     private static class ArrayRealSquareMatrix implements RealSquareMatrix {
  89.         /** The scaling threshold. Marsaglia used 1e140. This uses 2^400 ~ 2.58e120 */
  90.         private static final double SCALE_THRESHOLD = 0x1.0p400;
  91.         /** Dimension. */
  92.         private final int dim;
  93.         /** Entries of the matrix. */
  94.         private final double[] data;
  95.         /** Matrix scale. */
  96.         private final int exp;

  97.         /**
  98.          * @param dimension Matrix dimension.
  99.          * @param data Matrix data.
  100.          * @param scale Matrix scale.
  101.          */
  102.         ArrayRealSquareMatrix(int dimension, double[] data, int scale) {
  103.             this.dim = dimension;
  104.             this.data = data;
  105.             this.exp = scale;
  106.         }

  107.         @Override
  108.         public int dimension() {
  109.             return dim;
  110.         }

  111.         @Override
  112.         public int scale() {
  113.             return exp;
  114.         }

  115.         @Override
  116.         public double get(int i, int j) {
  117.             return data[i * dim + j];
  118.         }

  119.         @Override
  120.         public RealSquareMatrix power(int n) {
  121.             checkExponent(n);
  122.             if (n == 0) {
  123.                 return identity();
  124.             }
  125.             if (n == 1) {
  126.                 return this;
  127.             }

  128.             // Here at least 1 multiplication occurs.
  129.             // Compute the power by repeat squaring and multiplication:
  130.             // 13 = 1101
  131.             // x^13 = x^8 * x^4 * x^1
  132.             //      = ((x^2 * x)^2)^2 * x
  133.             // 21 = 10101
  134.             // x^21 = x^16 * x^4 * x^1
  135.             //      = (((x^2)^2 * x)^2)^2 * x
  136.             // 1. Find highest set bit in n
  137.             // 2. Initialise result as x
  138.             // 3. For remaining bits (0 or 1) below the highest set bit:
  139.             //    - square the current result
  140.             //    - if the current bit is 1 then multiply by x
  141.             // In this scheme we require 2 matrix array allocations and a column array.

  142.             // Working arrays
  143.             final double[] col = new double[dim];
  144.             double[] b = new double[data.length];
  145.             double[] tmp;

  146.             // Initialise result as A^1.
  147.             final double[] a = data;
  148.             final int ea = exp;
  149.             double[] r = a.clone();
  150.             int er = ea;

  151.             // Shift the highest set bit off the top.
  152.             // Any remaining bits are detected in the sign bit.
  153.             final int shift = Integer.numberOfLeadingZeros(n) + 1;
  154.             int bits = n << shift;

  155.             // Process remaining bits below highest set bit.
  156.             for (int i = 32 - shift; i != 0; i--, bits <<= 1) {
  157.                 // Square the result
  158.                 er = multiply(r, er, r, er, col, b);
  159.                 // Recycle working array
  160.                 tmp = b;
  161.                 b = r;
  162.                 r = tmp;
  163.                 if (bits < 0) {
  164.                     // Multiply by A
  165.                     er = multiply(r, er, a, ea, col, b);
  166.                     // Recycle working array
  167.                     tmp = b;
  168.                     b = r;
  169.                     r = tmp;
  170.                 }
  171.             }

  172.             return new ArrayRealSquareMatrix(dim, r, er);
  173.         }

  174.         /**
  175.          * Creates the identity matrix I with the same dimension as {@code this}.
  176.          *
  177.          * @return I
  178.          */
  179.         private RealSquareMatrix identity() {
  180.             final int n = dimension();
  181.             return new RealSquareMatrix() {
  182.                 @Override
  183.                 public int dimension() {
  184.                     return n;
  185.                 }

  186.                 @Override
  187.                 public int scale() {
  188.                     return 0;
  189.                 }

  190.                 @Override
  191.                 public double get(int i, int j) {
  192.                     return i == j ? 1 : 0;
  193.                 }

  194.                 @Override
  195.                 public RealSquareMatrix power(int p) {
  196.                     return this;
  197.                 }
  198.             };
  199.         }

  200.         /**
  201.          * Returns the result of postmultiplying {@code a} by {@code b}. It is expected
  202.          * the scale of the result will be the sum of the scale of the arguments; this
  203.          * may be adjusted by the scale power if the result is scaled by a power of two
  204.          * for overflow protection.
  205.          *
  206.          * @param a Matrix.
  207.          * @param sa Scale of matrix a.
  208.          * @param b Matrix to postmultiply by.
  209.          * @param sb Scale of matrix b.
  210.          * @param col Working array for a column of the matrix.
  211.          * @param out Output {@code a * b}
  212.          * @return Scale of {@code a * b}
  213.          */
  214.         private static int multiply(double[] a, int sa, double[] b, int sb, double[] col, double[] out) {
  215.             final int m = col.length;
  216.             // Rows are contiguous; Columns are non-contiguous
  217.             int k;
  218.             for (int c = 0; c < m; c++) {
  219.                 // Extract column from b to contiguous memory
  220.                 k = c;
  221.                 for (int i = 0; i < m; i++, k += m) {
  222.                     col[i] = b[k];
  223.                 }
  224.                 // row * col
  225.                 k = 0;
  226.                 for (int r = 0; r < m; r++) {
  227.                     double sum = 0;
  228.                     for (int i = 0; i < m; i++, k++) {
  229.                         sum += a[k] * col[i];
  230.                     }
  231.                     out[r * m + c] = sum;
  232.                 }
  233.             }
  234.             int s = sa + sb;
  235.             // Overflow protection. Ideally we would check all elements but for speed
  236.             // we check the central one only.
  237.             k = m >> 1;
  238.             if (out[k * m + k] > SCALE_THRESHOLD) {
  239.                 // Downscale
  240.                 // We could downscale by the inverse of SCALE_THRESHOLD.
  241.                 // However this does not account for how far above the threshold
  242.                 // the central element is. Here we downscale so the central element
  243.                 // is roughly 1 allowing other elements to be larger and still protected
  244.                 // from overflow.
  245.                 final int exp = Math.getExponent(out[k * m + k]);
  246.                 final double downScale = Math.scalb(1.0, -exp);
  247.                 s += exp;
  248.                 for (int i = 0; i < out.length; i++) {
  249.                     out[i] *= downScale;
  250.                 }
  251.             }
  252.             return s;
  253.         }

  254.         /**
  255.          * Check the exponent is not negative.
  256.          *
  257.          * @param p Exponent.
  258.          * @throws IllegalArgumentException if the exponent is negative
  259.          */
  260.         private static void checkExponent(int p) {
  261.             if (p < 0) {
  262.                 throw new IllegalArgumentException("Not positive exponent: " + p);
  263.             }
  264.         }
  265.     }

  266.     /** No instances. */
  267.     private SquareMatrixSupport() {}

  268.     /**
  269.      * Creates a square matrix. Data may be used in-place.
  270.      *
  271.      * <p>Values are addressed using {@code a[i][j] = i*dimension + j}.
  272.      *
  273.      * @param dimension Matrix dimension.
  274.      * @param data Matrix data.
  275.      * @return the square matrix
  276.      * @throws IllegalArgumentException if the matrix data is not square (length = dimension * dimension)
  277.      */
  278.     static RealSquareMatrix create(int dimension, double[] data) {
  279.         if (dimension * dimension != data.length) {
  280.             // Note: %<d is 'relative indexing' to re-use the last argument
  281.             throw new IllegalArgumentException(String.format("Not square: %d * %<d != %d", dimension, data.length));
  282.         }
  283.         return new ArrayRealSquareMatrix(dimension, data, 0);
  284.     }
  285. }