001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math3.ml.neuralnet.sofm; 019 020import java.util.Collection; 021import java.util.HashSet; 022import java.util.concurrent.atomic.AtomicLong; 023 024import org.apache.commons.math3.analysis.function.Gaussian; 025import org.apache.commons.math3.linear.ArrayRealVector; 026import org.apache.commons.math3.ml.distance.DistanceMeasure; 027import org.apache.commons.math3.ml.neuralnet.MapUtils; 028import org.apache.commons.math3.ml.neuralnet.Network; 029import org.apache.commons.math3.ml.neuralnet.Neuron; 030import org.apache.commons.math3.ml.neuralnet.UpdateAction; 031 032/** 033 * Update formula for <a href="http://en.wikipedia.org/wiki/Kohonen"> 034 * Kohonen's Self-Organizing Map</a>. 035 * <br/> 036 * The {@link #update(Network,double[]) update} method modifies the 037 * features {@code w} of the "winning" neuron and its neighbours 038 * according to the following rule: 039 * <code> 040 * w<sub>new</sub> = w<sub>old</sub> + α e<sup>(-d / σ)</sup> * (sample - w<sub>old</sub>) 041 * </code> 042 * where 043 * <ul> 044 * <li>α is the current <em>learning rate</em>, </li> 045 * <li>σ is the current <em>neighbourhood size</em>, and</li> 046 * <li>{@code d} is the number of links to traverse in order to reach 047 * the neuron from the winning neuron.</li> 048 * </ul> 049 * <br/> 050 * This class is thread-safe as long as the arguments passed to the 051 * {@link #KohonenUpdateAction(DistanceMeasure,LearningFactorFunction, 052 * NeighbourhoodSizeFunction) constructor} are instances of thread-safe 053 * classes. 054 * <br/> 055 * Each call to the {@link #update(Network,double[]) update} method 056 * will increment the internal counter used to compute the current 057 * values for 058 * <ul> 059 * <li>the <em>learning rate</em>, and</li> 060 * <li>the <em>neighbourhood size</em>.</li> 061 * </ul> 062 * Consequently, the function instances that compute those values (passed 063 * to the constructor of this class) must take into account whether this 064 * class's instance will be shared by multiple threads, as this will impact 065 * the training process. 066 * 067 * @since 3.3 068 */ 069public class KohonenUpdateAction implements UpdateAction { 070 /** Distance function. */ 071 private final DistanceMeasure distance; 072 /** Learning factor update function. */ 073 private final LearningFactorFunction learningFactor; 074 /** Neighbourhood size update function. */ 075 private final NeighbourhoodSizeFunction neighbourhoodSize; 076 /** Number of calls to {@link #update(Network,double[])}. */ 077 private final AtomicLong numberOfCalls = new AtomicLong(0); 078 079 /** 080 * @param distance Distance function. 081 * @param learningFactor Learning factor update function. 082 * @param neighbourhoodSize Neighbourhood size update function. 083 */ 084 public KohonenUpdateAction(DistanceMeasure distance, 085 LearningFactorFunction learningFactor, 086 NeighbourhoodSizeFunction neighbourhoodSize) { 087 this.distance = distance; 088 this.learningFactor = learningFactor; 089 this.neighbourhoodSize = neighbourhoodSize; 090 } 091 092 /** 093 * {@inheritDoc} 094 */ 095 public void update(Network net, 096 double[] features) { 097 final long numCalls = numberOfCalls.incrementAndGet() - 1; 098 final double currentLearning = learningFactor.value(numCalls); 099 final Neuron best = findAndUpdateBestNeuron(net, 100 features, 101 currentLearning); 102 103 final int currentNeighbourhood = neighbourhoodSize.value(numCalls); 104 // The farther away the neighbour is from the winning neuron, the 105 // smaller the learning rate will become. 106 final Gaussian neighbourhoodDecay 107 = new Gaussian(currentLearning, 108 0, 109 currentNeighbourhood); 110 111 if (currentNeighbourhood > 0) { 112 // Initial set of neurons only contains the winning neuron. 113 Collection<Neuron> neighbours = new HashSet<Neuron>(); 114 neighbours.add(best); 115 // Winning neuron must be excluded from the neighbours. 116 final HashSet<Neuron> exclude = new HashSet<Neuron>(); 117 exclude.add(best); 118 119 int radius = 1; 120 do { 121 // Retrieve immediate neighbours of the current set of neurons. 122 neighbours = net.getNeighbours(neighbours, exclude); 123 124 // Update all the neighbours. 125 for (Neuron n : neighbours) { 126 updateNeighbouringNeuron(n, features, neighbourhoodDecay.value(radius)); 127 } 128 129 // Add the neighbours to the exclude list so that they will 130 // not be update more than once per training step. 131 exclude.addAll(neighbours); 132 ++radius; 133 } while (radius <= currentNeighbourhood); 134 } 135 } 136 137 /** 138 * Retrieves the number of calls to the {@link #update(Network,double[]) update} 139 * method. 140 * 141 * @return the current number of calls. 142 */ 143 public long getNumberOfCalls() { 144 return numberOfCalls.get(); 145 } 146 147 /** 148 * Tries to update a neuron. 149 * 150 * @param n Neuron to be updated. 151 * @param features Training data. 152 * @param learningRate Learning factor. 153 * @return {@code true} if the update succeeded, {@code true} if a 154 * concurrent update has been detected. 155 */ 156 private boolean attemptNeuronUpdate(Neuron n, 157 double[] features, 158 double learningRate) { 159 final double[] expect = n.getFeatures(); 160 final double[] update = computeFeatures(expect, 161 features, 162 learningRate); 163 164 return n.compareAndSetFeatures(expect, update); 165 } 166 167 /** 168 * Atomically updates the given neuron. 169 * 170 * @param n Neuron to be updated. 171 * @param features Training data. 172 * @param learningRate Learning factor. 173 */ 174 private void updateNeighbouringNeuron(Neuron n, 175 double[] features, 176 double learningRate) { 177 while (true) { 178 if (attemptNeuronUpdate(n, features, learningRate)) { 179 break; 180 } 181 } 182 } 183 184 /** 185 * Searches for the neuron whose features are closest to the given 186 * sample, and atomically updates its features. 187 * 188 * @param net Network. 189 * @param features Sample data. 190 * @param learningRate Current learning factor. 191 * @return the winning neuron. 192 */ 193 private Neuron findAndUpdateBestNeuron(Network net, 194 double[] features, 195 double learningRate) { 196 while (true) { 197 final Neuron best = MapUtils.findBest(features, net, distance); 198 199 if (attemptNeuronUpdate(best, features, learningRate)) { 200 return best; 201 } 202 203 // If another thread modified the state of the winning neuron, 204 // it may not be the best match anymore for the given training 205 // sample: Hence, the winner search is performed again. 206 } 207 } 208 209 /** 210 * Computes the new value of the features set. 211 * 212 * @param current Current values of the features. 213 * @param sample Training data. 214 * @param learningRate Learning factor. 215 * @return the new values for the features. 216 */ 217 private double[] computeFeatures(double[] current, 218 double[] sample, 219 double learningRate) { 220 final ArrayRealVector c = new ArrayRealVector(current, false); 221 final ArrayRealVector s = new ArrayRealVector(sample, false); 222 // c + learningRate * (s - c) 223 return s.subtract(c).mapMultiplyToSelf(learningRate).add(c).toArray(); 224 } 225}