001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math3.analysis.function; 019 020import org.apache.commons.math3.analysis.FunctionUtils; 021import org.apache.commons.math3.analysis.UnivariateFunction; 022import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction; 023import org.apache.commons.math3.analysis.ParametricUnivariateFunction; 024import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; 025import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction; 026import org.apache.commons.math3.exception.NotStrictlyPositiveException; 027import org.apache.commons.math3.exception.NullArgumentException; 028import org.apache.commons.math3.exception.DimensionMismatchException; 029import org.apache.commons.math3.util.FastMath; 030 031/** 032 * <a href="http://en.wikipedia.org/wiki/Generalised_logistic_function"> 033 * Generalised logistic</a> function. 034 * 035 * @since 3.0 036 */ 037public class Logistic implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction { 038 /** Lower asymptote. */ 039 private final double a; 040 /** Upper asymptote. */ 041 private final double k; 042 /** Growth rate. */ 043 private final double b; 044 /** Parameter that affects near which asymptote maximum growth occurs. */ 045 private final double oneOverN; 046 /** Parameter that affects the position of the curve along the ordinate axis. */ 047 private final double q; 048 /** Abscissa of maximum growth. */ 049 private final double m; 050 051 /** 052 * @param k If {@code b > 0}, value of the function for x going towards +∞. 053 * If {@code b < 0}, value of the function for x going towards -∞. 054 * @param m Abscissa of maximum growth. 055 * @param b Growth rate. 056 * @param q Parameter that affects the position of the curve along the 057 * ordinate axis. 058 * @param a If {@code b > 0}, value of the function for x going towards -∞. 059 * If {@code b < 0}, value of the function for x going towards +∞. 060 * @param n Parameter that affects near which asymptote the maximum 061 * growth occurs. 062 * @throws NotStrictlyPositiveException if {@code n <= 0}. 063 */ 064 public Logistic(double k, 065 double m, 066 double b, 067 double q, 068 double a, 069 double n) 070 throws NotStrictlyPositiveException { 071 if (n <= 0) { 072 throw new NotStrictlyPositiveException(n); 073 } 074 075 this.k = k; 076 this.m = m; 077 this.b = b; 078 this.q = q; 079 this.a = a; 080 oneOverN = 1 / n; 081 } 082 083 /** {@inheritDoc} */ 084 public double value(double x) { 085 return value(m - x, k, b, q, a, oneOverN); 086 } 087 088 /** {@inheritDoc} 089 * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)} 090 */ 091 @Deprecated 092 public UnivariateFunction derivative() { 093 return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative(); 094 } 095 096 /** 097 * Parametric function where the input array contains the parameters of 098 * the {@link Logistic#Logistic(double,double,double,double,double,double) 099 * logistic function}, ordered as follows: 100 * <ul> 101 * <li>k</li> 102 * <li>m</li> 103 * <li>b</li> 104 * <li>q</li> 105 * <li>a</li> 106 * <li>n</li> 107 * </ul> 108 */ 109 public static class Parametric implements ParametricUnivariateFunction { 110 /** 111 * Computes the value of the sigmoid at {@code x}. 112 * 113 * @param x Value for which the function must be computed. 114 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q}, 115 * {@code a} and {@code n}. 116 * @return the value of the function. 117 * @throws NullArgumentException if {@code param} is {@code null}. 118 * @throws DimensionMismatchException if the size of {@code param} is 119 * not 6. 120 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}. 121 */ 122 public double value(double x, double ... param) 123 throws NullArgumentException, 124 DimensionMismatchException, 125 NotStrictlyPositiveException { 126 validateParameters(param); 127 return Logistic.value(param[1] - x, param[0], 128 param[2], param[3], 129 param[4], 1 / param[5]); 130 } 131 132 /** 133 * Computes the value of the gradient at {@code x}. 134 * The components of the gradient vector are the partial 135 * derivatives of the function with respect to each of the 136 * <em>parameters</em>. 137 * 138 * @param x Value at which the gradient must be computed. 139 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q}, 140 * {@code a} and {@code n}. 141 * @return the gradient vector at {@code x}. 142 * @throws NullArgumentException if {@code param} is {@code null}. 143 * @throws DimensionMismatchException if the size of {@code param} is 144 * not 6. 145 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}. 146 */ 147 public double[] gradient(double x, double ... param) 148 throws NullArgumentException, 149 DimensionMismatchException, 150 NotStrictlyPositiveException { 151 validateParameters(param); 152 153 final double b = param[2]; 154 final double q = param[3]; 155 156 final double mMinusX = param[1] - x; 157 final double oneOverN = 1 / param[5]; 158 final double exp = FastMath.exp(b * mMinusX); 159 final double qExp = q * exp; 160 final double qExp1 = qExp + 1; 161 final double factor1 = (param[0] - param[4]) * oneOverN / FastMath.pow(qExp1, oneOverN); 162 final double factor2 = -factor1 / qExp1; 163 164 // Components of the gradient. 165 final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN); 166 final double gm = factor2 * b * qExp; 167 final double gb = factor2 * mMinusX * qExp; 168 final double gq = factor2 * exp; 169 final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN); 170 final double gn = factor1 * FastMath.log(qExp1) * oneOverN; 171 172 return new double[] { gk, gm, gb, gq, ga, gn }; 173 } 174 175 /** 176 * Validates parameters to ensure they are appropriate for the evaluation of 177 * the {@link #value(double,double[])} and {@link #gradient(double,double[])} 178 * methods. 179 * 180 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q}, 181 * {@code a} and {@code n}. 182 * @throws NullArgumentException if {@code param} is {@code null}. 183 * @throws DimensionMismatchException if the size of {@code param} is 184 * not 6. 185 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}. 186 */ 187 private void validateParameters(double[] param) 188 throws NullArgumentException, 189 DimensionMismatchException, 190 NotStrictlyPositiveException { 191 if (param == null) { 192 throw new NullArgumentException(); 193 } 194 if (param.length != 6) { 195 throw new DimensionMismatchException(param.length, 6); 196 } 197 if (param[5] <= 0) { 198 throw new NotStrictlyPositiveException(param[5]); 199 } 200 } 201 } 202 203 /** 204 * @param mMinusX {@code m - x}. 205 * @param k {@code k}. 206 * @param b {@code b}. 207 * @param q {@code q}. 208 * @param a {@code a}. 209 * @param oneOverN {@code 1 / n}. 210 * @return the value of the function. 211 */ 212 private static double value(double mMinusX, 213 double k, 214 double b, 215 double q, 216 double a, 217 double oneOverN) { 218 return a + (k - a) / FastMath.pow(1 + q * FastMath.exp(b * mMinusX), oneOverN); 219 } 220 221 /** {@inheritDoc} 222 * @since 3.1 223 */ 224 public DerivativeStructure value(final DerivativeStructure t) { 225 return t.negate().add(m).multiply(b).exp().multiply(q).add(1).pow(oneOverN).reciprocal().multiply(k - a).add(a); 226 } 227 228}