001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math4.legacy.analysis.function; 019 020import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction; 021import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure; 022import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction; 023import org.apache.commons.math4.legacy.exception.DimensionMismatchException; 024import org.apache.commons.math4.legacy.exception.NullArgumentException; 025import org.apache.commons.math4.legacy.exception.OutOfRangeException; 026import org.apache.commons.math4.core.jdkmath.JdkMath; 027 028/** 029 * <a href="http://en.wikipedia.org/wiki/Logit"> 030 * Logit</a> function. 031 * It is the inverse of the {@link Sigmoid sigmoid} function. 032 * 033 * @since 3.0 034 */ 035public class Logit implements UnivariateDifferentiableFunction { 036 /** Lower bound. */ 037 private final double lo; 038 /** Higher bound. */ 039 private final double hi; 040 041 /** 042 * Usual logit function, where the lower bound is 0 and the higher 043 * bound is 1. 044 */ 045 public Logit() { 046 this(0, 1); 047 } 048 049 /** 050 * Logit function. 051 * 052 * @param lo Lower bound of the function domain. 053 * @param hi Higher bound of the function domain. 054 */ 055 public Logit(double lo, 056 double hi) { 057 this.lo = lo; 058 this.hi = hi; 059 } 060 061 /** {@inheritDoc} */ 062 @Override 063 public double value(double x) 064 throws OutOfRangeException { 065 return value(x, lo, hi); 066 } 067 068 /** 069 * Parametric function where the input array contains the parameters of 070 * the logit function. Ordered as follows: 071 * <ul> 072 * <li>Lower bound</li> 073 * <li>Higher bound</li> 074 * </ul> 075 */ 076 public static class Parametric implements ParametricUnivariateFunction { 077 /** 078 * Computes the value of the logit at {@code x}. 079 * 080 * @param x Value for which the function must be computed. 081 * @param param Values of lower bound and higher bounds. 082 * @return the value of the function. 083 * @throws NullArgumentException if {@code param} is {@code null}. 084 * @throws DimensionMismatchException if the size of {@code param} is 085 * not 2. 086 */ 087 @Override 088 public double value(double x, double ... param) 089 throws NullArgumentException, 090 DimensionMismatchException { 091 validateParameters(param); 092 return Logit.value(x, param[0], param[1]); 093 } 094 095 /** 096 * Computes the value of the gradient at {@code x}. 097 * The components of the gradient vector are the partial 098 * derivatives of the function with respect to each of the 099 * <em>parameters</em> (lower bound and higher bound). 100 * 101 * @param x Value at which the gradient must be computed. 102 * @param param Values for lower and higher bounds. 103 * @return the gradient vector at {@code x}. 104 * @throws NullArgumentException if {@code param} is {@code null}. 105 * @throws DimensionMismatchException if the size of {@code param} is 106 * not 2. 107 */ 108 @Override 109 public double[] gradient(double x, double ... param) 110 throws NullArgumentException, 111 DimensionMismatchException { 112 validateParameters(param); 113 114 final double lo = param[0]; 115 final double hi = param[1]; 116 117 return new double[] { 1 / (lo - x), 1 / (hi - x) }; 118 } 119 120 /** 121 * Validates parameters to ensure they are appropriate for the evaluation of 122 * the {@link #value(double,double[])} and {@link #gradient(double,double[])} 123 * methods. 124 * 125 * @param param Values for lower and higher bounds. 126 * @throws NullArgumentException if {@code param} is {@code null}. 127 * @throws DimensionMismatchException if the size of {@code param} is 128 * not 2. 129 */ 130 private void validateParameters(double[] param) 131 throws NullArgumentException, 132 DimensionMismatchException { 133 if (param == null) { 134 throw new NullArgumentException(); 135 } 136 if (param.length != 2) { 137 throw new DimensionMismatchException(param.length, 2); 138 } 139 } 140 } 141 142 /** 143 * @param x Value at which to compute the logit. 144 * @param lo Lower bound. 145 * @param hi Higher bound. 146 * @return the value of the logit function at {@code x}. 147 * @throws OutOfRangeException if {@code x < lo} or {@code x > hi}. 148 */ 149 private static double value(double x, 150 double lo, 151 double hi) 152 throws OutOfRangeException { 153 if (x < lo || x > hi) { 154 throw new OutOfRangeException(x, lo, hi); 155 } 156 return JdkMath.log((x - lo) / (hi - x)); 157 } 158 159 /** {@inheritDoc} 160 * @since 3.1 161 * @exception OutOfRangeException if parameter is outside of function domain 162 */ 163 @Override 164 public DerivativeStructure value(final DerivativeStructure t) 165 throws OutOfRangeException { 166 final double x = t.getValue(); 167 if (x < lo || x > hi) { 168 throw new OutOfRangeException(x, lo, hi); 169 } 170 double[] f = new double[t.getOrder() + 1]; 171 172 // function value 173 f[0] = JdkMath.log((x - lo) / (hi - x)); 174 175 if (Double.isInfinite(f[0])) { 176 177 if (f.length > 1) { 178 f[1] = Double.POSITIVE_INFINITY; 179 } 180 // fill the array with infinities 181 // (for x close to lo the signs will flip between -inf and +inf, 182 // for x close to hi the signs will always be +inf) 183 // this is probably overkill, since the call to compose at the end 184 // of the method will transform most infinities into NaN ... 185 for (int i = 2; i < f.length; ++i) { 186 f[i] = f[i - 2]; 187 } 188 } else { 189 190 // function derivatives 191 final double invL = 1.0 / (x - lo); 192 double xL = invL; 193 final double invH = 1.0 / (hi - x); 194 double xH = invH; 195 for (int i = 1; i < f.length; ++i) { 196 f[i] = xL + xH; 197 xL *= -i * invL; 198 xH *= i * invH; 199 } 200 } 201 202 return t.compose(f); 203 } 204}