TrapezoidalDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.statistics.distribution;

  18. import org.apache.commons.rng.UniformRandomProvider;

  19. /**
  20.  * Implementation of the trapezoidal distribution.
  21.  *
  22.  * <p>The probability density function of \( X \) is:
  23.  *
  24.  * <p>\[ f(x; a, b, c, d) = \begin{cases}
  25.  *       \frac{2}{d+c-a-b}\frac{x-a}{b-a} &amp; \text{for } a\le x \lt b \\
  26.  *       \frac{2}{d+c-a-b}                &amp; \text{for } b\le x \lt c \\
  27.  *       \frac{2}{d+c-a-b}\frac{d-x}{d-c} &amp; \text{for } c\le x \le d
  28.  *       \end{cases} \]
  29.  *
  30.  * <p>for \( -\infty \lt a \le b \le c \le d \lt \infty \) and
  31.  * \( x \in [a, d] \).
  32.  *
  33.  * <p>Note the special cases:
  34.  * <ul>
  35.  * <li>\( b = c \) is the triangular distribution
  36.  * <li>\( a = b \) and \( c = d \) is the uniform distribution
  37.  * </ul>
  38.  *
  39.  * @see <a href="https://en.wikipedia.org/wiki/Trapezoidal_distribution">Trapezoidal distribution (Wikipedia)</a>
  40.  */
  41. public abstract class TrapezoidalDistribution extends AbstractContinuousDistribution {
  42.     /** Lower limit of this distribution (inclusive). */
  43.     protected final double a;
  44.     /** Start of the trapezoid constant density. */
  45.     protected final double b;
  46.     /** End of the trapezoid constant density. */
  47.     protected final double c;
  48.     /** Upper limit of this distribution (inclusive). */
  49.     protected final double d;

  50.     /**
  51.      * Specialisation of the trapezoidal distribution used when the distribution simplifies
  52.      * to an alternative distribution.
  53.      */
  54.     private static class DelegatedTrapezoidalDistribution extends TrapezoidalDistribution {
  55.         /** Distribution delegate. */
  56.         private final ContinuousDistribution delegate;

  57.         /**
  58.          * @param a Lower limit of this distribution (inclusive).
  59.          * @param b Start of the trapezoid constant density.
  60.          * @param c End of the trapezoid constant density.
  61.          * @param d Upper limit of this distribution (inclusive).
  62.          * @param delegate Distribution delegate.
  63.          */
  64.         DelegatedTrapezoidalDistribution(double a, double b, double c, double d,
  65.                                          ContinuousDistribution delegate) {
  66.             super(a, b, c, d);
  67.             this.delegate = delegate;
  68.         }

  69.         @Override
  70.         public double density(double x) {
  71.             return delegate.density(x);
  72.         }

  73.         @Override
  74.         public double probability(double x0, double x1) {
  75.             return delegate.probability(x0, x1);
  76.         }

  77.         @Override
  78.         public double logDensity(double x) {
  79.             return delegate.logDensity(x);
  80.         }

  81.         @Override
  82.         public double cumulativeProbability(double x) {
  83.             return delegate.cumulativeProbability(x);
  84.         }

  85.         @Override
  86.         public double inverseCumulativeProbability(double p) {
  87.             return delegate.inverseCumulativeProbability(p);
  88.         }

  89.         @Override
  90.         public double survivalProbability(double x) {
  91.             return delegate.survivalProbability(x);
  92.         }

  93.         @Override
  94.         public double inverseSurvivalProbability(double p) {
  95.             return delegate.inverseSurvivalProbability(p);
  96.         }

  97.         @Override
  98.         public double getMean() {
  99.             return delegate.getMean();
  100.         }

  101.         @Override
  102.         public double getVariance() {
  103.             return delegate.getVariance();
  104.         }

  105.         @Override
  106.         public Sampler createSampler(UniformRandomProvider rng) {
  107.             return delegate.createSampler(rng);
  108.         }
  109.     }

  110.     /**
  111.      * Specialisation of the trapezoidal distribution used when {@code b == c}.
  112.      *
  113.      * <p>This delegates all methods to the triangular distribution.
  114.      */
  115.     private static class TriangularTrapezoidalDistribution extends DelegatedTrapezoidalDistribution {
  116.         /**
  117.          * @param a Lower limit of this distribution (inclusive).
  118.          * @param b Start/end of the trapezoid constant density (mode).
  119.          * @param d Upper limit of this distribution (inclusive).
  120.          */
  121.         TriangularTrapezoidalDistribution(double a, double b, double d) {
  122.             super(a, b, b, d, TriangularDistribution.of(a, b, d));
  123.         }
  124.     }

  125.     /**
  126.      * Specialisation of the trapezoidal distribution used when {@code a == b} and {@code c == d}.
  127.      *
  128.      * <p>This delegates all methods to the uniform distribution.
  129.      */
  130.     private static class UniformTrapezoidalDistribution extends DelegatedTrapezoidalDistribution {
  131.         /**
  132.          * @param a Lower limit of this distribution (inclusive).
  133.          * @param d Upper limit of this distribution (inclusive).
  134.          */
  135.         UniformTrapezoidalDistribution(double a, double d) {
  136.             super(a, a, d, d, UniformContinuousDistribution.of(a, d));
  137.         }
  138.     }

  139.     /**
  140.      * Regular implementation of the trapezoidal distribution.
  141.      */
  142.     private static class RegularTrapezoidalDistribution extends TrapezoidalDistribution {
  143.         /** Cached value (d + c - a - b). */
  144.         private final double divisor;
  145.         /** Cached value (b - a). */
  146.         private final double bma;
  147.         /** Cached value (d - c). */
  148.         private final double dmc;
  149.         /** Cumulative probability at b. */
  150.         private final double cdfB;
  151.         /** Cumulative probability at c. */
  152.         private final double cdfC;
  153.         /** Survival probability at b. */
  154.         private final double sfB;
  155.         /** Survival probability at c. */
  156.         private final double sfC;

  157.         /**
  158.          * @param a Lower limit of this distribution (inclusive).
  159.          * @param b Start of the trapezoid constant density.
  160.          * @param c End of the trapezoid constant density.
  161.          * @param d Upper limit of this distribution (inclusive).
  162.          */
  163.         RegularTrapezoidalDistribution(double a, double b, double c, double d) {
  164.             super(a, b, c, d);

  165.             // Sum positive terms
  166.             divisor = (d - a) + (c - b);
  167.             bma = b - a;
  168.             dmc = d - c;

  169.             cdfB = bma / divisor;
  170.             sfB = 1 - cdfB;
  171.             sfC = dmc / divisor;
  172.             cdfC = 1 - sfC;
  173.         }

  174.         @Override
  175.         public double density(double x) {
  176.             // Note: x < a allows correct density where a == b
  177.             if (x < a) {
  178.                 return 0;
  179.             }
  180.             if (x < b) {
  181.                 final double divident = (x - a) / bma;
  182.                 return 2 * (divident / divisor);
  183.             }
  184.             if (x < c) {
  185.                 return 2 / divisor;
  186.             }
  187.             if (x < d) {
  188.                 final double divident = (d - x) / dmc;
  189.                 return 2 * (divident / divisor);
  190.             }
  191.             return 0;
  192.         }

  193.         @Override
  194.         public double cumulativeProbability(double x)  {
  195.             if (x <= a) {
  196.                 return 0;
  197.             }
  198.             if (x < b) {
  199.                 final double divident = (x - a) * (x - a) / bma;
  200.                 return divident / divisor;
  201.             }
  202.             if (x < c) {
  203.                 final double divident = 2 * x - b - a;
  204.                 return divident / divisor;
  205.             }
  206.             if (x < d) {
  207.                 final double divident = (d - x) * (d - x) / dmc;
  208.                 return 1 - divident / divisor;
  209.             }
  210.             return 1;
  211.         }

  212.         @Override
  213.         public double survivalProbability(double x)  {
  214.             // By symmetry:
  215.             if (x <= a) {
  216.                 return 1;
  217.             }
  218.             if (x < b) {
  219.                 final double divident = (x - a) * (x - a) / bma;
  220.                 return 1 - divident / divisor;
  221.             }
  222.             if (x < c) {
  223.                 final double divident = 2 * x - b - a;
  224.                 return 1 - divident / divisor;
  225.             }
  226.             if (x < d) {
  227.                 final double divident = (d - x) * (d - x) / dmc;
  228.                 return divident / divisor;
  229.             }
  230.             return 0;
  231.         }

  232.         @Override
  233.         public double inverseCumulativeProbability(double p) {
  234.             ArgumentUtils.checkProbability(p);
  235.             if (p == 0) {
  236.                 return a;
  237.             }
  238.             if (p == 1) {
  239.                 return d;
  240.             }
  241.             if (p < cdfB) {
  242.                 return a + Math.sqrt(p * divisor * bma);
  243.             }
  244.             if (p < cdfC) {
  245.                 return 0.5 * ((p * divisor) + a + b);
  246.             }
  247.             return d - Math.sqrt((1 - p) * divisor * dmc);
  248.         }

  249.         @Override
  250.         public double inverseSurvivalProbability(double p) {
  251.             // By symmetry:
  252.             ArgumentUtils.checkProbability(p);
  253.             if (p == 1) {
  254.                 return a;
  255.             }
  256.             if (p == 0) {
  257.                 return d;
  258.             }
  259.             if (p > sfB) {
  260.                 return a + Math.sqrt((1 - p) * divisor * bma);
  261.             }
  262.             if (p > sfC) {
  263.                 return 0.5 * (((1 - p) * divisor) + a + b);
  264.             }
  265.             return d - Math.sqrt(p * divisor * dmc);
  266.         }

  267.         @Override
  268.         public double getMean() {
  269.             // Compute using a standardized distribution
  270.             // b' = (b-a) / (d-a)
  271.             // c' = (c-a) / (d-a)
  272.             final double scale = d - a;
  273.             final double bp = bma / scale;
  274.             final double cp = (c - a) / scale;
  275.             return nonCentralMoment(1, bp, cp) * scale + a;
  276.         }

  277.         @Override
  278.         public double getVariance() {
  279.             // Compute using a standardized distribution
  280.             // b' = (b-a) / (d-a)
  281.             // c' = (c-a) / (d-a)
  282.             final double scale = d - a;
  283.             final double bp = bma / scale;
  284.             final double cp = (c - a) / scale;
  285.             final double mu = nonCentralMoment(1, bp, cp);
  286.             return (nonCentralMoment(2, bp, cp) - mu * mu) * scale * scale;
  287.         }

  288.         /**
  289.          * Compute the {@code k}-th non-central moment of the standardized trapezoidal
  290.          * distribution.
  291.          *
  292.          * <p>Shifting the distribution by scale {@code (d - a)} and location {@code a}
  293.          * creates a standardized trapezoidal distribution. This has a simplified
  294.          * non-central moment as {@code a = 0, d = 1, 0 <= b < c <= 1}.
  295.          * <pre>
  296.          *               2             1       ( 1 - c^(k+2)           )
  297.          * E[X^k] = ----------- -------------- ( ----------- - b^(k+1) )
  298.          *          (1 + c - b) (k + 1)(k + 2) (    1 - c              )
  299.          * </pre>
  300.          *
  301.          * <p>Simplification eliminates issues computing the moments when {@code a == b}
  302.          * or {@code c == d} in the original (non-standardized) distribution.
  303.          *
  304.          * @param k Moment to compute
  305.          * @param b Start of the trapezoid constant density (shape parameter in [0, 1]).
  306.          * @param c End of the trapezoid constant density (shape parameter in [0, 1]).
  307.          * @return the moment
  308.          */
  309.         private static double nonCentralMoment(int k, double b, double c) {
  310.             // As c -> 1 then (1 - c^(k+2)) loses precision
  311.             // 1 - x^y == -(x^y - 1)    [high precision powm1]
  312.             //         == -(exp(y * log(x)) - 1)
  313.             // Note: avoid log(1) using the limit:
  314.             // (1 - c^(k+2)) / (1-c) -> (k+2) as c -> 1
  315.             final double term1 = c == 1 ? k + 2 : Math.expm1((k + 2) * Math.log(c)) / (c - 1);
  316.             final double term2 = Math.pow(b, k + 1);
  317.             return 2 * ((term1 - term2) / (c - b + 1) / ((k + 1) * (k + 2)));
  318.         }
  319.     }

  320.     /**
  321.      * @param a Lower limit of this distribution (inclusive).
  322.      * @param b Start of the trapezoid constant density.
  323.      * @param c End of the trapezoid constant density.
  324.      * @param d Upper limit of this distribution (inclusive).
  325.      */
  326.     TrapezoidalDistribution(double a, double b, double c, double d) {
  327.         this.a = a;
  328.         this.b = b;
  329.         this.c = c;
  330.         this.d = d;
  331.     }

  332.     /**
  333.      * Creates a trapezoidal distribution.
  334.      *
  335.      * <p>The distribution density is represented as an up sloping line from
  336.      * {@code a} to {@code b}, constant from {@code b} to {@code c}, and then a down
  337.      * sloping line from {@code c} to {@code d}.
  338.      *
  339.      * @param a Lower limit of this distribution (inclusive).
  340.      * @param b Start of the trapezoid constant density (first shape parameter).
  341.      * @param c End of the trapezoid constant density (second shape parameter).
  342.      * @param d Upper limit of this distribution (inclusive).
  343.      * @return the distribution
  344.      * @throws IllegalArgumentException if {@code a >= d}, if {@code b < a}, if
  345.      * {@code c < b} or if {@code c > d}.
  346.      */
  347.     public static TrapezoidalDistribution of(double a, double b, double c, double d) {
  348.         if (a >= d) {
  349.             throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GTE_HIGH,
  350.                                             a, d);
  351.         }
  352.         if (b < a) {
  353.             throw new DistributionException(DistributionException.TOO_SMALL,
  354.                                             b, a);
  355.         }
  356.         if (c < b) {
  357.             throw new DistributionException(DistributionException.TOO_SMALL,
  358.                                             c, b);
  359.         }
  360.         if (c > d) {
  361.             throw new DistributionException(DistributionException.TOO_LARGE,
  362.                                             c, d);
  363.         }
  364.         // For consistency, delegate to the appropriate simplified distribution.
  365.         // Note: Floating-point equality comparison is intentional.
  366.         if (b == c) {
  367.             return new TriangularTrapezoidalDistribution(a, b, d);
  368.         }
  369.         if (d - a == c - b) {
  370.             return new UniformTrapezoidalDistribution(a, d);
  371.         }
  372.         return new RegularTrapezoidalDistribution(a, b, c, d);
  373.     }

  374.     /**
  375.      * {@inheritDoc}
  376.      *
  377.      * <p>For lower limit \( a \), start of the density constant region \( b \),
  378.      * end of the density constant region \( c \) and upper limit \( d \), the
  379.      * mean is:
  380.      *
  381.      * <p>\[ \frac{1}{3(d+c-b-a)}\left(\frac{d^3-c^3}{d-c}-\frac{b^3-a^3}{b-a}\right) \]
  382.      */
  383.     @Override
  384.     public abstract double getMean();

  385.     /**
  386.      * {@inheritDoc}
  387.      *
  388.      * <p>For lower limit \( a \), start of the density constant region \( b \),
  389.      * end of the density constant region \( c \) and upper limit \( d \), the
  390.      * variance is:
  391.      *
  392.      * <p>\[ \frac{1}{6(d+c-b-a)}\left(\frac{d^4-c^4}{d-c}-\frac{b^4-a^4}{b-a}\right) - \mu^2 \]
  393.      *
  394.      * <p>where \( \mu \) is the mean.
  395.      */
  396.     @Override
  397.     public abstract double getVariance();

  398.     /**
  399.      * Gets the start of the constant region of the density function.
  400.      *
  401.      * <p>This is the first shape parameter {@code b} of the distribution.
  402.      *
  403.      * @return the first shape parameter {@code b}
  404.      */
  405.     public double getB() {
  406.         return b;
  407.     }

  408.     /**
  409.      * Gets the end of the constant region of the density function.
  410.      *
  411.      * <p>This is the second shape parameter {@code c} of the distribution.
  412.      *
  413.      * @return the second shape parameter {@code c}
  414.      */
  415.     public double getC() {
  416.         return c;
  417.     }

  418.     /**
  419.      * {@inheritDoc}
  420.      *
  421.      * <p>The lower bound of the support is equal to the lower limit parameter
  422.      * {@code a} of the distribution.
  423.      */
  424.     @Override
  425.     public double getSupportLowerBound() {
  426.         return a;
  427.     }

  428.     /**
  429.      * {@inheritDoc}
  430.      *
  431.      * <p>The upper bound of the support is equal to the upper limit parameter
  432.      * {@code d} of the distribution.
  433.      */
  434.     @Override
  435.     public double getSupportUpperBound() {
  436.         return d;
  437.     }
  438. }