1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.stat.ranking;
19
20 import java.util.ArrayList;
21 import java.util.Arrays;
22 import java.util.Iterator;
23 import java.util.List;
24
25 import org.apache.commons.rng.UniformRandomProvider;
26 import org.apache.commons.rng.simple.RandomSource;
27 import org.apache.commons.rng.sampling.distribution.UniformLongSampler;
28 import org.apache.commons.math4.legacy.exception.MathInternalError;
29 import org.apache.commons.math4.legacy.exception.NotANumberException;
30 import org.apache.commons.math4.core.jdkmath.JdkMath;
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74 public class NaturalRanking implements RankingAlgorithm {
75
76
77 public static final NaNStrategy DEFAULT_NAN_STRATEGY = NaNStrategy.FAILED;
78
79
80 public static final TiesStrategy DEFAULT_TIES_STRATEGY = TiesStrategy.AVERAGE;
81
82
83 private final NaNStrategy nanStrategy;
84
85
86 private final TiesStrategy tiesStrategy;
87
88
89 private final UniformRandomProvider random;
90
91
92
93
94 public NaturalRanking() {
95 this(DEFAULT_NAN_STRATEGY, DEFAULT_TIES_STRATEGY, null);
96 }
97
98
99
100
101
102
103 public NaturalRanking(TiesStrategy tiesStrategy) {
104 this(DEFAULT_NAN_STRATEGY,
105 tiesStrategy,
106 RandomSource.WELL_19937_C.create());
107 }
108
109
110
111
112
113
114 public NaturalRanking(NaNStrategy nanStrategy) {
115 this(nanStrategy, DEFAULT_TIES_STRATEGY, null);
116 }
117
118
119
120
121
122
123
124 public NaturalRanking(NaNStrategy nanStrategy,
125 TiesStrategy tiesStrategy) {
126 this(nanStrategy,
127 tiesStrategy,
128 RandomSource.WELL_19937_C.create());
129 }
130
131
132
133
134
135
136
137 public NaturalRanking(UniformRandomProvider randomGenerator) {
138 this(DEFAULT_NAN_STRATEGY, TiesStrategy.RANDOM, randomGenerator);
139 }
140
141
142
143
144
145
146
147
148 public NaturalRanking(NaNStrategy nanStrategy,
149 UniformRandomProvider randomGenerator) {
150 this(nanStrategy, TiesStrategy.RANDOM, randomGenerator);
151 }
152
153
154
155
156
157
158 private NaturalRanking(NaNStrategy nanStrategy,
159 TiesStrategy tiesStrategy,
160 UniformRandomProvider random) {
161 this.nanStrategy = nanStrategy;
162 this.tiesStrategy = tiesStrategy;
163 this.random = random;
164 }
165
166
167
168
169
170
171 public NaNStrategy getNanStrategy() {
172 return nanStrategy;
173 }
174
175
176
177
178
179
180 public TiesStrategy getTiesStrategy() {
181 return tiesStrategy;
182 }
183
184
185
186
187
188
189
190
191
192
193
194 @Override
195 public double[] rank(double[] data) {
196
197
198 IntDoublePair[] ranks = new IntDoublePair[data.length];
199 for (int i = 0; i < data.length; i++) {
200 ranks[i] = new IntDoublePair(data[i], i);
201 }
202
203
204 List<Integer> nanPositions = null;
205 switch (nanStrategy) {
206 case MAXIMAL:
207 recodeNaNs(ranks, Double.POSITIVE_INFINITY);
208 break;
209 case MINIMAL:
210 recodeNaNs(ranks, Double.NEGATIVE_INFINITY);
211 break;
212 case REMOVED:
213 ranks = removeNaNs(ranks);
214 break;
215 case FIXED:
216 nanPositions = getNanPositions(ranks);
217 break;
218 case FAILED:
219 nanPositions = getNanPositions(ranks);
220 if (nanPositions.size() > 0) {
221 throw new NotANumberException();
222 }
223 break;
224 default:
225 throw new MathInternalError();
226 }
227
228
229 Arrays.sort(ranks);
230
231
232
233 double[] out = new double[ranks.length];
234 int pos = 1;
235 out[ranks[0].getPosition()] = pos;
236 List<Integer> tiesTrace = new ArrayList<>();
237 tiesTrace.add(ranks[0].getPosition());
238 for (int i = 1; i < ranks.length; i++) {
239 if (Double.compare(ranks[i].getValue(), ranks[i - 1].getValue()) > 0) {
240
241 pos = i + 1;
242 if (tiesTrace.size() > 1) {
243 resolveTie(out, tiesTrace);
244 }
245 tiesTrace = new ArrayList<>();
246 tiesTrace.add(ranks[i].getPosition());
247 } else {
248
249 tiesTrace.add(ranks[i].getPosition());
250 }
251 out[ranks[i].getPosition()] = pos;
252 }
253 if (tiesTrace.size() > 1) {
254 resolveTie(out, tiesTrace);
255 }
256 if (nanStrategy == NaNStrategy.FIXED) {
257 restoreNaNs(out, nanPositions);
258 }
259 return out;
260 }
261
262
263
264
265
266
267
268
269 private IntDoublePair[] removeNaNs(IntDoublePair[] ranks) {
270 if (!containsNaNs(ranks)) {
271 return ranks;
272 }
273 IntDoublePair[] outRanks = new IntDoublePair[ranks.length];
274 int j = 0;
275 for (int i = 0; i < ranks.length; i++) {
276 if (Double.isNaN(ranks[i].getValue())) {
277
278 for (int k = i + 1; k < ranks.length; k++) {
279 ranks[k] = new IntDoublePair(
280 ranks[k].getValue(), ranks[k].getPosition() - 1);
281 }
282 } else {
283 outRanks[j] = new IntDoublePair(
284 ranks[i].getValue(), ranks[i].getPosition());
285 j++;
286 }
287 }
288 IntDoublePair[] returnRanks = new IntDoublePair[j];
289 System.arraycopy(outRanks, 0, returnRanks, 0, j);
290 return returnRanks;
291 }
292
293
294
295
296
297
298
299 private void recodeNaNs(IntDoublePair[] ranks, double value) {
300 for (int i = 0; i < ranks.length; i++) {
301 if (Double.isNaN(ranks[i].getValue())) {
302 ranks[i] = new IntDoublePair(
303 value, ranks[i].getPosition());
304 }
305 }
306 }
307
308
309
310
311
312
313
314 private boolean containsNaNs(IntDoublePair[] ranks) {
315 for (int i = 0; i < ranks.length; i++) {
316 if (Double.isNaN(ranks[i].getValue())) {
317 return true;
318 }
319 }
320 return false;
321 }
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337 private void resolveTie(double[] ranks, List<Integer> tiesTrace) {
338
339
340 final double c = ranks[tiesTrace.get(0)];
341
342
343 final int length = tiesTrace.size();
344
345 switch (tiesStrategy) {
346 case AVERAGE:
347 fill(ranks, tiesTrace, (2 * c + length - 1) / 2d);
348 break;
349 case MAXIMUM:
350 fill(ranks, tiesTrace, c + length - 1);
351 break;
352 case MINIMUM:
353 fill(ranks, tiesTrace, c);
354 break;
355 case RANDOM:
356 Iterator<Integer> iterator = tiesTrace.iterator();
357 long f = JdkMath.round(c);
358 final UniformLongSampler sampler = UniformLongSampler.of(random, f, f + length - 1);
359 while (iterator.hasNext()) {
360
361 ranks[iterator.next()] = sampler.sample();
362 }
363 break;
364 case SEQUENTIAL:
365
366 iterator = tiesTrace.iterator();
367 f = JdkMath.round(c);
368 int i = 0;
369 while (iterator.hasNext()) {
370 ranks[iterator.next()] = f + i++;
371 }
372 break;
373 default:
374 throw new MathInternalError();
375 }
376 }
377
378
379
380
381
382
383
384
385 private void fill(double[] data, List<Integer> tiesTrace, double value) {
386 Iterator<Integer> iterator = tiesTrace.iterator();
387 while (iterator.hasNext()) {
388 data[iterator.next()] = value;
389 }
390 }
391
392
393
394
395
396
397
398 private void restoreNaNs(double[] ranks, List<Integer> nanPositions) {
399 if (nanPositions.isEmpty()) {
400 return;
401 }
402 Iterator<Integer> iterator = nanPositions.iterator();
403 while (iterator.hasNext()) {
404 ranks[iterator.next().intValue()] = Double.NaN;
405 }
406 }
407
408
409
410
411
412
413
414 private List<Integer> getNanPositions(IntDoublePair[] ranks) {
415 ArrayList<Integer> out = new ArrayList<>();
416 for (int i = 0; i < ranks.length; i++) {
417 if (Double.isNaN(ranks[i].getValue())) {
418 out.add(Integer.valueOf(i));
419 }
420 }
421 return out;
422 }
423
424
425
426
427
428
429
430 private static final class IntDoublePair implements Comparable<IntDoublePair> {
431
432
433 private final double value;
434
435
436 private final int position;
437
438
439
440
441
442
443 IntDoublePair(double value, int position) {
444 this.value = value;
445 this.position = position;
446 }
447
448
449
450
451
452
453
454
455 @Override
456 public int compareTo(IntDoublePair other) {
457 return Double.compare(value, other.value);
458 }
459
460
461
462
463
464
465
466 public double getValue() {
467 return value;
468 }
469
470
471
472
473
474 public int getPosition() {
475 return position;
476 }
477 }
478 }