001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math3.ml.neuralnet.twod.util; 019 020import java.util.Map; 021import java.util.HashMap; 022import org.apache.commons.math3.ml.neuralnet.Neuron; 023import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D; 024import org.apache.commons.math3.exception.MathIllegalStateException; 025 026/** 027 * Helper class to find the grid coordinates of a neuron. 028 * @since 3.6 029 */ 030public class LocationFinder { 031 /** Identifier to location mapping. */ 032 private final Map<Long, Location> locations = new HashMap<Long, Location>(); 033 034 /** 035 * Container holding a (row, column) pair. 036 */ 037 public static class Location { 038 /** Row index. */ 039 private final int row; 040 /** Column index. */ 041 private final int column; 042 043 /** 044 * @param row Row index. 045 * @param column Column index. 046 */ 047 public Location(int row, 048 int column) { 049 this.row = row; 050 this.column = column; 051 } 052 053 /** 054 * @return the row index. 055 */ 056 public int getRow() { 057 return row; 058 } 059 060 /** 061 * @return the column index. 062 */ 063 public int getColumn() { 064 return column; 065 } 066 } 067 068 /** 069 * Builds a finder to retrieve the locations of neurons that 070 * belong to the given {@code map}. 071 * 072 * @param map Map. 073 * 074 * @throws MathIllegalStateException if the network contains non-unique 075 * identifiers. This indicates an inconsistent state due to a bug in 076 * the construction code of the underlying 077 * {@link org.apache.commons.math3.ml.neuralnet.Network network}. 078 */ 079 public LocationFinder(NeuronSquareMesh2D map) { 080 final int nR = map.getNumberOfRows(); 081 final int nC = map.getNumberOfColumns(); 082 083 for (int r = 0; r < nR; r++) { 084 for (int c = 0; c < nC; c++) { 085 final Long id = map.getNeuron(r, c).getIdentifier(); 086 if (locations.get(id) != null) { 087 throw new MathIllegalStateException(); 088 } 089 locations.put(id, new Location(r, c)); 090 } 091 } 092 } 093 094 /** 095 * Retrieves a neuron's grid coordinates. 096 * 097 * @param n Neuron. 098 * @return the (row, column) coordinates of {@code n}, or {@code null} 099 * if no such neuron belongs to the {@link #LocationFinder(NeuronSquareMesh2D) 100 * map used to build this instance}. 101 */ 102 public Location getLocation(Neuron n) { 103 return locations.get(n.getIdentifier()); 104 } 105}