FeatureInitializerFactory.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.neuralnet;

  18. import java.util.function.DoubleUnaryOperator;

  19. import org.apache.commons.rng.UniformRandomProvider;
  20. import org.apache.commons.rng.sampling.distribution.ContinuousUniformSampler;

  21. /**
  22.  * Creates functions that will select the initial values of a neuron's
  23.  * features.
  24.  *
  25.  * @since 3.3
  26.  */
  27. public final class FeatureInitializerFactory {
  28.     /** Class contains only static methods. */
  29.     private FeatureInitializerFactory() {}

  30.     /**
  31.      * Uniform sampling of the given range.
  32.      *
  33.      * @param min Lower bound of the range.
  34.      * @param max Upper bound of the range.
  35.      * @param rng Random number generator used to draw samples from a
  36.      * uniform distribution.
  37.      * @return an initializer such that the features will be initialized with
  38.      * values within the given range.
  39.      * @throws IllegalArgumentException if {@code min >= max}.
  40.      */
  41.     public static FeatureInitializer uniform(final UniformRandomProvider rng,
  42.                                              final double min,
  43.                                              final double max) {
  44.         return randomize(new ContinuousUniformSampler(rng, min, max),
  45.                          function(x -> 0, 0, 0));
  46.     }

  47.     /**
  48.      * Creates an initializer from a univariate function {@code f(x)}.
  49.      * The argument {@code x} is set to {@code init} at the first call
  50.      * and will be incremented at each call.
  51.      *
  52.      * @param f Function.
  53.      * @param init Initial value.
  54.      * @param inc Increment
  55.      * @return the initializer.
  56.      */
  57.     public static FeatureInitializer function(final DoubleUnaryOperator f,
  58.                                               final double init,
  59.                                               final double inc) {
  60.         return new FeatureInitializer() {
  61.             /** Argument. */
  62.             private double arg = init;

  63.             /** {@inheritDoc} */
  64.             @Override
  65.             public double value() {
  66.                 final double result = f.applyAsDouble(arg);
  67.                 arg += inc;
  68.                 return result;
  69.             }
  70.         };
  71.     }

  72.     /**
  73.      * Adds some amount of random data to the given initializer.
  74.      *
  75.      * @param random Random variable distribution sampler.
  76.      * @param orig Original initializer.
  77.      * @return an initializer whose {@link FeatureInitializer#value() value}
  78.      * method will return {@code orig.value() + random.sample()}.
  79.      */
  80.     public static FeatureInitializer randomize(final ContinuousUniformSampler random,
  81.                                                final FeatureInitializer orig) {
  82.         return new FeatureInitializer() {
  83.             /** {@inheritDoc} */
  84.             @Override
  85.             public double value() {
  86.                 return orig.value() + random.sample();
  87.             }
  88.         };
  89.     }
  90. }