Skip to content

Commit d6edaab

Browse files
q10facebook-github-bot
authored andcommitted
Optimize the cache fetch for forward split, pt. 2A (pytorch#2289)
Summary: Pull Request resolved: pytorch#2289 This follows up the work on D51865590 by plumbing the `uvm_cache_stats` argument passing up to the Python API level. `local_uvm_cache_stats` is now zeroed out before the prefetch step as opposed to after, to allow for the data to be passed into the forward step. This is a re-attempt of landing D51995949 with additions copied from D52670550 Reviewed By: spcyppt Differential Revision: D53113564 fbshipit-source-id: 211bf3d1c35994ebf2346e9abf004cdb85fee69e
1 parent 9b2fa10 commit d6edaab

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ class {{ autograd_func }} :
167167
const c10::optional<Tensor>& feature_requires_grad,
168168
{%- endif %}
169169
const Tensor& lxu_cache_locations,
170+
c10::optional<Tensor> uvm_cache_stats,
170171
{%- if optimizer != "none" %}
171172
const bool gradient_clipping,
172173
const double max_gradient,
@@ -196,6 +197,11 @@ class {{ autograd_func }} :
196197
const auto max_B_ = offsets.sym_size(0) / T;
197198
{%- endif %}
198199

200+
// NOTE: The `local_uvm_cache_stats` variable held by the nn.Module has dtype int32_t
201+
// TODO: Hook up with frontend code
202+
const auto uvm_cache_stats_ = uvm_cache_stats
203+
.value_or(at::empty({0}, uvm_weights.options().dtype(at::kInt)));
204+
199205
// TODO: don't guard here
200206
auto [info_B_num_bits, info_B_mask] = adjust_info_B_num_bits(max_B_.guard_int(__FILE__, __LINE__), T.guard_int(__FILE__, __LINE__));
201207

@@ -283,13 +289,6 @@ class {{ autograd_func }} :
283289
const auto& flatten_dev_weights = dev_weights;
284290
{%- endif %}
285291

286-
287-
288-
289-
const auto uvm_cache_stats = at::empty({0}, uvm_weights.options().dtype(at::kInt));
290-
291-
292-
293292
{%- if not nobag %}
294293
{%- for weighted in [False, True] %}
295294
{%- set wdesc = "weighted" if weighted else "unweighted" %}
@@ -324,7 +323,7 @@ class {{ autograd_func }} :
324323
*indice_weights,
325324
{%- endif %}
326325
lxu_cache_locations,
327-
uvm_cache_stats,
326+
uvm_cache_stats_,
328327
output_dtype,
329328
{%- if vbe %}
330329
vbe_row_output_offsets,
@@ -355,7 +354,7 @@ class {{ autograd_func }} :
355354
indices,
356355
offsets,
357356
lxu_cache_locations,
358-
uvm_cache_stats,
357+
uvm_cache_stats_,
359358
output_dtype,
360359
/*is_experimental=*/false
361360
)
@@ -555,6 +554,7 @@ class {{ autograd_func }} :
555554
grad_indice_weights, // indice_weights
556555
Variable(), // feature_requires_grad
557556
Variable(), // lxu_cache_locations
557+
Variable(), // uvm_cache_stats
558558
{%- if optimizer != "none" %}
559559
Variable(), // gradient_clipping
560560
Variable(), // max_gradient
@@ -628,6 +628,7 @@ class {{ autograd_func }} :
628628
Variable(), // indices
629629
Variable(), // offsets
630630
Variable(), // lxu_cache_locations
631+
Variable(), // uvm_cache_stats
631632
{%- if optimizer != "none" %}
632633
Variable(), // gradient_clipping
633634
Variable(), // max_gradient
@@ -688,7 +689,8 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
688689
const int64_t vbe_output_size = -1,
689690
const bool is_experimental = false,
690691
const bool use_uniq_cache_locations_bwd = false,
691-
const bool use_homogeneous_placements = false
692+
const bool use_homogeneous_placements = false,
693+
const c10::optional<Tensor>& uvm_cache_stats = c10::optional<Tensor>()
692694
) {
693695
{%- if has_gpu_support %}
694696
{%- for vbe in ([True, False] if has_vbe_support else [False]) %}
@@ -738,6 +740,7 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
738740
feature_requires_grad,
739741
{%- endif %}
740742
lxu_cache_locations,
743+
uvm_cache_stats,
741744
{%- if optimizer != "none" %}
742745
gradient_clipping,
743746
max_gradient,
@@ -802,7 +805,9 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) {
802805
" int vbe_output_size=-1, "
803806
" bool is_experimental=False, "
804807
" bool use_uniq_cache_locations_bwd=False, "
805-
" bool use_homogeneous_placements=False) -> Tensor",
808+
" bool use_homogeneous_placements=False, "
809+
" Tensor? uvm_cache_stats=None"
810+
") -> Tensor",
806811
{PT2_COMPLIANT_TAG});
807812
// We're playing a funny trick here: we're using the autograd
808813
// implementation of the operator at all the dispatch keys. This is OK

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,8 +1282,8 @@ def _prefetch(self, indices: Tensor, offsets: Tensor) -> None:
12821282

12831283
if self.gather_uvm_cache_stats:
12841284
# Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64).
1285-
# We may wanna do this accumulation atomically, but as it's only for monitoring,
1286-
# slightly inaccurate result may be acceptable.
1285+
# We may want to do this accumulation atomically, but as it's only
1286+
# for monitoring, slightly inaccurate result may be acceptable.
12871287
self.uvm_cache_stats = torch.add(
12881288
self.uvm_cache_stats, self.local_uvm_cache_stats
12891289
)

0 commit comments

Comments
 (0)