Skip to content

Commit f727c84

Browse files
Fridge003AkazaAkane
authored andcommitted
Bump Flashinfer to 0.2.5 (sgl-project#5870)
Co-authored-by: Yuhao Chen <[email protected]>
1 parent 311d1c5 commit f727c84

File tree

6 files changed

+138
-104
lines changed

6 files changed

+138
-104
lines changed

.github/workflows/pr-test.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,6 @@ jobs:
9696
uses: actions/checkout@v4
9797

9898
- name: Install dependencies
99-
env:
100-
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
10199
run: |
102100
bash scripts/ci_install_dependency.sh
103101

docs/start/install.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,4 +164,4 @@ sky status --endpoint 30000 sglang
164164
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please switch to other kernels by adding `--attention-backend triton --sampling-backend pytorch` and open an issue on GitHub.
165165
- If you only need to use OpenAI models with the frontend language, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.
166166
- The language frontend operates independently of the backend runtime. You can install the frontend locally without needing a GPU, while the backend can be set up on a GPU-enabled machine. To install the frontend, run `pip install sglang`, and for the backend, use `pip install sglang[srt]`. `srt` is the abbreviation of SGLang runtime.
167-
- To reinstall flashinfer locally, use the following command: `pip install "flashinfer-python==0.2.3" -i https://flashinfer.ai/whl/cu124/torch2.6 --force-reinstall --no-deps` and then delete the cache with `rm -rf ~/.cache/flashinfer`.
167+
- To reinstall flashinfer locally, use the following command: `pip install "flashinfer-python==0.2.5" -i https://flashinfer.ai/whl/cu124/torch2.6 --force-reinstall --no-deps` and then delete the cache with `rm -rf ~/.cache/flashinfer`.

python/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ runtime_common = [
3737
"python-multipart",
3838
"pyzmq>=25.1.2",
3939
"soundfile==0.13.1",
40-
"torchao>=0.7.0",
40+
"torchao>=0.9.0",
4141
"transformers==4.51.1",
4242
"uvicorn",
4343
"uvloop",
@@ -47,7 +47,7 @@ runtime_common = [
4747
srt = [
4848
"sglang[runtime_common]",
4949
"sgl-kernel==0.1.0",
50-
"flashinfer_python==0.2.3",
50+
"flashinfer_python==0.2.5",
5151
"torch==2.6.0",
5252
"torchvision==0.21.0",
5353
"cuda-python",

python/sglang/srt/entrypoints/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def _set_envs_and_config(server_args: ServerArgs):
453453
if server_args.attention_backend == "flashinfer":
454454
assert_pkg_version(
455455
"flashinfer_python",
456-
"0.2.3",
456+
"0.2.5",
457457
"Please uninstall the old version and "
458458
"reinstall the latest version by following the instructions "
459459
"at https://docs.flashinfer.ai/installation.html.",

python/sglang/srt/layers/attention/flashinfer_backend.py

Lines changed: 107 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515

1616
import torch
1717

18+
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
19+
import torch._dynamo
20+
21+
torch._dynamo.config.suppress_errors = True
22+
1823
from sglang.global_config import global_config
1924
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
2025
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
@@ -82,8 +87,6 @@ def __init__(
8287
self.max_context_len = model_runner.model_config.context_len
8388
self.skip_prefill = skip_prefill
8489
self.is_multimodal = model_runner.model_config.is_multimodal
85-
self.kv_cache_dtype = model_runner.kv_cache_dtype
86-
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
8790

8891
assert not (
8992
model_runner.sliding_window_size is not None
@@ -268,6 +271,12 @@ def init_cuda_graph_state(
268271
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
269272
]
270273

274+
# Ensure tensors are properly allocated
275+
for i in range(self.num_wrappers):
276+
# Force allocation by performing a small operation
277+
if len(self.cuda_graph_kv_indices[i]) > 0:
278+
self.cuda_graph_kv_indices[i][0] = 0
279+
271280
if not self.skip_prefill:
272281
self.cuda_graph_custom_mask = torch.zeros(
273282
(max_bs * self.max_context_len),
@@ -396,8 +405,6 @@ def forward_extend(
396405
forward_batch: ForwardBatch,
397406
save_kv_cache=True,
398407
):
399-
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
400-
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
401408
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
402409
self._get_wrapper_idx(layer)
403410
]
@@ -414,7 +421,7 @@ def forward_extend(
414421
assert v is not None
415422
if save_kv_cache:
416423
forward_batch.token_to_kv_pool.set_kv_buffer(
417-
layer, cache_loc, k, v, k_scale, v_scale
424+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
418425
)
419426

420427
o = prefill_wrapper_paged.forward(
@@ -424,8 +431,8 @@ def forward_extend(
424431
sm_scale=layer.scaling,
425432
window_left=layer.sliding_window_size,
426433
logits_soft_cap=logits_soft_cap,
427-
k_scale=k_scale,
428-
v_scale=v_scale,
434+
k_scale=layer.k_scale,
435+
v_scale=layer.v_scale,
429436
)
430437
else:
431438
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
@@ -452,7 +459,7 @@ def forward_extend(
452459

453460
if save_kv_cache:
454461
forward_batch.token_to_kv_pool.set_kv_buffer(
455-
layer, cache_loc, k, v, k_scale, v_scale
462+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
456463
)
457464

458465
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -466,8 +473,6 @@ def forward_decode(
466473
forward_batch: ForwardBatch,
467474
save_kv_cache=True,
468475
):
469-
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
470-
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
471476
decode_wrapper = self.forward_metadata.decode_wrappers[
472477
self._get_wrapper_idx(layer)
473478
]
@@ -481,16 +486,17 @@ def forward_decode(
481486
assert v is not None
482487
if save_kv_cache:
483488
forward_batch.token_to_kv_pool.set_kv_buffer(
484-
layer, cache_loc, k, v, k_scale, v_scale
489+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
485490
)
486491

492+
# Call the wrapped function
487493
o = decode_wrapper.forward(
488494
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
489495
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
490496
sm_scale=layer.scaling,
491497
logits_soft_cap=layer.logit_cap,
492-
k_scale=k_scale,
493-
v_scale=v_scale,
498+
k_scale=layer.k_scale,
499+
v_scale=layer.v_scale,
494500
)
495501

496502
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -1146,8 +1152,9 @@ def fast_decode_plan(
11461152
pos_encoding_mode: str = "NONE",
11471153
window_left: int = -1,
11481154
logits_soft_cap: Optional[float] = None,
1149-
data_type: Union[str, torch.dtype] = "float16",
11501155
q_data_type: Optional[Union[str, torch.dtype]] = None,
1156+
kv_data_type: Optional[Union[str, torch.dtype]] = None,
1157+
data_type: Optional[Union[str, torch.dtype]] = None,
11511158
sm_scale: Optional[float] = None,
11521159
rope_scale: Optional[float] = None,
11531160
rope_theta: Optional[float] = None,
@@ -1163,6 +1170,18 @@ def fast_decode_plan(
11631170
if logits_soft_cap is None:
11641171
logits_soft_cap = 0.0
11651172

1173+
# Handle data types consistently
1174+
if data_type is not None:
1175+
if q_data_type is None:
1176+
q_data_type = data_type
1177+
if kv_data_type is None:
1178+
kv_data_type = data_type
1179+
elif q_data_type is None:
1180+
q_data_type = "float16"
1181+
1182+
if kv_data_type is None:
1183+
kv_data_type = q_data_type
1184+
11661185
if self.use_tensor_cores:
11671186
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
11681187

@@ -1178,85 +1197,91 @@ def fast_decode_plan(
11781197
raise ValueError(
11791198
"The size of indices should be less than or equal to the allocated buffer"
11801199
)
1181-
# Skip these copies because we directly write to them during prepartion
1182-
# self._paged_kv_indptr_buf.copy_(indptr)
1183-
# self._paged_kv_indices_buf[: len(indices)] = indices
1184-
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
11851200
else:
11861201
self._paged_kv_indptr_buf = indptr
11871202
self._paged_kv_indices_buf = indices
11881203
self._paged_kv_last_page_len_buf = last_page_len
1189-
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)
1190-
1191-
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
1192-
if not q_data_type:
1193-
q_data_type = data_type
1194-
1195-
if not hasattr(self, "empty_q_data"):
1196-
self.empty_q_data = torch.empty(
1197-
0,
1198-
dtype=(
1199-
getattr(torch, q_data_type)
1200-
if isinstance(q_data_type, str)
1201-
else q_data_type
1202-
),
1203-
)
1204-
self.empty_kv_cache = torch.empty(
1205-
0,
1206-
dtype=(
1207-
getattr(torch, data_type) if isinstance(data_type, str) else data_type
1208-
),
1209-
)
1210-
self.last_page_len = torch.ones(32768, dtype=torch.int32)
1204+
if self.use_tensor_cores:
1205+
self._qo_indptr_buf = qo_indptr_host.to(
1206+
self.device, non_blocking=non_blocking
1207+
)
1208+
1209+
# Create empty tensors for dtype info if needed
1210+
empty_q_data = torch.empty(
1211+
0,
1212+
dtype=(
1213+
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
1214+
),
1215+
device=self.device,
1216+
)
1217+
1218+
empty_kv_cache = torch.empty(
1219+
0,
1220+
dtype=(
1221+
getattr(torch, kv_data_type)
1222+
if isinstance(kv_data_type, str)
1223+
else kv_data_type
1224+
),
1225+
device=self.device,
1226+
)
12111227

12121228
indptr_host = (
12131229
global_override_indptr_cpu
12141230
if global_override_indptr_cpu is not None
12151231
else indptr.cpu()
12161232
)
12171233

1218-
if self.use_tensor_cores:
1219-
kv_lens_arr_host = get_seq_lens(
1220-
indptr_host, self.last_page_len[:batch_size], page_size
1221-
)
1222-
1223-
self._plan_info = self._cached_module.plan(
1224-
self._float_workspace_buffer,
1225-
self._int_workspace_buffer,
1226-
self._pin_memory_int_workspace_buffer,
1227-
qo_indptr_host,
1228-
indptr_host,
1229-
kv_lens_arr_host,
1230-
batch_size, # total_num_rows
1231-
batch_size,
1232-
num_qo_heads,
1233-
num_kv_heads,
1234-
page_size,
1235-
self.is_cuda_graph_enabled,
1236-
head_dim,
1237-
head_dim,
1238-
False, # causal
1239-
torch.cuda.current_stream().cuda_stream,
1240-
)
1241-
else:
1242-
self._plan_info = self._cached_module.plan(
1243-
self._float_workspace_buffer,
1244-
self._int_workspace_buffer,
1245-
self._pin_memory_int_workspace_buffer,
1246-
indptr_host,
1247-
batch_size,
1248-
num_qo_heads,
1249-
num_kv_heads,
1250-
page_size,
1251-
self.is_cuda_graph_enabled,
1252-
window_left,
1253-
logits_soft_cap,
1254-
head_dim,
1255-
head_dim,
1256-
self.empty_q_data,
1257-
self.empty_kv_cache,
1258-
torch.cuda.current_stream().cuda_stream,
1259-
)
1234+
with torch.cuda.device(self.device):
1235+
1236+
if self.use_tensor_cores:
1237+
# ALSO convert last_page_len to CPU
1238+
last_page_len_host = last_page_len.cpu()
1239+
1240+
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
1241+
1242+
try:
1243+
# Make sure we pass exactly 15 arguments for tensor core version
1244+
self._plan_info = self._cached_module.plan(
1245+
self._float_workspace_buffer,
1246+
self._int_workspace_buffer,
1247+
self._pin_memory_int_workspace_buffer,
1248+
qo_indptr_host,
1249+
indptr_host,
1250+
kv_lens_arr_host,
1251+
batch_size, # total_num_rows
1252+
batch_size,
1253+
num_qo_heads,
1254+
num_kv_heads,
1255+
page_size,
1256+
self.is_cuda_graph_enabled,
1257+
head_dim,
1258+
head_dim,
1259+
False, # causal
1260+
)
1261+
except Exception as e:
1262+
raise RuntimeError(f"Error in standard plan: {e}")
1263+
else:
1264+
try:
1265+
# Make sure we pass exactly 15 arguments for standard version
1266+
self._plan_info = self._cached_module.plan(
1267+
self._float_workspace_buffer,
1268+
self._int_workspace_buffer,
1269+
self._pin_memory_int_workspace_buffer,
1270+
indptr_host,
1271+
batch_size,
1272+
num_qo_heads,
1273+
num_kv_heads,
1274+
page_size,
1275+
self.is_cuda_graph_enabled,
1276+
window_left,
1277+
logits_soft_cap,
1278+
head_dim,
1279+
head_dim,
1280+
empty_q_data,
1281+
empty_kv_cache,
1282+
)
1283+
except Exception as e:
1284+
raise RuntimeError(f"Error in standard plan: {e}")
12601285

12611286
self._pos_encoding_mode = pos_encoding_mode
12621287
self._window_left = window_left

0 commit comments

Comments
 (0)