1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.neuralnet.twod.util;
19
20 import org.apache.commons.rng.UniformRandomProvider;
21 import org.apache.commons.rng.simple.RandomSource;
22
23 import org.apache.commons.math4.neuralnet.Network;
24 import org.apache.commons.math4.neuralnet.FeatureInitializer;
25 import org.apache.commons.math4.neuralnet.FeatureInitializerFactory;
26 import org.apache.commons.math4.neuralnet.SquareNeighbourhood;
27 import org.apache.commons.math4.neuralnet.twod.NeuronSquareMesh2D;
28 import org.junit.Assert;
29 import org.junit.Test;
30
31
32
33
34 public class LocationFinderTest {
35 private final UniformRandomProvider rng = RandomSource.SPLIT_MIX_64.create();
36 private final FeatureInitializer init = FeatureInitializerFactory.uniform(rng, 0, 2);
37
38
39
40
41
42
43
44
45
46 @Test
47 public void test2x2Network() {
48 final FeatureInitializer[] initArray = {init};
49 final NeuronSquareMesh2D map = new NeuronSquareMesh2D(2, false,
50 2, false,
51 SquareNeighbourhood.VON_NEUMANN,
52 initArray);
53 final LocationFinder finder = new LocationFinder(map);
54 final Network net = map.getNetwork();
55 LocationFinder.Location loc;
56
57 loc = finder.getLocation(net.getNeuron(0));
58 Assert.assertEquals(0, loc.getRow());
59 Assert.assertEquals(0, loc.getColumn());
60
61 loc = finder.getLocation(net.getNeuron(1));
62 Assert.assertEquals(0, loc.getRow());
63 Assert.assertEquals(1, loc.getColumn());
64
65 loc = finder.getLocation(net.getNeuron(2));
66 Assert.assertEquals(1, loc.getRow());
67 Assert.assertEquals(0, loc.getColumn());
68
69 loc = finder.getLocation(net.getNeuron(3));
70 Assert.assertEquals(1, loc.getRow());
71 Assert.assertEquals(1, loc.getColumn());
72 }
73 }