Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 43051ca

Browse files
committedSep 7, 2024·
Prereq6 for async drop - templated coroutine processing and layout
1 parent 175820b commit 43051ca

File tree

12 files changed

+271
-40
lines changed

12 files changed

+271
-40
lines changed
 

‎compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs‎

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -669,8 +669,7 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
669669
_ => unreachable!(),
670670
};
671671

672-
let coroutine_layout =
673-
cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.kind_ty()).unwrap();
672+
let coroutine_layout = cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.args).unwrap();
674673

675674
let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
676675
let variant_range = coroutine_args.variant_range(coroutine_def_id, cx.tcx);

‎compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs‎

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,8 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
146146
DIFlags::FlagZero,
147147
),
148148
|cx, coroutine_type_di_node| {
149-
let coroutine_layout = cx
150-
.tcx
151-
.coroutine_layout(coroutine_def_id, coroutine_args.as_coroutine().kind_ty())
152-
.unwrap();
149+
let coroutine_layout =
150+
cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args).unwrap();
153151

154152
let Variants::Multiple { tag_encoding: TagEncoding::Direct, ref variants, .. } =
155153
coroutine_type_and_layout.variants

‎compiler/rustc_middle/src/arena.rs‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ macro_rules! arena_types {
99
($macro:path) => (
1010
$macro!([
1111
[] layout: rustc_target::abi::LayoutS<rustc_target::abi::FieldIdx, rustc_target::abi::VariantIdx>,
12+
[] proxy_coroutine_layout: rustc_middle::mir::CoroutineLayout<'tcx>,
1213
[] fn_abi: rustc_target::abi::call::FnAbi<'tcx, rustc_middle::ty::Ty<'tcx>>,
1314
// AdtDef are interned and compared by address
1415
[decode] adt_def: rustc_middle::ty::AdtDefData,

‎compiler/rustc_middle/src/query/mod.rs‎

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,13 @@ rustc_queries! {
506506
desc { |tcx| "elaborating drops for `{}`", tcx.def_path_str(key) }
507507
}
508508

509+
query templated_mir_drops_elaborated_and_const_checked(ty: Ty<'tcx>)
510+
-> &'tcx Steal<mir::Body<'tcx>>
511+
{
512+
no_hash
513+
desc { |tcx| "elaborating drops for templated mir `{}`", ty }
514+
}
515+
509516
query mir_for_ctfe(
510517
key: DefId
511518
) -> &'tcx mir::Body<'tcx> {
@@ -570,6 +577,11 @@ rustc_queries! {
570577
desc { |tcx| "checking for `#[coverage(..)]` on `{}`", tcx.def_path_str(key) }
571578
}
572579

580+
/// MIR for templated coroutine after our optimization passes have run.
581+
query templated_optimized_mir(ty: Ty<'tcx>) -> &'tcx mir::Body<'tcx> {
582+
desc { |tcx| "optimizing templated MIR for `{}`", ty }
583+
}
584+
573585
/// Summarizes coverage IDs inserted by the `InstrumentCoverage` MIR pass
574586
/// (for compiler option `-Cinstrument-coverage`), after MIR optimizations
575587
/// have had a chance to potentially remove some of them.
@@ -1161,7 +1173,11 @@ rustc_queries! {
11611173
/// Generates a MIR body for the shim.
11621174
query mir_shims(key: ty::InstanceKind<'tcx>) -> &'tcx mir::Body<'tcx> {
11631175
arena_cache
1164-
desc { |tcx| "generating MIR shim for `{}`", tcx.def_path_str(key.def_id()) }
1176+
desc {
1177+
|tcx| "generating MIR shim for `{}`, instance={:?}",
1178+
tcx.def_path_str(key.def_id()),
1179+
key
1180+
}
11651181
}
11661182

11671183
/// The `symbol_name` query provides the symbol name for calling a

‎compiler/rustc_middle/src/ty/layout.rs‎

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -896,22 +896,57 @@ where
896896
i,
897897
),
898898

899-
ty::Coroutine(def_id, args) => match this.variants {
900-
Variants::Single { index } => TyMaybeWithLayout::Ty(
901-
args.as_coroutine()
902-
.state_tys(def_id, tcx)
903-
.nth(index.as_usize())
904-
.unwrap()
905-
.nth(i)
906-
.unwrap(),
907-
),
908-
Variants::Multiple { tag, tag_field, .. } => {
909-
if i == tag_field {
910-
return TyMaybeWithLayout::TyAndLayout(tag_layout(tag));
899+
ty::Coroutine(def_id, args) => {
900+
// layout of `async_drop_in_place<T>::{closure}` in case,
901+
// when T is a coroutine, contains this internal coroutine's ref
902+
if tcx.is_templated_coroutine(def_id) {
903+
fn find_impl_coroutine<'tcx>(
904+
tcx: TyCtxt<'tcx>,
905+
mut cor_ty: Ty<'tcx>,
906+
) -> Ty<'tcx> {
907+
let mut ty = cor_ty;
908+
loop {
909+
if let ty::Coroutine(def_id, args) = ty.kind() {
910+
cor_ty = ty;
911+
if tcx.is_templated_coroutine(*def_id) {
912+
ty = args.first().unwrap().expect_ty();
913+
continue;
914+
} else {
915+
return cor_ty;
916+
}
917+
} else {
918+
return cor_ty;
919+
}
920+
}
921+
}
922+
let arg_cor_ty = args.first().unwrap().expect_ty();
923+
if arg_cor_ty.is_coroutine() {
924+
assert!(i == 0);
925+
let impl_cor_ty = find_impl_coroutine(tcx, arg_cor_ty);
926+
return TyMaybeWithLayout::Ty(Ty::new_mut_ref(
927+
tcx,
928+
tcx.lifetimes.re_static,
929+
impl_cor_ty,
930+
));
911931
}
912-
TyMaybeWithLayout::Ty(args.as_coroutine().prefix_tys()[i])
913932
}
914-
},
933+
match this.variants {
934+
Variants::Single { index } => TyMaybeWithLayout::Ty(
935+
args.as_coroutine()
936+
.state_tys(def_id, tcx)
937+
.nth(index.as_usize())
938+
.unwrap()
939+
.nth(i)
940+
.unwrap(),
941+
),
942+
Variants::Multiple { tag, tag_field, .. } => {
943+
if i == tag_field {
944+
return TyMaybeWithLayout::TyAndLayout(tag_layout(tag));
945+
}
946+
TyMaybeWithLayout::Ty(args.as_coroutine().prefix_tys()[i])
947+
}
948+
}
949+
}
915950

916951
ty::Tuple(tys) => TyMaybeWithLayout::Ty(tys[i]),
917952

‎compiler/rustc_middle/src/ty/mod.rs‎

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ use rustc_errors::{Diag, ErrorGuaranteed, StashKey};
3636
use rustc_hir::def::{CtorKind, CtorOf, DefKind, DocLinkResMap, LifetimeRes, Res};
3737
use rustc_hir::def_id::{CrateNum, DefId, DefIdMap, LocalDefId, LocalDefIdMap};
3838
use rustc_hir::LangItem;
39+
use rustc_index::bit_set::BitMatrix;
3940
use rustc_index::IndexVec;
4041
use rustc_macros::{
4142
extension, Decodable, Encodable, HashStable, TyDecodable, TyEncodable, TypeFoldable,
@@ -110,7 +111,7 @@ pub use self::IntVarValue::*;
110111
use crate::error::{OpaqueHiddenTypeMismatch, TypeMismatchReason};
111112
use crate::metadata::ModChild;
112113
use crate::middle::privacy::EffectiveVisibilities;
113-
use crate::mir::{Body, CoroutineLayout};
114+
use crate::mir::{Body, CoroutineLayout, CoroutineSavedLocal, CoroutineSavedTy, SourceInfo};
114115
use crate::query::Providers;
115116
use crate::traits::{self, Reveal};
116117
use crate::ty;
@@ -1771,7 +1772,7 @@ impl<'tcx> TyCtxt<'tcx> {
17711772
| ty::InstanceKind::FnPtrAddrShim(..)
17721773
| ty::InstanceKind::AsyncDropGlueCtorShim(..) => self.mir_shims(instance),
17731774
// async drop glue should be processed specifically, as a templated coroutine
1774-
ty::InstanceKind::AsyncDropGlue(_, _ty) => todo!(),
1775+
ty::InstanceKind::AsyncDropGlue(_, ty) => self.templated_optimized_mir(ty),
17751776
}
17761777
}
17771778

@@ -1851,16 +1852,17 @@ impl<'tcx> TyCtxt<'tcx> {
18511852
self.def_kind(trait_def_id) == DefKind::TraitAlias
18521853
}
18531854

1854-
/// Returns layout of a coroutine. Layout might be unavailable if the
1855+
/// Returns layout of a non-templated coroutine. Layout might be unavailable if the
18551856
/// coroutine is tainted by errors.
18561857
///
18571858
/// Takes `coroutine_kind` which can be acquired from the `CoroutineArgs::kind_ty`,
18581859
/// e.g. `args.as_coroutine().kind_ty()`.
1859-
pub fn coroutine_layout(
1860+
pub fn ordinary_coroutine_layout(
18601861
self,
18611862
def_id: DefId,
18621863
coroutine_kind_ty: Ty<'tcx>,
18631864
) -> Option<&'tcx CoroutineLayout<'tcx>> {
1865+
debug_assert_ne!(Some(def_id), self.lang_items().async_drop_in_place_poll_fn());
18641866
let mir = self.optimized_mir(def_id);
18651867
// Regular coroutine
18661868
if coroutine_kind_ty.is_unit() {
@@ -1890,6 +1892,66 @@ impl<'tcx> TyCtxt<'tcx> {
18901892
}
18911893
}
18921894

1895+
/// Returns layout of a templated coroutine. Layout might be unavailable if the
1896+
/// coroutine is tainted by errors. Atm, the only templated coroutine is
1897+
/// `async_drop_in_place<T>::{closure}` returned from `async fn async_drop_in_place<T>(..)`.
1898+
pub fn templated_coroutine_layout(self, ty: Ty<'tcx>) -> Option<&'tcx CoroutineLayout<'tcx>> {
1899+
self.templated_optimized_mir(ty).coroutine_layout_raw()
1900+
}
1901+
1902+
/// Returns layout of a templated (or not) coroutine. Layout might be unavailable if the
1903+
/// coroutine is tainted by errors.
1904+
pub fn coroutine_layout(
1905+
self,
1906+
def_id: DefId,
1907+
args: GenericArgsRef<'tcx>,
1908+
) -> Option<&'tcx CoroutineLayout<'tcx>> {
1909+
if Some(def_id) == self.lang_items().async_drop_in_place_poll_fn() {
1910+
fn find_impl_coroutine<'tcx>(tcx: TyCtxt<'tcx>, mut cor_ty: Ty<'tcx>) -> Ty<'tcx> {
1911+
let mut ty = cor_ty;
1912+
loop {
1913+
if let ty::Coroutine(def_id, args) = ty.kind() {
1914+
cor_ty = ty;
1915+
if tcx.is_templated_coroutine(*def_id) {
1916+
ty = args.first().unwrap().expect_ty();
1917+
continue;
1918+
} else {
1919+
return cor_ty;
1920+
}
1921+
} else {
1922+
return cor_ty;
1923+
}
1924+
}
1925+
}
1926+
// layout of `async_drop_in_place<T>::{closure}` in case,
1927+
// when T is a coroutine, contains this internal coroutine's ref
1928+
let arg_cor_ty = args.first().unwrap().expect_ty();
1929+
if arg_cor_ty.is_coroutine() {
1930+
let impl_cor_ty = find_impl_coroutine(self, arg_cor_ty);
1931+
let impl_ref = Ty::new_mut_ref(self, self.lifetimes.re_static, impl_cor_ty);
1932+
let span = self.def_span(def_id);
1933+
let source_info = SourceInfo::outermost(span);
1934+
let proxy_layout = CoroutineLayout {
1935+
field_tys: [CoroutineSavedTy {
1936+
ty: impl_ref,
1937+
source_info,
1938+
ignore_for_traits: true,
1939+
}]
1940+
.into(),
1941+
field_names: [None].into(),
1942+
variant_fields: [IndexVec::from([CoroutineSavedLocal::ZERO])].into(),
1943+
variant_source_info: [source_info].into(),
1944+
storage_conflicts: BitMatrix::new(1, 1),
1945+
};
1946+
return Some(self.arena.alloc(proxy_layout));
1947+
} else {
1948+
self.templated_coroutine_layout(Ty::new_coroutine(self, def_id, args))
1949+
}
1950+
} else {
1951+
self.ordinary_coroutine_layout(def_id, args.as_coroutine().kind_ty())
1952+
}
1953+
}
1954+
18931955
/// Given the `DefId` of an impl, returns the `DefId` of the trait it implements.
18941956
/// If it implements no trait, returns `None`.
18951957
pub fn trait_id_of_impl(self, def_id: DefId) -> Option<DefId> {

‎compiler/rustc_middle/src/ty/sty.rs‎

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ impl<'tcx> ty::CoroutineArgs<TyCtxt<'tcx>> {
7878
#[inline]
7979
fn variant_range(&self, def_id: DefId, tcx: TyCtxt<'tcx>) -> Range<VariantIdx> {
8080
// FIXME requires optimized MIR
81-
FIRST_VARIANT
82-
..tcx.coroutine_layout(def_id, tcx.types.unit).unwrap().variant_fields.next_index()
81+
FIRST_VARIANT..tcx.coroutine_layout(def_id, self.args).unwrap().variant_fields.next_index()
8382
}
8483

8584
/// The discriminant for the given variant. Panics if the `variant_index` is
@@ -139,10 +138,14 @@ impl<'tcx> ty::CoroutineArgs<TyCtxt<'tcx>> {
139138
def_id: DefId,
140139
tcx: TyCtxt<'tcx>,
141140
) -> impl Iterator<Item: Iterator<Item = Ty<'tcx>> + Captures<'tcx>> {
142-
let layout = tcx.coroutine_layout(def_id, self.kind_ty()).unwrap();
141+
let layout = tcx.coroutine_layout(def_id, self.args).unwrap();
143142
layout.variant_fields.iter().map(move |variant| {
144143
variant.iter().map(move |field| {
145-
ty::EarlyBinder::bind(layout.field_tys[*field].ty).instantiate(tcx, self.args)
144+
if tcx.is_templated_coroutine(def_id) {
145+
layout.field_tys[*field].ty
146+
} else {
147+
ty::EarlyBinder::bind(layout.field_tys[*field].ty).instantiate(tcx, self.args)
148+
}
146149
})
147150
})
148151
}

‎compiler/rustc_mir_dataflow/src/value_analysis.rs‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,9 @@ impl<'tcx> Map<'tcx> {
792792
if exclude.contains(local) {
793793
continue;
794794
}
795+
if decl.ty.is_templated_coroutine(tcx) {
796+
continue;
797+
}
795798

796799
// Create a place for the local.
797800
debug_assert!(self.locals[local].is_none());

‎compiler/rustc_mir_transform/src/known_panics_lint.rs‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,11 @@ impl CanConstProp {
883883
};
884884
for (local, val) in cpv.can_const_prop.iter_enumerated_mut() {
885885
let ty = body.local_decls[local].ty;
886-
if ty.is_union() {
886+
if ty.is_templated_coroutine(tcx) {
887+
// No const propagation for templated coroutine (AsyncDropGlue)
888+
*val = ConstPropMode::NoPropagation;
889+
continue;
890+
} else if ty.is_union() {
887891
// Unions are incompatible with the current implementation of
888892
// const prop because Rust has no concept of an active
889893
// variant of a union

‎compiler/rustc_mir_transform/src/lib.rs‎

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use rustc_middle::mir::{
2929
MirPhase, Operand, Place, ProjectionElem, Promoted, RuntimePhase, Rvalue, SourceInfo,
3030
Statement, StatementKind, TerminatorKind, START_BLOCK,
3131
};
32-
use rustc_middle::ty::{self, TyCtxt, TypeVisitableExt};
32+
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt};
3333
use rustc_middle::util::Providers;
3434
use rustc_middle::{bug, query, span_bug};
3535
use rustc_span::source_map::Spanned;
@@ -121,9 +121,11 @@ pub fn provide(providers: &mut Providers) {
121121
mir_const_qualif,
122122
mir_promoted,
123123
mir_drops_elaborated_and_const_checked,
124+
templated_mir_drops_elaborated_and_const_checked,
124125
mir_for_ctfe,
125126
mir_coroutine_witnesses: coroutine::mir_coroutine_witnesses,
126127
optimized_mir,
128+
templated_optimized_mir,
127129
is_mir_available,
128130
is_ctfe_mir_available: is_mir_available,
129131
mir_callgraph_reachable: inline::cycle::mir_callgraph_reachable,
@@ -459,6 +461,21 @@ fn mir_drops_elaborated_and_const_checked(tcx: TyCtxt<'_>, def: LocalDefId) -> &
459461
tcx.alloc_steal_mir(body)
460462
}
461463

464+
/// mir_drops_elaborated_and_const_checked simplified analog for templated coroutine
465+
fn templated_mir_drops_elaborated_and_const_checked<'tcx>(
466+
tcx: TyCtxt<'tcx>,
467+
ty: Ty<'tcx>,
468+
) -> &'tcx Steal<Body<'tcx>> {
469+
let ty::Coroutine(def_id, _) = ty.kind() else {
470+
bug!();
471+
};
472+
assert!(ty.is_templated_coroutine(tcx));
473+
474+
let instance = ty::InstanceKind::AsyncDropGlue(*def_id, ty);
475+
let body = tcx.mir_shims(instance).clone();
476+
tcx.alloc_steal_mir(body)
477+
}
478+
462479
// Made public such that `mir_drops_elaborated_and_const_checked` can be overridden
463480
// by custom rustc drivers, running all the steps by themselves.
464481
pub fn run_analysis_to_runtime_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
@@ -511,6 +528,7 @@ fn run_runtime_lowering_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
511528
&reveal_all::RevealAll, // has to be done before drop elaboration, since we need to drop opaque types, too.
512529
&add_subtyping_projections::Subtyper, // calling this after reveal_all ensures that we don't deal with opaque types
513530
&elaborate_drops::ElaborateDrops,
531+
&reveal_all::RevealAll,
514532
// This will remove extraneous landing pads which are no longer
515533
// necessary as well as forcing any call in a non-unwinding
516534
// function calling a possibly-unwinding function to abort the process.
@@ -523,6 +541,7 @@ fn run_runtime_lowering_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
523541
&add_retag::AddRetag,
524542
&elaborate_box_derefs::ElaborateBoxDerefs,
525543
&coroutine::StateTransform,
544+
&reveal_all::RevealAll,
526545
&Lint(known_panics_lint::KnownPanicsLint),
527546
];
528547
pm::run_passes_no_validate(tcx, body, passes, Some(MirPhase::Runtime(RuntimePhase::Initial)));
@@ -623,6 +642,11 @@ fn optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> &Body<'_> {
623642
tcx.arena.alloc(inner_optimized_mir(tcx, did))
624643
}
625644

645+
/// Optimize the templated MIR and prepare it for codegen.
646+
fn templated_optimized_mir<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> &'tcx Body<'tcx> {
647+
tcx.arena.alloc(inner_templated_optimized_mir(tcx, ty))
648+
}
649+
626650
fn inner_optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> Body<'_> {
627651
if tcx.is_constructor(did.to_def_id()) {
628652
// There's no reason to run all of the MIR passes on constructors when
@@ -667,6 +691,29 @@ fn inner_optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> Body<'_> {
667691
body
668692
}
669693

694+
fn inner_templated_optimized_mir<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Body<'tcx> {
695+
debug!("about to call templated_mir_drops_elaborated...");
696+
let body = tcx.templated_mir_drops_elaborated_and_const_checked(ty).steal();
697+
let mut body = remap_mir_for_const_eval_select(tcx, body, hir::Constness::NotConst);
698+
699+
if body.tainted_by_errors.is_some() {
700+
return body;
701+
}
702+
703+
// If `mir_drops_elaborated_and_const_checked` found that the current body has unsatisfiable
704+
// predicates, it will shrink the MIR to a single `unreachable` terminator.
705+
// More generally, if MIR is a lone `unreachable`, there is nothing to optimize.
706+
if let TerminatorKind::Unreachable = body.basic_blocks[START_BLOCK].terminator().kind
707+
&& body.basic_blocks[START_BLOCK].statements.is_empty()
708+
{
709+
return body;
710+
}
711+
712+
run_optimization_passes(tcx, &mut body);
713+
714+
body
715+
}
716+
670717
/// Fetch all the promoteds of an item and prepare their MIR bodies to be ready for
671718
/// constant evaluation once all generic parameters become known.
672719
fn promoted_mir(tcx: TyCtxt<'_>, def: LocalDefId) -> &IndexVec<Promoted, Body<'_>> {

‎compiler/rustc_mir_transform/src/validate.rs‎

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -716,12 +716,13 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
716716
// args of the coroutine. Otherwise, we prefer to use this body
717717
// since we may be in the process of computing this MIR in the
718718
// first place.
719-
let layout = if def_id == self.caller_body.source.def_id() {
720-
// FIXME: This is not right for async closures.
721-
self.caller_body.coroutine_layout_raw()
722-
} else {
723-
self.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty())
724-
};
719+
let layout = (def_id == self.caller_body.source.def_id())
720+
.then(
721+
// FIXME: This is not right for async closures.
722+
|| self.caller_body.coroutine_layout_raw(),
723+
)
724+
.flatten()
725+
.or_else(|| self.tcx.coroutine_layout(def_id, args));
725726

726727
let Some(layout) = layout else {
727728
self.fail(

‎compiler/rustc_ty_utils/src/layout.rs‎

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::fmt::Debug;
22
use std::iter;
3+
use std::ops::Bound;
34

45
use hir::def_id::DefId;
56
use rustc_hir as hir;
@@ -112,8 +113,60 @@ fn layout_of_uncached<'tcx>(
112113
if let Err(guar) = ty.error_reported() {
113114
return Err(error(cx, LayoutError::ReferencesError(guar)));
114115
}
115-
116116
let tcx = cx.tcx;
117+
118+
// layout of `async_drop_in_place<T>::{closure}` in case,
119+
// when T is a coroutine, contains this internal coroutine's ref
120+
if let ty::Coroutine(cor_def, cor_args) = ty.kind()
121+
&& tcx.is_templated_coroutine(*cor_def)
122+
{
123+
let arg_cor_ty = cor_args.first().unwrap().expect_ty();
124+
if arg_cor_ty.is_coroutine() {
125+
fn find_impl_coroutine<'tcx>(tcx: TyCtxt<'tcx>, mut cor_ty: Ty<'tcx>) -> Ty<'tcx> {
126+
let mut ty = cor_ty;
127+
loop {
128+
if let ty::Coroutine(def_id, args) = ty.kind() {
129+
cor_ty = ty;
130+
if tcx.is_templated_coroutine(*def_id) {
131+
ty = args.first().unwrap().expect_ty();
132+
continue;
133+
} else {
134+
return cor_ty;
135+
}
136+
} else {
137+
return cor_ty;
138+
}
139+
}
140+
}
141+
let repr = ReprOptions {
142+
int: None,
143+
align: None,
144+
pack: None,
145+
flags: ReprFlags::empty(),
146+
field_shuffle_seed: 0,
147+
};
148+
let impl_cor = find_impl_coroutine(tcx, arg_cor_ty);
149+
let impl_ref = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, impl_cor);
150+
let ref_layout = cx.layout_of(impl_ref)?.layout;
151+
let variants: IndexVec<VariantIdx, IndexVec<FieldIdx, Layout<'tcx>>> =
152+
[IndexVec::from([ref_layout])].into();
153+
let Some(layout) = cx.layout_of_struct_or_enum(
154+
&repr,
155+
&variants,
156+
false,
157+
false,
158+
(Bound::Unbounded, Bound::Unbounded),
159+
|_, _| (Integer::I128, false),
160+
None::<(VariantIdx, i128)>.into_iter(),
161+
false,
162+
true,
163+
) else {
164+
return Err(error(cx, LayoutError::SizeOverflow(ty)));
165+
};
166+
return Ok(tcx.mk_layout(layout));
167+
}
168+
}
169+
117170
let param_env = cx.param_env;
118171
let dl = cx.data_layout();
119172
let scalar_unit = |value: Primitive| {
@@ -816,9 +869,18 @@ fn coroutine_layout<'tcx>(
816869
) -> Result<Layout<'tcx>, &'tcx LayoutError<'tcx>> {
817870
use SavedLocalEligibility::*;
818871
let tcx = cx.tcx;
872+
let layout = if tcx.is_templated_coroutine(def_id) {
873+
// layout of `async_drop_in_place<T>::{closure}` in case,
874+
// when T is a coroutine, contains this internal coroutine's ref
875+
// and must be proceed above, in layout_of_uncached
876+
assert!(!args.first().unwrap().expect_ty().is_coroutine());
877+
tcx.templated_coroutine_layout(ty)
878+
} else {
879+
tcx.ordinary_coroutine_layout(def_id, args.as_coroutine().kind_ty())
880+
};
819881
let instantiate_field = |ty: Ty<'tcx>| EarlyBinder::bind(ty).instantiate(tcx, args);
820882

821-
let Some(info) = tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()) else {
883+
let Some(info) = layout else {
822884
return Err(error(cx, LayoutError::Unknown(ty)));
823885
};
824886
let (ineligible_locals, assignments) = coroutine_saved_local_eligibility(info);
@@ -1145,7 +1207,7 @@ fn variant_info_for_coroutine<'tcx>(
11451207
return (vec![], None);
11461208
};
11471209

1148-
let coroutine = cx.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()).unwrap();
1210+
let coroutine = cx.tcx.coroutine_layout(def_id, args).unwrap();
11491211
let upvar_names = cx.tcx.closure_saved_names_of_captured_variables(def_id);
11501212

11511213
let mut upvars_size = Size::ZERO;

0 commit comments

Comments
 (0)
Please sign in to comment.