Skip to content

[deepseek_r1] refine _schedule_prefills for prompts with large length range #1511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: deepseek_r1
Choose a base branch
from

Conversation

yangulei
Copy link

@yangulei yangulei commented Jul 2, 2025

The prefill batch size is usually controlled by max_num_batched_tokens which must be larger than max_model_len. For a set of requests with large variance of lengthens, the max_num_batched_tokens must be larger than max(seq_lens). Which leads to large batchsize for the short requests and hurt TTFT performance.

In this PR

  • the limitation of max_num_batched_tokens>=max_model_len is removed.
  • for the requests with seq_len > max_num_batched_tokens, the prefill batch size will be 1.
  • the requests with seq_len > max_num_batched_tokens will be skipped if there are prefills already scheduled, and the following requests will be checked and scheduled if possible.
    In this way, the max_num_batched_tokens should be chosen as the minimum length that could fully utilize the device, and the recommended value is 8192.

This PR is tested with the following example code:

# SPDX-License-Identifier: Apache-2.0
import os
os.environ["PT_HPU_LAZY_MODE"] = "1"
os.environ["VLLM_SKIP_WARMUP"] = "true"
os.environ["VLLM_EXPONENTIAL_BUCKETING"] = "false"
os.environ["VLLM_PROFILER_ENABLED"] = "true"
os.environ["PT_HPU_RECIPE_CACHE_CONFIG"] = "/models/.hpu_cache,false,4096"
os.environ["VLLM_PROMPT_BS_BUCKET_STEP"] = "1"
os.environ["VLLM_DECODE_BS_BUCKET_STEP"] = "1"


from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hi " * (3072-2),
    "Hi " * (1024-2),
    "Hi " * (1024-2),
    "Hi " * (1024-2),
    "Hi " * (512-2),
    "Hi " * (3072-2),
    "Hi " * (512-2),
    "Hi " * (512-2),
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=2)


def main():
    # Create an LLM.
    llm = LLM(
        model="/models/Llama-2-7b-chat-hf",
        max_num_batched_tokens=2048,
        max_model_len=4096,
        max_num_seqs=4
    )
    # Generate texts from the prompts.
    # The output is a list of RequestOutput objects
    # that contain the prompt, generated text, and other information.
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
    print("\nGenerated Outputs:\n" + "-" * 60)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt:    {prompt!r}")
        print(f"Output:    {generated_text!r}")
        print("-" * 60)


if __name__ == "__main__":
    main()

The results from high-level profile:

phase bucket_batch_size batch_size bucket_seq_len seq_len
prompt 1 1 3072 3072
prompt 2 2 1024 1024
prompt 1 1 1024 1024
decode 4 4 128 3073
decode 4 4 128 3074
prompt 3 3 512 512
decode 3 3 128 513
decode 3 3 128 514
prompt 1 1 3072 3072
decode 1 1 128 3073
decode 1 1 128 3074

@yangulei
Copy link
Author

yangulei commented Jul 2, 2025

@czhu15 @Wei-Lin-Intel
Please help to test and review, thanks!

@Wei-Lin-Intel
Copy link

LGTM, let me try on Tencent's dataset

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants