001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.statistics.distribution;
019
020import org.apache.commons.numbers.gamma.Erfc;
021import org.apache.commons.numbers.gamma.InverseErf;
022import org.apache.commons.numbers.gamma.ErfDifference;
023import org.apache.commons.rng.UniformRandomProvider;
024import org.apache.commons.rng.sampling.distribution.GaussianSampler;
025import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
026
027/**
028 * Implementation of the <a href="http://en.wikipedia.org/wiki/Normal_distribution">normal (Gaussian) distribution</a>.
029 */
030public class NormalDistribution extends AbstractContinuousDistribution {
031    /** &radic;(2). */
032    private static final double SQRT2 = Math.sqrt(2.0);
033    /** Mean of this distribution. */
034    private final double mean;
035    /** Standard deviation of this distribution. */
036    private final double standardDeviation;
037    /** The value of {@code log(sd) + 0.5*log(2*pi)} stored for faster computation. */
038    private final double logStandardDeviationPlusHalfLog2Pi;
039
040    /**
041     * Creates a distribution.
042     *
043     * @param mean Mean for this distribution.
044     * @param sd Standard deviation for this distribution.
045     * @throws IllegalArgumentException if {@code sd <= 0}.
046     */
047    public NormalDistribution(double mean,
048                              double sd) {
049        if (sd <= 0) {
050            throw new DistributionException(DistributionException.NEGATIVE, sd);
051        }
052
053        this.mean = mean;
054        standardDeviation = sd;
055        logStandardDeviationPlusHalfLog2Pi = Math.log(sd) + 0.5 * Math.log(2 * Math.PI);
056    }
057
058    /**
059     * Access the standard deviation.
060     *
061     * @return the standard deviation for this distribution.
062     */
063    public double getStandardDeviation() {
064        return standardDeviation;
065    }
066
067    /** {@inheritDoc} */
068    @Override
069    public double density(double x) {
070        return Math.exp(logDensity(x));
071    }
072
073    /** {@inheritDoc} */
074    @Override
075    public double logDensity(double x) {
076        final double x0 = x - mean;
077        final double x1 = x0 / standardDeviation;
078        return -0.5 * x1 * x1 - logStandardDeviationPlusHalfLog2Pi;
079    }
080
081    /**
082     * {@inheritDoc}
083     *
084     * If {@code x} is more than 40 standard deviations from the mean, 0 or 1
085     * is returned, as in these cases the actual value is within
086     * {@code Double.MIN_VALUE} of 0 or 1.
087     */
088    @Override
089    public double cumulativeProbability(double x)  {
090        final double dev = x - mean;
091        if (Math.abs(dev) > 40 * standardDeviation) {
092            return dev < 0 ? 0.0d : 1.0d;
093        }
094        return 0.5 * Erfc.value(-dev / (standardDeviation * SQRT2));
095    }
096
097    /** {@inheritDoc} */
098    @Override
099    public double inverseCumulativeProbability(final double p) {
100        if (p < 0 ||
101            p > 1) {
102            throw new DistributionException(DistributionException.INVALID_PROBABILITY, p);
103        }
104        return mean + standardDeviation * SQRT2 * InverseErf.value(2 * p - 1);
105    }
106
107    /** {@inheritDoc} */
108    @Override
109    public double probability(double x0,
110                              double x1) {
111        if (x0 > x1) {
112            throw new DistributionException(DistributionException.TOO_LARGE,
113                                            x0, x1);
114        }
115        final double denom = standardDeviation * SQRT2;
116        final double v0 = (x0 - mean) / denom;
117        final double v1 = (x1 - mean) / denom;
118        return 0.5 * ErfDifference.value(v0, v1);
119    }
120
121    /** {@inheritDoc} */
122    @Override
123    public double getMean() {
124        return mean;
125    }
126
127    /**
128     * {@inheritDoc}
129     *
130     * For standard deviation parameter {@code s}, the variance is {@code s^2}.
131     */
132    @Override
133    public double getVariance() {
134        final double s = getStandardDeviation();
135        return s * s;
136    }
137
138    /**
139     * {@inheritDoc}
140     *
141     * The lower bound of the support is always negative infinity
142     * no matter the parameters.
143     *
144     * @return lower bound of the support (always
145     * {@code Double.NEGATIVE_INFINITY})
146     */
147    @Override
148    public double getSupportLowerBound() {
149        return Double.NEGATIVE_INFINITY;
150    }
151
152    /**
153     * {@inheritDoc}
154     *
155     * The upper bound of the support is always positive infinity
156     * no matter the parameters.
157     *
158     * @return upper bound of the support (always
159     * {@code Double.POSITIVE_INFINITY})
160     */
161    @Override
162    public double getSupportUpperBound() {
163        return Double.POSITIVE_INFINITY;
164    }
165
166    /**
167     * {@inheritDoc}
168     *
169     * The support of this distribution is connected.
170     *
171     * @return {@code true}
172     */
173    @Override
174    public boolean isSupportConnected() {
175        return true;
176    }
177
178    /** {@inheritDoc} */
179    @Override
180    public ContinuousDistribution.Sampler createSampler(final UniformRandomProvider rng) {
181        // Gaussian distribution sampler.
182        return new GaussianSampler(new ZigguratNormalizedGaussianSampler(rng),
183                                   mean, standardDeviation)::sample;
184    }
185}