1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.commons.math4.legacy.analysis.function; 19 20 import java.util.Arrays; 21 22 import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction; 23 import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure; 24 import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction; 25 import org.apache.commons.math4.legacy.exception.DimensionMismatchException; 26 import org.apache.commons.math4.legacy.exception.NullArgumentException; 27 import org.apache.commons.math4.core.jdkmath.JdkMath; 28 29 /** 30 * <a href="http://en.wikipedia.org/wiki/Sigmoid_function"> 31 * Sigmoid</a> function. 32 * It is the inverse of the {@link Logit logit} function. 33 * A more flexible version, the generalised logistic, is implemented 34 * by the {@link Logistic} class. 35 * 36 * @since 3.0 37 */ 38 public class Sigmoid implements UnivariateDifferentiableFunction { 39 /** Lower asymptote. */ 40 private final double lo; 41 /** Higher asymptote. */ 42 private final double hi; 43 44 /** 45 * Usual sigmoid function, where the lower asymptote is 0 and the higher 46 * asymptote is 1. 47 */ 48 public Sigmoid() { 49 this(0, 1); 50 } 51 52 /** 53 * Sigmoid function. 54 * 55 * @param lo Lower asymptote. 56 * @param hi Higher asymptote. 57 */ 58 public Sigmoid(double lo, 59 double hi) { 60 this.lo = lo; 61 this.hi = hi; 62 } 63 64 /** {@inheritDoc} */ 65 @Override 66 public double value(double x) { 67 return value(x, lo, hi); 68 } 69 70 /** 71 * Parametric function where the input array contains the parameters of 72 * the {@link Sigmoid#Sigmoid(double,double) sigmoid function}. Ordered 73 * as follows: 74 * <ul> 75 * <li>Lower asymptote</li> 76 * <li>Higher asymptote</li> 77 * </ul> 78 */ 79 public static class Parametric implements ParametricUnivariateFunction { 80 /** 81 * Computes the value of the sigmoid at {@code x}. 82 * 83 * @param x Value for which the function must be computed. 84 * @param param Values of lower asymptote and higher asymptote. 85 * @return the value of the function. 86 * @throws NullArgumentException if {@code param} is {@code null}. 87 * @throws DimensionMismatchException if the size of {@code param} is 88 * not 2. 89 */ 90 @Override 91 public double value(double x, double ... param) 92 throws NullArgumentException, 93 DimensionMismatchException { 94 validateParameters(param); 95 return Sigmoid.value(x, param[0], param[1]); 96 } 97 98 /** 99 * Computes the value of the gradient at {@code x}. 100 * The components of the gradient vector are the partial 101 * derivatives of the function with respect to each of the 102 * <em>parameters</em> (lower asymptote and higher asymptote). 103 * 104 * @param x Value at which the gradient must be computed. 105 * @param param Values for lower asymptote and higher asymptote. 106 * @return the gradient vector at {@code x}. 107 * @throws NullArgumentException if {@code param} is {@code null}. 108 * @throws DimensionMismatchException if the size of {@code param} is 109 * not 2. 110 */ 111 @Override 112 public double[] gradient(double x, double ... param) 113 throws NullArgumentException, 114 DimensionMismatchException { 115 validateParameters(param); 116 117 final double invExp1 = 1 / (1 + JdkMath.exp(-x)); 118 119 return new double[] { 1 - invExp1, invExp1 }; 120 } 121 122 /** 123 * Validates parameters to ensure they are appropriate for the evaluation of 124 * the {@link #value(double,double[])} and {@link #gradient(double,double[])} 125 * methods. 126 * 127 * @param param Values for lower and higher asymptotes. 128 * @throws NullArgumentException if {@code param} is {@code null}. 129 * @throws DimensionMismatchException if the size of {@code param} is 130 * not 2. 131 */ 132 private void validateParameters(double[] param) 133 throws NullArgumentException, 134 DimensionMismatchException { 135 if (param == null) { 136 throw new NullArgumentException(); 137 } 138 if (param.length != 2) { 139 throw new DimensionMismatchException(param.length, 2); 140 } 141 } 142 } 143 144 /** 145 * @param x Value at which to compute the sigmoid. 146 * @param lo Lower asymptote. 147 * @param hi Higher asymptote. 148 * @return the value of the sigmoid function at {@code x}. 149 */ 150 private static double value(double x, 151 double lo, 152 double hi) { 153 return lo + (hi - lo) / (1 + JdkMath.exp(-x)); 154 } 155 156 /** {@inheritDoc} 157 * @since 3.1 158 */ 159 @Override 160 public DerivativeStructure value(final DerivativeStructure t) 161 throws DimensionMismatchException { 162 163 double[] f = new double[t.getOrder() + 1]; 164 final double exp = JdkMath.exp(-t.getValue()); 165 if (Double.isInfinite(exp)) { 166 167 // special handling near lower boundary, to avoid NaN 168 f[0] = lo; 169 Arrays.fill(f, 1, f.length, 0.0); 170 } else { 171 172 // the nth order derivative of sigmoid has the form: 173 // dn(sigmoid(x)/dxn = P_n(exp(-x)) / (1+exp(-x))^(n+1) 174 // where P_n(t) is a degree n polynomial with normalized higher term 175 // P_0(t) = 1, P_1(t) = t, P_2(t) = t^2 - t, P_3(t) = t^3 - 4 t^2 + t... 176 // the general recurrence relation for P_n is: 177 // P_n(x) = n t P_(n-1)(t) - t (1 + t) P_(n-1)'(t) 178 final double[] p = new double[f.length]; 179 180 final double inv = 1 / (1 + exp); 181 double coeff = hi - lo; 182 for (int n = 0; n < f.length; ++n) { 183 184 // update and evaluate polynomial P_n(t) 185 double v = 0; 186 p[n] = 1; 187 for (int k = n; k >= 0; --k) { 188 v = v * exp + p[k]; 189 if (k > 1) { 190 p[k - 1] = (n - k + 2) * p[k - 2] - (k - 1) * p[k - 1]; 191 } else { 192 p[0] = 0; 193 } 194 } 195 196 coeff *= inv; 197 f[n] = coeff * v; 198 } 199 200 // fix function value 201 f[0] += lo; 202 } 203 204 return t.compose(f); 205 } 206 }