View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.neuralnet.oned;
19  
20  import java.util.ArrayList;
21  import java.util.Collection;
22  
23  import org.junit.Assert;
24  import org.junit.Test;
25  
26  import org.apache.commons.rng.UniformRandomProvider;
27  import org.apache.commons.rng.simple.RandomSource;
28  
29  import org.apache.commons.math4.neuralnet.FeatureInitializer;
30  import org.apache.commons.math4.neuralnet.FeatureInitializerFactory;
31  import org.apache.commons.math4.neuralnet.Network;
32  import org.apache.commons.math4.neuralnet.Neuron;
33  
34  /**
35   * Tests for {@link NeuronString} and {@link Network} functionality for
36   * a one-dimensional network.
37   */
38  public class NeuronStringTest {
39      private final UniformRandomProvider rng = RandomSource.SPLIT_MIX_64.create();
40      private final FeatureInitializer init = FeatureInitializerFactory.uniform(rng, 0, 2);
41  
42      /*
43       * Test assumes that the network is
44       *
45       *  0-----1-----2-----3
46       */
47      @Test
48      public void testSegmentNetwork() {
49          final FeatureInitializer[] initArray = {init};
50          final NeuronString line = new NeuronString(4, false, initArray);
51          Assert.assertFalse(line.isWrapped());
52  
53          final Network net = line.getNetwork();
54          Collection<Neuron> neighbours;
55  
56          // Neuron 0.
57          neighbours = net.getNeighbours(net.getNeuron(0));
58          for (long nId : new long[] {1}) {
59              Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
60          }
61          // Ensures that no other neurons is in the neighbourhood set.
62          Assert.assertEquals(1, neighbours.size());
63  
64          // Neuron 1.
65          neighbours = net.getNeighbours(net.getNeuron(1));
66          for (long nId : new long[] {0, 2}) {
67              Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
68          }
69          // Ensures that no other neurons is in the neighbourhood set.
70          Assert.assertEquals(2, neighbours.size());
71  
72          // Neuron 2.
73          neighbours = net.getNeighbours(net.getNeuron(2));
74          for (long nId : new long[] {1, 3}) {
75              Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
76          }
77          // Ensures that no other neurons is in the neighbourhood set.
78          Assert.assertEquals(2, neighbours.size());
79  
80          // Neuron 3.
81          neighbours = net.getNeighbours(net.getNeuron(3));
82          for (long nId : new long[] {2}) {
83              Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
84          }
85          // Ensures that no other neurons is in the neighbourhood set.
86          Assert.assertEquals(1, neighbours.size());
87      }
88  
89      /*
90       * Test assumes that the network is
91       *
92       *  0-----1-----2-----3
93       */
94      @Test
95      public void testCircleNetwork() {
96          final FeatureInitializer[] initArray = {init};
97          final NeuronString line = new NeuronString(4, true, initArray);
98          Assert.assertTrue(line.isWrapped());
99  
100         final Network net = line.getNetwork();
101         Collection<Neuron> neighbours;
102 
103         // Neuron 0.
104         neighbours = net.getNeighbours(net.getNeuron(0));
105         for (long nId : new long[] {1, 3}) {
106             Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
107         }
108         // Ensures that no other neurons is in the neighbourhood set.
109         Assert.assertEquals(2, neighbours.size());
110 
111         // Neuron 1.
112         neighbours = net.getNeighbours(net.getNeuron(1));
113         for (long nId : new long[] {0, 2}) {
114             Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
115         }
116         // Ensures that no other neurons is in the neighbourhood set.
117         Assert.assertEquals(2, neighbours.size());
118 
119         // Neuron 2.
120         neighbours = net.getNeighbours(net.getNeuron(2));
121         for (long nId : new long[] {1, 3}) {
122             Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
123         }
124         // Ensures that no other neurons is in the neighbourhood set.
125         Assert.assertEquals(2, neighbours.size());
126 
127         // Neuron 3.
128         neighbours = net.getNeighbours(net.getNeuron(3));
129         for (long nId : new long[] {0, 2}) {
130             Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
131         }
132         // Ensures that no other neurons is in the neighbourhood set.
133         Assert.assertEquals(2, neighbours.size());
134     }
135 
136     /*
137      * Test assumes that the network is
138      *
139      *  0-----1-----2-----3-----4
140      */
141     @Test
142     public void testGetNeighboursWithExclude() {
143         final FeatureInitializer[] initArray = {init};
144         final Network net = new NeuronString(5, true, initArray).getNetwork();
145         final Collection<Neuron> exclude = new ArrayList<>();
146         exclude.add(net.getNeuron(1));
147         final Collection<Neuron> neighbours = net.getNeighbours(net.getNeuron(0),
148                                                                 exclude);
149         Assert.assertTrue(neighbours.contains(net.getNeuron(4)));
150         Assert.assertEquals(1, neighbours.size());
151     }
152 }