LocationFinder.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.neuralnet.twod.util;

  18. import java.util.Map;
  19. import java.util.concurrent.ConcurrentHashMap;
  20. import org.apache.commons.math4.neuralnet.Neuron;
  21. import org.apache.commons.math4.neuralnet.twod.NeuronSquareMesh2D;

  22. /**
  23.  * Helper class to find the grid coordinates of a neuron.
  24.  * @since 3.6
  25.  */
  26. public class LocationFinder {
  27.     /** Identifier to location mapping. */
  28.     private final Map<Long, Location> locations = new ConcurrentHashMap<>();

  29.     /**
  30.      * Container holding a (row, column) pair.
  31.      */
  32.     public static class Location {
  33.         /** Row index. */
  34.         private final int row;
  35.         /** Column index. */
  36.         private final int column;

  37.         /**
  38.          * @param row Row index.
  39.          * @param column Column index.
  40.          */
  41.         public Location(int row,
  42.                         int column) {
  43.             this.row = row;
  44.             this.column = column;
  45.         }

  46.         /**
  47.          * @return the row index.
  48.          */
  49.         public int getRow() {
  50.             return row;
  51.         }

  52.         /**
  53.          * @return the column index.
  54.          */
  55.         public int getColumn() {
  56.             return column;
  57.         }
  58.     }

  59.     /**
  60.      * Builds a finder to retrieve the locations of neurons that
  61.      * belong to the given {@code map}.
  62.      *
  63.      * @param map Map.
  64.      *
  65.      * @throws IllegalStateException if the network contains non-unique
  66.      * identifiers.  This indicates an inconsistent state due to a bug in
  67.      * the construction code of the underlying
  68.      * {@link org.apache.commons.math4.neuralnet.Network network}.
  69.      */
  70.     public LocationFinder(NeuronSquareMesh2D map) {
  71.         final int nR = map.getNumberOfRows();
  72.         final int nC = map.getNumberOfColumns();

  73.         for (int r = 0; r < nR; r++) {
  74.             for (int c = 0; c < nC; c++) {
  75.                 final Long id = map.getNeuron(r, c).getIdentifier();
  76.                 if (locations.get(id) != null) {
  77.                     throw new IllegalStateException();
  78.                 }
  79.                 locations.put(id, new Location(r, c));
  80.             }
  81.         }
  82.     }

  83.     /**
  84.      * Retrieves a neuron's grid coordinates.
  85.      *
  86.      * @param n Neuron.
  87.      * @return the (row, column) coordinates of {@code n}, or {@code null}
  88.      * if no such neuron belongs to the {@link #LocationFinder(NeuronSquareMesh2D)
  89.      * map used to build this instance}.
  90.      */
  91.     public Location getLocation(Neuron n) {
  92.         return locations.get(n.getIdentifier());
  93.     }
  94. }