1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.math4.neuralnet;
19
20 import java.util.List;
21 import java.util.ArrayList;
22 import java.util.Collections;
23 import java.util.Comparator;
24
25 import org.apache.commons.math4.neuralnet.internal.NeuralNetException;
26
27 /**
28 * Utility for ranking the units (neurons) of a network.
29 *
30 * @since 4.0
31 */
32 public class MapRanking {
33 /** List corresponding to the map passed to the constructor. */
34 private final List<Neuron> map = new ArrayList<>();
35 /** Distance function for sorting. */
36 private final DistanceMeasure distance;
37
38 /**
39 * @param neurons List to be ranked.
40 * No defensive copy is performed.
41 * The {@link #rank(double[],int) created list of units} will
42 * be sorted in increasing order of the {@code distance}.
43 * @param distance Distance function.
44 */
45 public MapRanking(Iterable<Neuron> neurons,
46 DistanceMeasure distance) {
47 this.distance = distance;
48
49 for (final Neuron n : neurons) {
50 map.add(n); // No defensive copy.
51 }
52 }
53
54 /**
55 * Creates a list of the neurons whose features best correspond to the
56 * given {@code features}.
57 *
58 * @param features Data.
59 * @return the list of neurons sorted in decreasing order of distance to
60 * the given data.
61 * @throws IllegalArgumentException if the size of the input is not
62 * compatible with the neurons features size.
63 */
64 public List<Neuron> rank(double[] features) {
65 return rank(features, map.size());
66 }
67
68 /**
69 * Creates a list of the neurons whose features best correspond to the
70 * given {@code features}.
71 *
72 * @param features Data.
73 * @param max Maximum size of the returned list.
74 * @return the list of neurons sorted in decreasing order of distance to
75 * the given data.
76 * @throws IllegalArgumentException if the size of the input is not
77 * compatible with the neurons features size or {@code max <= 0}.
78 */
79 public List<Neuron> rank(double[] features,
80 int max) {
81 if (max <= 0) {
82 throw new NeuralNetException(NeuralNetException.NOT_STRICTLY_POSITIVE, max);
83 }
84 final int m = max <= map.size() ?
85 max :
86 map.size();
87 final List<PairNeuronDouble> list = new ArrayList<>(m);
88
89 for (final Neuron n : map) {
90 final double d = distance.applyAsDouble(n.getFeatures(), features);
91 final PairNeuronDouble p = new PairNeuronDouble(n, d);
92
93 if (list.size() < m) {
94 list.add(p);
95 if (list.size() > 1) {
96 // Sort if there is more than 1 element.
97 Collections.sort(list, PairNeuronDouble.COMPARATOR);
98 }
99 } else {
100 final int last = list.size() - 1;
101 if (PairNeuronDouble.COMPARATOR.compare(p, list.get(last)) < 0) {
102 list.set(last, p); // Replace worst entry.
103 if (last > 0) {
104 // Sort if there is more than 1 element.
105 Collections.sort(list, PairNeuronDouble.COMPARATOR);
106 }
107 }
108 }
109 }
110
111 final List<Neuron> result = new ArrayList<>(m);
112 for (final PairNeuronDouble p : list) {
113 result.add(p.getNeuron());
114 }
115
116 return result;
117 }
118
119 /**
120 * Helper data structure holding a (Neuron, double) pair.
121 */
122 private static class PairNeuronDouble {
123 /** Comparator. */
124 static final Comparator<PairNeuronDouble> COMPARATOR
125 = new Comparator<PairNeuronDouble>() {
126 /** {@inheritDoc} */
127 @Override
128 public int compare(PairNeuronDouble o1,
129 PairNeuronDouble o2) {
130 return Double.compare(o1.value, o2.value);
131 }
132 };
133 /** Key. */
134 private final Neuron neuron;
135 /** Value. */
136 private final double value;
137
138 /**
139 * @param neuron Neuron.
140 * @param value Value.
141 */
142 PairNeuronDouble(Neuron neuron, double value) {
143 this.neuron = neuron;
144 this.value = value;
145 }
146
147 /** @return the neuron. */
148 public Neuron getNeuron() {
149 return neuron;
150 }
151 }
152 }