1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.math4.legacy.analysis.function;
19
20 import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
21 import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
22 import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction;
23 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
24 import org.apache.commons.math4.legacy.exception.NullArgumentException;
25 import org.apache.commons.math4.legacy.exception.OutOfRangeException;
26 import org.apache.commons.math4.core.jdkmath.JdkMath;
27
28 /**
29 * <a href="http://en.wikipedia.org/wiki/Logit">
30 * Logit</a> function.
31 * It is the inverse of the {@link Sigmoid sigmoid} function.
32 *
33 * @since 3.0
34 */
35 public class Logit implements UnivariateDifferentiableFunction {
36 /** Lower bound. */
37 private final double lo;
38 /** Higher bound. */
39 private final double hi;
40
41 /**
42 * Usual logit function, where the lower bound is 0 and the higher
43 * bound is 1.
44 */
45 public Logit() {
46 this(0, 1);
47 }
48
49 /**
50 * Logit function.
51 *
52 * @param lo Lower bound of the function domain.
53 * @param hi Higher bound of the function domain.
54 */
55 public Logit(double lo,
56 double hi) {
57 this.lo = lo;
58 this.hi = hi;
59 }
60
61 /** {@inheritDoc} */
62 @Override
63 public double value(double x)
64 throws OutOfRangeException {
65 return value(x, lo, hi);
66 }
67
68 /**
69 * Parametric function where the input array contains the parameters of
70 * the logit function. Ordered as follows:
71 * <ul>
72 * <li>Lower bound</li>
73 * <li>Higher bound</li>
74 * </ul>
75 */
76 public static class Parametric implements ParametricUnivariateFunction {
77 /**
78 * Computes the value of the logit at {@code x}.
79 *
80 * @param x Value for which the function must be computed.
81 * @param param Values of lower bound and higher bounds.
82 * @return the value of the function.
83 * @throws NullArgumentException if {@code param} is {@code null}.
84 * @throws DimensionMismatchException if the size of {@code param} is
85 * not 2.
86 */
87 @Override
88 public double value(double x, double ... param)
89 throws NullArgumentException,
90 DimensionMismatchException {
91 validateParameters(param);
92 return Logit.value(x, param[0], param[1]);
93 }
94
95 /**
96 * Computes the value of the gradient at {@code x}.
97 * The components of the gradient vector are the partial
98 * derivatives of the function with respect to each of the
99 * <em>parameters</em> (lower bound and higher bound).
100 *
101 * @param x Value at which the gradient must be computed.
102 * @param param Values for lower and higher bounds.
103 * @return the gradient vector at {@code x}.
104 * @throws NullArgumentException if {@code param} is {@code null}.
105 * @throws DimensionMismatchException if the size of {@code param} is
106 * not 2.
107 */
108 @Override
109 public double[] gradient(double x, double ... param)
110 throws NullArgumentException,
111 DimensionMismatchException {
112 validateParameters(param);
113
114 final double lo = param[0];
115 final double hi = param[1];
116
117 return new double[] { 1 / (lo - x), 1 / (hi - x) };
118 }
119
120 /**
121 * Validates parameters to ensure they are appropriate for the evaluation of
122 * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
123 * methods.
124 *
125 * @param param Values for lower and higher bounds.
126 * @throws NullArgumentException if {@code param} is {@code null}.
127 * @throws DimensionMismatchException if the size of {@code param} is
128 * not 2.
129 */
130 private void validateParameters(double[] param)
131 throws NullArgumentException,
132 DimensionMismatchException {
133 if (param == null) {
134 throw new NullArgumentException();
135 }
136 if (param.length != 2) {
137 throw new DimensionMismatchException(param.length, 2);
138 }
139 }
140 }
141
142 /**
143 * @param x Value at which to compute the logit.
144 * @param lo Lower bound.
145 * @param hi Higher bound.
146 * @return the value of the logit function at {@code x}.
147 * @throws OutOfRangeException if {@code x < lo} or {@code x > hi}.
148 */
149 private static double value(double x,
150 double lo,
151 double hi)
152 throws OutOfRangeException {
153 if (x < lo || x > hi) {
154 throw new OutOfRangeException(x, lo, hi);
155 }
156 return JdkMath.log((x - lo) / (hi - x));
157 }
158
159 /** {@inheritDoc}
160 * @since 3.1
161 * @exception OutOfRangeException if parameter is outside of function domain
162 */
163 @Override
164 public DerivativeStructure value(final DerivativeStructure t)
165 throws OutOfRangeException {
166 final double x = t.getValue();
167 if (x < lo || x > hi) {
168 throw new OutOfRangeException(x, lo, hi);
169 }
170 double[] f = new double[t.getOrder() + 1];
171
172 // function value
173 f[0] = JdkMath.log((x - lo) / (hi - x));
174
175 if (Double.isInfinite(f[0])) {
176
177 if (f.length > 1) {
178 f[1] = Double.POSITIVE_INFINITY;
179 }
180 // fill the array with infinities
181 // (for x close to lo the signs will flip between -inf and +inf,
182 // for x close to hi the signs will always be +inf)
183 // this is probably overkill, since the call to compose at the end
184 // of the method will transform most infinities into NaN ...
185 for (int i = 2; i < f.length; ++i) {
186 f[i] = f[i - 2];
187 }
188 } else {
189
190 // function derivatives
191 final double invL = 1.0 / (x - lo);
192 double xL = invL;
193 final double invH = 1.0 / (hi - x);
194 double xH = invH;
195 for (int i = 1; i < f.length; ++i) {
196 f[i] = xL + xH;
197 xL *= -i * invL;
198 xH *= i * invH;
199 }
200 }
201
202 return t.compose(f);
203 }
204 }