View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.neuralnet;
19  
20  import java.util.NoSuchElementException;
21  import java.util.List;
22  import java.util.ArrayList;
23  import java.util.Set;
24  import java.util.HashSet;
25  import java.util.Collection;
26  import java.util.Iterator;
27  import java.util.Collections;
28  import java.util.Map;
29  import java.util.concurrent.ConcurrentHashMap;
30  import java.util.concurrent.atomic.AtomicLong;
31  import java.util.stream.Collectors;
32  
33  import org.apache.commons.math4.neuralnet.internal.NeuralNetException;
34  
35  /**
36   * Neural network, composed of {@link Neuron} instances and the links
37   * between them.
38   *
39   * Although updating a neuron's state is thread-safe, modifying the
40   * network's topology (adding or removing links) is not.
41   *
42   * @since 3.3
43   */
44  public class Network
45      implements Iterable<Neuron> {
46      /** Neurons. */
47      private final ConcurrentHashMap<Long, Neuron> neuronMap
48          = new ConcurrentHashMap<>();
49      /** Next available neuron identifier. */
50      private final AtomicLong nextId;
51      /** Neuron's features set size. */
52      private final int featureSize;
53      /** Links. */
54      private final ConcurrentHashMap<Long, Set<Long>> linkMap
55          = new ConcurrentHashMap<>();
56  
57      /**
58       * @param firstId Identifier of the first neuron that will be added
59       * to this network.
60       * @param featureSize Size of the neuron's features.
61       */
62      public Network(long firstId,
63                     int featureSize) {
64          this.nextId = new AtomicLong(firstId);
65          this.featureSize = featureSize;
66      }
67  
68      /**
69       * Builds a network from a list of neurons and their neighbours.
70       *
71       * @param featureSize Number of features.
72       * @param idList List of neuron identifiers.
73       * @param featureList List of neuron features.
74       * @param neighbourIdList Links associated to each of the neurons in
75       * {@code idList}.
76       * @throws IllegalArgumentException if an inconsistency is detected.
77       * @return a new instance.
78       */
79      public static Network from(int featureSize,
80                                 long[] idList,
81                                 double[][] featureList,
82                                 long[][] neighbourIdList) {
83          final int numNeurons = idList.length;
84          if (idList.length != featureList.length) {
85              throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH,
86                                           idList.length, featureList.length);
87          }
88          if (idList.length != neighbourIdList.length) {
89              throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH,
90                                           idList.length, neighbourIdList.length);
91          }
92  
93          final Network net = new Network(Long.MIN_VALUE, featureSize);
94  
95          for (int i = 0; i < numNeurons; i++) {
96              final long id = idList[i];
97              net.createNeuron(id, featureList[i]);
98          }
99  
100         for (int i = 0; i < numNeurons; i++) {
101             final Neuron a = net.getNeuron(idList[i]);
102             for (final long id : neighbourIdList[i]) {
103                 final Neuron b = net.neuronMap.get(id);
104                 if (b == null) {
105                     throw new NeuralNetException(NeuralNetException.ID_NOT_FOUND, id);
106                 }
107                 net.addLink(a, b);
108             }
109         }
110 
111         return net;
112     }
113 
114     /**
115      * Performs a deep copy of this instance.
116      * Upon return, the copied and original instances will be independent:
117      * Updating one will not affect the other.
118      *
119      * @return a new instance with the same state as this instance.
120      * @since 3.6
121      */
122     public synchronized Network copy() {
123         final Network copy = new Network(nextId.get(),
124                                          featureSize);
125 
126 
127         for (final Map.Entry<Long, Neuron> e : neuronMap.entrySet()) {
128             copy.neuronMap.put(e.getKey(), e.getValue().copy());
129         }
130 
131         for (final Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) {
132             copy.linkMap.put(e.getKey(), new HashSet<>(e.getValue()));
133         }
134 
135         return copy;
136     }
137 
138     /**
139      * {@inheritDoc}
140      */
141     @Override
142     public Iterator<Neuron> iterator() {
143         return neuronMap.values().iterator();
144     }
145 
146     /**
147      * @return a shallow copy of the network's neurons.
148      */
149     public Collection<Neuron> getNeurons() {
150         return Collections.unmodifiableCollection(neuronMap.values());
151     }
152 
153     /**
154      * Creates a neuron and assigns it a unique identifier.
155      *
156      * @param features Initial values for the neuron's features.
157      * @return the neuron's identifier.
158      * @throws IllegalArgumentException if the length of {@code features}
159      * is different from the expected size (as set by the
160      * {@link #Network(long,int) constructor}).
161      */
162     public long createNeuron(double[] features) {
163         return createNeuron(createNextId(), features);
164     }
165 
166     /**
167      * @param id Identifier.
168      * @param features Features.
169      * @return {@¢ode id}.
170      * @throws IllegalArgumentException if the identifier is already used
171      * by a neuron that belongs to this network or the features size does
172      * not match the expected value.
173      */
174     private long createNeuron(long id,
175                               double[] features) {
176         if (neuronMap.get(id) != null) {
177             throw new NeuralNetException(NeuralNetException.ID_IN_USE, id);
178         }
179 
180         if (features.length != featureSize) {
181             throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH,
182                                          features.length, featureSize);
183         }
184 
185         neuronMap.put(id, new Neuron(id, features.clone()));
186         linkMap.put(id, new HashSet<>());
187 
188         if (id > nextId.get()) {
189             nextId.set(id);
190         }
191 
192         return id;
193     }
194 
195     /**
196      * Deletes a neuron.
197      * Links from all neighbours to the removed neuron will also be
198      * {@link #deleteLink(Neuron,Neuron) deleted}.
199      *
200      * @param neuron Neuron to be removed from this network.
201      * @throws NoSuchElementException if {@code n} does not belong to
202      * this network.
203      */
204     public void deleteNeuron(Neuron neuron) {
205         // Delete links to from neighbours.
206         getNeighbours(neuron).forEach(neighbour -> deleteLink(neighbour, neuron));
207 
208         // Remove neuron.
209         neuronMap.remove(neuron.getIdentifier());
210     }
211 
212     /**
213      * Gets the size of the neurons' features set.
214      *
215      * @return the size of the features set.
216      */
217     public int getFeaturesSize() {
218         return featureSize;
219     }
220 
221     /**
222      * Adds a link from neuron {@code a} to neuron {@code b}.
223      * Note: the link is not bi-directional; if a bi-directional link is
224      * required, an additional call must be made with {@code a} and
225      * {@code b} exchanged in the argument list.
226      *
227      * @param a Neuron.
228      * @param b Neuron.
229      * @throws NoSuchElementException if the neurons do not exist in the
230      * network.
231      */
232     public void addLink(Neuron a,
233                         Neuron b) {
234         // Check that the neurons belong to this network.
235         final long aId = a.getIdentifier();
236         if (a != getNeuron(aId)) {
237             throw new NoSuchElementException(Long.toString(aId));
238         }
239         final long bId = b.getIdentifier();
240         if (b != getNeuron(bId)) {
241             throw new NoSuchElementException(Long.toString(bId));
242         }
243 
244         // Add link from "a" to "b".
245         addLinkToLinkSet(linkMap.get(aId), bId);
246     }
247 
248     /**
249      * Adds a link to neuron {@code id} in given {@code linkSet}.
250      * Note: no check verifies that the identifier indeed belongs
251      * to this network.
252      *
253      * @param linkSet Neuron identifier.
254      * @param id Neuron identifier.
255      */
256     private void addLinkToLinkSet(Set<Long> linkSet,
257                                   long id) {
258         linkSet.add(id);
259     }
260 
261     /**
262      * Deletes the link between neurons {@code a} and {@code b}.
263      *
264      * @param a Neuron.
265      * @param b Neuron.
266      * @throws NoSuchElementException if the neurons do not exist in the
267      * network.
268      */
269     public void deleteLink(Neuron a,
270                            Neuron b) {
271         // Check that the neurons belong to this network.
272         final long aId = a.getIdentifier();
273         if (a != getNeuron(aId)) {
274             throw new NoSuchElementException(Long.toString(aId));
275         }
276         final long bId = b.getIdentifier();
277         if (b != getNeuron(bId)) {
278             throw new NoSuchElementException(Long.toString(bId));
279         }
280 
281         // Delete link from "a" to "b".
282         deleteLinkFromLinkSet(linkMap.get(aId), bId);
283     }
284 
285     /**
286      * Deletes a link to neuron {@code id} in given {@code linkSet}.
287      * Note: no check verifies that the identifier indeed belongs
288      * to this network.
289      *
290      * @param linkSet Neuron identifier.
291      * @param id Neuron identifier.
292      */
293     private void deleteLinkFromLinkSet(Set<Long> linkSet,
294                                        long id) {
295         linkSet.remove(id);
296     }
297 
298     /**
299      * Retrieves the neuron with the given (unique) {@code id}.
300      *
301      * @param id Identifier.
302      * @return the neuron associated with the given {@code id}.
303      * @throws NoSuchElementException if the neuron does not exist in the
304      * network.
305      */
306     public Neuron getNeuron(long id) {
307         final Neuron n = neuronMap.get(id);
308         if (n == null) {
309             throw new NoSuchElementException(Long.toString(id));
310         }
311         return n;
312     }
313 
314     /**
315      * Retrieves the neurons in the neighbourhood of any neuron in the
316      * {@code neurons} list.
317      * @param neurons Neurons for which to retrieve the neighbours.
318      * @return the list of neighbours.
319      * @see #getNeighbours(Iterable,Iterable)
320      */
321     public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) {
322         return getNeighbours(neurons, null);
323     }
324 
325     /**
326      * Retrieves the neurons in the neighbourhood of any neuron in the
327      * {@code neurons} list.
328      * The {@code exclude} list allows to retrieve the "concentric"
329      * neighbourhoods by removing the neurons that belong to the inner
330      * "circles".
331      *
332      * @param neurons Neurons for which to retrieve the neighbours.
333      * @param exclude Neurons to exclude from the returned list.
334      * Can be {@code null}.
335      * @return the list of neighbours.
336      */
337     public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons,
338                                             Iterable<Neuron> exclude) {
339         final Set<Long> idList = new HashSet<>();
340         neurons.forEach(n -> idList.addAll(linkMap.get(n.getIdentifier())));
341 
342         if (exclude != null) {
343             exclude.forEach(n -> idList.remove(n.getIdentifier()));
344         }
345 
346         return idList.stream().map(this::getNeuron).collect(Collectors.toList());
347     }
348 
349     /**
350      * Retrieves the neighbours of the given neuron.
351      *
352      * @param neuron Neuron for which to retrieve the neighbours.
353      * @return the list of neighbours.
354      * @see #getNeighbours(Neuron,Iterable)
355      */
356     public Collection<Neuron> getNeighbours(Neuron neuron) {
357         return getNeighbours(neuron, null);
358     }
359 
360     /**
361      * Retrieves the neighbours of the given neuron.
362      *
363      * @param neuron Neuron for which to retrieve the neighbours.
364      * @param exclude Neurons to exclude from the returned list.
365      * Can be {@code null}.
366      * @return the list of neighbours.
367      */
368     public Collection<Neuron> getNeighbours(Neuron neuron,
369                                             Iterable<Neuron> exclude) {
370         final Set<Long> idList = linkMap.get(neuron.getIdentifier());
371         if (exclude != null) {
372             for (final Neuron n : exclude) {
373                 idList.remove(n.getIdentifier());
374             }
375         }
376 
377         final List<Neuron> neuronList = new ArrayList<>();
378         for (final Long id : idList) {
379             neuronList.add(getNeuron(id));
380         }
381 
382         return neuronList;
383     }
384 
385     /**
386      * Creates a neuron identifier.
387      *
388      * @return a value that will serve as a unique identifier.
389      */
390     private Long createNextId() {
391         return nextId.getAndIncrement();
392     }
393 }