001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.ml.neuralnet.sofm;
019
020import java.util.Collection;
021import java.util.HashSet;
022import java.util.concurrent.atomic.AtomicLong;
023
024import org.apache.commons.math4.analysis.function.Gaussian;
025import org.apache.commons.math4.linear.ArrayRealVector;
026import org.apache.commons.math4.ml.distance.DistanceMeasure;
027import org.apache.commons.math4.ml.neuralnet.MapRanking;
028import org.apache.commons.math4.ml.neuralnet.Network;
029import org.apache.commons.math4.ml.neuralnet.Neuron;
030import org.apache.commons.math4.ml.neuralnet.UpdateAction;
031
032/**
033 * Update formula for <a href="http://en.wikipedia.org/wiki/Kohonen">
034 * Kohonen's Self-Organizing Map</a>.
035 * <br>
036 * The {@link #update(Network,double[]) update} method modifies the
037 * features {@code w} of the "winning" neuron and its neighbours
038 * according to the following rule:
039 * <code>
040 *  w<sub>new</sub> = w<sub>old</sub> + &alpha; e<sup>(-d / &sigma;)</sup> * (sample - w<sub>old</sub>)
041 * </code>
042 * where
043 * <ul>
044 *  <li>&alpha; is the current <em>learning rate</em>, </li>
045 *  <li>&sigma; is the current <em>neighbourhood size</em>, and</li>
046 *  <li>{@code d} is the number of links to traverse in order to reach
047 *   the neuron from the winning neuron.</li>
048 * </ul>
049 * <br>
050 * This class is thread-safe as long as the arguments passed to the
051 * {@link #KohonenUpdateAction(DistanceMeasure,LearningFactorFunction,
052 * NeighbourhoodSizeFunction) constructor} are instances of thread-safe
053 * classes.
054 * <br>
055 * Each call to the {@link #update(Network,double[]) update} method
056 * will increment the internal counter used to compute the current
057 * values for
058 * <ul>
059 *  <li>the <em>learning rate</em>, and</li>
060 *  <li>the <em>neighbourhood size</em>.</li>
061 * </ul>
062 * Consequently, the function instances that compute those values (passed
063 * to the constructor of this class) must take into account whether this
064 * class's instance will be shared by multiple threads, as this will impact
065 * the training process.
066 *
067 * @since 3.3
068 */
069public class KohonenUpdateAction implements UpdateAction {
070    /** Distance function. */
071    private final DistanceMeasure distance;
072    /** Learning factor update function. */
073    private final LearningFactorFunction learningFactor;
074    /** Neighbourhood size update function. */
075    private final NeighbourhoodSizeFunction neighbourhoodSize;
076    /** Number of calls to {@link #update(Network,double[])}. */
077    private final AtomicLong numberOfCalls = new AtomicLong(0);
078
079    /**
080     * @param distance Distance function.
081     * @param learningFactor Learning factor update function.
082     * @param neighbourhoodSize Neighbourhood size update function.
083     */
084    public KohonenUpdateAction(DistanceMeasure distance,
085                               LearningFactorFunction learningFactor,
086                               NeighbourhoodSizeFunction neighbourhoodSize) {
087        this.distance = distance;
088        this.learningFactor = learningFactor;
089        this.neighbourhoodSize = neighbourhoodSize;
090    }
091
092    /**
093     * {@inheritDoc}
094     */
095    @Override
096    public void update(Network net,
097                       double[] features) {
098        final long numCalls = numberOfCalls.incrementAndGet() - 1;
099        final double currentLearning = learningFactor.value(numCalls);
100        final Neuron best = findAndUpdateBestNeuron(net,
101                                                    features,
102                                                    currentLearning);
103
104        final int currentNeighbourhood = neighbourhoodSize.value(numCalls);
105        // The farther away the neighbour is from the winning neuron, the
106        // smaller the learning rate will become.
107        final Gaussian neighbourhoodDecay
108            = new Gaussian(currentLearning,
109                           0,
110                           currentNeighbourhood);
111
112        if (currentNeighbourhood > 0) {
113            // Initial set of neurons only contains the winning neuron.
114            Collection<Neuron> neighbours = new HashSet<>();
115            neighbours.add(best);
116            // Winning neuron must be excluded from the neighbours.
117            final HashSet<Neuron> exclude = new HashSet<>();
118            exclude.add(best);
119
120            int radius = 1;
121            do {
122                // Retrieve immediate neighbours of the current set of neurons.
123                neighbours = net.getNeighbours(neighbours, exclude);
124
125                // Update all the neighbours.
126                for (Neuron n : neighbours) {
127                    updateNeighbouringNeuron(n, features, neighbourhoodDecay.value(radius));
128                }
129
130                // Add the neighbours to the exclude list so that they will
131                // not be update more than once per training step.
132                exclude.addAll(neighbours);
133                ++radius;
134            } while (radius <= currentNeighbourhood);
135        }
136    }
137
138    /**
139     * Retrieves the number of calls to the {@link #update(Network,double[]) update}
140     * method.
141     *
142     * @return the current number of calls.
143     */
144    public long getNumberOfCalls() {
145        return numberOfCalls.get();
146    }
147
148    /**
149     * Tries to update a neuron.
150     *
151     * @param n Neuron to be updated.
152     * @param features Training data.
153     * @param learningRate Learning factor.
154     * @return {@code true} if the update succeeded, {@code true} if a
155     * concurrent update has been detected.
156     */
157    private boolean attemptNeuronUpdate(Neuron n,
158                                        double[] features,
159                                        double learningRate) {
160        final double[] expect = n.getFeatures();
161        final double[] update = computeFeatures(expect,
162                                                features,
163                                                learningRate);
164
165        return n.compareAndSetFeatures(expect, update);
166    }
167
168    /**
169     * Atomically updates the given neuron.
170     *
171     * @param n Neuron to be updated.
172     * @param features Training data.
173     * @param learningRate Learning factor.
174     */
175    private void updateNeighbouringNeuron(Neuron n,
176                                          double[] features,
177                                          double learningRate) {
178        while (true) {
179            if (attemptNeuronUpdate(n, features, learningRate)) {
180                break;
181            }
182        }
183    }
184
185    /**
186     * Searches for the neuron whose features are closest to the given
187     * sample, and atomically updates its features.
188     *
189     * @param net Network.
190     * @param features Sample data.
191     * @param learningRate Current learning factor.
192     * @return the winning neuron.
193     */
194    private Neuron findAndUpdateBestNeuron(Network net,
195                                           double[] features,
196                                           double learningRate) {
197        final MapRanking rank = new MapRanking(net, distance);
198
199        while (true) {
200            final Neuron best = rank.rank(features, 1).get(0);
201
202            if (attemptNeuronUpdate(best, features, learningRate)) {
203                return best;
204            }
205
206            // If another thread modified the state of the winning neuron,
207            // it may not be the best match anymore for the given training
208            // sample: Hence, the winner search is performed again.
209        }
210    }
211
212    /**
213     * Computes the new value of the features set.
214     *
215     * @param current Current values of the features.
216     * @param sample Training data.
217     * @param learningRate Learning factor.
218     * @return the new values for the features.
219     */
220    private double[] computeFeatures(double[] current,
221                                     double[] sample,
222                                     double learningRate) {
223        final ArrayRealVector c = new ArrayRealVector(current, false);
224        final ArrayRealVector s = new ArrayRealVector(sample, false);
225        // c + learningRate * (s - c)
226        return s.subtract(c).mapMultiplyToSelf(learningRate).add(c).toArray();
227    }
228}