Skip to content

Commit d2a3e66

Browse files
authored
Fix crash with PartialTypes and the enum plugin (#14021)
Fixes #12109. The original issue reported that the bug had to do with the use of the `--follow-imports=skip` flag. However, it turned out this was a red herring after closer inspection: I was able to trigger a more minimal repro both with and without this flag: ```python from enum import Enum class Foo(Enum): a = [] # E: Need type annotation for "a" (hint: "a: List[<type>] = ...") b = None def check(self) -> None: reveal_type(Foo.a.value) # N: Revealed type is "<partial list[?]>" reveal_type(Foo.b.value) # N: Revealed type is "<partial None>" ``` The first two `reveal_types` demonstrate the crux of the bug: the enum plugin does not correctly handle and convert partial types into regular types when inferring the type of the `.value` field. This can then cause any number of downstream problems. For example, suppose we modify `def check(...)` so it runs `reveal_type(self.value)`. Doing this will trigger a crash in mypy because it makes the enum plugin eventually try running `is_equivalent(...)` on the two partial types. But `is_equivalent` does not support partial types, so we crash. I opted to solve this problem by: 1. Making the enum plugin explicitly call the `fixup_partial_types` function on all field types. This prevents the code from crashing. 2. Modifies mypy so that Final vars are never marked as being PartialTypes. Without this, `reveal_type(Foo.b.value)` would report a type of `Union[Any, None]` instead of just `None`. (Note that all enum fields are implicitly final).
1 parent e8de6d1 commit d2a3e66

File tree

5 files changed

+70
-28
lines changed

5 files changed

+70
-28
lines changed

mypy/checker.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@
159159
erase_to_bound,
160160
erase_to_union_or_bound,
161161
false_only,
162+
fixup_partial_type,
162163
function_type,
163164
get_type_vars,
164165
is_literal_type_like,
@@ -2738,8 +2739,8 @@ def check_assignment(
27382739
# None initializers preserve the partial None type.
27392740
return
27402741

2741-
if is_valid_inferred_type(rvalue_type):
2742-
var = lvalue_type.var
2742+
var = lvalue_type.var
2743+
if is_valid_inferred_type(rvalue_type, is_lvalue_final=var.is_final):
27432744
partial_types = self.find_partial_types(var)
27442745
if partial_types is not None:
27452746
if not self.current_node_deferred:
@@ -3687,7 +3688,10 @@ def infer_variable_type(
36873688
"""Infer the type of initialized variables from initializer type."""
36883689
if isinstance(init_type, DeletedType):
36893690
self.msg.deleted_as_rvalue(init_type, context)
3690-
elif not is_valid_inferred_type(init_type) and not self.no_partial_types:
3691+
elif (
3692+
not is_valid_inferred_type(init_type, is_lvalue_final=name.is_final)
3693+
and not self.no_partial_types
3694+
):
36913695
# We cannot use the type of the initialization expression for full type
36923696
# inference (it's not specific enough), but we might be able to give
36933697
# partial type which will be made more specific later. A partial type
@@ -6114,7 +6118,7 @@ def enter_partial_types(
61146118
self.msg.need_annotation_for_var(var, context, self.options.python_version)
61156119
self.partial_reported.add(var)
61166120
if var.type:
6117-
fixed = self.fixup_partial_type(var.type)
6121+
fixed = fixup_partial_type(var.type)
61186122
var.invalid_partial_type = fixed != var.type
61196123
var.type = fixed
61206124

@@ -6145,20 +6149,7 @@ def handle_partial_var_type(
61456149
else:
61466150
# Defer the node -- we might get a better type in the outer scope
61476151
self.handle_cannot_determine_type(node.name, context)
6148-
return self.fixup_partial_type(typ)
6149-
6150-
def fixup_partial_type(self, typ: Type) -> Type:
6151-
"""Convert a partial type that we couldn't resolve into something concrete.
6152-
6153-
This means, for None we make it Optional[Any], and for anything else we
6154-
fill in all of the type arguments with Any.
6155-
"""
6156-
if not isinstance(typ, PartialType):
6157-
return typ
6158-
if typ.type is None:
6159-
return UnionType.make_union([AnyType(TypeOfAny.unannotated), NoneType()])
6160-
else:
6161-
return Instance(typ.type, [AnyType(TypeOfAny.unannotated)] * len(typ.type.type_vars))
6152+
return fixup_partial_type(typ)
61626153

61636154
def is_defined_in_base_class(self, var: Var) -> bool:
61646155
if var.info:
@@ -7006,20 +6997,27 @@ def infer_operator_assignment_method(typ: Type, operator: str) -> tuple[bool, st
70066997
return False, method
70076998

70086999

7009-
def is_valid_inferred_type(typ: Type) -> bool:
7010-
"""Is an inferred type valid?
7000+
def is_valid_inferred_type(typ: Type, is_lvalue_final: bool = False) -> bool:
7001+
"""Is an inferred type valid and needs no further refinement?
70117002
7012-
Examples of invalid types include the None type or List[<uninhabited>].
7003+
Examples of invalid types include the None type (when we are not assigning
7004+
None to a final lvalue) or List[<uninhabited>].
70137005
70147006
When not doing strict Optional checking, all types containing None are
70157007
invalid. When doing strict Optional checking, only None and types that are
70167008
incompletely defined (i.e. contain UninhabitedType) are invalid.
70177009
"""
7018-
if isinstance(get_proper_type(typ), (NoneType, UninhabitedType)):
7019-
# With strict Optional checking, we *may* eventually infer NoneType when
7020-
# the initializer is None, but we only do that if we can't infer a
7021-
# specific Optional type. This resolution happens in
7022-
# leave_partial_types when we pop a partial types scope.
7010+
proper_type = get_proper_type(typ)
7011+
if isinstance(proper_type, NoneType):
7012+
# If the lvalue is final, we may immediately infer NoneType when the
7013+
# initializer is None.
7014+
#
7015+
# If not, we want to defer making this decision. The final inferred
7016+
# type could either be NoneType or an Optional type, depending on
7017+
# the context. This resolution happens in leave_partial_types when
7018+
# we pop a partial types scope.
7019+
return is_lvalue_final
7020+
elif isinstance(proper_type, UninhabitedType):
70237021
return False
70247022
return not typ.accept(NothingSeeker())
70257023

mypy/checkexpr.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
custom_special_method,
112112
erase_to_union_or_bound,
113113
false_only,
114+
fixup_partial_type,
114115
function_type,
115116
is_literal_type_like,
116117
make_simplified_union,
@@ -2925,7 +2926,7 @@ def find_partial_type_ref_fast_path(self, expr: Expression) -> Type | None:
29252926
if isinstance(expr.node, Var):
29262927
result = self.analyze_var_ref(expr.node, expr)
29272928
if isinstance(result, PartialType) and result.type is not None:
2928-
self.chk.store_type(expr, self.chk.fixup_partial_type(result))
2929+
self.chk.store_type(expr, fixup_partial_type(result))
29292930
return result
29302931
return None
29312932

mypy/plugins/enums.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from mypy.nodes import TypeInfo
2020
from mypy.semanal_enum import ENUM_BASES
2121
from mypy.subtypes import is_equivalent
22-
from mypy.typeops import make_simplified_union
22+
from mypy.typeops import fixup_partial_type, make_simplified_union
2323
from mypy.types import CallableType, Instance, LiteralType, ProperType, Type, get_proper_type
2424

2525
ENUM_NAME_ACCESS: Final = {f"{prefix}.name" for prefix in ENUM_BASES} | {
@@ -77,6 +77,7 @@ def _infer_value_type_with_auto_fallback(
7777
"""
7878
if proper_type is None:
7979
return None
80+
proper_type = get_proper_type(fixup_partial_type(proper_type))
8081
if not (isinstance(proper_type, Instance) and proper_type.type.fullname == "enum.auto"):
8182
return proper_type
8283
assert isinstance(ctx.type, Instance), "An incorrect ctx.type was passed."

mypy/typeops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
Overloaded,
4242
Parameters,
4343
ParamSpecType,
44+
PartialType,
4445
ProperType,
4546
TupleType,
4647
Type,
@@ -1016,3 +1017,17 @@ def try_getting_instance_fallback(typ: Type) -> Instance | None:
10161017
elif isinstance(typ, TypeVarType):
10171018
return try_getting_instance_fallback(typ.upper_bound)
10181019
return None
1020+
1021+
1022+
def fixup_partial_type(typ: Type) -> Type:
1023+
"""Convert a partial type that we couldn't resolve into something concrete.
1024+
1025+
This means, for None we make it Optional[Any], and for anything else we
1026+
fill in all of the type arguments with Any.
1027+
"""
1028+
if not isinstance(typ, PartialType):
1029+
return typ
1030+
if typ.type is None:
1031+
return UnionType.make_union([AnyType(TypeOfAny.unannotated), NoneType()])
1032+
else:
1033+
return Instance(typ.type, [AnyType(TypeOfAny.unannotated)] * len(typ.type.type_vars))

test-data/unit/check-enum.test

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2100,3 +2100,30 @@ class Some:
21002100
class A(Some, Enum):
21012101
__labels__ = {1: "1"}
21022102
[builtins fixtures/dict.pyi]
2103+
2104+
[case testEnumWithPartialTypes]
2105+
from enum import Enum
2106+
2107+
class Mixed(Enum):
2108+
a = [] # E: Need type annotation for "a" (hint: "a: List[<type>] = ...")
2109+
b = None
2110+
2111+
def check(self) -> None:
2112+
reveal_type(Mixed.a.value) # N: Revealed type is "builtins.list[Any]"
2113+
reveal_type(Mixed.b.value) # N: Revealed type is "None"
2114+
2115+
# Inferring Any here instead of a union seems to be a deliberate
2116+
# choice; see the testEnumValueInhomogenous case above.
2117+
reveal_type(self.value) # N: Revealed type is "Any"
2118+
2119+
for field in Mixed:
2120+
reveal_type(field.value) # N: Revealed type is "Any"
2121+
if field.value is None:
2122+
pass
2123+
2124+
class AllPartialList(Enum):
2125+
a = [] # E: Need type annotation for "a" (hint: "a: List[<type>] = ...")
2126+
b = [] # E: Need type annotation for "b" (hint: "b: List[<type>] = ...")
2127+
2128+
def check(self) -> None:
2129+
reveal_type(self.value) # N: Revealed type is "builtins.list[Any]"

0 commit comments

Comments
 (0)