Stirling.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.numbers.combinatorics;

  18. /**
  19.  * Computation of <a href="https://en.wikipedia.org/wiki/Stirling_number">Stirling numbers</a>.
  20.  *
  21.  * @since 1.2
  22.  */
  23. public final class Stirling {
  24.     /** Stirling S1 error message. */
  25.     private static final String S1_ERROR_FORMAT = "s(n=%d, k=%d)";
  26.     /** Stirling S2 error message. */
  27.     private static final String S2_ERROR_FORMAT = "S(n=%d, k=%d)";
  28.     /** Overflow threshold for n when computing s(n, 1). */
  29.     private static final int S1_OVERFLOW_K_EQUALS_1 = 21;
  30.     /** Overflow threshold for n when computing s(n, n-2). */
  31.     private static final int S1_OVERFLOW_K_EQUALS_NM2 = 92682;
  32.     /** Overflow threshold for n when computing s(n, n-3). */
  33.     private static final int S1_OVERFLOW_K_EQUALS_NM3 = 2761;
  34.     /** Overflow threshold for n when computing S(n, n-2). */
  35.     private static final int S2_OVERFLOW_K_EQUALS_NM2 = 92683;
  36.     /** Overflow threshold for n when computing S(n, n-3). */
  37.     private static final int S2_OVERFLOW_K_EQUALS_NM3 = 2762;

  38.     /**
  39.      * Precomputed Stirling numbers of the first kind.
  40.      * Provides a thread-safe lazy initialization of the cache.
  41.      */
  42.     private static final class StirlingS1Cache {
  43.         /** Maximum n to compute (exclusive).
  44.          * As s(21,3) = 13803759753640704000 is larger than Long.MAX_VALUE
  45.          * we must stop computation at row 21. */
  46.         static final int MAX_N = 21;
  47.         /** Stirling numbers of the first kind. */
  48.         static final long[][] S1;

  49.         static {
  50.             S1 = new long[MAX_N][];
  51.             // Initialise first two rows to allow s(2, 1) to use s(1, 1)
  52.             S1[0] = new long[] {1};
  53.             S1[1] = new long[] {0, 1};
  54.             for (int n = 2; n < S1.length; n++) {
  55.                 S1[n] = new long[n + 1];
  56.                 S1[n][0] = 0;
  57.                 S1[n][n] = 1;
  58.                 for (int k = 1; k < n; k++) {
  59.                     S1[n][k] = S1[n - 1][k - 1] - (n - 1) * S1[n - 1][k];
  60.                 }
  61.             }
  62.         }
  63.     }

  64.     /**
  65.      * Precomputed Stirling numbers of the second kind.
  66.      * Provides a thread-safe lazy initialization of the cache.
  67.      */
  68.     private static final class StirlingS2Cache {
  69.         /** Maximum n to compute (exclusive).
  70.          * As S(26,9) = 11201516780955125625 is larger than Long.MAX_VALUE
  71.          * we must stop computation at row 26. */
  72.         static final int MAX_N = 26;
  73.         /** Stirling numbers of the second kind. */
  74.         static final long[][] S2;

  75.         static {
  76.             S2 = new long[MAX_N][];
  77.             S2[0] = new long[] {1};
  78.             for (int n = 1; n < S2.length; n++) {
  79.                 S2[n] = new long[n + 1];
  80.                 S2[n][0] = 0;
  81.                 S2[n][1] = 1;
  82.                 S2[n][n] = 1;
  83.                 for (int k = 2; k < n; k++) {
  84.                     S2[n][k] = k * S2[n - 1][k] + S2[n - 1][k - 1];
  85.                 }
  86.             }
  87.         }
  88.     }

  89.     /** Private constructor. */
  90.     private Stirling() {
  91.         // intentionally empty.
  92.     }

  93.     /**
  94.      * Returns the <em>signed</em> <a
  95.      * href="https://mathworld.wolfram.com/StirlingNumberoftheFirstKind.html">
  96.      * Stirling number of the first kind</a>, "{@code s(n,k)}". The number of permutations of
  97.      * {@code n} elements which contain exactly {@code k} permutation cycles is the
  98.      * nonnegative number: {@code |s(n,k)| = (-1)^(n-k) s(n,k)}
  99.      *
  100.      * @param n Size of the set
  101.      * @param k Number of permutation cycles ({@code 0 <= k <= n})
  102.      * @return {@code s(n,k)}
  103.      * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}.
  104.      * @throws ArithmeticException if some overflow happens, typically for n exceeding 20
  105.      * (s(n,n-1) is handled specifically and does not overflow)
  106.      */
  107.     public static long stirlingS1(int n, int k) {
  108.         checkArguments(n, k);

  109.         if (n < StirlingS1Cache.MAX_N) {
  110.             // The number is in the small cache
  111.             return StirlingS1Cache.S1[n][k];
  112.         }

  113.         // Simple cases
  114.         // https://en.wikipedia.org/wiki/Stirling_numbers_of_the_first_kind#Simple_identities
  115.         if (k == 0) {
  116.             return 0;
  117.         } else if (k == n) {
  118.             return 1;
  119.         } else if (k == 1) {
  120.             checkN(n, k, S1_OVERFLOW_K_EQUALS_1, S1_ERROR_FORMAT);
  121.             // Note: Only occurs for n=21 so avoid computing the sign with pow(-1, n-1) * (n-1)!
  122.             return Factorial.value(n - 1);
  123.         } else if (k == n - 1) {
  124.             return -BinomialCoefficient.value(n, 2);
  125.         } else if (k == n - 2) {
  126.             checkN(n, k, S1_OVERFLOW_K_EQUALS_NM2, S1_ERROR_FORMAT);
  127.             // (3n-1) * binom(n, 3) / 4
  128.             return productOver4(3L * n - 1, BinomialCoefficient.value(n, 3));
  129.         } else if (k == n - 3) {
  130.             checkN(n, k, S1_OVERFLOW_K_EQUALS_NM3, S1_ERROR_FORMAT);
  131.             return -BinomialCoefficient.value(n, 2) * BinomialCoefficient.value(n, 4);
  132.         }

  133.         // Compute using:
  134.         // s(n + 1, k) = s(n, k - 1)     - n       * s(n, k)
  135.         // s(n, k)     = s(n - 1, k - 1) - (n - 1) * s(n - 1, k)

  136.         // n >= 21 (MAX_N)
  137.         // 2 <= k <= n-4

  138.         // Start at the largest easily computed value: n < MAX_N or k < 2
  139.         final int reduction = Math.min(n - StirlingS1Cache.MAX_N, k - 2) + 1;
  140.         int n0 = n - reduction;
  141.         int k0 = k - reduction;

  142.         long sum = stirlingS1(n0, k0);
  143.         while (n0 < n) {
  144.             k0++;
  145.             sum = Math.subtractExact(
  146.                 sum,
  147.                 Math.multiplyExact(n0, stirlingS1(n0, k0))
  148.             );
  149.             n0++;
  150.         }

  151.         return sum;
  152.     }

  153.     /**
  154.      * Returns the <a
  155.      * href="https://mathworld.wolfram.com/StirlingNumberoftheSecondKind.html">
  156.      * Stirling number of the second kind</a>, "{@code S(n,k)}", the number of
  157.      * ways of partitioning an {@code n}-element set into {@code k} non-empty
  158.      * subsets.
  159.      *
  160.      * @param n Size of the set
  161.      * @param k Number of non-empty subsets ({@code 0 <= k <= n})
  162.      * @return {@code S(n,k)}
  163.      * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}.
  164.      * @throws ArithmeticException if some overflow happens, typically for n exceeding 25 and
  165.      * k between 20 and n-2 (S(n,n-1) is handled specifically and does not overflow)
  166.      */
  167.     public static long stirlingS2(int n, int k) {
  168.         checkArguments(n, k);

  169.         if (n < StirlingS2Cache.MAX_N) {
  170.             // The number is in the small cache
  171.             return StirlingS2Cache.S2[n][k];
  172.         }

  173.         // Simple cases
  174.         if (k == 0) {
  175.             return 0;
  176.         } else if (k == 1 || k == n) {
  177.             return 1;
  178.         } else if (k == 2) {
  179.             checkN(n, k, 64, S2_ERROR_FORMAT);
  180.             return (1L << (n - 1)) - 1L;
  181.         } else if (k == n - 1) {
  182.             return BinomialCoefficient.value(n, 2);
  183.         } else if (k == n - 2) {
  184.             checkN(n, k, S2_OVERFLOW_K_EQUALS_NM2, S2_ERROR_FORMAT);
  185.             // (3n-5) * binom(n, 3) / 4
  186.             return productOver4(3L * n - 5, BinomialCoefficient.value(n, 3));
  187.         } else if (k == n - 3) {
  188.             checkN(n, k, S2_OVERFLOW_K_EQUALS_NM3, S2_ERROR_FORMAT);
  189.             return BinomialCoefficient.value(n - 2, 2) * BinomialCoefficient.value(n, 4);
  190.         }

  191.         // Compute using:
  192.         // S(n, k) = k * S(n - 1, k) + S(n - 1, k - 1)

  193.         // n >= 26 (MAX_N)
  194.         // 3 <= k <= n-3

  195.         // Start at the largest easily computed value: n < MAX_N or k < 3
  196.         final int reduction = Math.min(n - StirlingS2Cache.MAX_N, k - 3) + 1;
  197.         int n0 = n - reduction;
  198.         int k0 = k - reduction;

  199.         long sum = stirlingS2(n0, k0);
  200.         while (n0 < n) {
  201.             k0++;
  202.             sum = Math.addExact(
  203.                 Math.multiplyExact(k0, stirlingS2(n0, k0)),
  204.                 sum
  205.             );
  206.             n0++;
  207.         }

  208.         return sum;
  209.     }

  210.     /**
  211.      * Check {@code 0 <= k <= n}.
  212.      *
  213.      * @param n N
  214.      * @param k K
  215.      * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}.
  216.      */
  217.     private static void checkArguments(int n, int k) {
  218.         // Combine all checks with a single branch:
  219.         // 0 <= n; 0 <= k <= n
  220.         // Note: If n >= 0 && k >= 0 && n - k < 0 then k > n.
  221.         // Bitwise or will detect a negative sign bit in any of the numbers
  222.         if ((n | k | (n - k)) < 0) {
  223.             // Raise the correct exception
  224.             if (n < 0) {
  225.                 throw new CombinatoricsException(CombinatoricsException.NEGATIVE, n);
  226.             }
  227.             throw new CombinatoricsException(CombinatoricsException.OUT_OF_RANGE, k, 0, n);
  228.         }
  229.     }

  230.     /**
  231.      * Check {@code n <= threshold}, or else throw an {@link ArithmeticException}.
  232.      *
  233.      * @param n N
  234.      * @param k K
  235.      * @param threshold Threshold for {@code n}
  236.      * @param msgFormat Error message format
  237.      * @throws ArithmeticException if overflow is expected to happen
  238.      */
  239.     private static void checkN(int n, int k, int threshold, String msgFormat) {
  240.         if (n > threshold) {
  241.             throw new ArithmeticException(String.format(msgFormat, n, k));
  242.         }
  243.     }

  244.     /**
  245.      * Return {@code a*b/4} without intermediate overflow.
  246.      * It is assumed that:
  247.      * <ul>
  248.      * <li>The coefficients a and b are positive
  249.      * <li>The product (a*b) is an exact multiple of 4
  250.      * <li>The result (a*b/4) is an exact integer that does not overflow a {@code long}
  251.      * </ul>
  252.      *
  253.      * <p>A conditional branch is performed on the odd/even property of {@code b}.
  254.      * The branch is predictable if {@code b} is typically the same parity.
  255.      *
  256.      * @param a Coefficient a
  257.      * @param b Coefficient b
  258.      * @return {@code a*b/4}
  259.      */
  260.     private static long productOver4(long a, long b) {
  261.         // Compute (a*b/4) without intermediate overflow.
  262.         // The product (a*b) must be an exact multiple of 4.
  263.         // If b is even: ((b/2) * a) / 2
  264.         // If b is odd then a must be even to make a*b even: ((a/2) * b) / 2
  265.         return (b & 1) == 0 ?
  266.             ((b >>> 1) * a) >>> 1 :
  267.             ((a >>> 1) * b) >>> 1;
  268.     }
  269. }