```001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
010 *
011 * Unless required by applicable law or agreed to in writing, software
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.analysis.function;
019
020import org.apache.commons.math4.analysis.ParametricUnivariateFunction;
021import org.apache.commons.math4.analysis.differentiation.DerivativeStructure;
022import org.apache.commons.math4.analysis.differentiation.UnivariateDifferentiableFunction;
023import org.apache.commons.math4.exception.DimensionMismatchException;
024import org.apache.commons.math4.exception.NullArgumentException;
025import org.apache.commons.math4.exception.OutOfRangeException;
026import org.apache.commons.math4.util.FastMath;
027
028/**
029 * <a href="http://en.wikipedia.org/wiki/Logit">
030 *  Logit</a> function.
031 * It is the inverse of the {@link Sigmoid sigmoid} function.
032 *
033 * @since 3.0
034 */
035public class Logit implements UnivariateDifferentiableFunction {
036    /** Lower bound. */
037    private final double lo;
038    /** Higher bound. */
039    private final double hi;
040
041    /**
042     * Usual logit function, where the lower bound is 0 and the higher
043     * bound is 1.
044     */
045    public Logit() {
046        this(0, 1);
047    }
048
049    /**
050     * Logit function.
051     *
052     * @param lo Lower bound of the function domain.
053     * @param hi Higher bound of the function domain.
054     */
055    public Logit(double lo,
056                 double hi) {
057        this.lo = lo;
058        this.hi = hi;
059    }
060
061    /** {@inheritDoc} */
062    @Override
063    public double value(double x)
064        throws OutOfRangeException {
065        return value(x, lo, hi);
066    }
067
068    /**
069     * Parametric function where the input array contains the parameters of
070     * the logit function, ordered as follows:
071     * <ul>
072     *  <li>Lower bound</li>
073     *  <li>Higher bound</li>
074     * </ul>
075     */
076    public static class Parametric implements ParametricUnivariateFunction {
077        /**
078         * Computes the value of the logit at {@code x}.
079         *
080         * @param x Value for which the function must be computed.
081         * @param param Values of lower bound and higher bounds.
082         * @return the value of the function.
083         * @throws NullArgumentException if {@code param} is {@code null}.
084         * @throws DimensionMismatchException if the size of {@code param} is
085         * not 2.
086         */
087        @Override
088        public double value(double x, double ... param)
089            throws NullArgumentException,
090                   DimensionMismatchException {
091            validateParameters(param);
092            return Logit.value(x, param[0], param[1]);
093        }
094
095        /**
096         * Computes the value of the gradient at {@code x}.
097         * The components of the gradient vector are the partial
098         * derivatives of the function with respect to each of the
099         * <em>parameters</em> (lower bound and higher bound).
100         *
101         * @param x Value at which the gradient must be computed.
102         * @param param Values for lower and higher bounds.
103         * @return the gradient vector at {@code x}.
104         * @throws NullArgumentException if {@code param} is {@code null}.
105         * @throws DimensionMismatchException if the size of {@code param} is
106         * not 2.
107         */
108        @Override
109        public double[] gradient(double x, double ... param)
110            throws NullArgumentException,
111                   DimensionMismatchException {
112            validateParameters(param);
113
114            final double lo = param[0];
115            final double hi = param[1];
116
117            return new double[] { 1 / (lo - x), 1 / (hi - x) };
118        }
119
120        /**
121         * Validates parameters to ensure they are appropriate for the evaluation of
123         * methods.
124         *
125         * @param param Values for lower and higher bounds.
126         * @throws NullArgumentException if {@code param} is {@code null}.
127         * @throws DimensionMismatchException if the size of {@code param} is
128         * not 2.
129         */
130        private void validateParameters(double[] param)
131            throws NullArgumentException,
132                   DimensionMismatchException {
133            if (param == null) {
134                throw new NullArgumentException();
135            }
136            if (param.length != 2) {
137                throw new DimensionMismatchException(param.length, 2);
138            }
139        }
140    }
141
142    /**
143     * @param x Value at which to compute the logit.
144     * @param lo Lower bound.
145     * @param hi Higher bound.
146     * @return the value of the logit function at {@code x}.
147     * @throws OutOfRangeException if {@code x < lo} or {@code x > hi}.
148     */
149    private static double value(double x,
150                                double lo,
151                                double hi)
152        throws OutOfRangeException {
153        if (x < lo || x > hi) {
154            throw new OutOfRangeException(x, lo, hi);
155        }
156        return FastMath.log((x - lo) / (hi - x));
157    }
158
159    /** {@inheritDoc}
160     * @since 3.1
161     * @exception OutOfRangeException if parameter is outside of function domain
162     */
163    @Override
164    public DerivativeStructure value(final DerivativeStructure t)
165        throws OutOfRangeException {
166        final double x = t.getValue();
167        if (x < lo || x > hi) {
168            throw new OutOfRangeException(x, lo, hi);
169        }
170        double[] f = new double[t.getOrder() + 1];
171
172        // function value
173        f[0] = FastMath.log((x - lo) / (hi - x));
174
175        if (Double.isInfinite(f[0])) {
176
177            if (f.length > 1) {
178                f[1] = Double.POSITIVE_INFINITY;
179            }
180            // fill the array with infinities
181            // (for x close to lo the signs will flip between -inf and +inf,
182            //  for x close to hi the signs will always be +inf)
183            // this is probably overkill, since the call to compose at the end
184            // of the method will transform most infinities into NaN ...
185            for (int i = 2; i < f.length; ++i) {
186                f[i] = f[i - 2];
187            }
188
189        } else {
190
191            // function derivatives
192            final double invL = 1.0 / (x - lo);
193            double xL = invL;
194            final double invH = 1.0 / (hi - x);
195            double xH = invH;
196            for (int i = 1; i < f.length; ++i) {
197                f[i] = xL + xH;
198                xL  *= -i * invL;
199                xH  *=  i * invH;
200            }
201        }
202
203        return t.compose(f);
204    }
205}

```