Skip to content

Commit 327e039

Browse files
BBufjimoosciuc
authored andcommitted
sgl scaled_fp8_quant support output padding (sgl-project#4861)
1 parent 105505f commit 327e039

File tree

3 files changed

+61
-4
lines changed

3 files changed

+61
-4
lines changed

python/sglang/srt/custom_op.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def dispatch_forward(self):
5050
def scaled_fp8_quant(
5151
input: torch.Tensor,
5252
scale: Optional[torch.Tensor] = None,
53+
num_token_padding: Optional[int] = None,
5354
use_per_token_if_dynamic: bool = False,
5455
) -> tuple[torch.Tensor, torch.Tensor]:
5556
"""
@@ -59,6 +60,8 @@ def scaled_fp8_quant(
5960
input (torch.Tensor): Input tensor to be quantized
6061
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
6162
If None, scales will be computed dynamically.
63+
num_token_padding (Optional[int]): If specified, pad the first dimension
64+
of the output to at least this value.
6265
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
6366
determines the quantization granularity:
6467
- True: compute scale per token
@@ -75,6 +78,8 @@ def scaled_fp8_quant(
7578
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
7679
shape = input.shape
7780
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
81+
if num_token_padding:
82+
shape = (max(num_token_padding, input.shape[0]), shape[1])
7883
output = torch.empty(shape, device=input.device, dtype=out_dtype)
7984

8085
if scale is None:

python/sglang/srt/layers/quantization/fp8_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -457,12 +457,9 @@ def apply(
457457
qinput, x_scale = sgl_scaled_fp8_quant(
458458
input_2d,
459459
input_scale,
460+
num_token_padding=self.output_padding,
460461
use_per_token_if_dynamic=use_per_token_if_dynamic,
461462
)
462-
if self.output_padding:
463-
pad_size = max(self.output_padding - qinput.shape[0], 0)
464-
if pad_size > 0:
465-
qinput = torch.nn.functional.pad(qinput, (0, 0, 0, pad_size))
466463
else:
467464
qinput, x_scale = ops.scaled_fp8_quant(
468465
input_2d,

python/sglang/test/test_custom_ops.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,61 @@ def dequantize_per_token(tensor, inv_scale, dtype):
8282
dequantize_per_token(ref_y, scale, dtype),
8383
)
8484

85+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
86+
def test_scaled_fp8_quant_with_padding(dtype) -> None:
87+
original_rows = 5
88+
x = (torch.randn(size=(original_rows, 16), device="cuda") * 13).to(dtype)
89+
90+
padding_size = 10
91+
92+
# Test with dynamic quantization
93+
y_dynamic, scale_dynamic = scaled_fp8_quant(
94+
x, None, num_token_padding=padding_size
95+
)
96+
97+
# Verify output shape has the padded size
98+
assert y_dynamic.shape[0] == padding_size
99+
assert y_dynamic.shape[1] == x.shape[1]
100+
101+
# Verify that the actual data in the non-padded region is correctly quantized
102+
y_without_padding, scale_without_padding = scaled_fp8_quant(x, None)
103+
torch.testing.assert_close(y_dynamic[:original_rows], y_without_padding)
104+
105+
# Test with static quantization
106+
# First get a scale
107+
_, scale = scaled_fp8_quant(x, None)
108+
109+
# Then use it for static quantization with padding
110+
y_static, _ = scaled_fp8_quant(x, scale, num_token_padding=padding_size)
111+
112+
# Verify output shape has the padded size
113+
assert y_static.shape[0] == padding_size
114+
assert y_static.shape[1] == x.shape[1]
115+
116+
# Verify that the actual data in the non-padded region is correctly quantized
117+
y_static_without_padding, _ = scaled_fp8_quant(x, scale)
118+
torch.testing.assert_close(y_static[:original_rows], y_static_without_padding)
119+
120+
# Test with per-token dynamic quantization
121+
y_per_token, scale_per_token = scaled_fp8_quant(
122+
x, None, num_token_padding=padding_size, use_per_token_if_dynamic=True
123+
)
124+
125+
# Verify output shape has the padded size
126+
assert y_per_token.shape[0] == padding_size
127+
assert y_per_token.shape[1] == x.shape[1]
128+
129+
# Verify that the actual data in the non-padded region is correctly quantized
130+
y_per_token_without_padding, scale_per_token_without_padding = scaled_fp8_quant(
131+
x, None, use_per_token_if_dynamic=True
132+
)
133+
torch.testing.assert_close(
134+
y_per_token[:original_rows], y_per_token_without_padding
135+
)
136+
torch.testing.assert_close(
137+
scale_per_token[:original_rows], scale_per_token_without_padding
138+
)
139+
85140

86141
if __name__ == "__main__":
87142
# Run the specific test function directly

0 commit comments

Comments
 (0)