Skip to content

Commit ce90883

Browse files
authored
divmod: fix aliasing error, add tests (#180)
This change fixes a flaw in `DivMod` related to aliasing of input arguments.
1 parent 9fb9e97 commit ce90883

File tree

2 files changed

+49
-12
lines changed

2 files changed

+49
-12
lines changed

ternary_test.go

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ var ternaryOpFuncs = []struct {
2121
{"AddMod", (*Int).AddMod, bigAddMod},
2222
{"MulMod", (*Int).MulMod, bigMulMod},
2323
{"MulModWithReciprocal", (*Int).mulModWithReciprocalWrapper, bigMulMod},
24+
{"DivModZ", divModZ, bigDivModZ},
25+
{"DivModM", divModM, bigDivModM},
2426
}
2527

2628
func checkTernaryOperation(t *testing.T, opName string, op opThreeArgFunc, bigOp bigThreeArgFunc, x, y, z Int) {
@@ -49,7 +51,10 @@ func checkTernaryOperation(t *testing.T, opName string, op opThreeArgFunc, bigOp
4951
t.Fatalf("%v\nsecond argument had been modified: %x", operation, f2)
5052
}
5153
if !f3.Eq(f3orig) {
52-
t.Fatalf("%v\nthird argument had been modified: %x", operation, f3)
54+
if opName != "DivModZ" && opName != "DivModM" {
55+
// DivMod takes m as third argument, modifies it, and returns it. That is by design.
56+
t.Fatalf("%v\nthird argument had been modified: %x", operation, f3)
57+
}
5358
}
5459
// Check if reusing args as result works correctly.
5560
if have = op(f1, f1, f2orig, f3orig); have != f1 {
@@ -117,3 +122,29 @@ func (z *Int) mulModWithReciprocalWrapper(x, y, mod *Int) *Int {
117122
mu := Reciprocal(mod)
118123
return z.MulModWithReciprocal(x, y, mod, &mu)
119124
}
125+
126+
func divModZ(z, x, y, m *Int) *Int {
127+
z2, _ := z.DivMod(x, y, m)
128+
return z2
129+
}
130+
131+
func bigDivModZ(result, x, y, mod *big.Int) *big.Int {
132+
if y.Sign() == 0 {
133+
return result.SetUint64(0)
134+
}
135+
z2, _ := result.DivMod(x, y, mod)
136+
return z2
137+
}
138+
139+
func divModM(z, x, y, m *Int) *Int {
140+
_, m2 := z.DivMod(x, y, m)
141+
return z.Set(m2)
142+
}
143+
144+
func bigDivModM(result, x, y, mod *big.Int) *big.Int {
145+
if y.Sign() == 0 {
146+
return result.SetUint64(0)
147+
}
148+
_, m2 := result.DivMod(x, y, mod)
149+
return result.Set(m2)
150+
}

uint256.go

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -369,9 +369,9 @@ func umul(x, y *Int, res *[8]uint64) {
369369
func (z *Int) Mul(x, y *Int) *Int {
370370
var (
371371
carry0, carry1, carry2 uint64
372-
res1, res2 uint64
373-
x0, x1, x2, x3 = x[0], x[1], x[2], x[3]
374-
y0, y1, y2, y3 = y[0], y[1], y[2], y[3]
372+
res1, res2 uint64
373+
x0, x1, x2, x3 = x[0], x[1], x[2], x[3]
374+
y0, y1, y2, y3 = y[0], y[1], y[2], y[3]
375375
)
376376

377377
carry0, z[0] = bits.Mul64(x0, y0)
@@ -610,14 +610,20 @@ func (z *Int) Mod(x, y *Int) *Int {
610610
// DivMod sets z to the quotient x div y and m to the modulus x mod y and returns the pair (z, m) for y != 0.
611611
// If y == 0, both z and m are set to 0 (OBS: differs from the big.Int)
612612
func (z *Int) DivMod(x, y, m *Int) (*Int, *Int) {
613+
if z == m {
614+
// We return both z and m as results, if they are aliased, we have to
615+
// un-alias them to be able to return separate results.
616+
m = new(Int).Set(m)
617+
}
613618
if y.IsZero() {
614619
return z.Clear(), m.Clear()
615620
}
616621
if x.Eq(y) {
617622
return z.SetOne(), m.Clear()
618623
}
619624
if x.Lt(y) {
620-
return z.Clear(), m.Set(x)
625+
m.Set(x)
626+
return z.Clear(), m
621627
}
622628

623629
// At this point:
@@ -1279,7 +1285,7 @@ func (z *Int) Sqrt(x *Int) *Int {
12791285
return z.SetUint64(x0)
12801286
}
12811287
for {
1282-
z2 = (z1 + x0 / z1) >> 1
1288+
z2 = (z1 + x0/z1) >> 1
12831289
if z2 >= z1 {
12841290
return z.SetUint64(z1)
12851291
}
@@ -1291,18 +1297,18 @@ func (z *Int) Sqrt(x *Int) *Int {
12911297
z2 := NewInt(0)
12921298

12931299
// Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller.
1294-
z1.Lsh(z1, uint(x.BitLen() + 1) / 2) // must be ≥ √x
1300+
z1.Lsh(z1, uint(x.BitLen()+1)/2) // must be ≥ √x
12951301

12961302
// We can do the first division outside the loop
1297-
z2.Rsh(x, uint(x.BitLen() + 1) / 2) // The first div is equal to a right shift
1303+
z2.Rsh(x, uint(x.BitLen()+1)/2) // The first div is equal to a right shift
12981304

12991305
for {
13001306
z2.Add(z2, z1)
1301-
1307+
13021308
// z2 = z2.Rsh(z2, 1) -- the code below does a 1-bit rsh faster
1303-
z2[0] = (z2[0] >> 1) | z2[1] << 63
1304-
z2[1] = (z2[1] >> 1) | z2[2] << 63
1305-
z2[2] = (z2[2] >> 1) | z2[3] << 63
1309+
z2[0] = (z2[0] >> 1) | z2[1]<<63
1310+
z2[1] = (z2[1] >> 1) | z2[2]<<63
1311+
z2[2] = (z2[2] >> 1) | z2[3]<<63
13061312
z2[3] >>= 1
13071313

13081314
if !z2.Lt(z1) {

0 commit comments

Comments
 (0)