001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.ml.neuralnet.twod;
019
020import java.util.List;
021import java.util.ArrayList;
022import java.io.Serializable;
023import java.io.ObjectInputStream;
024import org.apache.commons.math3.ml.neuralnet.Neuron;
025import org.apache.commons.math3.ml.neuralnet.Network;
026import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
027import org.apache.commons.math3.ml.neuralnet.SquareNeighbourhood;
028import org.apache.commons.math3.exception.NumberIsTooSmallException;
029import org.apache.commons.math3.exception.OutOfRangeException;
030import org.apache.commons.math3.exception.MathInternalError;
031
032/**
033 * Neural network with the topology of a two-dimensional surface.
034 * Each neuron defines one surface element.
035 * <br/>
036 * This network is primarily intended to represent a
037 * <a href="http://en.wikipedia.org/wiki/Kohonen">
038 *  Self Organizing Feature Map</a>.
039 *
040 * @see org.apache.commons.math3.ml.neuralnet.sofm
041 * @since 3.3
042 */
043public class NeuronSquareMesh2D implements Serializable {
044    /** Serial version ID */
045    private static final long serialVersionUID = 1L;
046    /** Underlying network. */
047    private final Network network;
048    /** Number of rows. */
049    private final int numberOfRows;
050    /** Number of columns. */
051    private final int numberOfColumns;
052    /** Wrap. */
053    private final boolean wrapRows;
054    /** Wrap. */
055    private final boolean wrapColumns;
056    /** Neighbourhood type. */
057    private final SquareNeighbourhood neighbourhood;
058    /**
059     * Mapping of the 2D coordinates (in the rectangular mesh) to
060     * the neuron identifiers (attributed by the {@link #network}
061     * instance).
062     */
063    private final long[][] identifiers;
064
065    /**
066     * Constructor with restricted access, solely used for deserialization.
067     *
068     * @param wrapRowDim Whether to wrap the first dimension (i.e the first
069     * and last neurons will be linked together).
070     * @param wrapColDim Whether to wrap the second dimension (i.e the first
071     * and last neurons will be linked together).
072     * @param neighbourhoodType Neighbourhood type.
073     * @param featuresList Arrays that will initialize the features sets of
074     * the network's neurons.
075     * @throws NumberIsTooSmallException if {@code numRows < 2} or
076     * {@code numCols < 2}.
077     */
078    NeuronSquareMesh2D(boolean wrapRowDim,
079                       boolean wrapColDim,
080                       SquareNeighbourhood neighbourhoodType,
081                       double[][][] featuresList) {
082        numberOfRows = featuresList.length;
083        numberOfColumns = featuresList[0].length;
084
085        if (numberOfRows < 2) {
086            throw new NumberIsTooSmallException(numberOfRows, 2, true);
087        }
088        if (numberOfColumns < 2) {
089            throw new NumberIsTooSmallException(numberOfColumns, 2, true);
090        }
091
092        wrapRows = wrapRowDim;
093        wrapColumns = wrapColDim;
094        neighbourhood = neighbourhoodType;
095
096        final int fLen = featuresList[0][0].length;
097        network = new Network(0, fLen);
098        identifiers = new long[numberOfRows][numberOfColumns];
099
100        // Add neurons.
101        for (int i = 0; i < numberOfRows; i++) {
102            for (int j = 0; j < numberOfColumns; j++) {
103                identifiers[i][j] = network.createNeuron(featuresList[i][j]);
104            }
105        }
106
107        // Add links.
108        createLinks();
109    }
110
111    /**
112     * Creates a two-dimensional network composed of square cells:
113     * Each neuron not located on the border of the mesh has four
114     * neurons linked to it.
115     * <br/>
116     * The links are bi-directional.
117     * <br/>
118     * The topology of the network can also be a cylinder (if one
119     * of the dimensions is wrapped) or a torus (if both dimensions
120     * are wrapped).
121     *
122     * @param numRows Number of neurons in the first dimension.
123     * @param wrapRowDim Whether to wrap the first dimension (i.e the first
124     * and last neurons will be linked together).
125     * @param numCols Number of neurons in the second dimension.
126     * @param wrapColDim Whether to wrap the second dimension (i.e the first
127     * and last neurons will be linked together).
128     * @param neighbourhoodType Neighbourhood type.
129     * @param featureInit Array of functions that will initialize the
130     * corresponding element of the features set of each newly created
131     * neuron. In particular, the size of this array defines the size of
132     * feature set.
133     * @throws NumberIsTooSmallException if {@code numRows < 2} or
134     * {@code numCols < 2}.
135     */
136    public NeuronSquareMesh2D(int numRows,
137                              boolean wrapRowDim,
138                              int numCols,
139                              boolean wrapColDim,
140                              SquareNeighbourhood neighbourhoodType,
141                              FeatureInitializer[] featureInit) {
142        if (numRows < 2) {
143            throw new NumberIsTooSmallException(numRows, 2, true);
144        }
145        if (numCols < 2) {
146            throw new NumberIsTooSmallException(numCols, 2, true);
147        }
148
149        numberOfRows = numRows;
150        wrapRows = wrapRowDim;
151        numberOfColumns = numCols;
152        wrapColumns = wrapColDim;
153        neighbourhood = neighbourhoodType;
154        identifiers = new long[numberOfRows][numberOfColumns];
155
156        final int fLen = featureInit.length;
157        network = new Network(0, fLen);
158
159        // Add neurons.
160        for (int i = 0; i < numRows; i++) {
161            for (int j = 0; j < numCols; j++) {
162                final double[] features = new double[fLen];
163                for (int fIndex = 0; fIndex < fLen; fIndex++) {
164                    features[fIndex] = featureInit[fIndex].value();
165                }
166                identifiers[i][j] = network.createNeuron(features);
167            }
168        }
169
170        // Add links.
171        createLinks();
172    }
173
174    /**
175     * Retrieves the underlying network.
176     * A reference is returned (enabling, for example, the network to be
177     * trained).
178     * This also implies that calling methods that modify the {@link Network}
179     * topology may cause this class to become inconsistent.
180     *
181     * @return the network.
182     */
183    public Network getNetwork() {
184        return network;
185    }
186
187    /**
188     * Gets the number of neurons in each row of this map.
189     *
190     * @return the number of rows.
191     */
192    public int getNumberOfRows() {
193        return numberOfRows;
194    }
195
196    /**
197     * Gets the number of neurons in each column of this map.
198     *
199     * @return the number of column.
200     */
201    public int getNumberOfColumns() {
202        return numberOfColumns;
203    }
204
205    /**
206     * Retrieves the neuron at location {@code (i, j)} in the map.
207     *
208     * @param i Row index.
209     * @param j Column index.
210     * @return the neuron at {@code (i, j)}.
211     * @throws OutOfRangeException if {@code i} or {@code j} is
212     * out of range.
213     */
214    public Neuron getNeuron(int i,
215                            int j) {
216        if (i < 0 ||
217            i >= numberOfRows) {
218            throw new OutOfRangeException(i, 0, numberOfRows - 1);
219        }
220        if (j < 0 ||
221            j >= numberOfColumns) {
222            throw new OutOfRangeException(j, 0, numberOfColumns - 1);
223        }
224
225        return network.getNeuron(identifiers[i][j]);
226    }
227
228    /**
229     * Creates the neighbour relationships between neurons.
230     */
231    private void createLinks() {
232        // "linkEnd" will store the identifiers of the "neighbours".
233        final List<Long> linkEnd = new ArrayList<Long>();
234        final int iLast = numberOfRows - 1;
235        final int jLast = numberOfColumns - 1;
236        for (int i = 0; i < numberOfRows; i++) {
237            for (int j = 0; j < numberOfColumns; j++) {
238                linkEnd.clear();
239
240                switch (neighbourhood) {
241
242                case MOORE:
243                    // Add links to "diagonal" neighbours.
244                    if (i > 0) {
245                        if (j > 0) {
246                            linkEnd.add(identifiers[i - 1][j - 1]);
247                        }
248                        if (j < jLast) {
249                            linkEnd.add(identifiers[i - 1][j + 1]);
250                        }
251                    }
252                    if (i < iLast) {
253                        if (j > 0) {
254                            linkEnd.add(identifiers[i + 1][j - 1]);
255                        }
256                        if (j < jLast) {
257                            linkEnd.add(identifiers[i + 1][j + 1]);
258                        }
259                    }
260                    if (wrapRows) {
261                        if (i == 0) {
262                            if (j > 0) {
263                                linkEnd.add(identifiers[iLast][j - 1]);
264                            }
265                            if (j < jLast) {
266                                linkEnd.add(identifiers[iLast][j + 1]);
267                            }
268                        } else if (i == iLast) {
269                            if (j > 0) {
270                                linkEnd.add(identifiers[0][j - 1]);
271                            }
272                            if (j < jLast) {
273                                linkEnd.add(identifiers[0][j + 1]);
274                            }
275                        }
276                    }
277                    if (wrapColumns) {
278                        if (j == 0) {
279                            if (i > 0) {
280                                linkEnd.add(identifiers[i - 1][jLast]);
281                            }
282                            if (i < iLast) {
283                                linkEnd.add(identifiers[i + 1][jLast]);
284                            }
285                        } else if (j == jLast) {
286                             if (i > 0) {
287                                 linkEnd.add(identifiers[i - 1][0]);
288                             }
289                             if (i < iLast) {
290                                 linkEnd.add(identifiers[i + 1][0]);
291                             }
292                        }
293                    }
294                    if (wrapRows &&
295                        wrapColumns) {
296                        if (i == 0 &&
297                            j == 0) {
298                            linkEnd.add(identifiers[iLast][jLast]);
299                        } else if (i == 0 &&
300                                   j == jLast) {
301                            linkEnd.add(identifiers[iLast][0]);
302                        } else if (i == iLast &&
303                                   j == 0) {
304                            linkEnd.add(identifiers[0][jLast]);
305                        } else if (i == iLast &&
306                                   j == jLast) {
307                            linkEnd.add(identifiers[0][0]);
308                        }
309                    }
310
311                    // Case falls through since the "Moore" neighbourhood
312                    // also contains the neurons that belong to the "Von
313                    // Neumann" neighbourhood.
314
315                    // fallthru (CheckStyle)
316                case VON_NEUMANN:
317                    // Links to preceding and following "row".
318                    if (i > 0) {
319                        linkEnd.add(identifiers[i - 1][j]);
320                    }
321                    if (i < iLast) {
322                        linkEnd.add(identifiers[i + 1][j]);
323                    }
324                    if (wrapRows) {
325                        if (i == 0) {
326                            linkEnd.add(identifiers[iLast][j]);
327                        } else if (i == iLast) {
328                            linkEnd.add(identifiers[0][j]);
329                        }
330                    }
331
332                    // Links to preceding and following "column".
333                    if (j > 0) {
334                        linkEnd.add(identifiers[i][j - 1]);
335                    }
336                    if (j < jLast) {
337                        linkEnd.add(identifiers[i][j + 1]);
338                    }
339                    if (wrapColumns) {
340                        if (j == 0) {
341                            linkEnd.add(identifiers[i][jLast]);
342                        } else if (j == jLast) {
343                            linkEnd.add(identifiers[i][0]);
344                        }
345                    }
346                    break;
347
348                default:
349                    throw new MathInternalError(); // Cannot happen.
350                }
351
352                final Neuron aNeuron = network.getNeuron(identifiers[i][j]);
353                for (long b : linkEnd) {
354                    final Neuron bNeuron = network.getNeuron(b);
355                    // Link to all neighbours.
356                    // The reverse links will be added as the loop proceeds.
357                    network.addLink(aNeuron, bNeuron);
358                }
359            }
360        }
361    }
362
363    /**
364     * Prevents proxy bypass.
365     *
366     * @param in Input stream.
367     */
368    private void readObject(ObjectInputStream in) {
369        throw new IllegalStateException();
370    }
371
372    /**
373     * Custom serialization.
374     *
375     * @return the proxy instance that will be actually serialized.
376     */
377    private Object writeReplace() {
378        final double[][][] featuresList = new double[numberOfRows][numberOfColumns][];
379        for (int i = 0; i < numberOfRows; i++) {
380            for (int j = 0; j < numberOfColumns; j++) {
381                featuresList[i][j] = getNeuron(i, j).getFeatures();
382            }
383        }
384
385        return new SerializationProxy(wrapRows,
386                                      wrapColumns,
387                                      neighbourhood,
388                                      featuresList);
389    }
390
391    /**
392     * Serialization.
393     */
394    private static class SerializationProxy implements Serializable {
395        /** Serializable. */
396        private static final long serialVersionUID = 20130226L;
397        /** Wrap. */
398        private final boolean wrapRows;
399        /** Wrap. */
400        private final boolean wrapColumns;
401        /** Neighbourhood type. */
402        private final SquareNeighbourhood neighbourhood;
403        /** Neurons' features. */
404        private final double[][][] featuresList;
405
406        /**
407         * @param wrapRows Whether the row dimension is wrapped.
408         * @param wrapColumns Whether the column dimension is wrapped.
409         * @param neighbourhood Neighbourhood type.
410         * @param featuresList List of neurons features.
411         * {@code neuronList}.
412         */
413        SerializationProxy(boolean wrapRows,
414                           boolean wrapColumns,
415                           SquareNeighbourhood neighbourhood,
416                           double[][][] featuresList) {
417            this.wrapRows = wrapRows;
418            this.wrapColumns = wrapColumns;
419            this.neighbourhood = neighbourhood;
420            this.featuresList = featuresList;
421        }
422
423        /**
424         * Custom serialization.
425         *
426         * @return the {@link Neuron} for which this instance is the proxy.
427         */
428        private Object readResolve() {
429            return new NeuronSquareMesh2D(wrapRows,
430                                          wrapColumns,
431                                          neighbourhood,
432                                          featuresList);
433        }
434    }
435}