Skip to content

[mypyc] Support yields while values are live #16305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 22, 2025
20 changes: 15 additions & 5 deletions mypyc/analysis/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Cast,
ComparisonOp,
ControlOp,
DecRef,
Extend,
Float,
FloatComparisonOp,
Expand All @@ -25,6 +26,7 @@
GetAttr,
GetElementPtr,
Goto,
IncRef,
InitStatic,
Integer,
IntOp,
Expand Down Expand Up @@ -77,12 +79,11 @@ def __str__(self) -> str:
return f"exits: {exits}\nsucc: {self.succ}\npred: {self.pred}"


def get_cfg(blocks: list[BasicBlock]) -> CFG:
def get_cfg(blocks: list[BasicBlock], *, use_yields: bool = False) -> CFG:
"""Calculate basic block control-flow graph.

The result is a dictionary like this:

basic block index -> (successors blocks, predecesssor blocks)
If use_yields is set, then we treat returns inserted by yields as gotos
instead of exits.
"""
succ_map = {}
pred_map: dict[BasicBlock, list[BasicBlock]] = {}
Expand All @@ -92,7 +93,10 @@ def get_cfg(blocks: list[BasicBlock]) -> CFG:
isinstance(op, ControlOp) for op in block.ops[:-1]
), "Control-flow ops must be at the end of blocks"

succ = list(block.terminator.targets())
if use_yields and isinstance(block.terminator, Return) and block.terminator.yield_target:
succ = [block.terminator.yield_target]
else:
succ = list(block.terminator.targets())
if not succ:
exits.add(block)

Expand Down Expand Up @@ -474,6 +478,12 @@ def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
return non_trivial_sources(op), set()

def visit_inc_ref(self, op: IncRef) -> GenAndKill[Value]:
return set(), set()

def visit_dec_ref(self, op: DecRef) -> GenAndKill[Value]:
return set(), set()


def analyze_live_regs(blocks: list[BasicBlock], cfg: CFG) -> AnalysisResult[Value]:
"""Calculate live registers at each CFG location.
Expand Down
11 changes: 11 additions & 0 deletions mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from mypyc.transform.flag_elimination import do_flag_elimination
from mypyc.transform.lower import lower_ir
from mypyc.transform.refcount import insert_ref_count_opcodes
from mypyc.transform.spill import insert_spills
from mypyc.transform.uninit import insert_uninit_checks

# All of the modules being compiled are divided into "groups". A group
Expand Down Expand Up @@ -228,6 +229,12 @@ def compile_scc_to_ir(
if errors.num_errors > 0:
return modules

env_user_functions = {}
for module in modules.values():
for cls in module.classes:
if cls.env_user_function:
env_user_functions[cls.env_user_function] = cls

for module in modules.values():
for fn in module.functions:
# Insert uninit checks.
Expand All @@ -236,6 +243,10 @@ def compile_scc_to_ir(
insert_exception_handling(fn)
# Insert refcount handling.
insert_ref_count_opcodes(fn)

if fn in env_user_functions:
insert_spills(fn, env_user_functions[fn])

# Switch to lower abstraction level IR.
lower_ir(fn, compiler_options)
# Perform optimizations.
Expand Down
7 changes: 7 additions & 0 deletions mypyc/ir/class_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def __init__(
# value of an attribute is the same as the error value.
self.bitmap_attrs: list[str] = []

# If this is a generator environment class, what is the actual method for it
self.env_user_function: FuncIR | None = None

def __repr__(self) -> str:
return (
"ClassIR("
Expand Down Expand Up @@ -394,6 +397,7 @@ def serialize(self) -> JsonDict:
"_always_initialized_attrs": sorted(self._always_initialized_attrs),
"_sometimes_initialized_attrs": sorted(self._sometimes_initialized_attrs),
"init_self_leak": self.init_self_leak,
"env_user_function": self.env_user_function.id if self.env_user_function else None,
}

@classmethod
Expand Down Expand Up @@ -446,6 +450,9 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR:
ir._always_initialized_attrs = set(data["_always_initialized_attrs"])
ir._sometimes_initialized_attrs = set(data["_sometimes_initialized_attrs"])
ir.init_self_leak = data["init_self_leak"]
ir.env_user_function = (
ctx.functions[data["env_user_function"]] if data["env_user_function"] else None
)

return ir

Expand Down
Loading