001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.ml.neuralnet.twod.util;
019
020import java.util.Map;
021import java.util.HashMap;
022import org.apache.commons.math3.ml.neuralnet.Neuron;
023import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
024import org.apache.commons.math3.exception.MathIllegalStateException;
025
026/**
027 * Helper class to find the grid coordinates of a neuron.
028 * @since 3.6
029 */
030public class LocationFinder {
031    /** Identifier to location mapping. */
032    private final Map<Long, Location> locations = new HashMap<Long, Location>();
033
034    /**
035     * Container holding a (row, column) pair.
036     */
037    public static class Location {
038        /** Row index. */
039        private final int row;
040        /** Column index. */
041        private final int column;
042
043        /**
044         * @param row Row index.
045         * @param column Column index.
046         */
047        public Location(int row,
048                        int column) {
049            this.row = row;
050            this.column = column;
051        }
052
053        /**
054         * @return the row index.
055         */
056        public int getRow() {
057            return row;
058        }
059
060        /**
061         * @return the column index.
062         */
063        public int getColumn() {
064            return column;
065        }
066    }
067
068    /**
069     * Builds a finder to retrieve the locations of neurons that
070     * belong to the given {@code map}.
071     *
072     * @param map Map.
073     *
074     * @throws MathIllegalStateException if the network contains non-unique
075     * identifiers.  This indicates an inconsistent state due to a bug in
076     * the construction code of the underlying
077     * {@link org.apache.commons.math3.ml.neuralnet.Network network}.
078     */
079    public LocationFinder(NeuronSquareMesh2D map) {
080        final int nR = map.getNumberOfRows();
081        final int nC = map.getNumberOfColumns();
082
083        for (int r = 0; r < nR; r++) {
084            for (int c = 0; c < nC; c++) {
085                final Long id = map.getNeuron(r, c).getIdentifier();
086                if (locations.get(id) != null) {
087                    throw new MathIllegalStateException();
088                }
089                locations.put(id, new Location(r, c));
090            }
091        }
092    }
093
094    /**
095     * Retrieves a neuron's grid coordinates.
096     *
097     * @param n Neuron.
098     * @return the (row, column) coordinates of {@code n}, or {@code null}
099     * if no such neuron belongs to the {@link #LocationFinder(NeuronSquareMesh2D)
100     * map used to build this instance}.
101     */
102    public Location getLocation(Neuron n) {
103        return locations.get(n.getIdentifier());
104    }
105}