View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.neuralnet;
19  
20  import java.util.Collection;
21  import java.util.Arrays;
22  import java.util.NoSuchElementException;
23  
24  import org.junit.Assert;
25  import org.junit.Test;
26  
27  import org.apache.commons.rng.UniformRandomProvider;
28  import org.apache.commons.rng.simple.RandomSource;
29  
30  import org.apache.commons.math4.neuralnet.twod.NeuronSquareMesh2D;
31  
32  /**
33   * Tests for {@link Network}.
34   */
35  public class NetworkTest {
36      private final UniformRandomProvider rng = RandomSource.SPLIT_MIX_64.create();
37      private final FeatureInitializer init = FeatureInitializerFactory.uniform(rng, 0, 2);
38  
39      @Test
40      public void testGetFeaturesSize() {
41          final FeatureInitializer[] initArray = {init, init, init};
42  
43          final Network net = new NeuronSquareMesh2D(2, false,
44                                                     2, false,
45                                                     SquareNeighbourhood.VON_NEUMANN,
46                                                     initArray).getNetwork();
47          Assert.assertEquals(3, net.getFeaturesSize());
48      }
49  
50      /*
51       * Test assumes that the network is
52       *
53       *  0-----1
54       *  |     |
55       *  |     |
56       *  2-----3
57       */
58      @Test
59      public void testDeleteLink() {
60          final FeatureInitializer[] initArray = {init};
61          final Network net = new NeuronSquareMesh2D(2, false,
62                                                     2, false,
63                                                     SquareNeighbourhood.VON_NEUMANN,
64                                                     initArray).getNetwork();
65          Collection<Neuron> neighbours;
66  
67          // Delete 0 --> 1.
68          net.deleteLink(net.getNeuron(0),
69                         net.getNeuron(1));
70  
71          // Link from 0 to 1 was deleted.
72          neighbours = net.getNeighbours(net.getNeuron(0));
73          Assert.assertFalse(neighbours.contains(net.getNeuron(1)));
74          // Link from 1 to 0 still exists.
75          neighbours = net.getNeighbours(net.getNeuron(1));
76          Assert.assertTrue(neighbours.contains(net.getNeuron(0)));
77      }
78  
79      /*
80       * Test assumes that the network is
81       *
82       *  0-----1
83       *  |     |
84       *  |     |
85       *  2-----3
86       */
87      @Test
88      public void testDeleteNeuron() {
89          final FeatureInitializer[] initArray = {init};
90          final Network net = new NeuronSquareMesh2D(2, false,
91                                                     2, false,
92                                                     SquareNeighbourhood.VON_NEUMANN,
93                                                     initArray).getNetwork();
94  
95          Assert.assertEquals(2, net.getNeighbours(net.getNeuron(0)).size());
96          Assert.assertEquals(2, net.getNeighbours(net.getNeuron(3)).size());
97  
98          // Delete neuron 1.
99          net.deleteNeuron(net.getNeuron(1));
100 
101         try {
102             net.getNeuron(1);
103         } catch (NoSuchElementException expected) {
104           // Ignore
105         }
106 
107         Assert.assertEquals(1, net.getNeighbours(net.getNeuron(0)).size());
108         Assert.assertEquals(1, net.getNeighbours(net.getNeuron(3)).size());
109     }
110 
111     @Test
112     public void testIdentifierAssignment() {
113         final FeatureInitializer[] initArray = {init};
114         final long[] ids = getIdentifiers(new NeuronSquareMesh2D(4, false,
115                                                                 3, true,
116                                                                 SquareNeighbourhood.VON_NEUMANN,
117                                                                 initArray).getNetwork());
118 
119         Assert.assertEquals(12, ids.length);
120         Assert.assertEquals(0, ids[0]);
121         Assert.assertEquals(11, ids[ids.length - 1]);
122     }
123 
124     /*
125      * Test assumes that the network is
126      *
127      *  0-----1
128      *  |     |
129      *  |     |
130      *  2-----3
131      */
132     @Test
133     public void testCopy() {
134         final FeatureInitializer[] initArray = {init};
135         final Network net = new NeuronSquareMesh2D(2, false,
136                                                    2, false,
137                                                    SquareNeighbourhood.VON_NEUMANN,
138                                                    initArray).getNetwork();
139 
140         final Network copy = net.copy();
141 
142         final Neuron netNeuron0 = net.getNeuron(0);
143         final Neuron copyNeuron0 = copy.getNeuron(0);
144         final Neuron netNeuron1 = net.getNeuron(1);
145         final Neuron copyNeuron1 = copy.getNeuron(1);
146         Collection<Neuron> netNeighbours;
147         Collection<Neuron> copyNeighbours;
148 
149         // Check that both networks have the same connections.
150         netNeighbours = net.getNeighbours(netNeuron0);
151         copyNeighbours = copy.getNeighbours(copyNeuron0);
152         Assert.assertTrue(netNeighbours.contains(netNeuron1));
153         Assert.assertTrue(copyNeighbours.contains(copyNeuron1));
154 
155         // Delete neuron 1 from original.
156         net.deleteNeuron(netNeuron1);
157 
158         // Check that the networks now differ.
159         netNeighbours = net.getNeighbours(netNeuron0);
160         copyNeighbours = copy.getNeighbours(copyNeuron0);
161         Assert.assertFalse(netNeighbours.contains(netNeuron1));
162         Assert.assertTrue(copyNeighbours.contains(copyNeuron1));
163     }
164 
165     /**
166      * @param net Network.
167      * @return the sorted list identifiers.
168      */
169     private long[] getIdentifiers(Network net) {
170         final Collection<Neuron> neurons = net.getNeurons();
171         final long[] identifiers = new long[neurons.size()];
172 
173         int idx = 0;
174         for (Neuron n : neurons) {
175             identifiers[idx++] = n.getIdentifier();
176         }
177         Arrays.sort(identifiers);
178         return identifiers;
179     }
180 }