UnifiedDistanceMatrix.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.neuralnet.twod.util;

  18. import org.apache.commons.math4.neuralnet.DistanceMeasure;
  19. import org.apache.commons.math4.neuralnet.Neuron;
  20. import org.apache.commons.math4.neuralnet.twod.NeuronSquareMesh2D;

  21. /**
  22.  * <a href="http://en.wikipedia.org/wiki/U-Matrix">U-Matrix</a>
  23.  * visualization of high-dimensional data projection.
  24.  * The 8 individual inter-units distances will be
  25.  * {@link #computeImage(NeuronSquareMesh2D) computed}.  They will be
  26.  * stored in additional pixels around each of the original units of the
  27.  * 2D-map.  The additional pixels that lie along a "diagonal" are shared
  28.  * by <em>two</em> pairs of units: their value will be set to the average
  29.  * distance between the units belonging to each of the pairs.  The value
  30.  * zero will be stored in the pixel corresponding to the location of a
  31.  * unit of the 2D-map.
  32.  *
  33.  * @since 3.6
  34.  * @see org.apache.commons.math4.neuralnet.twod.NeuronSquareMesh2D.DataVisualization#getUMatrix()
  35.  */
  36. public class UnifiedDistanceMatrix implements MapVisualization {
  37.     /** Distance. */
  38.     private final DistanceMeasure distance;

  39.     /**
  40.      * @param distance Distance.
  41.      */
  42.     public UnifiedDistanceMatrix(DistanceMeasure distance) {
  43.         this.distance = distance;
  44.     }

  45.     /**
  46.      * Computes the distances between a unit of the map and its
  47.      * neighbours.
  48.      * The image will contain more pixels than the number of neurons
  49.      * in the given {@code map} because each neuron has 8 neighbours.
  50.      * The value zero will be stored in the pixels corresponding to
  51.      * the location of a map unit.
  52.      *
  53.      * @param map Map.
  54.      * @return an image representing the individual distances.
  55.      */
  56.     @Override
  57.     public double[][] computeImage(NeuronSquareMesh2D map) {
  58.         final int numRows = map.getNumberOfRows();
  59.         final int numCols = map.getNumberOfColumns();

  60.         final double[][] uMatrix = new double[numRows * 2 + 1][numCols * 2 + 1];

  61.         // 1.
  62.         // Fill right and bottom slots of each unit's location with the
  63.         // distance between the current unit and each of the two neighbours,
  64.         // respectively.
  65.         for (int i = 0; i < numRows; i++) {
  66.             // Current unit's row index in result image.
  67.             final int iR = 2 * i + 1;

  68.             for (int j = 0; j < numCols; j++) {
  69.                 // Current unit's column index in result image.
  70.                 final int jR = 2 * j + 1;

  71.                 final double[] current = map.getNeuron(i, j).getFeatures();
  72.                 Neuron neighbour;

  73.                 // Right neighbour.
  74.                 neighbour = map.getNeuron(i, j,
  75.                                           NeuronSquareMesh2D.HorizontalDirection.RIGHT,
  76.                                           NeuronSquareMesh2D.VerticalDirection.CENTER);
  77.                 if (neighbour != null) {
  78.                     uMatrix[iR][jR + 1] = distance.applyAsDouble(current,
  79.                                                                  neighbour.getFeatures());
  80.                 }

  81.                 // Bottom-center neighbour.
  82.                 neighbour = map.getNeuron(i, j,
  83.                                           NeuronSquareMesh2D.HorizontalDirection.CENTER,
  84.                                           NeuronSquareMesh2D.VerticalDirection.DOWN);
  85.                 if (neighbour != null) {
  86.                     uMatrix[iR + 1][jR] = distance.applyAsDouble(current,
  87.                                                                  neighbour.getFeatures());
  88.                 }
  89.             }
  90.         }

  91.         // 2.
  92.         // Fill the bottom-right slot of each unit's location with the average
  93.         // of the distances between
  94.         //  * the current unit and its bottom-right neighbour, and
  95.         //  * the bottom-center neighbour and the right neighbour.
  96.         for (int i = 0; i < numRows; i++) {
  97.             // Current unit's row index in result image.
  98.             final int iR = 2 * i + 1;

  99.             for (int j = 0; j < numCols; j++) {
  100.                 // Current unit's column index in result image.
  101.                 final int jR = 2 * j + 1;

  102.                 final Neuron current = map.getNeuron(i, j);
  103.                 final Neuron right = map.getNeuron(i, j,
  104.                                                    NeuronSquareMesh2D.HorizontalDirection.RIGHT,
  105.                                                    NeuronSquareMesh2D.VerticalDirection.CENTER);
  106.                 final Neuron bottom = map.getNeuron(i, j,
  107.                                                     NeuronSquareMesh2D.HorizontalDirection.CENTER,
  108.                                                     NeuronSquareMesh2D.VerticalDirection.DOWN);
  109.                 final Neuron bottomRight = map.getNeuron(i, j,
  110.                                                          NeuronSquareMesh2D.HorizontalDirection.RIGHT,
  111.                                                          NeuronSquareMesh2D.VerticalDirection.DOWN);

  112.                 final double current2BottomRight = bottomRight == null ?
  113.                     0 :
  114.                     distance.applyAsDouble(current.getFeatures(),
  115.                                            bottomRight.getFeatures());
  116.                 final double right2Bottom = (right == null ||
  117.                                              bottom == null) ?
  118.                     0 :
  119.                     distance.applyAsDouble(right.getFeatures(),
  120.                                            bottom.getFeatures());

  121.                 // Bottom-right slot.
  122.                 uMatrix[iR + 1][jR + 1] = 0.5 * (current2BottomRight + right2Bottom);
  123.             }
  124.         }

  125.         // 3. Copy last row into first row.
  126.         final int lastRow = uMatrix.length - 1;
  127.         uMatrix[0] = uMatrix[lastRow];

  128.         // 4.
  129.         // Copy last column into first column.
  130.         final int lastCol = uMatrix[0].length - 1;
  131.         for (int r = 0; r < lastRow; r++) {
  132.             uMatrix[r][0] = uMatrix[r][lastCol];
  133.         }

  134.         return uMatrix;
  135.     }
  136. }