Skip to content

Commit 12c73a9

Browse files
committed
Separately check equality of the scalar types and compound types in the order of declaration.
1 parent ebe9b00 commit 12c73a9

File tree

5 files changed

+414
-58
lines changed

5 files changed

+414
-58
lines changed

compiler/rustc_ast/src/ast.rs

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2195,6 +2195,34 @@ impl FloatTy {
21952195
}
21962196
}
21972197

2198+
impl<'a> TryFrom<&'a str> for FloatTy {
2199+
type Error = ();
2200+
2201+
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
2202+
Ok(match value {
2203+
"f16" => Self::F16,
2204+
"f32" => Self::F32,
2205+
"f64" => Self::F64,
2206+
"f128" => Self::F128,
2207+
_ => return Err(()),
2208+
})
2209+
}
2210+
}
2211+
2212+
impl TryFrom<Symbol> for FloatTy {
2213+
type Error = ();
2214+
2215+
fn try_from(value: Symbol) -> Result<Self, Self::Error> {
2216+
Ok(match value {
2217+
sym::f16 => Self::F16,
2218+
sym::f32 => Self::F32,
2219+
sym::f64 => Self::F64,
2220+
sym::f128 => Self::F128,
2221+
_ => return Err(()),
2222+
})
2223+
}
2224+
}
2225+
21982226
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
21992227
#[derive(Encodable, Decodable, HashStable_Generic)]
22002228
pub enum IntTy {
@@ -2230,6 +2258,38 @@ impl IntTy {
22302258
}
22312259
}
22322260

2261+
impl<'a> TryFrom<&'a str> for IntTy {
2262+
type Error = ();
2263+
2264+
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
2265+
Ok(match value {
2266+
"isize" => Self::Isize,
2267+
"i8" => Self::I8,
2268+
"i16" => Self::I16,
2269+
"i32" => Self::I32,
2270+
"i64" => Self::I64,
2271+
"i128" => Self::I128,
2272+
_ => return Err(()),
2273+
})
2274+
}
2275+
}
2276+
2277+
impl TryFrom<Symbol> for IntTy {
2278+
type Error = ();
2279+
2280+
fn try_from(value: Symbol) -> Result<Self, Self::Error> {
2281+
Ok(match value {
2282+
sym::isize => Self::Isize,
2283+
sym::i8 => Self::I8,
2284+
sym::i16 => Self::I16,
2285+
sym::i32 => Self::I32,
2286+
sym::i64 => Self::I64,
2287+
sym::i128 => Self::I128,
2288+
_ => return Err(()),
2289+
})
2290+
}
2291+
}
2292+
22332293
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Copy, Debug)]
22342294
#[derive(Encodable, Decodable, HashStable_Generic)]
22352295
pub enum UintTy {
@@ -2265,6 +2325,38 @@ impl UintTy {
22652325
}
22662326
}
22672327

2328+
impl<'a> TryFrom<&'a str> for UintTy {
2329+
type Error = ();
2330+
2331+
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
2332+
Ok(match value {
2333+
"usize" => Self::Usize,
2334+
"u8" => Self::U8,
2335+
"u16" => Self::U16,
2336+
"u32" => Self::U32,
2337+
"u64" => Self::U64,
2338+
"u128" => Self::U128,
2339+
_ => return Err(()),
2340+
})
2341+
}
2342+
}
2343+
2344+
impl TryFrom<Symbol> for UintTy {
2345+
type Error = ();
2346+
2347+
fn try_from(value: Symbol) -> Result<Self, Self::Error> {
2348+
Ok(match value {
2349+
sym::usize => Self::Usize,
2350+
sym::u8 => Self::U8,
2351+
sym::u16 => Self::U16,
2352+
sym::u32 => Self::U32,
2353+
sym::u64 => Self::U64,
2354+
sym::u128 => Self::U128,
2355+
_ => return Err(()),
2356+
})
2357+
}
2358+
}
2359+
22682360
/// A constraint on an associated item.
22692361
///
22702362
/// ### Examples
@@ -2452,6 +2544,21 @@ impl TyKind {
24522544
None
24532545
}
24542546
}
2547+
2548+
pub fn is_scalar(&self) -> bool {
2549+
let Some(ty_kind) = self.is_simple_path() else {
2550+
match self {
2551+
TyKind::Tup(tys) => return tys.is_empty(), // unit type
2552+
_ => return false,
2553+
}
2554+
};
2555+
2556+
ty_kind == sym::bool
2557+
|| ty_kind == sym::char
2558+
|| IntTy::try_from(ty_kind).is_ok()
2559+
|| UintTy::try_from(ty_kind).is_ok()
2560+
|| FloatTy::try_from(ty_kind).is_ok()
2561+
}
24552562
}
24562563

24572564
/// A pattern type pattern.

compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs

Lines changed: 108 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::deriving::generic::ty::*;
88
use crate::deriving::generic::*;
99
use crate::deriving::{path_local, path_std};
1010

11+
/// Expands a `#[derive(PartialEq)]` attribute into an implementation for the target item.
1112
pub(crate) fn expand_deriving_partial_eq(
1213
cx: &ExtCtxt<'_>,
1314
span: Span,
@@ -16,62 +17,6 @@ pub(crate) fn expand_deriving_partial_eq(
1617
push: &mut dyn FnMut(Annotatable),
1718
is_const: bool,
1819
) {
19-
fn cs_eq(cx: &ExtCtxt<'_>, span: Span, substr: &Substructure<'_>) -> BlockOrExpr {
20-
let base = true;
21-
let expr = cs_fold(
22-
true, // use foldl
23-
cx,
24-
span,
25-
substr,
26-
|cx, fold| match fold {
27-
CsFold::Single(field) => {
28-
let [other_expr] = &field.other_selflike_exprs[..] else {
29-
cx.dcx()
30-
.span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
31-
};
32-
33-
// We received arguments of type `&T`. Convert them to type `T` by stripping
34-
// any leading `&`. This isn't necessary for type checking, but
35-
// it results in better error messages if something goes wrong.
36-
//
37-
// Note: for arguments that look like `&{ x }`, which occur with packed
38-
// structs, this would cause expressions like `{ self.x } == { other.x }`,
39-
// which isn't valid Rust syntax. This wouldn't break compilation because these
40-
// AST nodes are constructed within the compiler. But it would mean that code
41-
// printed by `-Zunpretty=expanded` (or `cargo expand`) would have invalid
42-
// syntax, which would be suboptimal. So we wrap these in parens, giving
43-
// `({ self.x }) == ({ other.x })`, which is valid syntax.
44-
let convert = |expr: &P<Expr>| {
45-
if let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) =
46-
&expr.kind
47-
{
48-
if let ExprKind::Block(..) = &inner.kind {
49-
// `&{ x }` form: remove the `&`, add parens.
50-
cx.expr_paren(field.span, inner.clone())
51-
} else {
52-
// `&x` form: remove the `&`.
53-
inner.clone()
54-
}
55-
} else {
56-
expr.clone()
57-
}
58-
};
59-
cx.expr_binary(
60-
field.span,
61-
BinOpKind::Eq,
62-
convert(&field.self_expr),
63-
convert(other_expr),
64-
)
65-
}
66-
CsFold::Combine(span, expr1, expr2) => {
67-
cx.expr_binary(span, BinOpKind::And, expr1, expr2)
68-
}
69-
CsFold::Fieldless => cx.expr_bool(span, base),
70-
},
71-
);
72-
BlockOrExpr::new_expr(expr)
73-
}
74-
7520
let structural_trait_def = TraitDef {
7621
span,
7722
path: path_std!(marker::StructuralPartialEq),
@@ -97,7 +42,9 @@ pub(crate) fn expand_deriving_partial_eq(
9742
ret_ty: Path(path_local!(bool)),
9843
attributes: thin_vec![cx.attr_word(sym::inline, span)],
9944
fieldless_variants_strategy: FieldlessVariantsStrategy::Unify,
100-
combine_substructure: combine_substructure(Box::new(|a, b, c| cs_eq(a, b, c))),
45+
combine_substructure: combine_substructure(Box::new(|a, b, c| {
46+
BlockOrExpr::new_expr(get_substructure_equality_expr(a, b, c))
47+
})),
10148
}];
10249

10350
let trait_def = TraitDef {
@@ -113,3 +60,107 @@ pub(crate) fn expand_deriving_partial_eq(
11360
};
11461
trait_def.expand(cx, mitem, item, push)
11562
}
63+
64+
/// Generates the equality expression for a struct or enum variant when deriving `PartialEq`.
65+
///
66+
/// This function generates an expression that checks if all fields of a struct or enum variant are equal.
67+
/// - Scalar fields are compared first for efficiency, followed by compound fields.
68+
/// - If there are no fields, returns `true` (fieldless types are always equal).
69+
///
70+
/// For enums, the discriminant is compared first, then the rest of the fields.
71+
///
72+
/// # Panics
73+
///
74+
/// If called on static or all-fieldless enums/structs, which should not occur during derive expansion.
75+
fn get_substructure_equality_expr(
76+
cx: &ExtCtxt<'_>,
77+
span: Span,
78+
substructure: &Substructure<'_>,
79+
) -> P<Expr> {
80+
use SubstructureFields::*;
81+
82+
match substructure.fields {
83+
EnumMatching(.., fields) | Struct(.., fields) => {
84+
let combine = move |acc, field| {
85+
let rhs = get_field_equality_expr(cx, field);
86+
if let Some(lhs) = acc {
87+
return Some(cx.expr_binary(field.span, BinOpKind::And, lhs, rhs));
88+
}
89+
Some(rhs)
90+
};
91+
92+
// First compare scalar fields, then compound fields, combining all with logical AND.
93+
return fields
94+
.iter()
95+
.filter(|field| !field.is_scalar)
96+
.fold(fields.iter().filter(|field| field.is_scalar).fold(None, combine), combine)
97+
.unwrap_or_else(|| {
98+
// If there are no fields, treat as always equal.
99+
cx.expr_bool(span, true)
100+
});
101+
}
102+
EnumDiscr(disc, match_expr) => {
103+
let lhs = get_field_equality_expr(cx, disc);
104+
let Some(match_expr) = match_expr else {
105+
return lhs;
106+
};
107+
// Compare the discriminant first (cheaper), then the rest of the fields.
108+
return cx.expr_binary(disc.span, BinOpKind::And, lhs, match_expr.clone());
109+
}
110+
StaticEnum(..) => cx.dcx().span_bug(
111+
span,
112+
"unexpected static enum encountered during `derive(PartialEq)` expansion",
113+
),
114+
StaticStruct(..) => cx.dcx().span_bug(
115+
span,
116+
"unexpected static struct encountered during `derive(PartialEq)` expansion",
117+
),
118+
AllFieldlessEnum(..) => cx.dcx().span_bug(
119+
span,
120+
"unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion",
121+
),
122+
}
123+
}
124+
125+
/// Generates an equality comparison expression for a single struct or enum field.
126+
///
127+
/// This function produces an AST expression that compares the `self` and `other` values for a field using `==`.
128+
/// It removes any leading references from both sides for readability.
129+
/// If the field is a block expression, it is wrapped in parentheses to ensure valid syntax.
130+
///
131+
/// # Panics
132+
///
133+
/// Panics if there are not exactly two arguments to compare (should be `self` and `other`).
134+
fn get_field_equality_expr(cx: &ExtCtxt<'_>, field: &FieldInfo) -> P<Expr> {
135+
let [rhs] = &field.other_selflike_exprs[..] else {
136+
cx.dcx().span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
137+
};
138+
139+
cx.expr_binary(
140+
field.span,
141+
BinOpKind::Eq,
142+
wrap_block_expr(cx, peel_refs(&field.self_expr)),
143+
wrap_block_expr(cx, peel_refs(rhs)),
144+
)
145+
}
146+
147+
/// Removes all leading immutable references from an expression.
148+
///
149+
/// This is used to strip away any number of leading `&` from an expression (e.g., `&&&T` becomes `T`).
150+
/// Only removes immutable references; mutable references are preserved.
151+
fn peel_refs(mut expr: &P<Expr>) -> P<Expr> {
152+
while let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) = &expr.kind {
153+
expr = &inner;
154+
}
155+
expr.clone()
156+
}
157+
158+
/// Wraps a block expression in parentheses to ensure valid AST in macro expansion output.
159+
///
160+
/// If the given expression is a block, it is wrapped in parentheses; otherwise, it is returned unchanged.
161+
fn wrap_block_expr(cx: &ExtCtxt<'_>, expr: P<Expr>) -> P<Expr> {
162+
if matches!(&expr.kind, ExprKind::Block(..)) {
163+
return cx.expr_paren(expr.span, expr);
164+
}
165+
expr
166+
}

compiler/rustc_builtin_macros/src/deriving/generic/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ pub(crate) struct FieldInfo {
284284
/// The expressions corresponding to references to this field in
285285
/// the other selflike arguments.
286286
pub other_selflike_exprs: Vec<P<Expr>>,
287+
pub is_scalar: bool,
287288
}
288289

289290
#[derive(Copy, Clone)]
@@ -1220,7 +1221,8 @@ impl<'a> MethodDef<'a> {
12201221

12211222
let self_expr = discr_exprs.remove(0);
12221223
let other_selflike_exprs = discr_exprs;
1223-
let discr_field = FieldInfo { span, name: None, self_expr, other_selflike_exprs };
1224+
let discr_field =
1225+
FieldInfo { span, name: None, self_expr, other_selflike_exprs, is_scalar: true };
12241226

12251227
let discr_let_stmts: ThinVec<_> = iter::zip(&discr_idents, &selflike_args)
12261228
.map(|(&ident, selflike_arg)| {
@@ -1533,6 +1535,7 @@ impl<'a> TraitDef<'a> {
15331535
name: struct_field.ident,
15341536
self_expr,
15351537
other_selflike_exprs,
1538+
is_scalar: struct_field.ty.peel_refs().kind.is_scalar(),
15361539
}
15371540
})
15381541
.collect()

tests/ui/deriving/deriving-all-codegen.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@ struct Big {
4545
b1: u32, b2: u32, b3: u32, b4: u32, b5: u32, b6: u32, b7: u32, b8: u32,
4646
}
4747

48+
// It is more efficient to compare scalar types before non-scalar types.
49+
#[derive(PartialEq, PartialOrd)]
50+
struct Reorder {
51+
b1: Option<f32>,
52+
b2: u16,
53+
b3: &'static str,
54+
b4: i8,
55+
b5: u128,
56+
_b: *mut &'static dyn FnMut() -> (),
57+
b6: f64,
58+
b7: &'static mut (),
59+
b8: char,
60+
b9: &'static [i64],
61+
b10: &'static *const bool,
62+
}
63+
4864
// A struct that doesn't impl `Copy`, which means it gets the non-simple
4965
// `clone` implemention that clones the fields individually.
5066
#[derive(Clone)]
@@ -130,6 +146,20 @@ enum Mixed {
130146
S { d1: Option<u32>, d2: Option<i32> },
131147
}
132148

149+
// When comparing enum variant it is more efficient to compare scalar types before non-scalar types.
150+
#[derive(PartialEq, PartialOrd)]
151+
enum ReorderEnum {
152+
A(i32),
153+
B,
154+
C(i8),
155+
D,
156+
E,
157+
F,
158+
G(&'static mut str, *const u8, *const dyn Fn() -> ()),
159+
H,
160+
I,
161+
}
162+
133163
// An enum with no fieldless variants. Note that `Default` cannot be derived
134164
// for this enum.
135165
#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]

0 commit comments

Comments
 (0)