Skip to content

Commit 0c042bd

Browse files
sryapfacebook-github-bot
authored andcommitted
Fix IMA in TBE grad indices kernel for int32 indices (pytorch#967)
Summary: This diff forces casting the weight tensor index to `overflow_safe_int_t` whichi is int64_t to address the 32-bit int overflow issue when the TBE `indices` is int32_t. Before this diff, `idx_j` and `D_emb` are both int32_t when using 32-bit int `indices`. When accessing the embedding `weights` tensor, we computed an index by multiplying `idx_j` and `D_emb` together. Their product could be larger than the max value of int32_t. This led to an integer overflow problem, resulting in illegal memory access. By forcing the `idx_j` and `D_emb` to be int64_t, we can prevent the 32-bit int overflow problem. Note that using int64_t is safe since its max value is much larger than the memory sizes of modern GPUs and CPUs. We also add unit tests for this issue. X-link: pytorch#3877 Pull Request resolved: facebookresearch/FBGEMM#967 **Facebook:** This is the fix for S498528. The full root cause details are in https://fburl.com/gdoc/lhmvenw3. Reviewed By: brad-mengchi, spcyppt Differential Revision: D71796826 fbshipit-source-id: df7c8e06d36e2f06585a5812df8ae40863ea6253
1 parent 26781a2 commit 0c042bd

File tree

2 files changed

+183
-2
lines changed

2 files changed

+183
-2
lines changed

fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,12 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
217217
D_emb += kINT8QparamsBytes;
218218
}
219219
auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(
220-
const_cast<emb_t*>(&weights[idx_j * D_emb]),
220+
const_cast<emb_t*>(
221+
&weights[
222+
static_cast<overflow_safe_int_t>(idx_j)
223+
* static_cast<overflow_safe_int_t>(D_emb)
224+
]
225+
),
221226
nullptr,
222227
D);
223228
float2 qparams;
@@ -237,7 +242,12 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
237242
D_emb += kINT8QparamsBytes;
238243
}
239244
auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(
240-
const_cast<emb_t*>(&weights[idx_j * D_emb]),
245+
const_cast<emb_t*>(
246+
&weights[
247+
static_cast<overflow_safe_int_t>(idx_j)
248+
* static_cast<overflow_safe_int_t>(D_emb)
249+
]
250+
),
241251
nullptr,
242252
D);
243253
float2 qparams;
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import math
9+
import unittest
10+
from typing import Any
11+
12+
import hypothesis.strategies as st
13+
import torch
14+
from fbgemm_gpu.split_embedding_configs import SparseType
15+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import EmbeddingLocation
16+
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
17+
ComputeDevice,
18+
SplitTableBatchedEmbeddingBagsCodegen,
19+
)
20+
from hypothesis import given, settings, Verbosity
21+
22+
from ..common import gpu_unavailable
23+
24+
common_st = {
25+
"D": st.integers(min_value=1, max_value=512),
26+
}
27+
28+
common_settings = {
29+
"verbosity": Verbosity.verbose,
30+
"max_examples": 4,
31+
"deadline": None,
32+
}
33+
34+
MAX_INT32 = 2147483647
35+
36+
37+
class ForwardBackwardInt32OverflowTest(unittest.TestCase):
38+
def _execute_forward_backward_large_emb(
39+
self,
40+
weights_precision: SparseType,
41+
indices_dtype: torch.dtype,
42+
D: int = 1,
43+
) -> None:
44+
"""
45+
Execute the forward and backward tests for a large embedding table
46+
(numel >= MAX_INT32)
47+
48+
The test will fail if a runtime error, such as illegal memory access,
49+
is caught
50+
"""
51+
weight_dtype_bytes = weights_precision.bit_rate() // 8
52+
53+
# Embedding dimension
54+
D = D * 4
55+
row_bytes = D * weight_dtype_bytes
56+
# Hash size
57+
# Compute the number of rows in the embedding table by
58+
# div_up(MAX_INT32, D) and add 32 extra bytes to ensure that IMA
59+
E = math.ceil(MAX_INT32 / D) + math.ceil(32 / row_bytes)
60+
61+
assert E * D >= MAX_INT32
62+
63+
# Compute total weight bytes
64+
weight_bytes = E * D * weight_dtype_bytes
65+
assert weight_bytes > 0
66+
67+
# Compute free memory
68+
total_memory = torch.cuda.get_device_properties().total_memory
69+
reserved_memory = torch.cuda.memory_reserved()
70+
free_memory = total_memory - reserved_memory
71+
if free_memory < weight_bytes:
72+
self.skipTest(
73+
f"Skip test_forward_backward_large_emb: Free memory "
74+
f"({free_memory}) < weight_bytes ({weight_bytes})"
75+
)
76+
77+
# Get device
78+
device = torch.cuda.current_device()
79+
80+
# Instantiate a TBE op
81+
op = SplitTableBatchedEmbeddingBagsCodegen(
82+
embedding_specs=[(E, D, EmbeddingLocation.DEVICE, ComputeDevice.CUDA)],
83+
output_dtype=SparseType.FP32,
84+
device=device,
85+
)
86+
87+
# Generate inputs
88+
indices = torch.as_tensor([E - 1], dtype=indices_dtype, device=device)
89+
offsets = torch.as_tensor([0, 1], dtype=indices_dtype, device=device)
90+
per_sample_weights = torch.as_tensor([0.9], dtype=torch.float, device=device)
91+
92+
# Test both weighted and unweighted
93+
for weighted in [False, True]:
94+
try:
95+
# Run forward
96+
out = op(
97+
indices=indices,
98+
offsets=offsets,
99+
per_sample_weights=per_sample_weights if weighted else None,
100+
)
101+
torch.cuda.synchronize()
102+
except RuntimeError as e:
103+
raise AssertionError(f"Forward error: {weighted=} {e}")
104+
105+
grad = out.clone().detach()
106+
107+
try:
108+
# Run backward
109+
out.backward(grad)
110+
torch.cuda.synchronize()
111+
except RuntimeError as e:
112+
raise AssertionError(f"Backward error: {weighted=} {e}")
113+
114+
# Delete the op to save space
115+
del op
116+
117+
@unittest.skipIf(*gpu_unavailable)
118+
@given(**common_st)
119+
@settings(**common_settings)
120+
def test_forward_backward_large_fp32_emb_int32_indices(self, **kwargs: Any) -> None:
121+
"""
122+
Test forward and backward TBE with a large FP32 embedding table and
123+
INT32 indices and offsets
124+
"""
125+
self._execute_forward_backward_large_emb(
126+
weights_precision=SparseType.FP32,
127+
indices_dtype=torch.int,
128+
**kwargs,
129+
)
130+
131+
@unittest.skipIf(*gpu_unavailable)
132+
@given(**common_st)
133+
@settings(**common_settings)
134+
def test_forward_backward_large_fp16_emb_int32_indices(self, **kwargs: Any) -> None:
135+
"""
136+
Test forward and backward TBE with a large FP16 embedding table and
137+
INT32 indices and offsets
138+
"""
139+
self._execute_forward_backward_large_emb(
140+
weights_precision=SparseType.FP16,
141+
indices_dtype=torch.int,
142+
**kwargs,
143+
)
144+
145+
@unittest.skipIf(*gpu_unavailable)
146+
@given(**common_st)
147+
@settings(**common_settings)
148+
def test_forward_backward_large_fp32_emb_int64_indices(self, **kwargs: Any) -> None:
149+
"""
150+
Test forward and backward TBE with a large FP32 embedding table and
151+
INT64 indices and offsets
152+
"""
153+
self._execute_forward_backward_large_emb(
154+
weights_precision=SparseType.FP32,
155+
indices_dtype=torch.long,
156+
**kwargs,
157+
)
158+
159+
@unittest.skipIf(*gpu_unavailable)
160+
@given(**common_st)
161+
@settings(**common_settings)
162+
def test_forward_backward_large_fp16_emb_int64_indices(self, **kwargs: Any) -> None:
163+
"""
164+
Test forward and backward TBE with a large FP16 embedding table and
165+
INT64 indices and offsets
166+
"""
167+
self._execute_forward_backward_large_emb(
168+
weights_precision=SparseType.FP16,
169+
indices_dtype=torch.long,
170+
**kwargs,
171+
)

0 commit comments

Comments
 (0)