001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.neuralnet.sofm;
019
020import java.util.Collection;
021import java.util.HashSet;
022import java.util.concurrent.atomic.AtomicLong;
023import java.util.function.DoubleUnaryOperator;
024
025import org.apache.commons.math4.neuralnet.DistanceMeasure;
026import org.apache.commons.math4.neuralnet.MapRanking;
027import org.apache.commons.math4.neuralnet.Network;
028import org.apache.commons.math4.neuralnet.Neuron;
029import org.apache.commons.math4.neuralnet.UpdateAction;
030
031/**
032 * Update formula for <a href="http://en.wikipedia.org/wiki/Kohonen">
033 * Kohonen's Self-Organizing Map</a>.
034 * <br>
035 * The {@link #update(Network,double[]) update} method modifies the
036 * features {@code w} of the "winning" neuron and its neighbours
037 * according to the following rule:
038 * <code>
039 *  w<sub>new</sub> = w<sub>old</sub> + &alpha; e<sup>(-d / &sigma;)</sup> * (sample - w<sub>old</sub>)
040 * </code>
041 * where
042 * <ul>
043 *  <li>&alpha; is the current <em>learning rate</em>, </li>
044 *  <li>&sigma; is the current <em>neighbourhood size</em>, and</li>
045 *  <li>{@code d} is the number of links to traverse in order to reach
046 *   the neuron from the winning neuron.</li>
047 * </ul>
048 * <br>
049 * This class is thread-safe as long as the arguments passed to the
050 * {@link #KohonenUpdateAction(DistanceMeasure,LearningFactorFunction,
051 * NeighbourhoodSizeFunction) constructor} are instances of thread-safe
052 * classes.
053 * <br>
054 * Each call to the {@link #update(Network,double[]) update} method
055 * will increment the internal counter used to compute the current
056 * values for
057 * <ul>
058 *  <li>the <em>learning rate</em>, and</li>
059 *  <li>the <em>neighbourhood size</em>.</li>
060 * </ul>
061 * Consequently, the function instances that compute those values (passed
062 * to the constructor of this class) must take into account whether this
063 * class's instance will be shared by multiple threads, as this will impact
064 * the training process.
065 *
066 * @since 3.3
067 */
068public class KohonenUpdateAction implements UpdateAction {
069    /** Distance function. */
070    private final DistanceMeasure distance;
071    /** Learning factor update function. */
072    private final LearningFactorFunction learningFactor;
073    /** Neighbourhood size update function. */
074    private final NeighbourhoodSizeFunction neighbourhoodSize;
075    /** Number of calls to {@link #update(Network,double[])}. */
076    private final AtomicLong numberOfCalls = new AtomicLong(0);
077
078    /**
079     * @param distance Distance function.
080     * @param learningFactor Learning factor update function.
081     * @param neighbourhoodSize Neighbourhood size update function.
082     */
083    public KohonenUpdateAction(DistanceMeasure distance,
084                               LearningFactorFunction learningFactor,
085                               NeighbourhoodSizeFunction neighbourhoodSize) {
086        this.distance = distance;
087        this.learningFactor = learningFactor;
088        this.neighbourhoodSize = neighbourhoodSize;
089    }
090
091    /**
092     * {@inheritDoc}
093     */
094    @Override
095    public void update(Network net,
096                       double[] features) {
097        final long numCalls = numberOfCalls.incrementAndGet() - 1;
098        final double currentLearning = learningFactor.value(numCalls);
099        final Neuron best = findAndUpdateBestNeuron(net,
100                                                    features,
101                                                    currentLearning);
102
103        final int currentNeighbourhood = neighbourhoodSize.value(numCalls);
104        // The farther away the neighbour is from the winning neuron, the
105        // smaller the learning rate will become.
106        final Gaussian neighbourhoodDecay
107            = new Gaussian(currentLearning, currentNeighbourhood);
108
109        if (currentNeighbourhood > 0) {
110            // Initial set of neurons only contains the winning neuron.
111            Collection<Neuron> neighbours = new HashSet<>();
112            neighbours.add(best);
113            // Winning neuron must be excluded from the neighbours.
114            final HashSet<Neuron> exclude = new HashSet<>();
115            exclude.add(best);
116
117            int radius = 1;
118            do {
119                // Retrieve immediate neighbours of the current set of neurons.
120                neighbours = net.getNeighbours(neighbours, exclude);
121
122                // Update all the neighbours.
123                for (final Neuron n : neighbours) {
124                    updateNeighbouringNeuron(n, features, neighbourhoodDecay.applyAsDouble(radius));
125                }
126
127                // Add the neighbours to the exclude list so that they will
128                // not be updated more than once per training step.
129                exclude.addAll(neighbours);
130                ++radius;
131            } while (radius <= currentNeighbourhood);
132        }
133    }
134
135    /**
136     * Retrieves the number of calls to the {@link #update(Network,double[]) update}
137     * method.
138     *
139     * @return the current number of calls.
140     */
141    public long getNumberOfCalls() {
142        return numberOfCalls.get();
143    }
144
145    /**
146     * Tries to update a neuron.
147     *
148     * @param n Neuron to be updated.
149     * @param features Training data.
150     * @param learningRate Learning factor.
151     * @return {@code true} if the update succeeded, {@code true} if a
152     * concurrent update has been detected.
153     */
154    private boolean attemptNeuronUpdate(Neuron n,
155                                        double[] features,
156                                        double learningRate) {
157        final double[] expect = n.getFeatures();
158        final double[] update = computeFeatures(expect,
159                                                features,
160                                                learningRate);
161
162        return n.compareAndSetFeatures(expect, update);
163    }
164
165    /**
166     * Atomically updates the given neuron.
167     *
168     * @param n Neuron to be updated.
169     * @param features Training data.
170     * @param learningRate Learning factor.
171     */
172    private void updateNeighbouringNeuron(Neuron n,
173                                          double[] features,
174                                          double learningRate) {
175        while (true) {
176            if (attemptNeuronUpdate(n, features, learningRate)) {
177                break;
178            }
179        }
180    }
181
182    /**
183     * Searches for the neuron whose features are closest to the given
184     * sample, and atomically updates its features.
185     *
186     * @param net Network.
187     * @param features Sample data.
188     * @param learningRate Current learning factor.
189     * @return the winning neuron.
190     */
191    private Neuron findAndUpdateBestNeuron(Network net,
192                                           double[] features,
193                                           double learningRate) {
194        final MapRanking rank = new MapRanking(net, distance);
195
196        while (true) {
197            final Neuron best = rank.rank(features, 1).get(0);
198
199            if (attemptNeuronUpdate(best, features, learningRate)) {
200                return best;
201            }
202
203            // If another thread modified the state of the winning neuron,
204            // it may not be the best match anymore for the given training
205            // sample: Hence, the winner search is performed again.
206        }
207    }
208
209    /**
210     * Computes the new value of the features set.
211     *
212     * @param current Current values of the features.
213     * @param sample Training data.
214     * @param learningRate Learning factor.
215     * @return the new values for the features.
216     */
217    private double[] computeFeatures(double[] current,
218                                     double[] sample,
219                                     double learningRate) {
220        final int len = current.length;
221        final double[] r = new double[len];
222        for (int i = 0; i < len; i++) {
223            final double c = current[i];
224            final double s = sample[i];
225            r[i] = c + learningRate * (s - c);
226        }
227        return r;
228    }
229
230    /**
231     * Gaussian function with zero mean.
232     */
233    private static class Gaussian implements DoubleUnaryOperator {
234        /** Inverse of twice the square of the standard deviation. */
235        private final double i2s2;
236        /** Normalization factor. */
237        private final double norm;
238
239        /**
240         * @param norm Normalization factor.
241         * @param sigma Standard deviation.
242         */
243        Gaussian(double norm,
244                 double sigma) {
245            this.norm = norm;
246            i2s2 = 1d / (2 * sigma * sigma);
247        }
248
249        @Override
250        public double applyAsDouble(double x) {
251            return norm * Math.exp(-x * x * i2s2);
252        }
253    }
254}