001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.analysis.function;
019
020import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction;
021import org.apache.commons.math3.analysis.FunctionUtils;
022import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
023import org.apache.commons.math3.analysis.UnivariateFunction;
024import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
025import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
026import org.apache.commons.math3.exception.DimensionMismatchException;
027import org.apache.commons.math3.exception.NullArgumentException;
028import org.apache.commons.math3.exception.OutOfRangeException;
029import org.apache.commons.math3.util.FastMath;
030
031/**
032 * <a href="http://en.wikipedia.org/wiki/Logit">
033 *  Logit</a> function.
034 * It is the inverse of the {@link Sigmoid sigmoid} function.
035 *
036 * @since 3.0
037 */
038public class Logit implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction {
039    /** Lower bound. */
040    private final double lo;
041    /** Higher bound. */
042    private final double hi;
043
044    /**
045     * Usual logit function, where the lower bound is 0 and the higher
046     * bound is 1.
047     */
048    public Logit() {
049        this(0, 1);
050    }
051
052    /**
053     * Logit function.
054     *
055     * @param lo Lower bound of the function domain.
056     * @param hi Higher bound of the function domain.
057     */
058    public Logit(double lo,
059                 double hi) {
060        this.lo = lo;
061        this.hi = hi;
062    }
063
064    /** {@inheritDoc} */
065    public double value(double x)
066        throws OutOfRangeException {
067        return value(x, lo, hi);
068    }
069
070    /** {@inheritDoc}
071     * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)}
072     */
073    @Deprecated
074    public UnivariateFunction derivative() {
075        return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative();
076    }
077
078    /**
079     * Parametric function where the input array contains the parameters of
080     * the logit function, ordered as follows:
081     * <ul>
082     *  <li>Lower bound</li>
083     *  <li>Higher bound</li>
084     * </ul>
085     */
086    public static class Parametric implements ParametricUnivariateFunction {
087        /**
088         * Computes the value of the logit at {@code x}.
089         *
090         * @param x Value for which the function must be computed.
091         * @param param Values of lower bound and higher bounds.
092         * @return the value of the function.
093         * @throws NullArgumentException if {@code param} is {@code null}.
094         * @throws DimensionMismatchException if the size of {@code param} is
095         * not 2.
096         */
097        public double value(double x, double ... param)
098            throws NullArgumentException,
099                   DimensionMismatchException {
100            validateParameters(param);
101            return Logit.value(x, param[0], param[1]);
102        }
103
104        /**
105         * Computes the value of the gradient at {@code x}.
106         * The components of the gradient vector are the partial
107         * derivatives of the function with respect to each of the
108         * <em>parameters</em> (lower bound and higher bound).
109         *
110         * @param x Value at which the gradient must be computed.
111         * @param param Values for lower and higher bounds.
112         * @return the gradient vector at {@code x}.
113         * @throws NullArgumentException if {@code param} is {@code null}.
114         * @throws DimensionMismatchException if the size of {@code param} is
115         * not 2.
116         */
117        public double[] gradient(double x, double ... param)
118            throws NullArgumentException,
119                   DimensionMismatchException {
120            validateParameters(param);
121
122            final double lo = param[0];
123            final double hi = param[1];
124
125            return new double[] { 1 / (lo - x), 1 / (hi - x) };
126        }
127
128        /**
129         * Validates parameters to ensure they are appropriate for the evaluation of
130         * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
131         * methods.
132         *
133         * @param param Values for lower and higher bounds.
134         * @throws NullArgumentException if {@code param} is {@code null}.
135         * @throws DimensionMismatchException if the size of {@code param} is
136         * not 2.
137         */
138        private void validateParameters(double[] param)
139            throws NullArgumentException,
140                   DimensionMismatchException {
141            if (param == null) {
142                throw new NullArgumentException();
143            }
144            if (param.length != 2) {
145                throw new DimensionMismatchException(param.length, 2);
146            }
147        }
148    }
149
150    /**
151     * @param x Value at which to compute the logit.
152     * @param lo Lower bound.
153     * @param hi Higher bound.
154     * @return the value of the logit function at {@code x}.
155     * @throws OutOfRangeException if {@code x < lo} or {@code x > hi}.
156     */
157    private static double value(double x,
158                                double lo,
159                                double hi)
160        throws OutOfRangeException {
161        if (x < lo || x > hi) {
162            throw new OutOfRangeException(x, lo, hi);
163        }
164        return FastMath.log((x - lo) / (hi - x));
165    }
166
167    /** {@inheritDoc}
168     * @since 3.1
169     * @exception OutOfRangeException if parameter is outside of function domain
170     */
171    public DerivativeStructure value(final DerivativeStructure t)
172        throws OutOfRangeException {
173        final double x = t.getValue();
174        if (x < lo || x > hi) {
175            throw new OutOfRangeException(x, lo, hi);
176        }
177        double[] f = new double[t.getOrder() + 1];
178
179        // function value
180        f[0] = FastMath.log((x - lo) / (hi - x));
181
182        if (Double.isInfinite(f[0])) {
183
184            if (f.length > 1) {
185                f[1] = Double.POSITIVE_INFINITY;
186            }
187            // fill the array with infinities
188            // (for x close to lo the signs will flip between -inf and +inf,
189            //  for x close to hi the signs will always be +inf)
190            // this is probably overkill, since the call to compose at the end
191            // of the method will transform most infinities into NaN ...
192            for (int i = 2; i < f.length; ++i) {
193                f[i] = f[i - 2];
194            }
195
196        } else {
197
198            // function derivatives
199            final double invL = 1.0 / (x - lo);
200            double xL = invL;
201            final double invH = 1.0 / (hi - x);
202            double xH = invH;
203            for (int i = 1; i < f.length; ++i) {
204                f[i] = xL + xH;
205                xL  *= -i * invL;
206                xH  *=  i * invH;
207            }
208        }
209
210        return t.compose(f);
211    }
212}