001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math4.legacy.analysis.function; 019 020import java.util.Arrays; 021 022import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction; 023import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure; 024import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction; 025import org.apache.commons.math4.legacy.exception.DimensionMismatchException; 026import org.apache.commons.math4.legacy.exception.NullArgumentException; 027import org.apache.commons.math4.core.jdkmath.JdkMath; 028 029/** 030 * <a href="http://en.wikipedia.org/wiki/Sigmoid_function"> 031 * Sigmoid</a> function. 032 * It is the inverse of the {@link Logit logit} function. 033 * A more flexible version, the generalised logistic, is implemented 034 * by the {@link Logistic} class. 035 * 036 * @since 3.0 037 */ 038public class Sigmoid implements UnivariateDifferentiableFunction { 039 /** Lower asymptote. */ 040 private final double lo; 041 /** Higher asymptote. */ 042 private final double hi; 043 044 /** 045 * Usual sigmoid function, where the lower asymptote is 0 and the higher 046 * asymptote is 1. 047 */ 048 public Sigmoid() { 049 this(0, 1); 050 } 051 052 /** 053 * Sigmoid function. 054 * 055 * @param lo Lower asymptote. 056 * @param hi Higher asymptote. 057 */ 058 public Sigmoid(double lo, 059 double hi) { 060 this.lo = lo; 061 this.hi = hi; 062 } 063 064 /** {@inheritDoc} */ 065 @Override 066 public double value(double x) { 067 return value(x, lo, hi); 068 } 069 070 /** 071 * Parametric function where the input array contains the parameters of 072 * the {@link Sigmoid#Sigmoid(double,double) sigmoid function}. Ordered 073 * as follows: 074 * <ul> 075 * <li>Lower asymptote</li> 076 * <li>Higher asymptote</li> 077 * </ul> 078 */ 079 public static class Parametric implements ParametricUnivariateFunction { 080 /** 081 * Computes the value of the sigmoid at {@code x}. 082 * 083 * @param x Value for which the function must be computed. 084 * @param param Values of lower asymptote and higher asymptote. 085 * @return the value of the function. 086 * @throws NullArgumentException if {@code param} is {@code null}. 087 * @throws DimensionMismatchException if the size of {@code param} is 088 * not 2. 089 */ 090 @Override 091 public double value(double x, double ... param) 092 throws NullArgumentException, 093 DimensionMismatchException { 094 validateParameters(param); 095 return Sigmoid.value(x, param[0], param[1]); 096 } 097 098 /** 099 * Computes the value of the gradient at {@code x}. 100 * The components of the gradient vector are the partial 101 * derivatives of the function with respect to each of the 102 * <em>parameters</em> (lower asymptote and higher asymptote). 103 * 104 * @param x Value at which the gradient must be computed. 105 * @param param Values for lower asymptote and higher asymptote. 106 * @return the gradient vector at {@code x}. 107 * @throws NullArgumentException if {@code param} is {@code null}. 108 * @throws DimensionMismatchException if the size of {@code param} is 109 * not 2. 110 */ 111 @Override 112 public double[] gradient(double x, double ... param) 113 throws NullArgumentException, 114 DimensionMismatchException { 115 validateParameters(param); 116 117 final double invExp1 = 1 / (1 + JdkMath.exp(-x)); 118 119 return new double[] { 1 - invExp1, invExp1 }; 120 } 121 122 /** 123 * Validates parameters to ensure they are appropriate for the evaluation of 124 * the {@link #value(double,double[])} and {@link #gradient(double,double[])} 125 * methods. 126 * 127 * @param param Values for lower and higher asymptotes. 128 * @throws NullArgumentException if {@code param} is {@code null}. 129 * @throws DimensionMismatchException if the size of {@code param} is 130 * not 2. 131 */ 132 private void validateParameters(double[] param) 133 throws NullArgumentException, 134 DimensionMismatchException { 135 if (param == null) { 136 throw new NullArgumentException(); 137 } 138 if (param.length != 2) { 139 throw new DimensionMismatchException(param.length, 2); 140 } 141 } 142 } 143 144 /** 145 * @param x Value at which to compute the sigmoid. 146 * @param lo Lower asymptote. 147 * @param hi Higher asymptote. 148 * @return the value of the sigmoid function at {@code x}. 149 */ 150 private static double value(double x, 151 double lo, 152 double hi) { 153 return lo + (hi - lo) / (1 + JdkMath.exp(-x)); 154 } 155 156 /** {@inheritDoc} 157 * @since 3.1 158 */ 159 @Override 160 public DerivativeStructure value(final DerivativeStructure t) 161 throws DimensionMismatchException { 162 163 double[] f = new double[t.getOrder() + 1]; 164 final double exp = JdkMath.exp(-t.getValue()); 165 if (Double.isInfinite(exp)) { 166 167 // special handling near lower boundary, to avoid NaN 168 f[0] = lo; 169 Arrays.fill(f, 1, f.length, 0.0); 170 } else { 171 172 // the nth order derivative of sigmoid has the form: 173 // dn(sigmoid(x)/dxn = P_n(exp(-x)) / (1+exp(-x))^(n+1) 174 // where P_n(t) is a degree n polynomial with normalized higher term 175 // P_0(t) = 1, P_1(t) = t, P_2(t) = t^2 - t, P_3(t) = t^3 - 4 t^2 + t... 176 // the general recurrence relation for P_n is: 177 // P_n(x) = n t P_(n-1)(t) - t (1 + t) P_(n-1)'(t) 178 final double[] p = new double[f.length]; 179 180 final double inv = 1 / (1 + exp); 181 double coeff = hi - lo; 182 for (int n = 0; n < f.length; ++n) { 183 184 // update and evaluate polynomial P_n(t) 185 double v = 0; 186 p[n] = 1; 187 for (int k = n; k >= 0; --k) { 188 v = v * exp + p[k]; 189 if (k > 1) { 190 p[k - 1] = (n - k + 2) * p[k - 2] - (k - 1) * p[k - 1]; 191 } else { 192 p[0] = 0; 193 } 194 } 195 196 coeff *= inv; 197 f[n] = coeff * v; 198 } 199 200 // fix function value 201 f[0] += lo; 202 } 203 204 return t.compose(f); 205 } 206}