View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math3.analysis.function;
19  
20  import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction;
21  import org.apache.commons.math3.analysis.FunctionUtils;
22  import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
23  import org.apache.commons.math3.analysis.UnivariateFunction;
24  import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
25  import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
26  import org.apache.commons.math3.exception.DimensionMismatchException;
27  import org.apache.commons.math3.exception.NullArgumentException;
28  import org.apache.commons.math3.exception.OutOfRangeException;
29  import org.apache.commons.math3.util.FastMath;
30  
31  /**
32   * <a href="http://en.wikipedia.org/wiki/Logit">
33   *  Logit</a> function.
34   * It is the inverse of the {@link Sigmoid sigmoid} function.
35   *
36   * @since 3.0
37   */
38  public class Logit implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction {
39      /** Lower bound. */
40      private final double lo;
41      /** Higher bound. */
42      private final double hi;
43  
44      /**
45       * Usual logit function, where the lower bound is 0 and the higher
46       * bound is 1.
47       */
48      public Logit() {
49          this(0, 1);
50      }
51  
52      /**
53       * Logit function.
54       *
55       * @param lo Lower bound of the function domain.
56       * @param hi Higher bound of the function domain.
57       */
58      public Logit(double lo,
59                   double hi) {
60          this.lo = lo;
61          this.hi = hi;
62      }
63  
64      /** {@inheritDoc} */
65      public double value(double x)
66          throws OutOfRangeException {
67          return value(x, lo, hi);
68      }
69  
70      /** {@inheritDoc}
71       * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)}
72       */
73      @Deprecated
74      public UnivariateFunction derivative() {
75          return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative();
76      }
77  
78      /**
79       * Parametric function where the input array contains the parameters of
80       * the logit function, ordered as follows:
81       * <ul>
82       *  <li>Lower bound</li>
83       *  <li>Higher bound</li>
84       * </ul>
85       */
86      public static class Parametric implements ParametricUnivariateFunction {
87          /**
88           * Computes the value of the logit at {@code x}.
89           *
90           * @param x Value for which the function must be computed.
91           * @param param Values of lower bound and higher bounds.
92           * @return the value of the function.
93           * @throws NullArgumentException if {@code param} is {@code null}.
94           * @throws DimensionMismatchException if the size of {@code param} is
95           * not 2.
96           */
97          public double value(double x, double ... param)
98              throws NullArgumentException,
99                     DimensionMismatchException {
100             validateParameters(param);
101             return Logit.value(x, param[0], param[1]);
102         }
103 
104         /**
105          * Computes the value of the gradient at {@code x}.
106          * The components of the gradient vector are the partial
107          * derivatives of the function with respect to each of the
108          * <em>parameters</em> (lower bound and higher bound).
109          *
110          * @param x Value at which the gradient must be computed.
111          * @param param Values for lower and higher bounds.
112          * @return the gradient vector at {@code x}.
113          * @throws NullArgumentException if {@code param} is {@code null}.
114          * @throws DimensionMismatchException if the size of {@code param} is
115          * not 2.
116          */
117         public double[] gradient(double x, double ... param)
118             throws NullArgumentException,
119                    DimensionMismatchException {
120             validateParameters(param);
121 
122             final double lo = param[0];
123             final double hi = param[1];
124 
125             return new double[] { 1 / (lo - x), 1 / (hi - x) };
126         }
127 
128         /**
129          * Validates parameters to ensure they are appropriate for the evaluation of
130          * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
131          * methods.
132          *
133          * @param param Values for lower and higher bounds.
134          * @throws NullArgumentException if {@code param} is {@code null}.
135          * @throws DimensionMismatchException if the size of {@code param} is
136          * not 2.
137          */
138         private void validateParameters(double[] param)
139             throws NullArgumentException,
140                    DimensionMismatchException {
141             if (param == null) {
142                 throw new NullArgumentException();
143             }
144             if (param.length != 2) {
145                 throw new DimensionMismatchException(param.length, 2);
146             }
147         }
148     }
149 
150     /**
151      * @param x Value at which to compute the logit.
152      * @param lo Lower bound.
153      * @param hi Higher bound.
154      * @return the value of the logit function at {@code x}.
155      * @throws OutOfRangeException if {@code x < lo} or {@code x > hi}.
156      */
157     private static double value(double x,
158                                 double lo,
159                                 double hi)
160         throws OutOfRangeException {
161         if (x < lo || x > hi) {
162             throw new OutOfRangeException(x, lo, hi);
163         }
164         return FastMath.log((x - lo) / (hi - x));
165     }
166 
167     /** {@inheritDoc}
168      * @since 3.1
169      * @exception OutOfRangeException if parameter is outside of function domain
170      */
171     public DerivativeStructure value(final DerivativeStructure t)
172         throws OutOfRangeException {
173         final double x = t.getValue();
174         if (x < lo || x > hi) {
175             throw new OutOfRangeException(x, lo, hi);
176         }
177         double[] f = new double[t.getOrder() + 1];
178 
179         // function value
180         f[0] = FastMath.log((x - lo) / (hi - x));
181 
182         if (Double.isInfinite(f[0])) {
183 
184             if (f.length > 1) {
185                 f[1] = Double.POSITIVE_INFINITY;
186             }
187             // fill the array with infinities
188             // (for x close to lo the signs will flip between -inf and +inf,
189             //  for x close to hi the signs will always be +inf)
190             // this is probably overkill, since the call to compose at the end
191             // of the method will transform most infinities into NaN ...
192             for (int i = 2; i < f.length; ++i) {
193                 f[i] = f[i - 2];
194             }
195 
196         } else {
197 
198             // function derivatives
199             final double invL = 1.0 / (x - lo);
200             double xL = invL;
201             final double invH = 1.0 / (hi - x);
202             double xH = invH;
203             for (int i = 1; i < f.length; ++i) {
204                 f[i] = xL + xH;
205                 xL  *= -i * invL;
206                 xH  *=  i * invH;
207             }
208         }
209 
210         return t.compose(f);
211     }
212 }