Skip to content

Commit c023eb4

Browse files
[NFC][SYCL] Use raw context_impl & in event_impl::[set|get]Context (#19007)
Continuation of the refactoring in #18795 #18877 #18966 #18979 #18980 #18981
1 parent e2bd09d commit c023eb4

File tree

10 files changed

+30
-33
lines changed

10 files changed

+30
-33
lines changed

sycl/source/detail/event_impl.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ void event_impl::initContextIfNeeded() {
3838
return;
3939

4040
const device SyclDevice;
41-
this->setContextImpl(
42-
detail::queue_impl::getDefaultOrNew(*detail::getSyclObjImpl(SyclDevice)));
41+
MIsHostEvent = false;
42+
MContext =
43+
detail::queue_impl::getDefaultOrNew(*detail::getSyclObjImpl(SyclDevice));
44+
assert(MContext);
4345
}
4446

4547
event_impl::~event_impl() {
@@ -140,9 +142,10 @@ void event_impl::setHandle(const ur_event_handle_t &UREvent) {
140142
MEvent.store(UREvent);
141143
}
142144

143-
const ContextImplPtr &event_impl::getContextImpl() {
145+
context_impl &event_impl::getContextImpl() {
144146
initContextIfNeeded();
145-
return MContext;
147+
assert(MContext && "Trying to get context from a host event!");
148+
return *MContext;
146149
}
147150

148151
const AdapterPtr &event_impl::getAdapter() {
@@ -152,9 +155,9 @@ const AdapterPtr &event_impl::getAdapter() {
152155

153156
void event_impl::setStateIncomplete() { MState = HES_NotComplete; }
154157

155-
void event_impl::setContextImpl(const ContextImplPtr &Context) {
156-
MIsHostEvent = Context == nullptr;
157-
MContext = Context;
158+
void event_impl::setContextImpl(context_impl &Context) {
159+
MIsHostEvent = false;
160+
MContext = Context.shared_from_this();
158161
}
159162

160163
event_impl::event_impl(ur_event_handle_t Event, const context &SyclContext,
@@ -178,7 +181,7 @@ event_impl::event_impl(ur_event_handle_t Event, const context &SyclContext,
178181
event_impl::event_impl(queue_impl &Queue, private_tag)
179182
: MQueue{Queue.weak_from_this()},
180183
MIsProfilingEnabled{Queue.MIsProfilingEnabled} {
181-
this->setContextImpl(Queue.getContextImplPtr());
184+
this->setContextImpl(Queue.getContextImpl());
182185
MState.store(HES_Complete);
183186
}
184187

sycl/source/detail/event_impl.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,21 +173,17 @@ class event_impl : public std::enable_shared_from_this<event_impl> {
173173
void setHandle(const ur_event_handle_t &UREvent);
174174

175175
/// Returns context that is associated with this event.
176-
///
177-
/// \return a shared pointer to a valid context_impl.
178-
const ContextImplPtr &getContextImpl();
176+
context_impl &getContextImpl();
179177

180178
/// \return the Adapter associated with the context of this event.
181179
/// Should be called when this is not a Host Event.
182180
const AdapterPtr &getAdapter();
183181

184182
/// Associate event with the context.
185183
///
186-
/// Provided UrContext inside ContextImplPtr must be associated
184+
/// Provided UrContext inside Context must be associated
187185
/// with the UrEvent object stored in this class
188-
///
189-
/// @param Context is a shared pointer to an instance of valid context_impl.
190-
void setContextImpl(const ContextImplPtr &Context);
186+
void setContextImpl(context_impl &Context);
191187

192188
/// Clear the event state
193189
void setStateIncomplete();

sycl/source/detail/graph_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ exec_graph_impl::enqueue(sycl::detail::queue_impl &Queue,
10371037

10381038
auto CreateNewEvent([&]() {
10391039
auto NewEvent = sycl::detail::event_impl::create_device_event(Queue);
1040-
NewEvent->setContextImpl(Queue.getContextImplPtr());
1040+
NewEvent->setContextImpl(Queue.getContextImpl());
10411041
NewEvent->setStateIncomplete();
10421042
return NewEvent;
10431043
});

sycl/source/detail/queue_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ queue_impl::get_backend_info<info::device::backend_version>() const {
121121
static event prepareSYCLEventAssociatedWithQueue(
122122
const std::shared_ptr<detail::queue_impl> &QueueImpl) {
123123
auto EventImpl = detail::event_impl::create_device_event(*QueueImpl);
124-
EventImpl->setContextImpl(detail::getSyclObjImpl(QueueImpl->get_context()));
124+
EventImpl->setContextImpl(QueueImpl->getContextImpl());
125125
EventImpl->setStateIncomplete();
126126
return detail::createSyclObjFromImpl<event>(EventImpl);
127127
}

sycl/source/detail/reduction.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ __SYCL_EXPORT void
208208
addCounterInit(handler &CGH, std::shared_ptr<sycl::detail::queue_impl> &Queue,
209209
std::shared_ptr<int> &Counter) {
210210
auto EventImpl = detail::event_impl::create_device_event(*Queue);
211-
EventImpl->setContextImpl(detail::getSyclObjImpl(Queue->get_context()));
211+
EventImpl->setContextImpl(Queue->getContextImpl());
212212
EventImpl->setStateIncomplete();
213213
ur_event_handle_t UREvent = nullptr;
214214
MemoryManager::fill_usm(Counter.get(), *Queue, sizeof(int), {0}, {},

sycl/source/detail/scheduler/commands.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -533,10 +533,8 @@ void Command::waitForEvents(queue_impl *Queue,
533533
RequiredEventsPerContext;
534534

535535
for (const EventImplPtr &Event : EventImpls) {
536-
ContextImplPtr Context = Event->getContextImpl();
537-
assert(Context.get() &&
538-
"Only non-host events are expected to be waited for here");
539-
RequiredEventsPerContext[Context.get()].push_back(Event);
536+
context_impl &Context = Event->getContextImpl();
537+
RequiredEventsPerContext[&Context].push_back(Event);
540538
}
541539

542540
for (auto &CtxWithEvents : RequiredEventsPerContext) {
@@ -576,7 +574,7 @@ Command::Command(
576574
MEvent->setSubmittedQueue(MWorkerQueue);
577575
MEvent->setCommand(this);
578576
if (MQueue)
579-
MEvent->setContextImpl(MQueue->getContextImplPtr());
577+
MEvent->setContextImpl(MQueue->getContextImpl());
580578
MEvent->setStateIncomplete();
581579
MEnqueueStatus = EnqueueResultT::SyclEnqueueReady;
582580

@@ -781,9 +779,9 @@ Command *Command::processDepEvent(EventImplPtr DepEvent, const DepDesc &Dep,
781779

782780
Command *ConnectionCmd = nullptr;
783781

784-
ContextImplPtr DepEventContext = DepEvent->getContextImpl();
782+
context_impl &DepEventContext = DepEvent->getContextImpl();
785783
// If contexts don't match we'll connect them using host task
786-
if (DepEventContext != WorkerContext && WorkerContext) {
784+
if (&DepEventContext != WorkerContext.get() && WorkerContext) {
787785
Scheduler::GraphBuilder &GB = Scheduler::getInstance().MGraphBuilder;
788786
ConnectionCmd = GB.connectDepEvent(this, DepEvent, Dep, ToCleanUp);
789787
} else
@@ -1298,7 +1296,7 @@ ur_result_t ReleaseCommand::enqueueImp() {
12981296

12991297
std::shared_ptr<event_impl> UnmapEventImpl =
13001298
event_impl::create_device_event(*Queue);
1301-
UnmapEventImpl->setContextImpl(Queue->getContextImplPtr());
1299+
UnmapEventImpl->setContextImpl(Queue->getContextImpl());
13021300
UnmapEventImpl->setStateIncomplete();
13031301
ur_event_handle_t UREvent = nullptr;
13041302

@@ -1516,7 +1514,7 @@ MemCpyCommand::MemCpyCommand(Requirement SrcReq,
15161514
MSrcReq(std::move(SrcReq)), MSrcAllocaCmd(SrcAllocaCmd),
15171515
MDstReq(std::move(DstReq)), MDstAllocaCmd(DstAllocaCmd) {
15181516
if (MSrcQueue) {
1519-
MEvent->setContextImpl(MSrcQueue->getContextImplPtr());
1517+
MEvent->setContextImpl(MSrcQueue->getContextImpl());
15201518
}
15211519

15221520
MWorkerQueue = !MQueue ? MSrcQueue : MQueue;
@@ -1689,7 +1687,7 @@ MemCpyCommandHost::MemCpyCommandHost(Requirement SrcReq,
16891687
MSrcReq(std::move(SrcReq)), MSrcAllocaCmd(SrcAllocaCmd),
16901688
MDstReq(std::move(DstReq)), MDstPtr(DstPtr) {
16911689
if (MSrcQueue) {
1692-
MEvent->setContextImpl(MSrcQueue->getContextImplPtr());
1690+
MEvent->setContextImpl(MSrcQueue->getContextImpl());
16931691
}
16941692

16951693
MWorkerQueue = !MQueue ? MSrcQueue : MQueue;

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,7 @@ void Scheduler::GraphBuilder::removeRecordForMemObj(SYCLMemObjI *MemObject) {
12211221
Command *Scheduler::GraphBuilder::connectDepEvent(
12221222
Command *const Cmd, const EventImplPtr &DepEvent, const DepDesc &Dep,
12231223
std::vector<Command *> &ToCleanUp) {
1224-
assert(Cmd->getWorkerContext() != DepEvent->getContextImpl());
1224+
assert(Cmd->getWorkerContext().get() != &DepEvent->getContextImpl());
12251225

12261226
// construct Host Task type command manually and make it depend on DepEvent
12271227
ExecCGCommand *ConnectCmd = nullptr;

sycl/source/detail/scheduler/scheduler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ bool Scheduler::CheckEventReadiness(context_impl &Context,
699699
return SyclEventImplPtr->isCompleted();
700700
}
701701
// Cross-context dependencies can't be passed to the backend directly.
702-
if (SyclEventImplPtr->getContextImpl().get() != &Context)
702+
if (&SyclEventImplPtr->getContextImpl() != &Context)
703703
return false;
704704

705705
// A nullptr here means that the commmand does not produce a UR event or it

sycl/source/handler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ event handler::finalize() {
610610
detail::queue_impl &Queue = impl->get_queue();
611611
LastEventImpl->setQueue(Queue);
612612
LastEventImpl->setWorkerQueue(Queue.weak_from_this());
613-
LastEventImpl->setContextImpl(impl->get_context().shared_from_this());
613+
LastEventImpl->setContextImpl(impl->get_context());
614614
LastEventImpl->setStateIncomplete();
615615
LastEventImpl->setSubmissionTime();
616616

sycl/unittests/scheduler/QueueFlushing.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ TEST_F(SchedulerTest, QueueFlushing) {
150150
access::mode::read_write};
151151
std::shared_ptr<detail::event_impl> DepEvent =
152152
detail::event_impl::create_device_event(QueueImplB);
153-
DepEvent->setContextImpl(QueueImplB.getContextImplPtr());
153+
DepEvent->setContextImpl(QueueImplB.getContextImpl());
154154

155155
ur_event_handle_t UREvent = mock::createDummyHandle<ur_event_handle_t>();
156156

@@ -170,7 +170,7 @@ TEST_F(SchedulerTest, QueueFlushing) {
170170
queue TempQueue{Ctx, default_selector_v};
171171
detail::queue_impl &TempQueueImpl = *detail::getSyclObjImpl(TempQueue);
172172
DepEvent = detail::event_impl::create_device_event(TempQueueImpl);
173-
DepEvent->setContextImpl(TempQueueImpl.getContextImplPtr());
173+
DepEvent->setContextImpl(TempQueueImpl.getContextImpl());
174174

175175
ur_event_handle_t UREvent = mock::createDummyHandle<ur_event_handle_t>();
176176

0 commit comments

Comments
 (0)