Skip to content

Commit 65581a3

Browse files
authored
Add capture support for unconventional result types (#1301)
1 parent c3b7fcf commit 65581a3

File tree

2 files changed

+53
-3
lines changed

2 files changed

+53
-3
lines changed

strings/base_com_ptr.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,27 @@ WINRT_EXPORT namespace winrt
77

88
namespace winrt::impl
99
{
10+
struct capture_decay
11+
{
12+
void** result;
13+
14+
template <typename T>
15+
operator T** ()
16+
{
17+
return reinterpret_cast<T**>(result);
18+
}
19+
};
20+
1021
template <typename T, typename F, typename...Args>
1122
int32_t capture_to(void**result, F function, Args&& ...args)
1223
{
13-
return function(args..., guid_of<T>(), result);
24+
return function(args..., guid_of<T>(), capture_decay{ result });
1425
}
1526

1627
template <typename T, typename O, typename M, typename...Args, std::enable_if_t<std::is_class_v<O> || std::is_union_v<O>, int> = 0>
1728
int32_t capture_to(void** result, O* object, M method, Args&& ...args)
1829
{
19-
return (object->*method)(args..., guid_of<T>(), result);
30+
return (object->*method)(args..., guid_of<T>(), capture_decay{ result });
2031
}
2132

2233
template <typename T, typename O, typename M, typename...Args>
@@ -343,7 +354,7 @@ namespace winrt::impl
343354
template <typename T, typename O, typename M, typename...Args>
344355
int32_t capture_to(void** result, com_ptr<O> const& object, M method, Args&& ...args)
345356
{
346-
return (object.get()->*(method))(args..., guid_of<T>(), result);
357+
return (object.get()->*(method))(args..., guid_of<T>(), capture_decay{ result });
347358
}
348359
}
349360

test/old_tests/UnitTests/capture.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ struct DECLSPEC_UUID("5fb96f8d-409c-42a9-99a7-8a95c1459dbd") ICapture : ::IUnkno
88
{
99
virtual int32_t __stdcall GetValue() noexcept = 0;
1010
virtual int32_t __stdcall CreateMemberCapture(int32_t value, GUID const& iid, void** object) noexcept = 0;
11+
virtual int32_t __stdcall CreateMemberCapture2(int32_t value, GUID const& iid, ::IUnknown** object) noexcept = 0;
1112
};
1213

1314
#ifdef __CRT_UUID_DECL
@@ -33,6 +34,12 @@ struct Capture : implements<Capture, ICapture>
3334
auto capture = make<Capture>(value);
3435
return capture->QueryInterface(iid, object);
3536
}
37+
38+
int32_t __stdcall CreateMemberCapture2(int32_t value, GUID const& iid, ::IUnknown** object) noexcept override
39+
{
40+
auto capture = make<Capture>(value);
41+
return capture->QueryInterface(iid, reinterpret_cast<void**>(object));
42+
}
3643
};
3744

3845
HRESULT __stdcall CreateCapture(int value, GUID const& iid, void** object) noexcept
@@ -41,6 +48,12 @@ HRESULT __stdcall CreateCapture(int value, GUID const& iid, void** object) noexc
4148
return capture->QueryInterface(iid, object);
4249
}
4350

51+
HRESULT __stdcall CreateCapture2(int value, GUID const& iid, ::IInspectable** object) noexcept
52+
{
53+
auto capture = make<Capture>(value);
54+
return capture->QueryInterface(iid, reinterpret_cast<void**>(object));
55+
}
56+
4457
TEST_CASE("capture")
4558
{
4659
// Capture from global function.
@@ -67,6 +80,19 @@ TEST_CASE("capture")
6780

6881
com_ptr<IDispatch> d;
6982

83+
// Capture with an unconventional result type.
84+
auto e = capture<ICapture>(a, &ICapture::CreateMemberCapture2, 30);
85+
REQUIRE(e->GetValue() == 30);
86+
e = nullptr;
87+
e.capture(a, &ICapture::CreateMemberCapture2, 40);
88+
REQUIRE(e->GetValue() == 40);
89+
90+
com_ptr<ICapture> f = capture<ICapture>(CreateCapture2, 10);
91+
REQUIRE(f->GetValue() == 10);
92+
f = nullptr;
93+
f.capture(CreateCapture2, 20);
94+
REQUIRE(a->GetValue() == 20);
95+
7096
REQUIRE_THROWS_AS(capture<IDispatch>(CreateCapture, 0), hresult_no_interface);
7197
REQUIRE_THROWS_AS(d.capture(CreateCapture, 0), hresult_no_interface);
7298
REQUIRE_THROWS_AS(capture<IDispatch>(a, &ICapture::CreateMemberCapture, 0), hresult_no_interface);
@@ -104,6 +130,19 @@ TEST_CASE("try_capture")
104130

105131
com_ptr<IDispatch> d;
106132

133+
// Capture with an unconventional result type.
134+
auto e = try_capture<ICapture>(a, &ICapture::CreateMemberCapture2, 30);
135+
REQUIRE(e->GetValue() == 30);
136+
e = nullptr;
137+
REQUIRE(e.try_capture(a, &ICapture::CreateMemberCapture2, 40));
138+
REQUIRE(e->GetValue() == 40);
139+
140+
com_ptr<ICapture> f = try_capture<ICapture>(CreateCapture2, 10);
141+
REQUIRE(f->GetValue() == 10);
142+
f = nullptr;
143+
REQUIRE(f.try_capture(CreateCapture2, 20));
144+
REQUIRE(f->GetValue() == 20);
145+
107146
REQUIRE(!try_capture<IDispatch>(CreateCapture, 0));
108147
REQUIRE(!d.try_capture(CreateCapture, 0));
109148
REQUIRE(!try_capture<IDispatch>(a, &ICapture::CreateMemberCapture, 0));

0 commit comments

Comments
 (0)