001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math4.neuralnet.twod.util; 019 020import java.util.Map; 021import java.util.concurrent.ConcurrentHashMap; 022import org.apache.commons.math4.neuralnet.Neuron; 023import org.apache.commons.math4.neuralnet.twod.NeuronSquareMesh2D; 024 025/** 026 * Helper class to find the grid coordinates of a neuron. 027 * @since 3.6 028 */ 029public class LocationFinder { 030 /** Identifier to location mapping. */ 031 private final Map<Long, Location> locations = new ConcurrentHashMap<>(); 032 033 /** 034 * Container holding a (row, column) pair. 035 */ 036 public static class Location { 037 /** Row index. */ 038 private final int row; 039 /** Column index. */ 040 private final int column; 041 042 /** 043 * @param row Row index. 044 * @param column Column index. 045 */ 046 public Location(int row, 047 int column) { 048 this.row = row; 049 this.column = column; 050 } 051 052 /** 053 * @return the row index. 054 */ 055 public int getRow() { 056 return row; 057 } 058 059 /** 060 * @return the column index. 061 */ 062 public int getColumn() { 063 return column; 064 } 065 } 066 067 /** 068 * Builds a finder to retrieve the locations of neurons that 069 * belong to the given {@code map}. 070 * 071 * @param map Map. 072 * 073 * @throws IllegalStateException if the network contains non-unique 074 * identifiers. This indicates an inconsistent state due to a bug in 075 * the construction code of the underlying 076 * {@link org.apache.commons.math4.neuralnet.Network network}. 077 */ 078 public LocationFinder(NeuronSquareMesh2D map) { 079 final int nR = map.getNumberOfRows(); 080 final int nC = map.getNumberOfColumns(); 081 082 for (int r = 0; r < nR; r++) { 083 for (int c = 0; c < nC; c++) { 084 final Long id = map.getNeuron(r, c).getIdentifier(); 085 if (locations.get(id) != null) { 086 throw new IllegalStateException(); 087 } 088 locations.put(id, new Location(r, c)); 089 } 090 } 091 } 092 093 /** 094 * Retrieves a neuron's grid coordinates. 095 * 096 * @param n Neuron. 097 * @return the (row, column) coordinates of {@code n}, or {@code null} 098 * if no such neuron belongs to the {@link #LocationFinder(NeuronSquareMesh2D) 099 * map used to build this instance}. 100 */ 101 public Location getLocation(Neuron n) { 102 return locations.get(n.getIdentifier()); 103 } 104}