001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math4.legacy.analysis.function; 019 020import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction; 021import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure; 022import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction; 023import org.apache.commons.math4.legacy.exception.DimensionMismatchException; 024import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException; 025import org.apache.commons.math4.legacy.exception.NullArgumentException; 026import org.apache.commons.math4.core.jdkmath.JdkMath; 027 028/** 029 * <a href="http://en.wikipedia.org/wiki/Generalised_logistic_function"> 030 * Generalised logistic</a> function. 031 * 032 * @since 3.0 033 */ 034public class Logistic implements UnivariateDifferentiableFunction { 035 /** Lower asymptote. */ 036 private final double a; 037 /** Upper asymptote. */ 038 private final double k; 039 /** Growth rate. */ 040 private final double b; 041 /** Parameter that affects near which asymptote maximum growth occurs. */ 042 private final double oneOverN; 043 /** Parameter that affects the position of the curve along the ordinate axis. */ 044 private final double q; 045 /** Abscissa of maximum growth. */ 046 private final double m; 047 048 /** 049 * @param k If {@code b > 0}, value of the function for x going towards +∞. 050 * If {@code b < 0}, value of the function for x going towards -∞. 051 * @param m Abscissa of maximum growth. 052 * @param b Growth rate. 053 * @param q Parameter that affects the position of the curve along the 054 * ordinate axis. 055 * @param a If {@code b > 0}, value of the function for x going towards -∞. 056 * If {@code b < 0}, value of the function for x going towards +∞. 057 * @param n Parameter that affects near which asymptote the maximum 058 * growth occurs. 059 * @throws NotStrictlyPositiveException if {@code n <= 0}. 060 */ 061 public Logistic(double k, 062 double m, 063 double b, 064 double q, 065 double a, 066 double n) 067 throws NotStrictlyPositiveException { 068 if (n <= 0) { 069 throw new NotStrictlyPositiveException(n); 070 } 071 072 this.k = k; 073 this.m = m; 074 this.b = b; 075 this.q = q; 076 this.a = a; 077 oneOverN = 1 / n; 078 } 079 080 /** {@inheritDoc} */ 081 @Override 082 public double value(double x) { 083 return value(m - x, k, b, q, a, oneOverN); 084 } 085 086 /** 087 * Parametric function where the input array contains the parameters of 088 * the {@link Logistic#Logistic(double,double,double,double,double,double) 089 * logistic function}. Ordered as follows: 090 * <ul> 091 * <li>k</li> 092 * <li>m</li> 093 * <li>b</li> 094 * <li>q</li> 095 * <li>a</li> 096 * <li>n</li> 097 * </ul> 098 */ 099 public static class Parametric implements ParametricUnivariateFunction { 100 /** 101 * Computes the value of the sigmoid at {@code x}. 102 * 103 * @param x Value for which the function must be computed. 104 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q}, 105 * {@code a} and {@code n}. 106 * @return the value of the function. 107 * @throws NullArgumentException if {@code param} is {@code null}. 108 * @throws DimensionMismatchException if the size of {@code param} is 109 * not 6. 110 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}. 111 */ 112 @Override 113 public double value(double x, double ... param) 114 throws NullArgumentException, 115 DimensionMismatchException, 116 NotStrictlyPositiveException { 117 validateParameters(param); 118 return Logistic.value(param[1] - x, param[0], 119 param[2], param[3], 120 param[4], 1 / param[5]); 121 } 122 123 /** 124 * Computes the value of the gradient at {@code x}. 125 * The components of the gradient vector are the partial 126 * derivatives of the function with respect to each of the 127 * <em>parameters</em>. 128 * 129 * @param x Value at which the gradient must be computed. 130 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q}, 131 * {@code a} and {@code n}. 132 * @return the gradient vector at {@code x}. 133 * @throws NullArgumentException if {@code param} is {@code null}. 134 * @throws DimensionMismatchException if the size of {@code param} is 135 * not 6. 136 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}. 137 */ 138 @Override 139 public double[] gradient(double x, double ... param) 140 throws NullArgumentException, 141 DimensionMismatchException, 142 NotStrictlyPositiveException { 143 validateParameters(param); 144 145 final double b = param[2]; 146 final double q = param[3]; 147 148 final double mMinusX = param[1] - x; 149 final double oneOverN = 1 / param[5]; 150 final double exp = JdkMath.exp(b * mMinusX); 151 final double qExp = q * exp; 152 final double qExp1 = qExp + 1; 153 final double factor1 = (param[0] - param[4]) * oneOverN / JdkMath.pow(qExp1, oneOverN); 154 final double factor2 = -factor1 / qExp1; 155 156 // Components of the gradient. 157 final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN); 158 final double gm = factor2 * b * qExp; 159 final double gb = factor2 * mMinusX * qExp; 160 final double gq = factor2 * exp; 161 final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN); 162 final double gn = factor1 * JdkMath.log(qExp1) * oneOverN; 163 164 return new double[] { gk, gm, gb, gq, ga, gn }; 165 } 166 167 /** 168 * Validates parameters to ensure they are appropriate for the evaluation of 169 * the {@link #value(double,double[])} and {@link #gradient(double,double[])} 170 * methods. 171 * 172 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q}, 173 * {@code a} and {@code n}. 174 * @throws NullArgumentException if {@code param} is {@code null}. 175 * @throws DimensionMismatchException if the size of {@code param} is 176 * not 6. 177 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}. 178 */ 179 private void validateParameters(double[] param) 180 throws NullArgumentException, 181 DimensionMismatchException, 182 NotStrictlyPositiveException { 183 if (param == null) { 184 throw new NullArgumentException(); 185 } 186 if (param.length != 6) { 187 throw new DimensionMismatchException(param.length, 6); 188 } 189 if (param[5] <= 0) { 190 throw new NotStrictlyPositiveException(param[5]); 191 } 192 } 193 } 194 195 /** 196 * @param mMinusX {@code m - x}. 197 * @param k {@code k}. 198 * @param b {@code b}. 199 * @param q {@code q}. 200 * @param a {@code a}. 201 * @param oneOverN {@code 1 / n}. 202 * @return the value of the function. 203 */ 204 private static double value(double mMinusX, 205 double k, 206 double b, 207 double q, 208 double a, 209 double oneOverN) { 210 return a + (k - a) / JdkMath.pow(1 + q * JdkMath.exp(b * mMinusX), oneOverN); 211 } 212 213 /** {@inheritDoc} 214 * @since 3.1 215 */ 216 @Override 217 public DerivativeStructure value(final DerivativeStructure t) { 218 return t.negate().add(m).multiply(b).exp().multiply(q).add(1).pow(oneOverN).reciprocal().multiply(k - a).add(a); 219 } 220}