Skip to content

fix(#141141): When expanding PartialEq, check equality of scalar types first. #141724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2452,6 +2452,39 @@ impl TyKind {
None
}
}

/// Returns `true` if this type is considered a scalar primitive (e.g.,
/// `i32`, `u8`, `bool`, etc).
///
/// This check is based on **symbol equality** and does **not** remove any
/// path prefixes or references. If a type alias or shadowing is present
/// (e.g., `type i32 = CustomType;`), this method will still return `true`
/// for `i32`, even though it may not refer to the primitive type.
pub fn maybe_scalar(&self) -> bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A brief comment explaining the maybe_ prefix would be good, i.e. that it's not perfectly accurate because of typedefs.

let Some(ty_sym) = self.is_simple_path() else {
// unit type
return self.is_unit();
};
matches!(
ty_sym,
sym::i8
| sym::i16
| sym::i32
| sym::i64
| sym::i128
| sym::u8
| sym::u16
| sym::u32
| sym::u64
| sym::u128
| sym::f16
| sym::f32
| sym::f64
| sym::f128
| sym::char
| sym::bool
)
}
}

/// A pattern type pattern.
Expand Down
215 changes: 158 additions & 57 deletions compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use crate::deriving::generic::ty::*;
use crate::deriving::generic::*;
use crate::deriving::{path_local, path_std};

/// Expands a `#[derive(PartialEq)]` attribute into an implementation for the
/// target item.
pub(crate) fn expand_deriving_partial_eq(
cx: &ExtCtxt<'_>,
span: Span,
Expand All @@ -16,62 +18,6 @@ pub(crate) fn expand_deriving_partial_eq(
push: &mut dyn FnMut(Annotatable),
is_const: bool,
) {
fn cs_eq(cx: &ExtCtxt<'_>, span: Span, substr: &Substructure<'_>) -> BlockOrExpr {
let base = true;
let expr = cs_fold(
true, // use foldl
cx,
span,
substr,
|cx, fold| match fold {
CsFold::Single(field) => {
let [other_expr] = &field.other_selflike_exprs[..] else {
cx.dcx()
.span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
};

// We received arguments of type `&T`. Convert them to type `T` by stripping
// any leading `&`. This isn't necessary for type checking, but
// it results in better error messages if something goes wrong.
//
// Note: for arguments that look like `&{ x }`, which occur with packed
// structs, this would cause expressions like `{ self.x } == { other.x }`,
// which isn't valid Rust syntax. This wouldn't break compilation because these
// AST nodes are constructed within the compiler. But it would mean that code
// printed by `-Zunpretty=expanded` (or `cargo expand`) would have invalid
// syntax, which would be suboptimal. So we wrap these in parens, giving
// `({ self.x }) == ({ other.x })`, which is valid syntax.
let convert = |expr: &P<Expr>| {
if let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) =
&expr.kind
{
if let ExprKind::Block(..) = &inner.kind {
// `&{ x }` form: remove the `&`, add parens.
cx.expr_paren(field.span, inner.clone())
} else {
// `&x` form: remove the `&`.
inner.clone()
}
} else {
expr.clone()
}
};
cx.expr_binary(
field.span,
BinOpKind::Eq,
convert(&field.self_expr),
convert(other_expr),
)
}
CsFold::Combine(span, expr1, expr2) => {
cx.expr_binary(span, BinOpKind::And, expr1, expr2)
}
CsFold::Fieldless => cx.expr_bool(span, base),
},
);
BlockOrExpr::new_expr(expr)
}

let structural_trait_def = TraitDef {
span,
path: path_std!(marker::StructuralPartialEq),
Expand All @@ -97,7 +43,9 @@ pub(crate) fn expand_deriving_partial_eq(
ret_ty: Path(path_local!(bool)),
attributes: thin_vec![cx.attr_word(sym::inline, span)],
fieldless_variants_strategy: FieldlessVariantsStrategy::Unify,
combine_substructure: combine_substructure(Box::new(|a, b, c| cs_eq(a, b, c))),
combine_substructure: combine_substructure(Box::new(|a, b, c| {
BlockOrExpr::new_expr(get_substructure_equality_expr(a, b, c))
})),
}];

let trait_def = TraitDef {
Expand All @@ -113,3 +61,156 @@ pub(crate) fn expand_deriving_partial_eq(
};
trait_def.expand(cx, mitem, item, push)
}

/// Generates the equality expression for a struct or enum variant when deriving
/// `PartialEq`.
///
/// This function generates an expression that checks if all fields of a struct
/// or enum variant are equal.
/// - Scalar fields are compared first for efficiency, followed by compound
/// fields.
/// - If there are no fields, returns `true` (fieldless types are always equal).
///
/// Whether a field is considered "scalar" is determined by comparing the symbol
/// of its type to a set of known scalar type symbols (e.g., `i32`, `u8`, etc).
/// This check is based on the type's symbol.
///
/// ### Example 1
/// ```
/// #[derive(PartialEq)]
/// struct i32;
///
/// // Here, `field_2` is of type `i32`, but since it's a user-defined type (not
/// // the primitive), it will not be treated as scalar. The function will still
/// // check equality of `field_2` first because the symbol matches `i32`.
/// #[derive(PartialEq)]
/// struct Struct {
/// field_1: &'static str,
/// field_2: i32,
/// }
/// ```
///
/// ### Example 2
/// ```
/// mod ty {
/// pub type i32 = i32;
/// }
///
/// // Here, `field_2` is of type `ty::i32`, which is a type alias for `i32`.
/// // However, the function will not reorder the fields because the symbol for
/// // `ty::i32` does not match the symbol for the primitive `i32`
/// // ("ty::i32" != "i32").
/// #[derive(PartialEq)]
/// struct Struct {
/// field_1: &'static str,
/// field_2: ty::i32,
/// }
/// ```
///
/// For enums, the discriminant is compared first, then the rest of the fields.
///
/// # Panics
///
/// If called on static or all-fieldless enums/structs, which should not occur
/// during derive expansion.
fn get_substructure_equality_expr(
cx: &ExtCtxt<'_>,
span: Span,
substructure: &Substructure<'_>,
) -> P<Expr> {
use SubstructureFields::*;

match substructure.fields {
EnumMatching(.., fields) | Struct(.., fields) => {
let combine = move |acc, field| {
let rhs = get_field_equality_expr(cx, field);
if let Some(lhs) = acc {
// Combine the previous comparison with the current field
// using logical AND.
return Some(cx.expr_binary(field.span, BinOpKind::And, lhs, rhs));
}
// Start the chain with the first field's comparison.
Some(rhs)
};

// First compare scalar fields, then compound fields, combining all
// with logical AND.
return fields
.iter()
.filter(|field| !field.maybe_scalar)
.fold(fields.iter().filter(|field| field.maybe_scalar).fold(None, combine), combine)
// If there are no fields, treat as always equal.
.unwrap_or_else(|| cx.expr_bool(span, true));
}
EnumDiscr(disc, match_expr) => {
let lhs = get_field_equality_expr(cx, disc);
let Some(match_expr) = match_expr else {
return lhs;
};
// Compare the discriminant first (cheaper), then the rest of the
// fields.
return cx.expr_binary(disc.span, BinOpKind::And, lhs, match_expr.clone());
}
StaticEnum(..) => cx.dcx().span_bug(
span,
"unexpected static enum encountered during `derive(PartialEq)` expansion",
),
StaticStruct(..) => cx.dcx().span_bug(
span,
"unexpected static struct encountered during `derive(PartialEq)` expansion",
),
AllFieldlessEnum(..) => cx.dcx().span_bug(
span,
"unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion",
),
}
}

/// Generates an equality comparison expression for a single struct or enum
/// field.
///
/// This function produces an AST expression that compares the `self` and
/// `other` values for a field using `==`. It removes any leading references
/// from both sides for readability. If the field is a block expression, it is
/// wrapped in parentheses to ensure valid syntax.
///
/// # Panics
///
/// Panics if there are not exactly two arguments to compare (should be `self`
/// and `other`).
fn get_field_equality_expr(cx: &ExtCtxt<'_>, field: &FieldInfo) -> P<Expr> {
let [rhs] = &field.other_selflike_exprs[..] else {
cx.dcx().span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
};

cx.expr_binary(
field.span,
BinOpKind::Eq,
wrap_block_expr(cx, peel_refs(&field.self_expr)),
wrap_block_expr(cx, peel_refs(rhs)),
)
}

/// Removes all leading immutable references from an expression.
///
/// This is used to strip away any number of leading `&` from an expression
/// (e.g., `&&&T` becomes `T`). Only removes immutable references; mutable
/// references are preserved.
fn peel_refs(mut expr: &P<Expr>) -> P<Expr> {
while let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) = &expr.kind {
expr = &inner;
}
expr.clone()
}

/// Wraps a block expression in parentheses to ensure valid AST in macro
/// expansion output.
///
/// If the given expression is a block, it is wrapped in parentheses; otherwise,
/// it is returned unchanged.
fn wrap_block_expr(cx: &ExtCtxt<'_>, expr: P<Expr>) -> P<Expr> {
if matches!(&expr.kind, ExprKind::Block(..)) {
return cx.expr_paren(expr.span, expr);
}
expr
}
5 changes: 4 additions & 1 deletion compiler/rustc_builtin_macros/src/deriving/generic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ pub(crate) struct FieldInfo {
/// The expressions corresponding to references to this field in
/// the other selflike arguments.
pub other_selflike_exprs: Vec<P<Expr>>,
pub maybe_scalar: bool,
}

#[derive(Copy, Clone)]
Expand Down Expand Up @@ -1220,7 +1221,8 @@ impl<'a> MethodDef<'a> {

let self_expr = discr_exprs.remove(0);
let other_selflike_exprs = discr_exprs;
let discr_field = FieldInfo { span, name: None, self_expr, other_selflike_exprs };
let discr_field =
FieldInfo { span, name: None, self_expr, other_selflike_exprs, maybe_scalar: true };

let discr_let_stmts: ThinVec<_> = iter::zip(&discr_idents, &selflike_args)
.map(|(&ident, selflike_arg)| {
Expand Down Expand Up @@ -1533,6 +1535,7 @@ impl<'a> TraitDef<'a> {
name: struct_field.ident,
self_expr,
other_selflike_exprs,
maybe_scalar: struct_field.ty.peel_refs().kind.maybe_scalar(),
}
})
.collect()
Expand Down
30 changes: 30 additions & 0 deletions tests/ui/deriving/deriving-all-codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@ struct Big {
b1: u32, b2: u32, b3: u32, b4: u32, b5: u32, b6: u32, b7: u32, b8: u32,
}

// It is more efficient to compare scalar types before non-scalar types.
#[derive(PartialEq, PartialOrd)]
struct Reorder {
b1: Option<f32>,
b2: u16,
b3: &'static str,
b4: i8,
b5: u128,
_b: *mut &'static dyn FnMut() -> (),
b6: f64,
b7: &'static mut (),
b8: char,
b9: &'static [i64],
b10: &'static *const bool,
}

// A struct that doesn't impl `Copy`, which means it gets the non-simple
// `clone` implemention that clones the fields individually.
#[derive(Clone)]
Expand Down Expand Up @@ -130,6 +146,20 @@ enum Mixed {
S { d1: Option<u32>, d2: Option<i32> },
}

// When comparing enum variant it is more efficient to compare scalar types before non-scalar types.
#[derive(PartialEq, PartialOrd)]
enum ReorderEnum {
A(i32),
B,
C(i8),
D,
E,
F,
G(&'static mut str, *const u8, *const dyn Fn() -> ()),
H,
I,
}

// An enum with no fieldless variants. Note that `Default` cannot be derived
// for this enum.
#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
Expand Down
Loading
Loading