001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.statistics.ranking;
018
019import java.util.Arrays;
020import java.util.Objects;
021import java.util.SplittableRandom;
022import java.util.function.DoubleUnaryOperator;
023import java.util.function.IntUnaryOperator;
024
025/**
026 * Ranking based on the natural ordering on floating-point values.
027 *
028 * <p>{@link Double#NaN NaNs} are treated according to the configured
029 * {@link NaNStrategy} and ties are handled using the selected
030 * {@link TiesStrategy}. Configuration settings are supplied in optional
031 * constructor arguments. Defaults are {@link NaNStrategy#FAILED} and
032 * {@link TiesStrategy#AVERAGE}, respectively.
033 *
034 * <p>When using {@link TiesStrategy#RANDOM}, a generator of random values in {@code [0, x)}
035 * can be supplied as a {@link IntUnaryOperator} argument; otherwise a default is created
036 * on-demand. The source of randomness can be supplied using a method reference.
037 * The following example creates a ranking with NaN values with the highest
038 * ranking and ties resolved randomly:
039 *
040 * <pre>
041 * NaturalRanking ranking = new NaturalRanking(NaNStrategy.MAXIMAL,
042 *                                             new SplittableRandom()::nextInt);
043 * </pre>
044 *
045 * <p>Note: Using {@link TiesStrategy#RANDOM} is not thread-safe due to the mutable
046 * generator of randomness. Instances not using random resolution of ties are
047 * thread-safe.
048 *
049 * <p>Examples:
050 *
051 * <table border="">
052 * <caption>Examples</caption>
053 * <tr><th colspan="3">
054 * Input data: [20, 17, 30, 42.3, 17, 50, Double.NaN, Double.NEGATIVE_INFINITY, 17]
055 * </th></tr>
056 * <tr><th>NaNStrategy</th><th>TiesStrategy</th>
057 * <th>{@code rank(data)}</th>
058 * <tr>
059 * <td>MAXIMAL</td>
060 * <td>default (ties averaged)</td>
061 * <td>[5, 3, 6, 7, 3, 8, 9, 1, 3]</td></tr>
062 * <tr>
063 * <td>MAXIMAL</td>
064 * <td>MINIMUM</td>
065 * <td>[5, 2, 6, 7, 2, 8, 9, 1, 2]</td></tr>
066 * <tr>
067 * <td>MINIMAL</td>
068 * <td>default (ties averaged]</td>
069 * <td>[6, 4, 7, 8, 4, 9, 1.5, 1.5, 4]</td></tr>
070 * <tr>
071 * <td>REMOVED</td>
072 * <td>SEQUENTIAL</td>
073 * <td>[5, 2, 6, 7, 3, 8, 1, 4]</td></tr>
074 * <tr>
075 * <td>MINIMAL</td>
076 * <td>MAXIMUM</td>
077 * <td>[6, 5, 7, 8, 5, 9, 2, 2, 5]</td></tr>
078 * <tr>
079 * <td>MINIMAL</td>
080 * <td>MAXIMUM</td>
081 * <td>[6, 5, 7, 8, 5, 9, 2, 2, 5]</td></tr>
082 * </table>
083 *
084 * @since 1.1
085 */
086public class NaturalRanking implements RankingAlgorithm {
087    /** Message for a null user-supplied {@link NaNStrategy}. */
088    private static final String NULL_NAN_STRATEGY = "nanStrategy";
089    /** Message for a null user-supplied {@link TiesStrategy}. */
090    private static final String NULL_TIES_STRATEGY = "tiesStrategy";
091    /** Message for a null user-supplied source of randomness. */
092    private static final String NULL_RANDOM_SOURCE = "randomIntFunction";
093    /** Default NaN strategy. */
094    private static final NaNStrategy DEFAULT_NAN_STRATEGY = NaNStrategy.FAILED;
095    /** Default ties strategy. */
096    private static final TiesStrategy DEFAULT_TIES_STRATEGY = TiesStrategy.AVERAGE;
097    /** Map values to positive infinity. */
098    private static final DoubleUnaryOperator ACTION_POS_INF = x -> Double.POSITIVE_INFINITY;
099    /** Map values to negative infinity. */
100    private static final DoubleUnaryOperator ACTION_NEG_INF = x -> Double.NEGATIVE_INFINITY;
101    /** Raise an exception for values. */
102    private static final DoubleUnaryOperator ACTION_ERROR = operand -> {
103        throw new IllegalArgumentException("Invalid data: " + operand);
104    };
105
106    /** NaN strategy. */
107    private final NaNStrategy nanStrategy;
108    /** Ties strategy. */
109    private final TiesStrategy tiesStrategy;
110    /** Source of randomness when ties strategy is RANDOM.
111     * Function maps positive x to {@code [0, x)}.
112     * Can be null to default to a JDK implementation. */
113    private IntUnaryOperator randomIntFunction;
114
115    /**
116     * Creates an instance with {@link NaNStrategy#FAILED} and
117     * {@link TiesStrategy#AVERAGE}.
118     */
119    public NaturalRanking() {
120        this(DEFAULT_NAN_STRATEGY, DEFAULT_TIES_STRATEGY, null);
121    }
122
123    /**
124     * Creates an instance with {@link NaNStrategy#FAILED} and the
125     * specified @{@code tiesStrategy}.
126     *
127     * <p>If the ties strategy is {@link TiesStrategy#RANDOM RANDOM} a default
128     * source of randomness is used to resolve ties.
129     *
130     * @param tiesStrategy TiesStrategy to use.
131     * @throws NullPointerException if the strategy is {@code null}
132     */
133    public NaturalRanking(TiesStrategy tiesStrategy) {
134        this(DEFAULT_NAN_STRATEGY,
135            Objects.requireNonNull(tiesStrategy, NULL_TIES_STRATEGY), null);
136    }
137
138    /**
139     * Creates an instance with the specified @{@code nanStrategy} and
140     * {@link TiesStrategy#AVERAGE}.
141     *
142     * @param nanStrategy NaNStrategy to use.
143     * @throws NullPointerException if the strategy is {@code null}
144     */
145    public NaturalRanking(NaNStrategy nanStrategy) {
146        this(Objects.requireNonNull(nanStrategy, NULL_NAN_STRATEGY),
147            DEFAULT_TIES_STRATEGY, null);
148    }
149
150    /**
151     * Creates an instance with the specified @{@code nanStrategy} and the
152     * specified @{@code tiesStrategy}.
153     *
154     * <p>If the ties strategy is {@link TiesStrategy#RANDOM RANDOM} a default
155     * source of randomness is used to resolve ties.
156     *
157     * @param nanStrategy NaNStrategy to use.
158     * @param tiesStrategy TiesStrategy to use.
159     * @throws NullPointerException if any strategy is {@code null}
160     */
161    public NaturalRanking(NaNStrategy nanStrategy,
162                          TiesStrategy tiesStrategy) {
163        this(Objects.requireNonNull(nanStrategy, NULL_NAN_STRATEGY),
164            Objects.requireNonNull(tiesStrategy, NULL_TIES_STRATEGY), null);
165    }
166
167    /**
168     * Creates an instance with {@link NaNStrategy#FAILED},
169     * {@link TiesStrategy#RANDOM} and the given the source of random index data.
170     *
171     * @param randomIntFunction Source of random index data.
172     * Function maps positive {@code x} randomly to {@code [0, x)}
173     * @throws NullPointerException if the source of randomness is {@code null}
174     */
175    public NaturalRanking(IntUnaryOperator randomIntFunction) {
176        this(DEFAULT_NAN_STRATEGY, TiesStrategy.RANDOM,
177            Objects.requireNonNull(randomIntFunction, NULL_RANDOM_SOURCE));
178    }
179
180    /**
181     * Creates an instance with the specified @{@code nanStrategy},
182     * {@link TiesStrategy#RANDOM} and the given the source of random index data.
183     *
184     * @param nanStrategy NaNStrategy to use.
185     * @param randomIntFunction Source of random index data.
186     * Function maps positive {@code x} randomly to {@code [0, x)}
187     * @throws NullPointerException if the strategy or source of randomness are {@code null}
188     */
189    public NaturalRanking(NaNStrategy nanStrategy,
190                          IntUnaryOperator randomIntFunction) {
191        this(Objects.requireNonNull(nanStrategy, NULL_NAN_STRATEGY), TiesStrategy.RANDOM,
192            Objects.requireNonNull(randomIntFunction, NULL_RANDOM_SOURCE));
193    }
194
195    /**
196     * @param nanStrategy NaNStrategy to use.
197     * @param tiesStrategy TiesStrategy to use.
198     * @param randomIntFunction Source of random index data.
199     */
200    private NaturalRanking(NaNStrategy nanStrategy,
201                           TiesStrategy tiesStrategy,
202                           IntUnaryOperator randomIntFunction) {
203        // User-supplied arguments are checked for non-null in the respective constructor
204        this.nanStrategy = nanStrategy;
205        this.tiesStrategy = tiesStrategy;
206        this.randomIntFunction = randomIntFunction;
207    }
208
209    /**
210     * Return the {@link NaNStrategy}.
211     *
212     * @return the strategy for handling NaN
213     */
214    public NaNStrategy getNanStrategy() {
215        return nanStrategy;
216    }
217
218    /**
219     * Return the {@link TiesStrategy}.
220     *
221     * @return the strategy for handling ties
222     */
223    public TiesStrategy getTiesStrategy() {
224        return tiesStrategy;
225    }
226
227    /**
228     * Rank {@code data} using the natural ordering on floating-point values, with
229     * NaN values handled according to {@code nanStrategy} and ties resolved using
230     * {@code tiesStrategy}.
231     *
232     * @throws IllegalArgumentException if the selected {@link NaNStrategy} is
233     * {@code FAILED} and a {@link Double#NaN} is encountered in the input data.
234     */
235    @Override
236    public double[] apply(double[] data) {
237        // Convert data for sorting.
238        // NaNs are counted for the FIXED strategy.
239        final int[] nanCount = {0};
240        final DataPosition[] ranks = createRankData(data, nanCount);
241
242        // Sorting will move NaNs to the end and we do not have to resolve ties in them.
243        final int nonNanSize = ranks.length - nanCount[0];
244
245        // Edge case for empty data
246        if (nonNanSize == 0) {
247            // Either NaN are left in-place or removed
248            return nanStrategy == NaNStrategy.FIXED ? data : new double[0];
249        }
250
251        Arrays.sort(ranks);
252
253        // Walk the sorted array, filling output array using sorted positions,
254        // resolving ties as we go.
255        int pos = 1;
256        final double[] out = new double[ranks.length];
257
258        DataPosition current = ranks[0];
259        out[current.getPosition()] = pos;
260
261        // Store all previous elements of a tie.
262        // Note this lags behind the length of the tie sequence by 1.
263        // In the event there are no ties this is not used.
264        final IntList tiesTrace = new IntList(ranks.length);
265
266        for (int i = 1; i < nonNanSize; i++) {
267            final DataPosition previous = current;
268            current = ranks[i];
269            if (current.compareTo(previous) > 0) {
270                // Check for a previous tie sequence
271                if (tiesTrace.size() != 0) {
272                    resolveTie(out, tiesTrace, previous.getPosition());
273                }
274                pos = i + 1;
275            } else {
276                // Tie sequence. Add the matching previous element.
277                tiesTrace.add(previous.getPosition());
278            }
279            out[current.getPosition()] = pos;
280        }
281        // Handle tie sequence at end
282        if (tiesTrace.size() != 0) {
283            resolveTie(out, tiesTrace, current.getPosition());
284        }
285        // For the FIXED strategy consume the remaining NaN elements
286        if (nanStrategy == NaNStrategy.FIXED) {
287            for (int i = nonNanSize; i < ranks.length; i++) {
288                out[ranks[i].getPosition()] = Double.NaN;
289            }
290        }
291        return out;
292    }
293
294    /**
295     * Creates the rank data. If using {@link NaNStrategy#REMOVED} then NaNs are
296     * filtered. Otherwise NaNs may be mapped to an infinite value, counted to allow
297     * subsequent processing, or cause an exception to be thrown.
298     *
299     * @param data Source data.
300     * @param nanCount Output counter for NaN values.
301     * @return the rank data
302     * @throws IllegalArgumentException if the data contains NaN values when using
303     * {@link NaNStrategy#FAILED}.
304     */
305    private DataPosition[] createRankData(double[] data, final int[] nanCount) {
306        return nanStrategy == NaNStrategy.REMOVED ?
307                createNonNaNRankData(data) :
308                createMappedRankData(data, createNaNAction(nanCount));
309    }
310
311    /**
312     * Creates the NaN action.
313     *
314     * @param nanCount Output counter for NaN values.
315     * @return the operator applied to NaN values
316     */
317    private DoubleUnaryOperator createNaNAction(int[] nanCount) {
318        // Exhaustive switch statement
319        switch (nanStrategy) {
320        case MAXIMAL: // Replace NaNs with +INFs
321            return ACTION_POS_INF;
322        case MINIMAL: // Replace NaNs with -INFs
323            return ACTION_NEG_INF;
324        case REMOVED: // NaNs are removed
325        case FIXED:   // NaNs are unchanged
326            // Count the NaNs in the data that must be handled
327            return x -> {
328                nanCount[0]++;
329                return x;
330            };
331        case FAILED:
332            return ACTION_ERROR;
333        }
334        // Unreachable code
335        throw new IllegalStateException(String.valueOf(nanStrategy));
336    }
337
338    /**
339     * Creates the rank data with NaNs removed.
340     *
341     * @param data Source data.
342     * @return the rank data
343     */
344    private static DataPosition[] createNonNaNRankData(double[] data) {
345        final DataPosition[] ranks = new DataPosition[data.length];
346        int size = 0;
347        for (final double v : data) {
348            if (!Double.isNaN(v)) {
349                ranks[size] = new DataPosition(v, size);
350                size++;
351            }
352        }
353        return size == data.length ? ranks : Arrays.copyOf(ranks, size);
354    }
355
356    /**
357     * Creates the rank data.
358     *
359     * @param data Source data.
360     * @param nanAction Mapping operator applied to NaN values.
361     * @return the rank data
362     */
363    private static DataPosition[] createMappedRankData(double[] data, DoubleUnaryOperator nanAction) {
364        final DataPosition[] ranks = new DataPosition[data.length];
365        for (int i = 0; i < data.length; i++) {
366            double v = data[i];
367            if (Double.isNaN(v)) {
368                v = nanAction.applyAsDouble(v);
369            }
370            ranks[i] = new DataPosition(v, i);
371        }
372        return ranks;
373    }
374
375    /**
376     * Resolve a sequence of ties, using the configured {@link TiesStrategy}. The
377     * input {@code ranks} array is expected to take the same value for all indices
378     * in {@code tiesTrace}. The common value is recoded according to the
379     * tiesStrategy. For example, if ranks = [5,8,2,6,2,7,1,2], tiesTrace = [2,4,7]
380     * and tiesStrategy is MINIMUM, ranks will be unchanged. The same array and
381     * trace with tiesStrategy AVERAGE will come out [5,8,3,6,3,7,1,3].
382     *
383     * <p>Note: For convenience the final index of the trace is passed as an argument;
384     * it is assumed the list is already non-empty. At the end of the method the
385     * list of indices is cleared.
386     *
387     * @param ranks Array of ranks.
388     * @param tiesTrace List of indices where {@code ranks} is constant, that is,
389     * for any i and j in {@code tiesTrace}: {@code ranks[i] == ranks[j]}.
390     * @param finalIndex The final index to add to the sequence of ties.
391     */
392    private void resolveTie(double[] ranks, IntList tiesTrace, int finalIndex) {
393        tiesTrace.add(finalIndex);
394
395        // Constant value of ranks over tiesTrace.
396        // Note: c is a rank counter starting from 1 so limited to an int.
397        final double c = ranks[tiesTrace.get(0)];
398
399        // length of sequence of tied ranks
400        final int length = tiesTrace.size();
401
402        // Exhaustive switch
403        switch (tiesStrategy) {
404        case  AVERAGE:   // Replace ranks with average: (lower + upper) / 2
405            fill(ranks, tiesTrace, (2 * c + length - 1) * 0.5);
406            break;
407        case MAXIMUM:    // Replace ranks with maximum values
408            fill(ranks, tiesTrace, c + length - 1);
409            break;
410        case MINIMUM:    // Replace ties with minimum
411            // Note that the tie sequence already has all values set to c so
412            // no requirement to fill again.
413            break;
414        case SEQUENTIAL: // Fill sequentially from c to c + length - 1
415        case RANDOM:     // Fill with randomized sequential values in [c, c + length - 1]
416            // This cast is safe as c is a counter.
417            int r = (int) c;
418            if (tiesStrategy == TiesStrategy.RANDOM) {
419                tiesTrace.shuffle(getRandomIntFunction());
420            }
421            final int size = tiesTrace.size();
422            for (int i = 0; i < size; i++) {
423                ranks[tiesTrace.get(i)] = r++;
424            }
425            break;
426        }
427
428        tiesTrace.clear();
429    }
430
431    /**
432     * Sets {@code data[i] = value} for each i in {@code tiesTrace}.
433     *
434     * @param data Array to modify.
435     * @param tiesTrace List of index values to set.
436     * @param value Value to set.
437     */
438    private static void fill(double[] data, IntList tiesTrace, double value) {
439        final int size = tiesTrace.size();
440        for (int i = 0; i < size; i++) {
441            data[tiesTrace.get(i)] = value;
442        }
443    }
444
445    /**
446     * Gets the function to map positive {@code x} randomly to {@code [0, x)}.
447     * Defaults to a system provided generator if the constructor source of randomness is null.
448     *
449     * @return the RNG
450     */
451    private IntUnaryOperator getRandomIntFunction() {
452        IntUnaryOperator r = randomIntFunction;
453        if (r == null) {
454            // Default to a SplittableRandom
455            randomIntFunction = r = new SplittableRandom()::nextInt;
456        }
457        return r;
458    }
459
460    /**
461     * An expandable list of int values. This allows tracking array positions
462     * without using boxed values in a {@code List<Integer>}.
463     */
464    private static class IntList {
465        /** The maximum size of array to allocate. */
466        private final int max;
467
468        /** The size of the list. */
469        private int size;
470        /** The list data. Initialised with space to store a tie of 2 values. */
471        private int[] data = new int[2];
472
473        /**
474         * @param max Maximum size of array to allocate. Can use the length of the parent array
475         * for which this is used to track indices.
476         */
477        IntList(int max) {
478            this.max = max;
479        }
480
481        /**
482         * Adds the value to the list.
483         *
484         * @param value the value
485         */
486        void add(int value) {
487            if (size == data.length) {
488                // Overflow safe doubling of the current size.
489                data = Arrays.copyOf(data, (int) Math.min(max, size * 2L));
490            }
491            data[size++] = value;
492        }
493
494        /**
495         * Gets the element at the specified {@code index}.
496         *
497         * @param index Element index
498         * @return the element
499         */
500        int get(int index) {
501            return data[index];
502        }
503
504        /**
505         * Gets the number of elements in the list.
506         *
507         * @return the size
508         */
509        int size() {
510            return size;
511        }
512
513        /**
514         * Clear the list.
515         */
516        void clear() {
517            size = 0;
518        }
519
520        /**
521         * Shuffle the list.
522         *
523         * @param randomIntFunction Function maps positive {@code x} randomly to {@code [0, x)}.
524         */
525        void shuffle(IntUnaryOperator randomIntFunction) {
526            // Fisher-Yates shuffle
527            final int[] array = data;
528            for (int i = size; i > 1; i--) {
529                swap(array, i - 1, randomIntFunction.applyAsInt(i));
530            }
531        }
532
533        /**
534         * Swaps the two specified elements in the specified array.
535         *
536         * @param array Data array
537         * @param i     First index
538         * @param j     Second index
539         */
540        private static void swap(int[] array, int i, int j) {
541            final int tmp = array[i];
542            array[i] = array[j];
543            array[j] = tmp;
544        }
545    }
546
547    /**
548     * Represents the position of a {@code double} value in a data array. The
549     * Comparable interface is implemented so Arrays.sort can be used to sort an
550     * array of data positions by value. Note that the implicitly defined natural
551     * ordering is NOT consistent with equals.
552     */
553    private static class DataPosition implements Comparable<DataPosition>  {
554        /** Data value. */
555        private final double value;
556        /** Data position. */
557        private final int position;
558
559        /**
560         * Create an instance with the given value and position.
561         *
562         * @param value Data value.
563         * @param position Data position.
564         */
565        DataPosition(double value, int position) {
566            this.value = value;
567            this.position = position;
568        }
569
570        /**
571         * Compare this value to another.
572         * Only the <strong>values</strong> are compared.
573         *
574         * @param other the other pair to compare this to
575         * @return result of {@code Double.compare(value, other.value)}
576         */
577        @Override
578        public int compareTo(DataPosition other) {
579            return Double.compare(value, other.value);
580        }
581
582        // equals() and hashCode() are not implemented; see MATH-610 for discussion.
583
584        /**
585         * Returns the data position.
586         *
587         * @return position
588         */
589        int getPosition() {
590            return position;
591        }
592    }
593}