View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math3.analysis.function;
19  
20  import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction;
21  import org.apache.commons.math3.analysis.FunctionUtils;
22  import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
23  import org.apache.commons.math3.analysis.UnivariateFunction;
24  import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
25  import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
26  import org.apache.commons.math3.exception.DimensionMismatchException;
27  import org.apache.commons.math3.exception.NullArgumentException;
28  import org.apache.commons.math3.exception.OutOfRangeException;
29  import org.apache.commons.math3.util.FastMath;
30  
31  /**
32   * <a href="http://en.wikipedia.org/wiki/Logit">
33   *  Logit</a> function.
34   * It is the inverse of the {@link Sigmoid sigmoid} function.
35   *
36   * @since 3.0
37   * @version $Id: Logit.java 1391927 2012-09-30 00:03:30Z erans $
38   */
39  public class Logit implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction {
40      /** Lower bound. */
41      private final double lo;
42      /** Higher bound. */
43      private final double hi;
44  
45      /**
46       * Usual logit function, where the lower bound is 0 and the higher
47       * bound is 1.
48       */
49      public Logit() {
50          this(0, 1);
51      }
52  
53      /**
54       * Logit function.
55       *
56       * @param lo Lower bound of the function domain.
57       * @param hi Higher bound of the function domain.
58       */
59      public Logit(double lo,
60                   double hi) {
61          this.lo = lo;
62          this.hi = hi;
63      }
64  
65      /** {@inheritDoc} */
66      public double value(double x)
67          throws OutOfRangeException {
68          return value(x, lo, hi);
69      }
70  
71      /** {@inheritDoc}
72       * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)}
73       */
74      @Deprecated
75      public UnivariateFunction derivative() {
76          return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative();
77      }
78  
79      /**
80       * Parametric function where the input array contains the parameters of
81       * the logit function, ordered as follows:
82       * <ul>
83       *  <li>Lower bound</li>
84       *  <li>Higher bound</li>
85       * </ul>
86       */
87      public static class Parametric implements ParametricUnivariateFunction {
88          /**
89           * Computes the value of the logit at {@code x}.
90           *
91           * @param x Value for which the function must be computed.
92           * @param param Values of lower bound and higher bounds.
93           * @return the value of the function.
94           * @throws NullArgumentException if {@code param} is {@code null}.
95           * @throws DimensionMismatchException if the size of {@code param} is
96           * not 2.
97           */
98          public double value(double x, double ... param)
99              throws NullArgumentException,
100                    DimensionMismatchException {
101             validateParameters(param);
102             return Logit.value(x, param[0], param[1]);
103         }
104 
105         /**
106          * Computes the value of the gradient at {@code x}.
107          * The components of the gradient vector are the partial
108          * derivatives of the function with respect to each of the
109          * <em>parameters</em> (lower bound and higher bound).
110          *
111          * @param x Value at which the gradient must be computed.
112          * @param param Values for lower and higher bounds.
113          * @return the gradient vector at {@code x}.
114          * @throws NullArgumentException if {@code param} is {@code null}.
115          * @throws DimensionMismatchException if the size of {@code param} is
116          * not 2.
117          */
118         public double[] gradient(double x, double ... param)
119             throws NullArgumentException,
120                    DimensionMismatchException {
121             validateParameters(param);
122 
123             final double lo = param[0];
124             final double hi = param[1];
125 
126             return new double[] { 1 / (lo - x), 1 / (hi - x) };
127         }
128 
129         /**
130          * Validates parameters to ensure they are appropriate for the evaluation of
131          * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
132          * methods.
133          *
134          * @param param Values for lower and higher bounds.
135          * @throws NullArgumentException if {@code param} is {@code null}.
136          * @throws DimensionMismatchException if the size of {@code param} is
137          * not 2.
138          */
139         private void validateParameters(double[] param)
140             throws NullArgumentException,
141                    DimensionMismatchException {
142             if (param == null) {
143                 throw new NullArgumentException();
144             }
145             if (param.length != 2) {
146                 throw new DimensionMismatchException(param.length, 2);
147             }
148         }
149     }
150 
151     /**
152      * @param x Value at which to compute the logit.
153      * @param lo Lower bound.
154      * @param hi Higher bound.
155      * @return the value of the logit function at {@code x}.
156      * @throws OutOfRangeException if {@code x < lo} or {@code x > hi}.
157      */
158     private static double value(double x,
159                                 double lo,
160                                 double hi)
161         throws OutOfRangeException {
162         if (x < lo || x > hi) {
163             throw new OutOfRangeException(x, lo, hi);
164         }
165         return FastMath.log((x - lo) / (hi - x));
166     }
167 
168     /** {@inheritDoc}
169      * @since 3.1
170      * @exception OutOfRangeException if parameter is outside of function domain
171      */
172     public DerivativeStructure value(final DerivativeStructure t)
173         throws OutOfRangeException {
174         final double x = t.getValue();
175         if (x < lo || x > hi) {
176             throw new OutOfRangeException(x, lo, hi);
177         }
178         double[] f = new double[t.getOrder() + 1];
179 
180         // function value
181         f[0] = FastMath.log((x - lo) / (hi - x));
182 
183         if (Double.isInfinite(f[0])) {
184 
185             if (f.length > 1) {
186                 f[1] = Double.POSITIVE_INFINITY;
187             }
188             // fill the array with infinities
189             // (for x close to lo the signs will flip between -inf and +inf,
190             //  for x close to hi the signs will always be +inf)
191             // this is probably overkill, since the call to compose at the end
192             // of the method will transform most infinities into NaN ...
193             for (int i = 2; i < f.length; ++i) {
194                 f[i] = f[i - 2];
195             }
196 
197         } else {
198 
199             // function derivatives
200             final double invL = 1.0 / (x - lo);
201             double xL = invL;
202             final double invH = 1.0 / (hi - x);
203             double xH = invH;
204             for (int i = 1; i < f.length; ++i) {
205                 f[i] = xL + xH;
206                 xL  *= -i * invL;
207                 xH  *=  i * invH;
208             }
209         }
210 
211         return t.compose(f);
212     }
213 }