001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.ml.neuralnet;
019
020import java.io.Serializable;
021import java.io.ObjectInputStream;
022import java.util.NoSuchElementException;
023import java.util.List;
024import java.util.ArrayList;
025import java.util.Set;
026import java.util.HashSet;
027import java.util.Collection;
028import java.util.Iterator;
029import java.util.Comparator;
030import java.util.Collections;
031import java.util.concurrent.ConcurrentHashMap;
032import java.util.concurrent.atomic.AtomicLong;
033import org.apache.commons.math3.exception.DimensionMismatchException;
034import org.apache.commons.math3.exception.MathIllegalStateException;
035
036/**
037 * Neural network, composed of {@link Neuron} instances and the links
038 * between them.
039 *
040 * Although updating a neuron's state is thread-safe, modifying the
041 * network's topology (adding or removing links) is not.
042 *
043 * @since 3.3
044 */
045public class Network
046    implements Iterable<Neuron>,
047               Serializable {
048    /** Serializable. */
049    private static final long serialVersionUID = 20130207L;
050    /** Neurons. */
051    private final ConcurrentHashMap<Long, Neuron> neuronMap
052        = new ConcurrentHashMap<Long, Neuron>();
053    /** Next available neuron identifier. */
054    private final AtomicLong nextId;
055    /** Neuron's features set size. */
056    private final int featureSize;
057    /** Links. */
058    private final ConcurrentHashMap<Long, Set<Long>> linkMap
059        = new ConcurrentHashMap<Long, Set<Long>>();
060
061    /**
062     * Comparator that prescribes an order of the neurons according
063     * to the increasing order of their identifier.
064     */
065    public static class NeuronIdentifierComparator
066        implements Comparator<Neuron>,
067                   Serializable {
068        /** Version identifier. */
069        private static final long serialVersionUID = 20130207L;
070
071        /** {@inheritDoc} */
072        public int compare(Neuron a,
073                           Neuron b) {
074            final long aId = a.getIdentifier();
075            final long bId = b.getIdentifier();
076            return aId < bId ? -1 :
077                aId > bId ? 1 : 0;
078        }
079    }
080
081    /**
082     * Constructor with restricted access, solely used for deserialization.
083     *
084     * @param nextId Next available identifier.
085     * @param featureSize Number of features.
086     * @param neuronList Neurons.
087     * @param neighbourIdList Links associated to each of the neurons in
088     * {@code neuronList}.
089     * @throws MathIllegalStateException if an inconsistency is detected
090     * (which probably means that the serialized form has been corrupted).
091     */
092    Network(long nextId,
093            int featureSize,
094            Neuron[] neuronList,
095            long[][] neighbourIdList) {
096        final int numNeurons = neuronList.length;
097        if (numNeurons != neighbourIdList.length) {
098            throw new MathIllegalStateException();
099        }
100
101        for (int i = 0; i < numNeurons; i++) {
102            final Neuron n = neuronList[i];
103            final long id = n.getIdentifier();
104            if (id >= nextId) {
105                throw new MathIllegalStateException();
106            }
107            neuronMap.put(id, n);
108            linkMap.put(id, new HashSet<Long>());
109        }
110
111        for (int i = 0; i < numNeurons; i++) {
112            final long aId = neuronList[i].getIdentifier();
113            final Set<Long> aLinks = linkMap.get(aId);
114            for (Long bId : neighbourIdList[i]) {
115                if (neuronMap.get(bId) == null) {
116                    throw new MathIllegalStateException();
117                }
118                addLinkToLinkSet(aLinks, bId);
119            }
120        }
121
122        this.nextId = new AtomicLong(nextId);
123        this.featureSize = featureSize;
124    }
125
126    /**
127     * @param initialIdentifier Identifier for the first neuron that
128     * will be added to this network.
129     * @param featureSize Size of the neuron's features.
130     */
131    public Network(long initialIdentifier,
132                   int featureSize) {
133        nextId = new AtomicLong(initialIdentifier);
134        this.featureSize = featureSize;
135    }
136
137    /**
138     * {@inheritDoc}
139     */
140    public Iterator<Neuron> iterator() {
141        return neuronMap.values().iterator();
142    }
143
144    /**
145     * Creates a list of the neurons, sorted in a custom order.
146     *
147     * @param comparator {@link Comparator} used for sorting the neurons.
148     * @return a list of neurons, sorted in the order prescribed by the
149     * given {@code comparator}.
150     * @see NeuronIdentifierComparator
151     */
152    public Collection<Neuron> getNeurons(Comparator<Neuron> comparator) {
153        final List<Neuron> neurons = new ArrayList<Neuron>();
154        neurons.addAll(neuronMap.values());
155
156        Collections.sort(neurons, comparator);
157
158        return neurons;
159    }
160
161    /**
162     * Creates a neuron and assigns it a unique identifier.
163     *
164     * @param features Initial values for the neuron's features.
165     * @return the neuron's identifier.
166     * @throws DimensionMismatchException if the length of {@code features}
167     * is different from the expected size (as set by the
168     * {@link #Network(long,int) constructor}).
169     */
170    public long createNeuron(double[] features) {
171        if (features.length != featureSize) {
172            throw new DimensionMismatchException(features.length, featureSize);
173        }
174
175        final long id = createNextId();
176        neuronMap.put(id, new Neuron(id, features));
177        linkMap.put(id, new HashSet<Long>());
178        return id;
179    }
180
181    /**
182     * Deletes a neuron.
183     * Links from all neighbours to the removed neuron will also be
184     * {@link #deleteLink(Neuron,Neuron) deleted}.
185     *
186     * @param neuron Neuron to be removed from this network.
187     * @throws NoSuchElementException if {@code n} does not belong to
188     * this network.
189     */
190    public void deleteNeuron(Neuron neuron) {
191        final Collection<Neuron> neighbours = getNeighbours(neuron);
192
193        // Delete links to from neighbours.
194        for (Neuron n : neighbours) {
195            deleteLink(n, neuron);
196        }
197
198        // Remove neuron.
199        neuronMap.remove(neuron.getIdentifier());
200    }
201
202    /**
203     * Gets the size of the neurons' features set.
204     *
205     * @return the size of the features set.
206     */
207    public int getFeaturesSize() {
208        return featureSize;
209    }
210
211    /**
212     * Adds a link from neuron {@code a} to neuron {@code b}.
213     * Note: the link is not bi-directional; if a bi-directional link is
214     * required, an additional call must be made with {@code a} and
215     * {@code b} exchanged in the argument list.
216     *
217     * @param a Neuron.
218     * @param b Neuron.
219     * @throws NoSuchElementException if the neurons do not exist in the
220     * network.
221     */
222    public void addLink(Neuron a,
223                        Neuron b) {
224        final long aId = a.getIdentifier();
225        final long bId = b.getIdentifier();
226
227        // Check that the neurons belong to this network.
228        if (a != getNeuron(aId)) {
229            throw new NoSuchElementException(Long.toString(aId));
230        }
231        if (b != getNeuron(bId)) {
232            throw new NoSuchElementException(Long.toString(bId));
233        }
234
235        // Add link from "a" to "b".
236        addLinkToLinkSet(linkMap.get(aId), bId);
237    }
238
239    /**
240     * Adds a link to neuron {@code id} in given {@code linkSet}.
241     * Note: no check verifies that the identifier indeed belongs
242     * to this network.
243     *
244     * @param linkSet Neuron identifier.
245     * @param id Neuron identifier.
246     */
247    private void addLinkToLinkSet(Set<Long> linkSet,
248                                  long id) {
249        linkSet.add(id);
250    }
251
252    /**
253     * Deletes the link between neurons {@code a} and {@code b}.
254     *
255     * @param a Neuron.
256     * @param b Neuron.
257     * @throws NoSuchElementException if the neurons do not exist in the
258     * network.
259     */
260    public void deleteLink(Neuron a,
261                           Neuron b) {
262        final long aId = a.getIdentifier();
263        final long bId = b.getIdentifier();
264
265        // Check that the neurons belong to this network.
266        if (a != getNeuron(aId)) {
267            throw new NoSuchElementException(Long.toString(aId));
268        }
269        if (b != getNeuron(bId)) {
270            throw new NoSuchElementException(Long.toString(bId));
271        }
272
273        // Delete link from "a" to "b".
274        deleteLinkFromLinkSet(linkMap.get(aId), bId);
275    }
276
277    /**
278     * Deletes a link to neuron {@code id} in given {@code linkSet}.
279     * Note: no check verifies that the identifier indeed belongs
280     * to this network.
281     *
282     * @param linkSet Neuron identifier.
283     * @param id Neuron identifier.
284     */
285    private void deleteLinkFromLinkSet(Set<Long> linkSet,
286                                       long id) {
287        linkSet.remove(id);
288    }
289
290    /**
291     * Retrieves the neuron with the given (unique) {@code id}.
292     *
293     * @param id Identifier.
294     * @return the neuron associated with the given {@code id}.
295     * @throws NoSuchElementException if the neuron does not exist in the
296     * network.
297     */
298    public Neuron getNeuron(long id) {
299        final Neuron n = neuronMap.get(id);
300        if (n == null) {
301            throw new NoSuchElementException(Long.toString(id));
302        }
303        return n;
304    }
305
306    /**
307     * Retrieves the neurons in the neighbourhood of any neuron in the
308     * {@code neurons} list.
309     * @param neurons Neurons for which to retrieve the neighbours.
310     * @return the list of neighbours.
311     * @see #getNeighbours(Iterable,Iterable)
312     */
313    public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) {
314        return getNeighbours(neurons, null);
315    }
316
317    /**
318     * Retrieves the neurons in the neighbourhood of any neuron in the
319     * {@code neurons} list.
320     * The {@code exclude} list allows to retrieve the "concentric"
321     * neighbourhoods by removing the neurons that belong to the inner
322     * "circles".
323     *
324     * @param neurons Neurons for which to retrieve the neighbours.
325     * @param exclude Neurons to exclude from the returned list.
326     * Can be {@code null}.
327     * @return the list of neighbours.
328     */
329    public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons,
330                                            Iterable<Neuron> exclude) {
331        final Set<Long> idList = new HashSet<Long>();
332
333        for (Neuron n : neurons) {
334            idList.addAll(linkMap.get(n.getIdentifier()));
335        }
336        if (exclude != null) {
337            for (Neuron n : exclude) {
338                idList.remove(n.getIdentifier());
339            }
340        }
341
342        final List<Neuron> neuronList = new ArrayList<Neuron>();
343        for (Long id : idList) {
344            neuronList.add(getNeuron(id));
345        }
346
347        return neuronList;
348    }
349
350    /**
351     * Retrieves the neighbours of the given neuron.
352     *
353     * @param neuron Neuron for which to retrieve the neighbours.
354     * @return the list of neighbours.
355     * @see #getNeighbours(Neuron,Iterable)
356     */
357    public Collection<Neuron> getNeighbours(Neuron neuron) {
358        return getNeighbours(neuron, null);
359    }
360
361    /**
362     * Retrieves the neighbours of the given neuron.
363     *
364     * @param neuron Neuron for which to retrieve the neighbours.
365     * @param exclude Neurons to exclude from the returned list.
366     * Can be {@code null}.
367     * @return the list of neighbours.
368     */
369    public Collection<Neuron> getNeighbours(Neuron neuron,
370                                            Iterable<Neuron> exclude) {
371        final Set<Long> idList = linkMap.get(neuron.getIdentifier());
372        if (exclude != null) {
373            for (Neuron n : exclude) {
374                idList.remove(n.getIdentifier());
375            }
376        }
377
378        final List<Neuron> neuronList = new ArrayList<Neuron>();
379        for (Long id : idList) {
380            neuronList.add(getNeuron(id));
381        }
382
383        return neuronList;
384    }
385
386    /**
387     * Creates a neuron identifier.
388     *
389     * @return a value that will serve as a unique identifier.
390     */
391    private Long createNextId() {
392        return nextId.getAndIncrement();
393    }
394
395    /**
396     * Prevents proxy bypass.
397     *
398     * @param in Input stream.
399     */
400    private void readObject(ObjectInputStream in) {
401        throw new IllegalStateException();
402    }
403
404    /**
405     * Custom serialization.
406     *
407     * @return the proxy instance that will be actually serialized.
408     */
409    private Object writeReplace() {
410        final Neuron[] neuronList = neuronMap.values().toArray(new Neuron[0]);
411        final long[][] neighbourIdList = new long[neuronList.length][];
412
413        for (int i = 0; i < neuronList.length; i++) {
414            final Collection<Neuron> neighbours = getNeighbours(neuronList[i]);
415            final long[] neighboursId = new long[neighbours.size()];
416            int count = 0;
417            for (Neuron n : neighbours) {
418                neighboursId[count] = n.getIdentifier();
419                ++count;
420            }
421            neighbourIdList[i] = neighboursId;
422        }
423
424        return new SerializationProxy(nextId.get(),
425                                      featureSize,
426                                      neuronList,
427                                      neighbourIdList);
428    }
429
430    /**
431     * Serialization.
432     */
433    private static class SerializationProxy implements Serializable {
434        /** Serializable. */
435        private static final long serialVersionUID = 20130207L;
436        /** Next identifier. */
437        private final long nextId;
438        /** Number of features. */
439        private final int featureSize;
440        /** Neurons. */
441        private final Neuron[] neuronList;
442        /** Links. */
443        private final long[][] neighbourIdList;
444
445        /**
446         * @param nextId Next available identifier.
447         * @param featureSize Number of features.
448         * @param neuronList Neurons.
449         * @param neighbourIdList Links associated to each of the neurons in
450         * {@code neuronList}.
451         */
452        SerializationProxy(long nextId,
453                           int featureSize,
454                           Neuron[] neuronList,
455                           long[][] neighbourIdList) {
456            this.nextId = nextId;
457            this.featureSize = featureSize;
458            this.neuronList = neuronList;
459            this.neighbourIdList = neighbourIdList;
460        }
461
462        /**
463         * Custom serialization.
464         *
465         * @return the {@link Network} for which this instance is the proxy.
466         */
467        private Object readResolve() {
468            return new Network(nextId,
469                               featureSize,
470                               neuronList,
471                               neighbourIdList);
472        }
473    }
474}