1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.math3.analysis.function;
19
20 import org.apache.commons.math3.analysis.FunctionUtils;
21 import org.apache.commons.math3.analysis.UnivariateFunction;
22 import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction;
23 import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
24 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
25 import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
26 import org.apache.commons.math3.exception.NotStrictlyPositiveException;
27 import org.apache.commons.math3.exception.NullArgumentException;
28 import org.apache.commons.math3.exception.DimensionMismatchException;
29 import org.apache.commons.math3.util.FastMath;
30
31 /**
32 * <a href="http://en.wikipedia.org/wiki/Generalised_logistic_function">
33 * Generalised logistic</a> function.
34 *
35 * @since 3.0
36 * @version $Id: Logistic.java 1391927 2012-09-30 00:03:30Z erans $
37 */
38 public class Logistic implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction {
39 /** Lower asymptote. */
40 private final double a;
41 /** Upper asymptote. */
42 private final double k;
43 /** Growth rate. */
44 private final double b;
45 /** Parameter that affects near which asymptote maximum growth occurs. */
46 private final double oneOverN;
47 /** Parameter that affects the position of the curve along the ordinate axis. */
48 private final double q;
49 /** Abscissa of maximum growth. */
50 private final double m;
51
52 /**
53 * @param k If {@code b > 0}, value of the function for x going towards +∞.
54 * If {@code b < 0}, value of the function for x going towards -∞.
55 * @param m Abscissa of maximum growth.
56 * @param b Growth rate.
57 * @param q Parameter that affects the position of the curve along the
58 * ordinate axis.
59 * @param a If {@code b > 0}, value of the function for x going towards -∞.
60 * If {@code b < 0}, value of the function for x going towards +∞.
61 * @param n Parameter that affects near which asymptote the maximum
62 * growth occurs.
63 * @throws NotStrictlyPositiveException if {@code n <= 0}.
64 */
65 public Logistic(double k,
66 double m,
67 double b,
68 double q,
69 double a,
70 double n)
71 throws NotStrictlyPositiveException {
72 if (n <= 0) {
73 throw new NotStrictlyPositiveException(n);
74 }
75
76 this.k = k;
77 this.m = m;
78 this.b = b;
79 this.q = q;
80 this.a = a;
81 oneOverN = 1 / n;
82 }
83
84 /** {@inheritDoc} */
85 public double value(double x) {
86 return value(m - x, k, b, q, a, oneOverN);
87 }
88
89 /** {@inheritDoc}
90 * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)}
91 */
92 @Deprecated
93 public UnivariateFunction derivative() {
94 return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative();
95 }
96
97 /**
98 * Parametric function where the input array contains the parameters of
99 * the logit function, ordered as follows:
100 * <ul>
101 * <li>Lower asymptote</li>
102 * <li>Higher asymptote</li>
103 * </ul>
104 */
105 public static class Parametric implements ParametricUnivariateFunction {
106 /**
107 * Computes the value of the sigmoid at {@code x}.
108 *
109 * @param x Value for which the function must be computed.
110 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
111 * {@code a} and {@code n}.
112 * @return the value of the function.
113 * @throws NullArgumentException if {@code param} is {@code null}.
114 * @throws DimensionMismatchException if the size of {@code param} is
115 * not 6.
116 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
117 */
118 public double value(double x, double ... param)
119 throws NullArgumentException,
120 DimensionMismatchException,
121 NotStrictlyPositiveException {
122 validateParameters(param);
123 return Logistic.value(param[1] - x, param[0],
124 param[2], param[3],
125 param[4], 1 / param[5]);
126 }
127
128 /**
129 * Computes the value of the gradient at {@code x}.
130 * The components of the gradient vector are the partial
131 * derivatives of the function with respect to each of the
132 * <em>parameters</em>.
133 *
134 * @param x Value at which the gradient must be computed.
135 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
136 * {@code a} and {@code n}.
137 * @return the gradient vector at {@code x}.
138 * @throws NullArgumentException if {@code param} is {@code null}.
139 * @throws DimensionMismatchException if the size of {@code param} is
140 * not 6.
141 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
142 */
143 public double[] gradient(double x, double ... param)
144 throws NullArgumentException,
145 DimensionMismatchException,
146 NotStrictlyPositiveException {
147 validateParameters(param);
148
149 final double b = param[2];
150 final double q = param[3];
151
152 final double mMinusX = param[1] - x;
153 final double oneOverN = 1 / param[5];
154 final double exp = FastMath.exp(b * mMinusX);
155 final double qExp = q * exp;
156 final double qExp1 = qExp + 1;
157 final double factor1 = (param[0] - param[4]) * oneOverN / FastMath.pow(qExp1, oneOverN);
158 final double factor2 = -factor1 / qExp1;
159
160 // Components of the gradient.
161 final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN);
162 final double gm = factor2 * b * qExp;
163 final double gb = factor2 * mMinusX * qExp;
164 final double gq = factor2 * exp;
165 final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN);
166 final double gn = factor1 * Math.log(qExp1) * oneOverN;
167
168 return new double[] { gk, gm, gb, gq, ga, gn };
169 }
170
171 /**
172 * Validates parameters to ensure they are appropriate for the evaluation of
173 * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
174 * methods.
175 *
176 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
177 * {@code a} and {@code n}.
178 * @throws NullArgumentException if {@code param} is {@code null}.
179 * @throws DimensionMismatchException if the size of {@code param} is
180 * not 6.
181 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
182 */
183 private void validateParameters(double[] param)
184 throws NullArgumentException,
185 DimensionMismatchException,
186 NotStrictlyPositiveException {
187 if (param == null) {
188 throw new NullArgumentException();
189 }
190 if (param.length != 6) {
191 throw new DimensionMismatchException(param.length, 6);
192 }
193 if (param[5] <= 0) {
194 throw new NotStrictlyPositiveException(param[5]);
195 }
196 }
197 }
198
199 /**
200 * @param mMinusX {@code m - x}.
201 * @param k {@code k}.
202 * @param b {@code b}.
203 * @param q {@code q}.
204 * @param a {@code a}.
205 * @param oneOverN {@code 1 / n}.
206 * @return the value of the function.
207 */
208 private static double value(double mMinusX,
209 double k,
210 double b,
211 double q,
212 double a,
213 double oneOverN) {
214 return a + (k - a) / FastMath.pow(1 + q * FastMath.exp(b * mMinusX), oneOverN);
215 }
216
217 /** {@inheritDoc}
218 * @since 3.1
219 */
220 public DerivativeStructure value(final DerivativeStructure t) {
221 return t.negate().add(m).multiply(b).exp().multiply(q).add(1).pow(oneOverN).reciprocal().multiply(k - a).add(a);
222 }
223
224 }