NeuronSquareMesh2D.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.neuralnet.twod;

  18. import java.util.List;
  19. import java.util.ArrayList;
  20. import java.util.Iterator;
  21. import java.util.Collection;

  22. import org.apache.commons.math4.neuralnet.DistanceMeasure;
  23. import org.apache.commons.math4.neuralnet.EuclideanDistance;
  24. import org.apache.commons.math4.neuralnet.FeatureInitializer;
  25. import org.apache.commons.math4.neuralnet.Network;
  26. import org.apache.commons.math4.neuralnet.Neuron;
  27. import org.apache.commons.math4.neuralnet.SquareNeighbourhood;
  28. import org.apache.commons.math4.neuralnet.MapRanking;
  29. import org.apache.commons.math4.neuralnet.internal.NeuralNetException;
  30. import org.apache.commons.math4.neuralnet.twod.util.LocationFinder;

  31. /**
  32.  * Neural network with the topology of a two-dimensional surface.
  33.  * Each neuron defines one surface element.
  34.  * <br>
  35.  * This network is primarily intended to represent a
  36.  * <a href="http://en.wikipedia.org/wiki/Kohonen">
  37.  *  Self Organizing Feature Map</a>.
  38.  *
  39.  * @see org.apache.commons.math4.neuralnet.sofm
  40.  * @since 3.3
  41.  */
  42. public class NeuronSquareMesh2D
  43.     implements Iterable<Neuron> {
  44.     /** Minimal number of rows or columns. */
  45.     private static final int MIN_ROWS = 2;
  46.     /** Underlying network. */
  47.     private final Network network;
  48.     /** Number of rows. */
  49.     private final int numberOfRows;
  50.     /** Number of columns. */
  51.     private final int numberOfColumns;
  52.     /** Wrap. */
  53.     private final boolean wrapRows;
  54.     /** Wrap. */
  55.     private final boolean wrapColumns;
  56.     /** Neighbourhood type. */
  57.     private final SquareNeighbourhood neighbourhood;
  58.     /**
  59.      * Mapping of the 2D coordinates (in the rectangular mesh) to
  60.      * the neuron identifiers (attributed by the {@link #network}
  61.      * instance).
  62.      */
  63.     private final long[][] identifiers;

  64.     /**
  65.      * Horizontal (along row) direction.
  66.      * @since 3.6
  67.      */
  68.     public enum HorizontalDirection {
  69.         /** Column at the right of the current column. */
  70.        RIGHT,
  71.        /** Current column. */
  72.        CENTER,
  73.        /** Column at the left of the current column. */
  74.        LEFT,
  75.     }
  76.     /**
  77.      * Vertical (along column) direction.
  78.      * @since 3.6
  79.      */
  80.     public enum VerticalDirection {
  81.         /** Row above the current row. */
  82.         UP,
  83.         /** Current row. */
  84.         CENTER,
  85.         /** Row below the current row. */
  86.         DOWN,
  87.     }

  88.     /**
  89.      * @param wrapRowDim Whether to wrap the first dimension (i.e the first
  90.      * and last neurons will be linked together).
  91.      * @param wrapColDim Whether to wrap the second dimension (i.e the first
  92.      * and last neurons will be linked together).
  93.      * @param neighbourhoodType Neighbourhood type.
  94.      * @param featuresList Arrays that will initialize the features sets of
  95.      * the network's neurons.
  96.      * @throws IllegalArgumentException if {@code numRows < 2} or
  97.      * {@code numCols < 2}.
  98.      */
  99.     public NeuronSquareMesh2D(boolean wrapRowDim,
  100.                               boolean wrapColDim,
  101.                               SquareNeighbourhood neighbourhoodType,
  102.                               double[][][] featuresList) {
  103.         numberOfRows = featuresList.length;
  104.         numberOfColumns = featuresList[0].length;

  105.         if (numberOfRows < MIN_ROWS) {
  106.             throw new NeuralNetException(NeuralNetException.TOO_SMALL, numberOfRows, MIN_ROWS);
  107.         }
  108.         if (numberOfColumns < MIN_ROWS) {
  109.             throw new NeuralNetException(NeuralNetException.TOO_SMALL, numberOfColumns, MIN_ROWS);
  110.         }

  111.         wrapRows = wrapRowDim;
  112.         wrapColumns = wrapColDim;
  113.         neighbourhood = neighbourhoodType;

  114.         final int fLen = featuresList[0][0].length;
  115.         network = new Network(0, fLen);
  116.         identifiers = new long[numberOfRows][numberOfColumns];

  117.         // Add neurons.
  118.         for (int i = 0; i < numberOfRows; i++) {
  119.             for (int j = 0; j < numberOfColumns; j++) {
  120.                 identifiers[i][j] = network.createNeuron(featuresList[i][j]);
  121.             }
  122.         }

  123.         // Add links.
  124.         createLinks();
  125.     }

  126.     /**
  127.      * Creates a two-dimensional network composed of square cells:
  128.      * Each neuron not located on the border of the mesh has four
  129.      * neurons linked to it.
  130.      * <br>
  131.      * The links are bi-directional.
  132.      * <br>
  133.      * The topology of the network can also be a cylinder (if one
  134.      * of the dimensions is wrapped) or a torus (if both dimensions
  135.      * are wrapped).
  136.      *
  137.      * @param numRows Number of neurons in the first dimension.
  138.      * @param wrapRowDim Whether to wrap the first dimension (i.e the first
  139.      * and last neurons will be linked together).
  140.      * @param numCols Number of neurons in the second dimension.
  141.      * @param wrapColDim Whether to wrap the second dimension (i.e the first
  142.      * and last neurons will be linked together).
  143.      * @param neighbourhoodType Neighbourhood type.
  144.      * @param featureInit Array of functions that will initialize the
  145.      * corresponding element of the features set of each newly created
  146.      * neuron. In particular, the size of this array defines the size of
  147.      * feature set.
  148.      * @throws IllegalArgumentException if {@code numRows < 2} or
  149.      * {@code numCols < 2}.
  150.      */
  151.     public NeuronSquareMesh2D(int numRows,
  152.                               boolean wrapRowDim,
  153.                               int numCols,
  154.                               boolean wrapColDim,
  155.                               SquareNeighbourhood neighbourhoodType,
  156.                               FeatureInitializer[] featureInit) {
  157.         if (numRows < MIN_ROWS) {
  158.             throw new NeuralNetException(NeuralNetException.TOO_SMALL, numRows, MIN_ROWS);
  159.         }
  160.         if (numCols < MIN_ROWS) {
  161.             throw new NeuralNetException(NeuralNetException.TOO_SMALL, numCols, MIN_ROWS);
  162.         }

  163.         numberOfRows = numRows;
  164.         wrapRows = wrapRowDim;
  165.         numberOfColumns = numCols;
  166.         wrapColumns = wrapColDim;
  167.         neighbourhood = neighbourhoodType;
  168.         identifiers = new long[numberOfRows][numberOfColumns];

  169.         final int fLen = featureInit.length;
  170.         network = new Network(0, fLen);

  171.         // Add neurons.
  172.         for (int i = 0; i < numRows; i++) {
  173.             for (int j = 0; j < numCols; j++) {
  174.                 final double[] features = new double[fLen];
  175.                 for (int fIndex = 0; fIndex < fLen; fIndex++) {
  176.                     features[fIndex] = featureInit[fIndex].value();
  177.                 }
  178.                 identifiers[i][j] = network.createNeuron(features);
  179.             }
  180.         }

  181.         // Add links.
  182.         createLinks();
  183.     }

  184.     /**
  185.      * Constructor with restricted access, solely used for making a
  186.      * {@link #copy() deep copy}.
  187.      *
  188.      * @param wrapRowDim Whether to wrap the first dimension (i.e the first
  189.      * and last neurons will be linked together).
  190.      * @param wrapColDim Whether to wrap the second dimension (i.e the first
  191.      * and last neurons will be linked together).
  192.      * @param neighbourhoodType Neighbourhood type.
  193.      * @param net Underlying network.
  194.      * @param idGrid Neuron identifiers.
  195.      */
  196.     private NeuronSquareMesh2D(boolean wrapRowDim,
  197.                                boolean wrapColDim,
  198.                                SquareNeighbourhood neighbourhoodType,
  199.                                Network net,
  200.                                long[][] idGrid) {
  201.         numberOfRows = idGrid.length;
  202.         numberOfColumns = idGrid[0].length;
  203.         wrapRows = wrapRowDim;
  204.         wrapColumns = wrapColDim;
  205.         neighbourhood = neighbourhoodType;
  206.         network = net;
  207.         identifiers = idGrid;
  208.     }

  209.     /**
  210.      * Performs a deep copy of this instance.
  211.      * Upon return, the copied and original instances will be independent:
  212.      * Updating one will not affect the other.
  213.      *
  214.      * @return a new instance with the same state as this instance.
  215.      * @since 3.6
  216.      */
  217.     public synchronized NeuronSquareMesh2D copy() {
  218.         final long[][] idGrid = new long[numberOfRows][numberOfColumns];
  219.         for (int r = 0; r < numberOfRows; r++) {
  220.             System.arraycopy(identifiers[r], 0, idGrid[r], 0, numberOfColumns);
  221.         }

  222.         return new NeuronSquareMesh2D(wrapRows,
  223.                                       wrapColumns,
  224.                                       neighbourhood,
  225.                                       network.copy(),
  226.                                       idGrid);
  227.     }

  228.     /** {@inheritDoc} */
  229.     @Override
  230.     public Iterator<Neuron> iterator() {
  231.         return network.iterator();
  232.     }

  233.     /**
  234.      * Retrieves the underlying network.
  235.      * A reference is returned (enabling, for example, the network to be
  236.      * trained).
  237.      * This also implies that calling methods that modify the {@link Network}
  238.      * topology may cause this class to become inconsistent.
  239.      *
  240.      * @return the network.
  241.      */
  242.     public Network getNetwork() {
  243.         return network;
  244.     }

  245.     /**
  246.      * Gets the number of neurons in each row of this map.
  247.      *
  248.      * @return the number of rows.
  249.      */
  250.     public int getNumberOfRows() {
  251.         return numberOfRows;
  252.     }

  253.     /**
  254.      * Gets the number of neurons in each column of this map.
  255.      *
  256.      * @return the number of column.
  257.      */
  258.     public int getNumberOfColumns() {
  259.         return numberOfColumns;
  260.     }

  261.     /**
  262.      * Indicates whether the map is wrapped along the first dimension.
  263.      *
  264.      * @return {@code true} if the last neuron of a row is linked to
  265.      * the first neuron of that row.
  266.      */
  267.     public boolean isWrappedRow() {
  268.         return wrapRows;
  269.     }

  270.     /**
  271.      * Indicates whether the map is wrapped along the second dimension.
  272.      *
  273.      * @return {@code true} if the last neuron of a column is linked to
  274.      * the first neuron of that column.
  275.      */
  276.     public boolean isWrappedColumn() {
  277.         return wrapColumns;
  278.     }

  279.     /**
  280.      * Indicates the {@link SquareNeighbourhood type of connectivity}
  281.      * between neurons.
  282.      *
  283.      * @return the neighbourhood type.
  284.      */
  285.     public SquareNeighbourhood getSquareNeighbourhood() {
  286.         return neighbourhood;
  287.     }

  288.     /**
  289.      * Retrieves the neuron at location {@code (i, j)} in the map.
  290.      * The neuron at position {@code (0, 0)} is located at the upper-left
  291.      * corner of the map.
  292.      *
  293.      * @param i Row index.
  294.      * @param j Column index.
  295.      * @return the neuron at {@code (i, j)}.
  296.      * @throws IllegalArgumentException if {@code i} or {@code j} is
  297.      * out of range.
  298.      *
  299.      * @see #getNeuron(int,int,HorizontalDirection,VerticalDirection)
  300.      */
  301.     public Neuron getNeuron(int i,
  302.                             int j) {
  303.         if (i < 0 ||
  304.             i >= numberOfRows) {
  305.             throw new NeuralNetException(NeuralNetException.OUT_OF_RANGE,
  306.                                          i, 0, numberOfRows - 1);
  307.         }
  308.         if (j < 0 ||
  309.             j >= numberOfColumns) {
  310.             throw new NeuralNetException(NeuralNetException.OUT_OF_RANGE,
  311.                                          i, 0, numberOfColumns - 1);
  312.         }

  313.         return network.getNeuron(identifiers[i][j]);
  314.     }

  315.     /**
  316.      * Retrieves the requested neuron relative to the given {@code (row, col)}
  317.      * position.
  318.      * The neuron at position {@code (0, 0)} is located at the upper-left
  319.      * corner of the map.
  320.      *
  321.      * @param row Row index.
  322.      * @param col Column index.
  323.      * @param alongRowDir Direction along the given {@code row} (i.e. an
  324.      * offset will be added to the given <em>column</em> index.
  325.      * @param alongColDir Direction along the given {@code col} (i.e. an
  326.      * offset will be added to the given <em>row</em> index.
  327.      * @return the neuron at the requested location, or {@code null} if
  328.      * the location is not on the map.
  329.      *
  330.      * @see #getNeuron(int,int)
  331.      */
  332.     public Neuron getNeuron(int row,
  333.                             int col,
  334.                             HorizontalDirection alongRowDir,
  335.                             VerticalDirection alongColDir) {
  336.         final int[] location = getLocation(row, col, alongRowDir, alongColDir);

  337.         return location == null ? null : getNeuron(location[0], location[1]);
  338.     }

  339.     /**
  340.      * Computes various {@link DataVisualization indicators} of the quality
  341.      * of the representation of the given {@code data} by this map.
  342.      *
  343.      * @param data Features.
  344.      * @return a new instance holding quality indicators.
  345.      */
  346.     public DataVisualization computeQualityIndicators(Iterable<double[]> data) {
  347.         return DataVisualization.from(copy(), data);
  348.     }

  349.     /**
  350.      * Computes the location of a neighbouring neuron.
  351.      * Returns {@code null} if the resulting location is not part
  352.      * of the map.
  353.      * Position {@code (0, 0)} is at the upper-left corner of the map.
  354.      *
  355.      * @param row Row index.
  356.      * @param col Column index.
  357.      * @param alongRowDir Direction along the given {@code row} (i.e. an
  358.      * offset will be added to the given <em>column</em> index.
  359.      * @param alongColDir Direction along the given {@code col} (i.e. an
  360.      * offset will be added to the given <em>row</em> index.
  361.      * @return an array of length 2 containing the indices of the requested
  362.      * location, or {@code null} if that location is not part of the map.
  363.      *
  364.      * @see #getNeuron(int,int)
  365.      */
  366.     private int[] getLocation(int row,
  367.                               int col,
  368.                               HorizontalDirection alongRowDir,
  369.                               VerticalDirection alongColDir) {
  370.         final int colOffset;
  371.         switch (alongRowDir) {
  372.         case LEFT:
  373.             colOffset = -1;
  374.             break;
  375.         case RIGHT:
  376.             colOffset = 1;
  377.             break;
  378.         case CENTER:
  379.             colOffset = 0;
  380.             break;
  381.         default:
  382.             // Should never happen.
  383.             throw new IllegalStateException();
  384.         }
  385.         int colIndex = col + colOffset;
  386.         if (wrapColumns) {
  387.             if (colIndex < 0) {
  388.                 colIndex += numberOfColumns;
  389.             } else {
  390.                 colIndex %= numberOfColumns;
  391.             }
  392.         }

  393.         final int rowOffset;
  394.         switch (alongColDir) {
  395.         case UP:
  396.             rowOffset = -1;
  397.             break;
  398.         case DOWN:
  399.             rowOffset = 1;
  400.             break;
  401.         case CENTER:
  402.             rowOffset = 0;
  403.             break;
  404.         default:
  405.             // Should never happen.
  406.             throw new IllegalStateException();
  407.         }
  408.         int rowIndex = row + rowOffset;
  409.         if (wrapRows) {
  410.             if (rowIndex < 0) {
  411.                 rowIndex += numberOfRows;
  412.             } else {
  413.                 rowIndex %= numberOfRows;
  414.             }
  415.         }

  416.         if (rowIndex < 0 ||
  417.             rowIndex >= numberOfRows ||
  418.             colIndex < 0 ||
  419.             colIndex >= numberOfColumns) {
  420.             return null;
  421.         } else {
  422.             return new int[] {rowIndex, colIndex};
  423.         }
  424.     }

  425.     /**
  426.      * Creates the neighbour relationships between neurons.
  427.      */
  428.     private void createLinks() {
  429.         // "linkEnd" will store the identifiers of the "neighbours".
  430.         final List<Long> linkEnd = new ArrayList<>();
  431.         final int iLast = numberOfRows - 1;
  432.         final int jLast = numberOfColumns - 1;
  433.         for (int i = 0; i < numberOfRows; i++) {
  434.             for (int j = 0; j < numberOfColumns; j++) {
  435.                 linkEnd.clear();

  436.                 switch (neighbourhood) {

  437.                 case MOORE:
  438.                     // Add links to "diagonal" neighbours.
  439.                     if (i > 0) {
  440.                         if (j > 0) {
  441.                             linkEnd.add(identifiers[i - 1][j - 1]);
  442.                         }
  443.                         if (j < jLast) {
  444.                             linkEnd.add(identifiers[i - 1][j + 1]);
  445.                         }
  446.                     }
  447.                     if (i < iLast) {
  448.                         if (j > 0) {
  449.                             linkEnd.add(identifiers[i + 1][j - 1]);
  450.                         }
  451.                         if (j < jLast) {
  452.                             linkEnd.add(identifiers[i + 1][j + 1]);
  453.                         }
  454.                     }
  455.                     if (wrapRows) {
  456.                         if (i == 0) {
  457.                             if (j > 0) {
  458.                                 linkEnd.add(identifiers[iLast][j - 1]);
  459.                             }
  460.                             if (j < jLast) {
  461.                                 linkEnd.add(identifiers[iLast][j + 1]);
  462.                             }
  463.                         } else if (i == iLast) {
  464.                             if (j > 0) {
  465.                                 linkEnd.add(identifiers[0][j - 1]);
  466.                             }
  467.                             if (j < jLast) {
  468.                                 linkEnd.add(identifiers[0][j + 1]);
  469.                             }
  470.                         }
  471.                     }
  472.                     if (wrapColumns) {
  473.                         if (j == 0) {
  474.                             if (i > 0) {
  475.                                 linkEnd.add(identifiers[i - 1][jLast]);
  476.                             }
  477.                             if (i < iLast) {
  478.                                 linkEnd.add(identifiers[i + 1][jLast]);
  479.                             }
  480.                         } else if (j == jLast) {
  481.                             if (i > 0) {
  482.                                 linkEnd.add(identifiers[i - 1][0]);
  483.                             }
  484.                             if (i < iLast) {
  485.                                 linkEnd.add(identifiers[i + 1][0]);
  486.                             }
  487.                         }
  488.                     }
  489.                     if (wrapRows &&
  490.                         wrapColumns) {
  491.                         if (i == 0 &&
  492.                             j == 0) {
  493.                             linkEnd.add(identifiers[iLast][jLast]);
  494.                         } else if (i == 0 &&
  495.                                    j == jLast) {
  496.                             linkEnd.add(identifiers[iLast][0]);
  497.                         } else if (i == iLast &&
  498.                                    j == 0) {
  499.                             linkEnd.add(identifiers[0][jLast]);
  500.                         } else if (i == iLast &&
  501.                                    j == jLast) {
  502.                             linkEnd.add(identifiers[0][0]);
  503.                         }
  504.                     }

  505.                     // Case falls through since the "Moore" neighbourhood
  506.                     // also contains the neurons that belong to the "Von
  507.                     // Neumann" neighbourhood.

  508.                     // fallthru (CheckStyle)
  509.                 case VON_NEUMANN:
  510.                     // Links to preceding and following "row".
  511.                     if (i > 0) {
  512.                         linkEnd.add(identifiers[i - 1][j]);
  513.                     }
  514.                     if (i < iLast) {
  515.                         linkEnd.add(identifiers[i + 1][j]);
  516.                     }
  517.                     if (wrapRows) {
  518.                         if (i == 0) {
  519.                             linkEnd.add(identifiers[iLast][j]);
  520.                         } else if (i == iLast) {
  521.                             linkEnd.add(identifiers[0][j]);
  522.                         }
  523.                     }

  524.                     // Links to preceding and following "column".
  525.                     if (j > 0) {
  526.                         linkEnd.add(identifiers[i][j - 1]);
  527.                     }
  528.                     if (j < jLast) {
  529.                         linkEnd.add(identifiers[i][j + 1]);
  530.                     }
  531.                     if (wrapColumns) {
  532.                         if (j == 0) {
  533.                             linkEnd.add(identifiers[i][jLast]);
  534.                         } else if (j == jLast) {
  535.                             linkEnd.add(identifiers[i][0]);
  536.                         }
  537.                     }
  538.                     break;

  539.                 default:
  540.                     throw new IllegalStateException(); // Cannot happen.
  541.                 }

  542.                 final Neuron aNeuron = network.getNeuron(identifiers[i][j]);
  543.                 for (final long b : linkEnd) {
  544.                     final Neuron bNeuron = network.getNeuron(b);
  545.                     // Link to all neighbours.
  546.                     // The reverse links will be added as the loop proceeds.
  547.                     network.addLink(aNeuron, bNeuron);
  548.                 }
  549.             }
  550.         }
  551.     }

  552.     /**
  553.      * Miscellaneous indicators of the map quality.
  554.      * <ul>
  555.      *  <li>Hit histogram</li>
  556.      *  <li>Quantization error</li>
  557.      *  <li>Topographic error</li>
  558.      *  <li>Unified distance matrix</li>
  559.      * </ul>
  560.      */
  561.     public static final class DataVisualization {
  562.         /** Distance function. */
  563.         private static final DistanceMeasure DISTANCE = new EuclideanDistance();
  564.         /** Total number of samples. */
  565.         private final int numberOfSamples;
  566.         /** Hit histogram. */
  567.         private final double[][] hitHistogram;
  568.         /** Quantization error. */
  569.         private final double[][] quantizationError;
  570.         /** Mean quantization error. */
  571.         private final double meanQuantizationError;
  572.         /** Topographic error. */
  573.         private final double[][] topographicError;
  574.         /** Mean topographic error. */
  575.         private final double meanTopographicError;
  576.         /** U-matrix. */
  577.         private final double[][] uMatrix;

  578.         /**
  579.          * @param numberOfSamples Number of samples.
  580.          * @param hitHistogram Hit histogram.
  581.          * @param quantizationError Quantization error.
  582.          * @param topographicError Topographic error.
  583.          * @param uMatrix U-matrix.
  584.          */
  585.         private DataVisualization(int numberOfSamples,
  586.                                   double[][] hitHistogram,
  587.                                   double[][] quantizationError,
  588.                                   double[][] topographicError,
  589.                                   double[][] uMatrix) {
  590.             this.numberOfSamples = numberOfSamples;
  591.             this.hitHistogram = hitHistogram;
  592.             this.quantizationError = quantizationError;
  593.             meanQuantizationError = hitWeightedMean(quantizationError, hitHistogram);
  594.             this.topographicError = topographicError;
  595.             meanTopographicError = hitWeightedMean(topographicError, hitHistogram);
  596.             this.uMatrix = uMatrix;
  597.         }

  598.         /**
  599.          * @param map Map
  600.          * @param data Data.
  601.          * @return the metrics.
  602.          */
  603.         static DataVisualization from(NeuronSquareMesh2D map,
  604.                                       Iterable<double[]> data) {
  605.             final LocationFinder finder = new LocationFinder(map);
  606.             final MapRanking rank = new MapRanking(map, DISTANCE);
  607.             final Network net = map.getNetwork();
  608.             final int nR = map.getNumberOfRows();
  609.             final int nC = map.getNumberOfColumns();

  610.             // Hit bins.
  611.             final int[][] hitCounter = new int[nR][nC];
  612.             // Hit bins.
  613.             final double[][] hitHistogram = new double[nR][nC];
  614.             // Quantization error bins.
  615.             final double[][] quantizationError = new double[nR][nC];
  616.             // Topographic error bins.
  617.             final double[][] topographicError = new double[nR][nC];
  618.             // U-matrix.
  619.             final double[][] uMatrix = new double[nR][nC];

  620.             int numSamples = 0;
  621.             for (final double[] sample : data) {
  622.                 ++numSamples;

  623.                 final List<Neuron> winners = rank.rank(sample, 2);
  624.                 final Neuron best = winners.get(0);
  625.                 final Neuron secondBest = winners.get(1);

  626.                 final LocationFinder.Location locBest = finder.getLocation(best);
  627.                 final int rowBest = locBest.getRow();
  628.                 final int colBest = locBest.getColumn();
  629.                 // Increment hit counter.
  630.                 hitCounter[rowBest][colBest] += 1;

  631.                 // Aggregate quantization error.
  632.                 quantizationError[rowBest][colBest] += DISTANCE.applyAsDouble(sample, best.getFeatures());

  633.                 // Aggregate topographic error.
  634.                 if (!net.getNeighbours(best).contains(secondBest)) {
  635.                     // Increment count if first and second best matching units
  636.                     // are not neighbours.
  637.                     topographicError[rowBest][colBest] += 1;
  638.                 }
  639.             }

  640.             for (int r = 0; r < nR; r++) {
  641.                 for (int c = 0; c < nC; c++) {
  642.                     final Neuron neuron = map.getNeuron(r, c);
  643.                     final Collection<Neuron> neighbours = net.getNeighbours(neuron);
  644.                     final double[] features = neuron.getFeatures();
  645.                     double uDistance = 0;
  646.                     int neighbourCount = 0;
  647.                     for (final Neuron n : neighbours) {
  648.                         ++neighbourCount;
  649.                         uDistance += DISTANCE.applyAsDouble(features, n.getFeatures());
  650.                     }

  651.                     final int hitCount = hitCounter[r][c];
  652.                     if (hitCount != 0) {
  653.                         hitHistogram[r][c] = hitCount / (double) numSamples;
  654.                         quantizationError[r][c] /= hitCount;
  655.                         topographicError[r][c] /= hitCount;
  656.                     }

  657.                     uMatrix[r][c] = uDistance / neighbourCount;
  658.                 }
  659.             }

  660.             return new DataVisualization(numSamples,
  661.                                          hitHistogram,
  662.                                          quantizationError,
  663.                                          topographicError,
  664.                                          uMatrix);
  665.         }

  666.         /**
  667.          * @return the total number of samples.
  668.          */
  669.         public int getNumberOfSamples() {
  670.             return numberOfSamples;
  671.         }

  672.         /**
  673.          * @return the quantization error.
  674.          * Each bin will contain the average of the distances between samples
  675.          * mapped to the corresponding unit and the weight vector of that unit.
  676.          * @see #getMeanQuantizationError()
  677.          */
  678.         public double[][] getQuantizationError() {
  679.             return copy(quantizationError);
  680.         }

  681.         /**
  682.          * @return the topographic error.
  683.          * Each bin will contain the number of data for which the first and
  684.          * second best matching units are not adjacent in the map.
  685.          * @see #getMeanTopographicError()
  686.          */
  687.         public double[][] getTopographicError() {
  688.             return copy(topographicError);
  689.         }

  690.         /**
  691.          * @return the hits histogram (normalized).
  692.          * Each bin will contain the number of data for which the corresponding
  693.          * neuron is the best matching unit.
  694.          */
  695.         public double[][] getNormalizedHits() {
  696.             return copy(hitHistogram);
  697.         }

  698.         /**
  699.          * @return the U-matrix.
  700.          * Each bin will contain the average distance between a unit and all its
  701.          * neighbours will be computed (and stored in the pixel corresponding to
  702.          * that unit of the 2D-map).  The number of neighbours taken into account
  703.          * depends on the network {@link org.apache.commons.math4.neuralnet.SquareNeighbourhood
  704.          * neighbourhood type}.
  705.          */
  706.         public double[][] getUMatrix() {
  707.             return copy(uMatrix);
  708.         }

  709.         /**
  710.          * @return the mean (hit-weighted) quantization error.
  711.          * @see #getQuantizationError()
  712.          */
  713.         public double getMeanQuantizationError() {
  714.             return meanQuantizationError;
  715.         }

  716.         /**
  717.          * @return the mean (hit-weighted) topographic error.
  718.          * @see #getTopographicError()
  719.          */
  720.         public double getMeanTopographicError() {
  721.             return meanTopographicError;
  722.         }

  723.         /**
  724.          * @param orig Source.
  725.          * @return a deep copy of the original array.
  726.          */
  727.         private static double[][] copy(double[][] orig) {
  728.             final double[][] copy = new double[orig.length][];
  729.             for (int i = 0; i < orig.length; i++) {
  730.                 copy[i] = orig[i].clone();
  731.             }

  732.             return copy;
  733.         }

  734.         /**
  735.          * @param metrics Metrics.
  736.          * @param normalizedHits Hits histogram (normalized).
  737.          * @return the hit-weighted mean of the given {@code metrics}.
  738.          */
  739.         private static double hitWeightedMean(double[][] metrics,
  740.                                               double[][] normalizedHits) {
  741.             double mean = 0;
  742.             final int rows = metrics.length;
  743.             final int cols = metrics[0].length;
  744.             for (int i = 0; i < rows; i++) {
  745.                 for (int j = 0; j < cols; j++) {
  746.                     mean += normalizedHits[i][j] * metrics[i][j];
  747.                 }
  748.             }

  749.             return mean;
  750.         }
  751.     }
  752. }