Skip to content

Commit 9db891a

Browse files
authored
Unrolled build for #141724
Rollup merge of #141724 - Sol-Ell:issue-141141-fix, r=nnethercote fix(#141141): When expanding `PartialEq`, check equality of scalar types first. Fixes #141141. Now, `cs_eq` function of `partial_eq.rs` compares [scalar types](https://doc.rust-lang.org/rust-by-example/primitives.html#scalar-types) first. - Add `is_scalar` field to `FieldInfo`. - Add `is_scalar` method to `TyKind`. - Pass `FieldInfo` via `CsFold::Combine` and refactor code relying on it. - Implement `TryFrom<&str>` and `TryFrom<Symbol>` for FloatTy. - Implement `TryFrom<&str>` and `TryFrom<Symbol>` for IntTy. - Implement `TryFrom<&str>` and `TryFrom<Symbol>` for UintTy.
2 parents aae43c4 + a6a1c1b commit 9db891a

File tree

5 files changed

+390
-58
lines changed

5 files changed

+390
-58
lines changed

compiler/rustc_ast/src/ast.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2465,6 +2465,39 @@ impl TyKind {
24652465
None
24662466
}
24672467
}
2468+
2469+
/// Returns `true` if this type is considered a scalar primitive (e.g.,
2470+
/// `i32`, `u8`, `bool`, etc).
2471+
///
2472+
/// This check is based on **symbol equality** and does **not** remove any
2473+
/// path prefixes or references. If a type alias or shadowing is present
2474+
/// (e.g., `type i32 = CustomType;`), this method will still return `true`
2475+
/// for `i32`, even though it may not refer to the primitive type.
2476+
pub fn maybe_scalar(&self) -> bool {
2477+
let Some(ty_sym) = self.is_simple_path() else {
2478+
// unit type
2479+
return self.is_unit();
2480+
};
2481+
matches!(
2482+
ty_sym,
2483+
sym::i8
2484+
| sym::i16
2485+
| sym::i32
2486+
| sym::i64
2487+
| sym::i128
2488+
| sym::u8
2489+
| sym::u16
2490+
| sym::u32
2491+
| sym::u64
2492+
| sym::u128
2493+
| sym::f16
2494+
| sym::f32
2495+
| sym::f64
2496+
| sym::f128
2497+
| sym::char
2498+
| sym::bool
2499+
)
2500+
}
24682501
}
24692502

24702503
/// A pattern type pattern.

compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs

Lines changed: 158 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use crate::deriving::generic::ty::*;
88
use crate::deriving::generic::*;
99
use crate::deriving::{path_local, path_std};
1010

11+
/// Expands a `#[derive(PartialEq)]` attribute into an implementation for the
12+
/// target item.
1113
pub(crate) fn expand_deriving_partial_eq(
1214
cx: &ExtCtxt<'_>,
1315
span: Span,
@@ -16,62 +18,6 @@ pub(crate) fn expand_deriving_partial_eq(
1618
push: &mut dyn FnMut(Annotatable),
1719
is_const: bool,
1820
) {
19-
fn cs_eq(cx: &ExtCtxt<'_>, span: Span, substr: &Substructure<'_>) -> BlockOrExpr {
20-
let base = true;
21-
let expr = cs_fold(
22-
true, // use foldl
23-
cx,
24-
span,
25-
substr,
26-
|cx, fold| match fold {
27-
CsFold::Single(field) => {
28-
let [other_expr] = &field.other_selflike_exprs[..] else {
29-
cx.dcx()
30-
.span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
31-
};
32-
33-
// We received arguments of type `&T`. Convert them to type `T` by stripping
34-
// any leading `&`. This isn't necessary for type checking, but
35-
// it results in better error messages if something goes wrong.
36-
//
37-
// Note: for arguments that look like `&{ x }`, which occur with packed
38-
// structs, this would cause expressions like `{ self.x } == { other.x }`,
39-
// which isn't valid Rust syntax. This wouldn't break compilation because these
40-
// AST nodes are constructed within the compiler. But it would mean that code
41-
// printed by `-Zunpretty=expanded` (or `cargo expand`) would have invalid
42-
// syntax, which would be suboptimal. So we wrap these in parens, giving
43-
// `({ self.x }) == ({ other.x })`, which is valid syntax.
44-
let convert = |expr: &P<Expr>| {
45-
if let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) =
46-
&expr.kind
47-
{
48-
if let ExprKind::Block(..) = &inner.kind {
49-
// `&{ x }` form: remove the `&`, add parens.
50-
cx.expr_paren(field.span, inner.clone())
51-
} else {
52-
// `&x` form: remove the `&`.
53-
inner.clone()
54-
}
55-
} else {
56-
expr.clone()
57-
}
58-
};
59-
cx.expr_binary(
60-
field.span,
61-
BinOpKind::Eq,
62-
convert(&field.self_expr),
63-
convert(other_expr),
64-
)
65-
}
66-
CsFold::Combine(span, expr1, expr2) => {
67-
cx.expr_binary(span, BinOpKind::And, expr1, expr2)
68-
}
69-
CsFold::Fieldless => cx.expr_bool(span, base),
70-
},
71-
);
72-
BlockOrExpr::new_expr(expr)
73-
}
74-
7521
let structural_trait_def = TraitDef {
7622
span,
7723
path: path_std!(marker::StructuralPartialEq),
@@ -97,7 +43,9 @@ pub(crate) fn expand_deriving_partial_eq(
9743
ret_ty: Path(path_local!(bool)),
9844
attributes: thin_vec![cx.attr_word(sym::inline, span)],
9945
fieldless_variants_strategy: FieldlessVariantsStrategy::Unify,
100-
combine_substructure: combine_substructure(Box::new(|a, b, c| cs_eq(a, b, c))),
46+
combine_substructure: combine_substructure(Box::new(|a, b, c| {
47+
BlockOrExpr::new_expr(get_substructure_equality_expr(a, b, c))
48+
})),
10149
}];
10250

10351
let trait_def = TraitDef {
@@ -113,3 +61,156 @@ pub(crate) fn expand_deriving_partial_eq(
11361
};
11462
trait_def.expand(cx, mitem, item, push)
11563
}
64+
65+
/// Generates the equality expression for a struct or enum variant when deriving
66+
/// `PartialEq`.
67+
///
68+
/// This function generates an expression that checks if all fields of a struct
69+
/// or enum variant are equal.
70+
/// - Scalar fields are compared first for efficiency, followed by compound
71+
/// fields.
72+
/// - If there are no fields, returns `true` (fieldless types are always equal).
73+
///
74+
/// Whether a field is considered "scalar" is determined by comparing the symbol
75+
/// of its type to a set of known scalar type symbols (e.g., `i32`, `u8`, etc).
76+
/// This check is based on the type's symbol.
77+
///
78+
/// ### Example 1
79+
/// ```
80+
/// #[derive(PartialEq)]
81+
/// struct i32;
82+
///
83+
/// // Here, `field_2` is of type `i32`, but since it's a user-defined type (not
84+
/// // the primitive), it will not be treated as scalar. The function will still
85+
/// // check equality of `field_2` first because the symbol matches `i32`.
86+
/// #[derive(PartialEq)]
87+
/// struct Struct {
88+
/// field_1: &'static str,
89+
/// field_2: i32,
90+
/// }
91+
/// ```
92+
///
93+
/// ### Example 2
94+
/// ```
95+
/// mod ty {
96+
/// pub type i32 = i32;
97+
/// }
98+
///
99+
/// // Here, `field_2` is of type `ty::i32`, which is a type alias for `i32`.
100+
/// // However, the function will not reorder the fields because the symbol for
101+
/// // `ty::i32` does not match the symbol for the primitive `i32`
102+
/// // ("ty::i32" != "i32").
103+
/// #[derive(PartialEq)]
104+
/// struct Struct {
105+
/// field_1: &'static str,
106+
/// field_2: ty::i32,
107+
/// }
108+
/// ```
109+
///
110+
/// For enums, the discriminant is compared first, then the rest of the fields.
111+
///
112+
/// # Panics
113+
///
114+
/// If called on static or all-fieldless enums/structs, which should not occur
115+
/// during derive expansion.
116+
fn get_substructure_equality_expr(
117+
cx: &ExtCtxt<'_>,
118+
span: Span,
119+
substructure: &Substructure<'_>,
120+
) -> P<Expr> {
121+
use SubstructureFields::*;
122+
123+
match substructure.fields {
124+
EnumMatching(.., fields) | Struct(.., fields) => {
125+
let combine = move |acc, field| {
126+
let rhs = get_field_equality_expr(cx, field);
127+
if let Some(lhs) = acc {
128+
// Combine the previous comparison with the current field
129+
// using logical AND.
130+
return Some(cx.expr_binary(field.span, BinOpKind::And, lhs, rhs));
131+
}
132+
// Start the chain with the first field's comparison.
133+
Some(rhs)
134+
};
135+
136+
// First compare scalar fields, then compound fields, combining all
137+
// with logical AND.
138+
return fields
139+
.iter()
140+
.filter(|field| !field.maybe_scalar)
141+
.fold(fields.iter().filter(|field| field.maybe_scalar).fold(None, combine), combine)
142+
// If there are no fields, treat as always equal.
143+
.unwrap_or_else(|| cx.expr_bool(span, true));
144+
}
145+
EnumDiscr(disc, match_expr) => {
146+
let lhs = get_field_equality_expr(cx, disc);
147+
let Some(match_expr) = match_expr else {
148+
return lhs;
149+
};
150+
// Compare the discriminant first (cheaper), then the rest of the
151+
// fields.
152+
return cx.expr_binary(disc.span, BinOpKind::And, lhs, match_expr.clone());
153+
}
154+
StaticEnum(..) => cx.dcx().span_bug(
155+
span,
156+
"unexpected static enum encountered during `derive(PartialEq)` expansion",
157+
),
158+
StaticStruct(..) => cx.dcx().span_bug(
159+
span,
160+
"unexpected static struct encountered during `derive(PartialEq)` expansion",
161+
),
162+
AllFieldlessEnum(..) => cx.dcx().span_bug(
163+
span,
164+
"unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion",
165+
),
166+
}
167+
}
168+
169+
/// Generates an equality comparison expression for a single struct or enum
170+
/// field.
171+
///
172+
/// This function produces an AST expression that compares the `self` and
173+
/// `other` values for a field using `==`. It removes any leading references
174+
/// from both sides for readability. If the field is a block expression, it is
175+
/// wrapped in parentheses to ensure valid syntax.
176+
///
177+
/// # Panics
178+
///
179+
/// Panics if there are not exactly two arguments to compare (should be `self`
180+
/// and `other`).
181+
fn get_field_equality_expr(cx: &ExtCtxt<'_>, field: &FieldInfo) -> P<Expr> {
182+
let [rhs] = &field.other_selflike_exprs[..] else {
183+
cx.dcx().span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
184+
};
185+
186+
cx.expr_binary(
187+
field.span,
188+
BinOpKind::Eq,
189+
wrap_block_expr(cx, peel_refs(&field.self_expr)),
190+
wrap_block_expr(cx, peel_refs(rhs)),
191+
)
192+
}
193+
194+
/// Removes all leading immutable references from an expression.
195+
///
196+
/// This is used to strip away any number of leading `&` from an expression
197+
/// (e.g., `&&&T` becomes `T`). Only removes immutable references; mutable
198+
/// references are preserved.
199+
fn peel_refs(mut expr: &P<Expr>) -> P<Expr> {
200+
while let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) = &expr.kind {
201+
expr = &inner;
202+
}
203+
expr.clone()
204+
}
205+
206+
/// Wraps a block expression in parentheses to ensure valid AST in macro
207+
/// expansion output.
208+
///
209+
/// If the given expression is a block, it is wrapped in parentheses; otherwise,
210+
/// it is returned unchanged.
211+
fn wrap_block_expr(cx: &ExtCtxt<'_>, expr: P<Expr>) -> P<Expr> {
212+
if matches!(&expr.kind, ExprKind::Block(..)) {
213+
return cx.expr_paren(expr.span, expr);
214+
}
215+
expr
216+
}

compiler/rustc_builtin_macros/src/deriving/generic/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ pub(crate) struct FieldInfo {
284284
/// The expressions corresponding to references to this field in
285285
/// the other selflike arguments.
286286
pub other_selflike_exprs: Vec<P<Expr>>,
287+
pub maybe_scalar: bool,
287288
}
288289

289290
#[derive(Copy, Clone)]
@@ -1220,7 +1221,8 @@ impl<'a> MethodDef<'a> {
12201221

12211222
let self_expr = discr_exprs.remove(0);
12221223
let other_selflike_exprs = discr_exprs;
1223-
let discr_field = FieldInfo { span, name: None, self_expr, other_selflike_exprs };
1224+
let discr_field =
1225+
FieldInfo { span, name: None, self_expr, other_selflike_exprs, maybe_scalar: true };
12241226

12251227
let discr_let_stmts: ThinVec<_> = iter::zip(&discr_idents, &selflike_args)
12261228
.map(|(&ident, selflike_arg)| {
@@ -1533,6 +1535,7 @@ impl<'a> TraitDef<'a> {
15331535
name: struct_field.ident,
15341536
self_expr,
15351537
other_selflike_exprs,
1538+
maybe_scalar: struct_field.ty.peel_refs().kind.maybe_scalar(),
15361539
}
15371540
})
15381541
.collect()

tests/ui/deriving/deriving-all-codegen.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@ struct Big {
4545
b1: u32, b2: u32, b3: u32, b4: u32, b5: u32, b6: u32, b7: u32, b8: u32,
4646
}
4747

48+
// It is more efficient to compare scalar types before non-scalar types.
49+
#[derive(PartialEq, PartialOrd)]
50+
struct Reorder {
51+
b1: Option<f32>,
52+
b2: u16,
53+
b3: &'static str,
54+
b4: i8,
55+
b5: u128,
56+
_b: *mut &'static dyn FnMut() -> (),
57+
b6: f64,
58+
b7: &'static mut (),
59+
b8: char,
60+
b9: &'static [i64],
61+
b10: &'static *const bool,
62+
}
63+
4864
// A struct that doesn't impl `Copy`, which means it gets the non-simple
4965
// `clone` implemention that clones the fields individually.
5066
#[derive(Clone)]
@@ -130,6 +146,20 @@ enum Mixed {
130146
S { d1: Option<u32>, d2: Option<i32> },
131147
}
132148

149+
// When comparing enum variant it is more efficient to compare scalar types before non-scalar types.
150+
#[derive(PartialEq, PartialOrd)]
151+
enum ReorderEnum {
152+
A(i32),
153+
B,
154+
C(i8),
155+
D,
156+
E,
157+
F,
158+
G(&'static mut str, *const u8, *const dyn Fn() -> ()),
159+
H,
160+
I,
161+
}
162+
133163
// An enum with no fieldless variants. Note that `Default` cannot be derived
134164
// for this enum.
135165
#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]

0 commit comments

Comments
 (0)