|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +import math |
| 9 | +import unittest |
| 10 | +from typing import Any |
| 11 | + |
| 12 | +import hypothesis.strategies as st |
| 13 | +import torch |
| 14 | +from fbgemm_gpu.split_embedding_configs import SparseType |
| 15 | +from fbgemm_gpu.split_table_batched_embeddings_ops_common import EmbeddingLocation |
| 16 | +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( |
| 17 | + ComputeDevice, |
| 18 | + SplitTableBatchedEmbeddingBagsCodegen, |
| 19 | +) |
| 20 | +from hypothesis import given, settings, Verbosity |
| 21 | + |
| 22 | +from ..common import gpu_unavailable |
| 23 | + |
| 24 | +common_st = { |
| 25 | + "D": st.integers(min_value=1, max_value=512), |
| 26 | +} |
| 27 | + |
| 28 | +common_settings = { |
| 29 | + "verbosity": Verbosity.verbose, |
| 30 | + "max_examples": 4, |
| 31 | + "deadline": None, |
| 32 | +} |
| 33 | + |
| 34 | +MAX_INT32 = 2147483647 |
| 35 | + |
| 36 | + |
| 37 | +class ForwardBackwardInt32OverflowTest(unittest.TestCase): |
| 38 | + def _execute_forward_backward_large_emb( |
| 39 | + self, |
| 40 | + weights_precision: SparseType, |
| 41 | + indices_dtype: torch.dtype, |
| 42 | + D: int = 1, |
| 43 | + ) -> None: |
| 44 | + """ |
| 45 | + Execute the forward and backward tests for a large embedding table |
| 46 | + (numel >= MAX_INT32) |
| 47 | +
|
| 48 | + The test will fail if a runtime error, such as illegal memory access, |
| 49 | + is caught |
| 50 | + """ |
| 51 | + weight_dtype_bytes = weights_precision.bit_rate() // 8 |
| 52 | + |
| 53 | + # Embedding dimension |
| 54 | + D = D * 4 |
| 55 | + row_bytes = D * weight_dtype_bytes |
| 56 | + # Hash size |
| 57 | + # Compute the number of rows in the embedding table by |
| 58 | + # div_up(MAX_INT32, D) and add 32 extra bytes to ensure that IMA |
| 59 | + E = math.ceil(MAX_INT32 / D) + math.ceil(32 / row_bytes) |
| 60 | + |
| 61 | + assert E * D >= MAX_INT32 |
| 62 | + |
| 63 | + # Compute total weight bytes |
| 64 | + weight_bytes = E * D * weight_dtype_bytes |
| 65 | + assert weight_bytes > 0 |
| 66 | + |
| 67 | + # Compute free memory |
| 68 | + total_memory = torch.cuda.get_device_properties().total_memory |
| 69 | + reserved_memory = torch.cuda.memory_reserved() |
| 70 | + free_memory = total_memory - reserved_memory |
| 71 | + if free_memory < weight_bytes: |
| 72 | + self.skipTest( |
| 73 | + f"Skip test_forward_backward_large_emb: Free memory " |
| 74 | + f"({free_memory}) < weight_bytes ({weight_bytes})" |
| 75 | + ) |
| 76 | + |
| 77 | + # Get device |
| 78 | + device = torch.cuda.current_device() |
| 79 | + |
| 80 | + # Instantiate a TBE op |
| 81 | + op = SplitTableBatchedEmbeddingBagsCodegen( |
| 82 | + embedding_specs=[(E, D, EmbeddingLocation.DEVICE, ComputeDevice.CUDA)], |
| 83 | + output_dtype=SparseType.FP32, |
| 84 | + device=device, |
| 85 | + ) |
| 86 | + |
| 87 | + # Generate inputs |
| 88 | + indices = torch.as_tensor([E - 1], dtype=indices_dtype, device=device) |
| 89 | + offsets = torch.as_tensor([0, 1], dtype=indices_dtype, device=device) |
| 90 | + per_sample_weights = torch.as_tensor([0.9], dtype=torch.float, device=device) |
| 91 | + |
| 92 | + # Test both weighted and unweighted |
| 93 | + for weighted in [False, True]: |
| 94 | + try: |
| 95 | + # Run forward |
| 96 | + out = op( |
| 97 | + indices=indices, |
| 98 | + offsets=offsets, |
| 99 | + per_sample_weights=per_sample_weights if weighted else None, |
| 100 | + ) |
| 101 | + torch.cuda.synchronize() |
| 102 | + except RuntimeError as e: |
| 103 | + raise AssertionError(f"Forward error: {weighted=} {e}") |
| 104 | + |
| 105 | + grad = out.clone().detach() |
| 106 | + |
| 107 | + try: |
| 108 | + # Run backward |
| 109 | + out.backward(grad) |
| 110 | + torch.cuda.synchronize() |
| 111 | + except RuntimeError as e: |
| 112 | + raise AssertionError(f"Backward error: {weighted=} {e}") |
| 113 | + |
| 114 | + # Delete the op to save space |
| 115 | + del op |
| 116 | + |
| 117 | + @unittest.skipIf(*gpu_unavailable) |
| 118 | + @given(**common_st) |
| 119 | + @settings(**common_settings) |
| 120 | + def test_forward_backward_large_fp32_emb_int32_indices(self, **kwargs: Any) -> None: |
| 121 | + """ |
| 122 | + Test forward and backward TBE with a large FP32 embedding table and |
| 123 | + INT32 indices and offsets |
| 124 | + """ |
| 125 | + self._execute_forward_backward_large_emb( |
| 126 | + weights_precision=SparseType.FP32, |
| 127 | + indices_dtype=torch.int, |
| 128 | + **kwargs, |
| 129 | + ) |
| 130 | + |
| 131 | + @unittest.skipIf(*gpu_unavailable) |
| 132 | + @given(**common_st) |
| 133 | + @settings(**common_settings) |
| 134 | + def test_forward_backward_large_fp16_emb_int32_indices(self, **kwargs: Any) -> None: |
| 135 | + """ |
| 136 | + Test forward and backward TBE with a large FP16 embedding table and |
| 137 | + INT32 indices and offsets |
| 138 | + """ |
| 139 | + self._execute_forward_backward_large_emb( |
| 140 | + weights_precision=SparseType.FP16, |
| 141 | + indices_dtype=torch.int, |
| 142 | + **kwargs, |
| 143 | + ) |
| 144 | + |
| 145 | + @unittest.skipIf(*gpu_unavailable) |
| 146 | + @given(**common_st) |
| 147 | + @settings(**common_settings) |
| 148 | + def test_forward_backward_large_fp32_emb_int64_indices(self, **kwargs: Any) -> None: |
| 149 | + """ |
| 150 | + Test forward and backward TBE with a large FP32 embedding table and |
| 151 | + INT64 indices and offsets |
| 152 | + """ |
| 153 | + self._execute_forward_backward_large_emb( |
| 154 | + weights_precision=SparseType.FP32, |
| 155 | + indices_dtype=torch.long, |
| 156 | + **kwargs, |
| 157 | + ) |
| 158 | + |
| 159 | + @unittest.skipIf(*gpu_unavailable) |
| 160 | + @given(**common_st) |
| 161 | + @settings(**common_settings) |
| 162 | + def test_forward_backward_large_fp16_emb_int64_indices(self, **kwargs: Any) -> None: |
| 163 | + """ |
| 164 | + Test forward and backward TBE with a large FP16 embedding table and |
| 165 | + INT64 indices and offsets |
| 166 | + """ |
| 167 | + self._execute_forward_backward_large_emb( |
| 168 | + weights_precision=SparseType.FP16, |
| 169 | + indices_dtype=torch.long, |
| 170 | + **kwargs, |
| 171 | + ) |
0 commit comments