KohonenUpdateAction.java
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- package org.apache.commons.math4.neuralnet.sofm;
- import java.util.Collection;
- import java.util.HashSet;
- import java.util.concurrent.atomic.AtomicLong;
- import java.util.function.DoubleUnaryOperator;
- import org.apache.commons.math4.neuralnet.DistanceMeasure;
- import org.apache.commons.math4.neuralnet.MapRanking;
- import org.apache.commons.math4.neuralnet.Network;
- import org.apache.commons.math4.neuralnet.Neuron;
- import org.apache.commons.math4.neuralnet.UpdateAction;
- /**
- * Update formula for <a href="http://en.wikipedia.org/wiki/Kohonen">
- * Kohonen's Self-Organizing Map</a>.
- * <br>
- * The {@link #update(Network,double[]) update} method modifies the
- * features {@code w} of the "winning" neuron and its neighbours
- * according to the following rule:
- * <code>
- * w<sub>new</sub> = w<sub>old</sub> + α e<sup>(-d / σ)</sup> * (sample - w<sub>old</sub>)
- * </code>
- * where
- * <ul>
- * <li>α is the current <em>learning rate</em>, </li>
- * <li>σ is the current <em>neighbourhood size</em>, and</li>
- * <li>{@code d} is the number of links to traverse in order to reach
- * the neuron from the winning neuron.</li>
- * </ul>
- * <br>
- * This class is thread-safe as long as the arguments passed to the
- * {@link #KohonenUpdateAction(DistanceMeasure,LearningFactorFunction,
- * NeighbourhoodSizeFunction) constructor} are instances of thread-safe
- * classes.
- * <br>
- * Each call to the {@link #update(Network,double[]) update} method
- * will increment the internal counter used to compute the current
- * values for
- * <ul>
- * <li>the <em>learning rate</em>, and</li>
- * <li>the <em>neighbourhood size</em>.</li>
- * </ul>
- * Consequently, the function instances that compute those values (passed
- * to the constructor of this class) must take into account whether this
- * class's instance will be shared by multiple threads, as this will impact
- * the training process.
- *
- * @since 3.3
- */
- public class KohonenUpdateAction implements UpdateAction {
- /** Distance function. */
- private final DistanceMeasure distance;
- /** Learning factor update function. */
- private final LearningFactorFunction learningFactor;
- /** Neighbourhood size update function. */
- private final NeighbourhoodSizeFunction neighbourhoodSize;
- /** Number of calls to {@link #update(Network,double[])}. */
- private final AtomicLong numberOfCalls = new AtomicLong(0);
- /**
- * @param distance Distance function.
- * @param learningFactor Learning factor update function.
- * @param neighbourhoodSize Neighbourhood size update function.
- */
- public KohonenUpdateAction(DistanceMeasure distance,
- LearningFactorFunction learningFactor,
- NeighbourhoodSizeFunction neighbourhoodSize) {
- this.distance = distance;
- this.learningFactor = learningFactor;
- this.neighbourhoodSize = neighbourhoodSize;
- }
- /**
- * {@inheritDoc}
- */
- @Override
- public void update(Network net,
- double[] features) {
- final long numCalls = numberOfCalls.incrementAndGet() - 1;
- final double currentLearning = learningFactor.value(numCalls);
- final Neuron best = findAndUpdateBestNeuron(net,
- features,
- currentLearning);
- final int currentNeighbourhood = neighbourhoodSize.value(numCalls);
- // The farther away the neighbour is from the winning neuron, the
- // smaller the learning rate will become.
- final Gaussian neighbourhoodDecay
- = new Gaussian(currentLearning, currentNeighbourhood);
- if (currentNeighbourhood > 0) {
- // Initial set of neurons only contains the winning neuron.
- Collection<Neuron> neighbours = new HashSet<>();
- neighbours.add(best);
- // Winning neuron must be excluded from the neighbours.
- final HashSet<Neuron> exclude = new HashSet<>();
- exclude.add(best);
- int radius = 1;
- do {
- // Retrieve immediate neighbours of the current set of neurons.
- neighbours = net.getNeighbours(neighbours, exclude);
- // Update all the neighbours.
- for (final Neuron n : neighbours) {
- updateNeighbouringNeuron(n, features, neighbourhoodDecay.applyAsDouble(radius));
- }
- // Add the neighbours to the exclude list so that they will
- // not be updated more than once per training step.
- exclude.addAll(neighbours);
- ++radius;
- } while (radius <= currentNeighbourhood);
- }
- }
- /**
- * Retrieves the number of calls to the {@link #update(Network,double[]) update}
- * method.
- *
- * @return the current number of calls.
- */
- public long getNumberOfCalls() {
- return numberOfCalls.get();
- }
- /**
- * Tries to update a neuron.
- *
- * @param n Neuron to be updated.
- * @param features Training data.
- * @param learningRate Learning factor.
- * @return {@code true} if the update succeeded, {@code true} if a
- * concurrent update has been detected.
- */
- private boolean attemptNeuronUpdate(Neuron n,
- double[] features,
- double learningRate) {
- final double[] expect = n.getFeatures();
- final double[] update = computeFeatures(expect,
- features,
- learningRate);
- return n.compareAndSetFeatures(expect, update);
- }
- /**
- * Atomically updates the given neuron.
- *
- * @param n Neuron to be updated.
- * @param features Training data.
- * @param learningRate Learning factor.
- */
- private void updateNeighbouringNeuron(Neuron n,
- double[] features,
- double learningRate) {
- while (true) {
- if (attemptNeuronUpdate(n, features, learningRate)) {
- break;
- }
- }
- }
- /**
- * Searches for the neuron whose features are closest to the given
- * sample, and atomically updates its features.
- *
- * @param net Network.
- * @param features Sample data.
- * @param learningRate Current learning factor.
- * @return the winning neuron.
- */
- private Neuron findAndUpdateBestNeuron(Network net,
- double[] features,
- double learningRate) {
- final MapRanking rank = new MapRanking(net, distance);
- while (true) {
- final Neuron best = rank.rank(features, 1).get(0);
- if (attemptNeuronUpdate(best, features, learningRate)) {
- return best;
- }
- // If another thread modified the state of the winning neuron,
- // it may not be the best match anymore for the given training
- // sample: Hence, the winner search is performed again.
- }
- }
- /**
- * Computes the new value of the features set.
- *
- * @param current Current values of the features.
- * @param sample Training data.
- * @param learningRate Learning factor.
- * @return the new values for the features.
- */
- private double[] computeFeatures(double[] current,
- double[] sample,
- double learningRate) {
- final int len = current.length;
- final double[] r = new double[len];
- for (int i = 0; i < len; i++) {
- final double c = current[i];
- final double s = sample[i];
- r[i] = c + learningRate * (s - c);
- }
- return r;
- }
- /**
- * Gaussian function with zero mean.
- */
- private static class Gaussian implements DoubleUnaryOperator {
- /** Inverse of twice the square of the standard deviation. */
- private final double i2s2;
- /** Normalization factor. */
- private final double norm;
- /**
- * @param norm Normalization factor.
- * @param sigma Standard deviation.
- */
- Gaussian(double norm,
- double sigma) {
- this.norm = norm;
- i2s2 = 1d / (2 * sigma * sigma);
- }
- @Override
- public double applyAsDouble(double x) {
- return norm * Math.exp(-x * x * i2s2);
- }
- }
- }