001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.neuralnet;
019
020import java.util.NoSuchElementException;
021import java.util.List;
022import java.util.ArrayList;
023import java.util.Set;
024import java.util.HashSet;
025import java.util.Collection;
026import java.util.Iterator;
027import java.util.Collections;
028import java.util.Map;
029import java.util.concurrent.ConcurrentHashMap;
030import java.util.concurrent.atomic.AtomicLong;
031import java.util.stream.Collectors;
032
033import org.apache.commons.math4.neuralnet.internal.NeuralNetException;
034
035/**
036 * Neural network, composed of {@link Neuron} instances and the links
037 * between them.
038 *
039 * Although updating a neuron's state is thread-safe, modifying the
040 * network's topology (adding or removing links) is not.
041 *
042 * @since 3.3
043 */
044public class Network
045    implements Iterable<Neuron> {
046    /** Neurons. */
047    private final ConcurrentHashMap<Long, Neuron> neuronMap
048        = new ConcurrentHashMap<>();
049    /** Next available neuron identifier. */
050    private final AtomicLong nextId;
051    /** Neuron's features set size. */
052    private final int featureSize;
053    /** Links. */
054    private final ConcurrentHashMap<Long, Set<Long>> linkMap
055        = new ConcurrentHashMap<>();
056
057    /**
058     * @param firstId Identifier of the first neuron that will be added
059     * to this network.
060     * @param featureSize Size of the neuron's features.
061     */
062    public Network(long firstId,
063                   int featureSize) {
064        this.nextId = new AtomicLong(firstId);
065        this.featureSize = featureSize;
066    }
067
068    /**
069     * Builds a network from a list of neurons and their neighbours.
070     *
071     * @param featureSize Number of features.
072     * @param idList List of neuron identifiers.
073     * @param featureList List of neuron features.
074     * @param neighbourIdList Links associated to each of the neurons in
075     * {@code idList}.
076     * @throws IllegalArgumentException if an inconsistency is detected.
077     * @return a new instance.
078     */
079    public static Network from(int featureSize,
080                               long[] idList,
081                               double[][] featureList,
082                               long[][] neighbourIdList) {
083        final int numNeurons = idList.length;
084        if (idList.length != featureList.length) {
085            throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH,
086                                         idList.length, featureList.length);
087        }
088        if (idList.length != neighbourIdList.length) {
089            throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH,
090                                         idList.length, neighbourIdList.length);
091        }
092
093        final Network net = new Network(Long.MIN_VALUE, featureSize);
094
095        for (int i = 0; i < numNeurons; i++) {
096            final long id = idList[i];
097            net.createNeuron(id, featureList[i]);
098        }
099
100        for (int i = 0; i < numNeurons; i++) {
101            final Neuron a = net.getNeuron(idList[i]);
102            for (final long id : neighbourIdList[i]) {
103                final Neuron b = net.neuronMap.get(id);
104                if (b == null) {
105                    throw new NeuralNetException(NeuralNetException.ID_NOT_FOUND, id);
106                }
107                net.addLink(a, b);
108            }
109        }
110
111        return net;
112    }
113
114    /**
115     * Performs a deep copy of this instance.
116     * Upon return, the copied and original instances will be independent:
117     * Updating one will not affect the other.
118     *
119     * @return a new instance with the same state as this instance.
120     * @since 3.6
121     */
122    public synchronized Network copy() {
123        final Network copy = new Network(nextId.get(),
124                                         featureSize);
125
126
127        for (final Map.Entry<Long, Neuron> e : neuronMap.entrySet()) {
128            copy.neuronMap.put(e.getKey(), e.getValue().copy());
129        }
130
131        for (final Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) {
132            copy.linkMap.put(e.getKey(), new HashSet<>(e.getValue()));
133        }
134
135        return copy;
136    }
137
138    /**
139     * {@inheritDoc}
140     */
141    @Override
142    public Iterator<Neuron> iterator() {
143        return neuronMap.values().iterator();
144    }
145
146    /**
147     * @return a shallow copy of the network's neurons.
148     */
149    public Collection<Neuron> getNeurons() {
150        return Collections.unmodifiableCollection(neuronMap.values());
151    }
152
153    /**
154     * Creates a neuron and assigns it a unique identifier.
155     *
156     * @param features Initial values for the neuron's features.
157     * @return the neuron's identifier.
158     * @throws IllegalArgumentException if the length of {@code features}
159     * is different from the expected size (as set by the
160     * {@link #Network(long,int) constructor}).
161     */
162    public long createNeuron(double[] features) {
163        return createNeuron(createNextId(), features);
164    }
165
166    /**
167     * @param id Identifier.
168     * @param features Features.
169     * @return {@¢ode id}.
170     * @throws IllegalArgumentException if the identifier is already used
171     * by a neuron that belongs to this network or the features size does
172     * not match the expected value.
173     */
174    private long createNeuron(long id,
175                              double[] features) {
176        if (neuronMap.get(id) != null) {
177            throw new NeuralNetException(NeuralNetException.ID_IN_USE, id);
178        }
179
180        if (features.length != featureSize) {
181            throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH,
182                                         features.length, featureSize);
183        }
184
185        neuronMap.put(id, new Neuron(id, features.clone()));
186        linkMap.put(id, new HashSet<>());
187
188        if (id > nextId.get()) {
189            nextId.set(id);
190        }
191
192        return id;
193    }
194
195    /**
196     * Deletes a neuron.
197     * Links from all neighbours to the removed neuron will also be
198     * {@link #deleteLink(Neuron,Neuron) deleted}.
199     *
200     * @param neuron Neuron to be removed from this network.
201     * @throws NoSuchElementException if {@code n} does not belong to
202     * this network.
203     */
204    public void deleteNeuron(Neuron neuron) {
205        // Delete links to from neighbours.
206        getNeighbours(neuron).forEach(neighbour -> deleteLink(neighbour, neuron));
207
208        // Remove neuron.
209        neuronMap.remove(neuron.getIdentifier());
210    }
211
212    /**
213     * Gets the size of the neurons' features set.
214     *
215     * @return the size of the features set.
216     */
217    public int getFeaturesSize() {
218        return featureSize;
219    }
220
221    /**
222     * Adds a link from neuron {@code a} to neuron {@code b}.
223     * Note: the link is not bi-directional; if a bi-directional link is
224     * required, an additional call must be made with {@code a} and
225     * {@code b} exchanged in the argument list.
226     *
227     * @param a Neuron.
228     * @param b Neuron.
229     * @throws NoSuchElementException if the neurons do not exist in the
230     * network.
231     */
232    public void addLink(Neuron a,
233                        Neuron b) {
234        // Check that the neurons belong to this network.
235        final long aId = a.getIdentifier();
236        if (a != getNeuron(aId)) {
237            throw new NoSuchElementException(Long.toString(aId));
238        }
239        final long bId = b.getIdentifier();
240        if (b != getNeuron(bId)) {
241            throw new NoSuchElementException(Long.toString(bId));
242        }
243
244        // Add link from "a" to "b".
245        addLinkToLinkSet(linkMap.get(aId), bId);
246    }
247
248    /**
249     * Adds a link to neuron {@code id} in given {@code linkSet}.
250     * Note: no check verifies that the identifier indeed belongs
251     * to this network.
252     *
253     * @param linkSet Neuron identifier.
254     * @param id Neuron identifier.
255     */
256    private void addLinkToLinkSet(Set<Long> linkSet,
257                                  long id) {
258        linkSet.add(id);
259    }
260
261    /**
262     * Deletes the link between neurons {@code a} and {@code b}.
263     *
264     * @param a Neuron.
265     * @param b Neuron.
266     * @throws NoSuchElementException if the neurons do not exist in the
267     * network.
268     */
269    public void deleteLink(Neuron a,
270                           Neuron b) {
271        // Check that the neurons belong to this network.
272        final long aId = a.getIdentifier();
273        if (a != getNeuron(aId)) {
274            throw new NoSuchElementException(Long.toString(aId));
275        }
276        final long bId = b.getIdentifier();
277        if (b != getNeuron(bId)) {
278            throw new NoSuchElementException(Long.toString(bId));
279        }
280
281        // Delete link from "a" to "b".
282        deleteLinkFromLinkSet(linkMap.get(aId), bId);
283    }
284
285    /**
286     * Deletes a link to neuron {@code id} in given {@code linkSet}.
287     * Note: no check verifies that the identifier indeed belongs
288     * to this network.
289     *
290     * @param linkSet Neuron identifier.
291     * @param id Neuron identifier.
292     */
293    private void deleteLinkFromLinkSet(Set<Long> linkSet,
294                                       long id) {
295        linkSet.remove(id);
296    }
297
298    /**
299     * Retrieves the neuron with the given (unique) {@code id}.
300     *
301     * @param id Identifier.
302     * @return the neuron associated with the given {@code id}.
303     * @throws NoSuchElementException if the neuron does not exist in the
304     * network.
305     */
306    public Neuron getNeuron(long id) {
307        final Neuron n = neuronMap.get(id);
308        if (n == null) {
309            throw new NoSuchElementException(Long.toString(id));
310        }
311        return n;
312    }
313
314    /**
315     * Retrieves the neurons in the neighbourhood of any neuron in the
316     * {@code neurons} list.
317     * @param neurons Neurons for which to retrieve the neighbours.
318     * @return the list of neighbours.
319     * @see #getNeighbours(Iterable,Iterable)
320     */
321    public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) {
322        return getNeighbours(neurons, null);
323    }
324
325    /**
326     * Retrieves the neurons in the neighbourhood of any neuron in the
327     * {@code neurons} list.
328     * The {@code exclude} list allows to retrieve the "concentric"
329     * neighbourhoods by removing the neurons that belong to the inner
330     * "circles".
331     *
332     * @param neurons Neurons for which to retrieve the neighbours.
333     * @param exclude Neurons to exclude from the returned list.
334     * Can be {@code null}.
335     * @return the list of neighbours.
336     */
337    public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons,
338                                            Iterable<Neuron> exclude) {
339        final Set<Long> idList = new HashSet<>();
340        neurons.forEach(n -> idList.addAll(linkMap.get(n.getIdentifier())));
341
342        if (exclude != null) {
343            exclude.forEach(n -> idList.remove(n.getIdentifier()));
344        }
345
346        return idList.stream().map(this::getNeuron).collect(Collectors.toList());
347    }
348
349    /**
350     * Retrieves the neighbours of the given neuron.
351     *
352     * @param neuron Neuron for which to retrieve the neighbours.
353     * @return the list of neighbours.
354     * @see #getNeighbours(Neuron,Iterable)
355     */
356    public Collection<Neuron> getNeighbours(Neuron neuron) {
357        return getNeighbours(neuron, null);
358    }
359
360    /**
361     * Retrieves the neighbours of the given neuron.
362     *
363     * @param neuron Neuron for which to retrieve the neighbours.
364     * @param exclude Neurons to exclude from the returned list.
365     * Can be {@code null}.
366     * @return the list of neighbours.
367     */
368    public Collection<Neuron> getNeighbours(Neuron neuron,
369                                            Iterable<Neuron> exclude) {
370        final Set<Long> idList = linkMap.get(neuron.getIdentifier());
371        if (exclude != null) {
372            for (final Neuron n : exclude) {
373                idList.remove(n.getIdentifier());
374            }
375        }
376
377        final List<Neuron> neuronList = new ArrayList<>();
378        for (final Long id : idList) {
379            neuronList.add(getNeuron(id));
380        }
381
382        return neuronList;
383    }
384
385    /**
386     * Creates a neuron identifier.
387     *
388     * @return a value that will serve as a unique identifier.
389     */
390    private Long createNextId() {
391        return nextId.getAndIncrement();
392    }
393}