Skip to content

Commit bbbd081

Browse files
[NFC][SYCL] Update memory_manager to pass context_impl by raw pointers (#18966)
Part of the ongoing refactoring to prefer raw ptr/ref for SYCL RT objects by default with explicit `shared_from_this` when lifetimes need to be extended.
1 parent 382f37f commit bbbd081

13 files changed

+114
-114
lines changed

sycl/source/detail/buffer_impl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace detail {
2121
#ifdef XPTI_ENABLE_INSTRUMENTATION
2222
uint8_t GBufferStreamID;
2323
#endif
24-
void *buffer_impl::allocateMem(ContextImplPtr Context, bool InitFromUserData,
24+
void *buffer_impl::allocateMem(context_impl *Context, bool InitFromUserData,
2525
void *HostPtr,
2626
ur_event_handle_t &OutEventToWait) {
2727
bool HostPtrReadOnly = false;
@@ -30,9 +30,9 @@ void *buffer_impl::allocateMem(ContextImplPtr Context, bool InitFromUserData,
3030
"Internal error. Allocating memory on the host "
3131
"while having use_host_ptr property");
3232
return MemoryManager::allocateMemBuffer(
33-
std::move(Context), this, HostPtr, HostPtrReadOnly,
34-
BaseT::getSizeInBytes(), BaseT::MInteropEvent, BaseT::MInteropContext,
35-
MProps, OutEventToWait);
33+
Context, this, HostPtr, HostPtrReadOnly, BaseT::getSizeInBytes(),
34+
BaseT::MInteropEvent, BaseT::MInteropContext.get(), MProps,
35+
OutEventToWait);
3636
}
3737
void buffer_impl::constructorNotification(const detail::code_location &CodeLoc,
3838
void *UserObj, const void *HostObj,

sycl/source/detail/buffer_impl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ class buffer_impl final : public SYCLMemObjT {
129129
: BaseT(MemObject, SyclContext, OwnNativeHandle,
130130
std::move(AvailableEvent), std::move(Allocator)) {}
131131

132-
void *allocateMem(ContextImplPtr Context, bool InitFromUserData,
133-
void *HostPtr, ur_event_handle_t &OutEventToWait) override;
132+
void *allocateMem(context_impl *Context, bool InitFromUserData, void *HostPtr,
133+
ur_event_handle_t &OutEventToWait) override;
134134
void constructorNotification(const detail::code_location &CodeLoc,
135135
void *UserObj, const void *HostObj,
136136
const void *Type, uint32_t Dim,

sycl/source/detail/device_global_map_entry.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,19 @@ DeviceGlobalMapEntry::getOrAllocateDeviceGlobalUSM(queue_impl &QueueImpl) {
4646
assert(!MIsDeviceImageScopeDecorated &&
4747
"USM allocations should not be acquired for device_global with "
4848
"device_image_scope property.");
49-
const std::shared_ptr<context_impl> &CtxImpl = QueueImpl.getContextImplPtr();
49+
context_impl &CtxImpl = QueueImpl.getContextImpl();
5050
const device_impl &DevImpl = QueueImpl.getDeviceImpl();
5151
std::lock_guard<std::mutex> Lock(MDeviceToUSMPtrMapMutex);
5252

53-
auto DGUSMPtr = MDeviceToUSMPtrMap.find({&DevImpl, CtxImpl.get()});
53+
auto DGUSMPtr = MDeviceToUSMPtrMap.find({&DevImpl, &CtxImpl});
5454
if (DGUSMPtr != MDeviceToUSMPtrMap.end())
5555
return DGUSMPtr->second;
5656

5757
void *NewDGUSMPtr = detail::usm::alignedAllocInternal(
58-
0, MDeviceGlobalTSize, CtxImpl.get(), &DevImpl, sycl::usm::alloc::device);
58+
0, MDeviceGlobalTSize, &CtxImpl, &DevImpl, sycl::usm::alloc::device);
5959

6060
auto NewAllocIt = MDeviceToUSMPtrMap.emplace(
61-
std::piecewise_construct, std::forward_as_tuple(&DevImpl, CtxImpl.get()),
61+
std::piecewise_construct, std::forward_as_tuple(&DevImpl, &CtxImpl),
6262
std::forward_as_tuple(NewDGUSMPtr));
6363
assert(NewAllocIt.second &&
6464
"USM allocation for device and context already happened.");
@@ -83,7 +83,7 @@ DeviceGlobalMapEntry::getOrAllocateDeviceGlobalUSM(queue_impl &QueueImpl) {
8383
NewAlloc.MInitEvent = InitEvent;
8484
}
8585

86-
CtxImpl->addAssociatedDeviceGlobal(MDeviceGlobalPtr);
86+
CtxImpl.addAssociatedDeviceGlobal(MDeviceGlobalPtr);
8787
return NewAlloc;
8888
}
8989

@@ -92,22 +92,20 @@ DeviceGlobalMapEntry::getOrAllocateDeviceGlobalUSM(const context &Context) {
9292
assert(!MIsDeviceImageScopeDecorated &&
9393
"USM allocations should not be acquired for device_global with "
9494
"device_image_scope property.");
95-
const std::shared_ptr<context_impl> &CtxImpl = getSyclObjImpl(Context);
95+
context_impl &CtxImpl = *getSyclObjImpl(Context);
9696
const std::shared_ptr<device_impl> &DevImpl =
97-
getSyclObjImpl(CtxImpl->getDevices().front());
97+
getSyclObjImpl(CtxImpl.getDevices().front());
9898
std::lock_guard<std::mutex> Lock(MDeviceToUSMPtrMapMutex);
9999

100-
auto DGUSMPtr = MDeviceToUSMPtrMap.find({DevImpl.get(), CtxImpl.get()});
100+
auto DGUSMPtr = MDeviceToUSMPtrMap.find({DevImpl.get(), &CtxImpl});
101101
if (DGUSMPtr != MDeviceToUSMPtrMap.end())
102102
return DGUSMPtr->second;
103103

104104
void *NewDGUSMPtr = detail::usm::alignedAllocInternal(
105-
0, MDeviceGlobalTSize, CtxImpl.get(), DevImpl.get(),
106-
sycl::usm::alloc::device);
105+
0, MDeviceGlobalTSize, &CtxImpl, DevImpl.get(), sycl::usm::alloc::device);
107106

108107
auto NewAllocIt = MDeviceToUSMPtrMap.emplace(
109-
std::piecewise_construct,
110-
std::forward_as_tuple(DevImpl.get(), CtxImpl.get()),
108+
std::piecewise_construct, std::forward_as_tuple(DevImpl.get(), &CtxImpl),
111109
std::forward_as_tuple(NewDGUSMPtr));
112110
assert(NewAllocIt.second &&
113111
"USM allocation for device and context already happened.");
@@ -123,9 +121,9 @@ DeviceGlobalMapEntry::getOrAllocateDeviceGlobalUSM(const context &Context) {
123121
reinterpret_cast<const void *>(
124122
reinterpret_cast<uintptr_t>(MDeviceGlobalPtr) +
125123
sizeof(MDeviceGlobalPtr)),
126-
CtxImpl, MDeviceGlobalTSize, NewAlloc.MPtr);
124+
&CtxImpl, MDeviceGlobalTSize, NewAlloc.MPtr);
127125

128-
CtxImpl->addAssociatedDeviceGlobal(MDeviceGlobalPtr);
126+
CtxImpl.addAssociatedDeviceGlobal(MDeviceGlobalPtr);
129127
return NewAlloc;
130128
}
131129

sycl/source/detail/image_impl.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,9 @@ image_channel_type convertChannelType(ur_image_channel_type_t Type) {
259259
}
260260

261261
template <typename T>
262-
static void getImageInfo(const ContextImplPtr &Context, ur_image_info_t Info,
263-
T &Dest, ur_mem_handle_t InteropMemObject) {
264-
const AdapterPtr &Adapter = Context->getAdapter();
262+
static void getImageInfo(context_impl &Context, ur_image_info_t Info, T &Dest,
263+
ur_mem_handle_t InteropMemObject) {
264+
const AdapterPtr &Adapter = Context.getAdapter();
265265
Adapter->call<UrApiKind::urMemImageGetInfo>(InteropMemObject, Info, sizeof(T),
266266
&Dest, nullptr);
267267
}
@@ -274,8 +274,8 @@ image_impl::image_impl(cl_mem MemObject, const context &SyclContext,
274274
std::move(Allocator)),
275275
MDimensions(Dimensions), MRange({0, 0, 0}) {
276276
ur_mem_handle_t Mem = ur::cast<ur_mem_handle_t>(BaseT::MInteropMemObject);
277-
const ContextImplPtr &Context = getSyclObjImpl(SyclContext);
278-
const AdapterPtr &Adapter = Context->getAdapter();
277+
detail::context_impl &Context = *getSyclObjImpl(SyclContext);
278+
const AdapterPtr &Adapter = Context.getAdapter();
279279
Adapter->call<UrApiKind::urMemGetInfo>(Mem, UR_MEM_INFO_SIZE, sizeof(size_t),
280280
&(BaseT::MSizeInBytes), nullptr);
281281

@@ -323,7 +323,7 @@ image_impl::image_impl(ur_native_handle_t MemObject, const context &SyclContext,
323323
setPitches(); // sets MRowPitch, MSlice and BaseT::MSizeInBytes
324324
}
325325

326-
void *image_impl::allocateMem(ContextImplPtr Context, bool InitFromUserData,
326+
void *image_impl::allocateMem(context_impl *Context, bool InitFromUserData,
327327
void *HostPtr,
328328
ur_event_handle_t &OutEventToWait) {
329329
bool HostPtrReadOnly = false;
@@ -338,13 +338,13 @@ void *image_impl::allocateMem(ContextImplPtr Context, bool InitFromUserData,
338338
"The check an image format failed.");
339339

340340
return MemoryManager::allocateMemImage(
341-
std::move(Context), this, HostPtr, HostPtrReadOnly,
342-
BaseT::getSizeInBytes(), Desc, Format, BaseT::MInteropEvent,
343-
BaseT::MInteropContext, MProps, OutEventToWait);
341+
Context, this, HostPtr, HostPtrReadOnly, BaseT::getSizeInBytes(), Desc,
342+
Format, BaseT::MInteropEvent, BaseT::MInteropContext.get(), MProps,
343+
OutEventToWait);
344344
}
345345

346346
bool image_impl::checkImageDesc(const ur_image_desc_t &Desc,
347-
ContextImplPtr Context, void *UserPtr) {
347+
context_impl *Context, void *UserPtr) {
348348
if (checkAny(Desc.type, UR_MEM_TYPE_IMAGE1D, UR_MEM_TYPE_IMAGE1D_ARRAY,
349349
UR_MEM_TYPE_IMAGE2D_ARRAY, UR_MEM_TYPE_IMAGE2D) &&
350350
!checkImageValueRange<info::device::image2d_max_width>(
@@ -409,7 +409,7 @@ bool image_impl::checkImageDesc(const ur_image_desc_t &Desc,
409409
}
410410

411411
bool image_impl::checkImageFormat(const ur_image_format_t &Format,
412-
ContextImplPtr Context) {
412+
context_impl *Context) {
413413
(void)Context;
414414
if (checkAny(Format.channelOrder, UR_IMAGE_CHANNEL_ORDER_INTENSITY,
415415
UR_IMAGE_CHANNEL_ORDER_LUMINANCE) &&
@@ -451,7 +451,7 @@ bool image_impl::checkImageFormat(const ur_image_format_t &Format,
451451
return true;
452452
}
453453

454-
std::vector<device> image_impl::getDevices(const ContextImplPtr Context) {
454+
std::vector<device> image_impl::getDevices(context_impl *Context) {
455455
if (!Context)
456456
return {};
457457
return Context->get_info<info::context::devices>();

sycl/source/detail/image_impl.hpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,8 @@ class image_impl final : public SYCLMemObjT {
254254
std::abort();
255255
}
256256

257-
void *allocateMem(ContextImplPtr Context, bool InitFromUserData,
258-
void *HostPtr, ur_event_handle_t &OutEventToWait) override;
257+
void *allocateMem(context_impl *Context, bool InitFromUserData, void *HostPtr,
258+
ur_event_handle_t &OutEventToWait) override;
259259

260260
MemObjType getType() const override { return MemObjType::Image; }
261261

@@ -298,7 +298,7 @@ class image_impl final : public SYCLMemObjT {
298298
void unsampledImageDestructorNotification(void *UserObj);
299299

300300
private:
301-
std::vector<device> getDevices(const ContextImplPtr Context);
301+
std::vector<device> getDevices(context_impl *Context);
302302

303303
ur_mem_type_t getImageType() {
304304
if (MDimensions == 1)
@@ -330,7 +330,7 @@ class image_impl final : public SYCLMemObjT {
330330
return Desc;
331331
}
332332

333-
bool checkImageDesc(const ur_image_desc_t &Desc, ContextImplPtr Context,
333+
bool checkImageDesc(const ur_image_desc_t &Desc, context_impl *Context,
334334
void *UserPtr);
335335

336336
ur_image_format_t getImageFormat() {
@@ -340,8 +340,7 @@ class image_impl final : public SYCLMemObjT {
340340
return Format;
341341
}
342342

343-
bool checkImageFormat(const ur_image_format_t &Format,
344-
ContextImplPtr Context);
343+
bool checkImageFormat(const ur_image_format_t &Format, context_impl *Context);
345344

346345
uint8_t MDimensions = 0;
347346
bool MIsArrayImage = false;

0 commit comments

Comments
 (0)