001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.ml.neuralnet.twod;
019
020import java.util.List;
021import java.util.ArrayList;
022import java.util.Iterator;
023import java.io.Serializable;
024import java.io.ObjectInputStream;
025import org.apache.commons.math3.ml.neuralnet.Neuron;
026import org.apache.commons.math3.ml.neuralnet.Network;
027import org.apache.commons.math3.ml.neuralnet.FeatureInitializer;
028import org.apache.commons.math3.ml.neuralnet.SquareNeighbourhood;
029import org.apache.commons.math3.exception.NumberIsTooSmallException;
030import org.apache.commons.math3.exception.OutOfRangeException;
031import org.apache.commons.math3.exception.MathInternalError;
032
033/**
034 * Neural network with the topology of a two-dimensional surface.
035 * Each neuron defines one surface element.
036 * <br/>
037 * This network is primarily intended to represent a
038 * <a href="http://en.wikipedia.org/wiki/Kohonen">
039 *  Self Organizing Feature Map</a>.
040 *
041 * @see org.apache.commons.math3.ml.neuralnet.sofm
042 * @since 3.3
043 */
044public class NeuronSquareMesh2D
045    implements Iterable<Neuron>,
046               Serializable {
047    /** Serial version ID */
048    private static final long serialVersionUID = 1L;
049    /** Underlying network. */
050    private final Network network;
051    /** Number of rows. */
052    private final int numberOfRows;
053    /** Number of columns. */
054    private final int numberOfColumns;
055    /** Wrap. */
056    private final boolean wrapRows;
057    /** Wrap. */
058    private final boolean wrapColumns;
059    /** Neighbourhood type. */
060    private final SquareNeighbourhood neighbourhood;
061    /**
062     * Mapping of the 2D coordinates (in the rectangular mesh) to
063     * the neuron identifiers (attributed by the {@link #network}
064     * instance).
065     */
066    private final long[][] identifiers;
067
068    /**
069     * Horizontal (along row) direction.
070     * @since 3.6
071     */
072    public enum HorizontalDirection {
073        /** Column at the right of the current column. */
074       RIGHT,
075       /** Current column. */
076       CENTER,
077       /** Column at the left of the current column. */
078       LEFT,
079    }
080    /**
081     * Vertical (along column) direction.
082     * @since 3.6
083     */
084    public enum VerticalDirection {
085        /** Row above the current row. */
086        UP,
087        /** Current row. */
088        CENTER,
089        /** Row below the current row. */
090        DOWN,
091    }
092
093    /**
094     * Constructor with restricted access, solely used for deserialization.
095     *
096     * @param wrapRowDim Whether to wrap the first dimension (i.e the first
097     * and last neurons will be linked together).
098     * @param wrapColDim Whether to wrap the second dimension (i.e the first
099     * and last neurons will be linked together).
100     * @param neighbourhoodType Neighbourhood type.
101     * @param featuresList Arrays that will initialize the features sets of
102     * the network's neurons.
103     * @throws NumberIsTooSmallException if {@code numRows < 2} or
104     * {@code numCols < 2}.
105     */
106    NeuronSquareMesh2D(boolean wrapRowDim,
107                       boolean wrapColDim,
108                       SquareNeighbourhood neighbourhoodType,
109                       double[][][] featuresList) {
110        numberOfRows = featuresList.length;
111        numberOfColumns = featuresList[0].length;
112
113        if (numberOfRows < 2) {
114            throw new NumberIsTooSmallException(numberOfRows, 2, true);
115        }
116        if (numberOfColumns < 2) {
117            throw new NumberIsTooSmallException(numberOfColumns, 2, true);
118        }
119
120        wrapRows = wrapRowDim;
121        wrapColumns = wrapColDim;
122        neighbourhood = neighbourhoodType;
123
124        final int fLen = featuresList[0][0].length;
125        network = new Network(0, fLen);
126        identifiers = new long[numberOfRows][numberOfColumns];
127
128        // Add neurons.
129        for (int i = 0; i < numberOfRows; i++) {
130            for (int j = 0; j < numberOfColumns; j++) {
131                identifiers[i][j] = network.createNeuron(featuresList[i][j]);
132            }
133        }
134
135        // Add links.
136        createLinks();
137    }
138
139    /**
140     * Creates a two-dimensional network composed of square cells:
141     * Each neuron not located on the border of the mesh has four
142     * neurons linked to it.
143     * <br/>
144     * The links are bi-directional.
145     * <br/>
146     * The topology of the network can also be a cylinder (if one
147     * of the dimensions is wrapped) or a torus (if both dimensions
148     * are wrapped).
149     *
150     * @param numRows Number of neurons in the first dimension.
151     * @param wrapRowDim Whether to wrap the first dimension (i.e the first
152     * and last neurons will be linked together).
153     * @param numCols Number of neurons in the second dimension.
154     * @param wrapColDim Whether to wrap the second dimension (i.e the first
155     * and last neurons will be linked together).
156     * @param neighbourhoodType Neighbourhood type.
157     * @param featureInit Array of functions that will initialize the
158     * corresponding element of the features set of each newly created
159     * neuron. In particular, the size of this array defines the size of
160     * feature set.
161     * @throws NumberIsTooSmallException if {@code numRows < 2} or
162     * {@code numCols < 2}.
163     */
164    public NeuronSquareMesh2D(int numRows,
165                              boolean wrapRowDim,
166                              int numCols,
167                              boolean wrapColDim,
168                              SquareNeighbourhood neighbourhoodType,
169                              FeatureInitializer[] featureInit) {
170        if (numRows < 2) {
171            throw new NumberIsTooSmallException(numRows, 2, true);
172        }
173        if (numCols < 2) {
174            throw new NumberIsTooSmallException(numCols, 2, true);
175        }
176
177        numberOfRows = numRows;
178        wrapRows = wrapRowDim;
179        numberOfColumns = numCols;
180        wrapColumns = wrapColDim;
181        neighbourhood = neighbourhoodType;
182        identifiers = new long[numberOfRows][numberOfColumns];
183
184        final int fLen = featureInit.length;
185        network = new Network(0, fLen);
186
187        // Add neurons.
188        for (int i = 0; i < numRows; i++) {
189            for (int j = 0; j < numCols; j++) {
190                final double[] features = new double[fLen];
191                for (int fIndex = 0; fIndex < fLen; fIndex++) {
192                    features[fIndex] = featureInit[fIndex].value();
193                }
194                identifiers[i][j] = network.createNeuron(features);
195            }
196        }
197
198        // Add links.
199        createLinks();
200    }
201
202    /**
203     * Constructor with restricted access, solely used for making a
204     * {@link #copy() deep copy}.
205     *
206     * @param wrapRowDim Whether to wrap the first dimension (i.e the first
207     * and last neurons will be linked together).
208     * @param wrapColDim Whether to wrap the second dimension (i.e the first
209     * and last neurons will be linked together).
210     * @param neighbourhoodType Neighbourhood type.
211     * @param net Underlying network.
212     * @param idGrid Neuron identifiers.
213     */
214    private NeuronSquareMesh2D(boolean wrapRowDim,
215                               boolean wrapColDim,
216                               SquareNeighbourhood neighbourhoodType,
217                               Network net,
218                               long[][] idGrid) {
219        numberOfRows = idGrid.length;
220        numberOfColumns = idGrid[0].length;
221        wrapRows = wrapRowDim;
222        wrapColumns = wrapColDim;
223        neighbourhood = neighbourhoodType;
224        network = net;
225        identifiers = idGrid;
226    }
227
228    /**
229     * Performs a deep copy of this instance.
230     * Upon return, the copied and original instances will be independent:
231     * Updating one will not affect the other.
232     *
233     * @return a new instance with the same state as this instance.
234     * @since 3.6
235     */
236    public synchronized NeuronSquareMesh2D copy() {
237        final long[][] idGrid = new long[numberOfRows][numberOfColumns];
238        for (int r = 0; r < numberOfRows; r++) {
239            for (int c = 0; c < numberOfColumns; c++) {
240                idGrid[r][c] = identifiers[r][c];
241            }
242        }
243
244        return new NeuronSquareMesh2D(wrapRows,
245                                      wrapColumns,
246                                      neighbourhood,
247                                      network.copy(),
248                                      idGrid);
249    }
250
251    /**
252     * {@inheritDoc}
253     *  @since 3.6
254     */
255    public Iterator<Neuron> iterator() {
256        return network.iterator();
257    }
258
259    /**
260     * Retrieves the underlying network.
261     * A reference is returned (enabling, for example, the network to be
262     * trained).
263     * This also implies that calling methods that modify the {@link Network}
264     * topology may cause this class to become inconsistent.
265     *
266     * @return the network.
267     */
268    public Network getNetwork() {
269        return network;
270    }
271
272    /**
273     * Gets the number of neurons in each row of this map.
274     *
275     * @return the number of rows.
276     */
277    public int getNumberOfRows() {
278        return numberOfRows;
279    }
280
281    /**
282     * Gets the number of neurons in each column of this map.
283     *
284     * @return the number of column.
285     */
286    public int getNumberOfColumns() {
287        return numberOfColumns;
288    }
289
290    /**
291     * Retrieves the neuron at location {@code (i, j)} in the map.
292     * The neuron at position {@code (0, 0)} is located at the upper-left
293     * corner of the map.
294     *
295     * @param i Row index.
296     * @param j Column index.
297     * @return the neuron at {@code (i, j)}.
298     * @throws OutOfRangeException if {@code i} or {@code j} is
299     * out of range.
300     *
301     * @see #getNeuron(int,int,HorizontalDirection,VerticalDirection)
302     */
303    public Neuron getNeuron(int i,
304                            int j) {
305        if (i < 0 ||
306            i >= numberOfRows) {
307            throw new OutOfRangeException(i, 0, numberOfRows - 1);
308        }
309        if (j < 0 ||
310            j >= numberOfColumns) {
311            throw new OutOfRangeException(j, 0, numberOfColumns - 1);
312        }
313
314        return network.getNeuron(identifiers[i][j]);
315    }
316
317    /**
318     * Retrieves the neuron at {@code (location[0], location[1])} in the map.
319     * The neuron at position {@code (0, 0)} is located at the upper-left
320     * corner of the map.
321     *
322     * @param row Row index.
323     * @param col Column index.
324     * @param alongRowDir Direction along the given {@code row} (i.e. an
325     * offset will be added to the given <em>column</em> index.
326     * @param alongColDir Direction along the given {@code col} (i.e. an
327     * offset will be added to the given <em>row</em> index.
328     * @return the neuron at the requested location, or {@code null} if
329     * the location is not on the map.
330     *
331     * @see #getNeuron(int,int)
332     */
333    public Neuron getNeuron(int row,
334                            int col,
335                            HorizontalDirection alongRowDir,
336                            VerticalDirection alongColDir) {
337        final int[] location = getLocation(row, col, alongRowDir, alongColDir);
338
339        return location == null ? null : getNeuron(location[0], location[1]);
340    }
341
342    /**
343     * Computes the location of a neighbouring neuron.
344     * It will return {@code null} if the resulting location is not part
345     * of the map.
346     * Position {@code (0, 0)} is at the upper-left corner of the map.
347     *
348     * @param row Row index.
349     * @param col Column index.
350     * @param alongRowDir Direction along the given {@code row} (i.e. an
351     * offset will be added to the given <em>column</em> index.
352     * @param alongColDir Direction along the given {@code col} (i.e. an
353     * offset will be added to the given <em>row</em> index.
354     * @return an array of length 2 containing the indices of the requested
355     * location, or {@code null} if that location is not part of the map.
356     *
357     * @see #getNeuron(int,int)
358     */
359    private int[] getLocation(int row,
360                              int col,
361                              HorizontalDirection alongRowDir,
362                              VerticalDirection alongColDir) {
363        final int colOffset;
364        switch (alongRowDir) {
365        case LEFT:
366            colOffset = -1;
367            break;
368        case RIGHT:
369            colOffset = 1;
370            break;
371        case CENTER:
372            colOffset = 0;
373            break;
374        default:
375            // Should never happen.
376            throw new MathInternalError();
377        }
378        int colIndex = col + colOffset;
379        if (wrapColumns) {
380            if (colIndex < 0) {
381                colIndex += numberOfColumns;
382            } else {
383                colIndex %= numberOfColumns;
384            }
385        }
386
387        final int rowOffset;
388        switch (alongColDir) {
389        case UP:
390            rowOffset = -1;
391            break;
392        case DOWN:
393            rowOffset = 1;
394            break;
395        case CENTER:
396            rowOffset = 0;
397            break;
398        default:
399            // Should never happen.
400            throw new MathInternalError();
401        }
402        int rowIndex = row + rowOffset;
403        if (wrapRows) {
404            if (rowIndex < 0) {
405                rowIndex += numberOfRows;
406            } else {
407                rowIndex %= numberOfRows;
408            }
409        }
410
411        if (rowIndex < 0 ||
412            rowIndex >= numberOfRows ||
413            colIndex < 0 ||
414            colIndex >= numberOfColumns) {
415            return null;
416        } else {
417            return new int[] { rowIndex, colIndex };
418        }
419    }
420
421    /**
422     * Creates the neighbour relationships between neurons.
423     */
424    private void createLinks() {
425        // "linkEnd" will store the identifiers of the "neighbours".
426        final List<Long> linkEnd = new ArrayList<Long>();
427        final int iLast = numberOfRows - 1;
428        final int jLast = numberOfColumns - 1;
429        for (int i = 0; i < numberOfRows; i++) {
430            for (int j = 0; j < numberOfColumns; j++) {
431                linkEnd.clear();
432
433                switch (neighbourhood) {
434
435                case MOORE:
436                    // Add links to "diagonal" neighbours.
437                    if (i > 0) {
438                        if (j > 0) {
439                            linkEnd.add(identifiers[i - 1][j - 1]);
440                        }
441                        if (j < jLast) {
442                            linkEnd.add(identifiers[i - 1][j + 1]);
443                        }
444                    }
445                    if (i < iLast) {
446                        if (j > 0) {
447                            linkEnd.add(identifiers[i + 1][j - 1]);
448                        }
449                        if (j < jLast) {
450                            linkEnd.add(identifiers[i + 1][j + 1]);
451                        }
452                    }
453                    if (wrapRows) {
454                        if (i == 0) {
455                            if (j > 0) {
456                                linkEnd.add(identifiers[iLast][j - 1]);
457                            }
458                            if (j < jLast) {
459                                linkEnd.add(identifiers[iLast][j + 1]);
460                            }
461                        } else if (i == iLast) {
462                            if (j > 0) {
463                                linkEnd.add(identifiers[0][j - 1]);
464                            }
465                            if (j < jLast) {
466                                linkEnd.add(identifiers[0][j + 1]);
467                            }
468                        }
469                    }
470                    if (wrapColumns) {
471                        if (j == 0) {
472                            if (i > 0) {
473                                linkEnd.add(identifiers[i - 1][jLast]);
474                            }
475                            if (i < iLast) {
476                                linkEnd.add(identifiers[i + 1][jLast]);
477                            }
478                        } else if (j == jLast) {
479                             if (i > 0) {
480                                 linkEnd.add(identifiers[i - 1][0]);
481                             }
482                             if (i < iLast) {
483                                 linkEnd.add(identifiers[i + 1][0]);
484                             }
485                        }
486                    }
487                    if (wrapRows &&
488                        wrapColumns) {
489                        if (i == 0 &&
490                            j == 0) {
491                            linkEnd.add(identifiers[iLast][jLast]);
492                        } else if (i == 0 &&
493                                   j == jLast) {
494                            linkEnd.add(identifiers[iLast][0]);
495                        } else if (i == iLast &&
496                                   j == 0) {
497                            linkEnd.add(identifiers[0][jLast]);
498                        } else if (i == iLast &&
499                                   j == jLast) {
500                            linkEnd.add(identifiers[0][0]);
501                        }
502                    }
503
504                    // Case falls through since the "Moore" neighbourhood
505                    // also contains the neurons that belong to the "Von
506                    // Neumann" neighbourhood.
507
508                    // fallthru (CheckStyle)
509                case VON_NEUMANN:
510                    // Links to preceding and following "row".
511                    if (i > 0) {
512                        linkEnd.add(identifiers[i - 1][j]);
513                    }
514                    if (i < iLast) {
515                        linkEnd.add(identifiers[i + 1][j]);
516                    }
517                    if (wrapRows) {
518                        if (i == 0) {
519                            linkEnd.add(identifiers[iLast][j]);
520                        } else if (i == iLast) {
521                            linkEnd.add(identifiers[0][j]);
522                        }
523                    }
524
525                    // Links to preceding and following "column".
526                    if (j > 0) {
527                        linkEnd.add(identifiers[i][j - 1]);
528                    }
529                    if (j < jLast) {
530                        linkEnd.add(identifiers[i][j + 1]);
531                    }
532                    if (wrapColumns) {
533                        if (j == 0) {
534                            linkEnd.add(identifiers[i][jLast]);
535                        } else if (j == jLast) {
536                            linkEnd.add(identifiers[i][0]);
537                        }
538                    }
539                    break;
540
541                default:
542                    throw new MathInternalError(); // Cannot happen.
543                }
544
545                final Neuron aNeuron = network.getNeuron(identifiers[i][j]);
546                for (long b : linkEnd) {
547                    final Neuron bNeuron = network.getNeuron(b);
548                    // Link to all neighbours.
549                    // The reverse links will be added as the loop proceeds.
550                    network.addLink(aNeuron, bNeuron);
551                }
552            }
553        }
554    }
555
556    /**
557     * Prevents proxy bypass.
558     *
559     * @param in Input stream.
560     */
561    private void readObject(ObjectInputStream in) {
562        throw new IllegalStateException();
563    }
564
565    /**
566     * Custom serialization.
567     *
568     * @return the proxy instance that will be actually serialized.
569     */
570    private Object writeReplace() {
571        final double[][][] featuresList = new double[numberOfRows][numberOfColumns][];
572        for (int i = 0; i < numberOfRows; i++) {
573            for (int j = 0; j < numberOfColumns; j++) {
574                featuresList[i][j] = getNeuron(i, j).getFeatures();
575            }
576        }
577
578        return new SerializationProxy(wrapRows,
579                                      wrapColumns,
580                                      neighbourhood,
581                                      featuresList);
582    }
583
584    /**
585     * Serialization.
586     */
587    private static class SerializationProxy implements Serializable {
588        /** Serializable. */
589        private static final long serialVersionUID = 20130226L;
590        /** Wrap. */
591        private final boolean wrapRows;
592        /** Wrap. */
593        private final boolean wrapColumns;
594        /** Neighbourhood type. */
595        private final SquareNeighbourhood neighbourhood;
596        /** Neurons' features. */
597        private final double[][][] featuresList;
598
599        /**
600         * @param wrapRows Whether the row dimension is wrapped.
601         * @param wrapColumns Whether the column dimension is wrapped.
602         * @param neighbourhood Neighbourhood type.
603         * @param featuresList List of neurons features.
604         * {@code neuronList}.
605         */
606        SerializationProxy(boolean wrapRows,
607                           boolean wrapColumns,
608                           SquareNeighbourhood neighbourhood,
609                           double[][][] featuresList) {
610            this.wrapRows = wrapRows;
611            this.wrapColumns = wrapColumns;
612            this.neighbourhood = neighbourhood;
613            this.featuresList = featuresList;
614        }
615
616        /**
617         * Custom serialization.
618         *
619         * @return the {@link Neuron} for which this instance is the proxy.
620         */
621        private Object readResolve() {
622            return new NeuronSquareMesh2D(wrapRows,
623                                          wrapColumns,
624                                          neighbourhood,
625                                          featuresList);
626        }
627    }
628}