1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.neuralnet.twod;
19
20 import java.util.List;
21 import java.util.ArrayList;
22 import java.util.Iterator;
23 import java.util.Collection;
24
25 import org.apache.commons.math4.neuralnet.DistanceMeasure;
26 import org.apache.commons.math4.neuralnet.EuclideanDistance;
27 import org.apache.commons.math4.neuralnet.FeatureInitializer;
28 import org.apache.commons.math4.neuralnet.Network;
29 import org.apache.commons.math4.neuralnet.Neuron;
30 import org.apache.commons.math4.neuralnet.SquareNeighbourhood;
31 import org.apache.commons.math4.neuralnet.MapRanking;
32 import org.apache.commons.math4.neuralnet.internal.NeuralNetException;
33 import org.apache.commons.math4.neuralnet.twod.util.LocationFinder;
34
35
36
37
38
39
40
41
42
43
44
45
46 public class NeuronSquareMesh2D
47 implements Iterable<Neuron> {
48
49 private static final int MIN_ROWS = 2;
50
51 private final Network network;
52
53 private final int numberOfRows;
54
55 private final int numberOfColumns;
56
57 private final boolean wrapRows;
58
59 private final boolean wrapColumns;
60
61 private final SquareNeighbourhood neighbourhood;
62
63
64
65
66
67 private final long[][] identifiers;
68
69
70
71
72
73 public enum HorizontalDirection {
74
75 RIGHT,
76
77 CENTER,
78
79 LEFT,
80 }
81
82
83
84
85 public enum VerticalDirection {
86
87 UP,
88
89 CENTER,
90
91 DOWN,
92 }
93
94
95
96
97
98
99
100
101
102
103
104
105 public NeuronSquareMesh2D(boolean wrapRowDim,
106 boolean wrapColDim,
107 SquareNeighbourhood neighbourhoodType,
108 double[][][] featuresList) {
109 numberOfRows = featuresList.length;
110 numberOfColumns = featuresList[0].length;
111
112 if (numberOfRows < MIN_ROWS) {
113 throw new NeuralNetException(NeuralNetException.TOO_SMALL, numberOfRows, MIN_ROWS);
114 }
115 if (numberOfColumns < MIN_ROWS) {
116 throw new NeuralNetException(NeuralNetException.TOO_SMALL, numberOfColumns, MIN_ROWS);
117 }
118
119 wrapRows = wrapRowDim;
120 wrapColumns = wrapColDim;
121 neighbourhood = neighbourhoodType;
122
123 final int fLen = featuresList[0][0].length;
124 network = new Network(0, fLen);
125 identifiers = new long[numberOfRows][numberOfColumns];
126
127
128 for (int i = 0; i < numberOfRows; i++) {
129 for (int j = 0; j < numberOfColumns; j++) {
130 identifiers[i][j] = network.createNeuron(featuresList[i][j]);
131 }
132 }
133
134
135 createLinks();
136 }
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163 public NeuronSquareMesh2D(int numRows,
164 boolean wrapRowDim,
165 int numCols,
166 boolean wrapColDim,
167 SquareNeighbourhood neighbourhoodType,
168 FeatureInitializer[] featureInit) {
169 if (numRows < MIN_ROWS) {
170 throw new NeuralNetException(NeuralNetException.TOO_SMALL, numRows, MIN_ROWS);
171 }
172 if (numCols < MIN_ROWS) {
173 throw new NeuralNetException(NeuralNetException.TOO_SMALL, numCols, MIN_ROWS);
174 }
175
176 numberOfRows = numRows;
177 wrapRows = wrapRowDim;
178 numberOfColumns = numCols;
179 wrapColumns = wrapColDim;
180 neighbourhood = neighbourhoodType;
181 identifiers = new long[numberOfRows][numberOfColumns];
182
183 final int fLen = featureInit.length;
184 network = new Network(0, fLen);
185
186
187 for (int i = 0; i < numRows; i++) {
188 for (int j = 0; j < numCols; j++) {
189 final double[] features = new double[fLen];
190 for (int fIndex = 0; fIndex < fLen; fIndex++) {
191 features[fIndex] = featureInit[fIndex].value();
192 }
193 identifiers[i][j] = network.createNeuron(features);
194 }
195 }
196
197
198 createLinks();
199 }
200
201
202
203
204
205
206
207
208
209
210
211
212
213 private NeuronSquareMesh2D(boolean wrapRowDim,
214 boolean wrapColDim,
215 SquareNeighbourhood neighbourhoodType,
216 Network net,
217 long[][] idGrid) {
218 numberOfRows = idGrid.length;
219 numberOfColumns = idGrid[0].length;
220 wrapRows = wrapRowDim;
221 wrapColumns = wrapColDim;
222 neighbourhood = neighbourhoodType;
223 network = net;
224 identifiers = idGrid;
225 }
226
227
228
229
230
231
232
233
234
235 public synchronized NeuronSquareMesh2D copy() {
236 final long[][] idGrid = new long[numberOfRows][numberOfColumns];
237 for (int r = 0; r < numberOfRows; r++) {
238 System.arraycopy(identifiers[r], 0, idGrid[r], 0, numberOfColumns);
239 }
240
241 return new NeuronSquareMesh2D(wrapRows,
242 wrapColumns,
243 neighbourhood,
244 network.copy(),
245 idGrid);
246 }
247
248
249 @Override
250 public Iterator<Neuron> iterator() {
251 return network.iterator();
252 }
253
254
255
256
257
258
259
260
261
262
263 public Network getNetwork() {
264 return network;
265 }
266
267
268
269
270
271
272 public int getNumberOfRows() {
273 return numberOfRows;
274 }
275
276
277
278
279
280
281 public int getNumberOfColumns() {
282 return numberOfColumns;
283 }
284
285
286
287
288
289
290
291 public boolean isWrappedRow() {
292 return wrapRows;
293 }
294
295
296
297
298
299
300
301 public boolean isWrappedColumn() {
302 return wrapColumns;
303 }
304
305
306
307
308
309
310
311 public SquareNeighbourhood getSquareNeighbourhood() {
312 return neighbourhood;
313 }
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328 public Neuron getNeuron(int i,
329 int j) {
330 if (i < 0 ||
331 i >= numberOfRows) {
332 throw new NeuralNetException(NeuralNetException.OUT_OF_RANGE,
333 i, 0, numberOfRows - 1);
334 }
335 if (j < 0 ||
336 j >= numberOfColumns) {
337 throw new NeuralNetException(NeuralNetException.OUT_OF_RANGE,
338 i, 0, numberOfColumns - 1);
339 }
340
341 return network.getNeuron(identifiers[i][j]);
342 }
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361 public Neuron getNeuron(int row,
362 int col,
363 HorizontalDirection alongRowDir,
364 VerticalDirection alongColDir) {
365 final int[] location = getLocation(row, col, alongRowDir, alongColDir);
366
367 return location == null ? null : getNeuron(location[0], location[1]);
368 }
369
370
371
372
373
374
375
376
377 public DataVisualization computeQualityIndicators(Iterable<double[]> data) {
378 return DataVisualization.from(copy(), data);
379 }
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398 private int[] getLocation(int row,
399 int col,
400 HorizontalDirection alongRowDir,
401 VerticalDirection alongColDir) {
402 final int colOffset;
403 switch (alongRowDir) {
404 case LEFT:
405 colOffset = -1;
406 break;
407 case RIGHT:
408 colOffset = 1;
409 break;
410 case CENTER:
411 colOffset = 0;
412 break;
413 default:
414
415 throw new IllegalStateException();
416 }
417 int colIndex = col + colOffset;
418 if (wrapColumns) {
419 if (colIndex < 0) {
420 colIndex += numberOfColumns;
421 } else {
422 colIndex %= numberOfColumns;
423 }
424 }
425
426 final int rowOffset;
427 switch (alongColDir) {
428 case UP:
429 rowOffset = -1;
430 break;
431 case DOWN:
432 rowOffset = 1;
433 break;
434 case CENTER:
435 rowOffset = 0;
436 break;
437 default:
438
439 throw new IllegalStateException();
440 }
441 int rowIndex = row + rowOffset;
442 if (wrapRows) {
443 if (rowIndex < 0) {
444 rowIndex += numberOfRows;
445 } else {
446 rowIndex %= numberOfRows;
447 }
448 }
449
450 if (rowIndex < 0 ||
451 rowIndex >= numberOfRows ||
452 colIndex < 0 ||
453 colIndex >= numberOfColumns) {
454 return null;
455 } else {
456 return new int[] {rowIndex, colIndex};
457 }
458 }
459
460
461
462
463 private void createLinks() {
464
465 final List<Long> linkEnd = new ArrayList<>();
466 final int iLast = numberOfRows - 1;
467 final int jLast = numberOfColumns - 1;
468 for (int i = 0; i < numberOfRows; i++) {
469 for (int j = 0; j < numberOfColumns; j++) {
470 linkEnd.clear();
471
472 switch (neighbourhood) {
473
474 case MOORE:
475
476 if (i > 0) {
477 if (j > 0) {
478 linkEnd.add(identifiers[i - 1][j - 1]);
479 }
480 if (j < jLast) {
481 linkEnd.add(identifiers[i - 1][j + 1]);
482 }
483 }
484 if (i < iLast) {
485 if (j > 0) {
486 linkEnd.add(identifiers[i + 1][j - 1]);
487 }
488 if (j < jLast) {
489 linkEnd.add(identifiers[i + 1][j + 1]);
490 }
491 }
492 if (wrapRows) {
493 if (i == 0) {
494 if (j > 0) {
495 linkEnd.add(identifiers[iLast][j - 1]);
496 }
497 if (j < jLast) {
498 linkEnd.add(identifiers[iLast][j + 1]);
499 }
500 } else if (i == iLast) {
501 if (j > 0) {
502 linkEnd.add(identifiers[0][j - 1]);
503 }
504 if (j < jLast) {
505 linkEnd.add(identifiers[0][j + 1]);
506 }
507 }
508 }
509 if (wrapColumns) {
510 if (j == 0) {
511 if (i > 0) {
512 linkEnd.add(identifiers[i - 1][jLast]);
513 }
514 if (i < iLast) {
515 linkEnd.add(identifiers[i + 1][jLast]);
516 }
517 } else if (j == jLast) {
518 if (i > 0) {
519 linkEnd.add(identifiers[i - 1][0]);
520 }
521 if (i < iLast) {
522 linkEnd.add(identifiers[i + 1][0]);
523 }
524 }
525 }
526 if (wrapRows &&
527 wrapColumns) {
528 if (i == 0 &&
529 j == 0) {
530 linkEnd.add(identifiers[iLast][jLast]);
531 } else if (i == 0 &&
532 j == jLast) {
533 linkEnd.add(identifiers[iLast][0]);
534 } else if (i == iLast &&
535 j == 0) {
536 linkEnd.add(identifiers[0][jLast]);
537 } else if (i == iLast &&
538 j == jLast) {
539 linkEnd.add(identifiers[0][0]);
540 }
541 }
542
543
544
545
546
547
548 case VON_NEUMANN:
549
550 if (i > 0) {
551 linkEnd.add(identifiers[i - 1][j]);
552 }
553 if (i < iLast) {
554 linkEnd.add(identifiers[i + 1][j]);
555 }
556 if (wrapRows) {
557 if (i == 0) {
558 linkEnd.add(identifiers[iLast][j]);
559 } else if (i == iLast) {
560 linkEnd.add(identifiers[0][j]);
561 }
562 }
563
564
565 if (j > 0) {
566 linkEnd.add(identifiers[i][j - 1]);
567 }
568 if (j < jLast) {
569 linkEnd.add(identifiers[i][j + 1]);
570 }
571 if (wrapColumns) {
572 if (j == 0) {
573 linkEnd.add(identifiers[i][jLast]);
574 } else if (j == jLast) {
575 linkEnd.add(identifiers[i][0]);
576 }
577 }
578 break;
579
580 default:
581 throw new IllegalStateException();
582 }
583
584 final Neuron aNeuron = network.getNeuron(identifiers[i][j]);
585 for (final long b : linkEnd) {
586 final Neuron bNeuron = network.getNeuron(b);
587
588
589 network.addLink(aNeuron, bNeuron);
590 }
591 }
592 }
593 }
594
595
596
597
598
599
600
601
602
603
604 public static final class DataVisualization {
605
606 private static final DistanceMeasure DISTANCE = new EuclideanDistance();
607
608 private final int numberOfSamples;
609
610 private final double[][] hitHistogram;
611
612 private final double[][] quantizationError;
613
614 private final double meanQuantizationError;
615
616 private final double[][] topographicError;
617
618 private final double meanTopographicError;
619
620 private final double[][] uMatrix;
621
622
623
624
625
626
627
628
629 private DataVisualization(int numberOfSamples,
630 double[][] hitHistogram,
631 double[][] quantizationError,
632 double[][] topographicError,
633 double[][] uMatrix) {
634 this.numberOfSamples = numberOfSamples;
635 this.hitHistogram = hitHistogram;
636 this.quantizationError = quantizationError;
637 meanQuantizationError = hitWeightedMean(quantizationError, hitHistogram);
638 this.topographicError = topographicError;
639 meanTopographicError = hitWeightedMean(topographicError, hitHistogram);
640 this.uMatrix = uMatrix;
641 }
642
643
644
645
646
647
648 static DataVisualization from(NeuronSquareMesh2D map,
649 Iterable<double[]> data) {
650 final LocationFinder finder = new LocationFinder(map);
651 final MapRanking rank = new MapRanking(map, DISTANCE);
652 final Network net = map.getNetwork();
653 final int nR = map.getNumberOfRows();
654 final int nC = map.getNumberOfColumns();
655
656
657 final int[][] hitCounter = new int[nR][nC];
658
659 final double[][] hitHistogram = new double[nR][nC];
660
661 final double[][] quantizationError = new double[nR][nC];
662
663 final double[][] topographicError = new double[nR][nC];
664
665 final double[][] uMatrix = new double[nR][nC];
666
667 int numSamples = 0;
668 for (final double[] sample : data) {
669 ++numSamples;
670
671 final List<Neuron> winners = rank.rank(sample, 2);
672 final Neuron best = winners.get(0);
673 final Neuron secondBest = winners.get(1);
674
675 final LocationFinder.Location locBest = finder.getLocation(best);
676 final int rowBest = locBest.getRow();
677 final int colBest = locBest.getColumn();
678
679 hitCounter[rowBest][colBest] += 1;
680
681
682 quantizationError[rowBest][colBest] += DISTANCE.applyAsDouble(sample, best.getFeatures());
683
684
685 if (!net.getNeighbours(best).contains(secondBest)) {
686
687
688 topographicError[rowBest][colBest] += 1;
689 }
690 }
691
692 for (int r = 0; r < nR; r++) {
693 for (int c = 0; c < nC; c++) {
694 final Neuron neuron = map.getNeuron(r, c);
695 final Collection<Neuron> neighbours = net.getNeighbours(neuron);
696 final double[] features = neuron.getFeatures();
697 double uDistance = 0;
698 int neighbourCount = 0;
699 for (final Neuron n : neighbours) {
700 ++neighbourCount;
701 uDistance += DISTANCE.applyAsDouble(features, n.getFeatures());
702 }
703
704 final int hitCount = hitCounter[r][c];
705 if (hitCount != 0) {
706 hitHistogram[r][c] = hitCount / (double) numSamples;
707 quantizationError[r][c] /= hitCount;
708 topographicError[r][c] /= hitCount;
709 }
710
711 uMatrix[r][c] = uDistance / neighbourCount;
712 }
713 }
714
715 return new DataVisualization(numSamples,
716 hitHistogram,
717 quantizationError,
718 topographicError,
719 uMatrix);
720 }
721
722
723
724
725 public int getNumberOfSamples() {
726 return numberOfSamples;
727 }
728
729
730
731
732
733
734
735 public double[][] getQuantizationError() {
736 return copy(quantizationError);
737 }
738
739
740
741
742
743
744
745 public double[][] getTopographicError() {
746 return copy(topographicError);
747 }
748
749
750
751
752
753
754 public double[][] getNormalizedHits() {
755 return copy(hitHistogram);
756 }
757
758
759
760
761
762
763
764
765
766 public double[][] getUMatrix() {
767 return copy(uMatrix);
768 }
769
770
771
772
773
774 public double getMeanQuantizationError() {
775 return meanQuantizationError;
776 }
777
778
779
780
781
782 public double getMeanTopographicError() {
783 return meanTopographicError;
784 }
785
786
787
788
789
790 private static double[][] copy(double[][] orig) {
791 final double[][] copy = new double[orig.length][];
792 for (int i = 0; i < orig.length; i++) {
793 copy[i] = orig[i].clone();
794 }
795
796 return copy;
797 }
798
799
800
801
802
803
804 private static double hitWeightedMean(double[][] metrics,
805 double[][] normalizedHits) {
806 double mean = 0;
807 final int rows = metrics.length;
808 final int cols = metrics[0].length;
809 for (int i = 0; i < rows; i++) {
810 for (int j = 0; j < cols; j++) {
811 mean += normalizedHits[i][j] * metrics[i][j];
812 }
813 }
814
815 return mean;
816 }
817 }
818 }