001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.ml.neuralnet;
019
020import java.util.List;
021import java.util.ArrayList;
022import java.util.Collections;
023import java.util.Comparator;
024
025import org.apache.commons.math4.exception.NotStrictlyPositiveException;
026import org.apache.commons.math4.ml.distance.DistanceMeasure;
027
028/**
029 * Utility for ranking the units (neurons) of a network.
030 *
031 * @since 4.0
032 */
033public class MapRanking {
034    /** List corresponding to the map passed to the constructor. */
035    private final List<Neuron> map = new ArrayList<>();
036    /** Distance function for sorting. */
037    private final DistanceMeasure distance;
038
039    /**
040     * @param neurons List to be ranked.
041     * No defensive copy is performed.
042     * The {@link #rank(double[],int) created list of units} will
043     * be sorted in increasing order of the {@code distance}.
044     * @param distance Distance function.
045     */
046    public MapRanking(Iterable<Neuron> neurons,
047                      DistanceMeasure distance) {
048        this.distance = distance;
049
050        for (Neuron n : neurons) {
051            map.add(n); // No defensive copy.
052        }
053    }
054
055    /**
056     * Creates a list of the neurons whose features best correspond to the
057     * given {@code features}.
058     *
059     * @param features Data.
060     * @return the list of neurons sorted in decreasing order of distance to
061     * the given data.
062     * @throws org.apache.commons.math4.exception.DimensionMismatchException
063     * if the size of the input is not compatible with the neurons features
064     * size.
065     */
066    public List<Neuron> rank(double[] features) {
067        return rank(features, map.size());
068    }
069
070    /**
071     * Creates a list of the neurons whose features best correspond to the
072     * given {@code features}.
073     *
074     * @param features Data.
075     * @param max Maximum size of the returned list.
076     * @return the list of neurons sorted in decreasing order of distance to
077     * the given data.
078     * @throws org.apache.commons.math4.exception.DimensionMismatchException
079     * if the size of the input is not compatible with the neurons features
080     * size.
081     * @throws NotStrictlyPositiveException if {@code max <= 0}.
082     */
083    public List<Neuron> rank(double[] features,
084                             int max) {
085        if (max <= 0) {
086            throw new NotStrictlyPositiveException(max);
087        }
088        final int m = max <= map.size() ?
089            max :
090            map.size();
091        final List<PairNeuronDouble> list = new ArrayList<>(m);
092
093        for (final Neuron n : map) {
094            final double d = distance.compute(n.getFeatures(), features);
095            final PairNeuronDouble p = new PairNeuronDouble(n, d);
096
097            if (list.size() < m) {
098                list.add(p);
099                if (list.size() > 1) {
100                    // Sort if there is more than 1 element.
101                    Collections.sort(list, PairNeuronDouble.COMPARATOR);
102                }
103            } else {
104                final int last = list.size() - 1;
105                if (PairNeuronDouble.COMPARATOR.compare(p, list.get(last)) < 0) {
106                    list.set(last, p); // Replace worst entry.
107                    if (last > 0) {
108                        // Sort if there is more than 1 element.
109                        Collections.sort(list, PairNeuronDouble.COMPARATOR);
110                    }
111                }
112            }
113        }
114
115        final List<Neuron> result = new ArrayList<>(m);
116        for (PairNeuronDouble p : list) {
117            result.add(p.getNeuron());
118        }
119
120        return result;
121    }
122
123    /**
124     * Helper data structure holding a (Neuron, double) pair.
125     */
126    private static class PairNeuronDouble {
127        /** Comparator. */
128        static final Comparator<PairNeuronDouble> COMPARATOR
129            = new Comparator<PairNeuronDouble>() {
130            /** {@inheritDoc} */
131            @Override
132            public int compare(PairNeuronDouble o1,
133                               PairNeuronDouble o2) {
134                return Double.compare(o1.value, o2.value);
135            }
136        };
137        /** Key. */
138        private final Neuron neuron;
139        /** Value. */
140        private final double value;
141
142        /**
143         * @param neuron Neuron.
144         * @param value Value.
145         */
146        PairNeuronDouble(Neuron neuron, double value) {
147            this.neuron = neuron;
148            this.value = value;
149        }
150
151        /** @return the neuron. */
152        public Neuron getNeuron() {
153            return neuron;
154        }
155    }
156}