QuasiSigmoidDecayFunction.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.neuralnet.sofm.util;

  18. import java.util.function.DoubleUnaryOperator;
  19. import java.util.function.LongToDoubleFunction;

  20. import org.apache.commons.math4.neuralnet.internal.NeuralNetException;

  21. /**
  22.  * Decay function whose shape is similar to a sigmoid.
  23.  * <br>
  24.  * Class is immutable.
  25.  *
  26.  * @since 3.3
  27.  */
  28. public class QuasiSigmoidDecayFunction implements LongToDoubleFunction {
  29.     /** Sigmoid. */
  30.     private final DoubleUnaryOperator sigmoid;
  31.     /** See {@link #value(long)}. */
  32.     private final double scale;

  33.     /**
  34.      * Creates an instance.
  35.      * The function {@code f} will have the following properties:
  36.      * <ul>
  37.      *  <li>{@code f(0) = initValue}</li>
  38.      *  <li>{@code numCall} is the inflexion point</li>
  39.      *  <li>{@code slope = f'(numCall)}</li>
  40.      * </ul>
  41.      *
  42.      * @param initValue Initial value, i.e. {@link #applyAsDouble(long) applyAsDouble(0)}.
  43.      * @param slope Value of the function derivative at {@code numCall}.
  44.      * @param numCall Inflexion point.
  45.      * @throws IllegalArgumentException if {@code initValue <= 0},
  46.      * {@code slope >= 0} or {@code numCall <= 0}.
  47.      */
  48.     public QuasiSigmoidDecayFunction(double initValue,
  49.                                      double slope,
  50.                                      long numCall) {
  51.         if (initValue <= 0) {
  52.             throw new NeuralNetException(NeuralNetException.NOT_STRICTLY_POSITIVE, initValue);
  53.         }
  54.         if (slope >= 0) {
  55.             throw new NeuralNetException(NeuralNetException.TOO_LARGE, slope, 0);
  56.         }
  57.         if (numCall <= 1) {
  58.             throw new NeuralNetException(NeuralNetException.TOO_SMALL, numCall, 1);
  59.         }

  60.         final double k = initValue;
  61.         final double m = numCall;
  62.         final double b = 4 * slope / initValue;
  63.         sigmoid = x -> k / (1 + Math.exp(b * (m - x)));

  64.         final double y0 = sigmoid.applyAsDouble(0d);
  65.         scale = k / y0;
  66.     }

  67.     /**
  68.      * Computes the value of the learning factor.
  69.      *
  70.      * @param numCall Current step of the training task.
  71.      * @return the value of the function at {@code numCall}.
  72.      */
  73.     @Override
  74.     public double applyAsDouble(long numCall) {
  75.         return scale * sigmoid.applyAsDouble((double) numCall);
  76.     }
  77. }