View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.neuralnet.sofm;
19  
20  import java.util.Collection;
21  import java.util.HashSet;
22  import java.util.concurrent.atomic.AtomicLong;
23  import java.util.function.DoubleUnaryOperator;
24  
25  import org.apache.commons.math4.neuralnet.DistanceMeasure;
26  import org.apache.commons.math4.neuralnet.MapRanking;
27  import org.apache.commons.math4.neuralnet.Network;
28  import org.apache.commons.math4.neuralnet.Neuron;
29  import org.apache.commons.math4.neuralnet.UpdateAction;
30  
31  /**
32   * Update formula for <a href="http://en.wikipedia.org/wiki/Kohonen">
33   * Kohonen's Self-Organizing Map</a>.
34   * <br>
35   * The {@link #update(Network,double[]) update} method modifies the
36   * features {@code w} of the "winning" neuron and its neighbours
37   * according to the following rule:
38   * <code>
39   *  w<sub>new</sub> = w<sub>old</sub> + &alpha; e<sup>(-d / &sigma;)</sup> * (sample - w<sub>old</sub>)
40   * </code>
41   * where
42   * <ul>
43   *  <li>&alpha; is the current <em>learning rate</em>, </li>
44   *  <li>&sigma; is the current <em>neighbourhood size</em>, and</li>
45   *  <li>{@code d} is the number of links to traverse in order to reach
46   *   the neuron from the winning neuron.</li>
47   * </ul>
48   * <br>
49   * This class is thread-safe as long as the arguments passed to the
50   * {@link #KohonenUpdateAction(DistanceMeasure,LearningFactorFunction,
51   * NeighbourhoodSizeFunction) constructor} are instances of thread-safe
52   * classes.
53   * <br>
54   * Each call to the {@link #update(Network,double[]) update} method
55   * will increment the internal counter used to compute the current
56   * values for
57   * <ul>
58   *  <li>the <em>learning rate</em>, and</li>
59   *  <li>the <em>neighbourhood size</em>.</li>
60   * </ul>
61   * Consequently, the function instances that compute those values (passed
62   * to the constructor of this class) must take into account whether this
63   * class's instance will be shared by multiple threads, as this will impact
64   * the training process.
65   *
66   * @since 3.3
67   */
68  public class KohonenUpdateAction implements UpdateAction {
69      /** Distance function. */
70      private final DistanceMeasure distance;
71      /** Learning factor update function. */
72      private final LearningFactorFunction learningFactor;
73      /** Neighbourhood size update function. */
74      private final NeighbourhoodSizeFunction neighbourhoodSize;
75      /** Number of calls to {@link #update(Network,double[])}. */
76      private final AtomicLong numberOfCalls = new AtomicLong(0);
77  
78      /**
79       * @param distance Distance function.
80       * @param learningFactor Learning factor update function.
81       * @param neighbourhoodSize Neighbourhood size update function.
82       */
83      public KohonenUpdateAction(DistanceMeasure distance,
84                                 LearningFactorFunction learningFactor,
85                                 NeighbourhoodSizeFunction neighbourhoodSize) {
86          this.distance = distance;
87          this.learningFactor = learningFactor;
88          this.neighbourhoodSize = neighbourhoodSize;
89      }
90  
91      /**
92       * {@inheritDoc}
93       */
94      @Override
95      public void update(Network net,
96                         double[] features) {
97          final long numCalls = numberOfCalls.incrementAndGet() - 1;
98          final double currentLearning = learningFactor.value(numCalls);
99          final Neuron best = findAndUpdateBestNeuron(net,
100                                                     features,
101                                                     currentLearning);
102 
103         final int currentNeighbourhood = neighbourhoodSize.value(numCalls);
104         // The farther away the neighbour is from the winning neuron, the
105         // smaller the learning rate will become.
106         final Gaussian neighbourhoodDecay
107             = new Gaussian(currentLearning, currentNeighbourhood);
108 
109         if (currentNeighbourhood > 0) {
110             // Initial set of neurons only contains the winning neuron.
111             Collection<Neuron> neighbours = new HashSet<>();
112             neighbours.add(best);
113             // Winning neuron must be excluded from the neighbours.
114             final HashSet<Neuron> exclude = new HashSet<>();
115             exclude.add(best);
116 
117             int radius = 1;
118             do {
119                 // Retrieve immediate neighbours of the current set of neurons.
120                 neighbours = net.getNeighbours(neighbours, exclude);
121 
122                 // Update all the neighbours.
123                 for (final Neuron n : neighbours) {
124                     updateNeighbouringNeuron(n, features, neighbourhoodDecay.applyAsDouble(radius));
125                 }
126 
127                 // Add the neighbours to the exclude list so that they will
128                 // not be updated more than once per training step.
129                 exclude.addAll(neighbours);
130                 ++radius;
131             } while (radius <= currentNeighbourhood);
132         }
133     }
134 
135     /**
136      * Retrieves the number of calls to the {@link #update(Network,double[]) update}
137      * method.
138      *
139      * @return the current number of calls.
140      */
141     public long getNumberOfCalls() {
142         return numberOfCalls.get();
143     }
144 
145     /**
146      * Tries to update a neuron.
147      *
148      * @param n Neuron to be updated.
149      * @param features Training data.
150      * @param learningRate Learning factor.
151      * @return {@code true} if the update succeeded, {@code true} if a
152      * concurrent update has been detected.
153      */
154     private boolean attemptNeuronUpdate(Neuron n,
155                                         double[] features,
156                                         double learningRate) {
157         final double[] expect = n.getFeatures();
158         final double[] update = computeFeatures(expect,
159                                                 features,
160                                                 learningRate);
161 
162         return n.compareAndSetFeatures(expect, update);
163     }
164 
165     /**
166      * Atomically updates the given neuron.
167      *
168      * @param n Neuron to be updated.
169      * @param features Training data.
170      * @param learningRate Learning factor.
171      */
172     private void updateNeighbouringNeuron(Neuron n,
173                                           double[] features,
174                                           double learningRate) {
175         while (true) {
176             if (attemptNeuronUpdate(n, features, learningRate)) {
177                 break;
178             }
179         }
180     }
181 
182     /**
183      * Searches for the neuron whose features are closest to the given
184      * sample, and atomically updates its features.
185      *
186      * @param net Network.
187      * @param features Sample data.
188      * @param learningRate Current learning factor.
189      * @return the winning neuron.
190      */
191     private Neuron findAndUpdateBestNeuron(Network net,
192                                            double[] features,
193                                            double learningRate) {
194         final MapRanking rank = new MapRanking(net, distance);
195 
196         while (true) {
197             final Neuron best = rank.rank(features, 1).get(0);
198 
199             if (attemptNeuronUpdate(best, features, learningRate)) {
200                 return best;
201             }
202 
203             // If another thread modified the state of the winning neuron,
204             // it may not be the best match anymore for the given training
205             // sample: Hence, the winner search is performed again.
206         }
207     }
208 
209     /**
210      * Computes the new value of the features set.
211      *
212      * @param current Current values of the features.
213      * @param sample Training data.
214      * @param learningRate Learning factor.
215      * @return the new values for the features.
216      */
217     private double[] computeFeatures(double[] current,
218                                      double[] sample,
219                                      double learningRate) {
220         final int len = current.length;
221         final double[] r = new double[len];
222         for (int i = 0; i < len; i++) {
223             final double c = current[i];
224             final double s = sample[i];
225             r[i] = c + learningRate * (s - c);
226         }
227         return r;
228     }
229 
230     /**
231      * Gaussian function with zero mean.
232      */
233     private static class Gaussian implements DoubleUnaryOperator {
234         /** Inverse of twice the square of the standard deviation. */
235         private final double i2s2;
236         /** Normalization factor. */
237         private final double norm;
238 
239         /**
240          * @param norm Normalization factor.
241          * @param sigma Standard deviation.
242          */
243         Gaussian(double norm,
244                  double sigma) {
245             this.norm = norm;
246             i2s2 = 1d / (2 * sigma * sigma);
247         }
248 
249         @Override
250         public double applyAsDouble(double x) {
251             return norm * Math.exp(-x * x * i2s2);
252         }
253     }
254 }