001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.neuralnet.twod.util;
019
020import java.util.Map;
021import java.util.concurrent.ConcurrentHashMap;
022import org.apache.commons.math4.neuralnet.Neuron;
023import org.apache.commons.math4.neuralnet.twod.NeuronSquareMesh2D;
024
025/**
026 * Helper class to find the grid coordinates of a neuron.
027 * @since 3.6
028 */
029public class LocationFinder {
030    /** Identifier to location mapping. */
031    private final Map<Long, Location> locations = new ConcurrentHashMap<>();
032
033    /**
034     * Container holding a (row, column) pair.
035     */
036    public static class Location {
037        /** Row index. */
038        private final int row;
039        /** Column index. */
040        private final int column;
041
042        /**
043         * @param row Row index.
044         * @param column Column index.
045         */
046        public Location(int row,
047                        int column) {
048            this.row = row;
049            this.column = column;
050        }
051
052        /**
053         * @return the row index.
054         */
055        public int getRow() {
056            return row;
057        }
058
059        /**
060         * @return the column index.
061         */
062        public int getColumn() {
063            return column;
064        }
065    }
066
067    /**
068     * Builds a finder to retrieve the locations of neurons that
069     * belong to the given {@code map}.
070     *
071     * @param map Map.
072     *
073     * @throws IllegalStateException if the network contains non-unique
074     * identifiers.  This indicates an inconsistent state due to a bug in
075     * the construction code of the underlying
076     * {@link org.apache.commons.math4.neuralnet.Network network}.
077     */
078    public LocationFinder(NeuronSquareMesh2D map) {
079        final int nR = map.getNumberOfRows();
080        final int nC = map.getNumberOfColumns();
081
082        for (int r = 0; r < nR; r++) {
083            for (int c = 0; c < nC; c++) {
084                final Long id = map.getNeuron(r, c).getIdentifier();
085                if (locations.get(id) != null) {
086                    throw new IllegalStateException();
087                }
088                locations.put(id, new Location(r, c));
089            }
090        }
091    }
092
093    /**
094     * Retrieves a neuron's grid coordinates.
095     *
096     * @param n Neuron.
097     * @return the (row, column) coordinates of {@code n}, or {@code null}
098     * if no such neuron belongs to the {@link #LocationFinder(NeuronSquareMesh2D)
099     * map used to build this instance}.
100     */
101    public Location getLocation(Neuron n) {
102        return locations.get(n.getIdentifier());
103    }
104}