View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.statistics.ranking;
18  
19  import java.util.Arrays;
20  import java.util.Objects;
21  import java.util.SplittableRandom;
22  import java.util.function.DoubleUnaryOperator;
23  import java.util.function.IntUnaryOperator;
24  
25  /**
26   * Ranking based on the natural ordering on floating-point values.
27   *
28   * <p>{@link Double#NaN NaNs} are treated according to the configured
29   * {@link NaNStrategy} and ties are handled using the selected
30   * {@link TiesStrategy}. Configuration settings are supplied in optional
31   * constructor arguments. Defaults are {@link NaNStrategy#FAILED} and
32   * {@link TiesStrategy#AVERAGE}, respectively.
33   *
34   * <p>When using {@link TiesStrategy#RANDOM}, a generator of random values in {@code [0, x)}
35   * can be supplied as a {@link IntUnaryOperator} argument; otherwise a default is created
36   * on-demand. The source of randomness can be supplied using a method reference.
37   * The following example creates a ranking with NaN values with the highest
38   * ranking and ties resolved randomly:
39   *
40   * <pre>
41   * NaturalRanking ranking = new NaturalRanking(NaNStrategy.MAXIMAL,
42   *                                             new SplittableRandom()::nextInt);
43   * </pre>
44   *
45   * <p>Note: Using {@link TiesStrategy#RANDOM} is not thread-safe due to the mutable
46   * generator of randomness. Instances not using random resolution of ties are
47   * thread-safe.
48   *
49   * <p>Examples:
50   *
51   * <table border="">
52   * <caption>Examples</caption>
53   * <tr><th colspan="3">
54   * Input data: [20, 17, 30, 42.3, 17, 50, Double.NaN, Double.NEGATIVE_INFINITY, 17]
55   * </th></tr>
56   * <tr><th>NaNStrategy</th><th>TiesStrategy</th>
57   * <th>{@code rank(data)}</th>
58   * <tr>
59   * <td>MAXIMAL</td>
60   * <td>default (ties averaged)</td>
61   * <td>[5, 3, 6, 7, 3, 8, 9, 1, 3]</td></tr>
62   * <tr>
63   * <td>MAXIMAL</td>
64   * <td>MINIMUM</td>
65   * <td>[5, 2, 6, 7, 2, 8, 9, 1, 2]</td></tr>
66   * <tr>
67   * <td>MINIMAL</td>
68   * <td>default (ties averaged]</td>
69   * <td>[6, 4, 7, 8, 4, 9, 1.5, 1.5, 4]</td></tr>
70   * <tr>
71   * <td>REMOVED</td>
72   * <td>SEQUENTIAL</td>
73   * <td>[5, 2, 6, 7, 3, 8, 1, 4]</td></tr>
74   * <tr>
75   * <td>MINIMAL</td>
76   * <td>MAXIMUM</td>
77   * <td>[6, 5, 7, 8, 5, 9, 2, 2, 5]</td></tr>
78   * <tr>
79   * <td>MINIMAL</td>
80   * <td>MAXIMUM</td>
81   * <td>[6, 5, 7, 8, 5, 9, 2, 2, 5]</td></tr>
82   * </table>
83   *
84   * @since 1.1
85   */
86  public class NaturalRanking implements RankingAlgorithm {
87      /** Message for a null user-supplied {@link NaNStrategy}. */
88      private static final String NULL_NAN_STRATEGY = "nanStrategy";
89      /** Message for a null user-supplied {@link TiesStrategy}. */
90      private static final String NULL_TIES_STRATEGY = "tiesStrategy";
91      /** Message for a null user-supplied source of randomness. */
92      private static final String NULL_RANDOM_SOURCE = "randomIntFunction";
93      /** Default NaN strategy. */
94      private static final NaNStrategy DEFAULT_NAN_STRATEGY = NaNStrategy.FAILED;
95      /** Default ties strategy. */
96      private static final TiesStrategy DEFAULT_TIES_STRATEGY = TiesStrategy.AVERAGE;
97      /** Map values to positive infinity. */
98      private static final DoubleUnaryOperator ACTION_POS_INF = x -> Double.POSITIVE_INFINITY;
99      /** Map values to negative infinity. */
100     private static final DoubleUnaryOperator ACTION_NEG_INF = x -> Double.NEGATIVE_INFINITY;
101     /** Raise an exception for values. */
102     private static final DoubleUnaryOperator ACTION_ERROR = operand -> {
103         throw new IllegalArgumentException("Invalid data: " + operand);
104     };
105 
106     /** NaN strategy. */
107     private final NaNStrategy nanStrategy;
108     /** Ties strategy. */
109     private final TiesStrategy tiesStrategy;
110     /** Source of randomness when ties strategy is RANDOM.
111      * Function maps positive x to {@code [0, x)}.
112      * Can be null to default to a JDK implementation. */
113     private IntUnaryOperator randomIntFunction;
114 
115     /**
116      * Creates an instance with {@link NaNStrategy#FAILED} and
117      * {@link TiesStrategy#AVERAGE}.
118      */
119     public NaturalRanking() {
120         this(DEFAULT_NAN_STRATEGY, DEFAULT_TIES_STRATEGY, null);
121     }
122 
123     /**
124      * Creates an instance with {@link NaNStrategy#FAILED} and the
125      * specified @{@code tiesStrategy}.
126      *
127      * <p>If the ties strategy is {@link TiesStrategy#RANDOM RANDOM} a default
128      * source of randomness is used to resolve ties.
129      *
130      * @param tiesStrategy TiesStrategy to use.
131      * @throws NullPointerException if the strategy is {@code null}
132      */
133     public NaturalRanking(TiesStrategy tiesStrategy) {
134         this(DEFAULT_NAN_STRATEGY,
135             Objects.requireNonNull(tiesStrategy, NULL_TIES_STRATEGY), null);
136     }
137 
138     /**
139      * Creates an instance with the specified @{@code nanStrategy} and
140      * {@link TiesStrategy#AVERAGE}.
141      *
142      * @param nanStrategy NaNStrategy to use.
143      * @throws NullPointerException if the strategy is {@code null}
144      */
145     public NaturalRanking(NaNStrategy nanStrategy) {
146         this(Objects.requireNonNull(nanStrategy, NULL_NAN_STRATEGY),
147             DEFAULT_TIES_STRATEGY, null);
148     }
149 
150     /**
151      * Creates an instance with the specified @{@code nanStrategy} and the
152      * specified @{@code tiesStrategy}.
153      *
154      * <p>If the ties strategy is {@link TiesStrategy#RANDOM RANDOM} a default
155      * source of randomness is used to resolve ties.
156      *
157      * @param nanStrategy NaNStrategy to use.
158      * @param tiesStrategy TiesStrategy to use.
159      * @throws NullPointerException if any strategy is {@code null}
160      */
161     public NaturalRanking(NaNStrategy nanStrategy,
162                           TiesStrategy tiesStrategy) {
163         this(Objects.requireNonNull(nanStrategy, NULL_NAN_STRATEGY),
164             Objects.requireNonNull(tiesStrategy, NULL_TIES_STRATEGY), null);
165     }
166 
167     /**
168      * Creates an instance with {@link NaNStrategy#FAILED},
169      * {@link TiesStrategy#RANDOM} and the given the source of random index data.
170      *
171      * @param randomIntFunction Source of random index data.
172      * Function maps positive {@code x} randomly to {@code [0, x)}
173      * @throws NullPointerException if the source of randomness is {@code null}
174      */
175     public NaturalRanking(IntUnaryOperator randomIntFunction) {
176         this(DEFAULT_NAN_STRATEGY, TiesStrategy.RANDOM,
177             Objects.requireNonNull(randomIntFunction, NULL_RANDOM_SOURCE));
178     }
179 
180     /**
181      * Creates an instance with the specified @{@code nanStrategy},
182      * {@link TiesStrategy#RANDOM} and the given the source of random index data.
183      *
184      * @param nanStrategy NaNStrategy to use.
185      * @param randomIntFunction Source of random index data.
186      * Function maps positive {@code x} randomly to {@code [0, x)}
187      * @throws NullPointerException if the strategy or source of randomness are {@code null}
188      */
189     public NaturalRanking(NaNStrategy nanStrategy,
190                           IntUnaryOperator randomIntFunction) {
191         this(Objects.requireNonNull(nanStrategy, NULL_NAN_STRATEGY), TiesStrategy.RANDOM,
192             Objects.requireNonNull(randomIntFunction, NULL_RANDOM_SOURCE));
193     }
194 
195     /**
196      * @param nanStrategy NaNStrategy to use.
197      * @param tiesStrategy TiesStrategy to use.
198      * @param randomIntFunction Source of random index data.
199      */
200     private NaturalRanking(NaNStrategy nanStrategy,
201                            TiesStrategy tiesStrategy,
202                            IntUnaryOperator randomIntFunction) {
203         // User-supplied arguments are checked for non-null in the respective constructor
204         this.nanStrategy = nanStrategy;
205         this.tiesStrategy = tiesStrategy;
206         this.randomIntFunction = randomIntFunction;
207     }
208 
209     /**
210      * Return the {@link NaNStrategy}.
211      *
212      * @return the strategy for handling NaN
213      */
214     public NaNStrategy getNanStrategy() {
215         return nanStrategy;
216     }
217 
218     /**
219      * Return the {@link TiesStrategy}.
220      *
221      * @return the strategy for handling ties
222      */
223     public TiesStrategy getTiesStrategy() {
224         return tiesStrategy;
225     }
226 
227     /**
228      * Rank {@code data} using the natural ordering on floating-point values, with
229      * NaN values handled according to {@code nanStrategy} and ties resolved using
230      * {@code tiesStrategy}.
231      *
232      * @throws IllegalArgumentException if the selected {@link NaNStrategy} is
233      * {@code FAILED} and a {@link Double#NaN} is encountered in the input data.
234      */
235     @Override
236     public double[] apply(double[] data) {
237         // Convert data for sorting.
238         // NaNs are counted for the FIXED strategy.
239         final int[] nanCount = {0};
240         final DataPosition[] ranks = createRankData(data, nanCount);
241 
242         // Sorting will move NaNs to the end and we do not have to resolve ties in them.
243         final int nonNanSize = ranks.length - nanCount[0];
244 
245         // Edge case for empty data
246         if (nonNanSize == 0) {
247             // Either NaN are left in-place or removed
248             return nanStrategy == NaNStrategy.FIXED ? data : new double[0];
249         }
250 
251         Arrays.sort(ranks);
252 
253         // Walk the sorted array, filling output array using sorted positions,
254         // resolving ties as we go.
255         int pos = 1;
256         final double[] out = new double[ranks.length];
257 
258         DataPosition current = ranks[0];
259         out[current.getPosition()] = pos;
260 
261         // Store all previous elements of a tie.
262         // Note this lags behind the length of the tie sequence by 1.
263         // In the event there are no ties this is not used.
264         final IntList tiesTrace = new IntList(ranks.length);
265 
266         for (int i = 1; i < nonNanSize; i++) {
267             final DataPosition previous = current;
268             current = ranks[i];
269             if (current.compareTo(previous) > 0) {
270                 // Check for a previous tie sequence
271                 if (tiesTrace.size() != 0) {
272                     resolveTie(out, tiesTrace, previous.getPosition());
273                 }
274                 pos = i + 1;
275             } else {
276                 // Tie sequence. Add the matching previous element.
277                 tiesTrace.add(previous.getPosition());
278             }
279             out[current.getPosition()] = pos;
280         }
281         // Handle tie sequence at end
282         if (tiesTrace.size() != 0) {
283             resolveTie(out, tiesTrace, current.getPosition());
284         }
285         // For the FIXED strategy consume the remaining NaN elements
286         if (nanStrategy == NaNStrategy.FIXED) {
287             for (int i = nonNanSize; i < ranks.length; i++) {
288                 out[ranks[i].getPosition()] = Double.NaN;
289             }
290         }
291         return out;
292     }
293 
294     /**
295      * Creates the rank data. If using {@link NaNStrategy#REMOVED} then NaNs are
296      * filtered. Otherwise NaNs may be mapped to an infinite value, counted to allow
297      * subsequent processing, or cause an exception to be thrown.
298      *
299      * @param data Source data.
300      * @param nanCount Output counter for NaN values.
301      * @return the rank data
302      * @throws IllegalArgumentException if the data contains NaN values when using
303      * {@link NaNStrategy#FAILED}.
304      */
305     private DataPosition[] createRankData(double[] data, final int[] nanCount) {
306         return nanStrategy == NaNStrategy.REMOVED ?
307                 createNonNaNRankData(data) :
308                 createMappedRankData(data, createNaNAction(nanCount));
309     }
310 
311     /**
312      * Creates the NaN action.
313      *
314      * @param nanCount Output counter for NaN values.
315      * @return the operator applied to NaN values
316      */
317     private DoubleUnaryOperator createNaNAction(int[] nanCount) {
318         // Exhaustive switch statement
319         switch (nanStrategy) {
320         case MAXIMAL: // Replace NaNs with +INFs
321             return ACTION_POS_INF;
322         case MINIMAL: // Replace NaNs with -INFs
323             return ACTION_NEG_INF;
324         case REMOVED: // NaNs are removed
325         case FIXED:   // NaNs are unchanged
326             // Count the NaNs in the data that must be handled
327             return x -> {
328                 nanCount[0]++;
329                 return x;
330             };
331         case FAILED:
332             return ACTION_ERROR;
333         }
334         // Unreachable code
335         throw new IllegalStateException(String.valueOf(nanStrategy));
336     }
337 
338     /**
339      * Creates the rank data with NaNs removed.
340      *
341      * @param data Source data.
342      * @return the rank data
343      */
344     private static DataPosition[] createNonNaNRankData(double[] data) {
345         final DataPosition[] ranks = new DataPosition[data.length];
346         int size = 0;
347         for (final double v : data) {
348             if (!Double.isNaN(v)) {
349                 ranks[size] = new DataPosition(v, size);
350                 size++;
351             }
352         }
353         return size == data.length ? ranks : Arrays.copyOf(ranks, size);
354     }
355 
356     /**
357      * Creates the rank data.
358      *
359      * @param data Source data.
360      * @param nanAction Mapping operator applied to NaN values.
361      * @return the rank data
362      */
363     private static DataPosition[] createMappedRankData(double[] data, DoubleUnaryOperator nanAction) {
364         final DataPosition[] ranks = new DataPosition[data.length];
365         for (int i = 0; i < data.length; i++) {
366             double v = data[i];
367             if (Double.isNaN(v)) {
368                 v = nanAction.applyAsDouble(v);
369             }
370             ranks[i] = new DataPosition(v, i);
371         }
372         return ranks;
373     }
374 
375     /**
376      * Resolve a sequence of ties, using the configured {@link TiesStrategy}. The
377      * input {@code ranks} array is expected to take the same value for all indices
378      * in {@code tiesTrace}. The common value is recoded according to the
379      * tiesStrategy. For example, if ranks = [5,8,2,6,2,7,1,2], tiesTrace = [2,4,7]
380      * and tiesStrategy is MINIMUM, ranks will be unchanged. The same array and
381      * trace with tiesStrategy AVERAGE will come out [5,8,3,6,3,7,1,3].
382      *
383      * <p>Note: For convenience the final index of the trace is passed as an argument;
384      * it is assumed the list is already non-empty. At the end of the method the
385      * list of indices is cleared.
386      *
387      * @param ranks Array of ranks.
388      * @param tiesTrace List of indices where {@code ranks} is constant, that is,
389      * for any i and j in {@code tiesTrace}: {@code ranks[i] == ranks[j]}.
390      * @param finalIndex The final index to add to the sequence of ties.
391      */
392     private void resolveTie(double[] ranks, IntList tiesTrace, int finalIndex) {
393         tiesTrace.add(finalIndex);
394 
395         // Constant value of ranks over tiesTrace.
396         // Note: c is a rank counter starting from 1 so limited to an int.
397         final double c = ranks[tiesTrace.get(0)];
398 
399         // length of sequence of tied ranks
400         final int length = tiesTrace.size();
401 
402         // Exhaustive switch
403         switch (tiesStrategy) {
404         case  AVERAGE:   // Replace ranks with average: (lower + upper) / 2
405             fill(ranks, tiesTrace, (2 * c + length - 1) * 0.5);
406             break;
407         case MAXIMUM:    // Replace ranks with maximum values
408             fill(ranks, tiesTrace, c + length - 1);
409             break;
410         case MINIMUM:    // Replace ties with minimum
411             // Note that the tie sequence already has all values set to c so
412             // no requirement to fill again.
413             break;
414         case SEQUENTIAL: // Fill sequentially from c to c + length - 1
415         case RANDOM:     // Fill with randomized sequential values in [c, c + length - 1]
416             // This cast is safe as c is a counter.
417             int r = (int) c;
418             if (tiesStrategy == TiesStrategy.RANDOM) {
419                 tiesTrace.shuffle(getRandomIntFunction());
420             }
421             final int size = tiesTrace.size();
422             for (int i = 0; i < size; i++) {
423                 ranks[tiesTrace.get(i)] = r++;
424             }
425             break;
426         }
427 
428         tiesTrace.clear();
429     }
430 
431     /**
432      * Sets {@code data[i] = value} for each i in {@code tiesTrace}.
433      *
434      * @param data Array to modify.
435      * @param tiesTrace List of index values to set.
436      * @param value Value to set.
437      */
438     private static void fill(double[] data, IntList tiesTrace, double value) {
439         final int size = tiesTrace.size();
440         for (int i = 0; i < size; i++) {
441             data[tiesTrace.get(i)] = value;
442         }
443     }
444 
445     /**
446      * Gets the function to map positive {@code x} randomly to {@code [0, x)}.
447      * Defaults to a system provided generator if the constructor source of randomness is null.
448      *
449      * @return the RNG
450      */
451     private IntUnaryOperator getRandomIntFunction() {
452         IntUnaryOperator r = randomIntFunction;
453         if (r == null) {
454             // Default to a SplittableRandom
455             randomIntFunction = r = new SplittableRandom()::nextInt;
456         }
457         return r;
458     }
459 
460     /**
461      * An expandable list of int values. This allows tracking array positions
462      * without using boxed values in a {@code List<Integer>}.
463      */
464     private static class IntList {
465         /** The maximum size of array to allocate. */
466         private final int max;
467 
468         /** The size of the list. */
469         private int size;
470         /** The list data. Initialised with space to store a tie of 2 values. */
471         private int[] data = new int[2];
472 
473         /**
474          * @param max Maximum size of array to allocate. Can use the length of the parent array
475          * for which this is used to track indices.
476          */
477         IntList(int max) {
478             this.max = max;
479         }
480 
481         /**
482          * Adds the value to the list.
483          *
484          * @param value the value
485          */
486         void add(int value) {
487             if (size == data.length) {
488                 // Overflow safe doubling of the current size.
489                 data = Arrays.copyOf(data, (int) Math.min(max, size * 2L));
490             }
491             data[size++] = value;
492         }
493 
494         /**
495          * Gets the element at the specified {@code index}.
496          *
497          * @param index Element index
498          * @return the element
499          */
500         int get(int index) {
501             return data[index];
502         }
503 
504         /**
505          * Gets the number of elements in the list.
506          *
507          * @return the size
508          */
509         int size() {
510             return size;
511         }
512 
513         /**
514          * Clear the list.
515          */
516         void clear() {
517             size = 0;
518         }
519 
520         /**
521          * Shuffle the list.
522          *
523          * @param randomIntFunction Function maps positive {@code x} randomly to {@code [0, x)}.
524          */
525         void shuffle(IntUnaryOperator randomIntFunction) {
526             // Fisher-Yates shuffle
527             final int[] array = data;
528             for (int i = size; i > 1; i--) {
529                 swap(array, i - 1, randomIntFunction.applyAsInt(i));
530             }
531         }
532 
533         /**
534          * Swaps the two specified elements in the specified array.
535          *
536          * @param array Data array
537          * @param i     First index
538          * @param j     Second index
539          */
540         private static void swap(int[] array, int i, int j) {
541             final int tmp = array[i];
542             array[i] = array[j];
543             array[j] = tmp;
544         }
545     }
546 
547     /**
548      * Represents the position of a {@code double} value in a data array. The
549      * Comparable interface is implemented so Arrays.sort can be used to sort an
550      * array of data positions by value. Note that the implicitly defined natural
551      * ordering is NOT consistent with equals.
552      */
553     private static class DataPosition implements Comparable<DataPosition>  {
554         /** Data value. */
555         private final double value;
556         /** Data position. */
557         private final int position;
558 
559         /**
560          * Create an instance with the given value and position.
561          *
562          * @param value Data value.
563          * @param position Data position.
564          */
565         DataPosition(double value, int position) {
566             this.value = value;
567             this.position = position;
568         }
569 
570         /**
571          * Compare this value to another.
572          * Only the <strong>values</strong> are compared.
573          *
574          * @param other the other pair to compare this to
575          * @return result of {@code Double.compare(value, other.value)}
576          */
577         @Override
578         public int compareTo(DataPosition other) {
579             return Double.compare(value, other.value);
580         }
581 
582         // equals() and hashCode() are not implemented; see MATH-610 for discussion.
583 
584         /**
585          * Returns the data position.
586          *
587          * @return position
588          */
589         int getPosition() {
590             return position;
591         }
592     }
593 }