Skip to content

Commit d9f861b

Browse files
yuguo68facebook-github-bot
authored andcommitted
TBE UVM cache prefetch pipeline (pytorch#1893)
Summary: Pull Request resolved: pytorch#1893 This diff is to enable cache prefetch pipeline of TBE, so that prefetch of batch_{i+1} can overlap with forward/backward of batch_i. As the cache can be evicted by prefetch and the weights can be updated by the backward, we need to carefully protect a few scenarios that result in cache invalidation. ## 1. prevent immature cache eviction: cache gets evicted while it is being used by forward pass Since prefetch can overlap with forward/backward pass, it is possible that prefetch tries to evict cache but the cache is being used by forward/backward pass. The fix is to use the `lxu_cache_locking_counter` in D46172802/pytorch#1883 to check whether a cache slot is in use or not when an eviction is attempted. ## 2. prevent dirty cache: weight is being updated while it is loading to cache If the prefetch overlaps with TBE backward pass, the backward may write to uvm (idx not in cache) and at the same time prefetch (idx is inserted to cache) loads the weight from uvm to cache. We sync the streams to avoid TBE backward pass overlapping with prefetch. The backward of the rest of the module can still overlap with prefetch of TBE. The stream sync looks like: ``` # backward(batch_i) waits for prefetch(batch_{i+1}) backward pre_hook: cur_stream.wait_stream(prefetch_stream) # backward(batch_i) TBE.backward() # prefetch(batch_{i+2}) waits for backward(batch_i) backward hook: prefetch_stream.wait_stream(cur_stream) ``` ## 3. prevent cache inconsistency: weight get updated after it is loaded to cache With pipeline, in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i. Example of the issue is as follows: idx is in batch_i, batch_{i+1} prefetch(batch_i) - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss) forward(batch_i) prefetch(batch_{i+1}) - insert idx into cache, cache is loaded from host memory backward(batch_i) - cache_locations_batch_i of idx is -1, the host memory is updated forward(batch_{i+1}) - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated. The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE. Reviewed By: sryap Differential Revision: D47418650 fbshipit-source-id: 84811c423ef30fec82282702be181c00310c4e84
1 parent 99f2287 commit d9f861b

File tree

2 files changed

+410
-29
lines changed

2 files changed

+410
-29
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 173 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
279279
record_cache_metrics: RecordCacheMetrics
280280
uvm_cache_stats: torch.Tensor
281281
local_uvm_cache_stats: torch.Tensor
282+
linear_cache_indices_list: List[Tensor]
282283

283284
def __init__( # noqa C901
284285
self,
@@ -323,13 +324,22 @@ def __init__( # noqa C901
323324
bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
324325
uvm_non_rowwise_momentum: bool = False, # place non-rowwise momentum on UVM
325326
use_experimental_tbe: bool = False, # set to True to use TBE v2 (only support NVIDIA GPUs)
327+
# set to True to enable prefetch pipeline, currently only supports LRU cache policy.
328+
# If a separate stream is used for prefetch, the optional forward_stream arg of prefetch function
329+
# should be set.
330+
prefetch_pipeline: bool = False,
326331
) -> None:
327332
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
328333

329334
self.pooling_mode = pooling_mode
330335
self.bounds_check_mode_int: int = bounds_check_mode.value
331336
self.weights_precision = weights_precision
332337
self.output_dtype: int = output_dtype.as_int()
338+
assert (
339+
not prefetch_pipeline or cache_algorithm == CacheAlgorithm.LRU
340+
), "Only LRU cache policy supports prefetch_pipeline."
341+
self.prefetch_pipeline: bool = prefetch_pipeline
342+
self.lock_cache_line: bool = self.prefetch_pipeline
333343

334344
if record_cache_metrics is not None:
335345
self.record_cache_metrics = record_cache_metrics
@@ -919,10 +929,10 @@ def forward( # noqa: C901
919929
)
920930
self.step += 1
921931
if len(self.timesteps_prefetched) == 0:
922-
self.prefetch(indices, offsets)
932+
self._prefetch(indices, offsets)
923933

924934
self.timesteps_prefetched.pop(0)
925-
lxu_cache_locations = (
935+
self.lxu_cache_locations = (
926936
self.lxu_cache_locations_empty
927937
if len(self.lxu_cache_locations_list) == 0
928938
else self.lxu_cache_locations_list.pop(0)
@@ -945,7 +955,7 @@ def forward( # noqa: C901
945955
pooling_mode=self.pooling_mode,
946956
indice_weights=per_sample_weights,
947957
feature_requires_grad=feature_requires_grad,
948-
lxu_cache_locations=lxu_cache_locations,
958+
lxu_cache_locations=self.lxu_cache_locations,
949959
output_dtype=self.output_dtype,
950960
vbe_metadata=vbe_metadata,
951961
is_experimental=self.is_experimental,
@@ -1114,7 +1124,23 @@ def print_uvm_cache_stats(self) -> None:
11141124
f"unique misses / requested indices: {uvm_cache_stats[3]/uvm_cache_stats[1]}\n"
11151125
)
11161126

1117-
def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
1127+
def prefetch(
1128+
self,
1129+
indices: Tensor,
1130+
offsets: Tensor,
1131+
forward_stream: Optional[torch.cuda.Stream] = None,
1132+
) -> None:
1133+
if self.prefetch_stream is None and forward_stream is not None:
1134+
self.prefetch_stream = torch.cuda.current_stream()
1135+
assert (
1136+
self.prefetch_stream != forward_stream
1137+
), "prefetch_stream and forward_stream should not be the same stream"
1138+
1139+
self._prefetch(indices, offsets)
1140+
if forward_stream is not None:
1141+
self._prefetch_tensors_record_stream(forward_stream)
1142+
1143+
def _prefetch(self, indices: Tensor, offsets: Tensor) -> None:
11181144
self.timestep += 1
11191145
self.timesteps_prefetched.append(self.timestep)
11201146
if not self.lxu_cache_weights.numel():
@@ -1163,6 +1189,8 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
11631189
self.stochastic_rounding,
11641190
self.gather_uvm_cache_stats,
11651191
self.local_uvm_cache_stats,
1192+
self.lock_cache_line,
1193+
self.lxu_cache_locking_counter,
11661194
)
11671195
elif self.cache_algorithm == CacheAlgorithm.LFU:
11681196
torch.ops.fbgemm.lfu_cache_populate(
@@ -1182,15 +1210,19 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
11821210
assert (
11831211
len(self.lxu_cache_locations_list) < self.max_prefetch_depth
11841212
), f"self.lxu_cache_locations_list has grown to size: {len(self.lxu_cache_locations_list)}, this exceeds the maximum: {self.max_prefetch_depth}. This probably indicates an error in logic where prefetch() is being called more frequently than forward()"
1185-
self.lxu_cache_locations_list.append(
1186-
torch.ops.fbgemm.lxu_cache_lookup(
1187-
linear_cache_indices,
1188-
self.lxu_cache_state,
1189-
self.total_cache_hash_size,
1190-
self.gather_uvm_cache_stats,
1191-
self.local_uvm_cache_stats,
1192-
)
1213+
1214+
lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
1215+
linear_cache_indices,
1216+
self.lxu_cache_state,
1217+
self.total_cache_hash_size,
1218+
self.gather_uvm_cache_stats,
1219+
self.local_uvm_cache_stats,
11931220
)
1221+
1222+
self.lxu_cache_locations_list.append(lxu_cache_locations)
1223+
if self.prefetch_pipeline:
1224+
self.linear_cache_indices_list.append(linear_cache_indices)
1225+
11941226
if self.gather_uvm_cache_stats:
11951227
# Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64).
11961228
# We may wanna do this accumulation atomically, but as it's only for monitoring,
@@ -1200,6 +1232,20 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
12001232
)
12011233
self.local_uvm_cache_stats.zero_()
12021234

1235+
def _prefetch_tensors_record_stream(
1236+
self, forward_stream: torch.cuda.Stream
1237+
) -> None:
1238+
# Record the tensors created by prefetch stream and consumed by forward/backward
1239+
# to the forward stream. In PyTorch, each backward CUDA op runs on the same
1240+
# stream that was used for its corresponding forward op.
1241+
1242+
for t in self.lxu_cache_locations_list:
1243+
# pyre-fixme[6]: For 1st param expected `_C.Stream` but got `streams.Stream`
1244+
t.record_stream(forward_stream)
1245+
for t in self.linear_cache_indices_list:
1246+
# pyre-fixme[6]: For 1st param expected `_C.Stream` but got `streams.Stream`
1247+
t.record_stream(forward_stream)
1248+
12031249
def _update_cache_miss_counter(
12041250
self,
12051251
lxu_cache_locations: Tensor,
@@ -1521,6 +1567,9 @@ def _apply_cache_state(
15211567
self.lxu_cache_locations_empty = torch.empty(
15221568
0, device=self.current_device, dtype=torch.int32
15231569
).fill_(-1)
1570+
self.lxu_cache_locations = self.lxu_cache_locations_empty
1571+
self.prefetch_stream: Optional[torch.cuda.Stream] = None
1572+
self.linear_cache_indices_list = []
15241573

15251574
self._init_uvm_cache_stats()
15261575

@@ -1561,6 +1610,7 @@ def _apply_cache_state(
15611610
torch.tensor([0, 0], dtype=torch.int64),
15621611
persistent=False,
15631612
)
1613+
self._init_uvm_cache_counter(cache_sets, persistent=False)
15641614
return
15651615

15661616
assert cache_load_factor > 0
@@ -1648,13 +1698,124 @@ def _apply_cache_state(
16481698
"cache_miss_counter",
16491699
torch.tensor([0, 0], device=self.current_device, dtype=torch.int64),
16501700
)
1701+
self._init_uvm_cache_counter(cache_sets, persistent=True)
1702+
if self.prefetch_pipeline:
1703+
# using the placeholder_autograd_tensor to make sure
1704+
# the hook is executed after the backward pass
1705+
# not using register_module_full_backward_hook
1706+
# due to https://github.com/pytorch/pytorch/issues/100528
1707+
self.placeholder_autograd_tensor.register_hook(
1708+
self._sync_stream_post_backward
1709+
)
1710+
self.register_full_backward_pre_hook(
1711+
self._update_cache_counter_and_locations
1712+
)
16511713

16521714
if cache_algorithm not in (CacheAlgorithm.LFU, CacheAlgorithm.LRU):
16531715
raise ValueError(
16541716
f"cache_algorithm must be {CacheAlgorithm.LRU} "
16551717
f"or {CacheAlgorithm.LFU}"
16561718
)
16571719

1720+
def _sync_stream_post_backward(
1721+
self,
1722+
grad: Tensor,
1723+
) -> None:
1724+
"""
1725+
backward hook function when prefetch_pipeline is enabled.
1726+
1727+
With the pipeline, prefetch(batch_{i+2}) may overlap with backward(batch_{i}).
1728+
There is race condition that backward(batch_i) writes to UVM memory and
1729+
at the same time prefetch(batch_{i+2}) loads UVM memory to cache. This stream sync forces
1730+
backward(batch_i) to finish before prefetch(batch_{i+2}).
1731+
"""
1732+
if self.prefetch_stream is not None:
1733+
self.prefetch_stream.wait_stream(torch.cuda.current_stream())
1734+
1735+
def _update_cache_counter_and_locations(
1736+
self,
1737+
module: nn.Module,
1738+
grad_input: Union[Tuple[Tensor, ...], Tensor],
1739+
) -> None:
1740+
"""
1741+
Backward prehook function when prefetch_pipeline is enabled.
1742+
1743+
This function does 3 things:
1744+
1. backward stream waits for prefetch stream to finish.
1745+
Otherwise the prefetch(batch_{i+1}) might overlap with backward(batch_i).
1746+
If an idx is not in cache in batch_i, but it is being inserted in batch_{i+1},
1747+
there is race condition that backward(batch_i) writes to UVM memory and
1748+
at the same time prefetch(batch_{i+1}) loads UVM memory to cache.
1749+
1750+
2. decrement the lxu_cache_locking_counter to indicate the current batch is finished.
1751+
The lxu_cache_locking_counter is updated in both prefetch and TBE backward.
1752+
As there is no overlap between prefetch and backward, we can decrement either before or
1753+
after backward. It's better to decrement before lxu_cache_locations gets updated.
1754+
1755+
3. update lxu_cache_locations to address the cache inconsistency issue.
1756+
In the case that the same index is not inserted into cache in batch_i,
1757+
but it is inserted in batch_{i+1}, the cache can be invalid in
1758+
the sense that the cached weight for this index does not have the
1759+
backward update of batch_i.
1760+
1761+
Example of the issue is as follows:
1762+
idx is in batch_i, batch_{i+1}
1763+
prefetch(batch_i)
1764+
- failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss)
1765+
forward(batch_i)
1766+
prefetch(batch_{i+1})
1767+
- insert idx into cache, cache is loaded from host memory
1768+
backward(batch_i)
1769+
- cache_locations_batch_i of idx is -1, the host memory is updated
1770+
forward(batch_{i+1})
1771+
- OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated.
1772+
1773+
The fix to this cache inconsistency is to update the cache_locations_batch_i before backward of batch_i,
1774+
so that the cache gets updated correctly by the backward pass of TBE.
1775+
"""
1776+
1777+
if self.prefetch_stream is not None:
1778+
# need to wait for the prefetch of next batch,
1779+
# so that cache states are valid
1780+
torch.cuda.current_stream().wait_stream(self.prefetch_stream)
1781+
1782+
torch.ops.fbgemm.lxu_cache_locking_counter_decrement(
1783+
self.lxu_cache_locking_counter,
1784+
self.lxu_cache_locations,
1785+
)
1786+
1787+
linear_cache_indices = self.linear_cache_indices_list.pop(0)
1788+
lxu_cache_locations_new = torch.ops.fbgemm.lxu_cache_lookup(
1789+
linear_cache_indices,
1790+
self.lxu_cache_state,
1791+
self.total_cache_hash_size,
1792+
False, # not collecting cache stats
1793+
self.local_uvm_cache_stats,
1794+
)
1795+
# self.lxu_cache_locations is updated inplace
1796+
torch.ops.fbgemm.lxu_cache_locations_update(
1797+
self.lxu_cache_locations,
1798+
lxu_cache_locations_new,
1799+
)
1800+
1801+
def _init_uvm_cache_counter(self, cache_sets: int, persistent: bool) -> None:
1802+
if self.prefetch_pipeline and persistent:
1803+
self.register_buffer(
1804+
"lxu_cache_locking_counter",
1805+
torch.zeros(
1806+
cache_sets,
1807+
DEFAULT_ASSOC,
1808+
device=self.current_device,
1809+
dtype=torch.int32,
1810+
),
1811+
)
1812+
else:
1813+
self.register_buffer(
1814+
"lxu_cache_locking_counter",
1815+
torch.zeros([0, 0], dtype=torch.int32, device=self.current_device),
1816+
persistent=persistent,
1817+
)
1818+
16581819
def _init_uvm_cache_stats(self) -> None:
16591820
if not self.gather_uvm_cache_stats:
16601821
# If uvm_cache_stats is not enabled, register stub entries via buffer to state_dict for TorchScript to JIT properly.

0 commit comments

Comments
 (0)