View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.numbers.arrays;
18  
19  import java.util.Arrays;
20  
21  /**
22   * Support class for sorting arrays.
23   *
24   * <p>Optimal sorting networks are used for small fixed size array sorting.
25   *
26   * <p>Note: Requires that the floating-point data contains no NaN values; sorting
27   * does not respect the order of signed zeros imposed by {@link Double#compare(double, double)}.
28   *
29   * @see <a href="https://en.wikipedia.org/wiki/Sorting_network">Sorting network (Wikipedia)</a>
30   * @see <a href="https://bertdobbelaere.github.io/sorting_networks.html">Sorting Networks (Bert Dobbelaere)</a>
31   *
32   * @since 1.2
33   */
34  final class Sorting {
35  
36      /** No instances. */
37      private Sorting() {}
38  
39      /**
40       * Sorts an array using an insertion sort.
41       *
42       * @param x Data array.
43       * @param left Lower bound (inclusive).
44       * @param right Upper bound (inclusive).
45       */
46      static void sort(double[] x, int left, int right) {
47          for (int i = left; ++i <= right;) {
48              final double v = x[i];
49              // Move preceding higher elements above (if required)
50              if (v < x[i - 1]) {
51                  int j = i;
52                  while (--j >= left && v < x[j]) {
53                      x[j + 1] = x[j];
54                  }
55                  x[j + 1] = v;
56              }
57          }
58      }
59  
60      /**
61       * Sorts the elements at the given distinct indices in an array.
62       *
63       * @param x Data array.
64       * @param a Index.
65       * @param b Index.
66       * @param c Index.
67       */
68      static void sort3(double[] x, int a, int b, int c) {
69          // Decision tree avoiding swaps:
70          // Order [(0,2)]
71          // Move point 1 above point 2 or below point 0
72          final double u = x[a];
73          final double v = x[b];
74          final double w = x[c];
75          if (w < u) {
76              if (v < w) {
77                  x[a] = v;
78                  x[b] = w;
79                  x[c] = u;
80                  return;
81              }
82              if (u < v) {
83                  x[a] = w;
84                  x[b] = u;
85                  x[c] = v;
86                  return;
87              }
88              // w < v < u
89              x[a] = w;
90              x[c] = u;
91              return;
92          }
93          if (v < u) {
94              // v < u < w
95              x[a] = v;
96              x[b] = u;
97              return;
98          }
99          if (w < v) {
100             // u < w < v
101             x[b] = w;
102             x[c] = v;
103         }
104         // u < v < w
105     }
106 
107     /**
108      * Sorts the elements at the given distinct indices in an array.
109      *
110      * @param x Data array.
111      * @param a Index.
112      * @param b Index.
113      * @param c Index.
114      * @param d Index.
115      * @param e Index.
116      */
117     static void sort5(double[] x, int a, int b, int c, int d, int e) {
118         // Uses an optimal sorting network from Knuth's Art of Computer Programming.
119         // 9 comparisons.
120         // Order pairs:
121         // [(0,3),(1,4)]
122         // [(0,2),(1,3)]
123         // [(0,1),(2,4)]
124         // [(1,2),(3,4)]
125         // [(2,3)]
126         if (x[e] < x[b]) {
127             final double u = x[e];
128             x[e] = x[b];
129             x[b] = u;
130         }
131         if (x[d] < x[a]) {
132             final double v = x[d];
133             x[d] = x[a];
134             x[a] = v;
135         }
136 
137         if (x[d] < x[b]) {
138             final double u = x[d];
139             x[d] = x[b];
140             x[b] = u;
141         }
142         if (x[c] < x[a]) {
143             final double v = x[c];
144             x[c] = x[a];
145             x[a] = v;
146         }
147 
148         if (x[e] < x[c]) {
149             final double u = x[e];
150             x[e] = x[c];
151             x[c] = u;
152         }
153         if (x[b] < x[a]) {
154             final double v = x[b];
155             x[b] = x[a];
156             x[a] = v;
157         }
158 
159         if (x[e] < x[d]) {
160             final double u = x[e];
161             x[e] = x[d];
162             x[d] = u;
163         }
164         if (x[c] < x[b]) {
165             final double v = x[c];
166             x[c] = x[b];
167             x[b] = v;
168         }
169 
170         if (x[d] < x[c]) {
171             final double u = x[d];
172             x[d] = x[c];
173             x[c] = u;
174         }
175     }
176 
177     /**
178      * Place the lower median of 4 elements in {@code b}; the smaller element in
179      * {@code a}; and the larger two elements in {@code c, d}.
180      *
181      * @param x Values
182      * @param a Index.
183      * @param b Index.
184      * @param c Index.
185      * @param d Index.
186      */
187     static void lowerMedian4(double[] x, int a, int b, int c, int d) {
188         // 3 to 5 comparisons
189         if (x[d] < x[b]) {
190             final double u = x[d];
191             x[d] = x[b];
192             x[b] = u;
193         }
194         if (x[c] < x[a]) {
195             final double v = x[c];
196             x[c] = x[a];
197             x[a] = v;
198         }
199         // a--c
200         // b--d
201         if (x[c] < x[b]) {
202             final double u = x[c];
203             x[c] = x[b];
204             x[b] = u;
205         } else if (x[b] < x[a]) {
206             //    a--c
207             // b--d
208             final double xb = x[a];
209             x[a] = x[b];
210             x[b] = xb;
211             //    b--c
212             // a--d
213             if (x[d] < xb) {
214                 x[b] = x[d];
215                 // Move a pair to maintain the sorted order
216                 x[d] = x[c];
217                 x[c] = xb;
218             }
219         }
220     }
221 
222     /**
223      * Place the upper median of 4 elements in {@code c}; the smaller two elements in
224      * {@code a,b}; and the larger element in {@code d}.
225      *
226      * @param x Values
227      * @param a Index.
228      * @param b Index.
229      * @param c Index.
230      * @param d Index.
231      */
232     static void upperMedian4(double[] x, int a, int b, int c, int d) {
233         // 3 to 5 comparisons
234         if (x[d] < x[b]) {
235             final double u = x[d];
236             x[d] = x[b];
237             x[b] = u;
238         }
239         if (x[c] < x[a]) {
240             final double v = x[c];
241             x[c] = x[a];
242             x[a] = v;
243         }
244         // a--c
245         // b--d
246         if (x[b] > x[c]) {
247             final double u = x[c];
248             x[c] = x[b];
249             x[b] = u;
250         } else if (x[c] > x[d]) {
251             //    a--c
252             // b--d
253             final double xc = x[d];
254             x[d] = x[c];
255             x[c] = xc;
256             //    a--d
257             // b--c
258             if (x[a] > xc) {
259                 x[c] = x[a];
260                 // Move a pair to maintain the sorted order
261                 x[a] = x[b];
262                 x[b] = xc;
263             }
264         }
265     }
266 
267     /**
268      * Sorts an array using an insertion sort.
269      *
270      * @param x Data array.
271      * @param left Lower bound (inclusive).
272      * @param right Upper bound (inclusive).
273      */
274     static void sort(int[] x, int left, int right) {
275         for (int i = left; ++i <= right;) {
276             final int v = x[i];
277             // Move preceding higher elements above (if required)
278             if (v < x[i - 1]) {
279                 int j = i;
280                 while (--j >= left && v < x[j]) {
281                     x[j + 1] = x[j];
282                 }
283                 x[j + 1] = v;
284             }
285         }
286     }
287 
288     /**
289      * Sorts the elements at the given distinct indices in an array.
290      *
291      * @param x Data array.
292      * @param a Index.
293      * @param b Index.
294      * @param c Index.
295      */
296     static void sort3(int[] x, int a, int b, int c) {
297         // Decision tree avoiding swaps:
298         // Order [(0,2)]
299         // Move point 1 above point 2 or below point 0
300         final int u = x[a];
301         final int v = x[b];
302         final int w = x[c];
303         if (w < u) {
304             if (v < w) {
305                 x[a] = v;
306                 x[b] = w;
307                 x[c] = u;
308                 return;
309             }
310             if (u < v) {
311                 x[a] = w;
312                 x[b] = u;
313                 x[c] = v;
314                 return;
315             }
316             // w < v < u
317             x[a] = w;
318             x[c] = u;
319             return;
320         }
321         if (v < u) {
322             // v < u < w
323             x[a] = v;
324             x[b] = u;
325             return;
326         }
327         if (w < v) {
328             // u < w < v
329             x[b] = w;
330             x[c] = v;
331         }
332         // u < v < w
333     }
334 
335     /**
336      * Sorts the elements at the given distinct indices in an array.
337      *
338      * @param x Data array.
339      * @param a Index.
340      * @param b Index.
341      * @param c Index.
342      * @param d Index.
343      * @param e Index.
344      */
345     static void sort5(int[] x, int a, int b, int c, int d, int e) {
346         // Uses an optimal sorting network from Knuth's Art of Computer Programming.
347         // 9 comparisons.
348         // Order pairs:
349         // [(0,3),(1,4)]
350         // [(0,2),(1,3)]
351         // [(0,1),(2,4)]
352         // [(1,2),(3,4)]
353         // [(2,3)]
354         if (x[e] < x[b]) {
355             final int u = x[e];
356             x[e] = x[b];
357             x[b] = u;
358         }
359         if (x[d] < x[a]) {
360             final int v = x[d];
361             x[d] = x[a];
362             x[a] = v;
363         }
364 
365         if (x[d] < x[b]) {
366             final int u = x[d];
367             x[d] = x[b];
368             x[b] = u;
369         }
370         if (x[c] < x[a]) {
371             final int v = x[c];
372             x[c] = x[a];
373             x[a] = v;
374         }
375 
376         if (x[e] < x[c]) {
377             final int u = x[e];
378             x[e] = x[c];
379             x[c] = u;
380         }
381         if (x[b] < x[a]) {
382             final int v = x[b];
383             x[b] = x[a];
384             x[a] = v;
385         }
386 
387         if (x[e] < x[d]) {
388             final int u = x[e];
389             x[e] = x[d];
390             x[d] = u;
391         }
392         if (x[c] < x[b]) {
393             final int v = x[c];
394             x[c] = x[b];
395             x[b] = v;
396         }
397 
398         if (x[d] < x[c]) {
399             final int u = x[d];
400             x[d] = x[c];
401             x[c] = u;
402         }
403     }
404 
405     /**
406      * Place the lower median of 4 elements in {@code b}; the smaller element in
407      * {@code a}; and the larger two elements in {@code c, d}.
408      *
409      * @param x Values
410      * @param a Index.
411      * @param b Index.
412      * @param c Index.
413      * @param d Index.
414      */
415     static void lowerMedian4(int[] x, int a, int b, int c, int d) {
416         // 3 to 5 comparisons
417         if (x[d] < x[b]) {
418             final int u = x[d];
419             x[d] = x[b];
420             x[b] = u;
421         }
422         if (x[c] < x[a]) {
423             final int v = x[c];
424             x[c] = x[a];
425             x[a] = v;
426         }
427         // a--c
428         // b--d
429         if (x[c] < x[b]) {
430             final int u = x[c];
431             x[c] = x[b];
432             x[b] = u;
433         } else if (x[b] < x[a]) {
434             //    a--c
435             // b--d
436             final int xb = x[a];
437             x[a] = x[b];
438             x[b] = xb;
439             //    b--c
440             // a--d
441             if (x[d] < xb) {
442                 x[b] = x[d];
443                 // Move a pair to maintain the sorted order
444                 x[d] = x[c];
445                 x[c] = xb;
446             }
447         }
448     }
449 
450     /**
451      * Place the upper median of 4 elements in {@code c}; the smaller two elements in
452      * {@code a,b}; and the larger element in {@code d}.
453      *
454      * @param x Values
455      * @param a Index.
456      * @param b Index.
457      * @param c Index.
458      * @param d Index.
459      */
460     static void upperMedian4(int[] x, int a, int b, int c, int d) {
461         // 3 to 5 comparisons
462         if (x[d] < x[b]) {
463             final int u = x[d];
464             x[d] = x[b];
465             x[b] = u;
466         }
467         if (x[c] < x[a]) {
468             final int v = x[c];
469             x[c] = x[a];
470             x[a] = v;
471         }
472         // a--c
473         // b--d
474         if (x[b] > x[c]) {
475             final int u = x[c];
476             x[c] = x[b];
477             x[b] = u;
478         } else if (x[c] > x[d]) {
479             //    a--c
480             // b--d
481             final int xc = x[d];
482             x[d] = x[c];
483             x[c] = xc;
484             //    a--d
485             // b--c
486             if (x[a] > xc) {
487                 x[c] = x[a];
488                 // Move a pair to maintain the sorted order
489                 x[a] = x[b];
490                 x[b] = xc;
491             }
492         }
493     }
494 
495     /**
496      * Sorts an array using an insertion sort.
497      *
498      * @param x Data array.
499      * @param left Lower bound (inclusive).
500      * @param right Upper bound (inclusive).
501      */
502     static void sort(long[] x, int left, int right) {
503         for (int i = left; ++i <= right;) {
504             final long v = x[i];
505             // Move preceding higher elements above (if required)
506             if (v < x[i - 1]) {
507                 int j = i;
508                 while (--j >= left && v < x[j]) {
509                     x[j + 1] = x[j];
510                 }
511                 x[j + 1] = v;
512             }
513         }
514     }
515 
516     /**
517      * Sorts the elements at the given distinct indices in an array.
518      *
519      * @param x Data array.
520      * @param a Index.
521      * @param b Index.
522      * @param c Index.
523      */
524     static void sort3(long[] x, int a, int b, int c) {
525         // Decision tree avoiding swaps:
526         // Order [(0,2)]
527         // Move point 1 above point 2 or below point 0
528         final long u = x[a];
529         final long v = x[b];
530         final long w = x[c];
531         if (w < u) {
532             if (v < w) {
533                 x[a] = v;
534                 x[b] = w;
535                 x[c] = u;
536                 return;
537             }
538             if (u < v) {
539                 x[a] = w;
540                 x[b] = u;
541                 x[c] = v;
542                 return;
543             }
544             // w < v < u
545             x[a] = w;
546             x[c] = u;
547             return;
548         }
549         if (v < u) {
550             // v < u < w
551             x[a] = v;
552             x[b] = u;
553             return;
554         }
555         if (w < v) {
556             // u < w < v
557             x[b] = w;
558             x[c] = v;
559         }
560         // u < v < w
561     }
562 
563     /**
564      * Sorts the elements at the given distinct indices in an array.
565      *
566      * @param x Data array.
567      * @param a Index.
568      * @param b Index.
569      * @param c Index.
570      * @param d Index.
571      * @param e Index.
572      */
573     static void sort5(long[] x, int a, int b, int c, int d, int e) {
574         // Uses an optimal sorting network from Knuth's Art of Computer Programming.
575         // 9 comparisons.
576         // Order pairs:
577         // [(0,3),(1,4)]
578         // [(0,2),(1,3)]
579         // [(0,1),(2,4)]
580         // [(1,2),(3,4)]
581         // [(2,3)]
582         if (x[e] < x[b]) {
583             final long u = x[e];
584             x[e] = x[b];
585             x[b] = u;
586         }
587         if (x[d] < x[a]) {
588             final long v = x[d];
589             x[d] = x[a];
590             x[a] = v;
591         }
592 
593         if (x[d] < x[b]) {
594             final long u = x[d];
595             x[d] = x[b];
596             x[b] = u;
597         }
598         if (x[c] < x[a]) {
599             final long v = x[c];
600             x[c] = x[a];
601             x[a] = v;
602         }
603 
604         if (x[e] < x[c]) {
605             final long u = x[e];
606             x[e] = x[c];
607             x[c] = u;
608         }
609         if (x[b] < x[a]) {
610             final long v = x[b];
611             x[b] = x[a];
612             x[a] = v;
613         }
614 
615         if (x[e] < x[d]) {
616             final long u = x[e];
617             x[e] = x[d];
618             x[d] = u;
619         }
620         if (x[c] < x[b]) {
621             final long v = x[c];
622             x[c] = x[b];
623             x[b] = v;
624         }
625 
626         if (x[d] < x[c]) {
627             final long u = x[d];
628             x[d] = x[c];
629             x[c] = u;
630         }
631     }
632 
633     /**
634      * Place the lower median of 4 elements in {@code b}; the smaller element in
635      * {@code a}; and the larger two elements in {@code c, d}.
636      *
637      * @param x Values
638      * @param a Index.
639      * @param b Index.
640      * @param c Index.
641      * @param d Index.
642      */
643     static void lowerMedian4(long[] x, int a, int b, int c, int d) {
644         // 3 to 5 comparisons
645         if (x[d] < x[b]) {
646             final long u = x[d];
647             x[d] = x[b];
648             x[b] = u;
649         }
650         if (x[c] < x[a]) {
651             final long v = x[c];
652             x[c] = x[a];
653             x[a] = v;
654         }
655         // a--c
656         // b--d
657         if (x[c] < x[b]) {
658             final long u = x[c];
659             x[c] = x[b];
660             x[b] = u;
661         } else if (x[b] < x[a]) {
662             //    a--c
663             // b--d
664             final long xb = x[a];
665             x[a] = x[b];
666             x[b] = xb;
667             //    b--c
668             // a--d
669             if (x[d] < xb) {
670                 x[b] = x[d];
671                 // Move a pair to maintain the sorted order
672                 x[d] = x[c];
673                 x[c] = xb;
674             }
675         }
676     }
677 
678     /**
679      * Place the upper median of 4 elements in {@code c}; the smaller two elements in
680      * {@code a,b}; and the larger element in {@code d}.
681      *
682      * @param x Values
683      * @param a Index.
684      * @param b Index.
685      * @param c Index.
686      * @param d Index.
687      */
688     static void upperMedian4(long[] x, int a, int b, int c, int d) {
689         // 3 to 5 comparisons
690         if (x[d] < x[b]) {
691             final long u = x[d];
692             x[d] = x[b];
693             x[b] = u;
694         }
695         if (x[c] < x[a]) {
696             final long v = x[c];
697             x[c] = x[a];
698             x[a] = v;
699         }
700         // a--c
701         // b--d
702         if (x[b] > x[c]) {
703             final long u = x[c];
704             x[c] = x[b];
705             x[b] = u;
706         } else if (x[c] > x[d]) {
707             //    a--c
708             // b--d
709             final long xc = x[d];
710             x[d] = x[c];
711             x[c] = xc;
712             //    a--d
713             // b--c
714             if (x[a] > xc) {
715                 x[c] = x[a];
716                 // Move a pair to maintain the sorted order
717                 x[a] = x[b];
718                 x[b] = xc;
719             }
720         }
721     }
722 
723     /**
724      * Sort the unique indices in-place to the start of the array. The number of
725      * unique indices is returned.
726      *
727      * <p>Uses an insertion sort modified to ignore duplicates. Use on small {@code n}.
728      *
729      * <p>Warning: Requires {@code n > 0}. The array contents after the count of unique
730      * indices {@code c} are unchanged (i.e. {@code [c, n)}. This may change the count of
731      * each unique index in the entire array.
732      *
733      * @param x Indices.
734      * @param n Number of indices.
735      * @return the number of unique indices
736      */
737     static int insertionSortIndices(int[] x, int n) {
738         // Index of last unique value
739         int unique = 0;
740         // Do an insertion sort but only compare the current set of unique values.
741         for (int i = 1; i < n; i++) {
742             final int v = x[i];
743             int j = unique;
744             if (v > x[j]) {
745                 // Insert at end
746                 x[++unique] = v;
747             } else if (v < x[j]) {
748                 // Find insertion point in the unique indices
749                 do {
750                     --j;
751                 } while (j >= 0 && v < x[j]);
752                 // Insertion point = j + 1
753                 // Insert if at start or non-duplicate
754                 if (j < 0 || v != x[j]) {
755                     // Move (j, unique] to (j+1, unique+1]
756                     for (int k = unique; k > j; --k) {
757                         x[k + 1] = x[k];
758                     }
759                     x[j + 1] = v;
760                     ++unique;
761                 }
762             }
763         }
764         return unique + 1;
765     }
766 
767     /**
768      * Sort the unique indices in-place to the start of the array. The number of
769      * unique indices is returned.
770      *
771      * <p>Uses an Order(1) data structure to ignore duplicates.
772      *
773      * <p>Warning: Requires {@code n > 0}. The array contents after the count of unique
774      * indices {@code c} are unchanged (i.e. {@code [c, n)}. This may change the count of
775      * each unique index in the entire array.
776      *
777      * @param x Indices.
778      * @param n Number of indices.
779      * @return the number of unique indices
780      */
781     static int sortIndices(int[] x, int n) {
782         // Duplicates are checked using a primitive hash set.
783         // Storage (bytes) = 4 * next-power-of-2(n*2) => 2-4 times n
784         final HashIndexSet set = HashIndexSet.create(n);
785         int i = 0;
786         int last = 0;
787         set.add(x[0]);
788         while (++i < n) {
789             final int v = x[i];
790             if (set.add(v)) {
791                 x[++last] = v;
792             }
793         }
794         Arrays.sort(x, 0, ++last);
795         return last;
796     }
797 }