Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2483,6 +2483,8 @@ def visit_class_def(self, defn: ClassDef) -> None:
context=defn,
code=codes.TYPE_VAR,
)
if typ.defn.type_vars:
self.check_typevar_defaults(typ.defn.type_vars)

if typ.is_protocol and typ.defn.type_vars:
self.check_protocol_variance(defn)
Expand Down Expand Up @@ -2546,6 +2548,15 @@ def check_init_subclass(self, defn: ClassDef) -> None:
# all other bases have already been checked.
break

def check_typevar_defaults(self, tvars: list[TypeVarLikeType]) -> None:
for tv in tvars:
if not (isinstance(tv, TypeVarType) and tv.has_default()):
continue
if not is_subtype(tv.default, tv.upper_bound):
self.fail("TypeVar default must be a subtype of the bound type", tv)
if tv.values and not any(tv.default == value for value in tv.values):
self.fail("TypeVar default must be one of the constraint types", tv)

def check_enum(self, defn: ClassDef) -> None:
assert defn.info.is_enum
if defn.info.fullname not in ENUM_BASES:
Expand Down Expand Up @@ -5365,6 +5376,9 @@ def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: dict[Var,
del type_map[expr]

def visit_type_alias_stmt(self, o: TypeAliasStmt) -> None:
if o.alias_node:
self.check_typevar_defaults(o.alias_node.alias_tvars)

with self.msg.filter_errors():
self.expr_checker.accept(o.value)

Expand Down
22 changes: 11 additions & 11 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,19 +1196,17 @@ def validate_type_param(self, type_param: ast_TypeVar) -> None:
def translate_type_params(self, type_params: list[Any]) -> list[TypeParam]:
explicit_type_params = []
for p in type_params:
bound = None
bound: Type | None = None
values: list[Type] = []
if sys.version_info >= (3, 13) and p.default_value is not None:
self.fail(
message_registry.TYPE_PARAM_DEFAULT_NOT_SUPPORTED,
p.lineno,
p.col_offset,
blocker=False,
)
default: Type | None = None
if sys.version_info >= (3, 13):
default = TypeConverter(self.errors, line=p.lineno).visit(p.default_value)
if isinstance(p, ast_ParamSpec): # type: ignore[misc]
explicit_type_params.append(TypeParam(p.name, PARAM_SPEC_KIND, None, []))
explicit_type_params.append(TypeParam(p.name, PARAM_SPEC_KIND, None, [], default))
elif isinstance(p, ast_TypeVarTuple): # type: ignore[misc]
explicit_type_params.append(TypeParam(p.name, TYPE_VAR_TUPLE_KIND, None, []))
explicit_type_params.append(
TypeParam(p.name, TYPE_VAR_TUPLE_KIND, None, [], default)
)
else:
if isinstance(p.bound, ast3.Tuple):
if len(p.bound.elts) < 2:
Expand All @@ -1224,7 +1222,9 @@ def translate_type_params(self, type_params: list[Any]) -> list[TypeParam]:
elif p.bound is not None:
self.validate_type_param(p)
bound = TypeConverter(self.errors, line=p.lineno).visit(p.bound)
explicit_type_params.append(TypeParam(p.name, TYPE_VAR_KIND, bound, values))
explicit_type_params.append(
TypeParam(p.name, TYPE_VAR_KIND, bound, values, default)
)
return explicit_type_params

# Return(expr? value)
Expand Down
5 changes: 0 additions & 5 deletions mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,3 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
TYPE_ALIAS_WITH_AWAIT_EXPRESSION: Final = ErrorMessage(
"Await expression cannot be used within a type alias", codes.SYNTAX
)

TYPE_PARAM_DEFAULT_NOT_SUPPORTED: Final = ErrorMessage(
"Type parameter default types not supported when using Python 3.12 type parameter syntax",
codes.MISC,
)
10 changes: 7 additions & 3 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,19 +670,21 @@ def set_line(


class TypeParam:
__slots__ = ("name", "kind", "upper_bound", "values")
__slots__ = ("name", "kind", "upper_bound", "values", "default")

def __init__(
self,
name: str,
kind: int,
upper_bound: mypy.types.Type | None,
values: list[mypy.types.Type],
default: mypy.types.Type | None,
) -> None:
self.name = name
self.kind = kind
self.upper_bound = upper_bound
self.values = values
self.default = default


FUNCITEM_FLAGS: Final = FUNCBASE_FLAGS + [
Expand Down Expand Up @@ -782,7 +784,7 @@ class FuncDef(FuncItem, SymbolNode, Statement):
"deco_line",
"is_trivial_body",
"is_mypy_only",
# Present only when a function is decorated with @typing.datasclass_transform or similar
# Present only when a function is decorated with @typing.dataclass_transform or similar
"dataclass_transform_spec",
"docstring",
"deprecated",
Expand Down Expand Up @@ -1657,21 +1659,23 @@ def accept(self, visitor: StatementVisitor[T]) -> T:


class TypeAliasStmt(Statement):
__slots__ = ("name", "type_args", "value", "invalid_recursive_alias")
__slots__ = ("name", "type_args", "value", "invalid_recursive_alias", "alias_node")

__match_args__ = ("name", "type_args", "value")

name: NameExpr
type_args: list[TypeParam]
value: LambdaExpr # Return value will get translated into a type
invalid_recursive_alias: bool
alias_node: TypeAlias | None

def __init__(self, name: NameExpr, type_args: list[TypeParam], value: LambdaExpr) -> None:
super().__init__()
self.name = name
self.type_args = type_args
self.value = value
self.invalid_recursive_alias = False
self.alias_node = None

def accept(self, visitor: StatementVisitor[T]) -> T:
return visitor.visit_type_alias_stmt(self)
Expand Down
81 changes: 57 additions & 24 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1808,7 +1808,26 @@ def analyze_type_param(
upper_bound = self.named_type("builtins.tuple", [self.object_type()])
else:
upper_bound = self.object_type()
default = AnyType(TypeOfAny.from_omitted_generics)
if type_param.default:
default = self.anal_type(
type_param.default,
allow_placeholder=True,
allow_unbound_tvars=True,
report_invalid_types=False,
allow_param_spec_literals=type_param.kind == PARAM_SPEC_KIND,
allow_tuple_literal=type_param.kind == PARAM_SPEC_KIND,
allow_unpack=type_param.kind == TYPE_VAR_TUPLE_KIND,
)
if default is None:
default = PlaceholderType(None, [], context.line)
elif type_param.kind == TYPE_VAR_KIND:
default = self.check_typevar_default(default, type_param.default)
elif type_param.kind == PARAM_SPEC_KIND:
default = self.check_paramspec_default(default, type_param.default)
elif type_param.kind == TYPE_VAR_TUPLE_KIND:
default = self.check_typevartuple_default(default, type_param.default)
else:
default = AnyType(TypeOfAny.from_omitted_generics)
if type_param.kind == TYPE_VAR_KIND:
values = []
if type_param.values:
Expand Down Expand Up @@ -4615,6 +4634,40 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool:
self.add_symbol(name, call.analyzed, s)
return True

def check_typevar_default(self, default: Type, context: Context) -> Type:
typ = get_proper_type(default)
if isinstance(typ, AnyType) and typ.is_from_error:
self.fail(
message_registry.TYPEVAR_ARG_MUST_BE_TYPE.format("TypeVar", "default"), context
)
return default

def check_paramspec_default(self, default: Type, context: Context) -> Type:
typ = get_proper_type(default)
if isinstance(typ, Parameters):
for i, arg_type in enumerate(typ.arg_types):
arg_ptype = get_proper_type(arg_type)
if isinstance(arg_ptype, AnyType) and arg_ptype.is_from_error:
self.fail(f"Argument {i} of ParamSpec default must be a type", context)
elif (
isinstance(typ, AnyType)
and typ.is_from_error
or not isinstance(typ, (AnyType, UnboundType))
):
self.fail(
"The default argument to ParamSpec must be a list expression, ellipsis, or a ParamSpec",
context,
)
default = AnyType(TypeOfAny.from_error)
return default

def check_typevartuple_default(self, default: Type, context: Context) -> Type:
typ = get_proper_type(default)
if not isinstance(typ, UnpackType):
self.fail("The default argument to TypeVarTuple must be an Unpacked tuple", context)
default = AnyType(TypeOfAny.from_error)
return default

def check_typevarlike_name(self, call: CallExpr, name: str, context: Context) -> bool:
"""Checks that the name of a TypeVar or ParamSpec matches its variable."""
name = unmangle(name)
Expand Down Expand Up @@ -4822,23 +4875,7 @@ def process_paramspec_declaration(self, s: AssignmentStmt) -> bool:
report_invalid_typevar_arg=False,
)
default = tv_arg or AnyType(TypeOfAny.from_error)
if isinstance(tv_arg, Parameters):
for i, arg_type in enumerate(tv_arg.arg_types):
typ = get_proper_type(arg_type)
if isinstance(typ, AnyType) and typ.is_from_error:
self.fail(
f"Argument {i} of ParamSpec default must be a type", param_value
)
elif (
isinstance(default, AnyType)
and default.is_from_error
or not isinstance(default, (AnyType, UnboundType))
):
self.fail(
"The default argument to ParamSpec must be a list expression, ellipsis, or a ParamSpec",
param_value,
)
default = AnyType(TypeOfAny.from_error)
default = self.check_paramspec_default(default, param_value)
else:
# ParamSpec is different from a regular TypeVar:
# arguments are not semantically valid. But, allowed in runtime.
Expand Down Expand Up @@ -4899,12 +4936,7 @@ def process_typevartuple_declaration(self, s: AssignmentStmt) -> bool:
allow_unpack=True,
)
default = tv_arg or AnyType(TypeOfAny.from_error)
if not isinstance(default, UnpackType):
self.fail(
"The default argument to TypeVarTuple must be an Unpacked tuple",
param_value,
)
default = AnyType(TypeOfAny.from_error)
default = self.check_typevartuple_default(default, param_value)
else:
self.fail(f'Unexpected keyword argument "{param_name}" for "TypeVarTuple"', s)

Expand Down Expand Up @@ -5503,6 +5535,7 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
eager=eager,
python_3_12_type_alias=True,
)
s.alias_node = alias_node

if (
existing
Expand Down
2 changes: 2 additions & 0 deletions mypy/strconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ def type_param(self, p: mypy.nodes.TypeParam) -> list[Any]:
a.append(p.upper_bound)
if p.values:
a.append(("Values", p.values))
if p.default:
a.append(("Default", [p.default]))
return [("TypeParam", a)]

# Expressions
Expand Down
4 changes: 4 additions & 0 deletions mypy/test/testparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class ParserSuite(DataSuite):
files.remove("parse-python310.test")
if sys.version_info < (3, 12):
files.remove("parse-python312.test")
if sys.version_info < (3, 13):
files.remove("parse-python313.test")

def run_case(self, testcase: DataDrivenTestCase) -> None:
test_parser(testcase)
Expand All @@ -43,6 +45,8 @@ def test_parser(testcase: DataDrivenTestCase) -> None:
options.python_version = (3, 10)
elif testcase.file.endswith("python312.test"):
options.python_version = (3, 12)
elif testcase.file.endswith("python313.test"):
options.python_version = (3, 13)
else:
options.python_version = defaults.PYTHON3_VERSION

Expand Down
Loading