Skip to content

Commit 5319cfb

Browse files
feat: allow range(x, y, bound=N) (#3679)
- allow range where both start and end arguments are variables, so long as a bound is supplied - ban range expressions of the form `range(x, x + N)` since the new form is cleaner and supersedes it. - also do a bit of refactoring of the codegen for range --------- Co-authored-by: Charles Cooper <cooper.charles.m@gmail.com>
1 parent 2e41873 commit 5319cfb

File tree

10 files changed

+390
-147
lines changed

10 files changed

+390
-147
lines changed

docs/control-structures.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,11 @@ Another use of range can be with ``START`` and ``STOP`` bounds.
287287
288288
Here, ``START`` and ``STOP`` are literal integers, with ``STOP`` being a greater value than ``START``. ``i`` begins as ``START`` and increments by one until it is equal to ``STOP``.
289289

290+
Finally, it is possible to use ``range`` with runtime `start` and `stop` values as long as a constant `bound` value is provided.
291+
In this case, Vyper checks at runtime that `end - start <= bound`.
292+
``N`` must be a compile-time constant.
293+
290294
.. code-block:: python
291295
292-
for i in range(a, a + N):
296+
for i in range(start, end, bound=N):
293297
...
294-
295-
``a`` is a variable with an integer type and ``N`` is a literal integer greater than zero. ``i`` begins as ``a`` and increments by one until it is equal to ``a + N``. If ``a + N`` would overflow, execution will revert.

tests/functional/codegen/features/iteration/test_for_in_list.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from decimal import Decimal
23

34
import pytest
@@ -700,13 +701,16 @@ def foo():
700701
""",
701702
StateAccessViolation,
702703
),
703-
"""
704+
(
705+
"""
704706
@external
705707
def foo():
706708
a: int128 = 6
707709
for i in range(a,a-3):
708710
pass
709711
""",
712+
StateAccessViolation,
713+
),
710714
# invalid argument length
711715
(
712716
"""
@@ -789,10 +793,13 @@ def test_for() -> int128:
789793
),
790794
]
791795

796+
BAD_CODE = [code if isinstance(code, tuple) else (code, StructureException) for code in BAD_CODE]
797+
for_code_regex = re.compile(r"for .+ in (.*):")
798+
bad_code_names = [
799+
f"{i} {for_code_regex.search(code).group(1)}" for i, (code, _) in enumerate(BAD_CODE)
800+
]
801+
792802

793-
@pytest.mark.parametrize("code", BAD_CODE)
794-
def test_bad_code(assert_compile_failed, get_contract, code):
795-
err = StructureException
796-
if not isinstance(code, str):
797-
code, err = code
803+
@pytest.mark.parametrize("code,err", BAD_CODE, ids=bad_code_names)
804+
def test_bad_code(assert_compile_failed, get_contract, code, err):
798805
assert_compile_failed(lambda: get_contract(code), err)

tests/functional/codegen/features/iteration/test_for_range.py

Lines changed: 107 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,102 @@ def repeat(n: uint256) -> uint256:
3232
c.repeat(7)
3333

3434

35+
def test_range_bound_constant_end(get_contract, tx_failed):
36+
code = """
37+
@external
38+
def repeat(n: uint256) -> uint256:
39+
x: uint256 = 0
40+
for i in range(n, 7, bound=6):
41+
x += i + 1
42+
return x
43+
"""
44+
c = get_contract(code)
45+
for n in range(1, 5):
46+
assert c.repeat(n) == sum(i + 1 for i in range(n, 7))
47+
48+
# check assertion for `start <= end`
49+
with tx_failed():
50+
c.repeat(8)
51+
# check assertion for `start + bound <= end`
52+
with tx_failed():
53+
c.repeat(0)
54+
55+
56+
def test_range_bound_two_args(get_contract, tx_failed):
57+
code = """
58+
@external
59+
def repeat(n: uint256) -> uint256:
60+
x: uint256 = 0
61+
for i in range(1, n, bound=6):
62+
x += i + 1
63+
return x
64+
"""
65+
c = get_contract(code)
66+
for n in range(1, 8):
67+
assert c.repeat(n) == sum(i + 1 for i in range(1, n))
68+
69+
# check assertion for `start <= end`
70+
with tx_failed():
71+
c.repeat(0)
72+
73+
# check codegen inserts assertion for `start + bound <= end`
74+
with tx_failed():
75+
c.repeat(8)
76+
77+
78+
def test_range_bound_two_runtime_args(get_contract, tx_failed):
79+
code = """
80+
@external
81+
def repeat(start: uint256, end: uint256) -> uint256:
82+
x: uint256 = 0
83+
for i in range(start, end, bound=6):
84+
x += i
85+
return x
86+
"""
87+
c = get_contract(code)
88+
for n in range(0, 7):
89+
assert c.repeat(0, n) == sum(range(0, n))
90+
assert c.repeat(n, n * 2) == sum(range(n, n * 2))
91+
92+
# check assertion for `start <= end`
93+
with tx_failed():
94+
c.repeat(1, 0)
95+
with tx_failed():
96+
c.repeat(7, 0)
97+
with tx_failed():
98+
c.repeat(8, 7)
99+
100+
# check codegen inserts assertion for `start + bound <= end`
101+
with tx_failed():
102+
c.repeat(0, 7)
103+
with tx_failed():
104+
c.repeat(14, 21)
105+
106+
107+
def test_range_overflow(get_contract, tx_failed):
108+
code = """
109+
@external
110+
def get_last(start: uint256, end: uint256) -> uint256:
111+
x: uint256 = 0
112+
for i in range(start, end, bound=6):
113+
x = i
114+
return x
115+
"""
116+
c = get_contract(code)
117+
UINT_MAX = 2**256 - 1
118+
assert c.get_last(UINT_MAX, UINT_MAX) == 0 # initial value of x
119+
120+
for n in range(1, 6):
121+
assert c.get_last(UINT_MAX - n, UINT_MAX) == UINT_MAX - 1
122+
123+
# check for `start + bound <= end`, overflow cases
124+
for n in range(1, 7):
125+
with tx_failed():
126+
c.get_last(UINT_MAX - n, 0)
127+
with tx_failed():
128+
c.get_last(UINT_MAX, UINT_MAX - n)
129+
130+
35131
def test_digit_reverser(get_contract_with_gas_estimation):
36132
digit_reverser = """
37133
@external
@@ -89,7 +185,7 @@ def test_offset_repeater_2(get_contract_with_gas_estimation, typ):
89185
@external
90186
def sum(frm: {typ}, to: {typ}) -> {typ}:
91187
out: {typ} = 0
92-
for i in range(frm, frm + 101):
188+
for i in range(frm, frm + 101, bound=101):
93189
if i == to:
94190
break
95191
out = out + i
@@ -146,26 +242,28 @@ def foo(a: {typ}) -> {typ}:
146242
assert c.foo(100) == 31337
147243

148244

149-
# test that we can get to the upper range of an integer
150245
@pytest.mark.parametrize("typ", ["uint8", "int128", "uint256"])
151246
def test_for_range_edge(get_contract, typ):
247+
"""
248+
Check that we can get to the upper range of an integer.
249+
Note that to avoid overflow in the bounds check for range(),
250+
we need to calculate i+1 inside the loop.
251+
"""
152252
code = f"""
153253
@external
154254
def test():
155255
found: bool = False
156256
x: {typ} = max_value({typ})
157-
for i in range(x, x + 1):
158-
if i == max_value({typ}):
257+
for i in range(x - 1, x, bound=1):
258+
if i + 1 == max_value({typ}):
159259
found = True
160-
161260
assert found
162261
163262
found = False
164263
x = max_value({typ}) - 1
165-
for i in range(x, x + 2):
166-
if i == max_value({typ}):
264+
for i in range(x - 1, x + 1, bound=2):
265+
if i + 1 == max_value({typ}):
167266
found = True
168-
169267
assert found
170268
"""
171269
c = get_contract(code)
@@ -178,7 +276,7 @@ def test_for_range_oob_check(get_contract, tx_failed, typ):
178276
@external
179277
def test():
180278
x: {typ} = max_value({typ})
181-
for i in range(x, x+2):
279+
for i in range(x, x + 2, bound=2):
182280
pass
183281
"""
184282
c = get_contract(code)

tests/functional/codegen/integration/test_crowdfund.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def finalize():
5252
@external
5353
def refund():
5454
ind: int128 = self.refundIndex
55-
for i in range(ind, ind + 30):
55+
for i in range(ind, ind + 30, bound=30):
5656
if i >= self.nextFunderIndex:
5757
self.refundIndex = self.nextFunderIndex
5858
return
@@ -147,7 +147,7 @@ def finalize():
147147
@external
148148
def refund():
149149
ind: int128 = self.refundIndex
150-
for i in range(ind, ind + 30):
150+
for i in range(ind, ind + 30, bound=30):
151151
if i >= self.nextFunderIndex:
152152
self.refundIndex = self.nextFunderIndex
153153
return

tests/functional/syntax/exceptions/test_invalid_literal_exception.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,6 @@ def foo():
1818
""",
1919
"""
2020
@external
21-
def foo(x: int128):
22-
y: int128 = 7
23-
for i in range(x, x + y):
24-
pass
25-
""",
26-
"""
27-
@external
2821
def foo():
2922
x: String[100] = "these bytes are nо gооd because the o's are from the Russian alphabet"
3023
""",

0 commit comments

Comments
 (0)