001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.legacy.analysis.function;
019
020import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
021import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
022import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction;
023import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
024import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
025import org.apache.commons.math4.legacy.exception.NullArgumentException;
026import org.apache.commons.math4.core.jdkmath.JdkMath;
027
028/**
029 * <a href="http://en.wikipedia.org/wiki/Generalised_logistic_function">
030 *  Generalised logistic</a> function.
031 *
032 * @since 3.0
033 */
034public class Logistic implements UnivariateDifferentiableFunction {
035    /** Lower asymptote. */
036    private final double a;
037    /** Upper asymptote. */
038    private final double k;
039    /** Growth rate. */
040    private final double b;
041    /** Parameter that affects near which asymptote maximum growth occurs. */
042    private final double oneOverN;
043    /** Parameter that affects the position of the curve along the ordinate axis. */
044    private final double q;
045    /** Abscissa of maximum growth. */
046    private final double m;
047
048    /**
049     * @param k If {@code b > 0}, value of the function for x going towards +&infin;.
050     * If {@code b < 0}, value of the function for x going towards -&infin;.
051     * @param m Abscissa of maximum growth.
052     * @param b Growth rate.
053     * @param q Parameter that affects the position of the curve along the
054     * ordinate axis.
055     * @param a If {@code b > 0}, value of the function for x going towards -&infin;.
056     * If {@code b < 0}, value of the function for x going towards +&infin;.
057     * @param n Parameter that affects near which asymptote the maximum
058     * growth occurs.
059     * @throws NotStrictlyPositiveException if {@code n <= 0}.
060     */
061    public Logistic(double k,
062                    double m,
063                    double b,
064                    double q,
065                    double a,
066                    double n)
067        throws NotStrictlyPositiveException {
068        if (n <= 0) {
069            throw new NotStrictlyPositiveException(n);
070        }
071
072        this.k = k;
073        this.m = m;
074        this.b = b;
075        this.q = q;
076        this.a = a;
077        oneOverN = 1 / n;
078    }
079
080    /** {@inheritDoc} */
081    @Override
082    public double value(double x) {
083        return value(m - x, k, b, q, a, oneOverN);
084    }
085
086    /**
087     * Parametric function where the input array contains the parameters of
088     * the {@link Logistic#Logistic(double,double,double,double,double,double)
089     * logistic function}. Ordered as follows:
090     * <ul>
091     *  <li>k</li>
092     *  <li>m</li>
093     *  <li>b</li>
094     *  <li>q</li>
095     *  <li>a</li>
096     *  <li>n</li>
097     * </ul>
098     */
099    public static class Parametric implements ParametricUnivariateFunction {
100        /**
101         * Computes the value of the sigmoid at {@code x}.
102         *
103         * @param x Value for which the function must be computed.
104         * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
105         * {@code a} and  {@code n}.
106         * @return the value of the function.
107         * @throws NullArgumentException if {@code param} is {@code null}.
108         * @throws DimensionMismatchException if the size of {@code param} is
109         * not 6.
110         * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
111         */
112        @Override
113        public double value(double x, double ... param)
114            throws NullArgumentException,
115                   DimensionMismatchException,
116                   NotStrictlyPositiveException {
117            validateParameters(param);
118            return Logistic.value(param[1] - x, param[0],
119                                  param[2], param[3],
120                                  param[4], 1 / param[5]);
121        }
122
123        /**
124         * Computes the value of the gradient at {@code x}.
125         * The components of the gradient vector are the partial
126         * derivatives of the function with respect to each of the
127         * <em>parameters</em>.
128         *
129         * @param x Value at which the gradient must be computed.
130         * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
131         * {@code a} and  {@code n}.
132         * @return the gradient vector at {@code x}.
133         * @throws NullArgumentException if {@code param} is {@code null}.
134         * @throws DimensionMismatchException if the size of {@code param} is
135         * not 6.
136         * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
137         */
138        @Override
139        public double[] gradient(double x, double ... param)
140            throws NullArgumentException,
141                   DimensionMismatchException,
142                   NotStrictlyPositiveException {
143            validateParameters(param);
144
145            final double b = param[2];
146            final double q = param[3];
147
148            final double mMinusX = param[1] - x;
149            final double oneOverN = 1 / param[5];
150            final double exp = JdkMath.exp(b * mMinusX);
151            final double qExp = q * exp;
152            final double qExp1 = qExp + 1;
153            final double factor1 = (param[0] - param[4]) * oneOverN / JdkMath.pow(qExp1, oneOverN);
154            final double factor2 = -factor1 / qExp1;
155
156            // Components of the gradient.
157            final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN);
158            final double gm = factor2 * b * qExp;
159            final double gb = factor2 * mMinusX * qExp;
160            final double gq = factor2 * exp;
161            final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN);
162            final double gn = factor1 * JdkMath.log(qExp1) * oneOverN;
163
164            return new double[] { gk, gm, gb, gq, ga, gn };
165        }
166
167        /**
168         * Validates parameters to ensure they are appropriate for the evaluation of
169         * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
170         * methods.
171         *
172         * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
173         * {@code a} and {@code n}.
174         * @throws NullArgumentException if {@code param} is {@code null}.
175         * @throws DimensionMismatchException if the size of {@code param} is
176         * not 6.
177         * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
178         */
179        private void validateParameters(double[] param)
180            throws NullArgumentException,
181                   DimensionMismatchException,
182                   NotStrictlyPositiveException {
183            if (param == null) {
184                throw new NullArgumentException();
185            }
186            if (param.length != 6) {
187                throw new DimensionMismatchException(param.length, 6);
188            }
189            if (param[5] <= 0) {
190                throw new NotStrictlyPositiveException(param[5]);
191            }
192        }
193    }
194
195    /**
196     * @param mMinusX {@code m - x}.
197     * @param k {@code k}.
198     * @param b {@code b}.
199     * @param q {@code q}.
200     * @param a {@code a}.
201     * @param oneOverN {@code 1 / n}.
202     * @return the value of the function.
203     */
204    private static double value(double mMinusX,
205                                double k,
206                                double b,
207                                double q,
208                                double a,
209                                double oneOverN) {
210        return a + (k - a) / JdkMath.pow(1 + q * JdkMath.exp(b * mMinusX), oneOverN);
211    }
212
213    /** {@inheritDoc}
214     * @since 3.1
215     */
216    @Override
217    public DerivativeStructure value(final DerivativeStructure t) {
218        return t.negate().add(m).multiply(b).exp().multiply(q).add(1).pow(oneOverN).reciprocal().multiply(k - a).add(a);
219    }
220}