001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math3.ml.neuralnet; 019 020import java.io.Serializable; 021import java.io.ObjectInputStream; 022import java.util.NoSuchElementException; 023import java.util.List; 024import java.util.ArrayList; 025import java.util.Set; 026import java.util.HashSet; 027import java.util.Collection; 028import java.util.Iterator; 029import java.util.Comparator; 030import java.util.Collections; 031import java.util.Map; 032import java.util.concurrent.ConcurrentHashMap; 033import java.util.concurrent.atomic.AtomicLong; 034import org.apache.commons.math3.exception.DimensionMismatchException; 035import org.apache.commons.math3.exception.MathIllegalStateException; 036 037/** 038 * Neural network, composed of {@link Neuron} instances and the links 039 * between them. 040 * 041 * Although updating a neuron's state is thread-safe, modifying the 042 * network's topology (adding or removing links) is not. 043 * 044 * @since 3.3 045 */ 046public class Network 047 implements Iterable<Neuron>, 048 Serializable { 049 /** Serializable. */ 050 private static final long serialVersionUID = 20130207L; 051 /** Neurons. */ 052 private final ConcurrentHashMap<Long, Neuron> neuronMap 053 = new ConcurrentHashMap<Long, Neuron>(); 054 /** Next available neuron identifier. */ 055 private final AtomicLong nextId; 056 /** Neuron's features set size. */ 057 private final int featureSize; 058 /** Links. */ 059 private final ConcurrentHashMap<Long, Set<Long>> linkMap 060 = new ConcurrentHashMap<Long, Set<Long>>(); 061 062 /** 063 * Comparator that prescribes an order of the neurons according 064 * to the increasing order of their identifier. 065 */ 066 public static class NeuronIdentifierComparator 067 implements Comparator<Neuron>, 068 Serializable { 069 /** Version identifier. */ 070 private static final long serialVersionUID = 20130207L; 071 072 /** {@inheritDoc} */ 073 public int compare(Neuron a, 074 Neuron b) { 075 final long aId = a.getIdentifier(); 076 final long bId = b.getIdentifier(); 077 return aId < bId ? -1 : 078 aId > bId ? 1 : 0; 079 } 080 } 081 082 /** 083 * Constructor with restricted access, solely used for deserialization. 084 * 085 * @param nextId Next available identifier. 086 * @param featureSize Number of features. 087 * @param neuronList Neurons. 088 * @param neighbourIdList Links associated to each of the neurons in 089 * {@code neuronList}. 090 * @throws MathIllegalStateException if an inconsistency is detected 091 * (which probably means that the serialized form has been corrupted). 092 */ 093 Network(long nextId, 094 int featureSize, 095 Neuron[] neuronList, 096 long[][] neighbourIdList) { 097 final int numNeurons = neuronList.length; 098 if (numNeurons != neighbourIdList.length) { 099 throw new MathIllegalStateException(); 100 } 101 102 for (int i = 0; i < numNeurons; i++) { 103 final Neuron n = neuronList[i]; 104 final long id = n.getIdentifier(); 105 if (id >= nextId) { 106 throw new MathIllegalStateException(); 107 } 108 neuronMap.put(id, n); 109 linkMap.put(id, new HashSet<Long>()); 110 } 111 112 for (int i = 0; i < numNeurons; i++) { 113 final long aId = neuronList[i].getIdentifier(); 114 final Set<Long> aLinks = linkMap.get(aId); 115 for (Long bId : neighbourIdList[i]) { 116 if (neuronMap.get(bId) == null) { 117 throw new MathIllegalStateException(); 118 } 119 addLinkToLinkSet(aLinks, bId); 120 } 121 } 122 123 this.nextId = new AtomicLong(nextId); 124 this.featureSize = featureSize; 125 } 126 127 /** 128 * @param initialIdentifier Identifier for the first neuron that 129 * will be added to this network. 130 * @param featureSize Size of the neuron's features. 131 */ 132 public Network(long initialIdentifier, 133 int featureSize) { 134 nextId = new AtomicLong(initialIdentifier); 135 this.featureSize = featureSize; 136 } 137 138 /** 139 * Performs a deep copy of this instance. 140 * Upon return, the copied and original instances will be independent: 141 * Updating one will not affect the other. 142 * 143 * @return a new instance with the same state as this instance. 144 * @since 3.6 145 */ 146 public synchronized Network copy() { 147 final Network copy = new Network(nextId.get(), 148 featureSize); 149 150 151 for (Map.Entry<Long, Neuron> e : neuronMap.entrySet()) { 152 copy.neuronMap.put(e.getKey(), e.getValue().copy()); 153 } 154 155 for (Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) { 156 copy.linkMap.put(e.getKey(), new HashSet<Long>(e.getValue())); 157 } 158 159 return copy; 160 } 161 162 /** 163 * {@inheritDoc} 164 */ 165 public Iterator<Neuron> iterator() { 166 return neuronMap.values().iterator(); 167 } 168 169 /** 170 * Creates a list of the neurons, sorted in a custom order. 171 * 172 * @param comparator {@link Comparator} used for sorting the neurons. 173 * @return a list of neurons, sorted in the order prescribed by the 174 * given {@code comparator}. 175 * @see NeuronIdentifierComparator 176 */ 177 public Collection<Neuron> getNeurons(Comparator<Neuron> comparator) { 178 final List<Neuron> neurons = new ArrayList<Neuron>(); 179 neurons.addAll(neuronMap.values()); 180 181 Collections.sort(neurons, comparator); 182 183 return neurons; 184 } 185 186 /** 187 * Creates a neuron and assigns it a unique identifier. 188 * 189 * @param features Initial values for the neuron's features. 190 * @return the neuron's identifier. 191 * @throws DimensionMismatchException if the length of {@code features} 192 * is different from the expected size (as set by the 193 * {@link #Network(long,int) constructor}). 194 */ 195 public long createNeuron(double[] features) { 196 if (features.length != featureSize) { 197 throw new DimensionMismatchException(features.length, featureSize); 198 } 199 200 final long id = createNextId(); 201 neuronMap.put(id, new Neuron(id, features)); 202 linkMap.put(id, new HashSet<Long>()); 203 return id; 204 } 205 206 /** 207 * Deletes a neuron. 208 * Links from all neighbours to the removed neuron will also be 209 * {@link #deleteLink(Neuron,Neuron) deleted}. 210 * 211 * @param neuron Neuron to be removed from this network. 212 * @throws NoSuchElementException if {@code n} does not belong to 213 * this network. 214 */ 215 public void deleteNeuron(Neuron neuron) { 216 final Collection<Neuron> neighbours = getNeighbours(neuron); 217 218 // Delete links to from neighbours. 219 for (Neuron n : neighbours) { 220 deleteLink(n, neuron); 221 } 222 223 // Remove neuron. 224 neuronMap.remove(neuron.getIdentifier()); 225 } 226 227 /** 228 * Gets the size of the neurons' features set. 229 * 230 * @return the size of the features set. 231 */ 232 public int getFeaturesSize() { 233 return featureSize; 234 } 235 236 /** 237 * Adds a link from neuron {@code a} to neuron {@code b}. 238 * Note: the link is not bi-directional; if a bi-directional link is 239 * required, an additional call must be made with {@code a} and 240 * {@code b} exchanged in the argument list. 241 * 242 * @param a Neuron. 243 * @param b Neuron. 244 * @throws NoSuchElementException if the neurons do not exist in the 245 * network. 246 */ 247 public void addLink(Neuron a, 248 Neuron b) { 249 final long aId = a.getIdentifier(); 250 final long bId = b.getIdentifier(); 251 252 // Check that the neurons belong to this network. 253 if (a != getNeuron(aId)) { 254 throw new NoSuchElementException(Long.toString(aId)); 255 } 256 if (b != getNeuron(bId)) { 257 throw new NoSuchElementException(Long.toString(bId)); 258 } 259 260 // Add link from "a" to "b". 261 addLinkToLinkSet(linkMap.get(aId), bId); 262 } 263 264 /** 265 * Adds a link to neuron {@code id} in given {@code linkSet}. 266 * Note: no check verifies that the identifier indeed belongs 267 * to this network. 268 * 269 * @param linkSet Neuron identifier. 270 * @param id Neuron identifier. 271 */ 272 private void addLinkToLinkSet(Set<Long> linkSet, 273 long id) { 274 linkSet.add(id); 275 } 276 277 /** 278 * Deletes the link between neurons {@code a} and {@code b}. 279 * 280 * @param a Neuron. 281 * @param b Neuron. 282 * @throws NoSuchElementException if the neurons do not exist in the 283 * network. 284 */ 285 public void deleteLink(Neuron a, 286 Neuron b) { 287 final long aId = a.getIdentifier(); 288 final long bId = b.getIdentifier(); 289 290 // Check that the neurons belong to this network. 291 if (a != getNeuron(aId)) { 292 throw new NoSuchElementException(Long.toString(aId)); 293 } 294 if (b != getNeuron(bId)) { 295 throw new NoSuchElementException(Long.toString(bId)); 296 } 297 298 // Delete link from "a" to "b". 299 deleteLinkFromLinkSet(linkMap.get(aId), bId); 300 } 301 302 /** 303 * Deletes a link to neuron {@code id} in given {@code linkSet}. 304 * Note: no check verifies that the identifier indeed belongs 305 * to this network. 306 * 307 * @param linkSet Neuron identifier. 308 * @param id Neuron identifier. 309 */ 310 private void deleteLinkFromLinkSet(Set<Long> linkSet, 311 long id) { 312 linkSet.remove(id); 313 } 314 315 /** 316 * Retrieves the neuron with the given (unique) {@code id}. 317 * 318 * @param id Identifier. 319 * @return the neuron associated with the given {@code id}. 320 * @throws NoSuchElementException if the neuron does not exist in the 321 * network. 322 */ 323 public Neuron getNeuron(long id) { 324 final Neuron n = neuronMap.get(id); 325 if (n == null) { 326 throw new NoSuchElementException(Long.toString(id)); 327 } 328 return n; 329 } 330 331 /** 332 * Retrieves the neurons in the neighbourhood of any neuron in the 333 * {@code neurons} list. 334 * @param neurons Neurons for which to retrieve the neighbours. 335 * @return the list of neighbours. 336 * @see #getNeighbours(Iterable,Iterable) 337 */ 338 public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) { 339 return getNeighbours(neurons, null); 340 } 341 342 /** 343 * Retrieves the neurons in the neighbourhood of any neuron in the 344 * {@code neurons} list. 345 * The {@code exclude} list allows to retrieve the "concentric" 346 * neighbourhoods by removing the neurons that belong to the inner 347 * "circles". 348 * 349 * @param neurons Neurons for which to retrieve the neighbours. 350 * @param exclude Neurons to exclude from the returned list. 351 * Can be {@code null}. 352 * @return the list of neighbours. 353 */ 354 public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons, 355 Iterable<Neuron> exclude) { 356 final Set<Long> idList = new HashSet<Long>(); 357 358 for (Neuron n : neurons) { 359 idList.addAll(linkMap.get(n.getIdentifier())); 360 } 361 if (exclude != null) { 362 for (Neuron n : exclude) { 363 idList.remove(n.getIdentifier()); 364 } 365 } 366 367 final List<Neuron> neuronList = new ArrayList<Neuron>(); 368 for (Long id : idList) { 369 neuronList.add(getNeuron(id)); 370 } 371 372 return neuronList; 373 } 374 375 /** 376 * Retrieves the neighbours of the given neuron. 377 * 378 * @param neuron Neuron for which to retrieve the neighbours. 379 * @return the list of neighbours. 380 * @see #getNeighbours(Neuron,Iterable) 381 */ 382 public Collection<Neuron> getNeighbours(Neuron neuron) { 383 return getNeighbours(neuron, null); 384 } 385 386 /** 387 * Retrieves the neighbours of the given neuron. 388 * 389 * @param neuron Neuron for which to retrieve the neighbours. 390 * @param exclude Neurons to exclude from the returned list. 391 * Can be {@code null}. 392 * @return the list of neighbours. 393 */ 394 public Collection<Neuron> getNeighbours(Neuron neuron, 395 Iterable<Neuron> exclude) { 396 final Set<Long> idList = linkMap.get(neuron.getIdentifier()); 397 if (exclude != null) { 398 for (Neuron n : exclude) { 399 idList.remove(n.getIdentifier()); 400 } 401 } 402 403 final List<Neuron> neuronList = new ArrayList<Neuron>(); 404 for (Long id : idList) { 405 neuronList.add(getNeuron(id)); 406 } 407 408 return neuronList; 409 } 410 411 /** 412 * Creates a neuron identifier. 413 * 414 * @return a value that will serve as a unique identifier. 415 */ 416 private Long createNextId() { 417 return nextId.getAndIncrement(); 418 } 419 420 /** 421 * Prevents proxy bypass. 422 * 423 * @param in Input stream. 424 */ 425 private void readObject(ObjectInputStream in) { 426 throw new IllegalStateException(); 427 } 428 429 /** 430 * Custom serialization. 431 * 432 * @return the proxy instance that will be actually serialized. 433 */ 434 private Object writeReplace() { 435 final Neuron[] neuronList = neuronMap.values().toArray(new Neuron[0]); 436 final long[][] neighbourIdList = new long[neuronList.length][]; 437 438 for (int i = 0; i < neuronList.length; i++) { 439 final Collection<Neuron> neighbours = getNeighbours(neuronList[i]); 440 final long[] neighboursId = new long[neighbours.size()]; 441 int count = 0; 442 for (Neuron n : neighbours) { 443 neighboursId[count] = n.getIdentifier(); 444 ++count; 445 } 446 neighbourIdList[i] = neighboursId; 447 } 448 449 return new SerializationProxy(nextId.get(), 450 featureSize, 451 neuronList, 452 neighbourIdList); 453 } 454 455 /** 456 * Serialization. 457 */ 458 private static class SerializationProxy implements Serializable { 459 /** Serializable. */ 460 private static final long serialVersionUID = 20130207L; 461 /** Next identifier. */ 462 private final long nextId; 463 /** Number of features. */ 464 private final int featureSize; 465 /** Neurons. */ 466 private final Neuron[] neuronList; 467 /** Links. */ 468 private final long[][] neighbourIdList; 469 470 /** 471 * @param nextId Next available identifier. 472 * @param featureSize Number of features. 473 * @param neuronList Neurons. 474 * @param neighbourIdList Links associated to each of the neurons in 475 * {@code neuronList}. 476 */ 477 SerializationProxy(long nextId, 478 int featureSize, 479 Neuron[] neuronList, 480 long[][] neighbourIdList) { 481 this.nextId = nextId; 482 this.featureSize = featureSize; 483 this.neuronList = neuronList; 484 this.neighbourIdList = neighbourIdList; 485 } 486 487 /** 488 * Custom serialization. 489 * 490 * @return the {@link Network} for which this instance is the proxy. 491 */ 492 private Object readResolve() { 493 return new Network(nextId, 494 featureSize, 495 neuronList, 496 neighbourIdList); 497 } 498 } 499}