1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.core.source64;
18
19 import java.math.BigInteger;
20 import java.util.SplittableRandom;
21
22 import org.junit.jupiter.api.Assertions;
23 import org.junit.jupiter.api.Test;
24 import org.junit.jupiter.params.ParameterizedTest;
25 import org.junit.jupiter.params.provider.CsvSource;
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50 class LXMSupportTest {
51
52 private static final BigInteger TWO_POW_63 = BigInteger.ONE.shiftLeft(63);
53
54 private static final BigInteger MOD = BigInteger.ONE.shiftLeft(128);
55
56 private static final int CLEAR_LOWER_6 = -1 << 6;
57
58 private static final int CLEAR_LOWER_7 = -1 << 7;
59
60 @Test
61 void testLea64() {
62
63
64 final long[] expected = {
65 0x45b8512f9ff46f10L, 0xd6ce3db0dd63efc3L, 0x47bf6058710f2a88L, 0x85b8c74e40981596L,
66 0xd77442e45944235eL, 0x3ea4255636bfb1c3L, 0x296ec3c9d3e0addcL, 0x6c285eb9694f6eb2L,
67 0x8121aeca2ba15b66L, 0x2b6d5c2848c4fdc4L, 0xcc99bc57f5e3e024L, 0xc00f59a3ad3666cbL,
68 0x74e5285467c20ae7L, 0xf4d51701e3ea9555L, 0x3aeb92e31a9b1a0eL, 0x5a1a0ce875c7dcaL,
69 0xb9a561fb7d82d0f3L, 0x97095f0ab633bf2fL, 0xfe74b5290c07c1d1L, 0x9dfd354727d45838L,
70 0xf6279a8801201eddL, 0x2db471b1d42860eeL, 0x4ee66ceb27bd34ecL, 0x2005875ad25bd11aL,
71 0x92eac4d1446a0204L, 0xa46087d5dd5fa38eL, 0x7967530c43faabe1L, 0xc53e1dd74fd9bd15L,
72 0x259001ab97cca8bcL, 0x5edf024ee6cb1d8bL, 0x3fc021bba7d0d7e6L, 0xf82cae56e00245dbL,
73 0xf1dc30974b524d02L, 0xe1f2f1db0af7ace9L, 0x853d5892ebccb9f6L, 0xe266f36a3121da55L,
74 0x3b034a81bad01622L, 0x852b53c14569ada2L, 0xee902ddc658c86c9L, 0xd9e926b766013254L,
75 };
76 long state = 0x012de1babb3c4104L;
77 final long increment = 0xc8161b4202294965L;
78
79 for (final long e : expected) {
80 Assertions.assertEquals(e, LXMSupport.lea64(state += increment));
81 }
82 }
83
84 @Test
85 void testUnsignedMultiplyHighEdgeCases() {
86 final long[] values = {
87 -1, 0, 1, Long.MAX_VALUE, Long.MIN_VALUE, LXMSupport.M128L,
88 0xffL, 0xff00L, 0xff0000L, 0xff000000L,
89 0xff00000000L, 0xff0000000000L, 0xff000000000000L, 0xff000000000000L,
90 0xffffL, 0xffff0000L, 0xffff00000000L, 0xffff000000000000L,
91 0xffffffffL, 0xffffffff00000000L
92 };
93
94 for (final long v1 : values) {
95 for (final long v2 : values) {
96 assertMultiplyHigh(v1, v2, LXMSupport.unsignedMultiplyHigh(v1, v2));
97 }
98 }
99 }
100
101 @Test
102 void testUnsignedMultiplyHigh() {
103 final long[] values = new SplittableRandom().longs(100).toArray();
104 for (final long v1 : values) {
105 for (final long v2 : values) {
106 assertMultiplyHigh(v1, v2, LXMSupport.unsignedMultiplyHigh(v1, v2));
107 }
108 }
109 }
110
111 private static void assertMultiplyHigh(long v1, long v2, long hi) {
112 final BigInteger bi1 = toUnsignedBigInteger(v1);
113 final BigInteger bi2 = toUnsignedBigInteger(v2);
114 final BigInteger expected = bi1.multiply(bi2);
115 Assertions.assertTrue(expected.bitLength() <= 128);
116 Assertions.assertEquals(expected.shiftRight(64).longValue(), hi,
117 () -> String.format("%s * %s", bi1, bi2));
118 }
119
120
121
122
123
124
125
126 static BigInteger toUnsignedBigInteger(long v) {
127 return v < 0 ?
128 TWO_POW_63.or(BigInteger.valueOf(v & Long.MAX_VALUE)) :
129 BigInteger.valueOf(v);
130 }
131
132
133
134
135
136
137
138
139 static BigInteger toUnsignedBigInteger(long hi, long lo) {
140 return toUnsignedBigInteger(hi).shiftLeft(64).or(toUnsignedBigInteger(lo));
141 }
142
143 @Test
144 void testUnsignedAddHigh() {
145
146
147 long a = 1;
148 long b = -1;
149
150
151 final SplittableRandom sr = new SplittableRandom();
152
153 final int pow = 5;
154
155
156 final long range = 1L << (64 - pow);
157 for (int i = 1 << pow; i-- != 0;) {
158
159 Assertions.assertEquals(0L, a + b);
160 Assertions.assertEquals(1L, b & 0x1);
161
162 assertAddHigh(a, b);
163
164 assertAddHigh(a - 1, b);
165
166 final long step = sr.nextLong(range) & ~0x1;
167 a += step;
168 b -= step;
169 }
170
171
172 for (int i = 0; i < 1000; i++) {
173 assertAddHigh(sr.nextLong(), sr.nextLong() | 1);
174 }
175 }
176
177 private static void assertAddHigh(long a, long b) {
178
179
180 final long sum = a + b;
181 final long carry = Long.compareUnsigned(sum, a) < 0 ? 1 : 0;
182 Assertions.assertEquals(carry, LXMSupport.unsignedAddHigh(a, b),
183 () -> String.format("%d + %d", a, b));
184 }
185
186 @ParameterizedTest
187 @CsvSource({
188 "6364136223846793005, 1442695040888963407, 2738942865345",
189
190
191 "-3372029247567499371, 9832718632891239, 236823998",
192 "-3372029247567499371, -6152834681292394, -6378917984523",
193 "-3372029247567499371, 12638123, 21313",
194 "-3372029247567499371, -67123, 42",
195 })
196 void testLcgAdvancePow2(long m, long c, long state) {
197
198 long s = state;
199 for (int i = 0; i < 1; i++) {
200 s = m * s + c;
201 }
202 Assertions.assertEquals(s, lcgAdvancePow2(state, m, c, 0), "2^0 cycles");
203 for (int i = 0; i < 1; i++) {
204 s = m * s + c;
205 }
206 Assertions.assertEquals(s, lcgAdvancePow2(state, m, c, 1), "2^1 cycles");
207 for (int i = 0; i < 2; i++) {
208 s = m * s + c;
209 }
210 Assertions.assertEquals(s, lcgAdvancePow2(state, m, c, 2), "2^2 cycles");
211 for (int i = 0; i < 4; i++) {
212 s = m * s + c;
213 }
214 Assertions.assertEquals(s, lcgAdvancePow2(state, m, c, 3), "2^3 cycles");
215
216
217 for (int n = 3; n < 63; n++) {
218 final int n1 = n + 1;
219 Assertions.assertEquals(
220 lcgAdvancePow2(lcgAdvancePow2(state, m, c, n), m, c, n),
221 lcgAdvancePow2(state, m, c, n1), () -> "2^" + n1 + " cycles");
222 }
223
224
225 for (final int i : new int[] {64, 67868, Integer.MAX_VALUE, Integer.MIN_VALUE, -26762, -2, -1}) {
226 final int n = i;
227 Assertions.assertEquals(state, lcgAdvancePow2(state, m, c, n),
228 () -> "2^" + n + " cycles");
229 }
230 }
231
232 @ParameterizedTest
233 @CsvSource({
234 "126868183112323, 6364136223846793005, 1442695040888963407, 2738942865345, 3467819237274724, 12367842684328",
235
236 "-126836182831123, -1, 12678162381123, -12673162838122, 12313212312354235, 127384628323784",
237 "92349876232, -1, 92374923739482, 2394782347892, 1239748923479, 627348278239",
238
239
240 "1, -3024805186288043011, 9832718632891239, 236823998, -23564628723714323, -12361783268182",
241 "1, -3024805186288043011, -6152834681292394, -6378917984523, 127317381313, -12637618368172",
242 "1, -3024805186288043011, 1, 2, 3, 4",
243 "1, -3024805186288043011, -1, -78, -56775, 121",
244 })
245 void testLcg128AdvancePow2(long mh, long ml, long ch, long cl, long stateh, long statel) {
246
247 BigInteger s = toUnsignedBigInteger(stateh, statel);
248 final BigInteger m = toUnsignedBigInteger(mh, ml);
249 final BigInteger c = toUnsignedBigInteger(ch, cl);
250 for (int i = 0; i < 1; i++) {
251 s = m.multiply(s).add(c).mod(MOD);
252 }
253 Assertions.assertEquals(s.shiftRight(64).longValue(),
254 lcgAdvancePow2High(stateh, statel, mh, ml, ch, cl, 0), "2^0 cycles");
255 for (int i = 0; i < 1; i++) {
256 s = m.multiply(s).add(c).mod(MOD);
257 }
258 Assertions.assertEquals(s.shiftRight(64).longValue(),
259 lcgAdvancePow2High(stateh, statel, mh, ml, ch, cl, 1), "2^1 cycles");
260 for (int i = 0; i < 2; i++) {
261 s = m.multiply(s).add(c).mod(MOD);
262 }
263 Assertions.assertEquals(s.shiftRight(64).longValue(),
264 lcgAdvancePow2High(stateh, statel, mh, ml, ch, cl, 2), "2^2 cycles");
265 for (int i = 0; i < 4; i++) {
266 s = m.multiply(s).add(c).mod(MOD);
267 }
268 Assertions.assertEquals(s.shiftRight(64).longValue(),
269 lcgAdvancePow2High(stateh, statel, mh, ml, ch, cl, 3), "2^3 cycles");
270
271
272 for (int n = 3; n < 127; n++) {
273 final int n1 = n + 1;
274
275
276 final long lo = lcgAdvancePow2(statel, ml, cl, n);
277 final long hi = lcgAdvancePow2High(stateh, statel, mh, ml, ch, cl, n);
278 Assertions.assertEquals(
279 lcgAdvancePow2High(hi, lo, mh, ml, ch, cl, n),
280 lcgAdvancePow2High(stateh, statel, mh, ml, ch, cl, n1), () -> "2^" + n1 + " cycles");
281 }
282
283
284 for (final int i : new int[] {128, 67868, Integer.MAX_VALUE, Integer.MIN_VALUE, -26762, -2, -1}) {
285 final int n = i;
286 Assertions.assertEquals(stateh, lcgAdvancePow2High(stateh, statel, mh, ml, ch, cl, n),
287 () -> "2^" + n + " cycles");
288 }
289 }
290
291 @Test
292 void testLcg64Advance2Pow32Constants() {
293
294
295
296 final long[] out = new long[2];
297 lcgAdvancePow2(LXMSupport.M64, 1, 32, out);
298 Assertions.assertEquals(LXMSupport.M64P, out[0], "m'");
299 Assertions.assertEquals(LXMSupport.C64P, out[1], "c'");
300
301 Assertions.assertEquals(1, (int) out[0], "low m'");
302 Assertions.assertEquals(0, (int) out[1], "low c'");
303 }
304
305 @Test
306 void testLcg128Advance2Pow64Constants() {
307
308
309
310 final long[] out = new long[4];
311 lcgAdvancePow2(1, LXMSupport.M128L, 0, 1, 64, out);
312 Assertions.assertEquals(LXMSupport.M128PH, out[0], "high m'");
313 Assertions.assertEquals(LXMSupport.C128PH, out[2], "high c'");
314
315
316 Assertions.assertEquals(1, out[1], "low m'");
317 Assertions.assertEquals(0, out[3], "low c'");
318 }
319
320
321
322
323
324 @Test
325 void testLcgAdvance2Pow32() {
326 final SplittableRandom r = new SplittableRandom();
327 final long[] out = new long[2];
328
329 for (int i = 0; i < 2000; i++) {
330
331 final long c = r.nextLong() | 1;
332 lcgAdvancePow2(LXMSupport.M64, c, 32, out);
333 final long a = out[1];
334
335 Assertions.assertEquals(1, (a >>> 32) & 0x1, "High half c' should be odd");
336 Assertions.assertEquals(0, (int) a, "Low half c' should be 0");
337
338 Assertions.assertEquals(a, LXMSupport.C64P * c);
339 }
340 }
341
342
343
344
345
346 @Test
347 void testLcgAdvance2Pow64() {
348 final SplittableRandom r = new SplittableRandom();
349 final long[] out = new long[4];
350
351 for (int i = 0; i < 2000; i++) {
352
353 final long ch = r.nextLong();
354 final long cl = r.nextLong() | 1;
355 lcgAdvancePow2(1, LXMSupport.M128L, ch, cl, 64, out);
356 final long ah = out[2];
357
358 Assertions.assertEquals(1, ah & 0x1, "High half c' should be odd");
359 Assertions.assertEquals(0, out[3], "Low half c' should be 0");
360
361 Assertions.assertEquals(ah, LXMSupport.C128PH * cl);
362 }
363 }
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394 private static void lcgAdvancePow2(long m, long c, int k, long[] out) {
395
396
397
398 if ((k & CLEAR_LOWER_6) != 0) {
399
400 out[0] = 1;
401 out[1] = 0;
402 return;
403 }
404
405 long mp = m;
406 long a = c;
407
408 for (int i = k; i != 0; i--) {
409
410 a = (mp + 1) * a;
411 mp *= mp;
412 }
413 out[0] = mp;
414 out[1] = a;
415 }
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446 static long lcgAdvancePow2(long s, long m, long c, int k) {
447 final long[] out = new long[2];
448 lcgAdvancePow2(m, c, k, out);
449 final long mp = out[0];
450 final long ap = out[1];
451 return mp * s + ap;
452 }
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485 private static void lcgAdvancePow2(final long mh, final long ml,
486 final long ch, final long cl,
487 int k, long[] out) {
488
489
490
491 if ((k & CLEAR_LOWER_7) != 0) {
492
493 out[0] = out[2] = out[3] = 0;
494 out[1] = 1;
495 return;
496 }
497
498 long mph = mh;
499 long mpl = ml;
500 long ah = ch;
501 long al = cl;
502
503 for (int i = k; i != 0; i--) {
504
505
506
507 final long mp1l = mpl + 1;
508 final long mp1h = mp1l == 0 ? mph + 1 : mph;
509 ah = LXMSupport.unsignedMultiplyHigh(mp1l, al) + mp1h * al + mp1l * ah;
510 al = mp1l * al;
511
512
513
514 mph = LXMSupport.unsignedMultiplyHigh(mpl, mpl) + 2 * mph * mpl;
515 mpl = mpl * mpl;
516 }
517
518 out[0] = mph;
519 out[1] = mpl;
520 out[2] = ah;
521 out[3] = al;
522 }
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559 static long lcgAdvancePow2High(long sh, long sl,
560 long mh, long ml,
561 long ch, long cl,
562 int k) {
563 final long[] out = new long[4];
564 lcgAdvancePow2(mh, ml, ch, cl, k, out);
565 final long mph = out[0];
566 final long mpl = out[1];
567 final long ah = out[2];
568 final long al = out[3];
569
570
571
572 long hi = LXMSupport.unsignedMultiplyHigh(mpl, sl) + mpl * sh + mph * sl + ah;
573
574
575 final long lo = sl * mpl;
576 if (Long.compareUnsigned(lo + al, lo) < 0) {
577 ++hi;
578 }
579 return hi;
580 }
581 }