1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.math4.neuralnet.sofm;
19
20 import java.util.Collection;
21 import java.util.HashSet;
22 import java.util.concurrent.atomic.AtomicLong;
23 import java.util.function.DoubleUnaryOperator;
24
25 import org.apache.commons.math4.neuralnet.DistanceMeasure;
26 import org.apache.commons.math4.neuralnet.MapRanking;
27 import org.apache.commons.math4.neuralnet.Network;
28 import org.apache.commons.math4.neuralnet.Neuron;
29 import org.apache.commons.math4.neuralnet.UpdateAction;
30
31 /**
32 * Update formula for <a href="http://en.wikipedia.org/wiki/Kohonen">
33 * Kohonen's Self-Organizing Map</a>.
34 * <br>
35 * The {@link #update(Network,double[]) update} method modifies the
36 * features {@code w} of the "winning" neuron and its neighbours
37 * according to the following rule:
38 * <code>
39 * w<sub>new</sub> = w<sub>old</sub> + α e<sup>(-d / σ)</sup> * (sample - w<sub>old</sub>)
40 * </code>
41 * where
42 * <ul>
43 * <li>α is the current <em>learning rate</em>, </li>
44 * <li>σ is the current <em>neighbourhood size</em>, and</li>
45 * <li>{@code d} is the number of links to traverse in order to reach
46 * the neuron from the winning neuron.</li>
47 * </ul>
48 * <br>
49 * This class is thread-safe as long as the arguments passed to the
50 * {@link #KohonenUpdateAction(DistanceMeasure,LearningFactorFunction,
51 * NeighbourhoodSizeFunction) constructor} are instances of thread-safe
52 * classes.
53 * <br>
54 * Each call to the {@link #update(Network,double[]) update} method
55 * will increment the internal counter used to compute the current
56 * values for
57 * <ul>
58 * <li>the <em>learning rate</em>, and</li>
59 * <li>the <em>neighbourhood size</em>.</li>
60 * </ul>
61 * Consequently, the function instances that compute those values (passed
62 * to the constructor of this class) must take into account whether this
63 * class's instance will be shared by multiple threads, as this will impact
64 * the training process.
65 *
66 * @since 3.3
67 */
68 public class KohonenUpdateAction implements UpdateAction {
69 /** Distance function. */
70 private final DistanceMeasure distance;
71 /** Learning factor update function. */
72 private final LearningFactorFunction learningFactor;
73 /** Neighbourhood size update function. */
74 private final NeighbourhoodSizeFunction neighbourhoodSize;
75 /** Number of calls to {@link #update(Network,double[])}. */
76 private final AtomicLong numberOfCalls = new AtomicLong(0);
77
78 /**
79 * @param distance Distance function.
80 * @param learningFactor Learning factor update function.
81 * @param neighbourhoodSize Neighbourhood size update function.
82 */
83 public KohonenUpdateAction(DistanceMeasure distance,
84 LearningFactorFunction learningFactor,
85 NeighbourhoodSizeFunction neighbourhoodSize) {
86 this.distance = distance;
87 this.learningFactor = learningFactor;
88 this.neighbourhoodSize = neighbourhoodSize;
89 }
90
91 /**
92 * {@inheritDoc}
93 */
94 @Override
95 public void update(Network net,
96 double[] features) {
97 final long numCalls = numberOfCalls.incrementAndGet() - 1;
98 final double currentLearning = learningFactor.value(numCalls);
99 final Neuron best = findAndUpdateBestNeuron(net,
100 features,
101 currentLearning);
102
103 final int currentNeighbourhood = neighbourhoodSize.value(numCalls);
104 // The farther away the neighbour is from the winning neuron, the
105 // smaller the learning rate will become.
106 final Gaussian neighbourhoodDecay
107 = new Gaussian(currentLearning, currentNeighbourhood);
108
109 if (currentNeighbourhood > 0) {
110 // Initial set of neurons only contains the winning neuron.
111 Collection<Neuron> neighbours = new HashSet<>();
112 neighbours.add(best);
113 // Winning neuron must be excluded from the neighbours.
114 final HashSet<Neuron> exclude = new HashSet<>();
115 exclude.add(best);
116
117 int radius = 1;
118 do {
119 // Retrieve immediate neighbours of the current set of neurons.
120 neighbours = net.getNeighbours(neighbours, exclude);
121
122 // Update all the neighbours.
123 for (final Neuron n : neighbours) {
124 updateNeighbouringNeuron(n, features, neighbourhoodDecay.applyAsDouble(radius));
125 }
126
127 // Add the neighbours to the exclude list so that they will
128 // not be updated more than once per training step.
129 exclude.addAll(neighbours);
130 ++radius;
131 } while (radius <= currentNeighbourhood);
132 }
133 }
134
135 /**
136 * Retrieves the number of calls to the {@link #update(Network,double[]) update}
137 * method.
138 *
139 * @return the current number of calls.
140 */
141 public long getNumberOfCalls() {
142 return numberOfCalls.get();
143 }
144
145 /**
146 * Tries to update a neuron.
147 *
148 * @param n Neuron to be updated.
149 * @param features Training data.
150 * @param learningRate Learning factor.
151 * @return {@code true} if the update succeeded, {@code true} if a
152 * concurrent update has been detected.
153 */
154 private boolean attemptNeuronUpdate(Neuron n,
155 double[] features,
156 double learningRate) {
157 final double[] expect = n.getFeatures();
158 final double[] update = computeFeatures(expect,
159 features,
160 learningRate);
161
162 return n.compareAndSetFeatures(expect, update);
163 }
164
165 /**
166 * Atomically updates the given neuron.
167 *
168 * @param n Neuron to be updated.
169 * @param features Training data.
170 * @param learningRate Learning factor.
171 */
172 private void updateNeighbouringNeuron(Neuron n,
173 double[] features,
174 double learningRate) {
175 while (true) {
176 if (attemptNeuronUpdate(n, features, learningRate)) {
177 break;
178 }
179 }
180 }
181
182 /**
183 * Searches for the neuron whose features are closest to the given
184 * sample, and atomically updates its features.
185 *
186 * @param net Network.
187 * @param features Sample data.
188 * @param learningRate Current learning factor.
189 * @return the winning neuron.
190 */
191 private Neuron findAndUpdateBestNeuron(Network net,
192 double[] features,
193 double learningRate) {
194 final MapRanking rank = new MapRanking(net, distance);
195
196 while (true) {
197 final Neuron best = rank.rank(features, 1).get(0);
198
199 if (attemptNeuronUpdate(best, features, learningRate)) {
200 return best;
201 }
202
203 // If another thread modified the state of the winning neuron,
204 // it may not be the best match anymore for the given training
205 // sample: Hence, the winner search is performed again.
206 }
207 }
208
209 /**
210 * Computes the new value of the features set.
211 *
212 * @param current Current values of the features.
213 * @param sample Training data.
214 * @param learningRate Learning factor.
215 * @return the new values for the features.
216 */
217 private double[] computeFeatures(double[] current,
218 double[] sample,
219 double learningRate) {
220 final int len = current.length;
221 final double[] r = new double[len];
222 for (int i = 0; i < len; i++) {
223 final double c = current[i];
224 final double s = sample[i];
225 r[i] = c + learningRate * (s - c);
226 }
227 return r;
228 }
229
230 /**
231 * Gaussian function with zero mean.
232 */
233 private static class Gaussian implements DoubleUnaryOperator {
234 /** Inverse of twice the square of the standard deviation. */
235 private final double i2s2;
236 /** Normalization factor. */
237 private final double norm;
238
239 /**
240 * @param norm Normalization factor.
241 * @param sigma Standard deviation.
242 */
243 Gaussian(double norm,
244 double sigma) {
245 this.norm = norm;
246 i2s2 = 1d / (2 * sigma * sigma);
247 }
248
249 @Override
250 public double applyAsDouble(double x) {
251 return norm * Math.exp(-x * x * i2s2);
252 }
253 }
254 }