Skip to content

Commit fac72c8

Browse files
authored
Add resume_agile to allow coroutine to resume in any apartment (#1356)
1 parent 23c4ced commit fac72c8

File tree

4 files changed

+112
-57
lines changed

4 files changed

+112
-57
lines changed

strings/base_coroutine_foundation.h

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -99,44 +99,49 @@ namespace winrt::impl
9999
return async.GetResults();
100100
}
101101

102-
template<typename Awaiter>
103-
struct disconnect_aware_handler
102+
struct ignore_apartment_context {};
103+
104+
template<bool preserve_context, typename Awaiter>
105+
struct disconnect_aware_handler : private std::conditional_t<preserve_context, resume_apartment_context, ignore_apartment_context>
104106
{
105107
disconnect_aware_handler(Awaiter* awaiter, coroutine_handle<> handle) noexcept
106108
: m_awaiter(awaiter), m_handle(handle) { }
107109

108-
disconnect_aware_handler(disconnect_aware_handler&& other) noexcept
109-
: m_context(std::move(other.m_context))
110-
, m_awaiter(std::exchange(other.m_awaiter, {}))
111-
, m_handle(std::exchange(other.m_handle, {})) { }
110+
disconnect_aware_handler(disconnect_aware_handler&& other) = default;
112111

113112
~disconnect_aware_handler()
114113
{
115-
if (m_handle) Complete();
114+
if (m_handle.value) Complete();
116115
}
117116

118117
template<typename Async>
119118
void operator()(Async&&, Windows::Foundation::AsyncStatus status)
120119
{
121-
m_awaiter->status = status;
120+
m_awaiter.value->status = status;
122121
Complete();
123122
}
124123

125124
private:
126-
resume_apartment_context m_context;
127-
Awaiter* m_awaiter;
128-
coroutine_handle<> m_handle;
125+
movable_primitive<Awaiter*> m_awaiter;
126+
movable_primitive<coroutine_handle<>, nullptr> m_handle;
129127

130128
void Complete()
131129
{
132-
if (m_awaiter->suspending.exchange(false, std::memory_order_release))
130+
if (m_awaiter.value->suspending.exchange(false, std::memory_order_release))
133131
{
134-
m_handle = nullptr; // resumption deferred to await_suspend
132+
m_handle.value = nullptr; // resumption deferred to await_suspend
135133
}
136134
else
137135
{
138-
auto handle = std::exchange(m_handle, {});
139-
if (!resume_apartment(m_context, handle, &m_awaiter->failure))
136+
auto handle = m_handle.detach();
137+
if constexpr (preserve_context)
138+
{
139+
if (!resume_apartment(*this, handle, &m_awaiter.value->failure))
140+
{
141+
handle.resume();
142+
}
143+
}
144+
else
140145
{
141146
handle.resume();
142147
}
@@ -145,7 +150,7 @@ namespace winrt::impl
145150
};
146151

147152
#ifdef WINRT_IMPL_COROUTINES
148-
template <typename Async>
153+
template <typename Async, bool preserve_context = true>
149154
struct await_adapter : cancellable_awaiter<await_adapter<Async>>
150155
{
151156
await_adapter(Async const& async) : async(async) { }
@@ -185,7 +190,7 @@ namespace winrt::impl
185190
private:
186191
bool register_completed_callback(coroutine_handle<> handle)
187192
{
188-
async.Completed(disconnect_aware_handler(this, handle));
193+
async.Completed(disconnect_aware_handler<preserve_context, await_adapter>(this, handle));
189194
return suspending.exchange(false, std::memory_order_acquire);
190195
}
191196

@@ -249,6 +254,15 @@ namespace winrt::impl
249254
}
250255

251256
#ifdef WINRT_IMPL_COROUTINES
257+
WINRT_EXPORT namespace winrt
258+
{
259+
template<typename Async, typename = std::enable_if_t<std::is_convertible_v<Async, winrt::Windows::Foundation::IAsyncInfo>>>
260+
inline impl::await_adapter<Async, false> resume_agile(Async const& async)
261+
{
262+
return { async };
263+
};
264+
}
265+
252266
WINRT_EXPORT namespace winrt::Windows::Foundation
253267
{
254268
inline impl::await_adapter<IAsyncAction> operator co_await(IAsyncAction const& async)

strings/base_coroutine_threadpool.h

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,23 +52,14 @@ namespace winrt::impl
5252
{
5353
resume_apartment_context() = default;
5454
resume_apartment_context(std::nullptr_t) : m_context(nullptr), m_context_type(-1) {}
55-
resume_apartment_context(resume_apartment_context const&) = default;
56-
resume_apartment_context(resume_apartment_context&& other) noexcept :
57-
m_context(std::move(other.m_context)), m_context_type(std::exchange(other.m_context_type, -1)) {}
58-
resume_apartment_context& operator=(resume_apartment_context const&) = default;
59-
resume_apartment_context& operator=(resume_apartment_context&& other) noexcept
60-
{
61-
m_context = std::move(other.m_context);
62-
m_context_type = std::exchange(other.m_context_type, -1);
63-
return *this;
64-
}
55+
6556
bool valid() const noexcept
6657
{
67-
return m_context_type >= 0;
58+
return m_context_type.value >= 0;
6859
}
6960

7061
com_ptr<IContextCallback> m_context = try_capture<IContextCallback>(WINRT_IMPL_CoGetObjectContext);
71-
int32_t m_context_type = get_apartment_type().first;
62+
movable_primitive<int32_t, -1> m_context_type = get_apartment_type().first;
7263
};
7364

7465
inline int32_t __stdcall resume_apartment_callback(com_callback_args* args) noexcept
@@ -124,7 +115,7 @@ namespace winrt::impl
124115
{
125116
return false;
126117
}
127-
else if (context.m_context_type == 1 /* APTTYPE_MTA */)
118+
else if (context.m_context_type.value == 1 /* APTTYPE_MTA */)
128119
{
129120
resume_background(handle);
130121
return true;

strings/base_meta.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,25 @@ namespace winrt::impl
193193
}
194194
}
195195

196+
template<typename T, auto empty_value = T{}>
197+
struct movable_primitive
198+
{
199+
T value = empty_value;
200+
movable_primitive() = default;
201+
movable_primitive(T const& init) : value(init) {}
202+
movable_primitive(movable_primitive const&) = default;
203+
movable_primitive(movable_primitive&& other) :
204+
value(other.detach()) {}
205+
movable_primitive& operator=(movable_primitive const&) = default;
206+
movable_primitive& operator=(movable_primitive&& other)
207+
{
208+
value = other.detach();
209+
return *this;
210+
}
211+
212+
T detach() { return std::exchange(value, empty_value); }
213+
};
214+
196215
template <typename T, typename Enable = void>
197216
struct arg
198217
{

test/test/await_adapter.cpp

Lines changed: 58 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,9 @@ namespace
1818

1919
static handle signal{ CreateEventW(nullptr, false, false, nullptr) };
2020

21-
IAsyncAction OtherForegroundAsync()
21+
IAsyncAction OtherForegroundAsync(DispatcherQueue dispatcher)
2222
{
23-
// Simple coroutine that completes on a unique STA thread.
24-
25-
auto controller = DispatcherQueueController::CreateOnDedicatedThread();
26-
auto dispatcher = controller.DispatcherQueue();
27-
23+
// Simple coroutine that completes on the specified STA thread.
2824
co_await resume_foreground(dispatcher);
2925
}
3026

@@ -35,37 +31,37 @@ namespace
3531
co_await resume_background();
3632
}
3733

38-
IAsyncAction ForegroundAsync(DispatcherQueue dispatcher)
34+
// Coroutine that completes on dispatcher1, while potentially blocking dispatcher2.
35+
IAsyncAction ForegroundAsync(DispatcherQueue dispatcher1, DispatcherQueue dispatcher2)
3936
{
4037
REQUIRE(!is_sta());
41-
co_await resume_foreground(dispatcher);
38+
co_await resume_foreground(dispatcher1);
4239
REQUIRE(is_sta());
4340

4441
// This exercises one STA thread waiting on another thus one context callback
4542
// completing on another.
4643
uint32_t id = GetCurrentThreadId();
47-
co_await OtherForegroundAsync();
44+
co_await OtherForegroundAsync(dispatcher2);
4845
REQUIRE(id == GetCurrentThreadId());
4946

50-
// This just avoids the ForegroundAsync coroutine completing before
51-
// BackgroundAsync waits on the result, forcing the Completed handler
52-
// to be called on the foreground thread. This just makes the test
53-
// success/failure more predictable.
47+
// This Sleep() makes it more likely that the caller will actually suspend in await_suspend,
48+
// so that the Completed handler triggers a resumption from the dispatcher1 thread.
5449
Sleep(100);
5550
}
5651

57-
fire_and_forget SignalFromForeground(DispatcherQueue dispatcher)
52+
fire_and_forget SignalFromForeground(DispatcherQueue dispatcher1)
5853
{
5954
REQUIRE(!is_sta());
60-
co_await resume_foreground(dispatcher);
55+
co_await resume_foreground(dispatcher1);
6156
REQUIRE(is_sta());
6257

63-
// Previously, this signal was never raised because the foreground thread
64-
// was always blocked waiting for ContextCallback to return.
58+
// Previously, we never got here because of a deadlock:
59+
// The dispatcher1 thread was blocked waiting for ContextCallback to return,
60+
// but the ContextCallback is waiting for this event to get signaled.
6561
REQUIRE(SetEvent(signal.get()));
6662
}
6763

68-
IAsyncAction BackgroundAsync(DispatcherQueue dispatcher)
64+
IAsyncAction BackgroundAsync(DispatcherQueue dispatcher1, DispatcherQueue dispatcher2)
6965
{
7066
// Switch to a background (MTA) thread.
7167
co_await resume_background();
@@ -76,19 +72,19 @@ namespace
7672
co_await OtherBackgroundAsync();
7773
REQUIRE(!is_sta());
7874

79-
// Wait for a coroutine that completes on a foreground (STA) thread.
80-
co_await ForegroundAsync(dispatcher);
75+
// Wait for a coroutine that completes on a the dispatcher1 thread (STA).
76+
co_await ForegroundAsync(dispatcher1, dispatcher2);
8177

8278
// Resumption should automatically switch to a background (MTA) thread
83-
// without blocking the Completed handler (which would in turn block the foreground thread).
79+
// without blocking the Completed handler (which would in turn block the dispatcher1 thread).
8480
REQUIRE(!is_sta());
8581

86-
// Attempt to signal from the foreground thread under the assumption
87-
// that the foreground thread is not blocked.
88-
SignalFromForeground(dispatcher);
82+
// Attempt to signal from the dispatcher1 thread under the assumption
83+
// that the dispatcher1 thread is not blocked.
84+
SignalFromForeground(dispatcher1);
8985

90-
// Block the background (MTA) thread indefinitely until the signal is raied.
91-
// Previously this would deadlock.
86+
// Block the background (MTA) thread indefinitely until the signal is raised.
87+
// Previously this would hang because the signal never got raised.
9288
REQUIRE(WAIT_OBJECT_0 == WaitForSingleObject(signal.get(), INFINITE));
9389
}
9490
}
@@ -99,9 +95,44 @@ TEST_CASE("await_adapter", "[.clang-crash]")
9995
#else
10096
TEST_CASE("await_adapter")
10197
#endif
98+
{
99+
auto controller1 = DispatcherQueueController::CreateOnDedicatedThread();
100+
auto controller2 = DispatcherQueueController::CreateOnDedicatedThread();
101+
102+
BackgroundAsync(controller1.DispatcherQueue(), controller2.DispatcherQueue()).get();
103+
controller1.ShutdownQueueAsync().get();
104+
controller2.ShutdownQueueAsync().get();
105+
}
106+
107+
namespace
108+
{
109+
IAsyncAction OtherBackgroundDelayAsync()
110+
{
111+
// Simple coroutine that completes on some MTA thread after a brief delay
112+
// to ensure that the caller has suspended.
113+
114+
co_await resume_after(100ms);
115+
}
116+
117+
IAsyncAction AgileAsync(DispatcherQueue dispatcher)
118+
{
119+
// Switch to the STA.
120+
co_await resume_foreground(dispatcher);
121+
REQUIRE(is_sta());
122+
123+
// Ask for agile resumption of a coroutine that finishes on a background thread.
124+
// Add a 100ms delay to ensure we suspend.
125+
co_await resume_agile(OtherBackgroundDelayAsync());
126+
// We should be on the background thread now.
127+
REQUIRE(!is_sta());
128+
}
129+
}
130+
131+
TEST_CASE("await_adapter_agile")
102132
{
103133
auto controller = DispatcherQueueController::CreateOnDedicatedThread();
104134
auto dispatcher = controller.DispatcherQueue();
105135

106-
BackgroundAsync(dispatcher).get();
136+
AgileAsync(dispatcher).get();
137+
controller.ShutdownQueueAsync().get();
107138
}

0 commit comments

Comments
 (0)