1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.math4.legacy.analysis.function;
19
20 import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
21 import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
22 import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction;
23 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
24 import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
25 import org.apache.commons.math4.legacy.exception.NullArgumentException;
26 import org.apache.commons.math4.core.jdkmath.JdkMath;
27
28 /**
29 * <a href="http://en.wikipedia.org/wiki/Generalised_logistic_function">
30 * Generalised logistic</a> function.
31 *
32 * @since 3.0
33 */
34 public class Logistic implements UnivariateDifferentiableFunction {
35 /** Lower asymptote. */
36 private final double a;
37 /** Upper asymptote. */
38 private final double k;
39 /** Growth rate. */
40 private final double b;
41 /** Parameter that affects near which asymptote maximum growth occurs. */
42 private final double oneOverN;
43 /** Parameter that affects the position of the curve along the ordinate axis. */
44 private final double q;
45 /** Abscissa of maximum growth. */
46 private final double m;
47
48 /**
49 * @param k If {@code b > 0}, value of the function for x going towards +∞.
50 * If {@code b < 0}, value of the function for x going towards -∞.
51 * @param m Abscissa of maximum growth.
52 * @param b Growth rate.
53 * @param q Parameter that affects the position of the curve along the
54 * ordinate axis.
55 * @param a If {@code b > 0}, value of the function for x going towards -∞.
56 * If {@code b < 0}, value of the function for x going towards +∞.
57 * @param n Parameter that affects near which asymptote the maximum
58 * growth occurs.
59 * @throws NotStrictlyPositiveException if {@code n <= 0}.
60 */
61 public Logistic(double k,
62 double m,
63 double b,
64 double q,
65 double a,
66 double n)
67 throws NotStrictlyPositiveException {
68 if (n <= 0) {
69 throw new NotStrictlyPositiveException(n);
70 }
71
72 this.k = k;
73 this.m = m;
74 this.b = b;
75 this.q = q;
76 this.a = a;
77 oneOverN = 1 / n;
78 }
79
80 /** {@inheritDoc} */
81 @Override
82 public double value(double x) {
83 return value(m - x, k, b, q, a, oneOverN);
84 }
85
86 /**
87 * Parametric function where the input array contains the parameters of
88 * the {@link Logistic#Logistic(double,double,double,double,double,double)
89 * logistic function}. Ordered as follows:
90 * <ul>
91 * <li>k</li>
92 * <li>m</li>
93 * <li>b</li>
94 * <li>q</li>
95 * <li>a</li>
96 * <li>n</li>
97 * </ul>
98 */
99 public static class Parametric implements ParametricUnivariateFunction {
100 /**
101 * Computes the value of the sigmoid at {@code x}.
102 *
103 * @param x Value for which the function must be computed.
104 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
105 * {@code a} and {@code n}.
106 * @return the value of the function.
107 * @throws NullArgumentException if {@code param} is {@code null}.
108 * @throws DimensionMismatchException if the size of {@code param} is
109 * not 6.
110 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
111 */
112 @Override
113 public double value(double x, double ... param)
114 throws NullArgumentException,
115 DimensionMismatchException,
116 NotStrictlyPositiveException {
117 validateParameters(param);
118 return Logistic.value(param[1] - x, param[0],
119 param[2], param[3],
120 param[4], 1 / param[5]);
121 }
122
123 /**
124 * Computes the value of the gradient at {@code x}.
125 * The components of the gradient vector are the partial
126 * derivatives of the function with respect to each of the
127 * <em>parameters</em>.
128 *
129 * @param x Value at which the gradient must be computed.
130 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
131 * {@code a} and {@code n}.
132 * @return the gradient vector at {@code x}.
133 * @throws NullArgumentException if {@code param} is {@code null}.
134 * @throws DimensionMismatchException if the size of {@code param} is
135 * not 6.
136 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
137 */
138 @Override
139 public double[] gradient(double x, double ... param)
140 throws NullArgumentException,
141 DimensionMismatchException,
142 NotStrictlyPositiveException {
143 validateParameters(param);
144
145 final double b = param[2];
146 final double q = param[3];
147
148 final double mMinusX = param[1] - x;
149 final double oneOverN = 1 / param[5];
150 final double exp = JdkMath.exp(b * mMinusX);
151 final double qExp = q * exp;
152 final double qExp1 = qExp + 1;
153 final double factor1 = (param[0] - param[4]) * oneOverN / JdkMath.pow(qExp1, oneOverN);
154 final double factor2 = -factor1 / qExp1;
155
156 // Components of the gradient.
157 final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN);
158 final double gm = factor2 * b * qExp;
159 final double gb = factor2 * mMinusX * qExp;
160 final double gq = factor2 * exp;
161 final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN);
162 final double gn = factor1 * JdkMath.log(qExp1) * oneOverN;
163
164 return new double[] { gk, gm, gb, gq, ga, gn };
165 }
166
167 /**
168 * Validates parameters to ensure they are appropriate for the evaluation of
169 * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
170 * methods.
171 *
172 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
173 * {@code a} and {@code n}.
174 * @throws NullArgumentException if {@code param} is {@code null}.
175 * @throws DimensionMismatchException if the size of {@code param} is
176 * not 6.
177 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
178 */
179 private void validateParameters(double[] param)
180 throws NullArgumentException,
181 DimensionMismatchException,
182 NotStrictlyPositiveException {
183 if (param == null) {
184 throw new NullArgumentException();
185 }
186 if (param.length != 6) {
187 throw new DimensionMismatchException(param.length, 6);
188 }
189 if (param[5] <= 0) {
190 throw new NotStrictlyPositiveException(param[5]);
191 }
192 }
193 }
194
195 /**
196 * @param mMinusX {@code m - x}.
197 * @param k {@code k}.
198 * @param b {@code b}.
199 * @param q {@code q}.
200 * @param a {@code a}.
201 * @param oneOverN {@code 1 / n}.
202 * @return the value of the function.
203 */
204 private static double value(double mMinusX,
205 double k,
206 double b,
207 double q,
208 double a,
209 double oneOverN) {
210 return a + (k - a) / JdkMath.pow(1 + q * JdkMath.exp(b * mMinusX), oneOverN);
211 }
212
213 /** {@inheritDoc}
214 * @since 3.1
215 */
216 @Override
217 public DerivativeStructure value(final DerivativeStructure t) {
218 return t.negate().add(m).multiply(b).exp().multiply(q).add(1).pow(oneOverN).reciprocal().multiply(k - a).add(a);
219 }
220 }