From 43b2486a9589380fd10698c2169351641a142f2e Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Tue, 27 Dec 2022 02:10:54 +0000
Subject: [PATCH 1/2] Clean up operator type checking

---
 compiler/rustc_hir_typeck/src/method/mod.rs | 71 +++++----------------
 compiler/rustc_hir_typeck/src/op.rs         | 43 ++++---------
 2 files changed, 27 insertions(+), 87 deletions(-)

diff --git a/compiler/rustc_hir_typeck/src/method/mod.rs b/compiler/rustc_hir_typeck/src/method/mod.rs
index fddb8a458a7d5..bd3b90134d7ff 100644
--- a/compiler/rustc_hir_typeck/src/method/mod.rs
+++ b/compiler/rustc_hir_typeck/src/method/mod.rs
@@ -306,8 +306,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         span: Span,
         trait_def_id: DefId,
         self_ty: Ty<'tcx>,
-        opt_input_type: Option<Ty<'tcx>>,
-        opt_input_expr: Option<&'tcx hir::Expr<'tcx>>,
+        opt_rhs: Option<(&'tcx hir::Expr<'tcx>, Ty<'tcx>)>,
         expected: Expectation<'tcx>,
     ) -> (traits::Obligation<'tcx, ty::Predicate<'tcx>>, &'tcx ty::List<ty::subst::GenericArg<'tcx>>)
     {
@@ -318,7 +317,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                 GenericParamDefKind::Type { .. } => {
                     if param.index == 0 {
                         return self_ty.into();
-                    } else if let Some(input_type) = opt_input_type {
+                    } else if let Some((_, input_type)) = opt_rhs {
                         return input_type.into();
                     }
                 }
@@ -339,9 +338,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                     span,
                     self.body_id,
                     traits::BinOp {
-                        rhs_span: opt_input_expr.map(|expr| expr.span),
-                        is_lit: opt_input_expr
-                            .map_or(false, |expr| matches!(expr.kind, hir::ExprKind::Lit(_))),
+                        rhs_span: opt_rhs.map(|(expr, _)| expr.span),
+                        is_lit: opt_rhs
+                            .map_or(false, |(expr, _)| matches!(expr.kind, hir::ExprKind::Lit(_))),
                         output_ty,
                     },
                 ),
@@ -368,15 +367,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
     ) -> Option<InferOk<'tcx, MethodCallee<'tcx>>> {
         let (obligation, substs) =
             self.obligation_for_method(span, trait_def_id, self_ty, opt_input_types);
-        self.construct_obligation_for_trait(
-            span,
-            m_name,
-            trait_def_id,
-            obligation,
-            substs,
-            None,
-            false,
-        )
+        self.construct_obligation_for_trait(span, m_name, trait_def_id, obligation, substs)
     }
 
     pub(super) fn lookup_op_method_in_trait(
@@ -385,27 +376,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         m_name: Ident,
         trait_def_id: DefId,
         self_ty: Ty<'tcx>,
-        opt_input_type: Option<Ty<'tcx>>,
-        opt_input_expr: Option<&'tcx hir::Expr<'tcx>>,
+        opt_rhs: Option<(&'tcx hir::Expr<'tcx>, Ty<'tcx>)>,
         expected: Expectation<'tcx>,
     ) -> Option<InferOk<'tcx, MethodCallee<'tcx>>> {
-        let (obligation, substs) = self.obligation_for_op_method(
-            span,
-            trait_def_id,
-            self_ty,
-            opt_input_type,
-            opt_input_expr,
-            expected,
-        );
-        self.construct_obligation_for_trait(
-            span,
-            m_name,
-            trait_def_id,
-            obligation,
-            substs,
-            opt_input_expr,
-            true,
-        )
+        let (obligation, substs) =
+            self.obligation_for_op_method(span, trait_def_id, self_ty, opt_rhs, expected);
+        self.construct_obligation_for_trait(span, m_name, trait_def_id, obligation, substs)
     }
 
     // FIXME(#18741): it seems likely that we can consolidate some of this
@@ -418,8 +394,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         trait_def_id: DefId,
         obligation: traits::PredicateObligation<'tcx>,
         substs: &'tcx ty::List<ty::subst::GenericArg<'tcx>>,
-        opt_input_expr: Option<&'tcx hir::Expr<'tcx>>,
-        is_op: bool,
     ) -> Option<InferOk<'tcx, MethodCallee<'tcx>>> {
         debug!(?obligation);
 
@@ -463,22 +437,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         let fn_sig = fn_sig.subst(self.tcx, substs);
         let fn_sig = self.replace_bound_vars_with_fresh_vars(span, infer::FnCall, fn_sig);
 
-        let cause = if is_op {
-            ObligationCause::new(
-                span,
-                self.body_id,
-                traits::BinOp {
-                    rhs_span: opt_input_expr.map(|expr| expr.span),
-                    is_lit: opt_input_expr
-                        .map_or(false, |expr| matches!(expr.kind, hir::ExprKind::Lit(_))),
-                    output_ty: None,
-                },
-            )
-        } else {
-            traits::ObligationCause::misc(span, self.body_id)
-        };
-
-        let InferOk { value, obligations: o } = self.at(&cause, self.param_env).normalize(fn_sig);
+        let InferOk { value, obligations: o } =
+            self.at(&obligation.cause, self.param_env).normalize(fn_sig);
         let fn_sig = {
             obligations.extend(o);
             value
@@ -494,7 +454,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         // any late-bound regions appearing in its bounds.
         let bounds = self.tcx.predicates_of(def_id).instantiate(self.tcx, substs);
 
-        let InferOk { value, obligations: o } = self.at(&cause, self.param_env).normalize(bounds);
+        let InferOk { value, obligations: o } =
+            self.at(&obligation.cause, self.param_env).normalize(bounds);
         let bounds = {
             obligations.extend(o);
             value
@@ -502,7 +463,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
 
         assert!(!bounds.has_escaping_bound_vars());
 
-        let predicates_cause = cause.clone();
+        let predicates_cause = obligation.cause.clone();
         obligations.extend(traits::predicates_for_generics(
             move |_, _| predicates_cause.clone(),
             self.param_env,
@@ -517,7 +478,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         );
         obligations.push(traits::Obligation::new(
             tcx,
-            cause,
+            obligation.cause,
             self.param_env,
             ty::Binder::dummy(ty::PredicateKind::WellFormed(method_ty.into())),
         ));
diff --git a/compiler/rustc_hir_typeck/src/op.rs b/compiler/rustc_hir_typeck/src/op.rs
index 9f0d175c4c669..632f4e9401e44 100644
--- a/compiler/rustc_hir_typeck/src/op.rs
+++ b/compiler/rustc_hir_typeck/src/op.rs
@@ -48,8 +48,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                 if self
                     .lookup_op_method(
                         lhs_deref_ty,
-                        Some(rhs_ty),
-                        Some(rhs),
+                        Some((rhs, rhs_ty)),
                         Op::Binary(op, IsAssign::Yes),
                         expected,
                     )
@@ -60,8 +59,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                     if self
                         .lookup_op_method(
                             lhs_ty,
-                            Some(rhs_ty),
-                            Some(rhs),
+                            Some((rhs, rhs_ty)),
                             Op::Binary(op, IsAssign::Yes),
                             expected,
                         )
@@ -248,8 +246,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
 
         let result = self.lookup_op_method(
             lhs_ty,
-            Some(rhs_ty_var),
-            Some(rhs_expr),
+            Some((rhs_expr, rhs_ty_var)),
             Op::Binary(op, is_assign),
             expected,
         );
@@ -382,8 +379,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                     if self
                         .lookup_op_method(
                             lhs_deref_ty,
-                            Some(rhs_ty),
-                            Some(rhs_expr),
+                            Some((rhs_expr, rhs_ty)),
                             Op::Binary(op, is_assign),
                             expected,
                         )
@@ -410,8 +406,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                 let is_compatible = |lhs_ty, rhs_ty| {
                     self.lookup_op_method(
                         lhs_ty,
-                        Some(rhs_ty),
-                        Some(rhs_expr),
+                        Some((rhs_expr, rhs_ty)),
                         Op::Binary(op, is_assign),
                         expected,
                     )
@@ -471,8 +466,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                         let errors = self
                             .lookup_op_method(
                                 lhs_ty,
-                                Some(rhs_ty),
-                                Some(rhs_expr),
+                                Some((rhs_expr, rhs_ty)),
                                 Op::Binary(op, is_assign),
                                 expected,
                             )
@@ -625,7 +619,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         expected: Expectation<'tcx>,
     ) -> Ty<'tcx> {
         assert!(op.is_by_value());
-        match self.lookup_op_method(operand_ty, None, None, Op::Unary(op, ex.span), expected) {
+        match self.lookup_op_method(operand_ty, None, Op::Unary(op, ex.span), expected) {
             Ok(method) => {
                 self.write_method_call(ex.hir_id, method);
                 method.sig.output()
@@ -712,8 +706,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
     fn lookup_op_method(
         &self,
         lhs_ty: Ty<'tcx>,
-        other_ty: Option<Ty<'tcx>>,
-        other_ty_expr: Option<&'tcx hir::Expr<'tcx>>,
+        opt_rhs: Option<(&'tcx hir::Expr<'tcx>, Ty<'tcx>)>,
         op: Op,
         expected: Expectation<'tcx>,
     ) -> Result<MethodCallee<'tcx>, Vec<FulfillmentError<'tcx>>> {
@@ -747,15 +740,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
 
         let opname = Ident::with_dummy_span(opname);
         let method = trait_did.and_then(|trait_did| {
-            self.lookup_op_method_in_trait(
-                span,
-                opname,
-                trait_did,
-                lhs_ty,
-                other_ty,
-                other_ty_expr,
-                expected,
-            )
+            self.lookup_op_method_in_trait(span, opname, trait_did, lhs_ty, opt_rhs, expected)
         });
 
         match (method, trait_did) {
@@ -766,14 +751,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
             }
             (None, None) => Err(vec![]),
             (None, Some(trait_did)) => {
-                let (obligation, _) = self.obligation_for_op_method(
-                    span,
-                    trait_did,
-                    lhs_ty,
-                    other_ty,
-                    other_ty_expr,
-                    expected,
-                );
+                let (obligation, _) =
+                    self.obligation_for_op_method(span, trait_did, lhs_ty, opt_rhs, expected);
                 Err(rustc_trait_selection::traits::fully_solve_obligation(self, obligation))
             }
         }

From 8bf7ec75343cde1c72a5e16b0171e259412b8958 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Tue, 27 Dec 2022 02:40:56 +0000
Subject: [PATCH 2/2] Deduplicate more op-flavored methods

---
 compiler/rustc_hir_typeck/src/callee.rs     |  2 +-
 compiler/rustc_hir_typeck/src/method/mod.rs | 89 +++------------------
 compiler/rustc_hir_typeck/src/op.rs         | 26 +++++-
 compiler/rustc_hir_typeck/src/place_op.rs   |  4 +-
 4 files changed, 37 insertions(+), 84 deletions(-)

diff --git a/compiler/rustc_hir_typeck/src/callee.rs b/compiler/rustc_hir_typeck/src/callee.rs
index af14ee08a9981..829913d278d06 100644
--- a/compiler/rustc_hir_typeck/src/callee.rs
+++ b/compiler/rustc_hir_typeck/src/callee.rs
@@ -241,7 +241,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
             });
 
             if let Some(ok) = self.lookup_method_in_trait(
-                call_expr.span,
+                self.misc(call_expr.span),
                 method_name,
                 trait_def_id,
                 adjusted_ty,
diff --git a/compiler/rustc_hir_typeck/src/method/mod.rs b/compiler/rustc_hir_typeck/src/method/mod.rs
index bd3b90134d7ff..b9b27e8627aff 100644
--- a/compiler/rustc_hir_typeck/src/method/mod.rs
+++ b/compiler/rustc_hir_typeck/src/method/mod.rs
@@ -11,7 +11,7 @@ pub use self::suggest::SelfSource;
 pub use self::MethodError::*;
 
 use crate::errors::OpMethodGenericParams;
-use crate::{Expectation, FnCtxt};
+use crate::FnCtxt;
 use rustc_data_structures::sync::Lrc;
 use rustc_errors::{Applicability, Diagnostic};
 use rustc_hir as hir;
@@ -264,7 +264,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
 
     pub(super) fn obligation_for_method(
         &self,
-        span: Span,
+        cause: ObligationCause<'tcx>,
         trait_def_id: DefId,
         self_ty: Ty<'tcx>,
         opt_input_types: Option<&[Ty<'tcx>]>,
@@ -282,70 +282,19 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                     }
                 }
             }
-            self.var_for_def(span, param)
-        });
-
-        let trait_ref = self.tcx.mk_trait_ref(trait_def_id, substs);
-
-        // Construct an obligation
-        let poly_trait_ref = ty::Binder::dummy(trait_ref);
-        (
-            traits::Obligation::misc(
-                self.tcx,
-                span,
-                self.body_id,
-                self.param_env,
-                poly_trait_ref.without_const(),
-            ),
-            substs,
-        )
-    }
-
-    pub(super) fn obligation_for_op_method(
-        &self,
-        span: Span,
-        trait_def_id: DefId,
-        self_ty: Ty<'tcx>,
-        opt_rhs: Option<(&'tcx hir::Expr<'tcx>, Ty<'tcx>)>,
-        expected: Expectation<'tcx>,
-    ) -> (traits::Obligation<'tcx, ty::Predicate<'tcx>>, &'tcx ty::List<ty::subst::GenericArg<'tcx>>)
-    {
-        // Construct a trait-reference `self_ty : Trait<input_tys>`
-        let substs = InternalSubsts::for_item(self.tcx, trait_def_id, |param, _| {
-            match param.kind {
-                GenericParamDefKind::Lifetime | GenericParamDefKind::Const { .. } => {}
-                GenericParamDefKind::Type { .. } => {
-                    if param.index == 0 {
-                        return self_ty.into();
-                    } else if let Some((_, input_type)) = opt_rhs {
-                        return input_type.into();
-                    }
-                }
-            }
-            self.var_for_def(span, param)
+            self.var_for_def(cause.span, param)
         });
 
         let trait_ref = self.tcx.mk_trait_ref(trait_def_id, substs);
 
         // Construct an obligation
         let poly_trait_ref = ty::Binder::dummy(trait_ref);
-        let output_ty = expected.only_has_type(self).and_then(|ty| (!ty.needs_infer()).then(|| ty));
-
         (
             traits::Obligation::new(
                 self.tcx,
-                traits::ObligationCause::new(
-                    span,
-                    self.body_id,
-                    traits::BinOp {
-                        rhs_span: opt_rhs.map(|(expr, _)| expr.span),
-                        is_lit: opt_rhs
-                            .map_or(false, |(expr, _)| matches!(expr.kind, hir::ExprKind::Lit(_))),
-                        output_ty,
-                    },
-                ),
+                cause,
                 self.param_env,
-                poly_trait_ref,
+                poly_trait_ref.without_const(),
             ),
             substs,
         )
@@ -356,32 +305,18 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
     /// In particular, it doesn't really do any probing: it simply constructs
     /// an obligation for a particular trait with the given self type and checks
     /// whether that trait is implemented.
-    #[instrument(level = "debug", skip(self, span))]
+    #[instrument(level = "debug", skip(self))]
     pub(super) fn lookup_method_in_trait(
         &self,
-        span: Span,
+        cause: ObligationCause<'tcx>,
         m_name: Ident,
         trait_def_id: DefId,
         self_ty: Ty<'tcx>,
         opt_input_types: Option<&[Ty<'tcx>]>,
     ) -> Option<InferOk<'tcx, MethodCallee<'tcx>>> {
         let (obligation, substs) =
-            self.obligation_for_method(span, trait_def_id, self_ty, opt_input_types);
-        self.construct_obligation_for_trait(span, m_name, trait_def_id, obligation, substs)
-    }
-
-    pub(super) fn lookup_op_method_in_trait(
-        &self,
-        span: Span,
-        m_name: Ident,
-        trait_def_id: DefId,
-        self_ty: Ty<'tcx>,
-        opt_rhs: Option<(&'tcx hir::Expr<'tcx>, Ty<'tcx>)>,
-        expected: Expectation<'tcx>,
-    ) -> Option<InferOk<'tcx, MethodCallee<'tcx>>> {
-        let (obligation, substs) =
-            self.obligation_for_op_method(span, trait_def_id, self_ty, opt_rhs, expected);
-        self.construct_obligation_for_trait(span, m_name, trait_def_id, obligation, substs)
+            self.obligation_for_method(cause, trait_def_id, self_ty, opt_input_types);
+        self.construct_obligation_for_trait(m_name, trait_def_id, obligation, substs)
     }
 
     // FIXME(#18741): it seems likely that we can consolidate some of this
@@ -389,7 +324,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
     // of this method is basically the same as confirmation.
     fn construct_obligation_for_trait(
         &self,
-        span: Span,
         m_name: Ident,
         trait_def_id: DefId,
         obligation: traits::PredicateObligation<'tcx>,
@@ -409,7 +343,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         let tcx = self.tcx;
         let Some(method_item) = self.associated_value(trait_def_id, m_name) else {
             tcx.sess.delay_span_bug(
-                span,
+                obligation.cause.span,
                 "operator trait does not have corresponding operator method",
             );
             return None;
@@ -435,7 +369,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         // with bound regions.
         let fn_sig = tcx.bound_fn_sig(def_id);
         let fn_sig = fn_sig.subst(self.tcx, substs);
-        let fn_sig = self.replace_bound_vars_with_fresh_vars(span, infer::FnCall, fn_sig);
+        let fn_sig =
+            self.replace_bound_vars_with_fresh_vars(obligation.cause.span, infer::FnCall, fn_sig);
 
         let InferOk { value, obligations: o } =
             self.at(&obligation.cause, self.param_env).normalize(fn_sig);
diff --git a/compiler/rustc_hir_typeck/src/op.rs b/compiler/rustc_hir_typeck/src/op.rs
index 632f4e9401e44..34140f3e1fe3e 100644
--- a/compiler/rustc_hir_typeck/src/op.rs
+++ b/compiler/rustc_hir_typeck/src/op.rs
@@ -12,14 +12,16 @@ use rustc_middle::ty::adjustment::{
     Adjust, Adjustment, AllowTwoPhase, AutoBorrow, AutoBorrowMutability,
 };
 use rustc_middle::ty::print::with_no_trimmed_paths;
-use rustc_middle::ty::{self, DefIdTree, Ty, TyCtxt, TypeFolder, TypeSuperFoldable, TypeVisitable};
+use rustc_middle::ty::{
+    self, DefIdTree, IsSuggestable, Ty, TyCtxt, TypeFolder, TypeSuperFoldable, TypeVisitable,
+};
 use rustc_session::errors::ExprParenthesesNeeded;
 use rustc_span::source_map::Spanned;
 use rustc_span::symbol::{sym, Ident};
 use rustc_span::Span;
 use rustc_trait_selection::infer::InferCtxtExt;
 use rustc_trait_selection::traits::error_reporting::suggestions::TypeErrCtxtExt as _;
-use rustc_trait_selection::traits::FulfillmentError;
+use rustc_trait_selection::traits::{self, FulfillmentError};
 use rustc_type_ir::sty::TyKind::*;
 
 impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
@@ -486,6 +488,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                                             if let Some(output_def_id) = output_def_id
                                                 && let Some(trait_def_id) = trait_def_id
                                                 && self.tcx.parent(output_def_id) == trait_def_id
+                                                && output_ty.is_suggestable(self.tcx, false)
                                             {
                                                 Some(("Output", *output_ty))
                                             } else {
@@ -735,12 +738,27 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                 Op::Unary(..) => 0,
             },
         ) {
+            self.tcx
+                .sess
+                .delay_span_bug(span, "operator didn't have the right number of generic args");
             return Err(vec![]);
         }
 
         let opname = Ident::with_dummy_span(opname);
+        let input_types =
+            opt_rhs.as_ref().map(|(_, ty)| std::slice::from_ref(ty)).unwrap_or_default();
+        let cause = self.cause(
+            span,
+            traits::BinOp {
+                rhs_span: opt_rhs.map(|(expr, _)| expr.span),
+                is_lit: opt_rhs
+                    .map_or(false, |(expr, _)| matches!(expr.kind, hir::ExprKind::Lit(_))),
+                output_ty: expected.only_has_type(self),
+            },
+        );
+
         let method = trait_did.and_then(|trait_did| {
-            self.lookup_op_method_in_trait(span, opname, trait_did, lhs_ty, opt_rhs, expected)
+            self.lookup_method_in_trait(cause.clone(), opname, trait_did, lhs_ty, Some(input_types))
         });
 
         match (method, trait_did) {
@@ -752,7 +770,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
             (None, None) => Err(vec![]),
             (None, Some(trait_did)) => {
                 let (obligation, _) =
-                    self.obligation_for_op_method(span, trait_did, lhs_ty, opt_rhs, expected);
+                    self.obligation_for_method(cause, trait_did, lhs_ty, Some(input_types));
                 Err(rustc_trait_selection::traits::fully_solve_obligation(self, obligation))
             }
         }
diff --git a/compiler/rustc_hir_typeck/src/place_op.rs b/compiler/rustc_hir_typeck/src/place_op.rs
index 952ea14887f7b..a0f048fc09b9b 100644
--- a/compiler/rustc_hir_typeck/src/place_op.rs
+++ b/compiler/rustc_hir_typeck/src/place_op.rs
@@ -225,7 +225,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
 
         imm_tr.and_then(|trait_did| {
             self.lookup_method_in_trait(
-                span,
+                self.misc(span),
                 Ident::with_dummy_span(imm_op),
                 trait_did,
                 base_ty,
@@ -264,7 +264,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
 
         mut_tr.and_then(|trait_did| {
             self.lookup_method_in_trait(
-                span,
+                self.misc(span),
                 Ident::with_dummy_span(mut_op),
                 trait_did,
                 base_ty,