Logit.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.analysis.function;

  18. import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
  19. import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
  20. import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction;
  21. import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
  22. import org.apache.commons.math4.legacy.exception.NullArgumentException;
  23. import org.apache.commons.math4.legacy.exception.OutOfRangeException;
  24. import org.apache.commons.math4.core.jdkmath.JdkMath;

  25. /**
  26.  * <a href="http://en.wikipedia.org/wiki/Logit">
  27.  *  Logit</a> function.
  28.  * It is the inverse of the {@link Sigmoid sigmoid} function.
  29.  *
  30.  * @since 3.0
  31.  */
  32. public class Logit implements UnivariateDifferentiableFunction {
  33.     /** Lower bound. */
  34.     private final double lo;
  35.     /** Higher bound. */
  36.     private final double hi;

  37.     /**
  38.      * Usual logit function, where the lower bound is 0 and the higher
  39.      * bound is 1.
  40.      */
  41.     public Logit() {
  42.         this(0, 1);
  43.     }

  44.     /**
  45.      * Logit function.
  46.      *
  47.      * @param lo Lower bound of the function domain.
  48.      * @param hi Higher bound of the function domain.
  49.      */
  50.     public Logit(double lo,
  51.                  double hi) {
  52.         this.lo = lo;
  53.         this.hi = hi;
  54.     }

  55.     /** {@inheritDoc} */
  56.     @Override
  57.     public double value(double x)
  58.         throws OutOfRangeException {
  59.         return value(x, lo, hi);
  60.     }

  61.     /**
  62.      * Parametric function where the input array contains the parameters of
  63.      * the logit function. Ordered as follows:
  64.      * <ul>
  65.      *  <li>Lower bound</li>
  66.      *  <li>Higher bound</li>
  67.      * </ul>
  68.      */
  69.     public static class Parametric implements ParametricUnivariateFunction {
  70.         /**
  71.          * Computes the value of the logit at {@code x}.
  72.          *
  73.          * @param x Value for which the function must be computed.
  74.          * @param param Values of lower bound and higher bounds.
  75.          * @return the value of the function.
  76.          * @throws NullArgumentException if {@code param} is {@code null}.
  77.          * @throws DimensionMismatchException if the size of {@code param} is
  78.          * not 2.
  79.          */
  80.         @Override
  81.         public double value(double x, double ... param)
  82.             throws NullArgumentException,
  83.                    DimensionMismatchException {
  84.             validateParameters(param);
  85.             return Logit.value(x, param[0], param[1]);
  86.         }

  87.         /**
  88.          * Computes the value of the gradient at {@code x}.
  89.          * The components of the gradient vector are the partial
  90.          * derivatives of the function with respect to each of the
  91.          * <em>parameters</em> (lower bound and higher bound).
  92.          *
  93.          * @param x Value at which the gradient must be computed.
  94.          * @param param Values for lower and higher bounds.
  95.          * @return the gradient vector at {@code x}.
  96.          * @throws NullArgumentException if {@code param} is {@code null}.
  97.          * @throws DimensionMismatchException if the size of {@code param} is
  98.          * not 2.
  99.          */
  100.         @Override
  101.         public double[] gradient(double x, double ... param)
  102.             throws NullArgumentException,
  103.                    DimensionMismatchException {
  104.             validateParameters(param);

  105.             final double lo = param[0];
  106.             final double hi = param[1];

  107.             return new double[] { 1 / (lo - x), 1 / (hi - x) };
  108.         }

  109.         /**
  110.          * Validates parameters to ensure they are appropriate for the evaluation of
  111.          * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
  112.          * methods.
  113.          *
  114.          * @param param Values for lower and higher bounds.
  115.          * @throws NullArgumentException if {@code param} is {@code null}.
  116.          * @throws DimensionMismatchException if the size of {@code param} is
  117.          * not 2.
  118.          */
  119.         private void validateParameters(double[] param)
  120.             throws NullArgumentException,
  121.                    DimensionMismatchException {
  122.             if (param == null) {
  123.                 throw new NullArgumentException();
  124.             }
  125.             if (param.length != 2) {
  126.                 throw new DimensionMismatchException(param.length, 2);
  127.             }
  128.         }
  129.     }

  130.     /**
  131.      * @param x Value at which to compute the logit.
  132.      * @param lo Lower bound.
  133.      * @param hi Higher bound.
  134.      * @return the value of the logit function at {@code x}.
  135.      * @throws OutOfRangeException if {@code x < lo} or {@code x > hi}.
  136.      */
  137.     private static double value(double x,
  138.                                 double lo,
  139.                                 double hi)
  140.         throws OutOfRangeException {
  141.         if (x < lo || x > hi) {
  142.             throw new OutOfRangeException(x, lo, hi);
  143.         }
  144.         return JdkMath.log((x - lo) / (hi - x));
  145.     }

  146.     /** {@inheritDoc}
  147.      * @since 3.1
  148.      * @exception OutOfRangeException if parameter is outside of function domain
  149.      */
  150.     @Override
  151.     public DerivativeStructure value(final DerivativeStructure t)
  152.         throws OutOfRangeException {
  153.         final double x = t.getValue();
  154.         if (x < lo || x > hi) {
  155.             throw new OutOfRangeException(x, lo, hi);
  156.         }
  157.         double[] f = new double[t.getOrder() + 1];

  158.         // function value
  159.         f[0] = JdkMath.log((x - lo) / (hi - x));

  160.         if (Double.isInfinite(f[0])) {

  161.             if (f.length > 1) {
  162.                 f[1] = Double.POSITIVE_INFINITY;
  163.             }
  164.             // fill the array with infinities
  165.             // (for x close to lo the signs will flip between -inf and +inf,
  166.             //  for x close to hi the signs will always be +inf)
  167.             // this is probably overkill, since the call to compose at the end
  168.             // of the method will transform most infinities into NaN ...
  169.             for (int i = 2; i < f.length; ++i) {
  170.                 f[i] = f[i - 2];
  171.             }
  172.         } else {

  173.             // function derivatives
  174.             final double invL = 1.0 / (x - lo);
  175.             double xL = invL;
  176.             final double invH = 1.0 / (hi - x);
  177.             double xH = invH;
  178.             for (int i = 1; i < f.length; ++i) {
  179.                 f[i] = xL + xH;
  180.                 xL  *= -i * invL;
  181.                 xH  *=  i * invH;
  182.             }
  183.         }

  184.         return t.compose(f);
  185.     }
  186. }