diff --git a/compiler/rustc_ast/src/ast.rs b/compiler/rustc_ast/src/ast.rs index a16219361c051..1d4df97da58b3 100644 --- a/compiler/rustc_ast/src/ast.rs +++ b/compiler/rustc_ast/src/ast.rs @@ -2452,6 +2452,39 @@ impl TyKind { None } } + + /// Returns `true` if this type is considered a scalar primitive (e.g., + /// `i32`, `u8`, `bool`, etc). + /// + /// This check is based on **symbol equality** and does **not** remove any + /// path prefixes or references. If a type alias or shadowing is present + /// (e.g., `type i32 = CustomType;`), this method will still return `true` + /// for `i32`, even though it may not refer to the primitive type. + pub fn maybe_scalar(&self) -> bool { + let Some(ty_sym) = self.is_simple_path() else { + // unit type + return self.is_unit(); + }; + matches!( + ty_sym, + sym::i8 + | sym::i16 + | sym::i32 + | sym::i64 + | sym::i128 + | sym::u8 + | sym::u16 + | sym::u32 + | sym::u64 + | sym::u128 + | sym::f16 + | sym::f32 + | sym::f64 + | sym::f128 + | sym::char + | sym::bool + ) + } } /// A pattern type pattern. diff --git a/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs b/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs index 4b93b3414c76b..b1d950b8d89de 100644 --- a/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs +++ b/compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs @@ -8,6 +8,8 @@ use crate::deriving::generic::ty::*; use crate::deriving::generic::*; use crate::deriving::{path_local, path_std}; +/// Expands a `#[derive(PartialEq)]` attribute into an implementation for the +/// target item. pub(crate) fn expand_deriving_partial_eq( cx: &ExtCtxt<'_>, span: Span, @@ -16,62 +18,6 @@ pub(crate) fn expand_deriving_partial_eq( push: &mut dyn FnMut(Annotatable), is_const: bool, ) { - fn cs_eq(cx: &ExtCtxt<'_>, span: Span, substr: &Substructure<'_>) -> BlockOrExpr { - let base = true; - let expr = cs_fold( - true, // use foldl - cx, - span, - substr, - |cx, fold| match fold { - CsFold::Single(field) => { - let [other_expr] = &field.other_selflike_exprs[..] else { - cx.dcx() - .span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`"); - }; - - // We received arguments of type `&T`. Convert them to type `T` by stripping - // any leading `&`. This isn't necessary for type checking, but - // it results in better error messages if something goes wrong. - // - // Note: for arguments that look like `&{ x }`, which occur with packed - // structs, this would cause expressions like `{ self.x } == { other.x }`, - // which isn't valid Rust syntax. This wouldn't break compilation because these - // AST nodes are constructed within the compiler. But it would mean that code - // printed by `-Zunpretty=expanded` (or `cargo expand`) would have invalid - // syntax, which would be suboptimal. So we wrap these in parens, giving - // `({ self.x }) == ({ other.x })`, which is valid syntax. - let convert = |expr: &P| { - if let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) = - &expr.kind - { - if let ExprKind::Block(..) = &inner.kind { - // `&{ x }` form: remove the `&`, add parens. - cx.expr_paren(field.span, inner.clone()) - } else { - // `&x` form: remove the `&`. - inner.clone() - } - } else { - expr.clone() - } - }; - cx.expr_binary( - field.span, - BinOpKind::Eq, - convert(&field.self_expr), - convert(other_expr), - ) - } - CsFold::Combine(span, expr1, expr2) => { - cx.expr_binary(span, BinOpKind::And, expr1, expr2) - } - CsFold::Fieldless => cx.expr_bool(span, base), - }, - ); - BlockOrExpr::new_expr(expr) - } - let structural_trait_def = TraitDef { span, path: path_std!(marker::StructuralPartialEq), @@ -97,7 +43,9 @@ pub(crate) fn expand_deriving_partial_eq( ret_ty: Path(path_local!(bool)), attributes: thin_vec![cx.attr_word(sym::inline, span)], fieldless_variants_strategy: FieldlessVariantsStrategy::Unify, - combine_substructure: combine_substructure(Box::new(|a, b, c| cs_eq(a, b, c))), + combine_substructure: combine_substructure(Box::new(|a, b, c| { + BlockOrExpr::new_expr(get_substructure_equality_expr(a, b, c)) + })), }]; let trait_def = TraitDef { @@ -113,3 +61,156 @@ pub(crate) fn expand_deriving_partial_eq( }; trait_def.expand(cx, mitem, item, push) } + +/// Generates the equality expression for a struct or enum variant when deriving +/// `PartialEq`. +/// +/// This function generates an expression that checks if all fields of a struct +/// or enum variant are equal. +/// - Scalar fields are compared first for efficiency, followed by compound +/// fields. +/// - If there are no fields, returns `true` (fieldless types are always equal). +/// +/// Whether a field is considered "scalar" is determined by comparing the symbol +/// of its type to a set of known scalar type symbols (e.g., `i32`, `u8`, etc). +/// This check is based on the type's symbol. +/// +/// ### Example 1 +/// ``` +/// #[derive(PartialEq)] +/// struct i32; +/// +/// // Here, `field_2` is of type `i32`, but since it's a user-defined type (not +/// // the primitive), it will not be treated as scalar. The function will still +/// // check equality of `field_2` first because the symbol matches `i32`. +/// #[derive(PartialEq)] +/// struct Struct { +/// field_1: &'static str, +/// field_2: i32, +/// } +/// ``` +/// +/// ### Example 2 +/// ``` +/// mod ty { +/// pub type i32 = i32; +/// } +/// +/// // Here, `field_2` is of type `ty::i32`, which is a type alias for `i32`. +/// // However, the function will not reorder the fields because the symbol for +/// // `ty::i32` does not match the symbol for the primitive `i32` +/// // ("ty::i32" != "i32"). +/// #[derive(PartialEq)] +/// struct Struct { +/// field_1: &'static str, +/// field_2: ty::i32, +/// } +/// ``` +/// +/// For enums, the discriminant is compared first, then the rest of the fields. +/// +/// # Panics +/// +/// If called on static or all-fieldless enums/structs, which should not occur +/// during derive expansion. +fn get_substructure_equality_expr( + cx: &ExtCtxt<'_>, + span: Span, + substructure: &Substructure<'_>, +) -> P { + use SubstructureFields::*; + + match substructure.fields { + EnumMatching(.., fields) | Struct(.., fields) => { + let combine = move |acc, field| { + let rhs = get_field_equality_expr(cx, field); + if let Some(lhs) = acc { + // Combine the previous comparison with the current field + // using logical AND. + return Some(cx.expr_binary(field.span, BinOpKind::And, lhs, rhs)); + } + // Start the chain with the first field's comparison. + Some(rhs) + }; + + // First compare scalar fields, then compound fields, combining all + // with logical AND. + return fields + .iter() + .filter(|field| !field.maybe_scalar) + .fold(fields.iter().filter(|field| field.maybe_scalar).fold(None, combine), combine) + // If there are no fields, treat as always equal. + .unwrap_or_else(|| cx.expr_bool(span, true)); + } + EnumDiscr(disc, match_expr) => { + let lhs = get_field_equality_expr(cx, disc); + let Some(match_expr) = match_expr else { + return lhs; + }; + // Compare the discriminant first (cheaper), then the rest of the + // fields. + return cx.expr_binary(disc.span, BinOpKind::And, lhs, match_expr.clone()); + } + StaticEnum(..) => cx.dcx().span_bug( + span, + "unexpected static enum encountered during `derive(PartialEq)` expansion", + ), + StaticStruct(..) => cx.dcx().span_bug( + span, + "unexpected static struct encountered during `derive(PartialEq)` expansion", + ), + AllFieldlessEnum(..) => cx.dcx().span_bug( + span, + "unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion", + ), + } +} + +/// Generates an equality comparison expression for a single struct or enum +/// field. +/// +/// This function produces an AST expression that compares the `self` and +/// `other` values for a field using `==`. It removes any leading references +/// from both sides for readability. If the field is a block expression, it is +/// wrapped in parentheses to ensure valid syntax. +/// +/// # Panics +/// +/// Panics if there are not exactly two arguments to compare (should be `self` +/// and `other`). +fn get_field_equality_expr(cx: &ExtCtxt<'_>, field: &FieldInfo) -> P { + let [rhs] = &field.other_selflike_exprs[..] else { + cx.dcx().span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`"); + }; + + cx.expr_binary( + field.span, + BinOpKind::Eq, + wrap_block_expr(cx, peel_refs(&field.self_expr)), + wrap_block_expr(cx, peel_refs(rhs)), + ) +} + +/// Removes all leading immutable references from an expression. +/// +/// This is used to strip away any number of leading `&` from an expression +/// (e.g., `&&&T` becomes `T`). Only removes immutable references; mutable +/// references are preserved. +fn peel_refs(mut expr: &P) -> P { + while let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) = &expr.kind { + expr = &inner; + } + expr.clone() +} + +/// Wraps a block expression in parentheses to ensure valid AST in macro +/// expansion output. +/// +/// If the given expression is a block, it is wrapped in parentheses; otherwise, +/// it is returned unchanged. +fn wrap_block_expr(cx: &ExtCtxt<'_>, expr: P) -> P { + if matches!(&expr.kind, ExprKind::Block(..)) { + return cx.expr_paren(expr.span, expr); + } + expr +} diff --git a/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs b/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs index 9aa53f9e4f73b..e0e44841acb4a 100644 --- a/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs +++ b/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs @@ -284,6 +284,7 @@ pub(crate) struct FieldInfo { /// The expressions corresponding to references to this field in /// the other selflike arguments. pub other_selflike_exprs: Vec>, + pub maybe_scalar: bool, } #[derive(Copy, Clone)] @@ -1220,7 +1221,8 @@ impl<'a> MethodDef<'a> { let self_expr = discr_exprs.remove(0); let other_selflike_exprs = discr_exprs; - let discr_field = FieldInfo { span, name: None, self_expr, other_selflike_exprs }; + let discr_field = + FieldInfo { span, name: None, self_expr, other_selflike_exprs, maybe_scalar: true }; let discr_let_stmts: ThinVec<_> = iter::zip(&discr_idents, &selflike_args) .map(|(&ident, selflike_arg)| { @@ -1533,6 +1535,7 @@ impl<'a> TraitDef<'a> { name: struct_field.ident, self_expr, other_selflike_exprs, + maybe_scalar: struct_field.ty.peel_refs().kind.maybe_scalar(), } }) .collect() diff --git a/tests/ui/deriving/deriving-all-codegen.rs b/tests/ui/deriving/deriving-all-codegen.rs index eab2b4f1f5335..e2b6804fbd1d8 100644 --- a/tests/ui/deriving/deriving-all-codegen.rs +++ b/tests/ui/deriving/deriving-all-codegen.rs @@ -45,6 +45,22 @@ struct Big { b1: u32, b2: u32, b3: u32, b4: u32, b5: u32, b6: u32, b7: u32, b8: u32, } +// It is more efficient to compare scalar types before non-scalar types. +#[derive(PartialEq, PartialOrd)] +struct Reorder { + b1: Option, + b2: u16, + b3: &'static str, + b4: i8, + b5: u128, + _b: *mut &'static dyn FnMut() -> (), + b6: f64, + b7: &'static mut (), + b8: char, + b9: &'static [i64], + b10: &'static *const bool, +} + // A struct that doesn't impl `Copy`, which means it gets the non-simple // `clone` implemention that clones the fields individually. #[derive(Clone)] @@ -130,6 +146,20 @@ enum Mixed { S { d1: Option, d2: Option }, } +// When comparing enum variant it is more efficient to compare scalar types before non-scalar types. +#[derive(PartialEq, PartialOrd)] +enum ReorderEnum { + A(i32), + B, + C(i8), + D, + E, + F, + G(&'static mut str, *const u8, *const dyn Fn() -> ()), + H, + I, +} + // An enum with no fieldless variants. Note that `Default` cannot be derived // for this enum. #[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] diff --git a/tests/ui/deriving/deriving-all-codegen.stdout b/tests/ui/deriving/deriving-all-codegen.stdout index 6503c87099040..fa8f249373d30 100644 --- a/tests/ui/deriving/deriving-all-codegen.stdout +++ b/tests/ui/deriving/deriving-all-codegen.stdout @@ -419,6 +419,100 @@ impl ::core::cmp::Ord for Big { } } +// It is more efficient to compare scalar types before non-scalar types. +struct Reorder { + b1: Option, + b2: u16, + b3: &'static str, + b4: i8, + b5: u128, + _b: *mut &'static dyn FnMut() -> (), + b6: f64, + b7: &'static mut (), + b8: char, + b9: &'static [i64], + b10: &'static *const bool, +} +#[automatically_derived] +impl ::core::marker::StructuralPartialEq for Reorder { } +#[automatically_derived] +impl ::core::cmp::PartialEq for Reorder { + #[inline] + fn eq(&self, other: &Reorder) -> bool { + self.b2 == other.b2 && self.b4 == other.b4 && self.b5 == other.b5 && + self.b6 == other.b6 && self.b7 == other.b7 && + self.b8 == other.b8 && self.b10 == other.b10 && + self.b1 == other.b1 && self.b3 == other.b3 && + self._b == other._b && self.b9 == other.b9 + } +} +#[automatically_derived] +impl ::core::cmp::PartialOrd for Reorder { + #[inline] + fn partial_cmp(&self, other: &Reorder) + -> ::core::option::Option<::core::cmp::Ordering> { + match ::core::cmp::PartialOrd::partial_cmp(&self.b1, &other.b1) { + ::core::option::Option::Some(::core::cmp::Ordering::Equal) => + match ::core::cmp::PartialOrd::partial_cmp(&self.b2, + &other.b2) { + ::core::option::Option::Some(::core::cmp::Ordering::Equal) + => + match ::core::cmp::PartialOrd::partial_cmp(&self.b3, + &other.b3) { + ::core::option::Option::Some(::core::cmp::Ordering::Equal) + => + match ::core::cmp::PartialOrd::partial_cmp(&self.b4, + &other.b4) { + ::core::option::Option::Some(::core::cmp::Ordering::Equal) + => + match ::core::cmp::PartialOrd::partial_cmp(&self.b5, + &other.b5) { + ::core::option::Option::Some(::core::cmp::Ordering::Equal) + => + match ::core::cmp::PartialOrd::partial_cmp(&self._b, + &other._b) { + ::core::option::Option::Some(::core::cmp::Ordering::Equal) + => + match ::core::cmp::PartialOrd::partial_cmp(&self.b6, + &other.b6) { + ::core::option::Option::Some(::core::cmp::Ordering::Equal) + => + match ::core::cmp::PartialOrd::partial_cmp(&self.b7, + &other.b7) { + ::core::option::Option::Some(::core::cmp::Ordering::Equal) + => + match ::core::cmp::PartialOrd::partial_cmp(&self.b8, + &other.b8) { + ::core::option::Option::Some(::core::cmp::Ordering::Equal) + => + match ::core::cmp::PartialOrd::partial_cmp(&self.b9, + &other.b9) { + ::core::option::Option::Some(::core::cmp::Ordering::Equal) + => + ::core::cmp::PartialOrd::partial_cmp(&self.b10, &other.b10), + cmp => cmp, + }, + cmp => cmp, + }, + cmp => cmp, + }, + cmp => cmp, + }, + cmp => cmp, + }, + cmp => cmp, + }, + cmp => cmp, + }, + cmp => cmp, + }, + cmp => cmp, + }, + cmp => cmp, + } + } +} + // A struct that doesn't impl `Copy`, which means it gets the non-simple // `clone` implemention that clones the fields individually. struct NonCopy(u32); @@ -1167,6 +1261,77 @@ impl ::core::cmp::Ord for Mixed { } } +// When comparing enum variant it is more efficient to compare scalar types before non-scalar types. +enum ReorderEnum { + A(i32), + B, + C(i8), + D, + E, + F, + G(&'static mut str, *const u8, *const dyn Fn() -> ()), + H, + I, +} +#[automatically_derived] +impl ::core::marker::StructuralPartialEq for ReorderEnum { } +#[automatically_derived] +impl ::core::cmp::PartialEq for ReorderEnum { + #[inline] + fn eq(&self, other: &ReorderEnum) -> bool { + let __self_discr = ::core::intrinsics::discriminant_value(self); + let __arg1_discr = ::core::intrinsics::discriminant_value(other); + __self_discr == __arg1_discr && + match (self, other) { + (ReorderEnum::A(__self_0), ReorderEnum::A(__arg1_0)) => + __self_0 == __arg1_0, + (ReorderEnum::C(__self_0), ReorderEnum::C(__arg1_0)) => + __self_0 == __arg1_0, + (ReorderEnum::G(__self_0, __self_1, __self_2), + ReorderEnum::G(__arg1_0, __arg1_1, __arg1_2)) => + __self_1 == __arg1_1 && __self_0 == __arg1_0 && + __self_2 == __arg1_2, + _ => true, + } + } +} +#[automatically_derived] +impl ::core::cmp::PartialOrd for ReorderEnum { + #[inline] + fn partial_cmp(&self, other: &ReorderEnum) + -> ::core::option::Option<::core::cmp::Ordering> { + let __self_discr = ::core::intrinsics::discriminant_value(self); + let __arg1_discr = ::core::intrinsics::discriminant_value(other); + match ::core::cmp::PartialOrd::partial_cmp(&__self_discr, + &__arg1_discr) { + ::core::option::Option::Some(::core::cmp::Ordering::Equal) => + match (self, other) { + (ReorderEnum::A(__self_0), ReorderEnum::A(__arg1_0)) => + ::core::cmp::PartialOrd::partial_cmp(__self_0, __arg1_0), + (ReorderEnum::C(__self_0), ReorderEnum::C(__arg1_0)) => + ::core::cmp::PartialOrd::partial_cmp(__self_0, __arg1_0), + (ReorderEnum::G(__self_0, __self_1, __self_2), + ReorderEnum::G(__arg1_0, __arg1_1, __arg1_2)) => + match ::core::cmp::PartialOrd::partial_cmp(__self_0, + __arg1_0) { + ::core::option::Option::Some(::core::cmp::Ordering::Equal) + => + match ::core::cmp::PartialOrd::partial_cmp(__self_1, + __arg1_1) { + ::core::option::Option::Some(::core::cmp::Ordering::Equal) + => ::core::cmp::PartialOrd::partial_cmp(__self_2, __arg1_2), + cmp => cmp, + }, + cmp => cmp, + }, + _ => + ::core::option::Option::Some(::core::cmp::Ordering::Equal), + }, + cmp => cmp, + } + } +} + // An enum with no fieldless variants. Note that `Default` cannot be derived // for this enum. enum Fielded { X(u32), Y(bool), Z(Option), }