001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math4.neuralnet; 019 020import java.util.List; 021import java.util.ArrayList; 022import java.util.Collections; 023import java.util.Comparator; 024 025import org.apache.commons.math4.neuralnet.internal.NeuralNetException; 026 027/** 028 * Utility for ranking the units (neurons) of a network. 029 * 030 * @since 4.0 031 */ 032public class MapRanking { 033 /** List corresponding to the map passed to the constructor. */ 034 private final List<Neuron> map = new ArrayList<>(); 035 /** Distance function for sorting. */ 036 private final DistanceMeasure distance; 037 038 /** 039 * @param neurons List to be ranked. 040 * No defensive copy is performed. 041 * The {@link #rank(double[],int) created list of units} will 042 * be sorted in increasing order of the {@code distance}. 043 * @param distance Distance function. 044 */ 045 public MapRanking(Iterable<Neuron> neurons, 046 DistanceMeasure distance) { 047 this.distance = distance; 048 049 for (final Neuron n : neurons) { 050 map.add(n); // No defensive copy. 051 } 052 } 053 054 /** 055 * Creates a list of the neurons whose features best correspond to the 056 * given {@code features}. 057 * 058 * @param features Data. 059 * @return the list of neurons sorted in decreasing order of distance to 060 * the given data. 061 * @throws IllegalArgumentException if the size of the input is not 062 * compatible with the neurons features size. 063 */ 064 public List<Neuron> rank(double[] features) { 065 return rank(features, map.size()); 066 } 067 068 /** 069 * Creates a list of the neurons whose features best correspond to the 070 * given {@code features}. 071 * 072 * @param features Data. 073 * @param max Maximum size of the returned list. 074 * @return the list of neurons sorted in decreasing order of distance to 075 * the given data. 076 * @throws IllegalArgumentException if the size of the input is not 077 * compatible with the neurons features size or {@code max <= 0}. 078 */ 079 public List<Neuron> rank(double[] features, 080 int max) { 081 if (max <= 0) { 082 throw new NeuralNetException(NeuralNetException.NOT_STRICTLY_POSITIVE, max); 083 } 084 final int m = max <= map.size() ? 085 max : 086 map.size(); 087 final List<PairNeuronDouble> list = new ArrayList<>(m); 088 089 for (final Neuron n : map) { 090 final double d = distance.applyAsDouble(n.getFeatures(), features); 091 final PairNeuronDouble p = new PairNeuronDouble(n, d); 092 093 if (list.size() < m) { 094 list.add(p); 095 if (list.size() > 1) { 096 // Sort if there is more than 1 element. 097 Collections.sort(list, PairNeuronDouble.COMPARATOR); 098 } 099 } else { 100 final int last = list.size() - 1; 101 if (PairNeuronDouble.COMPARATOR.compare(p, list.get(last)) < 0) { 102 list.set(last, p); // Replace worst entry. 103 if (last > 0) { 104 // Sort if there is more than 1 element. 105 Collections.sort(list, PairNeuronDouble.COMPARATOR); 106 } 107 } 108 } 109 } 110 111 final List<Neuron> result = new ArrayList<>(m); 112 for (final PairNeuronDouble p : list) { 113 result.add(p.getNeuron()); 114 } 115 116 return result; 117 } 118 119 /** 120 * Helper data structure holding a (Neuron, double) pair. 121 */ 122 private static class PairNeuronDouble { 123 /** Comparator. */ 124 static final Comparator<PairNeuronDouble> COMPARATOR 125 = new Comparator<PairNeuronDouble>() { 126 /** {@inheritDoc} */ 127 @Override 128 public int compare(PairNeuronDouble o1, 129 PairNeuronDouble o2) { 130 return Double.compare(o1.value, o2.value); 131 } 132 }; 133 /** Key. */ 134 private final Neuron neuron; 135 /** Value. */ 136 private final double value; 137 138 /** 139 * @param neuron Neuron. 140 * @param value Value. 141 */ 142 PairNeuronDouble(Neuron neuron, double value) { 143 this.neuron = neuron; 144 this.value = value; 145 } 146 147 /** @return the neuron. */ 148 public Neuron getNeuron() { 149 return neuron; 150 } 151 } 152}