-
Notifications
You must be signed in to change notification settings - Fork 615
Cuda graph implementation on trt backend #1071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
@lightvector Hi. Do you think this branch is ready to merge? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work! Got around to taking a look at this, left some comments. Aside from that, a question is - has this code been also tested with cuda graphs disabled, to sanity check that the changes don't break the original option that has no cuda graphs?
Additionally, does anything about the plan cache or similar things need to change with cuda graphs or any of these changes? Should we be bumping the salt for that, as well as adding cuda graph status into the hash so we differentiate the cache between cuda graphs and non-cuda graphs?
cudaMallocHost((void**)&maskInputs, maxBatchSize * singleMaskElts * sizeof(float)); | ||
cudaMallocHost((void**)&spatialInputs, maxBatchSize * singleInputElts * sizeof(float)); | ||
cudaMallocHost((void**)&globalInputs, maxBatchSize * singleInputGlobalElts * sizeof(float)); | ||
cudaMallocHost((void**)&metaInputs, maxBatchSize * singleInputMetaElts * sizeof(float)); | ||
cudaMallocHost((void**)&policyPassResults, maxBatchSize * singlePolicyPassResultElts * sizeof(float)); | ||
cudaMallocHost((void**)&policyResults, maxBatchSize * singlePolicyResultElts * sizeof(float)); | ||
cudaMallocHost((void**)&valueResults, maxBatchSize * singleValueResultElts * sizeof(float)); | ||
cudaMallocHost((void**)&scoreValueResults, maxBatchSize * singleScoreValueResultElts * sizeof(float)); | ||
cudaMallocHost((void**)&ownershipResults, maxBatchSize * singleOwnershipResultElts * sizeof(float)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're changing all of these to raw pointers, does this require a corresponding free operation somewhere?
{ | ||
int gpuId; | ||
cudaGetDevice(&gpuId); | ||
auto& mutex = mutexPerGpu[gpuId]; | ||
mutex.lock(); | ||
planBuffer.reset(builder->buildSerializedNetwork(*model->network, *config)); | ||
mutex.unlock(); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this code be gated behind the TENSORRT_CUDA_GRAPH define? Is there any other code that should be gated that isn't as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, is auto& mutex = mutexPerGpu[gpuId]; safe?
This is a global map that starts unintialized, so the map itself is subject to race conditions since retrieving a value out of it involves a mutation, right?
Nitpick: for the mutex lock/unlock, normally I think we would use a std::lock_guard so that RAII guarantees unlock even if the build of the network raises. (although admittedly generally katago is written so that exceptions in this kind of code are fatal anyways). Is that easy to do?
int gpuId; | ||
cudaGetDevice(&gpuId); | ||
auto& mutex = gpuHandle->mutexPerGpu[gpuId]; | ||
mutex.lock(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're going to have repeat code that does cudaGetDevice and grabs a mutex, maybe that should be factored out as a helper, so that it can be written once safely (with unsafe access to mutexPerGpu itself fixed).
No description provided.