@@ -82,6 +82,61 @@ def dequantize_per_token(tensor, inv_scale, dtype):
82
82
dequantize_per_token (ref_y , scale , dtype ),
83
83
)
84
84
85
+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
86
+ def test_scaled_fp8_quant_with_padding (dtype ) -> None :
87
+ original_rows = 5
88
+ x = (torch .randn (size = (original_rows , 16 ), device = "cuda" ) * 13 ).to (dtype )
89
+
90
+ padding_size = 10
91
+
92
+ # Test with dynamic quantization
93
+ y_dynamic , scale_dynamic = scaled_fp8_quant (
94
+ x , None , num_token_padding = padding_size
95
+ )
96
+
97
+ # Verify output shape has the padded size
98
+ assert y_dynamic .shape [0 ] == padding_size
99
+ assert y_dynamic .shape [1 ] == x .shape [1 ]
100
+
101
+ # Verify that the actual data in the non-padded region is correctly quantized
102
+ y_without_padding , scale_without_padding = scaled_fp8_quant (x , None )
103
+ torch .testing .assert_close (y_dynamic [:original_rows ], y_without_padding )
104
+
105
+ # Test with static quantization
106
+ # First get a scale
107
+ _ , scale = scaled_fp8_quant (x , None )
108
+
109
+ # Then use it for static quantization with padding
110
+ y_static , _ = scaled_fp8_quant (x , scale , num_token_padding = padding_size )
111
+
112
+ # Verify output shape has the padded size
113
+ assert y_static .shape [0 ] == padding_size
114
+ assert y_static .shape [1 ] == x .shape [1 ]
115
+
116
+ # Verify that the actual data in the non-padded region is correctly quantized
117
+ y_static_without_padding , _ = scaled_fp8_quant (x , scale )
118
+ torch .testing .assert_close (y_static [:original_rows ], y_static_without_padding )
119
+
120
+ # Test with per-token dynamic quantization
121
+ y_per_token , scale_per_token = scaled_fp8_quant (
122
+ x , None , num_token_padding = padding_size , use_per_token_if_dynamic = True
123
+ )
124
+
125
+ # Verify output shape has the padded size
126
+ assert y_per_token .shape [0 ] == padding_size
127
+ assert y_per_token .shape [1 ] == x .shape [1 ]
128
+
129
+ # Verify that the actual data in the non-padded region is correctly quantized
130
+ y_per_token_without_padding , scale_per_token_without_padding = scaled_fp8_quant (
131
+ x , None , use_per_token_if_dynamic = True
132
+ )
133
+ torch .testing .assert_close (
134
+ y_per_token [:original_rows ], y_per_token_without_padding
135
+ )
136
+ torch .testing .assert_close (
137
+ scale_per_token [:original_rows ], scale_per_token_without_padding
138
+ )
139
+
85
140
86
141
if __name__ == "__main__" :
87
142
# Run the specific test function directly
0 commit comments