001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math3.analysis.function; 019 020import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction; 021import org.apache.commons.math3.analysis.FunctionUtils; 022import org.apache.commons.math3.analysis.ParametricUnivariateFunction; 023import org.apache.commons.math3.analysis.UnivariateFunction; 024import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; 025import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction; 026import org.apache.commons.math3.exception.DimensionMismatchException; 027import org.apache.commons.math3.exception.NullArgumentException; 028import org.apache.commons.math3.exception.OutOfRangeException; 029import org.apache.commons.math3.util.FastMath; 030 031/** 032 * <a href="http://en.wikipedia.org/wiki/Logit"> 033 * Logit</a> function. 034 * It is the inverse of the {@link Sigmoid sigmoid} function. 035 * 036 * @since 3.0 037 */ 038public class Logit implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction { 039 /** Lower bound. */ 040 private final double lo; 041 /** Higher bound. */ 042 private final double hi; 043 044 /** 045 * Usual logit function, where the lower bound is 0 and the higher 046 * bound is 1. 047 */ 048 public Logit() { 049 this(0, 1); 050 } 051 052 /** 053 * Logit function. 054 * 055 * @param lo Lower bound of the function domain. 056 * @param hi Higher bound of the function domain. 057 */ 058 public Logit(double lo, 059 double hi) { 060 this.lo = lo; 061 this.hi = hi; 062 } 063 064 /** {@inheritDoc} */ 065 public double value(double x) 066 throws OutOfRangeException { 067 return value(x, lo, hi); 068 } 069 070 /** {@inheritDoc} 071 * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)} 072 */ 073 @Deprecated 074 public UnivariateFunction derivative() { 075 return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative(); 076 } 077 078 /** 079 * Parametric function where the input array contains the parameters of 080 * the logit function, ordered as follows: 081 * <ul> 082 * <li>Lower bound</li> 083 * <li>Higher bound</li> 084 * </ul> 085 */ 086 public static class Parametric implements ParametricUnivariateFunction { 087 /** 088 * Computes the value of the logit at {@code x}. 089 * 090 * @param x Value for which the function must be computed. 091 * @param param Values of lower bound and higher bounds. 092 * @return the value of the function. 093 * @throws NullArgumentException if {@code param} is {@code null}. 094 * @throws DimensionMismatchException if the size of {@code param} is 095 * not 2. 096 */ 097 public double value(double x, double ... param) 098 throws NullArgumentException, 099 DimensionMismatchException { 100 validateParameters(param); 101 return Logit.value(x, param[0], param[1]); 102 } 103 104 /** 105 * Computes the value of the gradient at {@code x}. 106 * The components of the gradient vector are the partial 107 * derivatives of the function with respect to each of the 108 * <em>parameters</em> (lower bound and higher bound). 109 * 110 * @param x Value at which the gradient must be computed. 111 * @param param Values for lower and higher bounds. 112 * @return the gradient vector at {@code x}. 113 * @throws NullArgumentException if {@code param} is {@code null}. 114 * @throws DimensionMismatchException if the size of {@code param} is 115 * not 2. 116 */ 117 public double[] gradient(double x, double ... param) 118 throws NullArgumentException, 119 DimensionMismatchException { 120 validateParameters(param); 121 122 final double lo = param[0]; 123 final double hi = param[1]; 124 125 return new double[] { 1 / (lo - x), 1 / (hi - x) }; 126 } 127 128 /** 129 * Validates parameters to ensure they are appropriate for the evaluation of 130 * the {@link #value(double,double[])} and {@link #gradient(double,double[])} 131 * methods. 132 * 133 * @param param Values for lower and higher bounds. 134 * @throws NullArgumentException if {@code param} is {@code null}. 135 * @throws DimensionMismatchException if the size of {@code param} is 136 * not 2. 137 */ 138 private void validateParameters(double[] param) 139 throws NullArgumentException, 140 DimensionMismatchException { 141 if (param == null) { 142 throw new NullArgumentException(); 143 } 144 if (param.length != 2) { 145 throw new DimensionMismatchException(param.length, 2); 146 } 147 } 148 } 149 150 /** 151 * @param x Value at which to compute the logit. 152 * @param lo Lower bound. 153 * @param hi Higher bound. 154 * @return the value of the logit function at {@code x}. 155 * @throws OutOfRangeException if {@code x < lo} or {@code x > hi}. 156 */ 157 private static double value(double x, 158 double lo, 159 double hi) 160 throws OutOfRangeException { 161 if (x < lo || x > hi) { 162 throw new OutOfRangeException(x, lo, hi); 163 } 164 return FastMath.log((x - lo) / (hi - x)); 165 } 166 167 /** {@inheritDoc} 168 * @since 3.1 169 * @exception OutOfRangeException if parameter is outside of function domain 170 */ 171 public DerivativeStructure value(final DerivativeStructure t) 172 throws OutOfRangeException { 173 final double x = t.getValue(); 174 if (x < lo || x > hi) { 175 throw new OutOfRangeException(x, lo, hi); 176 } 177 double[] f = new double[t.getOrder() + 1]; 178 179 // function value 180 f[0] = FastMath.log((x - lo) / (hi - x)); 181 182 if (Double.isInfinite(f[0])) { 183 184 if (f.length > 1) { 185 f[1] = Double.POSITIVE_INFINITY; 186 } 187 // fill the array with infinities 188 // (for x close to lo the signs will flip between -inf and +inf, 189 // for x close to hi the signs will always be +inf) 190 // this is probably overkill, since the call to compose at the end 191 // of the method will transform most infinities into NaN ... 192 for (int i = 2; i < f.length; ++i) { 193 f[i] = f[i - 2]; 194 } 195 196 } else { 197 198 // function derivatives 199 final double invL = 1.0 / (x - lo); 200 double xL = invL; 201 final double invH = 1.0 / (hi - x); 202 double xH = invH; 203 for (int i = 1; i < f.length; ++i) { 204 f[i] = xL + xH; 205 xL *= -i * invL; 206 xH *= i * invH; 207 } 208 } 209 210 return t.compose(f); 211 } 212}