Fix the hang issue in some TBE GPU optimizers #2509
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
Previously, some TBE optimizer unit tests hung indefinitely causing
the unit tests to timeout. We were able to reproduce this problem
consistently by using the config in D50612178 (composed by ezyang).
The main characteristics of this config are (1) the optimizer is
PARTIAL_ROWWISE_ADAM, (2) the embedding dimension is less than 32, and
(3) it contains long segments (i.e., some indices are repeated with
extremely high counts).
Upon our investigation, we identified that the value reduction in
PARTIAL_ROWWISE_ADAM was implemented incorrectly. The optimizer
intended to perform a value reduction within a sub-warp (i.e., a group
of threads in a warp) instead of an entire warp. (Note that sub-warp
reduction is done when the embedding dimension is smaller than the
warp size). However, it did not pass a correct
shfl_sync
mask. Thewrong mask expected an entire warp to perform the reduction. When the
segment length is long (> 32), only one sub-warp would perform the
reduction. Such the warp divergence caused the kernel execution to
freeze. (Note that the reduction is a collective operation).
This diff fixes the issue by passing a correct mask when invoking the
reduction function.
Reviewed By: shintaro-iwasaki
Differential Revision: D56223375