1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.commons.math4.legacy.analysis.function; 19 20 import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction; 21 import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure; 22 import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction; 23 import org.apache.commons.math4.legacy.exception.DimensionMismatchException; 24 import org.apache.commons.math4.legacy.exception.NullArgumentException; 25 import org.apache.commons.math4.legacy.exception.OutOfRangeException; 26 import org.apache.commons.math4.core.jdkmath.JdkMath; 27 28 /** 29 * <a href="http://en.wikipedia.org/wiki/Logit"> 30 * Logit</a> function. 31 * It is the inverse of the {@link Sigmoid sigmoid} function. 32 * 33 * @since 3.0 34 */ 35 public class Logit implements UnivariateDifferentiableFunction { 36 /** Lower bound. */ 37 private final double lo; 38 /** Higher bound. */ 39 private final double hi; 40 41 /** 42 * Usual logit function, where the lower bound is 0 and the higher 43 * bound is 1. 44 */ 45 public Logit() { 46 this(0, 1); 47 } 48 49 /** 50 * Logit function. 51 * 52 * @param lo Lower bound of the function domain. 53 * @param hi Higher bound of the function domain. 54 */ 55 public Logit(double lo, 56 double hi) { 57 this.lo = lo; 58 this.hi = hi; 59 } 60 61 /** {@inheritDoc} */ 62 @Override 63 public double value(double x) 64 throws OutOfRangeException { 65 return value(x, lo, hi); 66 } 67 68 /** 69 * Parametric function where the input array contains the parameters of 70 * the logit function. Ordered as follows: 71 * <ul> 72 * <li>Lower bound</li> 73 * <li>Higher bound</li> 74 * </ul> 75 */ 76 public static class Parametric implements ParametricUnivariateFunction { 77 /** 78 * Computes the value of the logit at {@code x}. 79 * 80 * @param x Value for which the function must be computed. 81 * @param param Values of lower bound and higher bounds. 82 * @return the value of the function. 83 * @throws NullArgumentException if {@code param} is {@code null}. 84 * @throws DimensionMismatchException if the size of {@code param} is 85 * not 2. 86 */ 87 @Override 88 public double value(double x, double ... param) 89 throws NullArgumentException, 90 DimensionMismatchException { 91 validateParameters(param); 92 return Logit.value(x, param[0], param[1]); 93 } 94 95 /** 96 * Computes the value of the gradient at {@code x}. 97 * The components of the gradient vector are the partial 98 * derivatives of the function with respect to each of the 99 * <em>parameters</em> (lower bound and higher bound). 100 * 101 * @param x Value at which the gradient must be computed. 102 * @param param Values for lower and higher bounds. 103 * @return the gradient vector at {@code x}. 104 * @throws NullArgumentException if {@code param} is {@code null}. 105 * @throws DimensionMismatchException if the size of {@code param} is 106 * not 2. 107 */ 108 @Override 109 public double[] gradient(double x, double ... param) 110 throws NullArgumentException, 111 DimensionMismatchException { 112 validateParameters(param); 113 114 final double lo = param[0]; 115 final double hi = param[1]; 116 117 return new double[] { 1 / (lo - x), 1 / (hi - x) }; 118 } 119 120 /** 121 * Validates parameters to ensure they are appropriate for the evaluation of 122 * the {@link #value(double,double[])} and {@link #gradient(double,double[])} 123 * methods. 124 * 125 * @param param Values for lower and higher bounds. 126 * @throws NullArgumentException if {@code param} is {@code null}. 127 * @throws DimensionMismatchException if the size of {@code param} is 128 * not 2. 129 */ 130 private void validateParameters(double[] param) 131 throws NullArgumentException, 132 DimensionMismatchException { 133 if (param == null) { 134 throw new NullArgumentException(); 135 } 136 if (param.length != 2) { 137 throw new DimensionMismatchException(param.length, 2); 138 } 139 } 140 } 141 142 /** 143 * @param x Value at which to compute the logit. 144 * @param lo Lower bound. 145 * @param hi Higher bound. 146 * @return the value of the logit function at {@code x}. 147 * @throws OutOfRangeException if {@code x < lo} or {@code x > hi}. 148 */ 149 private static double value(double x, 150 double lo, 151 double hi) 152 throws OutOfRangeException { 153 if (x < lo || x > hi) { 154 throw new OutOfRangeException(x, lo, hi); 155 } 156 return JdkMath.log((x - lo) / (hi - x)); 157 } 158 159 /** {@inheritDoc} 160 * @since 3.1 161 * @exception OutOfRangeException if parameter is outside of function domain 162 */ 163 @Override 164 public DerivativeStructure value(final DerivativeStructure t) 165 throws OutOfRangeException { 166 final double x = t.getValue(); 167 if (x < lo || x > hi) { 168 throw new OutOfRangeException(x, lo, hi); 169 } 170 double[] f = new double[t.getOrder() + 1]; 171 172 // function value 173 f[0] = JdkMath.log((x - lo) / (hi - x)); 174 175 if (Double.isInfinite(f[0])) { 176 177 if (f.length > 1) { 178 f[1] = Double.POSITIVE_INFINITY; 179 } 180 // fill the array with infinities 181 // (for x close to lo the signs will flip between -inf and +inf, 182 // for x close to hi the signs will always be +inf) 183 // this is probably overkill, since the call to compose at the end 184 // of the method will transform most infinities into NaN ... 185 for (int i = 2; i < f.length; ++i) { 186 f[i] = f[i - 2]; 187 } 188 } else { 189 190 // function derivatives 191 final double invL = 1.0 / (x - lo); 192 double xL = invL; 193 final double invH = 1.0 / (hi - x); 194 double xH = invH; 195 for (int i = 1; i < f.length; ++i) { 196 f[i] = xL + xH; 197 xL *= -i * invL; 198 xH *= i * invH; 199 } 200 } 201 202 return t.compose(f); 203 } 204 }