SquareMatrixSupport.java
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- package org.apache.commons.statistics.inference;
- /**
- * Provide support for square matrix basic algebraic operations.
- *
- * <p>Matrix element indexing is 0-based e.g. {@code get(0, 0)}
- * returns the element in the first row, first column of the matrix.
- *
- * <p>This class supports computations in the {@link KolmogorovSmirnovTest}.
- *
- * @since 1.1
- */
- final class SquareMatrixSupport {
- /**
- * Define a real-valued square matrix.
- *
- * <p>This matrix supports a scale to protect against overflow. The true value
- * of any matrix value is multiplied by {@code 2^scale}. This is readily performed
- * using {@link Math#scalb(double, int)}.
- */
- interface RealSquareMatrix {
- /**
- * Gets the dimension for the rows and columns.
- *
- * @return the dimension
- */
- int dimension();
- /**
- * Gets the scale of the matrix values.
- * The true value is the value returned from {@link #get(int, int)} multiplied by
- * {@code 2^scale}.
- *
- * @return the scale
- */
- int scale();
- /**
- * Gets the value. This is a scaled value. The true value is the value returned
- * multiplied by {@code 2^scale}.
- *
- * @param i Row
- * @param j Column
- * @return the value
- * @see #scale
- */
- double get(int i, int j);
- /**
- * Returns the result of multiplying {@code this} with itself {@code n} times.
- *
- * @param n raise {@code this} to power {@code n}
- * @return {@code this^n}
- * @throws IllegalArgumentException if {@code n < 0}
- */
- RealSquareMatrix power(int n);
- }
- /**
- * Implementation of {@link RealSquareMatrix} using a {@code double[]} array to
- * store entries. Values are addressed using {@code i*dim + j} where {@code dim} is
- * the square dimension.
- *
- * <p>Scaling is supported using the central element {@code [m][m]} where
- * {@code m = dimension/2}. Scaling is only implemented post-multiplication
- * to protect against overflow during repeat multiplication operations.
- *
- * <p>Note: The scaling is implemented to support computation of Kolmogorov's
- * distribution as described in:
- * <ul>
- * <li>
- * Marsaglia, G., Tsang, W. W., & Wang, J. (2003).
- * <a href="https://doi.org/10.18637/jss.v008.i18">Evaluating Kolmogorov's Distribution.</a>
- * Journal of Statistical Software, 8(18), 1–4.
- * </ul>
- */
- private static class ArrayRealSquareMatrix implements RealSquareMatrix {
- /** The scaling threshold. Marsaglia used 1e140. This uses 2^400 ~ 2.58e120 */
- private static final double SCALE_THRESHOLD = 0x1.0p400;
- /** Dimension. */
- private final int dim;
- /** Entries of the matrix. */
- private final double[] data;
- /** Matrix scale. */
- private final int exp;
- /**
- * @param dimension Matrix dimension.
- * @param data Matrix data.
- * @param scale Matrix scale.
- */
- ArrayRealSquareMatrix(int dimension, double[] data, int scale) {
- this.dim = dimension;
- this.data = data;
- this.exp = scale;
- }
- @Override
- public int dimension() {
- return dim;
- }
- @Override
- public int scale() {
- return exp;
- }
- @Override
- public double get(int i, int j) {
- return data[i * dim + j];
- }
- @Override
- public RealSquareMatrix power(int n) {
- checkExponent(n);
- if (n == 0) {
- return identity();
- }
- if (n == 1) {
- return this;
- }
- // Here at least 1 multiplication occurs.
- // Compute the power by repeat squaring and multiplication:
- // 13 = 1101
- // x^13 = x^8 * x^4 * x^1
- // = ((x^2 * x)^2)^2 * x
- // 21 = 10101
- // x^21 = x^16 * x^4 * x^1
- // = (((x^2)^2 * x)^2)^2 * x
- // 1. Find highest set bit in n
- // 2. Initialise result as x
- // 3. For remaining bits (0 or 1) below the highest set bit:
- // - square the current result
- // - if the current bit is 1 then multiply by x
- // In this scheme we require 2 matrix array allocations and a column array.
- // Working arrays
- final double[] col = new double[dim];
- double[] b = new double[data.length];
- double[] tmp;
- // Initialise result as A^1.
- final double[] a = data;
- final int ea = exp;
- double[] r = a.clone();
- int er = ea;
- // Shift the highest set bit off the top.
- // Any remaining bits are detected in the sign bit.
- final int shift = Integer.numberOfLeadingZeros(n) + 1;
- int bits = n << shift;
- // Process remaining bits below highest set bit.
- for (int i = 32 - shift; i != 0; i--, bits <<= 1) {
- // Square the result
- er = multiply(r, er, r, er, col, b);
- // Recycle working array
- tmp = b;
- b = r;
- r = tmp;
- if (bits < 0) {
- // Multiply by A
- er = multiply(r, er, a, ea, col, b);
- // Recycle working array
- tmp = b;
- b = r;
- r = tmp;
- }
- }
- return new ArrayRealSquareMatrix(dim, r, er);
- }
- /**
- * Creates the identity matrix I with the same dimension as {@code this}.
- *
- * @return I
- */
- private RealSquareMatrix identity() {
- final int n = dimension();
- return new RealSquareMatrix() {
- @Override
- public int dimension() {
- return n;
- }
- @Override
- public int scale() {
- return 0;
- }
- @Override
- public double get(int i, int j) {
- return i == j ? 1 : 0;
- }
- @Override
- public RealSquareMatrix power(int p) {
- return this;
- }
- };
- }
- /**
- * Returns the result of postmultiplying {@code a} by {@code b}. It is expected
- * the scale of the result will be the sum of the scale of the arguments; this
- * may be adjusted by the scale power if the result is scaled by a power of two
- * for overflow protection.
- *
- * @param a Matrix.
- * @param sa Scale of matrix a.
- * @param b Matrix to postmultiply by.
- * @param sb Scale of matrix b.
- * @param col Working array for a column of the matrix.
- * @param out Output {@code a * b}
- * @return Scale of {@code a * b}
- */
- private static int multiply(double[] a, int sa, double[] b, int sb, double[] col, double[] out) {
- final int m = col.length;
- // Rows are contiguous; Columns are non-contiguous
- int k;
- for (int c = 0; c < m; c++) {
- // Extract column from b to contiguous memory
- k = c;
- for (int i = 0; i < m; i++, k += m) {
- col[i] = b[k];
- }
- // row * col
- k = 0;
- for (int r = 0; r < m; r++) {
- double sum = 0;
- for (int i = 0; i < m; i++, k++) {
- sum += a[k] * col[i];
- }
- out[r * m + c] = sum;
- }
- }
- int s = sa + sb;
- // Overflow protection. Ideally we would check all elements but for speed
- // we check the central one only.
- k = m >> 1;
- if (out[k * m + k] > SCALE_THRESHOLD) {
- // Downscale
- // We could downscale by the inverse of SCALE_THRESHOLD.
- // However this does not account for how far above the threshold
- // the central element is. Here we downscale so the central element
- // is roughly 1 allowing other elements to be larger and still protected
- // from overflow.
- final int exp = Math.getExponent(out[k * m + k]);
- final double downScale = Math.scalb(1.0, -exp);
- s += exp;
- for (int i = 0; i < out.length; i++) {
- out[i] *= downScale;
- }
- }
- return s;
- }
- /**
- * Check the exponent is not negative.
- *
- * @param p Exponent.
- * @throws IllegalArgumentException if the exponent is negative
- */
- private static void checkExponent(int p) {
- if (p < 0) {
- throw new IllegalArgumentException("Not positive exponent: " + p);
- }
- }
- }
- /** No instances. */
- private SquareMatrixSupport() {}
- /**
- * Creates a square matrix. Data may be used in-place.
- *
- * <p>Values are addressed using {@code a[i][j] = i*dimension + j}.
- *
- * @param dimension Matrix dimension.
- * @param data Matrix data.
- * @return the square matrix
- * @throws IllegalArgumentException if the matrix data is not square (length = dimension * dimension)
- */
- static RealSquareMatrix create(int dimension, double[] data) {
- if (dimension * dimension != data.length) {
- // Note: %<d is 'relative indexing' to re-use the last argument
- throw new IllegalArgumentException(String.format("Not square: %d * %<d != %d", dimension, data.length));
- }
- return new ArrayRealSquareMatrix(dimension, data, 0);
- }
- }