1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.neuralnet.oned;
19
20 import java.util.ArrayList;
21 import java.util.Collection;
22
23 import org.junit.Assert;
24 import org.junit.Test;
25
26 import org.apache.commons.rng.UniformRandomProvider;
27 import org.apache.commons.rng.simple.RandomSource;
28
29 import org.apache.commons.math4.neuralnet.FeatureInitializer;
30 import org.apache.commons.math4.neuralnet.FeatureInitializerFactory;
31 import org.apache.commons.math4.neuralnet.Network;
32 import org.apache.commons.math4.neuralnet.Neuron;
33
34
35
36
37
38 public class NeuronStringTest {
39 private final UniformRandomProvider rng = RandomSource.SPLIT_MIX_64.create();
40 private final FeatureInitializer init = FeatureInitializerFactory.uniform(rng, 0, 2);
41
42
43
44
45
46
47 @Test
48 public void testSegmentNetwork() {
49 final FeatureInitializer[] initArray = {init};
50 final NeuronString line = new NeuronString(4, false, initArray);
51 Assert.assertFalse(line.isWrapped());
52
53 final Network net = line.getNetwork();
54 Collection<Neuron> neighbours;
55
56
57 neighbours = net.getNeighbours(net.getNeuron(0));
58 for (long nId : new long[] {1}) {
59 Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
60 }
61
62 Assert.assertEquals(1, neighbours.size());
63
64
65 neighbours = net.getNeighbours(net.getNeuron(1));
66 for (long nId : new long[] {0, 2}) {
67 Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
68 }
69
70 Assert.assertEquals(2, neighbours.size());
71
72
73 neighbours = net.getNeighbours(net.getNeuron(2));
74 for (long nId : new long[] {1, 3}) {
75 Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
76 }
77
78 Assert.assertEquals(2, neighbours.size());
79
80
81 neighbours = net.getNeighbours(net.getNeuron(3));
82 for (long nId : new long[] {2}) {
83 Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
84 }
85
86 Assert.assertEquals(1, neighbours.size());
87 }
88
89
90
91
92
93
94 @Test
95 public void testCircleNetwork() {
96 final FeatureInitializer[] initArray = {init};
97 final NeuronString line = new NeuronString(4, true, initArray);
98 Assert.assertTrue(line.isWrapped());
99
100 final Network net = line.getNetwork();
101 Collection<Neuron> neighbours;
102
103
104 neighbours = net.getNeighbours(net.getNeuron(0));
105 for (long nId : new long[] {1, 3}) {
106 Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
107 }
108
109 Assert.assertEquals(2, neighbours.size());
110
111
112 neighbours = net.getNeighbours(net.getNeuron(1));
113 for (long nId : new long[] {0, 2}) {
114 Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
115 }
116
117 Assert.assertEquals(2, neighbours.size());
118
119
120 neighbours = net.getNeighbours(net.getNeuron(2));
121 for (long nId : new long[] {1, 3}) {
122 Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
123 }
124
125 Assert.assertEquals(2, neighbours.size());
126
127
128 neighbours = net.getNeighbours(net.getNeuron(3));
129 for (long nId : new long[] {0, 2}) {
130 Assert.assertTrue(neighbours.contains(net.getNeuron(nId)));
131 }
132
133 Assert.assertEquals(2, neighbours.size());
134 }
135
136
137
138
139
140
141 @Test
142 public void testGetNeighboursWithExclude() {
143 final FeatureInitializer[] initArray = {init};
144 final Network net = new NeuronString(5, true, initArray).getNetwork();
145 final Collection<Neuron> exclude = new ArrayList<>();
146 exclude.add(net.getNeuron(1));
147 final Collection<Neuron> neighbours = net.getNeighbours(net.getNeuron(0),
148 exclude);
149 Assert.assertTrue(neighbours.contains(net.getNeuron(4)));
150 Assert.assertEquals(1, neighbours.size());
151 }
152 }