Skip to content

Commit d65498f

Browse files
Allow FBGEMM_TBE_BOUNDS_CHECK_MODE to take effect when using mode 4,5,6 (#3838)
Summary: X-link: facebookresearch/FBGEMM#925 This diff allows to set V2 check bound mode via env var `FBGEMM_TBE_BOUNDS_CHECK_MODE`, by setting its value to 4 (V2_IGNORE), 5 (V2_WARNING), 6 (V2_FATAL) Previously, we can only get v2 check bound mode if bounds_check_mode is set to mode prefixed with V2. Reviewed By: sryap Differential Revision: D71344486
1 parent 2e638e9 commit d65498f

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,10 @@ def __init__( # noqa C901
687687
self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE
688688
# If environment variable is set, it overwrites the default bounds check mode.
689689
self.bounds_check_version: int = 1
690+
self.bounds_check_mode_int: int = int(
691+
os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value)
692+
)
693+
bounds_check_mode = BoundsCheckMode(self.bounds_check_mode_int)
690694
if bounds_check_mode.name.startswith("V2_"):
691695
self.bounds_check_version = 2
692696
if bounds_check_mode == BoundsCheckMode.V2_IGNORE:
@@ -700,9 +704,6 @@ def __init__( # noqa C901
700704
f"Did not recognize V2 bounds check mode: {bounds_check_mode}"
701705
)
702706

703-
self.bounds_check_mode_int: int = int(
704-
os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value)
705-
)
706707
self.weights_precision = weights_precision
707708

708709
if torch.cuda.is_available() and torch.version.hip:

0 commit comments

Comments
 (0)