Skip to content

Commit 7efb8db

Browse files
committed
WIP: start on spilling
1 parent a63b610 commit 7efb8db

File tree

9 files changed

+169
-7
lines changed

9 files changed

+169
-7
lines changed

mypyc/analysis/dataflow.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Cast,
1818
ComparisonOp,
1919
ControlOp,
20+
DecRef,
2021
Extend,
2122
Float,
2223
FloatComparisonOp,
@@ -25,6 +26,7 @@
2526
GetAttr,
2627
GetElementPtr,
2728
Goto,
29+
IncRef,
2830
InitStatic,
2931
Integer,
3032
IntOp,
@@ -79,12 +81,11 @@ def __str__(self) -> str:
7981
return "\n".join(lines)
8082

8183

82-
def get_cfg(blocks: list[BasicBlock]) -> CFG:
84+
def get_cfg(blocks: list[BasicBlock], *, use_yields: bool = False) -> CFG:
8385
"""Calculate basic block control-flow graph.
8486
85-
The result is a dictionary like this:
86-
87-
basic block index -> (successors blocks, predecesssor blocks)
87+
If use_yields is set, then we treat returns inserted by yields as gotos
88+
instead of exits.
8889
"""
8990
succ_map = {}
9091
pred_map: dict[BasicBlock, list[BasicBlock]] = {}
@@ -94,7 +95,10 @@ def get_cfg(blocks: list[BasicBlock]) -> CFG:
9495
isinstance(op, ControlOp) for op in block.ops[:-1]
9596
), "Control-flow ops must be at the end of blocks"
9697

97-
succ = list(block.terminator.targets())
98+
if use_yields and isinstance(block.terminator, Return) and block.terminator.yield_target:
99+
succ = [block.terminator.yield_target]
100+
else:
101+
succ = list(block.terminator.targets())
98102
if not succ:
99103
exits.add(block)
100104

@@ -494,6 +498,12 @@ def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
494498
def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
495499
return non_trivial_sources(op), set()
496500

501+
def visit_inc_ref(self, op: IncRef) -> GenAndKill[Value]:
502+
return set(), set()
503+
504+
def visit_dec_ref(self, op: DecRef) -> GenAndKill[Value]:
505+
return set(), set()
506+
497507

498508
def analyze_live_regs(blocks: list[BasicBlock], cfg: CFG) -> AnalysisResult[Value]:
499509
"""Calculate live registers at each CFG location.

mypyc/codegen/emitmodule.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from mypyc.options import CompilerOptions
5959
from mypyc.transform.exceptions import insert_exception_handling
6060
from mypyc.transform.refcount import insert_ref_count_opcodes
61+
from mypyc.transform.spill import insert_spills
6162
from mypyc.transform.uninit import insert_uninit_checks
6263

6364
# All of the modules being compiled are divided into "groups". A group
@@ -225,6 +226,10 @@ def compile_scc_to_ir(
225226
if errors.num_errors > 0:
226227
return modules
227228

229+
# XXX: HOW WILL WE DEAL WITH REFCOUNTING ON THE SPILLAGE
230+
# DO WE DO IT... LAST? MAYBE MAYBE MAYBE YES
231+
# ONLY DO UNINIT.... YEAH OK
232+
228233
# Insert uninit checks.
229234
for module in modules.values():
230235
for fn in module.functions:
@@ -237,6 +242,10 @@ def compile_scc_to_ir(
237242
for module in modules.values():
238243
for fn in module.functions:
239244
insert_ref_count_opcodes(fn)
245+
for module in modules.values():
246+
for cls in module.classes:
247+
if cls.env_user_function:
248+
insert_spills(cls.env_user_function, cls)
240249

241250
return modules
242251

mypyc/ir/class_ir.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ def __init__(
194194
# value of an attribute is the same as the error value.
195195
self.bitmap_attrs: list[str] = []
196196

197+
# If this is a generator environment class, what is the actual method for it
198+
self.env_user_function: FuncIR | None = None
199+
197200
def __repr__(self) -> str:
198201
return (
199202
"ClassIR("
@@ -391,6 +394,7 @@ def serialize(self) -> JsonDict:
391394
"_always_initialized_attrs": sorted(self._always_initialized_attrs),
392395
"_sometimes_initialized_attrs": sorted(self._sometimes_initialized_attrs),
393396
"init_self_leak": self.init_self_leak,
397+
"env_user_function": self.env_user_function.id if self.env_user_function else None,
394398
}
395399

396400
@classmethod
@@ -442,6 +446,9 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR:
442446
ir._always_initialized_attrs = set(data["_always_initialized_attrs"])
443447
ir._sometimes_initialized_attrs = set(data["_sometimes_initialized_attrs"])
444448
ir.init_self_leak = data["init_self_leak"]
449+
ir.env_user_function = (
450+
ctx.functions[data["env_user_function"]] if data["env_user_function"] else None
451+
)
445452

446453
return ir
447454

mypyc/ir/ops.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,9 +430,16 @@ class Return(ControlOp):
430430

431431
error_kind = ERR_NEVER
432432

433-
def __init__(self, value: Value, line: int = -1) -> None:
433+
def __init__(
434+
self, value: Value, line: int = -1, *, yield_target: BasicBlock | None = None
435+
) -> None:
434436
super().__init__(line)
435437
self.value = value
438+
# If this return is created by a yield, keep track of the next
439+
# basic block. This doesn't affect the code we generate but
440+
# can feed into analysis that need to understand the
441+
# *original* CFG.
442+
self.yield_target = yield_target
436443

437444
def sources(self) -> list[Value]:
438445
return [self.value]

mypyc/irbuild/function.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def c() -> None:
269269
# Re-enter the FuncItem and visit the body of the function this time.
270270
builder.enter(fn_info)
271271
setup_env_for_generator_class(builder)
272+
272273
load_outer_envs(builder, builder.fn_info.generator_class)
273274
top_level = builder.top_level_fn_info()
274275
if (

mypyc/irbuild/generator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def add_helper_to_generator_class(
181181
)
182182
fn_info.generator_class.ir.methods["__mypyc_generator_helper__"] = helper_fn_ir
183183
builder.functions.append(helper_fn_ir)
184+
fn_info.env_class.env_user_function = helper_fn_ir
185+
184186
return helper_fn_decl
185187

186188

mypyc/irbuild/statement.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -894,7 +894,7 @@ def emit_yield(builder: IRBuilder, val: Value, line: int) -> Value:
894894
next_label = len(cls.continuation_blocks)
895895
cls.continuation_blocks.append(next_block)
896896
builder.assign(cls.next_label_target, Integer(next_label), line)
897-
builder.add(Return(retval))
897+
builder.add(Return(retval, yield_target=next_block))
898898
builder.activate_block(next_block)
899899

900900
add_raise_exception_blocks_to_generator_class(builder, line)

mypyc/test-data/run-generators.test

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,3 +679,21 @@ def test_basic() -> None:
679679
with context:
680680
assert context.x == 1
681681
assert context.x == 0
682+
683+
684+
[case testYieldSpill]
685+
from typing import Generator
686+
687+
def f() -> int:
688+
return 1
689+
690+
def yield_spill() -> Generator[str, int, int]:
691+
return f() + (yield "foo")
692+
693+
[file driver.py]
694+
from native import yield_spill
695+
from testutil import run_generator
696+
697+
yields, val = run_generator(yield_spill(), [2])
698+
assert yields == ('foo',)
699+
assert val == 3, val

mypyc/transform/spill.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Insert spills for values that are live across yields."""
2+
3+
from __future__ import annotations
4+
5+
from mypyc.analysis.dataflow import AnalysisResult, analyze_live_regs, get_cfg
6+
from mypyc.common import TEMP_ATTR_NAME
7+
from mypyc.ir.class_ir import ClassIR
8+
from mypyc.ir.func_ir import FuncIR
9+
from mypyc.ir.ops import (
10+
BasicBlock,
11+
Branch,
12+
DecRef,
13+
GetAttr,
14+
IncRef,
15+
LoadErrorValue,
16+
Register,
17+
SetAttr,
18+
Value,
19+
)
20+
21+
22+
def insert_spills(ir: FuncIR, env: ClassIR) -> None:
23+
cfg = get_cfg(ir.blocks, use_yields=True)
24+
live = analyze_live_regs(ir.blocks, cfg)
25+
entry_live = live.before[ir.blocks[0], 0]
26+
27+
# from mypyc.ir.pprint import format_func
28+
29+
# print('\n'.join(format_func(ir)))
30+
31+
entry_live = {op for op in entry_live if not (isinstance(op, Register) and op.is_arg)}
32+
# XXX: Actually for now, no Registers at all -- we keep the manual spills
33+
entry_live = {op for op in entry_live if not isinstance(op, Register)}
34+
35+
ir.blocks = spill_regs(ir.blocks, env, entry_live, live)
36+
# print("\n".join(format_func(ir)))
37+
# print("\n\n\n=========")
38+
39+
40+
def spill_regs(
41+
blocks: list[BasicBlock], env: ClassIR, to_spill: set[Value], live: AnalysisResult[Value]
42+
) -> list[BasicBlock]:
43+
for op in blocks[0].ops:
44+
if isinstance(op, GetAttr) and op.attr == "__mypyc_env__":
45+
env_reg = op
46+
break
47+
else:
48+
raise AssertionError("could not find __mypyc_env__")
49+
50+
spill_locs = {}
51+
for i, val in enumerate(to_spill):
52+
name = f"{TEMP_ATTR_NAME}2_{i}"
53+
env.attributes[name] = val.type
54+
spill_locs[val] = name
55+
56+
for block in blocks:
57+
ops = block.ops
58+
block.ops = []
59+
60+
for i, op in enumerate(ops):
61+
to_decref = []
62+
63+
if isinstance(op, IncRef) and op.src in spill_locs:
64+
raise AssertionError("not sure what to do with an incref of a spill...")
65+
if isinstance(op, DecRef) and op.src in spill_locs:
66+
# When we decref a spilled value, we turn that into
67+
# NULLing out the attribute, but only if the spilled
68+
# value is not live *when we include yields in the
69+
# CFG*. (The original decrefs are computed without that.)
70+
#
71+
# We also skip a decref is the env register is not
72+
# live. That should only happen when an exception is
73+
# being raised, so everything should be handled there.
74+
if op.src not in live.after[block, i] and env_reg in live.after[block, i]:
75+
# Skip the DecRef but null out the spilled location
76+
null = LoadErrorValue(op.src.type)
77+
block.ops.extend([null, SetAttr(env_reg, spill_locs[op.src], null, op.line)])
78+
continue
79+
80+
if (
81+
any(src in spill_locs for src in op.sources())
82+
# N.B: IS_ERROR should be before a spill happens
83+
# XXX: but could we have a regular branch?
84+
and not (isinstance(op, Branch) and op.op == Branch.IS_ERROR)
85+
):
86+
new_sources: list[Value] = []
87+
for src in op.sources():
88+
if src in spill_locs:
89+
read = GetAttr(env_reg, spill_locs[src], op.line)
90+
block.ops.append(read)
91+
new_sources.append(read)
92+
if src.type.is_refcounted:
93+
to_decref.append(read)
94+
else:
95+
new_sources.append(src)
96+
97+
op.set_sources(new_sources)
98+
99+
block.ops.append(op)
100+
101+
for dec in to_decref:
102+
block.ops.append(DecRef(dec))
103+
104+
if op in spill_locs:
105+
# XXX: could we set uninit?
106+
block.ops.append(SetAttr(env_reg, spill_locs[op], op, op.line))
107+
108+
return blocks

0 commit comments

Comments
 (0)