001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.ml.neuralnet;
019
020import java.io.Serializable;
021import java.io.ObjectInputStream;
022import java.util.NoSuchElementException;
023import java.util.List;
024import java.util.ArrayList;
025import java.util.Set;
026import java.util.HashSet;
027import java.util.Collection;
028import java.util.Iterator;
029import java.util.Comparator;
030import java.util.Collections;
031import java.util.Map;
032import java.util.concurrent.ConcurrentHashMap;
033import java.util.concurrent.atomic.AtomicLong;
034import java.util.stream.Collectors;
035
036import org.apache.commons.math4.exception.DimensionMismatchException;
037import org.apache.commons.math4.exception.MathIllegalStateException;
038
039/**
040 * Neural network, composed of {@link Neuron} instances and the links
041 * between them.
042 *
043 * Although updating a neuron's state is thread-safe, modifying the
044 * network's topology (adding or removing links) is not.
045 *
046 * @since 3.3
047 */
048public class Network
049    implements Iterable<Neuron>,
050               Serializable {
051    /** Serializable. */
052    private static final long serialVersionUID = 20130207L;
053    /** Neurons. */
054    private final ConcurrentHashMap<Long, Neuron> neuronMap
055        = new ConcurrentHashMap<>();
056    /** Next available neuron identifier. */
057    private final AtomicLong nextId;
058    /** Neuron's features set size. */
059    private final int featureSize;
060    /** Links. */
061    private final ConcurrentHashMap<Long, Set<Long>> linkMap
062        = new ConcurrentHashMap<>();
063
064    /**
065     * Comparator that prescribes an order of the neurons according
066     * to the increasing order of their identifier.
067     */
068    public static class NeuronIdentifierComparator
069        implements Comparator<Neuron>,
070                   Serializable {
071        /** Version identifier. */
072        private static final long serialVersionUID = 20130207L;
073
074        /** {@inheritDoc} */
075        @Override
076        public int compare(Neuron a,
077                           Neuron b) {
078            final long aId = a.getIdentifier();
079            final long bId = b.getIdentifier();
080            return aId < bId ? -1 :
081                aId > bId ? 1 : 0;
082        }
083    }
084
085    /**
086     * Constructor with restricted access, solely used for deserialization.
087     *
088     * @param nextId Next available identifier.
089     * @param featureSize Number of features.
090     * @param neuronList Neurons.
091     * @param neighbourIdList Links associated to each of the neurons in
092     * {@code neuronList}.
093     * @throws MathIllegalStateException if an inconsistency is detected
094     * (which probably means that the serialized form has been corrupted).
095     */
096    Network(long nextId,
097            int featureSize,
098            Neuron[] neuronList,
099            long[][] neighbourIdList) {
100        final int numNeurons = neuronList.length;
101        if (numNeurons != neighbourIdList.length) {
102            throw new MathIllegalStateException();
103        }
104
105        for (int i = 0; i < numNeurons; i++) {
106            final Neuron n = neuronList[i];
107            final long id = n.getIdentifier();
108            if (id >= nextId) {
109                throw new MathIllegalStateException();
110            }
111            neuronMap.put(id, n);
112            linkMap.put(id, new HashSet<Long>());
113        }
114
115        for (int i = 0; i < numNeurons; i++) {
116            final long aId = neuronList[i].getIdentifier();
117            final Set<Long> aLinks = linkMap.get(aId);
118            for (Long bId : neighbourIdList[i]) {
119                if (neuronMap.get(bId) == null) {
120                    throw new MathIllegalStateException();
121                }
122                addLinkToLinkSet(aLinks, bId);
123            }
124        }
125
126        this.nextId = new AtomicLong(nextId);
127        this.featureSize = featureSize;
128    }
129
130    /**
131     * @param initialIdentifier Identifier for the first neuron that
132     * will be added to this network.
133     * @param featureSize Size of the neuron's features.
134     */
135    public Network(long initialIdentifier,
136                   int featureSize) {
137        nextId = new AtomicLong(initialIdentifier);
138        this.featureSize = featureSize;
139    }
140
141    /**
142     * Performs a deep copy of this instance.
143     * Upon return, the copied and original instances will be independent:
144     * Updating one will not affect the other.
145     *
146     * @return a new instance with the same state as this instance.
147     * @since 3.6
148     */
149    public synchronized Network copy() {
150        final Network copy = new Network(nextId.get(),
151                                         featureSize);
152
153
154        for (Map.Entry<Long, Neuron> e : neuronMap.entrySet()) {
155            copy.neuronMap.put(e.getKey(), e.getValue().copy());
156        }
157
158        for (Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) {
159            copy.linkMap.put(e.getKey(), new HashSet<>(e.getValue()));
160        }
161
162        return copy;
163    }
164
165    /**
166     * {@inheritDoc}
167     */
168    @Override
169    public Iterator<Neuron> iterator() {
170        return neuronMap.values().iterator();
171    }
172
173    /**
174     * Creates a list of the neurons, sorted in a custom order.
175     *
176     * @param comparator {@link Comparator} used for sorting the neurons.
177     * @return a list of neurons, sorted in the order prescribed by the
178     * given {@code comparator}.
179     * @see NeuronIdentifierComparator
180     */
181    public Collection<Neuron> getNeurons(Comparator<Neuron> comparator) {
182        final List<Neuron> neurons = new ArrayList<>(neuronMap.values());
183
184        Collections.sort(neurons, comparator);
185
186        return neurons;
187    }
188
189    /**
190     * Creates a neuron and assigns it a unique identifier.
191     *
192     * @param features Initial values for the neuron's features.
193     * @return the neuron's identifier.
194     * @throws DimensionMismatchException if the length of {@code features}
195     * is different from the expected size (as set by the
196     * {@link #Network(long,int) constructor}).
197     */
198    public long createNeuron(double[] features) {
199        if (features.length != featureSize) {
200            throw new DimensionMismatchException(features.length, featureSize);
201        }
202
203        final long id = createNextId();
204        neuronMap.put(id, new Neuron(id, features));
205        linkMap.put(id, new HashSet<Long>());
206        return id;
207    }
208
209    /**
210     * Deletes a neuron.
211     * Links from all neighbours to the removed neuron will also be
212     * {@link #deleteLink(Neuron,Neuron) deleted}.
213     *
214     * @param neuron Neuron to be removed from this network.
215     * @throws NoSuchElementException if {@code n} does not belong to
216     * this network.
217     */
218    public void deleteNeuron(Neuron neuron) {
219        // Delete links to from neighbours.
220        getNeighbours(neuron).forEach(neighbour -> deleteLink(neighbour, neuron));
221
222        // Remove neuron.
223        neuronMap.remove(neuron.getIdentifier());
224    }
225
226    /**
227     * Gets the size of the neurons' features set.
228     *
229     * @return the size of the features set.
230     */
231    public int getFeaturesSize() {
232        return featureSize;
233    }
234
235    /**
236     * Adds a link from neuron {@code a} to neuron {@code b}.
237     * Note: the link is not bi-directional; if a bi-directional link is
238     * required, an additional call must be made with {@code a} and
239     * {@code b} exchanged in the argument list.
240     *
241     * @param a Neuron.
242     * @param b Neuron.
243     * @throws NoSuchElementException if the neurons do not exist in the
244     * network.
245     */
246    public void addLink(Neuron a,
247                        Neuron b) {
248        final long aId = a.getIdentifier();
249        final long bId = b.getIdentifier();
250
251        // Check that the neurons belong to this network.
252        if (a != getNeuron(aId)) {
253            throw new NoSuchElementException(Long.toString(aId));
254        }
255        if (b != getNeuron(bId)) {
256            throw new NoSuchElementException(Long.toString(bId));
257        }
258
259        // Add link from "a" to "b".
260        addLinkToLinkSet(linkMap.get(aId), bId);
261    }
262
263    /**
264     * Adds a link to neuron {@code id} in given {@code linkSet}.
265     * Note: no check verifies that the identifier indeed belongs
266     * to this network.
267     *
268     * @param linkSet Neuron identifier.
269     * @param id Neuron identifier.
270     */
271    private void addLinkToLinkSet(Set<Long> linkSet,
272                                  long id) {
273        linkSet.add(id);
274    }
275
276    /**
277     * Deletes the link between neurons {@code a} and {@code b}.
278     *
279     * @param a Neuron.
280     * @param b Neuron.
281     * @throws NoSuchElementException if the neurons do not exist in the
282     * network.
283     */
284    public void deleteLink(Neuron a,
285                           Neuron b) {
286        final long aId = a.getIdentifier();
287        final long bId = b.getIdentifier();
288
289        // Check that the neurons belong to this network.
290        if (a != getNeuron(aId)) {
291            throw new NoSuchElementException(Long.toString(aId));
292        }
293        if (b != getNeuron(bId)) {
294            throw new NoSuchElementException(Long.toString(bId));
295        }
296
297        // Delete link from "a" to "b".
298        deleteLinkFromLinkSet(linkMap.get(aId), bId);
299    }
300
301    /**
302     * Deletes a link to neuron {@code id} in given {@code linkSet}.
303     * Note: no check verifies that the identifier indeed belongs
304     * to this network.
305     *
306     * @param linkSet Neuron identifier.
307     * @param id Neuron identifier.
308     */
309    private void deleteLinkFromLinkSet(Set<Long> linkSet,
310                                       long id) {
311        linkSet.remove(id);
312    }
313
314    /**
315     * Retrieves the neuron with the given (unique) {@code id}.
316     *
317     * @param id Identifier.
318     * @return the neuron associated with the given {@code id}.
319     * @throws NoSuchElementException if the neuron does not exist in the
320     * network.
321     */
322    public Neuron getNeuron(long id) {
323        final Neuron n = neuronMap.get(id);
324        if (n == null) {
325            throw new NoSuchElementException(Long.toString(id));
326        }
327        return n;
328    }
329
330    /**
331     * Retrieves the neurons in the neighbourhood of any neuron in the
332     * {@code neurons} list.
333     * @param neurons Neurons for which to retrieve the neighbours.
334     * @return the list of neighbours.
335     * @see #getNeighbours(Iterable,Iterable)
336     */
337    public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) {
338        return getNeighbours(neurons, null);
339    }
340
341    /**
342     * Retrieves the neurons in the neighbourhood of any neuron in the
343     * {@code neurons} list.
344     * The {@code exclude} list allows to retrieve the "concentric"
345     * neighbourhoods by removing the neurons that belong to the inner
346     * "circles".
347     *
348     * @param neurons Neurons for which to retrieve the neighbours.
349     * @param exclude Neurons to exclude from the returned list.
350     * Can be {@code null}.
351     * @return the list of neighbours.
352     */
353    public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons,
354                                            Iterable<Neuron> exclude) {
355        final Set<Long> idList = new HashSet<>();
356        neurons.forEach(n -> idList.addAll(linkMap.get(n.getIdentifier())));
357
358        if (exclude != null) {
359            exclude.forEach(n -> idList.remove(n.getIdentifier()));
360        }
361
362        return idList.stream().map(this::getNeuron).collect(Collectors.toList());
363    }
364
365    /**
366     * Retrieves the neighbours of the given neuron.
367     *
368     * @param neuron Neuron for which to retrieve the neighbours.
369     * @return the list of neighbours.
370     * @see #getNeighbours(Neuron,Iterable)
371     */
372    public Collection<Neuron> getNeighbours(Neuron neuron) {
373        return getNeighbours(neuron, null);
374    }
375
376    /**
377     * Retrieves the neighbours of the given neuron.
378     *
379     * @param neuron Neuron for which to retrieve the neighbours.
380     * @param exclude Neurons to exclude from the returned list.
381     * Can be {@code null}.
382     * @return the list of neighbours.
383     */
384    public Collection<Neuron> getNeighbours(Neuron neuron,
385                                            Iterable<Neuron> exclude) {
386        final Set<Long> idList = linkMap.get(neuron.getIdentifier());
387        if (exclude != null) {
388            for (Neuron n : exclude) {
389                idList.remove(n.getIdentifier());
390            }
391        }
392
393        final List<Neuron> neuronList = new ArrayList<>();
394        for (Long id : idList) {
395            neuronList.add(getNeuron(id));
396        }
397
398        return neuronList;
399    }
400
401    /**
402     * Creates a neuron identifier.
403     *
404     * @return a value that will serve as a unique identifier.
405     */
406    private Long createNextId() {
407        return nextId.getAndIncrement();
408    }
409
410    /**
411     * Prevents proxy bypass.
412     *
413     * @param in Input stream.
414     */
415    private void readObject(ObjectInputStream in) {
416        throw new IllegalStateException();
417    }
418
419    /**
420     * Custom serialization.
421     *
422     * @return the proxy instance that will be actually serialized.
423     */
424    private Object writeReplace() {
425        final Neuron[] neuronList = neuronMap.values().toArray(new Neuron[0]);
426        final long[][] neighbourIdList = new long[neuronList.length][];
427
428        for (int i = 0; i < neuronList.length; i++) {
429            final Collection<Neuron> neighbours = getNeighbours(neuronList[i]);
430            final long[] neighboursId = new long[neighbours.size()];
431            int count = 0;
432            for (Neuron n : neighbours) {
433                neighboursId[count] = n.getIdentifier();
434                ++count;
435            }
436            neighbourIdList[i] = neighboursId;
437        }
438
439        return new SerializationProxy(nextId.get(),
440                                      featureSize,
441                                      neuronList,
442                                      neighbourIdList);
443    }
444
445    /**
446     * Serialization.
447     */
448    private static class SerializationProxy implements Serializable {
449        /** Serializable. */
450        private static final long serialVersionUID = 20130207L;
451        /** Next identifier. */
452        private final long nextId;
453        /** Number of features. */
454        private final int featureSize;
455        /** Neurons. */
456        private final Neuron[] neuronList;
457        /** Links. */
458        private final long[][] neighbourIdList;
459
460        /**
461         * @param nextId Next available identifier.
462         * @param featureSize Number of features.
463         * @param neuronList Neurons.
464         * @param neighbourIdList Links associated to each of the neurons in
465         * {@code neuronList}.
466         */
467        SerializationProxy(long nextId,
468                           int featureSize,
469                           Neuron[] neuronList,
470                           long[][] neighbourIdList) {
471            this.nextId = nextId;
472            this.featureSize = featureSize;
473            this.neuronList = neuronList;
474            this.neighbourIdList = neighbourIdList;
475        }
476
477        /**
478         * Custom serialization.
479         *
480         * @return the {@link Network} for which this instance is the proxy.
481         */
482        private Object readResolve() {
483            return new Network(nextId,
484                               featureSize,
485                               neuronList,
486                               neighbourIdList);
487        }
488    }
489}