MapRanking.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.neuralnet;

  18. import java.util.List;
  19. import java.util.ArrayList;
  20. import java.util.Collections;
  21. import java.util.Comparator;

  22. import org.apache.commons.math4.neuralnet.internal.NeuralNetException;

  23. /**
  24.  * Utility for ranking the units (neurons) of a network.
  25.  *
  26.  * @since 4.0
  27.  */
  28. public class MapRanking {
  29.     /** List corresponding to the map passed to the constructor. */
  30.     private final List<Neuron> map = new ArrayList<>();
  31.     /** Distance function for sorting. */
  32.     private final DistanceMeasure distance;

  33.     /**
  34.      * @param neurons List to be ranked.
  35.      * No defensive copy is performed.
  36.      * The {@link #rank(double[],int) created list of units} will
  37.      * be sorted in increasing order of the {@code distance}.
  38.      * @param distance Distance function.
  39.      */
  40.     public MapRanking(Iterable<Neuron> neurons,
  41.                       DistanceMeasure distance) {
  42.         this.distance = distance;

  43.         for (final Neuron n : neurons) {
  44.             map.add(n); // No defensive copy.
  45.         }
  46.     }

  47.     /**
  48.      * Creates a list of the neurons whose features best correspond to the
  49.      * given {@code features}.
  50.      *
  51.      * @param features Data.
  52.      * @return the list of neurons sorted in decreasing order of distance to
  53.      * the given data.
  54.      * @throws IllegalArgumentException if the size of the input is not
  55.      * compatible with the neurons features size.
  56.      */
  57.     public List<Neuron> rank(double[] features) {
  58.         return rank(features, map.size());
  59.     }

  60.     /**
  61.      * Creates a list of the neurons whose features best correspond to the
  62.      * given {@code features}.
  63.      *
  64.      * @param features Data.
  65.      * @param max Maximum size of the returned list.
  66.      * @return the list of neurons sorted in decreasing order of distance to
  67.      * the given data.
  68.      * @throws IllegalArgumentException if the size of the input is not
  69.      * compatible with the neurons features size or {@code max <= 0}.
  70.      */
  71.     public List<Neuron> rank(double[] features,
  72.                              int max) {
  73.         if (max <= 0) {
  74.             throw new NeuralNetException(NeuralNetException.NOT_STRICTLY_POSITIVE, max);
  75.         }
  76.         final int m = max <= map.size() ?
  77.             max :
  78.             map.size();
  79.         final List<PairNeuronDouble> list = new ArrayList<>(m);

  80.         for (final Neuron n : map) {
  81.             final double d = distance.applyAsDouble(n.getFeatures(), features);
  82.             final PairNeuronDouble p = new PairNeuronDouble(n, d);

  83.             if (list.size() < m) {
  84.                 list.add(p);
  85.                 if (list.size() > 1) {
  86.                     // Sort if there is more than 1 element.
  87.                     Collections.sort(list, PairNeuronDouble.COMPARATOR);
  88.                 }
  89.             } else {
  90.                 final int last = list.size() - 1;
  91.                 if (PairNeuronDouble.COMPARATOR.compare(p, list.get(last)) < 0) {
  92.                     list.set(last, p); // Replace worst entry.
  93.                     if (last > 0) {
  94.                         // Sort if there is more than 1 element.
  95.                         Collections.sort(list, PairNeuronDouble.COMPARATOR);
  96.                     }
  97.                 }
  98.             }
  99.         }

  100.         final List<Neuron> result = new ArrayList<>(m);
  101.         for (final PairNeuronDouble p : list) {
  102.             result.add(p.getNeuron());
  103.         }

  104.         return result;
  105.     }

  106.     /**
  107.      * Helper data structure holding a (Neuron, double) pair.
  108.      */
  109.     private static class PairNeuronDouble {
  110.         /** Comparator. */
  111.         static final Comparator<PairNeuronDouble> COMPARATOR
  112.             = new Comparator<PairNeuronDouble>() {
  113.                 /** {@inheritDoc} */
  114.                 @Override
  115.                 public int compare(PairNeuronDouble o1,
  116.                                    PairNeuronDouble o2) {
  117.                     return Double.compare(o1.value, o2.value);
  118.                 }
  119.             };
  120.         /** Key. */
  121.         private final Neuron neuron;
  122.         /** Value. */
  123.         private final double value;

  124.         /**
  125.          * @param neuron Neuron.
  126.          * @param value Value.
  127.          */
  128.         PairNeuronDouble(Neuron neuron, double value) {
  129.             this.neuron = neuron;
  130.             this.value = value;
  131.         }

  132.         /** @return the neuron. */
  133.         public Neuron getNeuron() {
  134.             return neuron;
  135.         }
  136.     }
  137. }