001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.ml.neuralnet;
019
020import java.io.Serializable;
021import java.io.ObjectInputStream;
022import java.util.NoSuchElementException;
023import java.util.List;
024import java.util.ArrayList;
025import java.util.Set;
026import java.util.HashSet;
027import java.util.Collection;
028import java.util.Iterator;
029import java.util.Comparator;
030import java.util.Collections;
031import java.util.Map;
032import java.util.concurrent.ConcurrentHashMap;
033import java.util.concurrent.atomic.AtomicLong;
034import org.apache.commons.math3.exception.DimensionMismatchException;
035import org.apache.commons.math3.exception.MathIllegalStateException;
036
037/**
038 * Neural network, composed of {@link Neuron} instances and the links
039 * between them.
040 *
041 * Although updating a neuron's state is thread-safe, modifying the
042 * network's topology (adding or removing links) is not.
043 *
044 * @since 3.3
045 */
046public class Network
047    implements Iterable<Neuron>,
048               Serializable {
049    /** Serializable. */
050    private static final long serialVersionUID = 20130207L;
051    /** Neurons. */
052    private final ConcurrentHashMap<Long, Neuron> neuronMap
053        = new ConcurrentHashMap<Long, Neuron>();
054    /** Next available neuron identifier. */
055    private final AtomicLong nextId;
056    /** Neuron's features set size. */
057    private final int featureSize;
058    /** Links. */
059    private final ConcurrentHashMap<Long, Set<Long>> linkMap
060        = new ConcurrentHashMap<Long, Set<Long>>();
061
062    /**
063     * Comparator that prescribes an order of the neurons according
064     * to the increasing order of their identifier.
065     */
066    public static class NeuronIdentifierComparator
067        implements Comparator<Neuron>,
068                   Serializable {
069        /** Version identifier. */
070        private static final long serialVersionUID = 20130207L;
071
072        /** {@inheritDoc} */
073        public int compare(Neuron a,
074                           Neuron b) {
075            final long aId = a.getIdentifier();
076            final long bId = b.getIdentifier();
077            return aId < bId ? -1 :
078                aId > bId ? 1 : 0;
079        }
080    }
081
082    /**
083     * Constructor with restricted access, solely used for deserialization.
084     *
085     * @param nextId Next available identifier.
086     * @param featureSize Number of features.
087     * @param neuronList Neurons.
088     * @param neighbourIdList Links associated to each of the neurons in
089     * {@code neuronList}.
090     * @throws MathIllegalStateException if an inconsistency is detected
091     * (which probably means that the serialized form has been corrupted).
092     */
093    Network(long nextId,
094            int featureSize,
095            Neuron[] neuronList,
096            long[][] neighbourIdList) {
097        final int numNeurons = neuronList.length;
098        if (numNeurons != neighbourIdList.length) {
099            throw new MathIllegalStateException();
100        }
101
102        for (int i = 0; i < numNeurons; i++) {
103            final Neuron n = neuronList[i];
104            final long id = n.getIdentifier();
105            if (id >= nextId) {
106                throw new MathIllegalStateException();
107            }
108            neuronMap.put(id, n);
109            linkMap.put(id, new HashSet<Long>());
110        }
111
112        for (int i = 0; i < numNeurons; i++) {
113            final long aId = neuronList[i].getIdentifier();
114            final Set<Long> aLinks = linkMap.get(aId);
115            for (Long bId : neighbourIdList[i]) {
116                if (neuronMap.get(bId) == null) {
117                    throw new MathIllegalStateException();
118                }
119                addLinkToLinkSet(aLinks, bId);
120            }
121        }
122
123        this.nextId = new AtomicLong(nextId);
124        this.featureSize = featureSize;
125    }
126
127    /**
128     * @param initialIdentifier Identifier for the first neuron that
129     * will be added to this network.
130     * @param featureSize Size of the neuron's features.
131     */
132    public Network(long initialIdentifier,
133                   int featureSize) {
134        nextId = new AtomicLong(initialIdentifier);
135        this.featureSize = featureSize;
136    }
137
138    /**
139     * Performs a deep copy of this instance.
140     * Upon return, the copied and original instances will be independent:
141     * Updating one will not affect the other.
142     *
143     * @return a new instance with the same state as this instance.
144     * @since 3.6
145     */
146    public synchronized Network copy() {
147        final Network copy = new Network(nextId.get(),
148                                         featureSize);
149
150
151        for (Map.Entry<Long, Neuron> e : neuronMap.entrySet()) {
152            copy.neuronMap.put(e.getKey(), e.getValue().copy());
153        }
154
155        for (Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) {
156            copy.linkMap.put(e.getKey(), new HashSet<Long>(e.getValue()));
157        }
158
159        return copy;
160    }
161
162    /**
163     * {@inheritDoc}
164     */
165    public Iterator<Neuron> iterator() {
166        return neuronMap.values().iterator();
167    }
168
169    /**
170     * Creates a list of the neurons, sorted in a custom order.
171     *
172     * @param comparator {@link Comparator} used for sorting the neurons.
173     * @return a list of neurons, sorted in the order prescribed by the
174     * given {@code comparator}.
175     * @see NeuronIdentifierComparator
176     */
177    public Collection<Neuron> getNeurons(Comparator<Neuron> comparator) {
178        final List<Neuron> neurons = new ArrayList<Neuron>();
179        neurons.addAll(neuronMap.values());
180
181        Collections.sort(neurons, comparator);
182
183        return neurons;
184    }
185
186    /**
187     * Creates a neuron and assigns it a unique identifier.
188     *
189     * @param features Initial values for the neuron's features.
190     * @return the neuron's identifier.
191     * @throws DimensionMismatchException if the length of {@code features}
192     * is different from the expected size (as set by the
193     * {@link #Network(long,int) constructor}).
194     */
195    public long createNeuron(double[] features) {
196        if (features.length != featureSize) {
197            throw new DimensionMismatchException(features.length, featureSize);
198        }
199
200        final long id = createNextId();
201        neuronMap.put(id, new Neuron(id, features));
202        linkMap.put(id, new HashSet<Long>());
203        return id;
204    }
205
206    /**
207     * Deletes a neuron.
208     * Links from all neighbours to the removed neuron will also be
209     * {@link #deleteLink(Neuron,Neuron) deleted}.
210     *
211     * @param neuron Neuron to be removed from this network.
212     * @throws NoSuchElementException if {@code n} does not belong to
213     * this network.
214     */
215    public void deleteNeuron(Neuron neuron) {
216        final Collection<Neuron> neighbours = getNeighbours(neuron);
217
218        // Delete links to from neighbours.
219        for (Neuron n : neighbours) {
220            deleteLink(n, neuron);
221        }
222
223        // Remove neuron.
224        neuronMap.remove(neuron.getIdentifier());
225    }
226
227    /**
228     * Gets the size of the neurons' features set.
229     *
230     * @return the size of the features set.
231     */
232    public int getFeaturesSize() {
233        return featureSize;
234    }
235
236    /**
237     * Adds a link from neuron {@code a} to neuron {@code b}.
238     * Note: the link is not bi-directional; if a bi-directional link is
239     * required, an additional call must be made with {@code a} and
240     * {@code b} exchanged in the argument list.
241     *
242     * @param a Neuron.
243     * @param b Neuron.
244     * @throws NoSuchElementException if the neurons do not exist in the
245     * network.
246     */
247    public void addLink(Neuron a,
248                        Neuron b) {
249        final long aId = a.getIdentifier();
250        final long bId = b.getIdentifier();
251
252        // Check that the neurons belong to this network.
253        if (a != getNeuron(aId)) {
254            throw new NoSuchElementException(Long.toString(aId));
255        }
256        if (b != getNeuron(bId)) {
257            throw new NoSuchElementException(Long.toString(bId));
258        }
259
260        // Add link from "a" to "b".
261        addLinkToLinkSet(linkMap.get(aId), bId);
262    }
263
264    /**
265     * Adds a link to neuron {@code id} in given {@code linkSet}.
266     * Note: no check verifies that the identifier indeed belongs
267     * to this network.
268     *
269     * @param linkSet Neuron identifier.
270     * @param id Neuron identifier.
271     */
272    private void addLinkToLinkSet(Set<Long> linkSet,
273                                  long id) {
274        linkSet.add(id);
275    }
276
277    /**
278     * Deletes the link between neurons {@code a} and {@code b}.
279     *
280     * @param a Neuron.
281     * @param b Neuron.
282     * @throws NoSuchElementException if the neurons do not exist in the
283     * network.
284     */
285    public void deleteLink(Neuron a,
286                           Neuron b) {
287        final long aId = a.getIdentifier();
288        final long bId = b.getIdentifier();
289
290        // Check that the neurons belong to this network.
291        if (a != getNeuron(aId)) {
292            throw new NoSuchElementException(Long.toString(aId));
293        }
294        if (b != getNeuron(bId)) {
295            throw new NoSuchElementException(Long.toString(bId));
296        }
297
298        // Delete link from "a" to "b".
299        deleteLinkFromLinkSet(linkMap.get(aId), bId);
300    }
301
302    /**
303     * Deletes a link to neuron {@code id} in given {@code linkSet}.
304     * Note: no check verifies that the identifier indeed belongs
305     * to this network.
306     *
307     * @param linkSet Neuron identifier.
308     * @param id Neuron identifier.
309     */
310    private void deleteLinkFromLinkSet(Set<Long> linkSet,
311                                       long id) {
312        linkSet.remove(id);
313    }
314
315    /**
316     * Retrieves the neuron with the given (unique) {@code id}.
317     *
318     * @param id Identifier.
319     * @return the neuron associated with the given {@code id}.
320     * @throws NoSuchElementException if the neuron does not exist in the
321     * network.
322     */
323    public Neuron getNeuron(long id) {
324        final Neuron n = neuronMap.get(id);
325        if (n == null) {
326            throw new NoSuchElementException(Long.toString(id));
327        }
328        return n;
329    }
330
331    /**
332     * Retrieves the neurons in the neighbourhood of any neuron in the
333     * {@code neurons} list.
334     * @param neurons Neurons for which to retrieve the neighbours.
335     * @return the list of neighbours.
336     * @see #getNeighbours(Iterable,Iterable)
337     */
338    public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) {
339        return getNeighbours(neurons, null);
340    }
341
342    /**
343     * Retrieves the neurons in the neighbourhood of any neuron in the
344     * {@code neurons} list.
345     * The {@code exclude} list allows to retrieve the "concentric"
346     * neighbourhoods by removing the neurons that belong to the inner
347     * "circles".
348     *
349     * @param neurons Neurons for which to retrieve the neighbours.
350     * @param exclude Neurons to exclude from the returned list.
351     * Can be {@code null}.
352     * @return the list of neighbours.
353     */
354    public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons,
355                                            Iterable<Neuron> exclude) {
356        final Set<Long> idList = new HashSet<Long>();
357
358        for (Neuron n : neurons) {
359            idList.addAll(linkMap.get(n.getIdentifier()));
360        }
361        if (exclude != null) {
362            for (Neuron n : exclude) {
363                idList.remove(n.getIdentifier());
364            }
365        }
366
367        final List<Neuron> neuronList = new ArrayList<Neuron>();
368        for (Long id : idList) {
369            neuronList.add(getNeuron(id));
370        }
371
372        return neuronList;
373    }
374
375    /**
376     * Retrieves the neighbours of the given neuron.
377     *
378     * @param neuron Neuron for which to retrieve the neighbours.
379     * @return the list of neighbours.
380     * @see #getNeighbours(Neuron,Iterable)
381     */
382    public Collection<Neuron> getNeighbours(Neuron neuron) {
383        return getNeighbours(neuron, null);
384    }
385
386    /**
387     * Retrieves the neighbours of the given neuron.
388     *
389     * @param neuron Neuron for which to retrieve the neighbours.
390     * @param exclude Neurons to exclude from the returned list.
391     * Can be {@code null}.
392     * @return the list of neighbours.
393     */
394    public Collection<Neuron> getNeighbours(Neuron neuron,
395                                            Iterable<Neuron> exclude) {
396        final Set<Long> idList = linkMap.get(neuron.getIdentifier());
397        if (exclude != null) {
398            for (Neuron n : exclude) {
399                idList.remove(n.getIdentifier());
400            }
401        }
402
403        final List<Neuron> neuronList = new ArrayList<Neuron>();
404        for (Long id : idList) {
405            neuronList.add(getNeuron(id));
406        }
407
408        return neuronList;
409    }
410
411    /**
412     * Creates a neuron identifier.
413     *
414     * @return a value that will serve as a unique identifier.
415     */
416    private Long createNextId() {
417        return nextId.getAndIncrement();
418    }
419
420    /**
421     * Prevents proxy bypass.
422     *
423     * @param in Input stream.
424     */
425    private void readObject(ObjectInputStream in) {
426        throw new IllegalStateException();
427    }
428
429    /**
430     * Custom serialization.
431     *
432     * @return the proxy instance that will be actually serialized.
433     */
434    private Object writeReplace() {
435        final Neuron[] neuronList = neuronMap.values().toArray(new Neuron[0]);
436        final long[][] neighbourIdList = new long[neuronList.length][];
437
438        for (int i = 0; i < neuronList.length; i++) {
439            final Collection<Neuron> neighbours = getNeighbours(neuronList[i]);
440            final long[] neighboursId = new long[neighbours.size()];
441            int count = 0;
442            for (Neuron n : neighbours) {
443                neighboursId[count] = n.getIdentifier();
444                ++count;
445            }
446            neighbourIdList[i] = neighboursId;
447        }
448
449        return new SerializationProxy(nextId.get(),
450                                      featureSize,
451                                      neuronList,
452                                      neighbourIdList);
453    }
454
455    /**
456     * Serialization.
457     */
458    private static class SerializationProxy implements Serializable {
459        /** Serializable. */
460        private static final long serialVersionUID = 20130207L;
461        /** Next identifier. */
462        private final long nextId;
463        /** Number of features. */
464        private final int featureSize;
465        /** Neurons. */
466        private final Neuron[] neuronList;
467        /** Links. */
468        private final long[][] neighbourIdList;
469
470        /**
471         * @param nextId Next available identifier.
472         * @param featureSize Number of features.
473         * @param neuronList Neurons.
474         * @param neighbourIdList Links associated to each of the neurons in
475         * {@code neuronList}.
476         */
477        SerializationProxy(long nextId,
478                           int featureSize,
479                           Neuron[] neuronList,
480                           long[][] neighbourIdList) {
481            this.nextId = nextId;
482            this.featureSize = featureSize;
483            this.neuronList = neuronList;
484            this.neighbourIdList = neighbourIdList;
485        }
486
487        /**
488         * Custom serialization.
489         *
490         * @return the {@link Network} for which this instance is the proxy.
491         */
492        private Object readResolve() {
493            return new Network(nextId,
494                               featureSize,
495                               neuronList,
496                               neighbourIdList);
497        }
498    }
499}