Skip to content

Commit 406ef0f

Browse files
yuguo68facebook-github-bot
authored andcommitted
TBE UVM cache prefetch pipeline
Summary: This diff is to enable cache prefetch pipeline of TBE, so that prefetch of batch_{i+1} can overlap with forward/backward of batch_i. As the cache can be evicted by prefetch and the weights can be updated by the backward, we need to carefully protect a few scenarios that result in cache invalidation. ## 1. prevent immature cache eviction: cache gets evicted while it is being used by forward pass Since prefetch can overlap with forward/backward pass, it is possible that prefetch tries to evict cache but the cache is being used by forward/backward pass. The fix is to use the `lxu_cache_locking_counter` in D46172802/pytorch#1883 to check whether a cache slot is in use or not when an eviction is attempted. ## 2. prevent dirty cache: weight is being updated while it is loading to cache If the prefetch overlaps with TBE backward pass, the backward may write to uvm (idx not in cache) and at the same time prefetch (idx is inserted to cache) loads the weight from uvm to cache. We sync the streams to avoid TBE backward pass overlapping with prefetch. The backward of the rest of the module can still overlap with prefetch of TBE. The stream sync looks like: ``` # backward(batch_i) waits for prefetch(batch_{i+1}) backward pre_hook: cur_stream.wait_stream(prefetch_stream) # backward(batch_i) TBE.backward() # prefetch(batch_{i+2}) waits for backward(batch_i) backward hook: prefetch_stream.wait_stream(cur_stream) ``` ## 3. prevent cache inconsistency: weight get updated after it is loaded to cache With pipeline, in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i. Example of the issue is as follows: idx is in batch_i, batch_{i+1} prefetch(batch_i) - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss) forward(batch_i) prefetch(batch_{i+1}) - insert idx into cache, cache is loaded from host memory backward(batch_i) - cache_locations_batch_i of idx is -1, the host memory is updated forward(batch_{i+1}) - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated. The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE. Reviewed By: sryap Differential Revision: D47418650 fbshipit-source-id: 05081a3b61d924238884e4263396847fe4fac4ed
1 parent 99f2287 commit 406ef0f

File tree

2 files changed

+397
-27
lines changed

2 files changed

+397
-27
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 158 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
279279
record_cache_metrics: RecordCacheMetrics
280280
uvm_cache_stats: torch.Tensor
281281
local_uvm_cache_stats: torch.Tensor
282+
linear_cache_indices_list: List[Tensor]
282283

283284
def __init__( # noqa C901
284285
self,
@@ -323,13 +324,23 @@ def __init__( # noqa C901
323324
bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
324325
uvm_non_rowwise_momentum: bool = False, # place non-rowwise momentum on UVM
325326
use_experimental_tbe: bool = False, # set to True to use TBE v2 (only support NVIDIA GPUs)
327+
# set to True to enable prefetch pipeline, currently only supports LRU cache policy.
328+
# If a separate stream is used for prefetch, user is responsible to call
329+
# set_prefetch_stream after module initialization and prefetch_tensors_record_stream
330+
# after each prefetch call.
331+
prefetch_pipeline: bool = False,
326332
) -> None:
327333
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
328334

329335
self.pooling_mode = pooling_mode
330336
self.bounds_check_mode_int: int = bounds_check_mode.value
331337
self.weights_precision = weights_precision
332338
self.output_dtype: int = output_dtype.as_int()
339+
assert (
340+
not prefetch_pipeline or cache_algorithm == CacheAlgorithm.LRU
341+
), "Only LRU cache policy supports prefetch_pipeline."
342+
self.prefetch_pipeline: bool = prefetch_pipeline
343+
self.lock_cache_line: bool = self.prefetch_pipeline
333344

334345
if record_cache_metrics is not None:
335346
self.record_cache_metrics = record_cache_metrics
@@ -922,7 +933,7 @@ def forward( # noqa: C901
922933
self.prefetch(indices, offsets)
923934

924935
self.timesteps_prefetched.pop(0)
925-
lxu_cache_locations = (
936+
self.lxu_cache_locations = (
926937
self.lxu_cache_locations_empty
927938
if len(self.lxu_cache_locations_list) == 0
928939
else self.lxu_cache_locations_list.pop(0)
@@ -945,7 +956,7 @@ def forward( # noqa: C901
945956
pooling_mode=self.pooling_mode,
946957
indice_weights=per_sample_weights,
947958
feature_requires_grad=feature_requires_grad,
948-
lxu_cache_locations=lxu_cache_locations,
959+
lxu_cache_locations=self.lxu_cache_locations,
949960
output_dtype=self.output_dtype,
950961
vbe_metadata=vbe_metadata,
951962
is_experimental=self.is_experimental,
@@ -1163,6 +1174,8 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
11631174
self.stochastic_rounding,
11641175
self.gather_uvm_cache_stats,
11651176
self.local_uvm_cache_stats,
1177+
self.lock_cache_line,
1178+
self.lxu_cache_locking_counter,
11661179
)
11671180
elif self.cache_algorithm == CacheAlgorithm.LFU:
11681181
torch.ops.fbgemm.lfu_cache_populate(
@@ -1182,15 +1195,19 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
11821195
assert (
11831196
len(self.lxu_cache_locations_list) < self.max_prefetch_depth
11841197
), f"self.lxu_cache_locations_list has grown to size: {len(self.lxu_cache_locations_list)}, this exceeds the maximum: {self.max_prefetch_depth}. This probably indicates an error in logic where prefetch() is being called more frequently than forward()"
1185-
self.lxu_cache_locations_list.append(
1186-
torch.ops.fbgemm.lxu_cache_lookup(
1187-
linear_cache_indices,
1188-
self.lxu_cache_state,
1189-
self.total_cache_hash_size,
1190-
self.gather_uvm_cache_stats,
1191-
self.local_uvm_cache_stats,
1192-
)
1198+
1199+
lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
1200+
linear_cache_indices,
1201+
self.lxu_cache_state,
1202+
self.total_cache_hash_size,
1203+
self.gather_uvm_cache_stats,
1204+
self.local_uvm_cache_stats,
11931205
)
1206+
1207+
self.lxu_cache_locations_list.append(lxu_cache_locations)
1208+
if self.prefetch_pipeline:
1209+
self.linear_cache_indices_list.append(linear_cache_indices)
1210+
11941211
if self.gather_uvm_cache_stats:
11951212
# Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64).
11961213
# We may wanna do this accumulation atomically, but as it's only for monitoring,
@@ -1521,6 +1538,9 @@ def _apply_cache_state(
15211538
self.lxu_cache_locations_empty = torch.empty(
15221539
0, device=self.current_device, dtype=torch.int32
15231540
).fill_(-1)
1541+
self.lxu_cache_locations = self.lxu_cache_locations_empty
1542+
self.prefetch_stream: Optional[torch.cuda.Stream] = None
1543+
self.linear_cache_indices_list = []
15241544

15251545
self._init_uvm_cache_stats()
15261546

@@ -1561,6 +1581,7 @@ def _apply_cache_state(
15611581
torch.tensor([0, 0], dtype=torch.int64),
15621582
persistent=False,
15631583
)
1584+
self._init_uvm_cache_counter(cache_sets, persistent=False)
15641585
return
15651586

15661587
assert cache_load_factor > 0
@@ -1648,13 +1669,137 @@ def _apply_cache_state(
16481669
"cache_miss_counter",
16491670
torch.tensor([0, 0], device=self.current_device, dtype=torch.int64),
16501671
)
1672+
self._init_uvm_cache_counter(cache_sets, persistent=True)
1673+
if self.prefetch_pipeline:
1674+
# using the placeholder_autograd_tensor to make sure
1675+
# the hook is executed after the backward pass
1676+
# not using register_module_full_backward_hook
1677+
# due to https://github.com/pytorch/pytorch/issues/100528
1678+
self.placeholder_autograd_tensor.register_hook(
1679+
self._sync_stream_post_backward
1680+
)
1681+
self.register_full_backward_pre_hook(
1682+
self._update_cache_counter_and_locations
1683+
)
16511684

16521685
if cache_algorithm not in (CacheAlgorithm.LFU, CacheAlgorithm.LRU):
16531686
raise ValueError(
16541687
f"cache_algorithm must be {CacheAlgorithm.LRU} "
16551688
f"or {CacheAlgorithm.LFU}"
16561689
)
16571690

1691+
def prefetch_tensors_record_stream(self, stream: torch.cuda.Stream) -> None:
1692+
# Record the tensors created by prefetch stream and consumed by forward/backward
1693+
# to the forward stream. In PyTorch, each backward CUDA op runs on the same
1694+
# stream that was used for its corresponding forward op.
1695+
if self.prefetch_stream is None:
1696+
return
1697+
for t in self.lxu_cache_locations_list:
1698+
# pyre-fixme[6]: For 1st param expected `_C.Stream` but got `streams.Stream`
1699+
t.record_stream(stream)
1700+
for t in self.linear_cache_indices_list:
1701+
# pyre-fixme[6]: For 1st param expected `_C.Stream` but got `streams.Stream`
1702+
t.record_stream(stream)
1703+
1704+
def _sync_stream_post_backward(
1705+
self,
1706+
grad: Tensor,
1707+
) -> None:
1708+
"""
1709+
backward hook function when prefetch_pipeline is enabled.
1710+
1711+
With the pipeline, prefetch(batch_{i+2}) may overlap with backward(batch_{i}).
1712+
There is race condition that backward(batch_i) writes to UVM memory and
1713+
at the same time prefetch(batch_{i+2}) loads UVM memory to cache. This stream sync forces
1714+
backward(batch_i) to finish before prefetch(batch_{i+2}).
1715+
"""
1716+
if self.prefetch_stream is not None:
1717+
self.prefetch_stream.wait_stream(torch.cuda.current_stream())
1718+
1719+
def _update_cache_counter_and_locations(
1720+
self,
1721+
module: nn.Module,
1722+
grad_input: Union[Tuple[Tensor, ...], Tensor],
1723+
) -> None:
1724+
"""
1725+
Backward prehook function when prefetch_pipeline is enabled.
1726+
1727+
This function does 3 things:
1728+
1. backward stream waits for prefetch stream to finish.
1729+
Otherwise the prefetch(batch_{i+1}) might overlap with backward(batch_i).
1730+
If an idx is not in cache in batch_i, but it is being inserted in batch_{i+1},
1731+
there is race condition that backward(batch_i) writes to UVM memory and
1732+
at the same time prefetch(batch_{i+1}) loads UVM memory to cache.
1733+
1734+
2. decrement the lxu_cache_locking_counter to indicate the current batch is finished.
1735+
The lxu_cache_locking_counter is updated in both prefetch and TBE backward.
1736+
As there is no overlap between prefetch and backward, we can decrement either before or
1737+
after backward. It's better to decrement before lxu_cache_locations gets updated.
1738+
1739+
3. update lxu_cache_locations to address the cache inconsistency issue.
1740+
In the case that the same index is not inserted into cache in batch_i,
1741+
but it is inserted in batch_{i+1}, the cache can be invalid in
1742+
the sense that the cached weight for this index does not have the
1743+
backward update of batch_i.
1744+
1745+
Example of the issue is as follows:
1746+
idx is in batch_i, batch_{i+1}
1747+
prefetch(batch_i)
1748+
- failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss)
1749+
forward(batch_i)
1750+
prefetch(batch_{i+1})
1751+
- insert idx into cache, cache is loaded from host memory
1752+
backward(batch_i)
1753+
- cache_locations_batch_i of idx is -1, the host memory is updated
1754+
forward(batch_{i+1})
1755+
- OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated.
1756+
1757+
The fix to this cache inconsistency is to update the cache_locations_batch_i before backward of batch_i,
1758+
so that the cache gets updated correctly by the backward pass of TBE.
1759+
"""
1760+
1761+
if self.prefetch_stream is not None:
1762+
# need to wait for the prefetch of next batch,
1763+
# so that cache states are valid
1764+
torch.cuda.current_stream().wait_stream(self.prefetch_stream)
1765+
1766+
torch.ops.fbgemm.lxu_cache_locking_counter_decrement(
1767+
self.lxu_cache_locking_counter,
1768+
self.lxu_cache_locations,
1769+
)
1770+
1771+
linear_cache_indices = self.linear_cache_indices_list.pop(0)
1772+
lxu_cache_locations_new = torch.ops.fbgemm.lxu_cache_lookup(
1773+
linear_cache_indices,
1774+
self.lxu_cache_state,
1775+
self.total_cache_hash_size,
1776+
False, # not collecting cache stats
1777+
self.local_uvm_cache_stats,
1778+
)
1779+
# self.lxu_cache_locations is updated inplace
1780+
torch.ops.fbgemm.lxu_cache_locations_update(
1781+
self.lxu_cache_locations,
1782+
lxu_cache_locations_new,
1783+
)
1784+
1785+
def _init_uvm_cache_counter(self, cache_sets: int, persistent: bool) -> None:
1786+
if self.prefetch_pipeline and persistent:
1787+
self.register_buffer(
1788+
"lxu_cache_locking_counter",
1789+
torch.zeros(
1790+
cache_sets,
1791+
DEFAULT_ASSOC,
1792+
device=self.current_device,
1793+
dtype=torch.int32,
1794+
),
1795+
)
1796+
else:
1797+
self.register_buffer(
1798+
"lxu_cache_locking_counter",
1799+
torch.zeros([0, 0], dtype=torch.int32, device=self.current_device),
1800+
persistent=persistent,
1801+
)
1802+
16581803
def _init_uvm_cache_stats(self) -> None:
16591804
if not self.gather_uvm_cache_stats:
16601805
# If uvm_cache_stats is not enabled, register stub entries via buffer to state_dict for TorchScript to JIT properly.
@@ -1696,6 +1841,9 @@ def _init_uvm_cache_stats(self) -> None:
16961841
)
16971842
self.reset_uvm_cache_stats()
16981843

1844+
def set_prefetch_stream(self, prefetch_stream: torch.cuda.Stream) -> None:
1845+
self.prefetch_stream = prefetch_stream
1846+
16991847
def reset_cache_states(self) -> None:
17001848
if not self.lxu_cache_weights.numel():
17011849
return

0 commit comments

Comments
 (0)