From da2054f389fbf631b0082dd3cdaf8724f9e93f56 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Fri, 12 Jul 2024 13:07:16 -0400
Subject: [PATCH] Add cross-crate precise capturing support to rustdoc

---
 compiler/rustc_hir_analysis/src/collect.rs    | 21 ++++++++++++
 .../src/rmeta/decoder/cstore_impl.rs          | 10 ++++++
 compiler/rustc_metadata/src/rmeta/encoder.rs  | 10 ++++++
 compiler/rustc_metadata/src/rmeta/mod.rs      |  1 +
 compiler/rustc_middle/src/query/mod.rs        |  8 +++++
 src/librustdoc/clean/mod.rs                   | 34 +++++++++----------
 tests/rustdoc/auxiliary/precise-capturing.rs  |  7 ++++
 tests/rustdoc/impl-trait-precise-capturing.rs | 13 +++++++
 8 files changed, 87 insertions(+), 17 deletions(-)
 create mode 100644 tests/rustdoc/auxiliary/precise-capturing.rs

diff --git a/compiler/rustc_hir_analysis/src/collect.rs b/compiler/rustc_hir_analysis/src/collect.rs
index e0aad29916323..8318e1f8955c8 100644
--- a/compiler/rustc_hir_analysis/src/collect.rs
+++ b/compiler/rustc_hir_analysis/src/collect.rs
@@ -84,6 +84,7 @@ pub fn provide(providers: &mut Providers) {
         coroutine_kind,
         coroutine_for_closure,
         is_type_alias_impl_trait,
+        rendered_precise_capturing_args,
         ..*providers
     };
 }
@@ -1880,3 +1881,23 @@ fn is_type_alias_impl_trait<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> bool
         _ => bug!("tried getting opaque_ty_origin for non-opaque: {:?}", def_id),
     }
 }
+
+fn rendered_precise_capturing_args<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    def_id: LocalDefId,
+) -> Option<&'tcx [Symbol]> {
+    if let Some(ty::ImplTraitInTraitData::Trait { opaque_def_id, .. }) =
+        tcx.opt_rpitit_info(def_id.to_def_id())
+    {
+        return tcx.rendered_precise_capturing_args(opaque_def_id);
+    }
+
+    tcx.hir_node_by_def_id(def_id).expect_item().expect_opaque_ty().bounds.iter().find_map(
+        |bound| match bound {
+            hir::GenericBound::Use(args, ..) => {
+                Some(&*tcx.arena.alloc_from_iter(args.iter().map(|arg| arg.name())))
+            }
+            _ => None,
+        },
+    )
+}
diff --git a/compiler/rustc_metadata/src/rmeta/decoder/cstore_impl.rs b/compiler/rustc_metadata/src/rmeta/decoder/cstore_impl.rs
index 6b240f0f0b3de..bbd9ab5704fd8 100644
--- a/compiler/rustc_metadata/src/rmeta/decoder/cstore_impl.rs
+++ b/compiler/rustc_metadata/src/rmeta/decoder/cstore_impl.rs
@@ -72,6 +72,15 @@ impl<'a, 'tcx, T: Copy + Decodable<DecodeContext<'a, 'tcx>>> ProcessQueryValue<'
     }
 }
 
+impl<'a, 'tcx, T: Copy + Decodable<DecodeContext<'a, 'tcx>>>
+    ProcessQueryValue<'tcx, Option<&'tcx [T]>> for Option<DecodeIterator<'a, 'tcx, T>>
+{
+    #[inline(always)]
+    fn process_decoded(self, tcx: TyCtxt<'tcx>, _err: impl Fn() -> !) -> Option<&'tcx [T]> {
+        if let Some(iter) = self { Some(&*tcx.arena.alloc_from_iter(iter)) } else { None }
+    }
+}
+
 impl ProcessQueryValue<'_, Option<DeprecationEntry>> for Option<Deprecation> {
     #[inline(always)]
     fn process_decoded(self, _tcx: TyCtxt<'_>, _err: impl Fn() -> !) -> Option<DeprecationEntry> {
@@ -249,6 +258,7 @@ provide! { tcx, def_id, other, cdata,
             .process_decoded(tcx, || panic!("{def_id:?} does not have coerce_unsized_info"))) }
     mir_const_qualif => { table }
     rendered_const => { table }
+    rendered_precise_capturing_args => { table }
     asyncness => { table_direct }
     fn_arg_names => { table }
     coroutine_kind => { table_direct }
diff --git a/compiler/rustc_metadata/src/rmeta/encoder.rs b/compiler/rustc_metadata/src/rmeta/encoder.rs
index 209316ca20fdb..8596a0645e4a9 100644
--- a/compiler/rustc_metadata/src/rmeta/encoder.rs
+++ b/compiler/rustc_metadata/src/rmeta/encoder.rs
@@ -1496,6 +1496,7 @@ impl<'a, 'tcx> EncodeContext<'a, 'tcx> {
                 self.tables
                     .is_type_alias_impl_trait
                     .set(def_id.index, self.tcx.is_type_alias_impl_trait(def_id));
+                self.encode_precise_capturing_args(def_id);
             }
             if tcx.impl_method_has_trait_impl_trait_tys(def_id)
                 && let Ok(table) = self.tcx.collect_return_position_impl_trait_in_trait_tys(def_id)
@@ -1635,6 +1636,7 @@ impl<'a, 'tcx> EncodeContext<'a, 'tcx> {
                     self.tables.assumed_wf_types_for_rpitit[def_id]
                         <- self.tcx.assumed_wf_types_for_rpitit(def_id)
                 );
+                self.encode_precise_capturing_args(def_id);
             }
         }
         if item.is_effects_desugaring {
@@ -1642,6 +1644,14 @@ impl<'a, 'tcx> EncodeContext<'a, 'tcx> {
         }
     }
 
+    fn encode_precise_capturing_args(&mut self, def_id: DefId) {
+        let Some(precise_capturing_args) = self.tcx.rendered_precise_capturing_args(def_id) else {
+            return;
+        };
+
+        record_array!(self.tables.rendered_precise_capturing_args[def_id] <- precise_capturing_args);
+    }
+
     fn encode_mir(&mut self) {
         if self.is_proc_macro {
             return;
diff --git a/compiler/rustc_metadata/src/rmeta/mod.rs b/compiler/rustc_metadata/src/rmeta/mod.rs
index 2a44b3423ae2f..e565c8c1ea1c9 100644
--- a/compiler/rustc_metadata/src/rmeta/mod.rs
+++ b/compiler/rustc_metadata/src/rmeta/mod.rs
@@ -442,6 +442,7 @@ define_tables! {
     coerce_unsized_info: Table<DefIndex, LazyValue<ty::adjustment::CoerceUnsizedInfo>>,
     mir_const_qualif: Table<DefIndex, LazyValue<mir::ConstQualifs>>,
     rendered_const: Table<DefIndex, LazyValue<String>>,
+    rendered_precise_capturing_args: Table<DefIndex, LazyArray<Symbol>>,
     asyncness: Table<DefIndex, ty::Asyncness>,
     fn_arg_names: Table<DefIndex, LazyArray<Ident>>,
     coroutine_kind: Table<DefIndex, hir::CoroutineKind>,
diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs
index 817c7157b6820..c7ea1d4338366 100644
--- a/compiler/rustc_middle/src/query/mod.rs
+++ b/compiler/rustc_middle/src/query/mod.rs
@@ -1261,6 +1261,7 @@ rustc_queries! {
         desc { |tcx| "looking up function parameter names for `{}`", tcx.def_path_str(def_id) }
         separate_provide_extern
     }
+
     /// Gets the rendered value of the specified constant or associated constant.
     /// Used by rustdoc.
     query rendered_const(def_id: DefId) -> &'tcx String {
@@ -1268,6 +1269,13 @@ rustc_queries! {
         desc { |tcx| "rendering constant initializer of `{}`", tcx.def_path_str(def_id) }
         separate_provide_extern
     }
+
+    /// Gets the rendered precise capturing args for an opaque for use in rustdoc.
+    query rendered_precise_capturing_args(def_id: DefId) -> Option<&'tcx [Symbol]> {
+        desc { |tcx| "rendering precise capturing args for `{}`", tcx.def_path_str(def_id) }
+        separate_provide_extern
+    }
+
     query impl_parent(def_id: DefId) -> Option<DefId> {
         desc { |tcx| "computing specialization parent impl of `{}`", tcx.def_path_str(def_id) }
         separate_provide_extern
diff --git a/src/librustdoc/clean/mod.rs b/src/librustdoc/clean/mod.rs
index a0e28d2f55c70..4dd065f568087 100644
--- a/src/librustdoc/clean/mod.rs
+++ b/src/librustdoc/clean/mod.rs
@@ -461,13 +461,7 @@ fn clean_projection<'tcx>(
     def_id: Option<DefId>,
 ) -> Type {
     if cx.tcx.is_impl_trait_in_trait(ty.skip_binder().def_id) {
-        let bounds = cx
-            .tcx
-            .explicit_item_bounds(ty.skip_binder().def_id)
-            .iter_instantiated_copied(cx.tcx, ty.skip_binder().args)
-            .map(|(pred, _)| pred)
-            .collect::<Vec<_>>();
-        return clean_middle_opaque_bounds(cx, bounds);
+        return clean_middle_opaque_bounds(cx, ty.skip_binder().def_id, ty.skip_binder().args);
     }
 
     let trait_ = clean_trait_ref_with_constraints(
@@ -2243,13 +2237,7 @@ pub(crate) fn clean_middle_ty<'tcx>(
                 *cx.current_type_aliases.entry(def_id).or_insert(0) += 1;
                 // Grab the "TraitA + TraitB" from `impl TraitA + TraitB`,
                 // by looking up the bounds associated with the def_id.
-                let bounds = cx
-                    .tcx
-                    .explicit_item_bounds(def_id)
-                    .iter_instantiated_copied(cx.tcx, args)
-                    .map(|(bound, _)| bound)
-                    .collect::<Vec<_>>();
-                let ty = clean_middle_opaque_bounds(cx, bounds);
+                let ty = clean_middle_opaque_bounds(cx, def_id, args);
                 if let Some(count) = cx.current_type_aliases.get_mut(&def_id) {
                     *count -= 1;
                     if *count == 0 {
@@ -2272,12 +2260,20 @@ pub(crate) fn clean_middle_ty<'tcx>(
 
 fn clean_middle_opaque_bounds<'tcx>(
     cx: &mut DocContext<'tcx>,
-    bounds: Vec<ty::Clause<'tcx>>,
+    impl_trait_def_id: DefId,
+    args: ty::GenericArgsRef<'tcx>,
 ) -> Type {
     let mut has_sized = false;
+
+    let bounds: Vec<_> = cx
+        .tcx
+        .explicit_item_bounds(impl_trait_def_id)
+        .iter_instantiated_copied(cx.tcx, args)
+        .collect();
+
     let mut bounds = bounds
         .iter()
-        .filter_map(|bound| {
+        .filter_map(|(bound, _)| {
             let bound_predicate = bound.kind();
             let trait_ref = match bound_predicate.skip_binder() {
                 ty::ClauseKind::Trait(tr) => bound_predicate.rebind(tr.trait_ref),
@@ -2296,7 +2292,7 @@ fn clean_middle_opaque_bounds<'tcx>(
 
             let bindings: ThinVec<_> = bounds
                 .iter()
-                .filter_map(|bound| {
+                .filter_map(|(bound, _)| {
                     if let ty::ClauseKind::Projection(proj) = bound.kind().skip_binder() {
                         if proj.projection_term.trait_ref(cx.tcx) == trait_ref.skip_binder() {
                             Some(AssocItemConstraint {
@@ -2336,6 +2332,10 @@ fn clean_middle_opaque_bounds<'tcx>(
         bounds.insert(0, GenericBound::sized(cx));
     }
 
+    if let Some(args) = cx.tcx.rendered_precise_capturing_args(impl_trait_def_id) {
+        bounds.push(GenericBound::Use(args.to_vec()));
+    }
+
     ImplTrait(bounds)
 }
 
diff --git a/tests/rustdoc/auxiliary/precise-capturing.rs b/tests/rustdoc/auxiliary/precise-capturing.rs
new file mode 100644
index 0000000000000..531d4dfdccc6e
--- /dev/null
+++ b/tests/rustdoc/auxiliary/precise-capturing.rs
@@ -0,0 +1,7 @@
+#![feature(precise_capturing)]
+
+pub fn cross_crate_empty() -> impl Sized + use<> {}
+
+pub fn cross_crate_missing() -> impl Sized {}
+
+pub fn cross_crate_args<'a, T, const N: usize>() -> impl Sized + use<'a, T, N> {}
diff --git a/tests/rustdoc/impl-trait-precise-capturing.rs b/tests/rustdoc/impl-trait-precise-capturing.rs
index d1987a555c151..a964a1f8518f5 100644
--- a/tests/rustdoc/impl-trait-precise-capturing.rs
+++ b/tests/rustdoc/impl-trait-precise-capturing.rs
@@ -1,6 +1,10 @@
+//@ aux-build:precise-capturing.rs
+
 #![crate_name = "foo"]
 #![feature(precise_capturing)]
 
+extern crate precise_capturing;
+
 //@ has foo/fn.two.html '//section[@id="main-content"]//pre' "-> impl Sized + use<'b, 'a>"
 pub fn two<'a, 'b, 'c>() -> impl Sized + use<'b, 'a /* no 'c */> {}
 
@@ -12,3 +16,12 @@ pub fn none() -> impl Sized + use<> {}
 
 //@ has foo/fn.first.html '//section[@id="main-content"]//pre' "-> impl use<> + Sized"
 pub fn first() -> impl use<> + Sized {}
+
+//@ has foo/fn.cross_crate_empty.html '//section[@id="main-content"]//pre' "-> impl Sized + use<>"
+pub use precise_capturing::cross_crate_empty;
+
+//@ matches foo/fn.cross_crate_missing.html '//section[@id="main-content"]//pre' "-> impl Sized$"
+pub use precise_capturing::cross_crate_missing;
+
+//@ has foo/fn.cross_crate_args.html '//section[@id="main-content"]//pre' "-> impl Sized + use<'a, T, N>"
+pub use precise_capturing::cross_crate_args;