From c5bf15c757007e3fb53b781164c08e7474d28f7e Mon Sep 17 00:00:00 2001
From: lcnr <rust@lcnr.de>
Date: Fri, 8 Mar 2024 12:44:01 +0100
Subject: [PATCH 1/4] snapshot: avoid leaking inference vars

---
 compiler/rustc_hir_typeck/src/coercion.rs     |   2 +-
 compiler/rustc_hir_typeck/src/method/mod.rs   |   8 +-
 compiler/rustc_hir_typeck/src/method/probe.rs |   8 +
 compiler/rustc_infer/src/infer/mod.rs         |   2 +-
 .../src/infer/snapshot/check_leaks.rs         | 137 +++++++++++
 .../rustc_infer/src/infer/snapshot/fudge.rs   | 212 ++++++++++--------
 .../rustc_infer/src/infer/snapshot/mod.rs     | 170 +++++++++++++-
 compiler/rustc_infer/src/traits/project.rs    |   2 +-
 compiler/rustc_middle/src/ty/error.rs         |   2 +-
 compiler/rustc_trait_selection/src/infer.rs   |  36 +--
 .../src/solve/eval_ctxt/probe.rs              |   2 +
 .../src/solve/fulfill.rs                      |  30 +--
 .../src/solve/inspect/analyse.rs              |   3 +-
 .../src/traits/select/candidate_assembly.rs   |   8 +-
 .../src/traits/select/mod.rs                  |   2 +
 15 files changed, 487 insertions(+), 137 deletions(-)
 create mode 100644 compiler/rustc_infer/src/infer/snapshot/check_leaks.rs

diff --git a/compiler/rustc_hir_typeck/src/coercion.rs b/compiler/rustc_hir_typeck/src/coercion.rs
index 792359c9dda1b..573e093d0684b 100644
--- a/compiler/rustc_hir_typeck/src/coercion.rs
+++ b/compiler/rustc_hir_typeck/src/coercion.rs
@@ -1076,7 +1076,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         let coerce = Coerce::new(self, cause, AllowTwoPhase::No);
         coerce
             .autoderef(rustc_span::DUMMY_SP, expr_ty)
-            .find_map(|(ty, steps)| self.probe(|_| coerce.unify(ty, target)).ok().map(|_| steps))
+            .find_map(|(ty, steps)| self.probe(|_| coerce.unify(ty, target).ok().map(|_| steps)))
     }
 
     /// Given a type, this function will calculate and return the type given
diff --git a/compiler/rustc_hir_typeck/src/method/mod.rs b/compiler/rustc_hir_typeck/src/method/mod.rs
index 3b26a791f6577..cc538927ecbb6 100644
--- a/compiler/rustc_hir_typeck/src/method/mod.rs
+++ b/compiler/rustc_hir_typeck/src/method/mod.rs
@@ -16,6 +16,7 @@ use rustc_hir as hir;
 use rustc_hir::def::{CtorOf, DefKind, Namespace};
 use rustc_hir::def_id::DefId;
 use rustc_infer::infer::{self, InferOk};
+use rustc_infer::trivial_no_snapshot_leaks;
 use rustc_middle::query::Providers;
 use rustc_middle::traits::ObligationCause;
 use rustc_middle::ty::{self, GenericParamDefKind, Ty, TypeVisitableExt};
@@ -43,6 +44,8 @@ pub struct MethodCallee<'tcx> {
     pub sig: ty::FnSig<'tcx>,
 }
 
+// FIXME(#122188): This is wrong, as this type may leak inference vars.
+trivial_no_snapshot_leaks!('tcx, MethodError<'tcx>);
 #[derive(Debug)]
 pub enum MethodError<'tcx> {
     // Did not find an applicable method, but we did find various near-misses that may work.
@@ -79,8 +82,9 @@ pub struct NoMatchData<'tcx> {
     pub mode: probe::Mode,
 }
 
-// A pared down enum describing just the places from which a method
-// candidate can arise. Used for error reporting only.
+trivial_no_snapshot_leaks!('tcx, CandidateSource);
+/// A pared down enum describing just the places from which a method
+/// candidate can arise. Used for error reporting only.
 #[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
 pub enum CandidateSource {
     Impl(DefId),
diff --git a/compiler/rustc_hir_typeck/src/method/probe.rs b/compiler/rustc_hir_typeck/src/method/probe.rs
index bdc796aca3a46..a7894297ed35c 100644
--- a/compiler/rustc_hir_typeck/src/method/probe.rs
+++ b/compiler/rustc_hir_typeck/src/method/probe.rs
@@ -13,8 +13,10 @@ use rustc_hir_analysis::autoderef::{self, Autoderef};
 use rustc_infer::infer::canonical::OriginalQueryValues;
 use rustc_infer::infer::canonical::{Canonical, QueryResponse};
 use rustc_infer::infer::error_reporting::TypeAnnotationNeeded::E0282;
+use rustc_infer::infer::snapshot::NoSnapshotLeaks;
 use rustc_infer::infer::DefineOpaqueTypes;
 use rustc_infer::infer::{self, InferOk, TyCtxtInferExt};
+use rustc_infer::trivial_no_snapshot_leaks;
 use rustc_middle::middle::stability;
 use rustc_middle::query::Providers;
 use rustc_middle::ty::fast_reject::{simplify_type, TreatParams};
@@ -97,6 +99,8 @@ impl<'a, 'tcx> Deref for ProbeContext<'a, 'tcx> {
     }
 }
 
+// FIXME(#122188): This is wrong as this type may leak inference variables.
+trivial_no_snapshot_leaks!('tcx, Candidate<'tcx>);
 #[derive(Debug, Clone)]
 pub(crate) struct Candidate<'tcx> {
     // Candidates are (I'm not quite sure, but they are mostly) basically
@@ -152,6 +156,7 @@ pub(crate) enum CandidateKind<'tcx> {
     ),
 }
 
+trivial_no_snapshot_leaks!('tcx, ProbeResult);
 #[derive(Debug, PartialEq, Eq, Copy, Clone)]
 enum ProbeResult {
     NoMatch,
@@ -195,6 +200,8 @@ impl AutorefOrPtrAdjustment {
     }
 }
 
+// FIXME(#122188): This is wrong as this type may leak inference variables.
+trivial_no_snapshot_leaks!('tcx, Pick<'tcx>);
 #[derive(Debug, Clone)]
 pub struct Pick<'tcx> {
     pub item: ty::AssocItem,
@@ -368,6 +375,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         op: OP,
     ) -> Result<R, MethodError<'tcx>>
     where
+        R: NoSnapshotLeaks<'tcx>,
         OP: FnOnce(ProbeContext<'_, 'tcx>) -> Result<R, MethodError<'tcx>>,
     {
         let mut orig_values = OriginalQueryValues::default();
diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs
index 15cdd6a910e77..0df66187d1b2d 100644
--- a/compiler/rustc_infer/src/infer/mod.rs
+++ b/compiler/rustc_infer/src/infer/mod.rs
@@ -62,7 +62,7 @@ mod projection;
 pub mod region_constraints;
 mod relate;
 pub mod resolve;
-pub(crate) mod snapshot;
+pub mod snapshot;
 pub mod type_variable;
 
 #[must_use]
diff --git a/compiler/rustc_infer/src/infer/snapshot/check_leaks.rs b/compiler/rustc_infer/src/infer/snapshot/check_leaks.rs
new file mode 100644
index 0000000000000..3a0f3465a8777
--- /dev/null
+++ b/compiler/rustc_infer/src/infer/snapshot/check_leaks.rs
@@ -0,0 +1,137 @@
+use super::VariableLengths;
+use crate::infer::InferCtxt;
+use rustc_middle::ty::{self, Ty, TyCtxt};
+use rustc_middle::ty::{TypeSuperVisitable, TypeVisitor};
+use std::ops::ControlFlow;
+
+pub struct HasSnapshotLeaksVisitor {
+    universe: ty::UniverseIndex,
+    variable_lengths: VariableLengths,
+}
+impl HasSnapshotLeaksVisitor {
+    pub fn new<'tcx>(infcx: &InferCtxt<'tcx>) -> Self {
+        HasSnapshotLeaksVisitor {
+            universe: infcx.universe(),
+            variable_lengths: infcx.variable_lengths(),
+        }
+    }
+}
+
+fn continue_if(b: bool) -> ControlFlow<()> {
+    if b { ControlFlow::Continue(()) } else { ControlFlow::Break(()) }
+}
+
+impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for HasSnapshotLeaksVisitor {
+    type Result = ControlFlow<()>;
+
+    fn visit_region(&mut self, r: ty::Region<'tcx>) -> Self::Result {
+        match r.kind() {
+            ty::ReVar(var) => continue_if(var.as_usize() < self.variable_lengths.region_vars),
+            ty::RePlaceholder(p) => continue_if(self.universe.can_name(p.universe)),
+            ty::ReEarlyParam(_)
+            | ty::ReBound(_, _)
+            | ty::ReLateParam(_)
+            | ty::ReStatic
+            | ty::ReErased
+            | ty::ReError(_) => ControlFlow::Continue(()),
+        }
+    }
+    fn visit_ty(&mut self, t: Ty<'tcx>) -> Self::Result {
+        match t.kind() {
+            ty::Infer(ty::TyVar(var)) => {
+                continue_if(var.as_usize() < self.variable_lengths.type_vars)
+            }
+            ty::Infer(ty::IntVar(var)) => {
+                continue_if(var.as_usize() < self.variable_lengths.int_vars)
+            }
+            ty::Infer(ty::FloatVar(var)) => {
+                continue_if(var.as_usize() < self.variable_lengths.float_vars)
+            }
+            ty::Placeholder(p) => continue_if(self.universe.can_name(p.universe)),
+            ty::Infer(ty::FreshTy(..) | ty::FreshIntTy(..) | ty::FreshFloatTy(..))
+            | ty::Bool
+            | ty::Char
+            | ty::Int(_)
+            | ty::Uint(_)
+            | ty::Float(_)
+            | ty::Adt(_, _)
+            | ty::Foreign(_)
+            | ty::Str
+            | ty::Array(_, _)
+            | ty::Slice(_)
+            | ty::RawPtr(_)
+            | ty::Ref(_, _, _)
+            | ty::FnDef(_, _)
+            | ty::FnPtr(_)
+            | ty::Dynamic(_, _, _)
+            | ty::Closure(_, _)
+            | ty::CoroutineClosure(_, _)
+            | ty::Coroutine(_, _)
+            | ty::CoroutineWitness(_, _)
+            | ty::Never
+            | ty::Tuple(_)
+            | ty::Alias(_, _)
+            | ty::Param(_)
+            | ty::Bound(_, _)
+            | ty::Error(_) => t.super_visit_with(self),
+        }
+    }
+    fn visit_const(&mut self, c: ty::Const<'tcx>) -> Self::Result {
+        match c.kind() {
+            ty::ConstKind::Infer(ty::InferConst::Var(var)) => {
+                continue_if(var.as_usize() < self.variable_lengths.const_vars)
+            }
+            // FIXME(const_trait_impl): need to handle effect vars here and in `fudge_inference_if_ok`.
+            ty::ConstKind::Infer(ty::InferConst::EffectVar(_)) => ControlFlow::Continue(()),
+            ty::ConstKind::Placeholder(p) => continue_if(self.universe.can_name(p.universe)),
+            ty::ConstKind::Infer(ty::InferConst::Fresh(_))
+            | ty::ConstKind::Param(_)
+            | ty::ConstKind::Bound(_, _)
+            | ty::ConstKind::Unevaluated(_)
+            | ty::ConstKind::Value(_)
+            | ty::ConstKind::Expr(_)
+            | ty::ConstKind::Error(_) => c.super_visit_with(self),
+        }
+    }
+}
+
+#[macro_export]
+#[cfg(debug_assertions)]
+macro_rules! type_foldable_verify_no_snapshot_leaks {
+    ($tcx:lifetime, $t:ty) => {
+        const _: () = {
+            use rustc_middle::ty::TypeVisitable;
+            use $crate::infer::snapshot::check_leaks::HasSnapshotLeaksVisitor;
+            use $crate::infer::InferCtxt;
+            impl<$tcx> $crate::infer::snapshot::NoSnapshotLeaks<$tcx> for $t {
+                type StartData = HasSnapshotLeaksVisitor;
+                type EndData = ($t, HasSnapshotLeaksVisitor);
+                fn snapshot_start_data(infcx: &$crate::infer::InferCtxt<$tcx>) -> Self::StartData {
+                    HasSnapshotLeaksVisitor::new(infcx)
+                }
+                fn end_of_snapshot(
+                    _: &InferCtxt<'tcx>,
+                    value: $t,
+                    visitor: Self::StartData,
+                ) -> Self::EndData {
+                    (value, visitor)
+                }
+                fn avoid_leaks(_: &InferCtxt<$tcx>, (value, mut visitor): Self::EndData) -> Self {
+                    if value.visit_with(&mut visitor).is_break() {
+                        bug!("leaking vars from snapshot: {value:?}");
+                    }
+
+                    value
+                }
+            }
+        };
+    };
+}
+
+#[macro_export]
+#[cfg(not(debug_assertions))]
+macro_rules! type_foldable_verify_no_snapshot_leaks {
+    ($tcx:lifetime, $t:ty) => {
+        trivial_no_snapshot_leaks!($tcx, $t);
+    };
+}
diff --git a/compiler/rustc_infer/src/infer/snapshot/fudge.rs b/compiler/rustc_infer/src/infer/snapshot/fudge.rs
index 14de461cd17eb..0e182cb181105 100644
--- a/compiler/rustc_infer/src/infer/snapshot/fudge.rs
+++ b/compiler/rustc_infer/src/infer/snapshot/fudge.rs
@@ -1,6 +1,7 @@
 use rustc_middle::infer::unify_key::{ConstVariableOriginKind, ConstVariableValue, ConstVidKey};
-use rustc_middle::ty::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
+use rustc_middle::ty::TypeVisitableExt;
 use rustc_middle::ty::{self, ConstVid, FloatVid, IntVid, RegionVid, Ty, TyCtxt, TyVid};
+use rustc_middle::ty::{TypeFoldable, TypeFolder, TypeSuperFoldable};
 
 use crate::infer::type_variable::TypeVariableOrigin;
 use crate::infer::InferCtxt;
@@ -12,6 +13,8 @@ use ut::UnifyKey;
 
 use std::ops::Range;
 
+use super::{NoSnapshotLeaks, VariableLengths};
+
 fn vars_since_snapshot<'tcx, T>(
     table: &mut UnificationTable<'_, 'tcx, T>,
     snapshot_var_len: usize,
@@ -43,26 +46,7 @@ fn const_vars_since_snapshot<'tcx>(
     )
 }
 
-struct VariableLengths {
-    type_var_len: usize,
-    const_var_len: usize,
-    int_var_len: usize,
-    float_var_len: usize,
-    region_constraints_len: usize,
-}
-
 impl<'tcx> InferCtxt<'tcx> {
-    fn variable_lengths(&self) -> VariableLengths {
-        let mut inner = self.inner.borrow_mut();
-        VariableLengths {
-            type_var_len: inner.type_variables().num_vars(),
-            const_var_len: inner.const_unification_table().len(),
-            int_var_len: inner.int_unification_table().len(),
-            float_var_len: inner.float_unification_table().len(),
-            region_constraints_len: inner.unwrap_region_constraints().num_region_vars(),
-        }
-    }
-
     /// This rather funky routine is used while processing expected
     /// types. What happens here is that we want to propagate a
     /// coercion through the return type of a fn to its
@@ -107,75 +91,79 @@ impl<'tcx> InferCtxt<'tcx> {
     where
         F: FnOnce() -> Result<T, E>,
         T: TypeFoldable<TyCtxt<'tcx>>,
+        E: NoSnapshotLeaks<'tcx>,
     {
-        let variable_lengths = self.variable_lengths();
-        let (mut fudger, value) = self.probe(|_| {
-            match f() {
-                Ok(value) => {
-                    let value = self.resolve_vars_if_possible(value);
-
-                    // At this point, `value` could in principle refer
-                    // to inference variables that have been created during
-                    // the snapshot. Once we exit `probe()`, those are
-                    // going to be popped, so we will have to
-                    // eliminate any references to them.
-
-                    let mut inner = self.inner.borrow_mut();
-                    let type_vars =
-                        inner.type_variables().vars_since_snapshot(variable_lengths.type_var_len);
-                    let int_vars = vars_since_snapshot(
-                        &mut inner.int_unification_table(),
-                        variable_lengths.int_var_len,
-                    );
-                    let float_vars = vars_since_snapshot(
-                        &mut inner.float_unification_table(),
-                        variable_lengths.float_var_len,
-                    );
-                    let region_vars = inner
-                        .unwrap_region_constraints()
-                        .vars_since_snapshot(variable_lengths.region_constraints_len);
-                    let const_vars = const_vars_since_snapshot(
-                        &mut inner.const_unification_table(),
-                        variable_lengths.const_var_len,
-                    );
-
-                    let fudger = InferenceFudger {
-                        infcx: self,
-                        type_vars,
-                        int_vars,
-                        float_vars,
-                        region_vars,
-                        const_vars,
-                    };
+        self.probe(|_| f().map(|value| FudgeInference(self.resolve_vars_if_possible(value))))
+            .map(|FudgeInference(value)| value)
+    }
+}
 
-                    Ok((fudger, value))
+#[macro_export]
+macro_rules! fudge_vars_no_snapshot_leaks {
+    ($tcx:lifetime, $t:ty) => {
+        const _: () = {
+            use rustc_middle::ty::TypeVisitableExt;
+            use $crate::infer::snapshot::fudge::InferenceFudgeData;
+            impl<$tcx> $crate::infer::snapshot::NoSnapshotLeaks<$tcx> for $t {
+                type StartData = $crate::infer::snapshot::VariableLengths;
+                type EndData = ($t, Option<InferenceFudgeData>);
+                fn snapshot_start_data(infcx: &InferCtxt<$tcx>) -> Self::StartData {
+                    infcx.variable_lengths()
+                }
+                fn end_of_snapshot(
+                    infcx: &InferCtxt<$tcx>,
+                    value: $t,
+                    variable_lengths: Self::StartData,
+                ) -> Self::EndData {
+                    if value.has_infer() {
+                        (value, Some(InferenceFudgeData::new(infcx, variable_lengths)))
+                    } else {
+                        (value, None)
+                    }
+                }
+                fn avoid_leaks(
+                    infcx: &InferCtxt<'tcx>,
+                    (value, fudge_data): Self::EndData,
+                ) -> Self {
+                    if let Some(fudge_data) = fudge_data {
+                        fudge_data.fudge_inference(infcx, value)
+                    } else {
+                        value
+                    }
                 }
-                Err(e) => Err(e),
             }
-        })?;
-
-        // At this point, we need to replace any of the now-popped
-        // type/region variables that appear in `value` with a fresh
-        // variable of the appropriate kind. We can't do this during
-        // the probe because they would just get popped then too. =)
+        };
+    };
+}
 
-        // Micro-optimization: if no variables have been created, then
-        // `value` can't refer to any of them. =) So we can just return it.
-        if fudger.type_vars.0.is_empty()
-            && fudger.int_vars.is_empty()
-            && fudger.float_vars.is_empty()
-            && fudger.region_vars.0.is_empty()
-            && fudger.const_vars.0.is_empty()
-        {
-            Ok(value)
+struct FudgeInference<T>(T);
+impl<'tcx, T: TypeFoldable<TyCtxt<'tcx>>> NoSnapshotLeaks<'tcx> for FudgeInference<T> {
+    type StartData = VariableLengths;
+    type EndData = (T, Option<InferenceFudgeData>);
+    fn snapshot_start_data(infcx: &InferCtxt<'tcx>) -> Self::StartData {
+        infcx.variable_lengths()
+    }
+    fn end_of_snapshot(
+        infcx: &InferCtxt<'tcx>,
+        FudgeInference(value): FudgeInference<T>,
+        variable_lengths: Self::StartData,
+    ) -> Self::EndData {
+        if value.has_infer() {
+            (value, Some(InferenceFudgeData::new(infcx, variable_lengths)))
+        } else {
+            (value, None)
+        }
+    }
+    fn avoid_leaks(infcx: &InferCtxt<'tcx>, (value, fudge_data): Self::EndData) -> Self {
+        if let Some(fudge_data) = fudge_data {
+            FudgeInference(fudge_data.fudge_inference(infcx, value))
         } else {
-            Ok(value.fold_with(&mut fudger))
+            FudgeInference(value)
         }
     }
 }
 
-pub struct InferenceFudger<'a, 'tcx> {
-    infcx: &'a InferCtxt<'tcx>,
+pub struct InferenceFudgeData {
     type_vars: (Range<TyVid>, Vec<TypeVariableOrigin>),
     int_vars: Range<IntVid>,
     float_vars: Range<FloatVid>,
@@ -183,6 +171,50 @@ pub struct InferenceFudger<'a, 'tcx> {
     const_vars: (Range<ConstVid>, Vec<ConstVariableOrigin>),
 }
 
+impl InferenceFudgeData {
+    pub fn new<'tcx>(
+        infcx: &InferCtxt<'tcx>,
+        variable_lengths: VariableLengths,
+    ) -> InferenceFudgeData {
+        let mut inner = infcx.inner.borrow_mut();
+        let type_vars = inner.type_variables().vars_since_snapshot(variable_lengths.type_vars);
+        let int_vars =
+            vars_since_snapshot(&mut inner.int_unification_table(), variable_lengths.int_vars);
+        let float_vars =
+            vars_since_snapshot(&mut inner.float_unification_table(), variable_lengths.float_vars);
+        let region_vars =
+            inner.unwrap_region_constraints().vars_since_snapshot(variable_lengths.region_vars);
+        let const_vars = const_vars_since_snapshot(
+            &mut inner.const_unification_table(),
+            variable_lengths.const_vars,
+        );
+
+        InferenceFudgeData { type_vars, int_vars, float_vars, region_vars, const_vars }
+    }
+
+    pub fn fudge_inference<'tcx, T: TypeFoldable<TyCtxt<'tcx>>>(
+        self,
+        infcx: &InferCtxt<'tcx>,
+        value: T,
+    ) -> T {
+        if self.type_vars.0.is_empty()
+            && self.int_vars.is_empty()
+            && self.float_vars.is_empty()
+            && self.region_vars.0.is_empty()
+            && self.const_vars.0.is_empty()
+        {
+            value
+        } else {
+            value.fold_with(&mut InferenceFudger { infcx, data: self })
+        }
+    }
+}
+
+struct InferenceFudger<'a, 'tcx> {
+    infcx: &'a InferCtxt<'tcx>,
+    data: InferenceFudgeData,
+}
+
 impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
     fn interner(&self) -> TyCtxt<'tcx> {
         self.infcx.tcx
@@ -191,11 +223,11 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
     fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
         match *ty.kind() {
             ty::Infer(ty::InferTy::TyVar(vid)) => {
-                if self.type_vars.0.contains(&vid) {
+                if self.data.type_vars.0.contains(&vid) {
                     // This variable was created during the fudging.
                     // Recreate it with a fresh variable here.
-                    let idx = vid.as_usize() - self.type_vars.0.start.as_usize();
-                    let origin = self.type_vars.1[idx];
+                    let idx = vid.as_usize() - self.data.type_vars.0.start.as_usize();
+                    let origin = self.data.type_vars.1[idx];
                     self.infcx.next_ty_var(origin)
                 } else {
                     // This variable was created before the
@@ -210,14 +242,14 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
                 }
             }
             ty::Infer(ty::InferTy::IntVar(vid)) => {
-                if self.int_vars.contains(&vid) {
+                if self.data.int_vars.contains(&vid) {
                     self.infcx.next_int_var()
                 } else {
                     ty
                 }
             }
             ty::Infer(ty::InferTy::FloatVar(vid)) => {
-                if self.float_vars.contains(&vid) {
+                if self.data.float_vars.contains(&vid) {
                     self.infcx.next_float_var()
                 } else {
                     ty
@@ -229,10 +261,10 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
 
     fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
         if let ty::ReVar(vid) = *r
-            && self.region_vars.0.contains(&vid)
+            && self.data.region_vars.0.contains(&vid)
         {
-            let idx = vid.index() - self.region_vars.0.start.index();
-            let origin = self.region_vars.1[idx];
+            let idx = vid.index() - self.data.region_vars.0.start.index();
+            let origin = self.data.region_vars.1[idx];
             return self.infcx.next_region_var(origin);
         }
         r
@@ -240,11 +272,11 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
 
     fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
         if let ty::ConstKind::Infer(ty::InferConst::Var(vid)) = ct.kind() {
-            if self.const_vars.0.contains(&vid) {
+            if self.data.const_vars.0.contains(&vid) {
                 // This variable was created during the fudging.
                 // Recreate it with a fresh variable here.
-                let idx = vid.index() - self.const_vars.0.start.index();
-                let origin = self.const_vars.1[idx];
+                let idx = vid.index() - self.data.const_vars.0.start.index();
+                let origin = self.data.const_vars.1[idx];
                 self.infcx.next_const_var(ct.ty(), origin)
             } else {
                 ct
diff --git a/compiler/rustc_infer/src/infer/snapshot/mod.rs b/compiler/rustc_infer/src/infer/snapshot/mod.rs
index 9eef1471b1af2..5ab1af7cf9637 100644
--- a/compiler/rustc_infer/src/infer/snapshot/mod.rs
+++ b/compiler/rustc_infer/src/infer/snapshot/mod.rs
@@ -3,6 +3,7 @@ use super::InferCtxt;
 use rustc_data_structures::undo_log::UndoLogs;
 use rustc_middle::ty;
 
+pub mod check_leaks;
 mod fudge;
 pub(crate) mod undo_log;
 
@@ -60,24 +61,41 @@ impl<'tcx> InferCtxt<'tcx> {
     pub fn commit_if_ok<T, E, F>(&self, f: F) -> Result<T, E>
     where
         F: FnOnce(&CombinedSnapshot<'tcx>) -> Result<T, E>,
+        E: NoSnapshotLeaks<'tcx>,
     {
+        let no_leaks_data = E::snapshot_start_data(self);
         let snapshot = self.start_snapshot();
         let r = f(&snapshot);
         debug!("commit_if_ok() -- r.is_ok() = {}", r.is_ok());
         match r {
-            Ok(_) => {
+            Ok(value) => {
                 self.commit_from(snapshot);
+                Ok(value)
             }
-            Err(_) => {
+            Err(err) => {
+                let no_leaks_data = E::end_of_snapshot(self, err, no_leaks_data);
                 self.rollback_to(snapshot);
+                Err(E::avoid_leaks(self, no_leaks_data))
             }
         }
-        r
     }
 
     /// Execute `f` then unroll any bindings it creates.
     #[instrument(skip(self, f), level = "debug")]
     pub fn probe<R, F>(&self, f: F) -> R
+    where
+        F: FnOnce(&CombinedSnapshot<'tcx>) -> R,
+        R: NoSnapshotLeaks<'tcx>,
+    {
+        let no_leaks_data = R::snapshot_start_data(self);
+        let snapshot = self.start_snapshot();
+        let r = f(&snapshot);
+        let no_leaks_data = R::end_of_snapshot(self, r, no_leaks_data);
+        self.rollback_to(snapshot);
+        R::avoid_leaks(self, no_leaks_data)
+    }
+
+    pub fn probe_unchecked<R, F>(&self, f: F) -> R
     where
         F: FnOnce(&CombinedSnapshot<'tcx>) -> R,
     {
@@ -99,4 +117,150 @@ impl<'tcx> InferCtxt<'tcx> {
     pub fn opaque_types_added_in_snapshot(&self, snapshot: &CombinedSnapshot<'tcx>) -> bool {
         self.inner.borrow().undo_log.opaque_types_in_snapshot(&snapshot.undo_snapshot)
     }
+
+    pub fn variable_lengths(&self) -> VariableLengths {
+        let mut inner = self.inner.borrow_mut();
+        VariableLengths {
+            type_vars: inner.type_variables().num_vars(),
+            const_vars: inner.const_unification_table().len(),
+            int_vars: inner.int_unification_table().len(),
+            float_vars: inner.float_unification_table().len(),
+            region_vars: inner.unwrap_region_constraints().num_region_vars(),
+        }
+    }
+}
+
+pub struct VariableLengths {
+    region_vars: usize,
+    type_vars: usize,
+    int_vars: usize,
+    float_vars: usize,
+    const_vars: usize,
+}
+
+pub trait NoSnapshotLeaks<'tcx> {
+    type StartData;
+    type EndData;
+    fn snapshot_start_data(infcx: &InferCtxt<'tcx>) -> Self::StartData;
+    fn end_of_snapshot(
+        infcx: &InferCtxt<'tcx>,
+        this: Self,
+        start: Self::StartData,
+    ) -> Self::EndData;
+    fn avoid_leaks(infcx: &InferCtxt<'tcx>, data: Self::EndData) -> Self;
+}
+
+pub trait TrivialNoSnapshotLeaks<'tcx> {}
+impl<'tcx, T: TrivialNoSnapshotLeaks<'tcx>> NoSnapshotLeaks<'tcx> for T {
+    type StartData = ();
+    type EndData = T;
+    #[inline]
+    fn snapshot_start_data(_: &InferCtxt<'tcx>) {}
+    #[inline]
+    fn end_of_snapshot(_: &InferCtxt<'tcx>, this: Self, _: ()) -> T {
+        this
+    }
+    #[inline]
+    fn avoid_leaks(_: &InferCtxt<'tcx>, this: T) -> Self {
+        this
+    }
+}
+
+#[macro_export]
+macro_rules! trivial_no_snapshot_leaks {
+    ($tcx:lifetime, $t:ty) => {
+        impl<$tcx> $crate::infer::snapshot::TrivialNoSnapshotLeaks<$tcx> for $t {}
+    };
+}
+
+mod impls {
+    use super::{NoSnapshotLeaks, TrivialNoSnapshotLeaks};
+    use crate::fudge_vars_no_snapshot_leaks;
+    use crate::infer::InferCtxt;
+    use crate::traits::solve::{CanonicalResponse, Certainty};
+    use crate::traits::MismatchedProjectionTypes;
+    use crate::type_foldable_verify_no_snapshot_leaks;
+    use rustc_hir::def_id::DefId;
+    use rustc_middle::infer::canonical::Canonical;
+    use rustc_middle::traits::query::{MethodAutoderefStepsResult, NoSolution};
+    use rustc_middle::traits::{BuiltinImplSource, EvaluationResult, OverflowError};
+    use rustc_middle::ty;
+    use rustc_middle::ty::error::TypeError;
+    use rustc_span::symbol::Ident;
+    use rustc_span::ErrorGuaranteed;
+
+    trivial_no_snapshot_leaks!('tcx, ());
+    trivial_no_snapshot_leaks!('tcx, bool);
+    trivial_no_snapshot_leaks!('tcx, usize);
+    trivial_no_snapshot_leaks!('tcx, ty::AssocItem);
+    trivial_no_snapshot_leaks!('tcx, BuiltinImplSource);
+    trivial_no_snapshot_leaks!('tcx, DefId);
+    trivial_no_snapshot_leaks!('tcx, ErrorGuaranteed);
+    trivial_no_snapshot_leaks!('tcx, EvaluationResult);
+    trivial_no_snapshot_leaks!('tcx, Ident);
+    trivial_no_snapshot_leaks!('tcx, OverflowError);
+    trivial_no_snapshot_leaks!('tcx, NoSolution);
+    trivial_no_snapshot_leaks!('tcx, Vec<(CanonicalResponse<'tcx>, BuiltinImplSource)>);
+    trivial_no_snapshot_leaks!('tcx, (bool, Certainty));
+    // FIXME(#122188): This is wrong, this can leak inference vars in `opt_bad_ty` and `steps`.
+    trivial_no_snapshot_leaks!('tcx, MethodAutoderefStepsResult<'tcx>);
+    type_foldable_verify_no_snapshot_leaks!('tcx, ty::PolyFnSig<'tcx>);
+    fudge_vars_no_snapshot_leaks!('tcx, TypeError<'tcx>);
+    fudge_vars_no_snapshot_leaks!('tcx, MismatchedProjectionTypes<'tcx>);
+
+    impl<'tcx, T: NoSnapshotLeaks<'tcx>> NoSnapshotLeaks<'tcx> for Option<T> {
+        type StartData = T::StartData;
+        type EndData = Option<T::EndData>;
+        #[inline]
+        fn snapshot_start_data(infcx: &InferCtxt<'tcx>) -> T::StartData {
+            T::snapshot_start_data(infcx)
+        }
+        #[inline]
+        fn end_of_snapshot(
+            infcx: &InferCtxt<'tcx>,
+            this: Option<T>,
+            start_data: T::StartData,
+        ) -> Option<T::EndData> {
+            this.map(|this| T::end_of_snapshot(infcx, this, start_data))
+        }
+        #[inline]
+        fn avoid_leaks(infcx: &InferCtxt<'tcx>, data: Self::EndData) -> Self {
+            data.map(|data| T::avoid_leaks(infcx, data))
+        }
+    }
+
+    impl<'tcx, T, E> NoSnapshotLeaks<'tcx> for Result<T, E>
+    where
+        T: NoSnapshotLeaks<'tcx>,
+        E: NoSnapshotLeaks<'tcx>,
+    {
+        type StartData = (T::StartData, E::StartData);
+        type EndData = Result<T::EndData, E::EndData>;
+        #[inline]
+        fn snapshot_start_data(infcx: &InferCtxt<'tcx>) -> Self::StartData {
+            (T::snapshot_start_data(infcx), E::snapshot_start_data(infcx))
+        }
+        #[inline]
+        fn end_of_snapshot(
+            infcx: &InferCtxt<'tcx>,
+            this: Self,
+            (t, e): Self::StartData,
+        ) -> Self::EndData {
+            match this {
+                Ok(value) => Ok(T::end_of_snapshot(infcx, value, t)),
+                Err(err) => Err(E::end_of_snapshot(infcx, err, e)),
+            }
+        }
+
+        #[inline]
+        fn avoid_leaks(infcx: &InferCtxt<'tcx>, data: Self::EndData) -> Self {
+            match data {
+                Ok(value) => Ok(T::avoid_leaks(infcx, value)),
+                Err(err) => Err(E::avoid_leaks(infcx, err)),
+            }
+        }
+    }
+
+    impl<'tcx, T: TrivialNoSnapshotLeaks<'tcx>> TrivialNoSnapshotLeaks<'tcx> for Vec<T> {}
+    impl<'tcx, V> TrivialNoSnapshotLeaks<'tcx> for Canonical<'tcx, V> {}
 }
diff --git a/compiler/rustc_infer/src/traits/project.rs b/compiler/rustc_infer/src/traits/project.rs
index 31ceb2343324d..ce1ec1154a39c 100644
--- a/compiler/rustc_infer/src/traits/project.rs
+++ b/compiler/rustc_infer/src/traits/project.rs
@@ -15,7 +15,7 @@ pub use rustc_middle::traits::{EvaluationResult, Reveal};
 pub(crate) type UndoLog<'tcx> =
     snapshot_map::UndoLog<ProjectionCacheKey<'tcx>, ProjectionCacheEntry<'tcx>>;
 
-#[derive(Clone)]
+#[derive(Clone, TypeVisitable, TypeFoldable)]
 pub struct MismatchedProjectionTypes<'tcx> {
     pub err: ty::error::TypeError<'tcx>,
 }
diff --git a/compiler/rustc_middle/src/ty/error.rs b/compiler/rustc_middle/src/ty/error.rs
index e15f03788464e..da0ea6496536b 100644
--- a/compiler/rustc_middle/src/ty/error.rs
+++ b/compiler/rustc_middle/src/ty/error.rs
@@ -27,7 +27,7 @@ impl<T> ExpectedFound<T> {
 }
 
 // Data structures used in type unification
-#[derive(Copy, Clone, Debug, TypeVisitable, PartialEq, Eq)]
+#[derive(Copy, Clone, Debug, TypeVisitable, TypeFoldable, PartialEq, Eq)]
 #[rustc_pass_by_value]
 pub enum TypeError<'tcx> {
     Mismatch,
diff --git a/compiler/rustc_trait_selection/src/infer.rs b/compiler/rustc_trait_selection/src/infer.rs
index f694dd0070363..880d0a666890d 100644
--- a/compiler/rustc_trait_selection/src/infer.rs
+++ b/compiler/rustc_trait_selection/src/infer.rs
@@ -69,15 +69,17 @@ impl<'tcx> InferCtxt<'tcx> {
         self.evaluate_obligation(&obligation).unwrap_or(traits::EvaluationResult::EvaluatedToErr)
     }
 
-    /// Returns `Some` if a type implements a trait shallowly, without side-effects,
-    /// along with any errors that would have been reported upon further obligation
-    /// processing.
+    /// Returns `Some` if a type implements a trait shallowly, along with any errors
+    /// that would have been reported upon further obligation processing.
     ///
     /// - If this returns `Some([])`, then the trait holds modulo regions.
     /// - If this returns `Some([errors..])`, then the trait has an impl for
     /// the self type, but some nested obligations do not hold.
     /// - If this returns `None`, no implementation that applies could be found.
     ///
+    /// FIXME: This cannot use a probe as the `FulfillmentError` would otherwise leak
+    /// inference variables.
+    ///
     /// FIXME(-Znext-solver): Due to the recursive nature of the new solver,
     /// this will probably only ever return `Some([])` or `None`.
     fn type_implements_trait_shallow(
@@ -86,22 +88,20 @@ impl<'tcx> InferCtxt<'tcx> {
         ty: Ty<'tcx>,
         param_env: ty::ParamEnv<'tcx>,
     ) -> Option<Vec<traits::FulfillmentError<'tcx>>> {
-        self.probe(|_snapshot| {
-            let mut selcx = SelectionContext::new(self);
-            match selcx.select(&Obligation::new(
-                self.tcx,
-                ObligationCause::dummy(),
-                param_env,
-                ty::TraitRef::new(self.tcx, trait_def_id, [ty]),
-            )) {
-                Ok(Some(selection)) => {
-                    let mut fulfill_cx = <dyn TraitEngine<'tcx>>::new(self);
-                    fulfill_cx.register_predicate_obligations(self, selection.nested_obligations());
-                    Some(fulfill_cx.select_all_or_error(self))
-                }
-                Ok(None) | Err(_) => None,
+        let mut selcx = SelectionContext::new(self);
+        match selcx.select(&Obligation::new(
+            self.tcx,
+            ObligationCause::dummy(),
+            param_env,
+            ty::TraitRef::new(self.tcx, trait_def_id, [ty]),
+        )) {
+            Ok(Some(selection)) => {
+                let mut fulfill_cx = <dyn TraitEngine<'tcx>>::new(self);
+                fulfill_cx.register_predicate_obligations(self, selection.nested_obligations());
+                Some(fulfill_cx.select_all_or_error(self))
             }
-        })
+            Ok(None) | Err(_) => None,
+        }
     }
 }
 
diff --git a/compiler/rustc_trait_selection/src/solve/eval_ctxt/probe.rs b/compiler/rustc_trait_selection/src/solve/eval_ctxt/probe.rs
index 91fd48807a4d8..540101407bd57 100644
--- a/compiler/rustc_trait_selection/src/solve/eval_ctxt/probe.rs
+++ b/compiler/rustc_trait_selection/src/solve/eval_ctxt/probe.rs
@@ -1,6 +1,7 @@
 use crate::solve::assembly::Candidate;
 
 use super::EvalCtxt;
+use rustc_infer::infer::snapshot::NoSnapshotLeaks;
 use rustc_middle::traits::{
     query::NoSolution,
     solve::{inspect, CandidateSource, QueryResult},
@@ -15,6 +16,7 @@ pub(in crate::solve) struct ProbeCtxt<'me, 'a, 'tcx, F, T> {
 
 impl<'tcx, F, T> ProbeCtxt<'_, '_, 'tcx, F, T>
 where
+    T: NoSnapshotLeaks<'tcx>,
     F: FnOnce(&T) -> inspect::ProbeKind<'tcx>,
 {
     pub(in crate::solve) fn enter(self, f: impl FnOnce(&mut EvalCtxt<'_, 'tcx>) -> T) -> T {
diff --git a/compiler/rustc_trait_selection/src/solve/fulfill.rs b/compiler/rustc_trait_selection/src/solve/fulfill.rs
index 3fa409eefffcd..fc98b729801fc 100644
--- a/compiler/rustc_trait_selection/src/solve/fulfill.rs
+++ b/compiler/rustc_trait_selection/src/solve/fulfill.rs
@@ -241,22 +241,22 @@ fn fulfillment_error_for_stalled<'tcx>(
     infcx: &InferCtxt<'tcx>,
     obligation: PredicateObligation<'tcx>,
 ) -> FulfillmentError<'tcx> {
-    let code = infcx.probe(|_| {
-        match infcx.evaluate_root_goal(obligation.clone().into(), GenerateProofTree::Never).0 {
-            Ok((_, Certainty::Maybe(MaybeCause::Ambiguity))) => {
-                FulfillmentErrorCode::Ambiguity { overflow: None }
-            }
-            Ok((_, Certainty::Maybe(MaybeCause::Overflow { suggest_increasing_limit }))) => {
-                FulfillmentErrorCode::Ambiguity { overflow: Some(suggest_increasing_limit) }
-            }
-            Ok((_, Certainty::Yes)) => {
-                bug!("did not expect successful goal when collecting ambiguity errors")
-            }
-            Err(_) => {
-                bug!("did not expect selection error when collecting ambiguity errors")
-            }
+    let code = match infcx
+        .probe(|_| infcx.evaluate_root_goal(obligation.clone().into(), GenerateProofTree::Never).0)
+    {
+        Ok((_, Certainty::Maybe(MaybeCause::Ambiguity))) => {
+            FulfillmentErrorCode::Ambiguity { overflow: None }
         }
-    });
+        Ok((_, Certainty::Maybe(MaybeCause::Overflow { suggest_increasing_limit }))) => {
+            FulfillmentErrorCode::Ambiguity { overflow: Some(suggest_increasing_limit) }
+        }
+        Ok((_, Certainty::Yes)) => {
+            bug!("did not expect successful goal when collecting ambiguity errors")
+        }
+        Err(_) => {
+            bug!("did not expect selection error when collecting ambiguity errors")
+        }
+    };
 
     FulfillmentError { obligation: obligation.clone(), code, root_obligation: obligation }
 }
diff --git a/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs b/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
index 9e3e6a4676efb..08876be739ed7 100644
--- a/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
+++ b/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
@@ -11,6 +11,7 @@
 
 use rustc_ast_ir::try_visit;
 use rustc_ast_ir::visit::VisitorResult;
+use rustc_infer::infer::snapshot::NoSnapshotLeaks;
 use rustc_infer::infer::InferCtxt;
 use rustc_middle::traits::query::NoSolution;
 use rustc_middle::traits::solve::{inspect, QueryResult};
@@ -201,7 +202,7 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
 
 /// The public API to interact with proof trees.
 pub trait ProofTreeVisitor<'tcx> {
-    type Result: VisitorResult = ();
+    type Result: VisitorResult + NoSnapshotLeaks<'tcx> = ();
 
     fn visit_goal(&mut self, goal: &InspectGoal<'_, 'tcx>) -> Self::Result;
 }
diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
index 66f740b761d32..d1c6e9331a609 100644
--- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
@@ -893,7 +893,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         ty: Ty<'tcx>,
         param_env: ty::ParamEnv<'tcx>,
         cause: &ObligationCause<'tcx>,
-    ) -> Option<ty::PolyExistentialTraitRef<'tcx>> {
+    ) -> Option<DefId> {
         let tcx = self.tcx();
         if tcx.features().trait_upcasting {
             return None;
@@ -922,7 +922,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
             .ty()
             .unwrap();
 
-            if let ty::Dynamic(data, ..) = ty.kind() { data.principal() } else { None }
+            if let ty::Dynamic(data, ..) = ty.kind() { data.principal_def_id() } else { None }
         })
     }
 
@@ -993,12 +993,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
                     let principal_a = a_data.principal().unwrap();
                     let target_trait_did = principal_def_id_b.unwrap();
                     let source_trait_ref = principal_a.with_self_ty(self.tcx(), source);
-                    if let Some(deref_trait_ref) = self.need_migrate_deref_output_trait_object(
+                    if let Some(deref_trait_def_id) = self.need_migrate_deref_output_trait_object(
                         source,
                         obligation.param_env,
                         &obligation.cause,
                     ) {
-                        if deref_trait_ref.def_id() == target_trait_did {
+                        if deref_trait_def_id == target_trait_did {
                             return;
                         }
                     }
diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs
index a6bd1ba9c3f88..d695724bee110 100644
--- a/compiler/rustc_trait_selection/src/traits/select/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs
@@ -35,6 +35,7 @@ use rustc_hir::def_id::DefId;
 use rustc_infer::infer::BoundRegionConversionTime;
 use rustc_infer::infer::DefineOpaqueTypes;
 use rustc_infer::traits::TraitObligation;
+use rustc_infer::trivial_no_snapshot_leaks;
 use rustc_middle::dep_graph::dep_kinds;
 use rustc_middle::dep_graph::DepNodeIndex;
 use rustc_middle::mir::interpret::ErrorHandled;
@@ -3055,6 +3056,7 @@ impl<'o, 'tcx> fmt::Debug for TraitObligationStack<'o, 'tcx> {
     }
 }
 
+trivial_no_snapshot_leaks!('tcx, ProjectionMatchesProjection);
 pub enum ProjectionMatchesProjection {
     Yes,
     Ambiguous,

From 08c0326872db6b97c79ae8f9e3a580eef1a1bfda Mon Sep 17 00:00:00 2001
From: lcnr <rust@lcnr.de>
Date: Mon, 11 Mar 2024 12:24:15 +0100
Subject: [PATCH 2/4] add comments

---
 .../src/infer/snapshot/check_leaks.rs         |  3 +++
 .../rustc_infer/src/infer/snapshot/fudge.rs   | 24 +++++++++++++++++
 .../rustc_infer/src/infer/snapshot/mod.rs     | 26 +++++++++++++++++++
 3 files changed, 53 insertions(+)

diff --git a/compiler/rustc_infer/src/infer/snapshot/check_leaks.rs b/compiler/rustc_infer/src/infer/snapshot/check_leaks.rs
index 3a0f3465a8777..5dc9139e3eaa6 100644
--- a/compiler/rustc_infer/src/infer/snapshot/check_leaks.rs
+++ b/compiler/rustc_infer/src/infer/snapshot/check_leaks.rs
@@ -4,6 +4,9 @@ use rustc_middle::ty::{self, Ty, TyCtxt};
 use rustc_middle::ty::{TypeSuperVisitable, TypeVisitor};
 use std::ops::ControlFlow;
 
+/// Check for leaking inference variables and placeholders
+/// from snapshot. This is only used if `debug_assertions`
+/// are enabled.
 pub struct HasSnapshotLeaksVisitor {
     universe: ty::UniverseIndex,
     variable_lengths: VariableLengths,
diff --git a/compiler/rustc_infer/src/infer/snapshot/fudge.rs b/compiler/rustc_infer/src/infer/snapshot/fudge.rs
index 0e182cb181105..013d46997c688 100644
--- a/compiler/rustc_infer/src/infer/snapshot/fudge.rs
+++ b/compiler/rustc_infer/src/infer/snapshot/fudge.rs
@@ -98,6 +98,12 @@ impl<'tcx> InferCtxt<'tcx> {
     }
 }
 
+/// To avoid leaking inference variables from snapshots, fudge inference
+/// by replacing inference variables from the snapshot with fresh ones
+/// created outside of it.
+///
+/// To see how this works, check out the documentation of the [`FudgeInference`]
+/// wrapper used by [`fn InferCtxt::fudge_inference_if_ok`].
 #[macro_export]
 macro_rules! fudge_vars_no_snapshot_leaks {
     ($tcx:lifetime, $t:ty) => {
@@ -136,13 +142,22 @@ macro_rules! fudge_vars_no_snapshot_leaks {
     };
 }
 
+/// When rolling back a snapshot, replaces inference variables in `T` created
+/// during the snapshot with new inference variables created afterwards.
 struct FudgeInference<T>(T);
 impl<'tcx, T: TypeFoldable<TyCtxt<'tcx>>> NoSnapshotLeaks<'tcx> for FudgeInference<T> {
     type StartData = VariableLengths;
     type EndData = (T, Option<InferenceFudgeData>);
+
+    /// Store which inference variables already exist at the start
+    /// of the snapshot.
     fn snapshot_start_data(infcx: &InferCtxt<'tcx>) -> Self::StartData {
         infcx.variable_lengths()
     }
+    /// At the end of the snapshot, fetch the metadata for all variables
+    /// created during the snapshot. As these variables get discarded during
+    /// rollback, we have to get this information before rollback and use it
+    /// to create new inference variables after.
     fn end_of_snapshot(
         infcx: &InferCtxt<'tcx>,
         FudgeInference(value): FudgeInference<T>,
@@ -154,6 +169,9 @@ impl<'tcx, T: TypeFoldable<TyCtxt<'tcx>>> NoSnapshotLeaks<'tcx> for FudgeInferen
             (value, None)
         }
     }
+    /// Using the metadata fetched in `fn end_of_snapshot`, replace all leaking
+    /// inference variables with new ones, reusing the metadata of the leaked
+    /// variables.
     fn avoid_leaks(infcx: &InferCtxt<'tcx>, (value, fudge_data): Self::EndData) -> Self {
         if let Some(fudge_data) = fudge_data {
             FudgeInference(fudge_data.fudge_inference(infcx, value))
@@ -163,6 +181,8 @@ impl<'tcx, T: TypeFoldable<TyCtxt<'tcx>>> NoSnapshotLeaks<'tcx> for FudgeInferen
     }
 }
 
+/// At the end of a snpashot, right before rollback, remember all newly created
+/// inference variables and their metadata.
 pub struct InferenceFudgeData {
     type_vars: (Range<TyVid>, Vec<TypeVariableOrigin>),
     int_vars: Range<IntVid>,
@@ -210,6 +230,10 @@ impl InferenceFudgeData {
     }
 }
 
+/// Using the `InferenceFudgeData` created right before rollback, replace
+/// all leaked inference variables of the snapshot with newly created ones.
+///
+/// This is used after the snapshot has already been rolled back.
 struct InferenceFudger<'a, 'tcx> {
     infcx: &'a InferCtxt<'tcx>,
     data: InferenceFudgeData,
diff --git a/compiler/rustc_infer/src/infer/snapshot/mod.rs b/compiler/rustc_infer/src/infer/snapshot/mod.rs
index 5ab1af7cf9637..1389bde3d4ce5 100644
--- a/compiler/rustc_infer/src/infer/snapshot/mod.rs
+++ b/compiler/rustc_infer/src/infer/snapshot/mod.rs
@@ -138,6 +138,22 @@ pub struct VariableLengths {
     const_vars: usize,
 }
 
+/// When rolling back a snapshot, we discard all inference constraints
+/// added during that snapshot. We also completely remove any inference
+/// variables created during the snapshot. Leaking these inference
+/// variables from the snapshot and later using them can then result
+/// either in an ICE or even accidentally reuse a newly created, totally
+/// separate, inference variable.
+///
+/// To avoid this we make sure that when rolling back snapshots in
+///  `fn probe` and `fn commit_if_ok` we do not return any inference
+/// variables created during this snapshot.
+///
+/// This has a fairly involved setup as we previously did not check this
+/// and now rely on leaking inference variables, e.g. via `TypeError`.
+/// To still avoid ICE we now "fudge inference" in these cases, replacing
+/// any newly created inference variables from inside the snapshot with
+/// new inference variables created outside of it.
 pub trait NoSnapshotLeaks<'tcx> {
     type StartData;
     type EndData;
@@ -150,6 +166,16 @@ pub trait NoSnapshotLeaks<'tcx> {
     fn avoid_leaks(infcx: &InferCtxt<'tcx>, data: Self::EndData) -> Self;
 }
 
+/// A trait implemented by types which cannot contain any inference variables
+/// which could be leaked. The [`NoSnapshotLeaks`] impl for these types is
+/// trivial.
+///
+/// You can mostly think of this as if it is an auto-trait with negative
+/// impls for `Region`, `Ty` and `Const` and a positive impl for `Canonical`.
+/// Actually using an auto-trait instead of manually implementing it for
+/// all types of interest results in overlaps during coherence. Not using
+/// auto-traits will also make it easier to share this code with Rust Analyzer
+/// in the future, as they want to avoid any unstable features.
 pub trait TrivialNoSnapshotLeaks<'tcx> {}
 impl<'tcx, T: TrivialNoSnapshotLeaks<'tcx>> NoSnapshotLeaks<'tcx> for T {
     type StartData = ();

From 0afbf751b01665d672d56785360bf36ca421f813 Mon Sep 17 00:00:00 2001
From: lcnr <rust@lcnr.de>
Date: Mon, 11 Mar 2024 12:45:07 +0100
Subject: [PATCH 3/4] move folder

---
 src/tools/tidy/src/issues.txt                               | 1 -
 .../cycle-trait/cycle-trait-default-type-trait.rs           | 0
 .../cycle-trait/cycle-trait-default-type-trait.stderr       | 0
 .../cycle-trait/cycle-trait-supertrait-direct.rs            | 0
 .../cycle-trait/cycle-trait-supertrait-direct.stderr        | 0
 .../cycle-trait/cycle-trait-supertrait-indirect.rs          | 0
 .../cycle-trait/cycle-trait-supertrait-indirect.stderr      | 0
 .../cycle-trait/super-trait-issue-12511.rs}                 | 0
 .../cycle-trait/super-trait-issue-12511.stderr}             | 6 +++---
 9 files changed, 3 insertions(+), 4 deletions(-)
 rename tests/ui/{ => traits}/cycle-trait/cycle-trait-default-type-trait.rs (100%)
 rename tests/ui/{ => traits}/cycle-trait/cycle-trait-default-type-trait.stderr (100%)
 rename tests/ui/{ => traits}/cycle-trait/cycle-trait-supertrait-direct.rs (100%)
 rename tests/ui/{ => traits}/cycle-trait/cycle-trait-supertrait-direct.stderr (100%)
 rename tests/ui/{ => traits}/cycle-trait/cycle-trait-supertrait-indirect.rs (100%)
 rename tests/ui/{ => traits}/cycle-trait/cycle-trait-supertrait-indirect.stderr (100%)
 rename tests/ui/{cycle-trait/issue-12511.rs => traits/cycle-trait/super-trait-issue-12511.rs} (100%)
 rename tests/ui/{cycle-trait/issue-12511.stderr => traits/cycle-trait/super-trait-issue-12511.stderr} (84%)

diff --git a/src/tools/tidy/src/issues.txt b/src/tools/tidy/src/issues.txt
index 91bbf5041ff58..840ea2aab4752 100644
--- a/src/tools/tidy/src/issues.txt
+++ b/src/tools/tidy/src/issues.txt
@@ -844,7 +844,6 @@
 "ui/coroutine/issue-91477.rs",
 "ui/coroutine/issue-93161.rs",
 "ui/cross-crate/issue-64872/issue-64872.rs",
-"ui/cycle-trait/issue-12511.rs",
 "ui/debuginfo/issue-105386-debuginfo-ub.rs",
 "ui/deprecation/issue-66340-deprecated-attr-non-meta-grammar.rs",
 "ui/deprecation/issue-84637-deprecated-associated-function.rs",
diff --git a/tests/ui/cycle-trait/cycle-trait-default-type-trait.rs b/tests/ui/traits/cycle-trait/cycle-trait-default-type-trait.rs
similarity index 100%
rename from tests/ui/cycle-trait/cycle-trait-default-type-trait.rs
rename to tests/ui/traits/cycle-trait/cycle-trait-default-type-trait.rs
diff --git a/tests/ui/cycle-trait/cycle-trait-default-type-trait.stderr b/tests/ui/traits/cycle-trait/cycle-trait-default-type-trait.stderr
similarity index 100%
rename from tests/ui/cycle-trait/cycle-trait-default-type-trait.stderr
rename to tests/ui/traits/cycle-trait/cycle-trait-default-type-trait.stderr
diff --git a/tests/ui/cycle-trait/cycle-trait-supertrait-direct.rs b/tests/ui/traits/cycle-trait/cycle-trait-supertrait-direct.rs
similarity index 100%
rename from tests/ui/cycle-trait/cycle-trait-supertrait-direct.rs
rename to tests/ui/traits/cycle-trait/cycle-trait-supertrait-direct.rs
diff --git a/tests/ui/cycle-trait/cycle-trait-supertrait-direct.stderr b/tests/ui/traits/cycle-trait/cycle-trait-supertrait-direct.stderr
similarity index 100%
rename from tests/ui/cycle-trait/cycle-trait-supertrait-direct.stderr
rename to tests/ui/traits/cycle-trait/cycle-trait-supertrait-direct.stderr
diff --git a/tests/ui/cycle-trait/cycle-trait-supertrait-indirect.rs b/tests/ui/traits/cycle-trait/cycle-trait-supertrait-indirect.rs
similarity index 100%
rename from tests/ui/cycle-trait/cycle-trait-supertrait-indirect.rs
rename to tests/ui/traits/cycle-trait/cycle-trait-supertrait-indirect.rs
diff --git a/tests/ui/cycle-trait/cycle-trait-supertrait-indirect.stderr b/tests/ui/traits/cycle-trait/cycle-trait-supertrait-indirect.stderr
similarity index 100%
rename from tests/ui/cycle-trait/cycle-trait-supertrait-indirect.stderr
rename to tests/ui/traits/cycle-trait/cycle-trait-supertrait-indirect.stderr
diff --git a/tests/ui/cycle-trait/issue-12511.rs b/tests/ui/traits/cycle-trait/super-trait-issue-12511.rs
similarity index 100%
rename from tests/ui/cycle-trait/issue-12511.rs
rename to tests/ui/traits/cycle-trait/super-trait-issue-12511.rs
diff --git a/tests/ui/cycle-trait/issue-12511.stderr b/tests/ui/traits/cycle-trait/super-trait-issue-12511.stderr
similarity index 84%
rename from tests/ui/cycle-trait/issue-12511.stderr
rename to tests/ui/traits/cycle-trait/super-trait-issue-12511.stderr
index 0246bf219831a..68e0c34456c70 100644
--- a/tests/ui/cycle-trait/issue-12511.stderr
+++ b/tests/ui/traits/cycle-trait/super-trait-issue-12511.stderr
@@ -1,17 +1,17 @@
 error[E0391]: cycle detected when computing the super predicates of `T1`
-  --> $DIR/issue-12511.rs:1:12
+  --> $DIR/super-trait-issue-12511.rs:1:12
    |
 LL | trait T1 : T2 {
    |            ^^
    |
 note: ...which requires computing the super predicates of `T2`...
-  --> $DIR/issue-12511.rs:5:12
+  --> $DIR/super-trait-issue-12511.rs:5:12
    |
 LL | trait T2 : T1 {
    |            ^^
    = note: ...which again requires computing the super predicates of `T1`, completing the cycle
 note: cycle used when checking that `T1` is well-formed
-  --> $DIR/issue-12511.rs:1:1
+  --> $DIR/super-trait-issue-12511.rs:1:1
    |
 LL | / trait T1 : T2 {
 LL | |

From b0ad8d78fe7f264409355c1dc0283697d3904a5b Mon Sep 17 00:00:00 2001
From: lcnr <rust@lcnr.de>
Date: Mon, 11 Mar 2024 13:13:47 +0100
Subject: [PATCH 4/4] add tests

---
 tests/ui/snapshot/leaked-vars-issue-114056.rs | 10 ++++
 .../snapshot/leaked-vars-issue-114056.stderr  | 27 +++++++++++
 tests/ui/snapshot/leaked-vars-issue-122098.rs | 14 ++++++
 .../snapshot/leaked-vars-issue-122098.stderr  | 46 +++++++++++++++++++
 4 files changed, 97 insertions(+)
 create mode 100644 tests/ui/snapshot/leaked-vars-issue-114056.rs
 create mode 100644 tests/ui/snapshot/leaked-vars-issue-114056.stderr
 create mode 100644 tests/ui/snapshot/leaked-vars-issue-122098.rs
 create mode 100644 tests/ui/snapshot/leaked-vars-issue-122098.stderr

diff --git a/tests/ui/snapshot/leaked-vars-issue-114056.rs b/tests/ui/snapshot/leaked-vars-issue-114056.rs
new file mode 100644
index 0000000000000..039692306c59d
--- /dev/null
+++ b/tests/ui/snapshot/leaked-vars-issue-114056.rs
@@ -0,0 +1,10 @@
+// Regression test for #114056. Fixed by #111516.
+struct P<Q>(Q);
+impl<Q> P<Q> {
+    fn foo(&self) {
+        self.partial_cmp(())
+        //~^ ERROR the method `partial_cmp` exists for reference `&P<Q>`
+    }
+}
+
+fn main() {}
diff --git a/tests/ui/snapshot/leaked-vars-issue-114056.stderr b/tests/ui/snapshot/leaked-vars-issue-114056.stderr
new file mode 100644
index 0000000000000..f4c4f49028042
--- /dev/null
+++ b/tests/ui/snapshot/leaked-vars-issue-114056.stderr
@@ -0,0 +1,27 @@
+error[E0599]: the method `partial_cmp` exists for reference `&P<Q>`, but its trait bounds were not satisfied
+  --> $DIR/leaked-vars-issue-114056.rs:5:14
+   |
+LL | struct P<Q>(Q);
+   | ----------- doesn't satisfy `P<Q>: Iterator` or `P<Q>: PartialOrd<_>`
+...
+LL |         self.partial_cmp(())
+   |              ^^^^^^^^^^^ method cannot be called on `&P<Q>` due to unsatisfied trait bounds
+   |
+   = note: the following trait bounds were not satisfied:
+           `P<Q>: PartialOrd<_>`
+           which is required by `&P<Q>: PartialOrd<&_>`
+           `&P<Q>: Iterator`
+           which is required by `&mut &P<Q>: Iterator`
+           `P<Q>: Iterator`
+           which is required by `&mut P<Q>: Iterator`
+note: the trait `Iterator` must be implemented
+  --> $SRC_DIR/core/src/iter/traits/iterator.rs:LL:COL
+help: consider annotating `P<Q>` with `#[derive(PartialEq, PartialOrd)]`
+   |
+LL + #[derive(PartialEq, PartialOrd)]
+LL | struct P<Q>(Q);
+   |
+
+error: aborting due to 1 previous error
+
+For more information about this error, try `rustc --explain E0599`.
diff --git a/tests/ui/snapshot/leaked-vars-issue-122098.rs b/tests/ui/snapshot/leaked-vars-issue-122098.rs
new file mode 100644
index 0000000000000..dd9344d93980d
--- /dev/null
+++ b/tests/ui/snapshot/leaked-vars-issue-122098.rs
@@ -0,0 +1,14 @@
+// Regression test for #122098. Has been slightly minimized, fixed by #122189.
+trait LendingIterator: Sized {
+    type Item<'q>;
+
+    fn for_each(self, f: Box<dyn FnMut(Self::Item<'_>)>) {}
+}
+
+struct Query;
+fn main() {
+    LendingIterator::for_each(Query, Box::new);
+    //~^ ERROR the trait bound `Query: LendingIterator` is not satisfied
+    //~| ERROR mismatched types
+    //~| ERROR the trait bound `Query: LendingIterator` is not satisfied
+}
diff --git a/tests/ui/snapshot/leaked-vars-issue-122098.stderr b/tests/ui/snapshot/leaked-vars-issue-122098.stderr
new file mode 100644
index 0000000000000..7667349755152
--- /dev/null
+++ b/tests/ui/snapshot/leaked-vars-issue-122098.stderr
@@ -0,0 +1,46 @@
+error[E0277]: the trait bound `Query: LendingIterator` is not satisfied
+  --> $DIR/leaked-vars-issue-122098.rs:10:31
+   |
+LL |     LendingIterator::for_each(Query, Box::new);
+   |     ------------------------- ^^^^^ the trait `LendingIterator` is not implemented for `Query`
+   |     |
+   |     required by a bound introduced by this call
+   |
+help: this trait has no implementations, consider adding one
+  --> $DIR/leaked-vars-issue-122098.rs:2:1
+   |
+LL | trait LendingIterator: Sized {
+   | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+error[E0308]: mismatched types
+  --> $DIR/leaked-vars-issue-122098.rs:10:38
+   |
+LL |     LendingIterator::for_each(Query, Box::new);
+   |     -------------------------        ^^^^^^^^ expected `Box<dyn FnMut(...)>`, found fn item
+   |     |
+   |     arguments to this function are incorrect
+   |
+   = note: expected struct `Box<(dyn for<'a> FnMut(<Query as LendingIterator>::Item<'a>) + 'static)>`
+             found fn item `fn(_) -> Box<_> {Box::<_>::new}`
+note: method defined here
+  --> $DIR/leaked-vars-issue-122098.rs:5:8
+   |
+LL |     fn for_each(self, f: Box<dyn FnMut(Self::Item<'_>)>) {}
+   |        ^^^^^^^^       ---------------------------------
+
+error[E0277]: the trait bound `Query: LendingIterator` is not satisfied
+  --> $DIR/leaked-vars-issue-122098.rs:10:38
+   |
+LL |     LendingIterator::for_each(Query, Box::new);
+   |                                      ^^^^^^^^ the trait `LendingIterator` is not implemented for `Query`
+   |
+help: this trait has no implementations, consider adding one
+  --> $DIR/leaked-vars-issue-122098.rs:2:1
+   |
+LL | trait LendingIterator: Sized {
+   | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+error: aborting due to 3 previous errors
+
+Some errors have detailed explanations: E0277, E0308.
+For more information about an error, try `rustc --explain E0277`.