KohonenUpdateAction.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.neuralnet.sofm;

  18. import java.util.Collection;
  19. import java.util.HashSet;
  20. import java.util.concurrent.atomic.AtomicLong;
  21. import java.util.function.DoubleUnaryOperator;

  22. import org.apache.commons.math4.neuralnet.DistanceMeasure;
  23. import org.apache.commons.math4.neuralnet.MapRanking;
  24. import org.apache.commons.math4.neuralnet.Network;
  25. import org.apache.commons.math4.neuralnet.Neuron;
  26. import org.apache.commons.math4.neuralnet.UpdateAction;

  27. /**
  28.  * Update formula for <a href="http://en.wikipedia.org/wiki/Kohonen">
  29.  * Kohonen's Self-Organizing Map</a>.
  30.  * <br>
  31.  * The {@link #update(Network,double[]) update} method modifies the
  32.  * features {@code w} of the "winning" neuron and its neighbours
  33.  * according to the following rule:
  34.  * <code>
  35.  *  w<sub>new</sub> = w<sub>old</sub> + &alpha; e<sup>(-d / &sigma;)</sup> * (sample - w<sub>old</sub>)
  36.  * </code>
  37.  * where
  38.  * <ul>
  39.  *  <li>&alpha; is the current <em>learning rate</em>, </li>
  40.  *  <li>&sigma; is the current <em>neighbourhood size</em>, and</li>
  41.  *  <li>{@code d} is the number of links to traverse in order to reach
  42.  *   the neuron from the winning neuron.</li>
  43.  * </ul>
  44.  * <br>
  45.  * This class is thread-safe as long as the arguments passed to the
  46.  * {@link #KohonenUpdateAction(DistanceMeasure,LearningFactorFunction,
  47.  * NeighbourhoodSizeFunction) constructor} are instances of thread-safe
  48.  * classes.
  49.  * <br>
  50.  * Each call to the {@link #update(Network,double[]) update} method
  51.  * will increment the internal counter used to compute the current
  52.  * values for
  53.  * <ul>
  54.  *  <li>the <em>learning rate</em>, and</li>
  55.  *  <li>the <em>neighbourhood size</em>.</li>
  56.  * </ul>
  57.  * Consequently, the function instances that compute those values (passed
  58.  * to the constructor of this class) must take into account whether this
  59.  * class's instance will be shared by multiple threads, as this will impact
  60.  * the training process.
  61.  *
  62.  * @since 3.3
  63.  */
  64. public class KohonenUpdateAction implements UpdateAction {
  65.     /** Distance function. */
  66.     private final DistanceMeasure distance;
  67.     /** Learning factor update function. */
  68.     private final LearningFactorFunction learningFactor;
  69.     /** Neighbourhood size update function. */
  70.     private final NeighbourhoodSizeFunction neighbourhoodSize;
  71.     /** Number of calls to {@link #update(Network,double[])}. */
  72.     private final AtomicLong numberOfCalls = new AtomicLong(0);

  73.     /**
  74.      * @param distance Distance function.
  75.      * @param learningFactor Learning factor update function.
  76.      * @param neighbourhoodSize Neighbourhood size update function.
  77.      */
  78.     public KohonenUpdateAction(DistanceMeasure distance,
  79.                                LearningFactorFunction learningFactor,
  80.                                NeighbourhoodSizeFunction neighbourhoodSize) {
  81.         this.distance = distance;
  82.         this.learningFactor = learningFactor;
  83.         this.neighbourhoodSize = neighbourhoodSize;
  84.     }

  85.     /**
  86.      * {@inheritDoc}
  87.      */
  88.     @Override
  89.     public void update(Network net,
  90.                        double[] features) {
  91.         final long numCalls = numberOfCalls.incrementAndGet() - 1;
  92.         final double currentLearning = learningFactor.value(numCalls);
  93.         final Neuron best = findAndUpdateBestNeuron(net,
  94.                                                     features,
  95.                                                     currentLearning);

  96.         final int currentNeighbourhood = neighbourhoodSize.value(numCalls);
  97.         // The farther away the neighbour is from the winning neuron, the
  98.         // smaller the learning rate will become.
  99.         final Gaussian neighbourhoodDecay
  100.             = new Gaussian(currentLearning, currentNeighbourhood);

  101.         if (currentNeighbourhood > 0) {
  102.             // Initial set of neurons only contains the winning neuron.
  103.             Collection<Neuron> neighbours = new HashSet<>();
  104.             neighbours.add(best);
  105.             // Winning neuron must be excluded from the neighbours.
  106.             final HashSet<Neuron> exclude = new HashSet<>();
  107.             exclude.add(best);

  108.             int radius = 1;
  109.             do {
  110.                 // Retrieve immediate neighbours of the current set of neurons.
  111.                 neighbours = net.getNeighbours(neighbours, exclude);

  112.                 // Update all the neighbours.
  113.                 for (final Neuron n : neighbours) {
  114.                     updateNeighbouringNeuron(n, features, neighbourhoodDecay.applyAsDouble(radius));
  115.                 }

  116.                 // Add the neighbours to the exclude list so that they will
  117.                 // not be updated more than once per training step.
  118.                 exclude.addAll(neighbours);
  119.                 ++radius;
  120.             } while (radius <= currentNeighbourhood);
  121.         }
  122.     }

  123.     /**
  124.      * Retrieves the number of calls to the {@link #update(Network,double[]) update}
  125.      * method.
  126.      *
  127.      * @return the current number of calls.
  128.      */
  129.     public long getNumberOfCalls() {
  130.         return numberOfCalls.get();
  131.     }

  132.     /**
  133.      * Tries to update a neuron.
  134.      *
  135.      * @param n Neuron to be updated.
  136.      * @param features Training data.
  137.      * @param learningRate Learning factor.
  138.      * @return {@code true} if the update succeeded, {@code true} if a
  139.      * concurrent update has been detected.
  140.      */
  141.     private boolean attemptNeuronUpdate(Neuron n,
  142.                                         double[] features,
  143.                                         double learningRate) {
  144.         final double[] expect = n.getFeatures();
  145.         final double[] update = computeFeatures(expect,
  146.                                                 features,
  147.                                                 learningRate);

  148.         return n.compareAndSetFeatures(expect, update);
  149.     }

  150.     /**
  151.      * Atomically updates the given neuron.
  152.      *
  153.      * @param n Neuron to be updated.
  154.      * @param features Training data.
  155.      * @param learningRate Learning factor.
  156.      */
  157.     private void updateNeighbouringNeuron(Neuron n,
  158.                                           double[] features,
  159.                                           double learningRate) {
  160.         while (true) {
  161.             if (attemptNeuronUpdate(n, features, learningRate)) {
  162.                 break;
  163.             }
  164.         }
  165.     }

  166.     /**
  167.      * Searches for the neuron whose features are closest to the given
  168.      * sample, and atomically updates its features.
  169.      *
  170.      * @param net Network.
  171.      * @param features Sample data.
  172.      * @param learningRate Current learning factor.
  173.      * @return the winning neuron.
  174.      */
  175.     private Neuron findAndUpdateBestNeuron(Network net,
  176.                                            double[] features,
  177.                                            double learningRate) {
  178.         final MapRanking rank = new MapRanking(net, distance);

  179.         while (true) {
  180.             final Neuron best = rank.rank(features, 1).get(0);

  181.             if (attemptNeuronUpdate(best, features, learningRate)) {
  182.                 return best;
  183.             }

  184.             // If another thread modified the state of the winning neuron,
  185.             // it may not be the best match anymore for the given training
  186.             // sample: Hence, the winner search is performed again.
  187.         }
  188.     }

  189.     /**
  190.      * Computes the new value of the features set.
  191.      *
  192.      * @param current Current values of the features.
  193.      * @param sample Training data.
  194.      * @param learningRate Learning factor.
  195.      * @return the new values for the features.
  196.      */
  197.     private double[] computeFeatures(double[] current,
  198.                                      double[] sample,
  199.                                      double learningRate) {
  200.         final int len = current.length;
  201.         final double[] r = new double[len];
  202.         for (int i = 0; i < len; i++) {
  203.             final double c = current[i];
  204.             final double s = sample[i];
  205.             r[i] = c + learningRate * (s - c);
  206.         }
  207.         return r;
  208.     }

  209.     /**
  210.      * Gaussian function with zero mean.
  211.      */
  212.     private static class Gaussian implements DoubleUnaryOperator {
  213.         /** Inverse of twice the square of the standard deviation. */
  214.         private final double i2s2;
  215.         /** Normalization factor. */
  216.         private final double norm;

  217.         /**
  218.          * @param norm Normalization factor.
  219.          * @param sigma Standard deviation.
  220.          */
  221.         Gaussian(double norm,
  222.                  double sigma) {
  223.             this.norm = norm;
  224.             i2s2 = 1d / (2 * sigma * sigma);
  225.         }

  226.         @Override
  227.         public double applyAsDouble(double x) {
  228.             return norm * Math.exp(-x * x * i2s2);
  229.         }
  230.     }
  231. }