Skip to content

Commit fc027e1

Browse files
authored
linear_congruential_engine: Fixes for __lce_alg_picker (#81080)
This fixes two major mistakes in the implementation of `linear_congruential_engine` that allowed it to produce incorrect output. Specifically, these mistakes are in `__lce_alg_picker`, which is used to determine whether Schrage's algorithm is valid and needed. The first mistake is in the definition of `_OverflowOK`. The code comment and the description of [D65041](https://reviews.llvm.org/D65041) both indicate that it's supposed to be true iff `m` is a power of two. However, the definition used does not work out to that, and instead is true whenever `m` is even. This could result in `linear_congruential_engine` using an invalid implementation, as it would incorrectly assume that any integer overflow can't change the result. I changed the implementation to one that accurately checks if `m` is a power of two. Technically, this implementation has an edge case where it considers `0` to be a power of two, but in this case this is actually accurate behavior, as `m = 0` indicates a modulus of 2^w where w is the size of `result_type` in bits, which *is* a power of two. The second mistake is in the static assert. The original static assert erroneously included an unnecessary `a != 0 || m != 0`. Combined with the `|| !_MightOverflow`, this actually resulted in the static assert being impossible to fail. Applying De Morgan's law and expanding `_MightOverflow` gives that the only way this static assert can be triggered is if `a == 0 && m == 0 && a != 0 && m != 0 && ...`, which clearly cannot be true. I simply removed the explicit checks against `a` and `m`, as the intended checks are already included in `_MightOverflow` and `_SchrageOK`, and their inclusion doesn't provide any obvious semantic benefit. This should fix all the current instances where `linear_congruential_engine` uses an invalid implementation. This technically isn't a complete implementation, though, since the static assert will cause some instantiations of `linear_congruential_engine` not disallowed by the standard from compiling. However, this should still be an improvement, as all compiling instantiations of `linear_congruential_engine` should use a valid implementation. Fixing the cases where the static assert triggers will require adding additional implementations, some of which will be fairly non-trivial, so I'd rather leave those for another PR so they don't hold up these more important fixes. Fixes #33554
1 parent 0f1847c commit fc027e1

File tree

6 files changed

+189
-85
lines changed

6 files changed

+189
-85
lines changed

libcxx/include/__random/linear_congruential_engine.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ template <unsigned long long __a,
3131
unsigned long long __m,
3232
unsigned long long _Mp,
3333
bool _MightOverflow = (__a != 0 && __m != 0 && __m - 1 > (_Mp - __c) / __a),
34-
bool _OverflowOK = ((__m | (__m - 1)) > __m), // m = 2^n
34+
bool _OverflowOK = ((__m & (__m - 1)) == 0ull), // m = 2^n
3535
bool _SchrageOK = (__a != 0 && __m != 0 && __m % __a <= __m / __a)> // r <= q
3636
struct __lce_alg_picker {
37-
static_assert(__a != 0 || __m != 0 || !_MightOverflow || _OverflowOK || _SchrageOK,
37+
static_assert(!_MightOverflow || _OverflowOK || _SchrageOK,
3838
"The current values of a, c, and m cannot generate a number "
3939
"within bounds of linear_congruential_engine.");
4040

libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,48 +22,63 @@ int main(int, char**)
2222
{
2323
typedef unsigned long long T;
2424

25-
// m might overflow, but the overflow is OK so it shouldn't use schrage's algorithm
26-
typedef std::linear_congruential_engine<T, 25214903917ull, 1, (1ull<<48)> E1;
25+
// m might overflow, but the overflow is OK so it shouldn't use Schrage's algorithm
26+
typedef std::linear_congruential_engine<T, 25214903917ull, 1, (1ull << 48)> E1;
2727
E1 e1;
2828
// make sure the right algorithm was used
29-
assert(e1() == 25214903918);
30-
assert(e1() == 205774354444503);
31-
assert(e1() == 158051849450892);
29+
assert(e1() == 25214903918ull);
30+
assert(e1() == 205774354444503ull);
31+
assert(e1() == 158051849450892ull);
3232
// make sure result is in bounds
33-
assert(e1() < (1ull<<48));
34-
assert(e1() < (1ull<<48));
35-
assert(e1() < (1ull<<48));
36-
assert(e1() < (1ull<<48));
37-
assert(e1() < (1ull<<48));
33+
assert(e1() < (1ull << 48));
34+
assert(e1() < (1ull << 48));
35+
assert(e1() < (1ull << 48));
36+
assert(e1() < (1ull << 48));
37+
assert(e1() < (1ull << 48));
3838

3939
// m might overflow. The overflow is not OK and result will be in bounds
40-
// so we should use shrage's algorithm
41-
typedef std::linear_congruential_engine<T, (1ull<<2), 0, (1ull<<63) + 1> E2;
40+
// so we should use Schrage's algorithm
41+
typedef std::linear_congruential_engine<T, (1ull << 32), 0, (1ull << 63) + 1> E2;
4242
E2 e2;
43-
// make sure shrage's algorithm is used (it would be 0s otherwise)
44-
assert(e2() == 4);
45-
assert(e2() == 16);
46-
assert(e2() == 64);
43+
// make sure Schrage's algorithm is used (it would be 0s after the first otherwise)
44+
assert(e2() == (1ull << 32));
45+
assert(e2() == (1ull << 63) - 1ull);
46+
assert(e2() == (1ull << 63) - (1ull << 33) + 1ull);
4747
// make sure result is in bounds
48-
assert(e2() < (1ull<<48) + 1);
49-
assert(e2() < (1ull<<48) + 1);
50-
assert(e2() < (1ull<<48) + 1);
51-
assert(e2() < (1ull<<48) + 1);
52-
assert(e2() < (1ull<<48) + 1);
48+
assert(e2() < (1ull << 63) + 1);
49+
assert(e2() < (1ull << 63) + 1);
50+
assert(e2() < (1ull << 63) + 1);
51+
assert(e2() < (1ull << 63) + 1);
52+
assert(e2() < (1ull << 63) + 1);
5353

54-
// m will not overflow so we should not use shrage's algorithm
55-
typedef std::linear_congruential_engine<T, 1ull, 1, (1ull<<48)> E3;
54+
// m might overflow. The overflow is not OK and result will be in bounds
55+
// so we should use Schrage's algorithm. m is even
56+
typedef std::linear_congruential_engine<T, 0x18000001ull, 0x12347ull, (3ull << 56)> E3;
5657
E3 e3;
58+
// make sure Schrage's algorithm is used
59+
assert(e3() == 402727752ull);
60+
assert(e3() == 162159612030764687ull);
61+
assert(e3() == 108176466184989142ull);
62+
// make sure result is in bounds
63+
assert(e3() < (3ull << 56));
64+
assert(e3() < (3ull << 56));
65+
assert(e3() < (3ull << 56));
66+
assert(e3() < (3ull << 56));
67+
assert(e3() < (3ull << 56));
68+
69+
// m will not overflow so we should not use Schrage's algorithm
70+
typedef std::linear_congruential_engine<T, 1ull, 1, (1ull << 48)> E4;
71+
E4 e4;
5772
// make sure the correct algorithm was used
58-
assert(e3() == 2);
59-
assert(e3() == 3);
60-
assert(e3() == 4);
73+
assert(e4() == 2ull);
74+
assert(e4() == 3ull);
75+
assert(e4() == 4ull);
6176
// make sure result is in bounds
62-
assert(e3() < (1ull<<48));
63-
assert(e3() < (1ull<<48));
64-
assert(e3() < (1ull<<48));
65-
assert(e3() < (1ull<<48));
66-
assert(e2() < (1ull<<48));
77+
assert(e4() < (1ull << 48));
78+
assert(e4() < (1ull << 48));
79+
assert(e4() < (1ull << 48));
80+
assert(e4() < (1ull << 48));
81+
assert(e4() < (1ull << 48));
6782

6883
return 0;
6984
}

libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include <random>
1717
#include <cassert>
18+
#include <climits>
1819

1920
#include "test_macros.h"
2021

@@ -35,19 +36,41 @@ template <class T>
3536
void
3637
test()
3738
{
38-
test1<T, 0, 0, 0>();
39-
test1<T, 0, 1, 2>();
40-
test1<T, 1, 1, 2>();
41-
const T M(static_cast<T>(-1));
42-
test1<T, 0, 0, M>();
43-
test1<T, 0, M-2, M>();
44-
test1<T, 0, M-1, M>();
45-
test1<T, M-2, 0, M>();
46-
test1<T, M-2, M-2, M>();
47-
test1<T, M-2, M-1, M>();
48-
test1<T, M-1, 0, M>();
49-
test1<T, M-1, M-2, M>();
50-
test1<T, M-1, M-1, M>();
39+
const int W = sizeof(T) * CHAR_BIT;
40+
const T M(static_cast<T>(-1));
41+
const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
42+
43+
// Cases where m = 0
44+
test1<T, 0, 0, 0>();
45+
test1<T, A, 0, 0>();
46+
test1<T, 0, 1, 0>();
47+
test1<T, A, 1, 0>();
48+
49+
// Cases where m = 2^n for n < w
50+
test1<T, 0, 0, 256>();
51+
test1<T, 5, 0, 256>();
52+
test1<T, 0, 1, 256>();
53+
test1<T, 5, 1, 256>();
54+
55+
// Cases where m is odd and a = 0
56+
test1<T, 0, 0, M>();
57+
test1<T, 0, M - 2, M>();
58+
test1<T, 0, M - 1, M>();
59+
60+
// Cases where m is odd and m % a <= m / a (Schrage)
61+
test1<T, A, 0, M>();
62+
test1<T, A, M - 2, M>();
63+
test1<T, A, M - 1, M>();
64+
65+
/*
66+
// Cases where m is odd and m % a > m / a (not implemented)
67+
test1<T, M - 2, 0, M>();
68+
test1<T, M - 2, M - 2, M>();
69+
test1<T, M - 2, M - 1, M>();
70+
test1<T, M - 1, 0, M>();
71+
test1<T, M - 1, M - 2, M>();
72+
test1<T, M - 1, M - 1, M>();
73+
*/
5174
}
5275

5376
int main(int, char**)

libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,41 @@ template <class T>
3535
void
3636
test()
3737
{
38-
test1<T, 0, 0, 0>();
39-
test1<T, 0, 1, 2>();
40-
test1<T, 1, 1, 2>();
41-
const T M(static_cast<T>(-1));
42-
test1<T, 0, 0, M>();
43-
test1<T, 0, M-2, M>();
44-
test1<T, 0, M-1, M>();
45-
test1<T, M-2, 0, M>();
46-
test1<T, M-2, M-2, M>();
47-
test1<T, M-2, M-1, M>();
48-
test1<T, M-1, 0, M>();
49-
test1<T, M-1, M-2, M>();
50-
test1<T, M-1, M-1, M>();
38+
const int W = sizeof(T) * CHAR_BIT;
39+
const T M(static_cast<T>(-1));
40+
const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
41+
42+
// Cases where m = 0
43+
test1<T, 0, 0, 0>();
44+
test1<T, A, 0, 0>();
45+
test1<T, 0, 1, 0>();
46+
test1<T, A, 1, 0>();
47+
48+
// Cases where m = 2^n for n < w
49+
test1<T, 0, 0, 256>();
50+
test1<T, 5, 0, 256>();
51+
test1<T, 0, 1, 256>();
52+
test1<T, 5, 1, 256>();
53+
54+
// Cases where m is odd and a = 0
55+
test1<T, 0, 0, M>();
56+
test1<T, 0, M - 2, M>();
57+
test1<T, 0, M - 1, M>();
58+
59+
// Cases where m is odd and m % a <= m / a (Schrage)
60+
test1<T, A, 0, M>();
61+
test1<T, A, M - 2, M>();
62+
test1<T, A, M - 1, M>();
63+
64+
/*
65+
// Cases where m is odd and m % a > m / a (not implemented)
66+
test1<T, M - 2, 0, M>();
67+
test1<T, M - 2, M - 2, M>();
68+
test1<T, M - 2, M - 1, M>();
69+
test1<T, M - 1, 0, M>();
70+
test1<T, M - 1, M - 2, M>();
71+
test1<T, M - 1, M - 1, M>();
72+
*/
5173
}
5274

5375
int main(int, char**)

libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,41 @@ template <class T>
3333
void
3434
test()
3535
{
36-
test1<T, 0, 0, 0>();
37-
test1<T, 0, 1, 2>();
38-
test1<T, 1, 1, 2>();
39-
const T M(static_cast<T>(-1));
40-
test1<T, 0, 0, M>();
41-
test1<T, 0, M-2, M>();
42-
test1<T, 0, M-1, M>();
43-
test1<T, M-2, 0, M>();
44-
test1<T, M-2, M-2, M>();
45-
test1<T, M-2, M-1, M>();
46-
test1<T, M-1, 0, M>();
47-
test1<T, M-1, M-2, M>();
48-
test1<T, M-1, M-1, M>();
36+
const int W = sizeof(T) * CHAR_BIT;
37+
const T M(static_cast<T>(-1));
38+
const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
39+
40+
// Cases where m = 0
41+
test1<T, 0, 0, 0>();
42+
test1<T, A, 0, 0>();
43+
test1<T, 0, 1, 0>();
44+
test1<T, A, 1, 0>();
45+
46+
// Cases where m = 2^n for n < w
47+
test1<T, 0, 0, 256>();
48+
test1<T, 5, 0, 256>();
49+
test1<T, 0, 1, 256>();
50+
test1<T, 5, 1, 256>();
51+
52+
// Cases where m is odd and a = 0
53+
test1<T, 0, 0, M>();
54+
test1<T, 0, M - 2, M>();
55+
test1<T, 0, M - 1, M>();
56+
57+
// Cases where m is odd and m % a <= m / a (Schrage)
58+
test1<T, A, 0, M>();
59+
test1<T, A, M - 2, M>();
60+
test1<T, A, M - 1, M>();
61+
62+
/*
63+
// Cases where m is odd and m % a > m / a (not implemented)
64+
test1<T, M - 2, 0, M>();
65+
test1<T, M - 2, M - 2, M>();
66+
test1<T, M - 2, M - 1, M>();
67+
test1<T, M - 1, 0, M>();
68+
test1<T, M - 1, M - 2, M>();
69+
test1<T, M - 1, M - 1, M>();
70+
*/
4971
}
5072

5173
int main(int, char**)

libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,41 @@ template <class T>
6666
void
6767
test()
6868
{
69-
test1<T, 0, 0, 0>();
70-
test1<T, 0, 1, 2>();
71-
test1<T, 1, 1, 2>();
72-
const T M(static_cast<T>(-1));
73-
test1<T, 0, 0, M>();
74-
test1<T, 0, M-2, M>();
75-
test1<T, 0, M-1, M>();
76-
test1<T, M-2, 0, M>();
77-
test1<T, M-2, M-2, M>();
78-
test1<T, M-2, M-1, M>();
79-
test1<T, M-1, 0, M>();
80-
test1<T, M-1, M-2, M>();
81-
test1<T, M-1, M-1, M>();
69+
const int W = sizeof(T) * CHAR_BIT;
70+
const T M(static_cast<T>(-1));
71+
const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
72+
73+
// Cases where m = 0
74+
test1<T, 0, 0, 0>();
75+
test1<T, A, 0, 0>();
76+
test1<T, 0, 1, 0>();
77+
test1<T, A, 1, 0>();
78+
79+
// Cases where m = 2^n for n < w
80+
test1<T, 0, 0, 256>();
81+
test1<T, 5, 0, 256>();
82+
test1<T, 0, 1, 256>();
83+
test1<T, 5, 1, 256>();
84+
85+
// Cases where m is odd and a = 0
86+
test1<T, 0, 0, M>();
87+
test1<T, 0, M - 2, M>();
88+
test1<T, 0, M - 1, M>();
89+
90+
// Cases where m is odd and m % a <= m / a (Schrage)
91+
test1<T, A, 0, M>();
92+
test1<T, A, M - 2, M>();
93+
test1<T, A, M - 1, M>();
94+
95+
/*
96+
// Cases where m is odd and m % a > m / a (not implemented)
97+
test1<T, M - 2, 0, M>();
98+
test1<T, M - 2, M - 2, M>();
99+
test1<T, M - 2, M - 1, M>();
100+
test1<T, M - 1, 0, M>();
101+
test1<T, M - 1, M - 2, M>();
102+
test1<T, M - 1, M - 1, M>();
103+
*/
82104
}
83105

84106
int main(int, char**)

0 commit comments

Comments
 (0)