Skip to content

Commit 54fbccd

Browse files
committed
mylis_2: tail-call optimization implemented
1 parent 2001225 commit 54fbccd

File tree

5 files changed

+104
-47
lines changed

5 files changed

+104
-47
lines changed

mylis/mylis_2/examples/fact-iter.scm

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
(define (! n)
2+
(factorial-iter n 1))
3+
4+
(define (factorial-iter n product)
5+
(if (= n 1)
6+
product
7+
(factorial-iter (- n 1) (* n product))))
8+
(display (! x))

mylis/mylis_2/lis.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
Atom: TypeAlias = int | float | Symbol
2323
Expression: TypeAlias = Atom | list
2424

25+
TCO_ENABLED = True
26+
2527

2628
class Environment(ChainMap):
2729
def change(self, key: Symbol, value: object):
@@ -227,16 +229,16 @@ def evaluate(exp: Expression, env: Environment) -> Any:
227229
case [op, *args] if op not in KEYWORDS: # (proc arg...)
228230
proc = evaluate(op, env)
229231
values = (evaluate(arg, env) for arg in args)
230-
# if isinstance(proc, Procedure):
231-
# exp = ['begin', proc.body]
232-
# local_env = dict(zip(proc.parms, values))
233-
# env = Environment(local_env, proc.env)
234-
# else:
235-
try:
236-
return proc(*values)
237-
except TypeError as exc:
238-
msg = f'{exc!r} invoking {proc!r}({args!r}): {lispstr(exp)}\nAST={exp!r}'
239-
raise EvaluatorException(msg) from exc
232+
if TCO_ENABLED and isinstance(proc, Procedure):
233+
exp = ['begin', *proc.body]
234+
local_env = dict(zip(proc.parms, values))
235+
env = Environment(local_env, proc.env)
236+
else:
237+
try:
238+
return proc(*values)
239+
except TypeError as exc:
240+
msg = f'{exc!r} invoking {proc!r}({args!r}): {lispstr(exp)}\nAST={exp!r}'
241+
raise EvaluatorException(msg) from exc
240242
case _:
241243
raise InvalidSyntax(lispstr(exp))
242244

mylis/mylis_2/lis_test.py

Lines changed: 79 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import io
1+
import math
22
from exceptions import InvalidSyntax
33

44
from pytest import mark, fixture, raises
@@ -167,36 +167,25 @@ def test_begin(std_env: Environment) -> None:
167167
assert got == 42
168168

169169

170-
def test_invocation_builtin_car(std_env: Environment) -> None:
170+
def test_call_builtin_car(std_env: Environment) -> None:
171171
source = '(car (quote (11 22 33)))'
172172
got = evaluate(parse(source), std_env)
173173
assert got == 11
174174

175175

176-
def test_invocation_builtin_append(std_env: Environment) -> None:
176+
def test_call_builtin_append(std_env: Environment) -> None:
177177
source = '(append (quote (a b)) (quote (c d)))'
178178
got = evaluate(parse(source), std_env)
179179
assert got == ['a', 'b', 'c', 'd']
180180

181181

182-
def test_invocation_builtin_map(std_env: Environment) -> None:
182+
def test_call_builtin_map(std_env: Environment) -> None:
183183
source = '(map (lambda (x) (* x 2)) (quote (1 2 3))))'
184184
got = evaluate(parse(source), std_env)
185185
assert got == [2, 4, 6]
186186

187187

188-
def test_invocation_user_procedure(std_env: Environment) -> None:
189-
source = """
190-
(begin
191-
(define max (lambda (a b) (if (>= a b) a b)))
192-
(max 22 11)
193-
)
194-
"""
195-
got = evaluate(parse(source), std_env)
196-
assert got == 22
197-
198-
199-
def test_define_function(std_env: Environment) -> None:
188+
def test_define_procedure(std_env: Environment) -> None:
200189
source = '(define (max a b) (if (>= a b) a b))'
201190
got = evaluate(parse(source), std_env)
202191
assert got is None
@@ -209,6 +198,17 @@ def test_define_function(std_env: Environment) -> None:
209198
assert max_fn(3, 2) == 3
210199

211200

201+
def test_call_user_procedure(std_env: Environment) -> None:
202+
source = """
203+
(begin
204+
(define max (lambda (a b) (if (>= a b) a b)))
205+
(max 22 11)
206+
)
207+
"""
208+
got = evaluate(parse(source), std_env)
209+
assert got == 22
210+
211+
212212
def test_cond(std_env: Environment) -> None:
213213
source = """
214214
(cond ((> x 0) x)
@@ -262,4 +262,66 @@ def test_or(source: str, expected: Expression) -> None:
262262
])
263263
def test_and(source: str, expected: Expression) -> None:
264264
got = evaluate(parse(source), {})
265-
assert got == expected
265+
assert got == expected
266+
267+
############### tail-call optimization (TCO)
268+
269+
def test_simple_user_procedure_call(std_env: Environment) -> None:
270+
source = """
271+
(begin
272+
(define (answer) 42)
273+
(answer)
274+
)
275+
"""
276+
got = evaluate(parse(source), std_env)
277+
assert got == 42
278+
279+
280+
@fixture
281+
def tco(request):
282+
import lis
283+
initial_tco_setting = lis.TCO_ENABLED
284+
marker = request.node.get_closest_marker('tail_call_optimization')
285+
tco_flag = marker.args[0]
286+
assert tco_flag in (True, False)
287+
lis.TCO_ENABLED = tco_flag
288+
yield
289+
lis.TCO_ENABLED = initial_tco_setting
290+
291+
292+
@mark.tail_call_optimization(True)
293+
def test_tail_call_countdown(std_env: Environment, tco) -> None:
294+
countdown_scm = """
295+
(define (countdown n)
296+
(if (= n 0)
297+
0
298+
(countdown (- n 1))))
299+
"""
300+
evaluate(parse(countdown_scm), std_env)
301+
# maximum with TCO: n=475; without TCO: n=316
302+
n = 475
303+
source = f'(countdown {n})'
304+
got = evaluate(parse(source), std_env)
305+
assert got == 0
306+
307+
308+
@mark.tail_call_optimization(True)
309+
def test_tail_call_factorial(std_env: Environment, tco) -> None:
310+
factorial_scm = """
311+
(begin
312+
(define (factorial n)
313+
(factorial-iter n 1))
314+
315+
(define (factorial-iter n product)
316+
(if (= n 1)
317+
product
318+
(factorial-iter (- n 1) (* n product))))
319+
)
320+
"""
321+
evaluate(parse(factorial_scm), std_env)
322+
import lis
323+
# maximum with TCO: n=475; without TCO: n=316
324+
n = 475
325+
source = f'(factorial {n})'
326+
got = evaluate(parse(source), std_env)
327+
assert got == math.prod(range(2, n + 1))

mylis/mylis_2/mylis_test.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -121,24 +121,6 @@ def test_variadic_comparison(source: str, expected: lis.Expression) -> None:
121121
assert got == expected
122122

123123

124-
############### tail-call optimization (TCO)
125-
126-
@mark.skip('requires TCO for n > 159')
127-
def test_tail_call() -> None:
128-
factorial_scm = """
129-
(define (! n)
130-
(if (<= n 1)
131-
1
132-
(* n (! (- n 1)))))
133-
"""
134-
env = standard_env()
135-
lis.evaluate(lis.parse(factorial_scm), env)
136-
n = 200 # maximum without TCO: n = 159
137-
source = f'(! {n})'
138-
got = lis.evaluate(lis.parse(source), env)
139-
assert got == math.prod(range(2, n+1))
140-
141-
142124
############### multi-line REPL
143125

144126

@@ -250,8 +232,9 @@ def test_repl_gcd_example_multiline(capsys):
250232
assert dlg.session == normalize(captured.out)
251233

252234

253-
# Note: flip-flop is defined as with shared closure,
254-
# there is no "maker" function surrounding it.
235+
# Note: flip-flop keep state in a global variable
236+
# because lis.py has no `let` form, and there is
237+
# no "maker" function surrounding the function.
255238
# Contrast with make-counter in examples_test.py
256239
# where each call to (make-counter) creates a new procedure
257240
# with its own closure.

pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
[pytest]
22
addopts = --doctest-modules
3+
markers =
4+
tail_call_optimization: enable or disable TCO

0 commit comments

Comments
 (0)