001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.ml.neuralnet.sofm;
019
020import java.util.Collection;
021import java.util.HashSet;
022import java.util.concurrent.atomic.AtomicLong;
023import org.apache.commons.math3.ml.neuralnet.Network;
024import org.apache.commons.math3.ml.neuralnet.MapUtils;
025import org.apache.commons.math3.ml.neuralnet.Neuron;
026import org.apache.commons.math3.ml.neuralnet.UpdateAction;
027import org.apache.commons.math3.ml.distance.DistanceMeasure;
028import org.apache.commons.math3.linear.ArrayRealVector;
029import org.apache.commons.math3.analysis.function.Gaussian;
030
031/**
032 * Update formula for <a href="http://en.wikipedia.org/wiki/Kohonen">
033 * Kohonen's Self-Organizing Map</a>.
034 * <br/>
035 * The {@link #update(Network,double[]) update} method modifies the
036 * features {@code w} of the "winning" neuron and its neighbours
037 * according to the following rule:
038 * <code>
039 *  w<sub>new</sub> = w<sub>old</sub> + &alpha; e<sup>(-d / &sigma;)</sup> * (sample - w<sub>old</sub>)
040 * </code>
041 * where
042 * <ul>
043 *  <li>&alpha; is the current <em>learning rate</em>, </li>
044 *  <li>&sigma; is the current <em>neighbourhood size</em>, and</li>
045 *  <li>{@code d} is the number of links to traverse in order to reach
046 *   the neuron from the winning neuron.</li>
047 * </ul>
048 * <br/>
049 * This class is thread-safe as long as the arguments passed to the
050 * {@link #KohonenUpdateAction(DistanceMeasure,LearningFactorFunction,
051 * NeighbourhoodSizeFunction) constructor} are instances of thread-safe
052 * classes.
053 * <br/>
054 * Each call to the {@link #update(Network,double[]) update} method
055 * will increment the internal counter used to compute the current
056 * values for
057 * <ul>
058 *  <li>the <em>learning rate</em>, and</li>
059 *  <li>the <em>neighbourhood size</em>.</li>
060 * </ul>
061 * Consequently, the function instances that compute those values (passed
062 * to the constructor of this class) must take into account whether this
063 * class's instance will be shared by multiple threads, as this will impact
064 * the training process.
065 *
066 * @since 3.3
067 */
068public class KohonenUpdateAction implements UpdateAction {
069    /** Distance function. */
070    private final DistanceMeasure distance;
071    /** Learning factor update function. */
072    private final LearningFactorFunction learningFactor;
073    /** Neighbourhood size update function. */
074    private final NeighbourhoodSizeFunction neighbourhoodSize;
075    /** Number of calls to {@link #update(Network,double[])}. */
076    private final AtomicLong numberOfCalls = new AtomicLong(-1);
077
078    /**
079     * @param distance Distance function.
080     * @param learningFactor Learning factor update function.
081     * @param neighbourhoodSize Neighbourhood size update function.
082     */
083    public KohonenUpdateAction(DistanceMeasure distance,
084                               LearningFactorFunction learningFactor,
085                               NeighbourhoodSizeFunction neighbourhoodSize) {
086        this.distance = distance;
087        this.learningFactor = learningFactor;
088        this.neighbourhoodSize = neighbourhoodSize;
089    }
090
091    /**
092     * {@inheritDoc}
093     */
094    public void update(Network net,
095                       double[] features) {
096        final long numCalls = numberOfCalls.incrementAndGet();
097        final double currentLearning = learningFactor.value(numCalls);
098        final Neuron best = findAndUpdateBestNeuron(net,
099                                                    features,
100                                                    currentLearning);
101
102        final int currentNeighbourhood = neighbourhoodSize.value(numCalls);
103        // The farther away the neighbour is from the winning neuron, the
104        // smaller the learning rate will become.
105        final Gaussian neighbourhoodDecay
106            = new Gaussian(currentLearning,
107                           0,
108                           1d / currentNeighbourhood);
109
110        if (currentNeighbourhood > 0) {
111            // Initial set of neurons only contains the winning neuron.
112            Collection<Neuron> neighbours = new HashSet<Neuron>();
113            neighbours.add(best);
114            // Winning neuron must be excluded from the neighbours.
115            final HashSet<Neuron> exclude = new HashSet<Neuron>();
116            exclude.add(best);
117
118            int radius = 1;
119            do {
120                // Retrieve immediate neighbours of the current set of neurons.
121                neighbours = net.getNeighbours(neighbours, exclude);
122
123                // Update all the neighbours.
124                for (Neuron n : neighbours) {
125                    updateNeighbouringNeuron(n, features, neighbourhoodDecay.value(radius));
126                }
127
128                // Add the neighbours to the exclude list so that they will
129                // not be update more than once per training step.
130                exclude.addAll(neighbours);
131                ++radius;
132            } while (radius <= currentNeighbourhood);
133        }
134    }
135
136    /**
137     * Retrieves the number of calls to the {@link #update(Network,double[]) update}
138     * method.
139     *
140     * @return the current number of calls.
141     */
142    public long getNumberOfCalls() {
143        return numberOfCalls.get();
144    }
145
146    /**
147     * Atomically updates the given neuron.
148     *
149     * @param n Neuron to be updated.
150     * @param features Training data.
151     * @param learningRate Learning factor.
152     */
153    private void updateNeighbouringNeuron(Neuron n,
154                                          double[] features,
155                                          double learningRate) {
156        while (true) {
157            final double[] expect = n.getFeatures();
158            final double[] update = computeFeatures(expect,
159                                                    features,
160                                                    learningRate);
161            if (n.compareAndSetFeatures(expect, update)) {
162                break;
163            }
164        }
165    }
166
167    /**
168     * Searches for the neuron whose features are closest to the given
169     * sample, and atomically updates its features.
170     *
171     * @param net Network.
172     * @param features Sample data.
173     * @param learningRate Current learning factor.
174     * @return the winning neuron.
175     */
176    private Neuron findAndUpdateBestNeuron(Network net,
177                                           double[] features,
178                                           double learningRate) {
179        while (true) {
180            final Neuron best = MapUtils.findBest(features, net, distance);
181
182            final double[] expect = best.getFeatures();
183            final double[] update = computeFeatures(expect,
184                                                    features,
185                                                    learningRate);
186            if (best.compareAndSetFeatures(expect, update)) {
187                return best;
188            }
189
190            // If another thread modified the state of the winning neuron,
191            // it may not be the best match anymore for the given training
192            // sample: Hence, the winner search is performed again.
193        }
194    }
195
196    /**
197     * Computes the new value of the features set.
198     *
199     * @param current Current values of the features.
200     * @param sample Training data.
201     * @param learningRate Learning factor.
202     * @return the new values for the features.
203     */
204    private double[] computeFeatures(double[] current,
205                                     double[] sample,
206                                     double learningRate) {
207        final ArrayRealVector c = new ArrayRealVector(current, false);
208        final ArrayRealVector s = new ArrayRealVector(sample, false);
209        // c + learningRate * (s - c)
210        return s.subtract(c).mapMultiplyToSelf(learningRate).add(c).toArray();
211    }
212}