1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.analysis.function;
19
20 import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
21 import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
22 import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction;
23 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
24 import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
25 import org.apache.commons.math4.legacy.exception.NullArgumentException;
26 import org.apache.commons.math4.core.jdkmath.JdkMath;
27
28
29
30
31
32
33
34 public class Logistic implements UnivariateDifferentiableFunction {
35
36 private final double a;
37
38 private final double k;
39
40 private final double b;
41
42 private final double oneOverN;
43
44 private final double q;
45
46 private final double m;
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61 public Logistic(double k,
62 double m,
63 double b,
64 double q,
65 double a,
66 double n)
67 throws NotStrictlyPositiveException {
68 if (n <= 0) {
69 throw new NotStrictlyPositiveException(n);
70 }
71
72 this.k = k;
73 this.m = m;
74 this.b = b;
75 this.q = q;
76 this.a = a;
77 oneOverN = 1 / n;
78 }
79
80
81 @Override
82 public double value(double x) {
83 return value(m - x, k, b, q, a, oneOverN);
84 }
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99 public static class Parametric implements ParametricUnivariateFunction {
100
101
102
103
104
105
106
107
108
109
110
111
112 @Override
113 public double value(double x, double ... param)
114 throws NullArgumentException,
115 DimensionMismatchException,
116 NotStrictlyPositiveException {
117 validateParameters(param);
118 return Logistic.value(param[1] - x, param[0],
119 param[2], param[3],
120 param[4], 1 / param[5]);
121 }
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138 @Override
139 public double[] gradient(double x, double ... param)
140 throws NullArgumentException,
141 DimensionMismatchException,
142 NotStrictlyPositiveException {
143 validateParameters(param);
144
145 final double b = param[2];
146 final double q = param[3];
147
148 final double mMinusX = param[1] - x;
149 final double oneOverN = 1 / param[5];
150 final double exp = JdkMath.exp(b * mMinusX);
151 final double qExp = q * exp;
152 final double qExp1 = qExp + 1;
153 final double factor1 = (param[0] - param[4]) * oneOverN / JdkMath.pow(qExp1, oneOverN);
154 final double factor2 = -factor1 / qExp1;
155
156
157 final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN);
158 final double gm = factor2 * b * qExp;
159 final double gb = factor2 * mMinusX * qExp;
160 final double gq = factor2 * exp;
161 final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN);
162 final double gn = factor1 * JdkMath.log(qExp1) * oneOverN;
163
164 return new double[] { gk, gm, gb, gq, ga, gn };
165 }
166
167
168
169
170
171
172
173
174
175
176
177
178
179 private void validateParameters(double[] param)
180 throws NullArgumentException,
181 DimensionMismatchException,
182 NotStrictlyPositiveException {
183 if (param == null) {
184 throw new NullArgumentException();
185 }
186 if (param.length != 6) {
187 throw new DimensionMismatchException(param.length, 6);
188 }
189 if (param[5] <= 0) {
190 throw new NotStrictlyPositiveException(param[5]);
191 }
192 }
193 }
194
195
196
197
198
199
200
201
202
203
204 private static double value(double mMinusX,
205 double k,
206 double b,
207 double q,
208 double a,
209 double oneOverN) {
210 return a + (k - a) / JdkMath.pow(1 + q * JdkMath.exp(b * mMinusX), oneOverN);
211 }
212
213
214
215
216 @Override
217 public DerivativeStructure value(final DerivativeStructure t) {
218 return t.negate().add(m).multiply(b).exp().multiply(q).add(1).pow(oneOverN).reciprocal().multiply(k - a).add(a);
219 }
220 }