View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.analysis.function;
19  
20  import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
21  import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
22  import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction;
23  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
24  import org.apache.commons.math4.legacy.exception.NullArgumentException;
25  import org.apache.commons.math4.legacy.exception.OutOfRangeException;
26  import org.apache.commons.math4.core.jdkmath.JdkMath;
27  
28  /**
29   * <a href="http://en.wikipedia.org/wiki/Logit">
30   *  Logit</a> function.
31   * It is the inverse of the {@link Sigmoid sigmoid} function.
32   *
33   * @since 3.0
34   */
35  public class Logit implements UnivariateDifferentiableFunction {
36      /** Lower bound. */
37      private final double lo;
38      /** Higher bound. */
39      private final double hi;
40  
41      /**
42       * Usual logit function, where the lower bound is 0 and the higher
43       * bound is 1.
44       */
45      public Logit() {
46          this(0, 1);
47      }
48  
49      /**
50       * Logit function.
51       *
52       * @param lo Lower bound of the function domain.
53       * @param hi Higher bound of the function domain.
54       */
55      public Logit(double lo,
56                   double hi) {
57          this.lo = lo;
58          this.hi = hi;
59      }
60  
61      /** {@inheritDoc} */
62      @Override
63      public double value(double x)
64          throws OutOfRangeException {
65          return value(x, lo, hi);
66      }
67  
68      /**
69       * Parametric function where the input array contains the parameters of
70       * the logit function. Ordered as follows:
71       * <ul>
72       *  <li>Lower bound</li>
73       *  <li>Higher bound</li>
74       * </ul>
75       */
76      public static class Parametric implements ParametricUnivariateFunction {
77          /**
78           * Computes the value of the logit at {@code x}.
79           *
80           * @param x Value for which the function must be computed.
81           * @param param Values of lower bound and higher bounds.
82           * @return the value of the function.
83           * @throws NullArgumentException if {@code param} is {@code null}.
84           * @throws DimensionMismatchException if the size of {@code param} is
85           * not 2.
86           */
87          @Override
88          public double value(double x, double ... param)
89              throws NullArgumentException,
90                     DimensionMismatchException {
91              validateParameters(param);
92              return Logit.value(x, param[0], param[1]);
93          }
94  
95          /**
96           * Computes the value of the gradient at {@code x}.
97           * The components of the gradient vector are the partial
98           * derivatives of the function with respect to each of the
99           * <em>parameters</em> (lower bound and higher bound).
100          *
101          * @param x Value at which the gradient must be computed.
102          * @param param Values for lower and higher bounds.
103          * @return the gradient vector at {@code x}.
104          * @throws NullArgumentException if {@code param} is {@code null}.
105          * @throws DimensionMismatchException if the size of {@code param} is
106          * not 2.
107          */
108         @Override
109         public double[] gradient(double x, double ... param)
110             throws NullArgumentException,
111                    DimensionMismatchException {
112             validateParameters(param);
113 
114             final double lo = param[0];
115             final double hi = param[1];
116 
117             return new double[] { 1 / (lo - x), 1 / (hi - x) };
118         }
119 
120         /**
121          * Validates parameters to ensure they are appropriate for the evaluation of
122          * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
123          * methods.
124          *
125          * @param param Values for lower and higher bounds.
126          * @throws NullArgumentException if {@code param} is {@code null}.
127          * @throws DimensionMismatchException if the size of {@code param} is
128          * not 2.
129          */
130         private void validateParameters(double[] param)
131             throws NullArgumentException,
132                    DimensionMismatchException {
133             if (param == null) {
134                 throw new NullArgumentException();
135             }
136             if (param.length != 2) {
137                 throw new DimensionMismatchException(param.length, 2);
138             }
139         }
140     }
141 
142     /**
143      * @param x Value at which to compute the logit.
144      * @param lo Lower bound.
145      * @param hi Higher bound.
146      * @return the value of the logit function at {@code x}.
147      * @throws OutOfRangeException if {@code x < lo} or {@code x > hi}.
148      */
149     private static double value(double x,
150                                 double lo,
151                                 double hi)
152         throws OutOfRangeException {
153         if (x < lo || x > hi) {
154             throw new OutOfRangeException(x, lo, hi);
155         }
156         return JdkMath.log((x - lo) / (hi - x));
157     }
158 
159     /** {@inheritDoc}
160      * @since 3.1
161      * @exception OutOfRangeException if parameter is outside of function domain
162      */
163     @Override
164     public DerivativeStructure value(final DerivativeStructure t)
165         throws OutOfRangeException {
166         final double x = t.getValue();
167         if (x < lo || x > hi) {
168             throw new OutOfRangeException(x, lo, hi);
169         }
170         double[] f = new double[t.getOrder() + 1];
171 
172         // function value
173         f[0] = JdkMath.log((x - lo) / (hi - x));
174 
175         if (Double.isInfinite(f[0])) {
176 
177             if (f.length > 1) {
178                 f[1] = Double.POSITIVE_INFINITY;
179             }
180             // fill the array with infinities
181             // (for x close to lo the signs will flip between -inf and +inf,
182             //  for x close to hi the signs will always be +inf)
183             // this is probably overkill, since the call to compose at the end
184             // of the method will transform most infinities into NaN ...
185             for (int i = 2; i < f.length; ++i) {
186                 f[i] = f[i - 2];
187             }
188         } else {
189 
190             // function derivatives
191             final double invL = 1.0 / (x - lo);
192             double xL = invL;
193             final double invH = 1.0 / (hi - x);
194             double xH = invH;
195             for (int i = 1; i < f.length; ++i) {
196                 f[i] = xL + xH;
197                 xL  *= -i * invL;
198                 xH  *=  i * invH;
199             }
200         }
201 
202         return t.compose(f);
203     }
204 }