Network.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.neuralnet;

  18. import java.util.NoSuchElementException;
  19. import java.util.List;
  20. import java.util.ArrayList;
  21. import java.util.Set;
  22. import java.util.HashSet;
  23. import java.util.Collection;
  24. import java.util.Iterator;
  25. import java.util.Collections;
  26. import java.util.Map;
  27. import java.util.concurrent.ConcurrentHashMap;
  28. import java.util.concurrent.atomic.AtomicLong;
  29. import java.util.stream.Collectors;

  30. import org.apache.commons.math4.neuralnet.internal.NeuralNetException;

  31. /**
  32.  * Neural network, composed of {@link Neuron} instances and the links
  33.  * between them.
  34.  *
  35.  * Although updating a neuron's state is thread-safe, modifying the
  36.  * network's topology (adding or removing links) is not.
  37.  *
  38.  * @since 3.3
  39.  */
  40. public class Network
  41.     implements Iterable<Neuron> {
  42.     /** Neurons. */
  43.     private final ConcurrentHashMap<Long, Neuron> neuronMap
  44.         = new ConcurrentHashMap<>();
  45.     /** Next available neuron identifier. */
  46.     private final AtomicLong nextId;
  47.     /** Neuron's features set size. */
  48.     private final int featureSize;
  49.     /** Links. */
  50.     private final ConcurrentHashMap<Long, Set<Long>> linkMap
  51.         = new ConcurrentHashMap<>();

  52.     /**
  53.      * @param firstId Identifier of the first neuron that will be added
  54.      * to this network.
  55.      * @param featureSize Size of the neuron's features.
  56.      */
  57.     public Network(long firstId,
  58.                    int featureSize) {
  59.         this.nextId = new AtomicLong(firstId);
  60.         this.featureSize = featureSize;
  61.     }

  62.     /**
  63.      * Builds a network from a list of neurons and their neighbours.
  64.      *
  65.      * @param featureSize Number of features.
  66.      * @param idList List of neuron identifiers.
  67.      * @param featureList List of neuron features.
  68.      * @param neighbourIdList Links associated to each of the neurons in
  69.      * {@code idList}.
  70.      * @throws IllegalArgumentException if an inconsistency is detected.
  71.      * @return a new instance.
  72.      */
  73.     public static Network from(int featureSize,
  74.                                long[] idList,
  75.                                double[][] featureList,
  76.                                long[][] neighbourIdList) {
  77.         final int numNeurons = idList.length;
  78.         if (idList.length != featureList.length) {
  79.             throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH,
  80.                                          idList.length, featureList.length);
  81.         }
  82.         if (idList.length != neighbourIdList.length) {
  83.             throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH,
  84.                                          idList.length, neighbourIdList.length);
  85.         }

  86.         final Network net = new Network(Long.MIN_VALUE, featureSize);

  87.         for (int i = 0; i < numNeurons; i++) {
  88.             final long id = idList[i];
  89.             net.createNeuron(id, featureList[i]);
  90.         }

  91.         for (int i = 0; i < numNeurons; i++) {
  92.             final Neuron a = net.getNeuron(idList[i]);
  93.             for (final long id : neighbourIdList[i]) {
  94.                 final Neuron b = net.neuronMap.get(id);
  95.                 if (b == null) {
  96.                     throw new NeuralNetException(NeuralNetException.ID_NOT_FOUND, id);
  97.                 }
  98.                 net.addLink(a, b);
  99.             }
  100.         }

  101.         return net;
  102.     }

  103.     /**
  104.      * Performs a deep copy of this instance.
  105.      * Upon return, the copied and original instances will be independent:
  106.      * Updating one will not affect the other.
  107.      *
  108.      * @return a new instance with the same state as this instance.
  109.      * @since 3.6
  110.      */
  111.     public synchronized Network copy() {
  112.         final Network copy = new Network(nextId.get(),
  113.                                          featureSize);


  114.         for (final Map.Entry<Long, Neuron> e : neuronMap.entrySet()) {
  115.             copy.neuronMap.put(e.getKey(), e.getValue().copy());
  116.         }

  117.         for (final Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) {
  118.             copy.linkMap.put(e.getKey(), new HashSet<>(e.getValue()));
  119.         }

  120.         return copy;
  121.     }

  122.     /**
  123.      * {@inheritDoc}
  124.      */
  125.     @Override
  126.     public Iterator<Neuron> iterator() {
  127.         return neuronMap.values().iterator();
  128.     }

  129.     /**
  130.      * @return a shallow copy of the network's neurons.
  131.      */
  132.     public Collection<Neuron> getNeurons() {
  133.         return Collections.unmodifiableCollection(neuronMap.values());
  134.     }

  135.     /**
  136.      * Creates a neuron and assigns it a unique identifier.
  137.      *
  138.      * @param features Initial values for the neuron's features.
  139.      * @return the neuron's identifier.
  140.      * @throws IllegalArgumentException if the length of {@code features}
  141.      * is different from the expected size (as set by the
  142.      * {@link #Network(long,int) constructor}).
  143.      */
  144.     public long createNeuron(double[] features) {
  145.         return createNeuron(createNextId(), features);
  146.     }

  147.     /**
  148.      * @param id Identifier.
  149.      * @param features Features.
  150.      * @return {@¢ode id}.
  151.      * @throws IllegalArgumentException if the identifier is already used
  152.      * by a neuron that belongs to this network or the features size does
  153.      * not match the expected value.
  154.      */
  155.     private long createNeuron(long id,
  156.                               double[] features) {
  157.         if (neuronMap.get(id) != null) {
  158.             throw new NeuralNetException(NeuralNetException.ID_IN_USE, id);
  159.         }

  160.         if (features.length != featureSize) {
  161.             throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH,
  162.                                          features.length, featureSize);
  163.         }

  164.         neuronMap.put(id, new Neuron(id, features.clone()));
  165.         linkMap.put(id, new HashSet<>());

  166.         if (id > nextId.get()) {
  167.             nextId.set(id);
  168.         }

  169.         return id;
  170.     }

  171.     /**
  172.      * Deletes a neuron.
  173.      * Links from all neighbours to the removed neuron will also be
  174.      * {@link #deleteLink(Neuron,Neuron) deleted}.
  175.      *
  176.      * @param neuron Neuron to be removed from this network.
  177.      * @throws NoSuchElementException if {@code n} does not belong to
  178.      * this network.
  179.      */
  180.     public void deleteNeuron(Neuron neuron) {
  181.         // Delete links to from neighbours.
  182.         getNeighbours(neuron).forEach(neighbour -> deleteLink(neighbour, neuron));

  183.         // Remove neuron.
  184.         neuronMap.remove(neuron.getIdentifier());
  185.     }

  186.     /**
  187.      * Gets the size of the neurons' features set.
  188.      *
  189.      * @return the size of the features set.
  190.      */
  191.     public int getFeaturesSize() {
  192.         return featureSize;
  193.     }

  194.     /**
  195.      * Adds a link from neuron {@code a} to neuron {@code b}.
  196.      * Note: the link is not bi-directional; if a bi-directional link is
  197.      * required, an additional call must be made with {@code a} and
  198.      * {@code b} exchanged in the argument list.
  199.      *
  200.      * @param a Neuron.
  201.      * @param b Neuron.
  202.      * @throws NoSuchElementException if the neurons do not exist in the
  203.      * network.
  204.      */
  205.     public void addLink(Neuron a,
  206.                         Neuron b) {
  207.         // Check that the neurons belong to this network.
  208.         final long aId = a.getIdentifier();
  209.         if (a != getNeuron(aId)) {
  210.             throw new NoSuchElementException(Long.toString(aId));
  211.         }
  212.         final long bId = b.getIdentifier();
  213.         if (b != getNeuron(bId)) {
  214.             throw new NoSuchElementException(Long.toString(bId));
  215.         }

  216.         // Add link from "a" to "b".
  217.         addLinkToLinkSet(linkMap.get(aId), bId);
  218.     }

  219.     /**
  220.      * Adds a link to neuron {@code id} in given {@code linkSet}.
  221.      * Note: no check verifies that the identifier indeed belongs
  222.      * to this network.
  223.      *
  224.      * @param linkSet Neuron identifier.
  225.      * @param id Neuron identifier.
  226.      */
  227.     private void addLinkToLinkSet(Set<Long> linkSet,
  228.                                   long id) {
  229.         linkSet.add(id);
  230.     }

  231.     /**
  232.      * Deletes the link between neurons {@code a} and {@code b}.
  233.      *
  234.      * @param a Neuron.
  235.      * @param b Neuron.
  236.      * @throws NoSuchElementException if the neurons do not exist in the
  237.      * network.
  238.      */
  239.     public void deleteLink(Neuron a,
  240.                            Neuron b) {
  241.         // Check that the neurons belong to this network.
  242.         final long aId = a.getIdentifier();
  243.         if (a != getNeuron(aId)) {
  244.             throw new NoSuchElementException(Long.toString(aId));
  245.         }
  246.         final long bId = b.getIdentifier();
  247.         if (b != getNeuron(bId)) {
  248.             throw new NoSuchElementException(Long.toString(bId));
  249.         }

  250.         // Delete link from "a" to "b".
  251.         deleteLinkFromLinkSet(linkMap.get(aId), bId);
  252.     }

  253.     /**
  254.      * Deletes a link to neuron {@code id} in given {@code linkSet}.
  255.      * Note: no check verifies that the identifier indeed belongs
  256.      * to this network.
  257.      *
  258.      * @param linkSet Neuron identifier.
  259.      * @param id Neuron identifier.
  260.      */
  261.     private void deleteLinkFromLinkSet(Set<Long> linkSet,
  262.                                        long id) {
  263.         linkSet.remove(id);
  264.     }

  265.     /**
  266.      * Retrieves the neuron with the given (unique) {@code id}.
  267.      *
  268.      * @param id Identifier.
  269.      * @return the neuron associated with the given {@code id}.
  270.      * @throws NoSuchElementException if the neuron does not exist in the
  271.      * network.
  272.      */
  273.     public Neuron getNeuron(long id) {
  274.         final Neuron n = neuronMap.get(id);
  275.         if (n == null) {
  276.             throw new NoSuchElementException(Long.toString(id));
  277.         }
  278.         return n;
  279.     }

  280.     /**
  281.      * Retrieves the neurons in the neighbourhood of any neuron in the
  282.      * {@code neurons} list.
  283.      * @param neurons Neurons for which to retrieve the neighbours.
  284.      * @return the list of neighbours.
  285.      * @see #getNeighbours(Iterable,Iterable)
  286.      */
  287.     public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) {
  288.         return getNeighbours(neurons, null);
  289.     }

  290.     /**
  291.      * Retrieves the neurons in the neighbourhood of any neuron in the
  292.      * {@code neurons} list.
  293.      * The {@code exclude} list allows to retrieve the "concentric"
  294.      * neighbourhoods by removing the neurons that belong to the inner
  295.      * "circles".
  296.      *
  297.      * @param neurons Neurons for which to retrieve the neighbours.
  298.      * @param exclude Neurons to exclude from the returned list.
  299.      * Can be {@code null}.
  300.      * @return the list of neighbours.
  301.      */
  302.     public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons,
  303.                                             Iterable<Neuron> exclude) {
  304.         final Set<Long> idList = new HashSet<>();
  305.         neurons.forEach(n -> idList.addAll(linkMap.get(n.getIdentifier())));

  306.         if (exclude != null) {
  307.             exclude.forEach(n -> idList.remove(n.getIdentifier()));
  308.         }

  309.         return idList.stream().map(this::getNeuron).collect(Collectors.toList());
  310.     }

  311.     /**
  312.      * Retrieves the neighbours of the given neuron.
  313.      *
  314.      * @param neuron Neuron for which to retrieve the neighbours.
  315.      * @return the list of neighbours.
  316.      * @see #getNeighbours(Neuron,Iterable)
  317.      */
  318.     public Collection<Neuron> getNeighbours(Neuron neuron) {
  319.         return getNeighbours(neuron, null);
  320.     }

  321.     /**
  322.      * Retrieves the neighbours of the given neuron.
  323.      *
  324.      * @param neuron Neuron for which to retrieve the neighbours.
  325.      * @param exclude Neurons to exclude from the returned list.
  326.      * Can be {@code null}.
  327.      * @return the list of neighbours.
  328.      */
  329.     public Collection<Neuron> getNeighbours(Neuron neuron,
  330.                                             Iterable<Neuron> exclude) {
  331.         final Set<Long> idList = linkMap.get(neuron.getIdentifier());
  332.         if (exclude != null) {
  333.             for (final Neuron n : exclude) {
  334.                 idList.remove(n.getIdentifier());
  335.             }
  336.         }

  337.         final List<Neuron> neuronList = new ArrayList<>();
  338.         for (final Long id : idList) {
  339.             neuronList.add(getNeuron(id));
  340.         }

  341.         return neuronList;
  342.     }

  343.     /**
  344.      * Creates a neuron identifier.
  345.      *
  346.      * @return a value that will serve as a unique identifier.
  347.      */
  348.     private Long createNextId() {
  349.         return nextId.getAndIncrement();
  350.     }
  351. }