Skip to content

Commit 19fe882

Browse files
committed
[ty] Fix normalization of unions containing instances generic over unions
1 parent 9aa6330 commit 19fe882

File tree

6 files changed

+44
-12
lines changed

6 files changed

+44
-12
lines changed

crates/ty_python_semantic/resources/mdtest/directives/cast.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,8 @@ def f(x: Any, y: Unknown, z: Any | str | int):
7373
c = cast(Unknown, y)
7474
reveal_type(c) # revealed: Unknown
7575

76-
d = cast(str | int | Any, z) # error: [redundant-cast]
76+
d = cast(Unknown, x)
77+
reveal_type(d) # revealed: Unknown
78+
79+
e = cast(str | int | Any, z) # error: [redundant-cast]
7780
```

crates/ty_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,23 @@ class R: ...
118118
static_assert(is_equivalent_to(Intersection[tuple[P | Q], R], Intersection[tuple[Q | P], R]))
119119
```
120120

121+
## Unions containing generic instances parameterized by unions
122+
123+
```toml
124+
[environment]
125+
python-version = "3.12"
126+
```
127+
128+
```py
129+
from ty_extensions import is_equivalent_to, static_assert
130+
131+
class A: ...
132+
class B: ...
133+
class Foo[T]: ...
134+
135+
static_assert(is_equivalent_to(A | Foo[A | B], Foo[B | A] | A))
136+
```
137+
121138
## Callable
122139

123140
### Equivalent

crates/ty_python_semantic/src/types.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -985,15 +985,15 @@ impl<'db> Type<'db> {
985985
Type::Tuple(tuple) => Type::Tuple(tuple.normalized(db)),
986986
Type::Callable(callable) => Type::Callable(callable.normalized(db)),
987987
Type::ProtocolInstance(protocol) => protocol.normalized(db),
988+
Type::NominalInstance(instance) => Type::NominalInstance(instance.normalized(db)),
989+
Type::Dynamic(_) => Type::any(),
988990
Type::LiteralString
989-
| Type::NominalInstance(_)
990991
| Type::PropertyInstance(_)
991992
| Type::AlwaysFalsy
992993
| Type::AlwaysTruthy
993994
| Type::BooleanLiteral(_)
994995
| Type::BytesLiteral(_)
995996
| Type::StringLiteral(_)
996-
| Type::Dynamic(_)
997997
| Type::Never
998998
| Type::FunctionLiteral(_)
999999
| Type::MethodWrapper(_)
@@ -1007,10 +1007,7 @@ impl<'db> Type<'db> {
10071007
| Type::IntLiteral(_)
10081008
| Type::BoundSuper(_)
10091009
| Type::SubclassOf(_) => self,
1010-
Type::GenericAlias(generic) => {
1011-
let specialization = generic.specialization(db).normalized(db);
1012-
Type::GenericAlias(GenericAlias::new(db, generic.origin(db), specialization))
1013-
}
1010+
Type::GenericAlias(generic) => Type::GenericAlias(generic.normalized(db)),
10141011
Type::TypeVar(typevar) => match typevar.bound_or_constraints(db) {
10151012
Some(TypeVarBoundOrConstraints::UpperBound(bound)) => {
10161013
Type::TypeVar(TypeVarInstance::new(

crates/ty_python_semantic/src/types/class.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ pub struct GenericAlias<'db> {
164164
}
165165

166166
impl<'db> GenericAlias<'db> {
167+
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
168+
Self::new(db, self.origin(db), self.specialization(db).normalized(db))
169+
}
170+
167171
pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> {
168172
self.origin(db).definition(db)
169173
}
@@ -207,6 +211,13 @@ pub enum ClassType<'db> {
207211

208212
#[salsa::tracked]
209213
impl<'db> ClassType<'db> {
214+
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
215+
match self {
216+
Self::NonGeneric(_) => self,
217+
Self::Generic(generic) => Self::Generic(generic.normalized(db)),
218+
}
219+
}
220+
210221
/// Returns the class literal and specialization for this class. For a non-generic class, this
211222
/// is the class itself. For a generic alias, this is the alias's origin.
212223
pub(crate) fn class_literal(

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5048,15 +5048,15 @@ impl<'db> TypeInferenceBuilder<'db> {
50485048
overload.parameter_types()
50495049
{
50505050
let db = self.db();
5051+
let contains_unknown_or_todo = |ty| matches!(ty, Type::Dynamic(dynamic) if dynamic != DynamicType::Any);
50515052
if (source_type.is_equivalent_to(db, *casted_type)
50525053
|| source_type.normalized(db)
50535054
== casted_type.normalized(db))
5055+
&& !casted_type.any_over_type(db, &|ty| {
5056+
contains_unknown_or_todo(ty)
5057+
})
50545058
&& !source_type.any_over_type(db, &|ty| {
5055-
matches!(
5056-
ty,
5057-
Type::Dynamic(dynamic)
5058-
if dynamic != DynamicType::Any
5059-
)
5059+
contains_unknown_or_todo(ty)
50605060
})
50615061
{
50625062
if let Some(builder) = self

crates/ty_python_semantic/src/types/instance.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ impl<'db> NominalInstanceType<'db> {
7575
}
7676
}
7777

78+
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
79+
Self::from_class(self.class.normalized(db))
80+
}
81+
7882
pub(super) fn is_subtype_of(self, db: &'db dyn Db, other: Self) -> bool {
7983
// N.B. The subclass relation is fully static
8084
self.class.is_subclass_of(db, other.class)

0 commit comments

Comments
 (0)