Skip to content

Perform Batch Tokenization. #5141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 21, 2025

Conversation

sundar24295s
Copy link
Contributor

@sundar24295s sundar24295s commented Apr 7, 2025

Motivation

  • The current tokenizer_manager processes batch workloads by sequentially tokenizing each prompt before sending it to the DP. This creates overhead in the tokenization phase.
  • We can improve tokenization performance by leveraging the batch tokenization capabilities of fast tokenizers to process multiple prompts simultaneously before distributing them across DPs.

Modifications

  • This PR introduces a batch tokenization option in TokenizerManager, controlled by server_args.enable_tokenizer_batch_encode.
    • When enabled, it tokenizes all text inputs in a single pass for generation requests, while other input types—such as input_ids, input_embeds, or image data—continue to follow the existing sequential process.
    • This optimization is particularly beneficial for prefill-heavy workloads with smaller batch sizes
  • Also, added couple of benchmarking scripts:
    • To benchmark the tokenizer with batch prompts.
    • To benchmark sending batch requests to the /generate endpoint.

Future Work

  • The current implementation tokenizes the entire batch at once which is suitable for usecases sending prompts in smaller batch sizes. But, which may not scale well for very large batch sizes (e.g., 1000).
  • Future improvements could include:
    • Splitting large batches into manageable chunks, tokenize the chunk and send them to the DPs
  • These enhancements will be explored in future iterations to support a wider range of use cases

Benchmarks

  • All benchmarks are performed on H100s.

Batch Tokenization

(venv) jobuser [ ~/sglang ]$ python3.10 benchmark/benchmark_batch/benchmark_tokenizer.py
Tokenizer Benchmark: Sequential vs Batch Processing
------------------------------------------------------------
Tokenizer: /shared/public/sharing/fait360brew/training/models/meta-llama/Llama-3.2-3B
Tokens per prompt: 20000
Number of runs per batch size: 5
------------------------------------------------------------
Generating 8 random prompts with 20000 tokens each...
  Prompt 0: 20905 tokens
  Prompt 1: 20867 tokens
  Prompt 2: 20889 tokens
  Prompt 3: 20882 tokens
  Prompt 4: 20786 tokens
  Prompt 5: 20891 tokens
  Prompt 6: 20876 tokens
  Prompt 7: 20835 tokens

Running benchmark...
.
.
.
============================================================
SUMMARY OF RESULTS
============================================================
Batch Size Sequential (ms)    Batch (ms)         Speedup   
------------------------------------------------------------
1          33.23 ms         33.15 ms         1.00x
2          67.28 ms         39.67 ms         1.70x
4          159.67 ms         57.98 ms         2.75x
8          351.50 ms         67.81 ms         5.18x

Bechmark Batch prefill

  • The following benchmark is to paint a picture on how much overhead we can save for a batched request if we perform batch tokenization.
  • Launch Server Command
  (venv) jobuser [ ~/sglang ]$ python -m sglang.launch_server --model-path /models/meta-llama/Llama-3.2-3B --port 30000 --host 0.0.0.0 --disable-radix-cache --disable-cuda-graph --max-prefill-tokens 131072 --chunked-prefill-size 131072 --tp 1 --dp 8

Results using exisitng Sequential Tokenization

(venv) jobuser [ ~/sglang ]$ python3.10 benchmark/benchmark_batch/benchmark_batch.py 
.
.
.
Generated 480 prompts with 32000 tokens each, grouped into 60 requests of 8 prompts.

Starting benchmark: NUM_TOKENS=32000, BATCH_SIZE=8, NUM_REQUESTS=60

[Request] Sending request 1/10 with 8 prompts at 1744061155472
.
.
.

Benchmark Summary:
  Total requests sent:         10
  Total prompts sent:          80
  Successful requests:         10
  Failed requests:             0
  Total latency (all requests): 23658.48 ms
  Avg per request latency:     2365.73 ms
  Avg per prompt latency:      295.72 ms
  Throughput:                  3.38 prompts/second

Results using exisitng Batch Tokenization

  • Launch Server Command
python -m sglang.launch_server --model-path /models/meta-llama/Llama-3.2-3B --port 30000 --host 0.0.0.0 --disable-radix-cache --disable-cuda-graph --max-prefill-tokens 131072 --chunked-prefill-size 131072 --tp 1 --dp 8 --enable-tokenizer-batch-encode
(venv) jobuser [ ~/sglang ]$ python3.10 benchmark/benchmark_batch/benchmark_batch.py 
.
.
.
Generated 480 prompts with 32000 tokens each, grouped into 60 requests of 8 prompts.

Starting benchmark: NUM_TOKENS=32000, BATCH_SIZE=8, NUM_REQUESTS=60

[Request] Sending request 1/10 with 8 prompts at 1744061155472
.
.
.
Benchmark Summary:
  Total requests sent:         60
  Total prompts sent:          480
  Successful requests:         60
  Failed requests:             0
  Total latency (all requests): 126336.63 ms
  Avg per request latency:     2105.17 ms
  Avg per prompt latency:      263.15 ms
  Throughput:                  3.80 prompts/second
  • From the above benchmark we can see a good chunk of 30ms saved per prompt in a batch as measured from the client side.

Checklist

@hebiao064
Copy link
Collaborator

Is CI down now @zhyncs

@hebiao064
Copy link
Collaborator

@sundar24295s qq about this:

This optimization is particularly beneficial for prefill-heavy workloads with smaller batch sizes

I feel like larger batch size will be more beneficial. Could you please explain a little bit? Thanks

@zhyncs
Copy link
Member

zhyncs commented Apr 7, 2025

Is CI down now @zhyncs

yeah just wait for @merrymercy

@sundar24295s
Copy link
Contributor Author

@hebiao064

@sundar24295s qq about this:

This optimization is particularly beneficial for prefill-heavy workloads with smaller batch sizes

I feel like larger batch size will be more beneficial. Could you please explain a little bit? Thanks

  • The current implementation tokenizes the entire batch. If the batch size is larger (say 400 or so - 400 is a rough number, it depends on the model, token length per prompt within a batch etc...), then we wait for entire batch to be tokenized before forwarding the request to GPU. During this time, GPU is idle.
  • The ideal implementation would be split the big batch into smaller batches for tokenizer, tokenize them and send to the DPs which I have called out in the future work.

@hebiao064
Copy link
Collaborator

@hebiao064

@sundar24295s qq about this:

This optimization is particularly beneficial for prefill-heavy workloads with smaller batch sizes

I feel like larger batch size will be more beneficial. Could you please explain a little bit? Thanks

  • The current implementation tokenizes the entire batch. If the batch size is larger (say 400 or so - 400 is a rough number, it depends on the model, token length per prompt within a batch etc...), then we wait for entire batch to be tokenized before forwarding the request to GPU. During this time, GPU is idle.
  • The ideal implementation would be split the big batch into smaller batches for tokenizer, tokenize them and send to the DPs which I have called out in the future work.

Thanks, very clear!

@sundar24295s
Copy link
Contributor Author

  • 3 tests failed, not related to the changes in the current PR, looks flaky.
    @zhyncs Can you take a look at the PR?

@sundar24295s
Copy link
Contributor Author

There were 3 unrelated unit test failures.

@zhyncs / @xiezhq-hermann Can you take a look at this PR?

@merrymercy merrymercy merged commit f081541 into sgl-project:main Apr 21, 2025
22 of 23 checks passed
tarinkk pushed a commit to Pb314314/sglang that referenced this pull request Apr 21, 2025
pi314ever pushed a commit to pi314ever/sglang that referenced this pull request May 16, 2025
* fix: update pr-test-sgl-kernel (sgl-project#5399)

* kernel: support slightly faster merge_state_v2 cuda kernel (sgl-project#5381)

* chore: bump sgl-kernel 0.0.9 (sgl-project#5400)

* chore: upgrade sgl-kernel 0.0.9 (sgl-project#5401)

* Tiny fix DeepseekScalingRotaryEmbedding always use forward_native (sgl-project#5406)

* Fix bench_serving with random-ids (sgl-project#5214)

* [misc] fix ci flaky case (sgl-project#5352)

* [FIX] Fix concatenation error in capture_bs when open --disable-cuda-graph-padding and without MTP (sgl-project#5412)

* Support dynamic connection and TP 16 (sgl-project#5351)

Co-authored-by: luoyuan.luo <[email protected]>

* Fix broadcast use cuda device lead to memory capacity unbalanced (sgl-project#5416)

* [PD] Fix dynamic port support and MLA buffer for Mooncake (sgl-project#5415)

Signed-off-by: Shangming Cai <[email protected]>
Co-authored-by: ybyang <[email protected]>

* Distinguish bootstrap key only in decode server (sgl-project#5422)

* [PD] Remove unused bootstrap param and fix port table type (sgl-project#5423)

* [minor] cleanup cmakelists.txt (sgl-project#5420)

* bugfix: fix merge_state_v2 cuda graph (sgl-project#5419)

* chore: bump sgl-kernel v0.0.9.post1 (sgl-project#5430)

* fix: solve release issue (sgl-project#5434)

* BLackwell cutlass mla: Add check for bad page size/block num combinations (sgl-project#5431)

* feat: update model_specific_adjustment (sgl-project#5344)

Co-authored-by: hebiao064 <[email protected]>

* chore: upgrade sgl-kernel 0.0.9.post1 (sgl-project#5436)

* Fix ignore_eos parameter when loading a chat template (sgl-project#5264)

* add attention backend supporting matrix in the doc (sgl-project#5211)

Co-authored-by: Stefan He <[email protected]>

* Support BNB quantization for llama/mllama (sgl-project#5038)

Co-authored-by: Yuhao Yang <[email protected]>

* [Docs] Update start/install.md (sgl-project#5398)

* [Minor] Move torch.compile patch to a better place (sgl-project#5397)

* [Bug fix] need record start time in pd mode (sgl-project#5425)

* Support MHA with chunked prefix cache for DeepSeek chunked prefill (sgl-project#5113)

* chore: bump v0.4.5.post1 (sgl-project#5445)

* Fix several minor issues in PD disaggregation (sgl-project#5444)

* [doc] Update benchmark_and_profiling.md (sgl-project#5449)

* Update cutlass dependency. (sgl-project#5447)

* add multi-lora feature in README.md (sgl-project#5463)

* Clean up imports (sgl-project#5467)

* [verl] Modify the update_weights func to align with verl's resharding (sgl-project#5345)

Co-authored-by: Chayenne <[email protected]>

* [Model Support] unsloth/Phi-4-mini bnb model (sgl-project#4982)

Co-authored-by: yhyang201 <[email protected]>
Co-authored-by: Liangsheng Yin <[email protected]>
Co-authored-by: Chayenne <[email protected]>
Co-authored-by: Yineng Zhang <[email protected]>

* Update attention_backend.md: plural form (sgl-project#5489)

* Add test for flash_attn_varlen_func kernel (sgl-project#5484)

* Deprecate disable-mla (sgl-project#5481)

* Deprecate enable-flashinfer-mla and enable-flashmla (sgl-project#5480)

* Feat/support encoder model (like bert) (sgl-project#4887)

* Enable local attention during decode (sgl-project#5479)

* Refactor DeepSeek decoder layer branches (sgl-project#5205)

* Fix a link in sgl-kernel/README.md (sgl-project#5493)

* [Bug fix] use correct func path in deepseek (sgl-project#5496)

Signed-off-by: Xuchun Shang <[email protected]>

* Doc: fix problems of the 'Execute Notebooks / run-all-notebooks' ci caused by the unstability of deepseek-ai/DeepSeek-R1-Distill-Qwen-7B (sgl-project#5503)

* [Feat] Update sgl-kernel flashinfer to latest main version (sgl-project#5500)

Co-authored-by: zhyncs <[email protected]>

* Fix: Incorrect parameters passed to forward_batch_generation (sgl-project#5506) (sgl-project#5511)

* Fix: fix the exception 'the memory capacity is unbalanced. Some GPUs … (sgl-project#5426)

Co-authored-by: ocss884 <[email protected]>

* [docs] Fix several consistency issues in sampling_params.md (sgl-project#5373)

Signed-off-by: windsonsea <[email protected]>
Co-authored-by: Baizhou Zhang <[email protected]>

* Configuration qwen2_moe.py - qkv_bias now in transformers (sgl-project#5512)

* Introduce moe_dense_tp_size to fix dense layer errors in DeepSeek V3 + 4x8xH100 (sgl-project#4836)

* Sgl kernel fused_moe_gate support n_shared_experts (sgl-project#5440)

* chore: bump sgl-kernel 0.0.9.post2 (sgl-project#5518)

* use sglang_per_token_group_quant_fp8 from sgl-kernel instead of trion kernel (sgl-project#5473)

Co-authored-by: Zhang Kaihong <[email protected]>

* fix kimi vl running bug after rebase main (sgl-project#5461)

* fix bug of VLLM_AVAILABLE not defined (sgl-project#5497)

* Avoid computing lse in Ragged Prefill when there's no prefix. (sgl-project#5476)

Co-authored-by: Baizhou Zhang <[email protected]>

* [Model] Adding Qwen3 and Qwen3MoE (sgl-project#4693)

* fix util import (sgl-project#5542)

* Revert "Avoid computing lse in Ragged Prefill when there's no prefix.… (sgl-project#5544)

* chore: upgrade sgl-kernel 0.0.9.post2 (sgl-project#5540)

* Fix DeepGEMM masked cannot be run on groups not being multiple or 4 (sgl-project#5340)

* Make profiler output file names consistent (sgl-project#5548)

* [PD] Tiny fix timeout error when generate (sgl-project#5545)

* [PD] Fix no cache connect for recevier (sgl-project#5534)

* feat: use flashinfer jit package (sgl-project#5547)

* [PD] Remove the requirement of config file for mooncake backend  (sgl-project#5460)

* restruct compressed_tensors_w8a8_fp8 (sgl-project#5475)

* simplify the control logic for using shared experts fusion (sgl-project#5504)

* Remove one kernel in per_tensor_quant_mla_fp8 (sgl-project#5549)

* Fix sampler nan check when calling top_k_top_p_sampling_from_probs (sgl-project#5546)

* [PD] Support page size > 1 (sgl-project#5561)

* fix hicache write back (sgl-project#5543)

* Minor update for ROCm variable style (sgl-project#5562)

* Fix bench_one_batch producing unnatural results for expert parallel (sgl-project#5149)

* [perf] introduce deep gemm group_gemm_masked as bmm (sgl-project#5432)

* [PD] Fix DeepSeek cannot be run on latest master (sgl-project#5568)

* Fix BumpAllocator error when no input_ids (sgl-project#5564)

* enable DeepSeek V3 shared_experts_fusion in sm90 (sgl-project#5571)

* [Fix] fix outlines and xgrammar (sgl-project#4947)

* [Doc]Add instruction for profiling with bench_one_batch (sgl-project#5581)

* Release v0.4.5.post2 (sgl-project#5582)

* Fix bench_serving fail when zero warmup requests (sgl-project#5574)

* Fix DeepEP cannot run on latest master (sgl-project#5567)

* Fix torch memory saver not enabled in DP scenario (sgl-project#5560)

* Super tiny fix typo (sgl-project#5559)

* Add document for LoRA serving (sgl-project#5521)

* Tiny improve error message (sgl-project#5526)

* [PD] Fix server crash when using batch requests (sgl-project#5531)

* [Feat] upgrade pytorch2.6 (sgl-project#5417)

* Fix enable chunked prefill for Llama4 (sgl-project#5575)

* fix: use fa3 for gemma2 (sgl-project#5586)

* Fix ChatCompletionMessageGenericParam to allow for None content (sgl-project#5452)

* [PD] Fix large page size + chunk prefill (sgl-project#5588)

* Add test config yamls for Deepseek v3 (sgl-project#5433)

* [Feature] Prefill assistant response - add continue_final_message parameter (sgl-project#4226)

Co-authored-by: Chayenne <[email protected]>

* add function call parser for DeepSeek V3 (sgl-project#5224)

* smaller and non gated models for docs (sgl-project#5378)

* Feat: Implement JSON Mode (response_format.type="json_object") (sgl-project#4733)

Co-authored-by: Kyle Pena <[email protected]>

* check marlin format before attempting conversion (sgl-project#4675)

* compressed_tensors: port w8a16 fp8 from vllm (sgl-project#4852)

* Fix one more issue reported by torchfix (sgl-project#4859)

* Add sanity check for max_running_requests (sgl-project#5016)

* Correct grafana heatmap. (sgl-project#5019)

* Perform Batch Tokenization. (sgl-project#5141)

* Speedup shared expert weight construction by avoid cloning (sgl-project#5188)

* Tiny add Engine.flush_cache API (sgl-project#5241)

* [misc] remove is_cuda_available (sgl-project#5319)

* Fix flush cache (sgl-project#5590)

* Add Speculative Decoding Eagle3 topk > 1 (sgl-project#5318)

Co-authored-by: Stefan He <[email protected]>
Co-authored-by: Yubo Wang <[email protected]>

* upstream hicache fixes (sgl-project#5570)

* Tiny add warning when cannot recognize bool env var (sgl-project#5348)

* Modify metrics service endpoint (sgl-project#3443)

* Update protocol.py to fix sgl-project#4589 (sgl-project#4590)

* [Feat.] Enable grafana to show metrics (sgl-project#4718)

Co-authored-by: zhaochenyang20 <[email protected]>

* [Fix] Enhance DP Attention for IPv6 Compatibility (sgl-project#4937)

* Support o1 model on Azure (sgl-project#4980)

Co-authored-by: Shan Yu <[email protected]>

* Tiny remove duplicated code (sgl-project#5021)

* Tiny update error hint (sgl-project#5037)

* Support PD bootstrap fields on /v1/chat/completions endpoint (sgl-project#5488)

* [PD] Fix generate endpoint of min_lb for PD (sgl-project#5598)

Signed-off-by: Shangming Cai <[email protected]>

* [PD] Fix edge case and simplify large page size + chunked prefill (sgl-project#5589)

* [PD] Add NIXL transfer backend  (sgl-project#5477)

* [PD] Support decode overlap schedule (sgl-project#5608)

* [PD] Support prefill overlap + Ensure no race condition (sgl-project#5609)

* Enhance GPU memory settings (sgl-project#5604)

* [feature] enable pre compile jit deep_gemm (sgl-project#5580)

* Clean up mem settings (sgl-project#5610)

* Support aiter RMSNorm in AMD (sgl-project#5510)

Co-authored-by: JieXin Liang <[email protected]>

* chore: bump v0.4.5.post3 (sgl-project#5611)

* Remove extra copy in deepseek forward absorb (sgl-project#5578)

Co-authored-by: saienduri <[email protected]>

* [Doc] Fix a 404 link to llama-405b (sgl-project#5615)

Signed-off-by: windsonsea <[email protected]>

* [fix] force use deepgemm in compile_deep_gemm (sgl-project#5618)

* [fix] fix compile_deep_gemm missing kv_b_proj (sgl-project#5620)

* fix: gemma 3 not use softcap (sgl-project#5622)

* Fix FA3 DeepSeek prefill performance regression (sgl-project#5624)

Co-authored-by: ispobock <[email protected]>

* [NFC] Remove duplicate `compressed-tensors` (sgl-project#5640)

* Fix shared experts fusion error without quantization (sgl-project#5632)

* [feature] Add H20 fp8_w8a8 FusedMoE config for --n-share-experts-fusion=16 (sgl-project#5641)

Co-authored-by: yuethe <[email protected]>

* fix flashmla bug (sgl-project#5272)

* [fix] reduce dp capture bs (sgl-project#5634)

Co-authored-by: alcanerian <[email protected]>

* Remove q concat in FA3 backend for DeepSeek decode (sgl-project#5638)

* Revert "Support aiter RMSNorm in AMD" (sgl-project#5646)

* fix: update bench_speculative (sgl-project#5649)

* Turn on DeepGemm By Default and Update Doc (sgl-project#5628)

* Fuse q_a_proj and kv_a_proj (sgl-project#5619)

* Remove unnecessary `torch.full` in DeepSeek (sgl-project#5601)

* [1/2] Add FP8 Blockscale MoE CUTLASS kernel for Blackwell (sgl-project#5281)

* fix sgl-kernel unit tests (sgl-project#5666)

* fix awq_dequantize import (sgl-project#5669)

* Integrating PD disaggregation with DP attention and DeepEP (sgl-project#5435)

Co-authored-by: Byron Hsu <[email protected]>

* fix gemma3 unit test (sgl-project#5670)

* fix torchvision::nms not exist (sgl-project#5671)

* [PD] Add support for dp attention with mooncake (sgl-project#5530)

Signed-off-by: Shangming Cai <[email protected]>

* tune the threshold of gemma-2-27b-it in test_nightly_gsm8k_eval.py (sgl-project#5677)

* [Doc] Fix two 404 links caused by sglang typo (sgl-project#5667)

Signed-off-by: windsonsea <[email protected]>

* fix: update truss bench_serving (sgl-project#5683)

* fix: only compile ApplyTokenBitmaskInplace cu124+ (sgl-project#5686)

* chore: bump sgl-kernel 0.1.0 (sgl-project#5688)

* vlm: enable radix cache for qwen-vl models (sgl-project#5349)

Co-authored-by: Xinyuan Tong <[email protected]>

* [BugFix] Fix combination of MTP and `--n-share-experts-fusion`with R1 (sgl-project#5707)

* Fix weight loading bug for Deepseek v3+nextn (sgl-project#5684)

* Add example to use sgl engine with fastapi (sgl-project#5648)

Co-authored-by: Ravi Theja Desetty <[email protected]>

* [Doc] Fix a link to Weilin Zhao (sgl-project#5706)

Signed-off-by: windsonsea <[email protected]>

* Add MMMU benchmark results (sgl-project#4491)

Co-authored-by: Ravi Theja Desetty <[email protected]>

* [Model] Support `ArcticForCausalLM` architecture (Snowflake/snowflake-arctic-instruct) (sgl-project#5078)

Co-authored-by: vincent-4 <[email protected]>

* [PD] Better logs (sgl-project#5715)

* [PD] Add kvargs table and thread pool for kvcache sender of mooncake (sgl-project#5738)

Signed-off-by: Shangming Cai <[email protected]>

* [PD]: Support Muti Prefill in one node (sgl-project#5704)

Co-authored-by: shuaills <[email protected]>

* Fix: deepseek forward absorb (sgl-project#5723)

Co-authored-by: ispobock <[email protected]>

* Pin torch audio to 2.6.0 (sgl-project#5750)

* Revert "[Model] Support `ArcticForCausalLM` architecture (Snowflake/snowflake-arctic-instruct)" (sgl-project#5754)

* Disable flaky eagle tests (sgl-project#5753)

* update triton 3.2.0 h200 fused moe triton config and add warning about triton fused_moe_kernel performance degradation due to different Triton versions. (sgl-project#5740)

* [Docs] Update runtime/engine/readme.md (sgl-project#5737)

Signed-off-by: windsonsea <[email protected]>

* Reorder loop in shared expert weight loading (sgl-project#5719)

* fix: fix one more bug from merging mm_inputs (sgl-project#5718)

Co-authored-by: Xinyuan Tong <[email protected]>
Co-authored-by: XinyuanTong <[email protected]>

* [Fix]: support deepseek-vl2-tiny model (sgl-project#5552)

Co-authored-by: bppps <[email protected]>

* Bugfix for minicpmo vision test (sgl-project#5760)

* [Minor] fix documentations (sgl-project#5756)

* Add an assertion to enhance the robustness of the operator (sgl-project#5736)

* fix: import vllm_rotary_embedding error when head_size not in 64, 128, 256, 512 (sgl-project#5733)

* Use device_id in dist init to reduce NCCL communicator warmup & creation overhead (sgl-project#5728)

* [fix] fix potential bumpy throughtput with deepgemm (sgl-project#5722)

* Resolves the `404 Not Found` error when running `compile_deep_gemm.py` in multi-node setups (sgl-project#5720)

* perf: update H20 fused_moe_triton kernel config to get higher throughput during prefilling (sgl-project#5716)

* we fix the non existent access of `decrypted_config_file` (sgl-project#5685)

* CI: rewrite test_vision_chunked_prefill to speedup (sgl-project#5682)

* Fuse MLA set kv cache kernel (sgl-project#5748)

* Update amd docker image to `sglang:v0.4.5.post3-rocm630`. (sgl-project#5697)

* [feature] support for roberta embedding models (sgl-project#5730)

* [fix] fix bench_one_batch_server (sgl-project#5607)

* support for the DeepSeek model by enabling streaming response parsing (sgl-project#5592)

* fix: Use `is not None` instead of `!= None` for None checks. (sgl-project#5687)

* Add Llama 4 to FA3 test (sgl-project#5509)

* [misc] more decode step log for batch_one_batch (sgl-project#5565)

* Handle JSONDecodeError while processing request data (sgl-project#5599)

* fix(srt): check if sample_indices is not None before usage. (sgl-project#5633)

* update llguidance to 0.7.11; adds StructTag (sgl-project#4870)

* Use sgl-kernel sgl_per_token_group_quant_int8 (sgl-project#4971)

* Add memory_saver check (sgl-project#4986)

Signed-off-by: Kebe <[email protected]>

* add switch to disable open api doc (sgl-project#3744)

Signed-off-by: congcongke <[email protected]>

* Revert "fix: import vllm_rotary_embedding error when head_size not in 64, 128, 256, 512" (sgl-project#5772)

* Fix eagle test case (sgl-project#5776)

* Split local attention test from fa3 test (sgl-project#5774)

* Revert "Revert "fix: import vllm_rotary_embedding error when head_size not in 64, 128, 256, 512"" (sgl-project#5777)

* Simplify FA3 tests (sgl-project#5779)

* Revert "[fix] fix bench_one_batch_server" (sgl-project#5785)

* Revert "Use device_id in dist init to reduce NCCL communicator warmup & creation overhead" (sgl-project#5786)

* [CI] Tune threshold (sgl-project#5787)

* [CI] fix port conflicts (sgl-project#5789)

* [CI] Fix ci tests (sgl-project#5769)

* [PD]Reduce kv transfer threads (sgl-project#5791)

* [CI] Fix test case (sgl-project#5790)

* Add 8-GPU Test for Deepseek-V3  (sgl-project#5691)

Co-authored-by: Lianmin Zheng <[email protected]>

* Release v0.4.6 (sgl-project#5795)

* Update nightly-test.yml (sgl-project#5797)

* [CI] Improve github summary & enable fa3 for more models (sgl-project#5796)

* [Docs] update grafana setup guide in production metrics (sgl-project#5643)

Co-authored-by: NoahM <[email protected]>

* [Misc] add structure logging, write to file and log tracing for SGL Router

* Improve overlap scheduling (sgl-project#5788)

* Add Cutlass MLA attention backend (sgl-project#5390)

* chore: upgrade sgl-kernel 0.1.0 (sgl-project#5690)

* Dockerfile.dev pip scikit_build_core (sgl-project#5807)

* Add a doc to fix sgl-kernel build link error in py39 with ccache (sgl-project#5809)

* Turn on overlap scheduler for multimodal models (sgl-project#5771)

* Tiny refactor DefaultModelLoader.Source (sgl-project#5482)

* [Docs] Replace lists with tables for cleanup and readability in server_arguments (sgl-project#5276)

* Revert "Tiny refactor DefaultModelLoader.Source" (sgl-project#5825)

* Feat: add support for thinking mode via chat_template_kwargs.enable_t… (sgl-project#5551)

Co-authored-by: shuaills <[email protected]>
Co-authored-by: Chayenne <[email protected]>
Co-authored-by: Lianmin Zheng <[email protected]>
Co-authored-by: Yineng Zhang <[email protected]>

* fix: fix the error where the content is None when reasoning and tool … (sgl-project#5838)

* feat: Add fused moe triton config for qwen3 moe on h100 (sgl-project#5833)

* fused moe triton tuning script support qwen3 (sgl-project#5842)

* feat: Add fused moe triton config for qwen3bf16 moe on h20 (sgl-project#5839)

* [PD] support pd fake transfer for warmup (sgl-project#5726)

* [config] qwen3moe_tune_h20 fp8 tp4 (sgl-project#5846)

* [Doc] Recover history of server_arguments.md (sgl-project#5851)

* feat: Add fused moe triton config for qwen3-30b-fp8 moe on h20 (sgl-project#5850)

* [CI] test chunked prefill more (sgl-project#5798)

* ROCm: update AITER (sgl-project#5816)

* [Feat] QWen-1M context support[1/2]: Update block sparse attention backend utils kernel (sgl-project#5847)

Co-authored-by: sighingnow <[email protected]>

* [Fix] Missing bootstrap_port field (sgl-project#5823)

* feat: update is_fa3_default_architecture (sgl-project#5854)

* add fused moe config for qwen3moe fp8/bf16 (sgl-project#5849)

* chore: bump v0.4.6.post1 (sgl-project#5845)

* fix for hpu backend in model runner and server args

Signed-off-by: Mohit Sinha <[email protected]>

* rebase formatting issue

Signed-off-by: Mohit Sinha <[email protected]>

* [SW-228218]: Fix device mismatch in frequency penalty.

Ensure tensors in BatchedFrequencyPenalizer are on the same device by
moving output_ids and frequency_penalties to the device of
cumulated_frequency_penalties. This resolves a RuntimeError
caused by tensors on cpu and hpu:0 during logits subtraction.

---------

Signed-off-by: Shangming Cai <[email protected]>
Signed-off-by: Xuchun Shang <[email protected]>
Signed-off-by: windsonsea <[email protected]>
Signed-off-by: Kebe <[email protected]>
Signed-off-by: congcongke <[email protected]>
Signed-off-by: Mohit Sinha <[email protected]>
Co-authored-by: Yineng Zhang <[email protected]>
Co-authored-by: DefTruth <[email protected]>
Co-authored-by: fzyzcjy <[email protected]>
Co-authored-by: Yuhong Guo <[email protected]>
Co-authored-by: JieXin Liang <[email protected]>
Co-authored-by: Zhaoyang Hao <[email protected]>
Co-authored-by: Yuan Luo <[email protected]>
Co-authored-by: luoyuan.luo <[email protected]>
Co-authored-by: lambert0312 <[email protected]>
Co-authored-by: shangmingc <[email protected]>
Co-authored-by: ybyang <[email protected]>
Co-authored-by: Liangsheng Yin <[email protected]>
Co-authored-by: Lianmin Zheng <[email protected]>
Co-authored-by: Trevor Morris <[email protected]>
Co-authored-by: hebiao064 <[email protected]>
Co-authored-by: Chang Su <[email protected]>
Co-authored-by: mRSun15 <[email protected]>
Co-authored-by: ryang <[email protected]>
Co-authored-by: Yuhao Yang <[email protected]>
Co-authored-by: Michael Yao <[email protected]>
Co-authored-by: ybyang <[email protected]>
Co-authored-by: Baizhou Zhang <[email protected]>
Co-authored-by: Cheng Wan <[email protected]>
Co-authored-by: Xiaoyu Zhang <[email protected]>
Co-authored-by: Elfie Guo <[email protected]>
Co-authored-by: Ying Sheng <[email protected]>
Co-authored-by: BearBiscuit <[email protected]>
Co-authored-by: Chayenne <[email protected]>
Co-authored-by: eigen <[email protected]>
Co-authored-by: yhyang201 <[email protected]>
Co-authored-by: Didier Durand <[email protected]>
Co-authored-by: woodx <[email protected]>
Co-authored-by: Xuchun Shang <[email protected]>
Co-authored-by: mlmz <[email protected]>
Co-authored-by: PGFLMG <[email protected]>
Co-authored-by: u4lr451 <[email protected]>
Co-authored-by: ocss884 <[email protected]>
Co-authored-by: Michael Feil <[email protected]>
Co-authored-by: strgrb <[email protected]>
Co-authored-by: Zhang Kaihong <[email protected]>
Co-authored-by: liwenju0 <[email protected]>
Co-authored-by: Wenxuan Tan <[email protected]>
Co-authored-by: yhyang201 <[email protected]>
Co-authored-by: Yubo Wang <[email protected]>
Co-authored-by: Byron Hsu <[email protected]>
Co-authored-by: Zhiqiang Xie <[email protected]>
Co-authored-by: Zhaoyi Li <[email protected]>
Co-authored-by: lukec <[email protected]>
Co-authored-by: tarinkk <[email protected]>
Co-authored-by: AmadeusW <[email protected]>
Co-authored-by: Adarsh Shirawalmath <[email protected]>
Co-authored-by: Yi Zhou <[email protected]>
Co-authored-by: simveit <[email protected]>
Co-authored-by: kyle-pena-kuzco <[email protected]>
Co-authored-by: Kyle Pena <[email protected]>
Co-authored-by: Enrique Shockwave <[email protected]>
Co-authored-by: Juwan Yoo <[email protected]>
Co-authored-by: Brayden Zhong <[email protected]>
Co-authored-by: mac0ne <[email protected]>
Co-authored-by: Sundara Raman Ramachandran <[email protected]>
Co-authored-by: Qingquan Song <[email protected]>
Co-authored-by: moontidef <[email protected]>
Co-authored-by: Huapeng Zhou <[email protected]>
Co-authored-by: Lucius <[email protected]>
Co-authored-by: Chuyue Sun <[email protected]>
Co-authored-by: Shan Yu <[email protected]>
Co-authored-by: Yongtong Wu <[email protected]>
Co-authored-by: michael-amd <[email protected]>
Co-authored-by: Ke Bao <[email protected]>
Co-authored-by: saienduri <[email protected]>
Co-authored-by: ispobock <[email protected]>
Co-authored-by: Connector Switch <[email protected]>
Co-authored-by: saltyfish66 <[email protected]>
Co-authored-by: yuethe <[email protected]>
Co-authored-by: alcanerian <[email protected]>
Co-authored-by: HAI <[email protected]>
Co-authored-by: Mick <[email protected]>
Co-authored-by: Xinyuan Tong <[email protected]>
Co-authored-by: Ravi Theja <[email protected]>
Co-authored-by: Ravi Theja Desetty <[email protected]>
Co-authored-by: vincent-4 <[email protected]>
Co-authored-by: IAN <[email protected]>
Co-authored-by: shuaills <[email protected]>
Co-authored-by: XinyuanTong <[email protected]>
Co-authored-by: ZXN <[email protected]>
Co-authored-by: bppps <[email protected]>
Co-authored-by: Yi Zhang <[email protected]>
Co-authored-by: Kyungmin Lee <[email protected]>
Co-authored-by: vzed <[email protected]>
Co-authored-by: DavidBao <[email protected]>
Co-authored-by: Frankey_8080 <[email protected]>
Co-authored-by: yan97ao <[email protected]>
Co-authored-by: aoshen524 <[email protected]>
Co-authored-by: Michał Moskal <[email protected]>
Co-authored-by: Kebe <[email protected]>
Co-authored-by: zhanweidu <[email protected]>
Co-authored-by: NoahM <[email protected]>
Co-authored-by: Simo Lin <[email protected]>
Co-authored-by: JiLi <[email protected]>
Co-authored-by: sighingnow <[email protected]>
Co-authored-by: XTY <[email protected]>
Co-authored-by: vikram singh shekhawat <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants