Skip to content

Commit 0e19fc9

Browse files
authored
[ty] defer calculating conjunctions in narrowing constraints (#23552)
1 parent 14bd2b2 commit 0e19fc9

File tree

4 files changed

+208
-79
lines changed

4 files changed

+208
-79
lines changed

crates/ty_python_semantic/resources/mdtest/narrow/truthiness.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ else:
3131
reveal_type(x) # revealed: Never
3232

3333
if x or not x:
34-
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
34+
reveal_type(x) # revealed: Literal[-1, 0, "foo", "", b"bar", b""] | bool | None | tuple[()]
3535
else:
3636
reveal_type(x) # revealed: Never
3737

3838
if not (x or not x):
3939
reveal_type(x) # revealed: Never
4040
else:
41-
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
41+
reveal_type(x) # revealed: Literal[-1, 0, "foo", "", b"bar", b""] | bool | None | tuple[()]
4242

4343
if (isinstance(x, int) or isinstance(x, str)) and x:
4444
reveal_type(x) # revealed: Literal[-1, True, "foo"]

crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -767,12 +767,11 @@ impl ReachabilityConstraintsBuilder {
767767

768768
/// AND a new optional narrowing constraint with an accumulated one.
769769
fn accumulate_constraint<'db>(
770-
db: &'db dyn Db,
771770
accumulated: Option<NarrowingConstraint<'db>>,
772771
new: Option<NarrowingConstraint<'db>>,
773772
) -> Option<NarrowingConstraint<'db>> {
774773
match (accumulated, new) {
775-
(Some(acc), Some(new_c)) => Some(new_c.merge_constraint_and(acc, db)),
774+
(Some(acc), Some(new_c)) => Some(new_c.merge_constraint_and(acc)),
776775
(None, Some(new_c)) => Some(new_c),
777776
(Some(acc), None) => Some(acc),
778777
(None, None) => None,
@@ -839,7 +838,7 @@ impl ReachabilityConstraints {
839838
// Apply all accumulated narrowing constraints to the base type
840839
match accumulated {
841840
Some(constraint) => NarrowingConstraint::intersection(base_ty)
842-
.merge_constraint_and(constraint, db)
841+
.merge_constraint_and(constraint)
843842
.evaluate_constraint_type(db),
844843
None => base_ty,
845844
}
@@ -888,7 +887,7 @@ impl ReachabilityConstraints {
888887
is_positive: !predicate.is_positive,
889888
};
890889
let neg_constraint = infer_narrowing_constraint(db, neg_predicate, place);
891-
let false_accumulated = accumulate_constraint(db, accumulated, neg_constraint);
890+
let false_accumulated = accumulate_constraint(accumulated, neg_constraint);
892891
return self.narrow_by_constraint_inner(
893892
db,
894893
predicates,
@@ -901,7 +900,7 @@ impl ReachabilityConstraints {
901900

902901
// If the false branch is statically unreachable, skip it entirely.
903902
if node.if_false == ALWAYS_FALSE {
904-
let true_accumulated = accumulate_constraint(db, accumulated, pos_constraint);
903+
let true_accumulated = accumulate_constraint(accumulated, pos_constraint);
905904
return self.narrow_by_constraint_inner(
906905
db,
907906
predicates,
@@ -913,8 +912,7 @@ impl ReachabilityConstraints {
913912
}
914913

915914
// True branch: predicate holds → accumulate positive narrowing
916-
let true_accumulated =
917-
accumulate_constraint(db, accumulated.clone(), pos_constraint);
915+
let true_accumulated = accumulate_constraint(accumulated.clone(), pos_constraint);
918916
let true_ty = self.narrow_by_constraint_inner(
919917
db,
920918
predicates,
@@ -930,7 +928,7 @@ impl ReachabilityConstraints {
930928
is_positive: !predicate.is_positive,
931929
};
932930
let neg_constraint = infer_narrowing_constraint(db, neg_predicate, place);
933-
let false_accumulated = accumulate_constraint(db, accumulated, neg_constraint);
931+
let false_accumulated = accumulate_constraint(accumulated, neg_constraint);
934932
let false_ty = self.narrow_by_constraint_inner(
935933
db,
936934
predicates,

crates/ty_python_semantic/src/types/builder.rs

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,87 @@ use crate::types::{
4545
use crate::{Db, FxOrderMap, FxOrderSet};
4646
use smallvec::SmallVec;
4747

48+
/// Extract `(core, guard)` from truthiness-guarded intersections.
49+
///
50+
/// e.g.
51+
/// - `A & ~AlwaysTruthy` -> `Some((A, ~AlwaysTruthy))`
52+
/// - `A & ~AlwaysFalsy` -> `Some((A, ~AlwaysFalsy))`
53+
/// - `A` -> `None`
54+
/// - `A & ~AlwaysTruthy & ~AlwaysFalsy` -> `None` (not a single-guard shape)
55+
///
56+
/// This only recognizes the "single truthiness guard" forms used by truthiness narrowing.
57+
fn split_truthiness_guarded_intersection<'db>(
58+
db: &'db dyn Db,
59+
ty: Type<'db>,
60+
) -> Option<(Type<'db>, Type<'db>)> {
61+
let Type::Intersection(intersection) = ty else {
62+
return None;
63+
};
64+
let falsy = Type::AlwaysTruthy.negate(db);
65+
let truthy = Type::AlwaysFalsy.negate(db);
66+
67+
let has_not_truthy = intersection.negative(db).contains(&Type::AlwaysTruthy);
68+
let has_not_falsy = intersection.negative(db).contains(&Type::AlwaysFalsy);
69+
let guard = match (has_not_truthy, has_not_falsy) {
70+
(true, false) => falsy,
71+
(false, true) => truthy,
72+
_ => return None,
73+
};
74+
75+
let mut core = IntersectionBuilder::new(db);
76+
for positive in intersection.positive(db) {
77+
core = core.add_positive(*positive);
78+
}
79+
for negative in intersection.negative(db) {
80+
if (guard == falsy && *negative == Type::AlwaysTruthy)
81+
|| (guard == truthy && *negative == Type::AlwaysFalsy)
82+
{
83+
continue;
84+
}
85+
core = core.add_negative(*negative);
86+
}
87+
Some((core.build(), guard))
88+
}
89+
90+
/// Try to merge a complementary guarded pair into an unguarded core.
91+
///
92+
/// e.g.
93+
/// - `(A & ~AlwaysTruthy, A & ~AlwaysFalsy)` -> `Some(A)`
94+
/// - `(A & ~AlwaysTruthy, B & ~AlwaysFalsy)` -> `Some(A | B)` if reconstruction is exact
95+
/// - `(A & ~AlwaysTruthy, C)` -> `None`
96+
///
97+
/// Safety rule:
98+
/// The candidate merge is accepted only if adding each original guard back reconstructs
99+
/// exactly the original operands (`left` and `right`).
100+
///
101+
/// TODO: This processing is specialized for `AlwaysTruthy/AlwaysFalsy`.
102+
/// It would be nice to generalize this in the future.
103+
/// Discussion: <https://github.com/astral-sh/ty/issues/224>
104+
fn merge_truthiness_guarded_pair<'db>(
105+
db: &'db dyn Db,
106+
left: Type<'db>,
107+
right: Type<'db>,
108+
) -> Option<Type<'db>> {
109+
let (left_core, left_guard) = split_truthiness_guarded_intersection(db, left)?;
110+
let (right_core, right_guard) = split_truthiness_guarded_intersection(db, right)?;
111+
if left_guard == right_guard {
112+
return None;
113+
}
114+
115+
if left_core.is_equivalent_to(db, right_core) {
116+
return Some(left_core);
117+
}
118+
119+
let candidate = UnionType::from_elements(db, [left_core, right_core]);
120+
let left_reconstructed = IntersectionType::from_two_elements(db, candidate, left_guard);
121+
let right_reconstructed = IntersectionType::from_two_elements(db, candidate, right_guard);
122+
if left_reconstructed == left && right_reconstructed == right {
123+
Some(candidate)
124+
} else {
125+
None
126+
}
127+
}
128+
48129
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49130
enum LiteralKind<'db> {
50131
Int,
@@ -650,10 +731,13 @@ impl<'db> UnionBuilder<'db> {
650731
}
651732

652733
fn push_type(&mut self, ty: Type<'db>, seen_aliases: &mut Vec<Type<'db>>) {
653-
let bool_pair = if let Some(LiteralValueTypeKind::Bool(b)) = ty.as_literal_value_kind() {
654-
Some(LiteralValueTypeKind::Bool(!b))
655-
} else {
656-
None
734+
let mut ty = ty;
735+
let bool_pair = |ty: Type<'db>| {
736+
if let Some(LiteralValueTypeKind::Bool(b)) = ty.as_literal_value_kind() {
737+
Some(LiteralValueTypeKind::Bool(!b))
738+
} else {
739+
None
740+
}
657741
};
658742

659743
// If an alias gets here, it means we aren't unpacking aliases, and we also
@@ -686,9 +770,16 @@ impl<'db> UnionBuilder<'db> {
686770
return;
687771
}
688772

773+
// Fold `(T & ~AlwaysTruthy) | (T & ~AlwaysFalsy)` to `T`.
774+
if let Some(merged_type) = merge_truthiness_guarded_pair(self.db, ty, element_type) {
775+
to_remove.push(i);
776+
ty = merged_type;
777+
continue;
778+
}
779+
689780
if element_type
690781
.as_literal_value_kind()
691-
.zip(bool_pair)
782+
.zip(bool_pair(ty))
692783
.is_some_and(|(element, pair)| element == pair)
693784
{
694785
self.add_in_place_impl(KnownClass::Bool.to_instance(self.db), seen_aliases);

0 commit comments

Comments
 (0)