Open
Description
Checklist
- 1. I have searched related issues but cannot get the expected help.
- 2. The bug has not been fixed in the latest version.
- 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
- 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
- 5. Please use English, otherwise it will be closed.
Describe the bug
Llama-4-Scout-17B-16E-Instruct would raise CUDA OOM error during our image benchmark.
Reproduction
Server start command:
python3 -m sglang.launch_server --model meta-llama/Llama-4-Scout-17B-16E-Instruct --tp-size=4 --host=0.0.0.0 --mem-fraction-static=0.95 --context-length=196608 --enable-multimodal --tool-call-parser=pythonic --chat-template=examples/chat_template/tool_chat_template_llama4_pythonic.jinja --disable-radix-cache
SGLANG_VLM_CACHE_SIZE_MB: 100
Image benchmark:
We used our production benchmark tool. It would try to send request containing one image with different image size and concurrency. Here is the pesudo code of the logic:
for image in size(512*512, 2048*2048, 4096*4096):
# generate an openAI ChatCompletionRequest with the image and a prompt with questions about the content of the image
for concurrency in [1,4,8,16,64,128]:
# use locust to simulate *concurrency* concurrent requests to the server
# record metrics
At 4096*4096, concurrency 64, the server would have OOM:
[2025-06-05 14:56:29] INFO: 127.0.0.1:40484 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-06-05 14:56:29 TP0] Decode batch. #running-req: 62, #token: 167777, token usage: 0.81, cuda graph: True, gen throughput (token/s): 896.73, #queue-req: 2
[2025-06-05 14:56:29 TP0] Prefill batch. #new-seq: 2, #new-token: 4992, #cached-token: 0, token usage: 0.81, #running-req: 61, #queue-req: 0
[2025-06-05 14:56:29] INFO: 127.0.0.1:40488 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-06-05 14:56:30 TP0] Prefill batch. #new-seq: 1, #new-token: 2489, #cached-token: 0, token usage: 0.83, #running-req: 63, #queue-req: 0
[2025-06-05 14:56:30 TP0] TpModelWorkerClient hit an exception: Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 118, in forward_thread_func
self.forward_thread_func_()
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 151, in forward_thread_func_
self.worker.forward_batch_generation(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 202, in forward_batch_generation
logits_output, can_run_cuda_graph = self.model_runner.forward(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1199, in forward
output = self._forward_raw(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1228, in _forward_raw
ret = self.forward_extend(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1167, in forward_extend
return self.model.forward(
File "/sgl-workspace/sglang/python/sglang/srt/models/mllama4.py", line 83, in forward
hs = general_mm_embed_routine(
File "/sgl-workspace/sglang/python/sglang/srt/managers/mm_utils.py", line 602, in general_mm_embed_routine
inputs_embeds = embed_mm_inputs(
File "/sgl-workspace/sglang/python/sglang/srt/managers/mm_utils.py", line 481, in embed_mm_inputs
embedding, mask = get_embedding_and_mask(
File "/sgl-workspace/sglang/python/sglang/srt/managers/mm_utils.py", line 393, in get_embedding_and_mask
embedding = _get_chunked_prefill_embedding(
File "/sgl-workspace/sglang/python/sglang/srt/managers/mm_utils.py", line 297, in _get_chunked_prefill_embedding
embedding_per_req = data_embedding_func(embedding_items_per_req)
File "/sgl-workspace/sglang/python/sglang/srt/models/mllama4.py", line 69, in get_image_feature
image_outputs = self.vision_model(pixel_values, output_hidden_states=False)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama4/modeling_llama4.py", line 1448, in forward
output = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama4/modeling_llama4.py", line 1284, in forward
layer_outputs = encoder_layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama4/modeling_llama4.py", line 1195, in forward
hidden_state, attn_weights = self.self_attn(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama4/modeling_llama4.py", line 1140, in forward
attn_output, attn_weights = attention_interface(
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama4/modeling_llama4.py", line 275, in vision_eager_attention_forward
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.head_dim**-0.5
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 174.00 MiB. GPU 0 has a total capacity of 79.44 GiB of which 137.00 MiB is free. Process 235237 has 946.00 MiB memory in use. Process 235597 has 78.37 GiB memory in use. Of the allocated memory 74.79 GiB is allocated by PyTorch, with 38.04 MiB allocated in private pools (e.g., CUDA Graphs), and 811.92 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[2025-06-05 14:56:30] Received sigquit from a child process. It usually means the child failed.
Environment
Python: 3.10.16 (main, Dec 4 2024, 08:53:37) [GCC 9.4.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA H100 80GB HBM3
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.4, V12.4.131
CUDA Driver Version: 560.35.05
PyTorch: 2.6.0+cu124
sglang: 0.4.6.post5
sgl_kernel: 0.1.5
flashinfer_python: 0.2.5
triton: 3.2.0
transformers: 4.52.3
torchao: 0.9.0
numpy: 1.26.4
aiohttp: 3.11.11
fastapi: 0.115.6
hf_transfer: 0.1.8
huggingface_hub: 0.30.1
interegular: 0.3.3
modelscope: 1.21.1
orjson: 3.10.13
outlines: 0.0.46
packaging: 24.2
psutil: 6.1.1
pydantic: 2.10.4
python-multipart: 0.0.20
pyzmq: 26.2.0
uvicorn: 0.34.0
uvloop: 0.21.0
vllm: 0.6.4.post1
xgrammar: 0.1.19
openai: 1.59.3
tiktoken: 0.7.0
anthropic: 0.42.0
litellm: 1.56.10
decord: 0.6.0
NVIDIA Topology:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 NIC4 NIC5 NIC6 NIC7 NIC8 NIC9 NIC10 NIC11 NIC12 NIC13 NIC14 NIC15 NIC16 NIC17 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18 PXB PXB NODE NODE NODE NODE NODE NODE NODE SYS SYSSYS SYS SYS SYS SYS SYS SYS 0-55,112-167 0 N/A
GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18 NODE NODE NODE PXB PXB NODE NODE NODE NODE SYS SYSSYS SYS SYS SYS SYS SYS SYS 0-55,112-167 0 N/A
GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18 NODE NODE NODE NODE NODE PXB PXB NODE NODE SYS SYSSYS SYS SYS SYS SYS SYS SYS 0-55,112-167 0 N/A
GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18 NODE NODE NODE NODE NODE NODE NODE PXB PXB SYS SYSSYS SYS SYS SYS SYS SYS SYS 0-55,112-167 0 N/A
GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18 SYS SYS SYS SYS SYS SYS SYS SYS SYS PXB PXBNODE NODE NODE NODE NODE NODE NODE 56-111,168-223 1 N/A
GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18 SYS SYS SYS SYS SYS SYS SYS SYS SYS NODE NODE NODE PXB PXB NODE NODE NODE NODE 56-111,168-223 1 N/A
GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18 SYS SYS SYS SYS SYS SYS SYS SYS SYS NODE NODE NODE NODE NODE PXB PXB NODE NODE 56-111,168-223 1 N/A
GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X SYS SYS SYS SYS SYS SYS SYS SYS SYS NODE NODE NODE NODE NODE NODE NODE PXB PXB 56-111,168-223 1 N/A
NIC0 PXB NODE NODE NODE SYS SYS SYS SYS X PIX NODE NODE NODE NODE NODE NODE NODE SYS SYSSYS SYS SYS SYS SYS SYS SYS
NIC1 PXB NODE NODE NODE SYS SYS SYS SYS PIX X NODE NODE NODE NODE NODE NODE NODE SYS SYSSYS SYS SYS SYS SYS SYS SYS
NIC2 NODE NODE NODE NODE SYS SYS SYS SYS NODE NODE X NODE NODE NODE NODE NODE NODE SYS SYSSYS SYS SYS SYS SYS SYS SYS
NIC3 NODE PXB NODE NODE SYS SYS SYS SYS NODE NODE NODE X PIX NODE NODE NODE NODE SYS SYSSYS SYS SYS SYS SYS SYS SYS
NIC4 NODE PXB NODE NODE SYS SYS SYS SYS NODE NODE NODE PIX X NODE NODE NODE NODE SYS SYSSYS SYS SYS SYS SYS SYS SYS
NIC5 NODE NODE PXB NODE SYS SYS SYS SYS NODE NODE NODE NODE NODE X PIX NODE NODE SYS SYSSYS SYS SYS SYS SYS SYS SYS
NIC6 NODE NODE PXB NODE SYS SYS SYS SYS NODE NODE NODE NODE NODE PIX X NODE NODE SYS SYSSYS SYS SYS SYS SYS SYS SYS
NIC7 NODE NODE NODE PXB SYS SYS SYS SYS NODE NODE NODE NODE NODE NODE NODE X PIX SYS SYSSYS SYS SYS SYS SYS SYS SYS
NIC8 NODE NODE NODE PXB SYS SYS SYS SYS NODE NODE NODE NODE NODE NODE NODE PIX X SYS SYSSYS SYS SYS SYS SYS SYS SYS
NIC9 SYS SYS SYS SYS PXB NODE NODE NODE SYS SYS SYS SYS SYS SYS SYS SYS SYS X PIXNODE NODE NODE NODE NODE NODE NODE
NIC10 SYS SYS SYS SYS PXB NODE NODE NODE SYS SYS SYS SYS SYS SYS SYS SYS SYS PIX X NODE NODE NODE NODE NODE NODE NODE
NIC11 SYS SYS SYS SYS NODE NODE NODE NODE SYS SYS SYS SYS SYS SYS SYS SYS SYS NODE NODE X NODE NODE NODE NODE NODE NODE
NIC12 SYS SYS SYS SYS NODE PXB NODE NODE SYS SYS SYS SYS SYS SYS SYS SYS SYS NODE NODE NODE X PIX NODE NODE NODE NODE
NIC13 SYS SYS SYS SYS NODE PXB NODE NODE SYS SYS SYS SYS SYS SYS SYS SYS SYS NODE NODE NODE PIX X NODE NODE NODE NODE
NIC14 SYS SYS SYS SYS NODE NODE PXB NODE SYS SYS SYS SYS SYS SYS SYS SYS SYS NODE NODE NODE NODE NODE X PIX NODE NODE
NIC15 SYS SYS SYS SYS NODE NODE PXB NODE SYS SYS SYS SYS SYS SYS SYS SYS SYS NODE NODE NODE NODE NODE PIX X NODE NODE
NIC16 SYS SYS SYS SYS NODE NODE NODE PXB SYS SYS SYS SYS SYS SYS SYS SYS SYS NODE NODE NODE NODE NODE NODE NODE X PIX
NIC17 SYS SYS SYS SYS NODE NODE NODE PXB SYS SYS SYS SYS SYS SYS SYS SYS SYS NODE NODE NODE NODE NODE NODE NODE PIX X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NIC Legend:
NIC0: mlx5_0
NIC1: mlx5_1
NIC2: mlx5_2
NIC3: mlx5_3
NIC4: mlx5_4
NIC5: mlx5_5
NIC6: mlx5_6
NIC7: mlx5_7
NIC8: mlx5_8
NIC9: mlx5_9
NIC10: mlx5_10
NIC11: mlx5_11
NIC12: mlx5_12
NIC13: mlx5_13
NIC14: mlx5_14
NIC15: mlx5_15
NIC16: mlx5_16
NIC17: mlx5_17
ulimit soft: 65535