15
15
16
16
import torch
17
17
18
+ if os .environ ["SGLANG_ENABLE_TORCH_COMPILE" ] == "1" :
19
+ import torch ._dynamo
20
+
21
+ torch ._dynamo .config .suppress_errors = True
22
+
18
23
from sglang .global_config import global_config
19
24
from sglang .srt .layers .attention .base_attn_backend import AttentionBackend
20
25
from sglang .srt .layers .attention .utils import create_flashinfer_kv_indices_triton
@@ -82,8 +87,6 @@ def __init__(
82
87
self .max_context_len = model_runner .model_config .context_len
83
88
self .skip_prefill = skip_prefill
84
89
self .is_multimodal = model_runner .model_config .is_multimodal
85
- self .kv_cache_dtype = model_runner .kv_cache_dtype
86
- self .kv_cache_dtype_str = model_runner .server_args .kv_cache_dtype
87
90
88
91
assert not (
89
92
model_runner .sliding_window_size is not None
@@ -268,6 +271,12 @@ def init_cuda_graph_state(
268
271
cuda_graph_kv_indices .clone () for _ in range (self .num_wrappers - 1 )
269
272
]
270
273
274
+ # Ensure tensors are properly allocated
275
+ for i in range (self .num_wrappers ):
276
+ # Force allocation by performing a small operation
277
+ if len (self .cuda_graph_kv_indices [i ]) > 0 :
278
+ self .cuda_graph_kv_indices [i ][0 ] = 0
279
+
271
280
if not self .skip_prefill :
272
281
self .cuda_graph_custom_mask = torch .zeros (
273
282
(max_bs * self .max_context_len ),
@@ -396,8 +405,6 @@ def forward_extend(
396
405
forward_batch : ForwardBatch ,
397
406
save_kv_cache = True ,
398
407
):
399
- k_scale = layer .k_scale_float if self .kv_cache_dtype_str != "auto" else None
400
- v_scale = layer .v_scale_float if self .kv_cache_dtype_str != "auto" else None
401
408
prefill_wrapper_paged = self .forward_metadata .prefill_wrappers [
402
409
self ._get_wrapper_idx (layer )
403
410
]
@@ -414,7 +421,7 @@ def forward_extend(
414
421
assert v is not None
415
422
if save_kv_cache :
416
423
forward_batch .token_to_kv_pool .set_kv_buffer (
417
- layer , cache_loc , k , v , k_scale , v_scale
424
+ layer , cache_loc , k , v , layer . k_scale , layer . v_scale
418
425
)
419
426
420
427
o = prefill_wrapper_paged .forward (
@@ -424,8 +431,8 @@ def forward_extend(
424
431
sm_scale = layer .scaling ,
425
432
window_left = layer .sliding_window_size ,
426
433
logits_soft_cap = logits_soft_cap ,
427
- k_scale = k_scale ,
428
- v_scale = v_scale ,
434
+ k_scale = layer . k_scale ,
435
+ v_scale = layer . v_scale ,
429
436
)
430
437
else :
431
438
o1 , s1 = self .prefill_wrapper_ragged .forward_return_lse (
@@ -452,7 +459,7 @@ def forward_extend(
452
459
453
460
if save_kv_cache :
454
461
forward_batch .token_to_kv_pool .set_kv_buffer (
455
- layer , cache_loc , k , v , k_scale , v_scale
462
+ layer , cache_loc , k , v , layer . k_scale , layer . v_scale
456
463
)
457
464
458
465
return o .view (- 1 , layer .tp_q_head_num * layer .head_dim )
@@ -466,8 +473,6 @@ def forward_decode(
466
473
forward_batch : ForwardBatch ,
467
474
save_kv_cache = True ,
468
475
):
469
- k_scale = layer .k_scale_float if self .kv_cache_dtype_str != "auto" else None
470
- v_scale = layer .v_scale_float if self .kv_cache_dtype_str != "auto" else None
471
476
decode_wrapper = self .forward_metadata .decode_wrappers [
472
477
self ._get_wrapper_idx (layer )
473
478
]
@@ -481,16 +486,17 @@ def forward_decode(
481
486
assert v is not None
482
487
if save_kv_cache :
483
488
forward_batch .token_to_kv_pool .set_kv_buffer (
484
- layer , cache_loc , k , v , k_scale , v_scale
489
+ layer , cache_loc , k , v , layer . k_scale , layer . v_scale
485
490
)
486
491
492
+ # Call the wrapped function
487
493
o = decode_wrapper .forward (
488
494
q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim ),
489
495
forward_batch .token_to_kv_pool .get_kv_buffer (layer .layer_id ),
490
496
sm_scale = layer .scaling ,
491
497
logits_soft_cap = layer .logit_cap ,
492
- k_scale = k_scale ,
493
- v_scale = v_scale ,
498
+ k_scale = layer . k_scale ,
499
+ v_scale = layer . v_scale ,
494
500
)
495
501
496
502
return o .view (- 1 , layer .tp_q_head_num * layer .head_dim )
@@ -1146,8 +1152,9 @@ def fast_decode_plan(
1146
1152
pos_encoding_mode : str = "NONE" ,
1147
1153
window_left : int = - 1 ,
1148
1154
logits_soft_cap : Optional [float ] = None ,
1149
- data_type : Union [str , torch .dtype ] = "float16" ,
1150
1155
q_data_type : Optional [Union [str , torch .dtype ]] = None ,
1156
+ kv_data_type : Optional [Union [str , torch .dtype ]] = None ,
1157
+ data_type : Optional [Union [str , torch .dtype ]] = None ,
1151
1158
sm_scale : Optional [float ] = None ,
1152
1159
rope_scale : Optional [float ] = None ,
1153
1160
rope_theta : Optional [float ] = None ,
@@ -1163,6 +1170,18 @@ def fast_decode_plan(
1163
1170
if logits_soft_cap is None :
1164
1171
logits_soft_cap = 0.0
1165
1172
1173
+ # Handle data types consistently
1174
+ if data_type is not None :
1175
+ if q_data_type is None :
1176
+ q_data_type = data_type
1177
+ if kv_data_type is None :
1178
+ kv_data_type = data_type
1179
+ elif q_data_type is None :
1180
+ q_data_type = "float16"
1181
+
1182
+ if kv_data_type is None :
1183
+ kv_data_type = q_data_type
1184
+
1166
1185
if self .use_tensor_cores :
1167
1186
qo_indptr_host = _get_range_buf (batch_size + 1 , "cpu" )
1168
1187
@@ -1178,85 +1197,91 @@ def fast_decode_plan(
1178
1197
raise ValueError (
1179
1198
"The size of indices should be less than or equal to the allocated buffer"
1180
1199
)
1181
- # Skip these copies because we directly write to them during prepartion
1182
- # self._paged_kv_indptr_buf.copy_(indptr)
1183
- # self._paged_kv_indices_buf[: len(indices)] = indices
1184
- # self._paged_kv_last_page_len_buf.copy_(last_page_len)
1185
1200
else :
1186
1201
self ._paged_kv_indptr_buf = indptr
1187
1202
self ._paged_kv_indices_buf = indices
1188
1203
self ._paged_kv_last_page_len_buf = last_page_len
1189
- self ._qo_indptr_buf = qo_indptr_host .to (self .device , non_blocking = non_blocking )
1190
-
1191
- # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
1192
- if not q_data_type :
1193
- q_data_type = data_type
1194
-
1195
- if not hasattr (self , "empty_q_data" ):
1196
- self .empty_q_data = torch .empty (
1197
- 0 ,
1198
- dtype = (
1199
- getattr (torch , q_data_type )
1200
- if isinstance (q_data_type , str )
1201
- else q_data_type
1202
- ),
1203
- )
1204
- self .empty_kv_cache = torch .empty (
1205
- 0 ,
1206
- dtype = (
1207
- getattr (torch , data_type ) if isinstance (data_type , str ) else data_type
1208
- ),
1209
- )
1210
- self .last_page_len = torch .ones (32768 , dtype = torch .int32 )
1204
+ if self .use_tensor_cores :
1205
+ self ._qo_indptr_buf = qo_indptr_host .to (
1206
+ self .device , non_blocking = non_blocking
1207
+ )
1208
+
1209
+ # Create empty tensors for dtype info if needed
1210
+ empty_q_data = torch .empty (
1211
+ 0 ,
1212
+ dtype = (
1213
+ getattr (torch , q_data_type ) if isinstance (q_data_type , str ) else q_data_type
1214
+ ),
1215
+ device = self .device ,
1216
+ )
1217
+
1218
+ empty_kv_cache = torch .empty (
1219
+ 0 ,
1220
+ dtype = (
1221
+ getattr (torch , kv_data_type )
1222
+ if isinstance (kv_data_type , str )
1223
+ else kv_data_type
1224
+ ),
1225
+ device = self .device ,
1226
+ )
1211
1227
1212
1228
indptr_host = (
1213
1229
global_override_indptr_cpu
1214
1230
if global_override_indptr_cpu is not None
1215
1231
else indptr .cpu ()
1216
1232
)
1217
1233
1218
- if self .use_tensor_cores :
1219
- kv_lens_arr_host = get_seq_lens (
1220
- indptr_host , self .last_page_len [:batch_size ], page_size
1221
- )
1222
-
1223
- self ._plan_info = self ._cached_module .plan (
1224
- self ._float_workspace_buffer ,
1225
- self ._int_workspace_buffer ,
1226
- self ._pin_memory_int_workspace_buffer ,
1227
- qo_indptr_host ,
1228
- indptr_host ,
1229
- kv_lens_arr_host ,
1230
- batch_size , # total_num_rows
1231
- batch_size ,
1232
- num_qo_heads ,
1233
- num_kv_heads ,
1234
- page_size ,
1235
- self .is_cuda_graph_enabled ,
1236
- head_dim ,
1237
- head_dim ,
1238
- False , # causal
1239
- torch .cuda .current_stream ().cuda_stream ,
1240
- )
1241
- else :
1242
- self ._plan_info = self ._cached_module .plan (
1243
- self ._float_workspace_buffer ,
1244
- self ._int_workspace_buffer ,
1245
- self ._pin_memory_int_workspace_buffer ,
1246
- indptr_host ,
1247
- batch_size ,
1248
- num_qo_heads ,
1249
- num_kv_heads ,
1250
- page_size ,
1251
- self .is_cuda_graph_enabled ,
1252
- window_left ,
1253
- logits_soft_cap ,
1254
- head_dim ,
1255
- head_dim ,
1256
- self .empty_q_data ,
1257
- self .empty_kv_cache ,
1258
- torch .cuda .current_stream ().cuda_stream ,
1259
- )
1234
+ with torch .cuda .device (self .device ):
1235
+
1236
+ if self .use_tensor_cores :
1237
+ # ALSO convert last_page_len to CPU
1238
+ last_page_len_host = last_page_len .cpu ()
1239
+
1240
+ kv_lens_arr_host = get_seq_lens (indptr_host , last_page_len_host , page_size )
1241
+
1242
+ try :
1243
+ # Make sure we pass exactly 15 arguments for tensor core version
1244
+ self ._plan_info = self ._cached_module .plan (
1245
+ self ._float_workspace_buffer ,
1246
+ self ._int_workspace_buffer ,
1247
+ self ._pin_memory_int_workspace_buffer ,
1248
+ qo_indptr_host ,
1249
+ indptr_host ,
1250
+ kv_lens_arr_host ,
1251
+ batch_size , # total_num_rows
1252
+ batch_size ,
1253
+ num_qo_heads ,
1254
+ num_kv_heads ,
1255
+ page_size ,
1256
+ self .is_cuda_graph_enabled ,
1257
+ head_dim ,
1258
+ head_dim ,
1259
+ False , # causal
1260
+ )
1261
+ except Exception as e :
1262
+ raise RuntimeError (f"Error in standard plan: { e } " )
1263
+ else :
1264
+ try :
1265
+ # Make sure we pass exactly 15 arguments for standard version
1266
+ self ._plan_info = self ._cached_module .plan (
1267
+ self ._float_workspace_buffer ,
1268
+ self ._int_workspace_buffer ,
1269
+ self ._pin_memory_int_workspace_buffer ,
1270
+ indptr_host ,
1271
+ batch_size ,
1272
+ num_qo_heads ,
1273
+ num_kv_heads ,
1274
+ page_size ,
1275
+ self .is_cuda_graph_enabled ,
1276
+ window_left ,
1277
+ logits_soft_cap ,
1278
+ head_dim ,
1279
+ head_dim ,
1280
+ empty_q_data ,
1281
+ empty_kv_cache ,
1282
+ )
1283
+ except Exception as e :
1284
+ raise RuntimeError (f"Error in standard plan: { e } " )
1260
1285
1261
1286
self ._pos_encoding_mode = pos_encoding_mode
1262
1287
self ._window_left = window_left
0 commit comments