001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math4.neuralnet;
019
020import java.util.List;
021import java.util.ArrayList;
022import java.util.Collections;
023import java.util.Comparator;
024
025import org.apache.commons.math4.neuralnet.internal.NeuralNetException;
026
027/**
028 * Utility for ranking the units (neurons) of a network.
029 *
030 * @since 4.0
031 */
032public class MapRanking {
033    /** List corresponding to the map passed to the constructor. */
034    private final List<Neuron> map = new ArrayList<>();
035    /** Distance function for sorting. */
036    private final DistanceMeasure distance;
037
038    /**
039     * @param neurons List to be ranked.
040     * No defensive copy is performed.
041     * The {@link #rank(double[],int) created list of units} will
042     * be sorted in increasing order of the {@code distance}.
043     * @param distance Distance function.
044     */
045    public MapRanking(Iterable<Neuron> neurons,
046                      DistanceMeasure distance) {
047        this.distance = distance;
048
049        for (final Neuron n : neurons) {
050            map.add(n); // No defensive copy.
051        }
052    }
053
054    /**
055     * Creates a list of the neurons whose features best correspond to the
056     * given {@code features}.
057     *
058     * @param features Data.
059     * @return the list of neurons sorted in decreasing order of distance to
060     * the given data.
061     * @throws IllegalArgumentException if the size of the input is not
062     * compatible with the neurons features size.
063     */
064    public List<Neuron> rank(double[] features) {
065        return rank(features, map.size());
066    }
067
068    /**
069     * Creates a list of the neurons whose features best correspond to the
070     * given {@code features}.
071     *
072     * @param features Data.
073     * @param max Maximum size of the returned list.
074     * @return the list of neurons sorted in decreasing order of distance to
075     * the given data.
076     * @throws IllegalArgumentException if the size of the input is not
077     * compatible with the neurons features size or {@code max <= 0}.
078     */
079    public List<Neuron> rank(double[] features,
080                             int max) {
081        if (max <= 0) {
082            throw new NeuralNetException(NeuralNetException.NOT_STRICTLY_POSITIVE, max);
083        }
084        final int m = max <= map.size() ?
085            max :
086            map.size();
087        final List<PairNeuronDouble> list = new ArrayList<>(m);
088
089        for (final Neuron n : map) {
090            final double d = distance.applyAsDouble(n.getFeatures(), features);
091            final PairNeuronDouble p = new PairNeuronDouble(n, d);
092
093            if (list.size() < m) {
094                list.add(p);
095                if (list.size() > 1) {
096                    // Sort if there is more than 1 element.
097                    Collections.sort(list, PairNeuronDouble.COMPARATOR);
098                }
099            } else {
100                final int last = list.size() - 1;
101                if (PairNeuronDouble.COMPARATOR.compare(p, list.get(last)) < 0) {
102                    list.set(last, p); // Replace worst entry.
103                    if (last > 0) {
104                        // Sort if there is more than 1 element.
105                        Collections.sort(list, PairNeuronDouble.COMPARATOR);
106                    }
107                }
108            }
109        }
110
111        final List<Neuron> result = new ArrayList<>(m);
112        for (final PairNeuronDouble p : list) {
113            result.add(p.getNeuron());
114        }
115
116        return result;
117    }
118
119    /**
120     * Helper data structure holding a (Neuron, double) pair.
121     */
122    private static class PairNeuronDouble {
123        /** Comparator. */
124        static final Comparator<PairNeuronDouble> COMPARATOR
125            = new Comparator<PairNeuronDouble>() {
126                /** {@inheritDoc} */
127                @Override
128                public int compare(PairNeuronDouble o1,
129                                   PairNeuronDouble o2) {
130                    return Double.compare(o1.value, o2.value);
131                }
132            };
133        /** Key. */
134        private final Neuron neuron;
135        /** Value. */
136        private final double value;
137
138        /**
139         * @param neuron Neuron.
140         * @param value Value.
141         */
142        PairNeuronDouble(Neuron neuron, double value) {
143            this.neuron = neuron;
144            this.value = value;
145        }
146
147        /** @return the neuron. */
148        public Neuron getNeuron() {
149            return neuron;
150        }
151    }
152}