@@ -167,6 +167,7 @@ class {{ autograd_func }} :
167
167
const c10::optional<Tensor>& feature_requires_grad,
168
168
{%- endif %}
169
169
const Tensor& lxu_cache_locations,
170
+ c10::optional<Tensor> uvm_cache_stats,
170
171
{%- if optimizer != " none" %}
171
172
const bool gradient_clipping,
172
173
const double max_gradient,
@@ -196,6 +197,11 @@ class {{ autograd_func }} :
196
197
const auto max_B_ = offsets.sym_size (0 ) / T;
197
198
{%- endif %}
198
199
200
+ // NOTE: The `local_uvm_cache_stats` variable held by the nn.Module has dtype int32_t
201
+ // TODO: Hook up with frontend code
202
+ const auto uvm_cache_stats_ = uvm_cache_stats
203
+ .value_or (at::empty ({0 }, uvm_weights.options ().dtype (at::kInt )));
204
+
199
205
// TODO: don't guard here
200
206
auto [info_B_num_bits, info_B_mask] = adjust_info_B_num_bits (max_B_.guard_int (__FILE__, __LINE__), T.guard_int (__FILE__, __LINE__));
201
207
@@ -283,13 +289,6 @@ class {{ autograd_func }} :
283
289
const auto & flatten_dev_weights = dev_weights;
284
290
{%- endif %}
285
291
286
-
287
-
288
-
289
- const auto uvm_cache_stats = at::empty ({0 }, uvm_weights.options ().dtype (at::kInt ));
290
-
291
-
292
-
293
292
{%- if not nobag %}
294
293
{%- for weighted in [False, True] %}
295
294
{%- set wdesc = " weighted" if weighted else " unweighted" %}
@@ -324,7 +323,7 @@ class {{ autograd_func }} :
324
323
*indice_weights,
325
324
{%- endif %}
326
325
lxu_cache_locations,
327
- uvm_cache_stats ,
326
+ uvm_cache_stats_ ,
328
327
output_dtype,
329
328
{%- if vbe %}
330
329
vbe_row_output_offsets,
@@ -355,7 +354,7 @@ class {{ autograd_func }} :
355
354
indices,
356
355
offsets,
357
356
lxu_cache_locations,
358
- uvm_cache_stats ,
357
+ uvm_cache_stats_ ,
359
358
output_dtype,
360
359
/* is_experimental=*/ false
361
360
)
@@ -555,6 +554,7 @@ class {{ autograd_func }} :
555
554
grad_indice_weights, // indice_weights
556
555
Variable (), // feature_requires_grad
557
556
Variable (), // lxu_cache_locations
557
+ Variable (), // uvm_cache_stats
558
558
{%- if optimizer != " none" %}
559
559
Variable (), // gradient_clipping
560
560
Variable (), // max_gradient
@@ -628,6 +628,7 @@ class {{ autograd_func }} :
628
628
Variable (), // indices
629
629
Variable (), // offsets
630
630
Variable (), // lxu_cache_locations
631
+ Variable (), // uvm_cache_stats
631
632
{%- if optimizer != " none" %}
632
633
Variable (), // gradient_clipping
633
634
Variable (), // max_gradient
@@ -688,7 +689,8 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
688
689
const int64_t vbe_output_size = -1 ,
689
690
const bool is_experimental = false ,
690
691
const bool use_uniq_cache_locations_bwd = false ,
691
- const bool use_homogeneous_placements = false
692
+ const bool use_homogeneous_placements = false ,
693
+ const c10::optional<Tensor>& uvm_cache_stats = c10::optional<Tensor>()
692
694
) {
693
695
{%- if has_gpu_support %}
694
696
{%- for vbe in ([True, False] if has_vbe_support else [False]) %}
@@ -738,6 +740,7 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
738
740
feature_requires_grad,
739
741
{%- endif %}
740
742
lxu_cache_locations,
743
+ uvm_cache_stats,
741
744
{%- if optimizer != " none" %}
742
745
gradient_clipping,
743
746
max_gradient,
@@ -802,7 +805,9 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) {
802
805
" int vbe_output_size=-1, "
803
806
" bool is_experimental=False, "
804
807
" bool use_uniq_cache_locations_bwd=False, "
805
- " bool use_homogeneous_placements=False) -> Tensor" ,
808
+ " bool use_homogeneous_placements=False, "
809
+ " Tensor? uvm_cache_stats=None"
810
+ " ) -> Tensor" ,
806
811
{PT2_COMPLIANT_TAG});
807
812
// We're playing a funny trick here: we're using the autograd
808
813
// implementation of the operator at all the dispatch keys. This is OK
0 commit comments