001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math3.analysis.function; 019 020import java.util.Arrays; 021 022import org.apache.commons.math3.analysis.FunctionUtils; 023import org.apache.commons.math3.analysis.UnivariateFunction; 024import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction; 025import org.apache.commons.math3.analysis.ParametricUnivariateFunction; 026import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; 027import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction; 028import org.apache.commons.math3.exception.NullArgumentException; 029import org.apache.commons.math3.exception.DimensionMismatchException; 030import org.apache.commons.math3.util.FastMath; 031 032/** 033 * <a href="http://en.wikipedia.org/wiki/Sigmoid_function"> 034 * Sigmoid</a> function. 035 * It is the inverse of the {@link Logit logit} function. 036 * A more flexible version, the generalised logistic, is implemented 037 * by the {@link Logistic} class. 038 * 039 * @since 3.0 040 */ 041public class Sigmoid implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction { 042 /** Lower asymptote. */ 043 private final double lo; 044 /** Higher asymptote. */ 045 private final double hi; 046 047 /** 048 * Usual sigmoid function, where the lower asymptote is 0 and the higher 049 * asymptote is 1. 050 */ 051 public Sigmoid() { 052 this(0, 1); 053 } 054 055 /** 056 * Sigmoid function. 057 * 058 * @param lo Lower asymptote. 059 * @param hi Higher asymptote. 060 */ 061 public Sigmoid(double lo, 062 double hi) { 063 this.lo = lo; 064 this.hi = hi; 065 } 066 067 /** {@inheritDoc} 068 * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)} 069 */ 070 @Deprecated 071 public UnivariateFunction derivative() { 072 return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative(); 073 } 074 075 /** {@inheritDoc} */ 076 public double value(double x) { 077 return value(x, lo, hi); 078 } 079 080 /** 081 * Parametric function where the input array contains the parameters of 082 * the {@link Sigmoid#Sigmoid(double,double) sigmoid function}, ordered 083 * as follows: 084 * <ul> 085 * <li>Lower asymptote</li> 086 * <li>Higher asymptote</li> 087 * </ul> 088 */ 089 public static class Parametric implements ParametricUnivariateFunction { 090 /** 091 * Computes the value of the sigmoid at {@code x}. 092 * 093 * @param x Value for which the function must be computed. 094 * @param param Values of lower asymptote and higher asymptote. 095 * @return the value of the function. 096 * @throws NullArgumentException if {@code param} is {@code null}. 097 * @throws DimensionMismatchException if the size of {@code param} is 098 * not 2. 099 */ 100 public double value(double x, double ... param) 101 throws NullArgumentException, 102 DimensionMismatchException { 103 validateParameters(param); 104 return Sigmoid.value(x, param[0], param[1]); 105 } 106 107 /** 108 * Computes the value of the gradient at {@code x}. 109 * The components of the gradient vector are the partial 110 * derivatives of the function with respect to each of the 111 * <em>parameters</em> (lower asymptote and higher asymptote). 112 * 113 * @param x Value at which the gradient must be computed. 114 * @param param Values for lower asymptote and higher asymptote. 115 * @return the gradient vector at {@code x}. 116 * @throws NullArgumentException if {@code param} is {@code null}. 117 * @throws DimensionMismatchException if the size of {@code param} is 118 * not 2. 119 */ 120 public double[] gradient(double x, double ... param) 121 throws NullArgumentException, 122 DimensionMismatchException { 123 validateParameters(param); 124 125 final double invExp1 = 1 / (1 + FastMath.exp(-x)); 126 127 return new double[] { 1 - invExp1, invExp1 }; 128 } 129 130 /** 131 * Validates parameters to ensure they are appropriate for the evaluation of 132 * the {@link #value(double,double[])} and {@link #gradient(double,double[])} 133 * methods. 134 * 135 * @param param Values for lower and higher asymptotes. 136 * @throws NullArgumentException if {@code param} is {@code null}. 137 * @throws DimensionMismatchException if the size of {@code param} is 138 * not 2. 139 */ 140 private void validateParameters(double[] param) 141 throws NullArgumentException, 142 DimensionMismatchException { 143 if (param == null) { 144 throw new NullArgumentException(); 145 } 146 if (param.length != 2) { 147 throw new DimensionMismatchException(param.length, 2); 148 } 149 } 150 } 151 152 /** 153 * @param x Value at which to compute the sigmoid. 154 * @param lo Lower asymptote. 155 * @param hi Higher asymptote. 156 * @return the value of the sigmoid function at {@code x}. 157 */ 158 private static double value(double x, 159 double lo, 160 double hi) { 161 return lo + (hi - lo) / (1 + FastMath.exp(-x)); 162 } 163 164 /** {@inheritDoc} 165 * @since 3.1 166 */ 167 public DerivativeStructure value(final DerivativeStructure t) 168 throws DimensionMismatchException { 169 170 double[] f = new double[t.getOrder() + 1]; 171 final double exp = FastMath.exp(-t.getValue()); 172 if (Double.isInfinite(exp)) { 173 174 // special handling near lower boundary, to avoid NaN 175 f[0] = lo; 176 Arrays.fill(f, 1, f.length, 0.0); 177 178 } else { 179 180 // the nth order derivative of sigmoid has the form: 181 // dn(sigmoid(x)/dxn = P_n(exp(-x)) / (1+exp(-x))^(n+1) 182 // where P_n(t) is a degree n polynomial with normalized higher term 183 // P_0(t) = 1, P_1(t) = t, P_2(t) = t^2 - t, P_3(t) = t^3 - 4 t^2 + t... 184 // the general recurrence relation for P_n is: 185 // P_n(x) = n t P_(n-1)(t) - t (1 + t) P_(n-1)'(t) 186 final double[] p = new double[f.length]; 187 188 final double inv = 1 / (1 + exp); 189 double coeff = hi - lo; 190 for (int n = 0; n < f.length; ++n) { 191 192 // update and evaluate polynomial P_n(t) 193 double v = 0; 194 p[n] = 1; 195 for (int k = n; k >= 0; --k) { 196 v = v * exp + p[k]; 197 if (k > 1) { 198 p[k - 1] = (n - k + 2) * p[k - 2] - (k - 1) * p[k - 1]; 199 } else { 200 p[0] = 0; 201 } 202 } 203 204 coeff *= inv; 205 f[n] = coeff * v; 206 207 } 208 209 // fix function value 210 f[0] += lo; 211 212 } 213 214 return t.compose(f); 215 216 } 217 218}