Skip to content

Cuda graph implementation on trt backend #1071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

zjuwyz
Copy link

@zjuwyz zjuwyz commented Jun 6, 2025

No description provided.

@zjuwyz
Copy link
Author

zjuwyz commented Jun 6, 2025

See https://discord.com/channels/417022162348802048/583775968804732928/1380433026093289596

@zsqdx
Copy link

zsqdx commented Jun 19, 2025

@lightvector Hi. Do you think this branch is ready to merge? Thanks!

Copy link
Owner

@lightvector lightvector left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! Got around to taking a look at this, left some comments. Aside from that, a question is - has this code been also tested with cuda graphs disabled, to sanity check that the changes don't break the original option that has no cuda graphs?

Additionally, does anything about the plan cache or similar things need to change with cuda graphs or any of these changes? Should we be bumping the salt for that, as well as adding cuda graph status into the hash so we differentiate the cache between cuda graphs and non-cuda graphs?

Comment on lines +1617 to +1625
cudaMallocHost((void**)&maskInputs, maxBatchSize * singleMaskElts * sizeof(float));
cudaMallocHost((void**)&spatialInputs, maxBatchSize * singleInputElts * sizeof(float));
cudaMallocHost((void**)&globalInputs, maxBatchSize * singleInputGlobalElts * sizeof(float));
cudaMallocHost((void**)&metaInputs, maxBatchSize * singleInputMetaElts * sizeof(float));
cudaMallocHost((void**)&policyPassResults, maxBatchSize * singlePolicyPassResultElts * sizeof(float));
cudaMallocHost((void**)&policyResults, maxBatchSize * singlePolicyResultElts * sizeof(float));
cudaMallocHost((void**)&valueResults, maxBatchSize * singleValueResultElts * sizeof(float));
cudaMallocHost((void**)&scoreValueResults, maxBatchSize * singleScoreValueResultElts * sizeof(float));
cudaMallocHost((void**)&ownershipResults, maxBatchSize * singleOwnershipResultElts * sizeof(float));
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're changing all of these to raw pointers, does this require a corresponding free operation somewhere?

Comment on lines +1325 to +1333
{
int gpuId;
cudaGetDevice(&gpuId);
auto& mutex = mutexPerGpu[gpuId];
mutex.lock();
planBuffer.reset(builder->buildSerializedNetwork(*model->network, *config));
mutex.unlock();
}

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this code be gated behind the TENSORRT_CUDA_GRAPH define? Is there any other code that should be gated that isn't as well?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is auto& mutex = mutexPerGpu[gpuId]; safe?

This is a global map that starts unintialized, so the map itself is subject to race conditions since retrieving a value out of it involves a mutation, right?

Nitpick: for the mutex lock/unlock, normally I think we would use a std::lock_guard so that RAII guarantees unlock even if the build of the network raises. (although admittedly generally katago is written so that exceptions in this kind of code are fatal anyways). Is that easy to do?

Comment on lines +1718 to +1721
int gpuId;
cudaGetDevice(&gpuId);
auto& mutex = gpuHandle->mutexPerGpu[gpuId];
mutex.lock();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're going to have repeat code that does cudaGetDevice and grabs a mutex, maybe that should be factored out as a helper, so that it can be written once safely (with unsafe access to mutexPerGpu itself fixed).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants