Logistic.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.analysis.function;

  18. import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
  19. import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
  20. import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction;
  21. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  22. import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
  23. import org.apache.commons.math4.legacy.exception.NullArgumentException;
  24. import org.apache.commons.math4.core.jdkmath.JdkMath;

  25. /**
  26.  * <a href="http://en.wikipedia.org/wiki/Generalised_logistic_function">
  27.  *  Generalised logistic</a> function.
  28.  *
  29.  * @since 3.0
  30.  */
  31. public class Logistic implements UnivariateDifferentiableFunction {
  32.     /** Lower asymptote. */
  33.     private final double a;
  34.     /** Upper asymptote. */
  35.     private final double k;
  36.     /** Growth rate. */
  37.     private final double b;
  38.     /** Parameter that affects near which asymptote maximum growth occurs. */
  39.     private final double oneOverN;
  40.     /** Parameter that affects the position of the curve along the ordinate axis. */
  41.     private final double q;
  42.     /** Abscissa of maximum growth. */
  43.     private final double m;

  44.     /**
  45.      * @param k If {@code b > 0}, value of the function for x going towards +&infin;.
  46.      * If {@code b < 0}, value of the function for x going towards -&infin;.
  47.      * @param m Abscissa of maximum growth.
  48.      * @param b Growth rate.
  49.      * @param q Parameter that affects the position of the curve along the
  50.      * ordinate axis.
  51.      * @param a If {@code b > 0}, value of the function for x going towards -&infin;.
  52.      * If {@code b < 0}, value of the function for x going towards +&infin;.
  53.      * @param n Parameter that affects near which asymptote the maximum
  54.      * growth occurs.
  55.      * @throws NotStrictlyPositiveException if {@code n <= 0}.
  56.      */
  57.     public Logistic(double k,
  58.                     double m,
  59.                     double b,
  60.                     double q,
  61.                     double a,
  62.                     double n)
  63.         throws NotStrictlyPositiveException {
  64.         if (n <= 0) {
  65.             throw new NotStrictlyPositiveException(n);
  66.         }

  67.         this.k = k;
  68.         this.m = m;
  69.         this.b = b;
  70.         this.q = q;
  71.         this.a = a;
  72.         oneOverN = 1 / n;
  73.     }

  74.     /** {@inheritDoc} */
  75.     @Override
  76.     public double value(double x) {
  77.         return value(m - x, k, b, q, a, oneOverN);
  78.     }

  79.     /**
  80.      * Parametric function where the input array contains the parameters of
  81.      * the {@link Logistic#Logistic(double,double,double,double,double,double)
  82.      * logistic function}. Ordered as follows:
  83.      * <ul>
  84.      *  <li>k</li>
  85.      *  <li>m</li>
  86.      *  <li>b</li>
  87.      *  <li>q</li>
  88.      *  <li>a</li>
  89.      *  <li>n</li>
  90.      * </ul>
  91.      */
  92.     public static class Parametric implements ParametricUnivariateFunction {
  93.         /**
  94.          * Computes the value of the sigmoid at {@code x}.
  95.          *
  96.          * @param x Value for which the function must be computed.
  97.          * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
  98.          * {@code a} and  {@code n}.
  99.          * @return the value of the function.
  100.          * @throws NullArgumentException if {@code param} is {@code null}.
  101.          * @throws DimensionMismatchException if the size of {@code param} is
  102.          * not 6.
  103.          * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
  104.          */
  105.         @Override
  106.         public double value(double x, double ... param)
  107.             throws NullArgumentException,
  108.                    DimensionMismatchException,
  109.                    NotStrictlyPositiveException {
  110.             validateParameters(param);
  111.             return Logistic.value(param[1] - x, param[0],
  112.                                   param[2], param[3],
  113.                                   param[4], 1 / param[5]);
  114.         }

  115.         /**
  116.          * Computes the value of the gradient at {@code x}.
  117.          * The components of the gradient vector are the partial
  118.          * derivatives of the function with respect to each of the
  119.          * <em>parameters</em>.
  120.          *
  121.          * @param x Value at which the gradient must be computed.
  122.          * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
  123.          * {@code a} and  {@code n}.
  124.          * @return the gradient vector at {@code x}.
  125.          * @throws NullArgumentException if {@code param} is {@code null}.
  126.          * @throws DimensionMismatchException if the size of {@code param} is
  127.          * not 6.
  128.          * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
  129.          */
  130.         @Override
  131.         public double[] gradient(double x, double ... param)
  132.             throws NullArgumentException,
  133.                    DimensionMismatchException,
  134.                    NotStrictlyPositiveException {
  135.             validateParameters(param);

  136.             final double b = param[2];
  137.             final double q = param[3];

  138.             final double mMinusX = param[1] - x;
  139.             final double oneOverN = 1 / param[5];
  140.             final double exp = JdkMath.exp(b * mMinusX);
  141.             final double qExp = q * exp;
  142.             final double qExp1 = qExp + 1;
  143.             final double factor1 = (param[0] - param[4]) * oneOverN / JdkMath.pow(qExp1, oneOverN);
  144.             final double factor2 = -factor1 / qExp1;

  145.             // Components of the gradient.
  146.             final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN);
  147.             final double gm = factor2 * b * qExp;
  148.             final double gb = factor2 * mMinusX * qExp;
  149.             final double gq = factor2 * exp;
  150.             final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN);
  151.             final double gn = factor1 * JdkMath.log(qExp1) * oneOverN;

  152.             return new double[] { gk, gm, gb, gq, ga, gn };
  153.         }

  154.         /**
  155.          * Validates parameters to ensure they are appropriate for the evaluation of
  156.          * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
  157.          * methods.
  158.          *
  159.          * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
  160.          * {@code a} and {@code n}.
  161.          * @throws NullArgumentException if {@code param} is {@code null}.
  162.          * @throws DimensionMismatchException if the size of {@code param} is
  163.          * not 6.
  164.          * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
  165.          */
  166.         private void validateParameters(double[] param)
  167.             throws NullArgumentException,
  168.                    DimensionMismatchException,
  169.                    NotStrictlyPositiveException {
  170.             if (param == null) {
  171.                 throw new NullArgumentException();
  172.             }
  173.             if (param.length != 6) {
  174.                 throw new DimensionMismatchException(param.length, 6);
  175.             }
  176.             if (param[5] <= 0) {
  177.                 throw new NotStrictlyPositiveException(param[5]);
  178.             }
  179.         }
  180.     }

  181.     /**
  182.      * @param mMinusX {@code m - x}.
  183.      * @param k {@code k}.
  184.      * @param b {@code b}.
  185.      * @param q {@code q}.
  186.      * @param a {@code a}.
  187.      * @param oneOverN {@code 1 / n}.
  188.      * @return the value of the function.
  189.      */
  190.     private static double value(double mMinusX,
  191.                                 double k,
  192.                                 double b,
  193.                                 double q,
  194.                                 double a,
  195.                                 double oneOverN) {
  196.         return a + (k - a) / JdkMath.pow(1 + q * JdkMath.exp(b * mMinusX), oneOverN);
  197.     }

  198.     /** {@inheritDoc}
  199.      * @since 3.1
  200.      */
  201.     @Override
  202.     public DerivativeStructure value(final DerivativeStructure t) {
  203.         return t.negate().add(m).multiply(b).exp().multiply(q).add(1).pow(oneOverN).reciprocal().multiply(k - a).add(a);
  204.     }
  205. }