diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index a774c5ff42..adbe3ac4cd 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -175,6 +175,7 @@ def tbe_input_combine_with_length_abstract( torch._check(len(indices_list) == len(offsets_list)) torch._check(len(indices_list) == len(per_sample_weights)) total_indices = 0 + total_offsets = 0 need_weight = False for index, offset, weight in zip(indices_list, offsets_list, per_sample_weights): torch._check(index.dtype == torch.int or index.dtype == torch.long) @@ -184,12 +185,12 @@ def tbe_input_combine_with_length_abstract( torch._check(index.is_contiguous()) torch._check(offset.is_contiguous()) total_indices = total_indices + index.numel() + total_offsets = total_offsets + offset.numel() if weight.numel() > 0: torch._check(weight.dim() == 1) torch._check(weight.numel() == index.numel()) torch._check(weight.is_contiguous()) need_weight = True - total_offsets = torch.library.get_ctx().new_dynamic_size() combined_indices = indices_list[0].new_empty([total_indices], dtype=torch.int) combined_offsets = offsets_list[0].new_empty([total_offsets], dtype=torch.int) if need_weight: @@ -197,7 +198,7 @@ def tbe_input_combine_with_length_abstract( [total_indices], dtype=torch.float ) else: - combined_weights = torch.empty(0) + combined_weights = torch.empty(0, device=indices_list[0].device) return combined_indices, combined_offsets, combined_weights