Skip to content

Support splitting one batch into two micro-batches #4965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from enum import Enum, auto

from sglang.srt.distributed import get_tensor_model_parallel_rank

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -671,6 +673,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# For DP attention
global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None
tbo_split_seq_index: Optional[int] = None
can_run_dp_cuda_graph: bool = False

# For processing logprobs
Expand Down Expand Up @@ -1465,6 +1468,7 @@ def get_model_worker_batch(self) -> ModelWorkerBatch:
token_ids_logprobs=self.token_ids_logprobs,
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
tbo_split_seq_index=self.tbo_split_seq_index,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
seq_lens_cpu=seq_lens_cpu,
extend_num_tokens=self.extend_num_tokens,
Expand Down Expand Up @@ -1542,6 +1546,7 @@ class ModelWorkerBatch:
# For DP attention
global_num_tokens: Optional[List[int]]
global_num_tokens_for_logprob: Optional[List[int]]
tbo_split_seq_index: Optional[int]
can_run_dp_cuda_graph: bool

# For extend
Expand Down
68 changes: 61 additions & 7 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from torch.distributed import barrier

from sglang.global_config import global_config
from sglang.srt import two_batch_overlap
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
from sglang.srt.disaggregation.decode import (
Expand All @@ -50,6 +51,7 @@
DisaggregationMode,
ReqToMetadataIdxAllocator,
)
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
Expand Down Expand Up @@ -1456,14 +1458,38 @@ def process_batch_result(
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
return self.prepare_dp_attn_batch_raw(
local_batch,
dp_size=self.server_args.dp_size,
attn_tp_size=self.attn_tp_size,
tp_cpu_group=self.tp_cpu_group,
get_idle_batch=self.get_idle_batch,
disable_cuda_graph=self.server_args.disable_cuda_graph,
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
spec_algorithm=self.spec_algorithm,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
)

@staticmethod
def prepare_dp_attn_batch_raw(
local_batch: ScheduleBatch,
dp_size,
attn_tp_size: int,
tp_cpu_group,
get_idle_batch,
disable_cuda_graph: bool,
enable_two_batch_overlap: bool,
spec_algorithm,
speculative_num_draft_tokens,
):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
global_num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
num_tokens = num_tokens * speculative_num_draft_tokens
global_num_tokens_for_logprob = num_tokens
else:
num_tokens = local_batch.extend_num_tokens
Expand All @@ -1482,46 +1508,70 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
else:
can_cuda_graph = 0

if not self.spec_algorithm.is_none():
if not spec_algorithm.is_none():
# TODO(sang): Support cuda graph when idle batch is there.
if local_batch is None or local_batch.forward_mode.is_idle():
can_cuda_graph = 0

is_extend_in_batch = (
local_batch.forward_mode.is_extend() if local_batch else False
)

if local_batch is not None:
local_tbo_split_seq_index = two_batch_overlap.compute_split_seq_index(
forward_mode=local_batch.forward_mode,
num_tokens=local_batch.input_ids.shape[0],
extend_lens=local_batch.extend_lens,
)
else:
local_tbo_split_seq_index = None
local_can_run_tbo = local_tbo_split_seq_index is not None

local_info = torch.tensor(
[
num_tokens,
can_cuda_graph,
global_num_tokens_for_logprob,
is_extend_in_batch,
local_can_run_tbo,
local_batch.forward_mode.value if local_batch is not None else -1,
],
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 4),
(dp_size, attn_tp_size, 6),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
global_info.flatten(),
local_info,
group=self.tp_cpu_group,
group=tp_cpu_group,
)
global_num_tokens = global_info[:, 0, 0].tolist()
can_cuda_graph = min(global_info[:, 0, 1].tolist())
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
is_extend_in_batch = global_info[:, 0, 3].tolist()
local_can_run_tbo_aggregated = min(global_info[:, 0, 4].tolist())
forward_mode_same = _is_all_same(global_info[:, 0, 5].tolist())

can_run_tbo = (
enable_two_batch_overlap
and local_can_run_tbo_aggregated
and forward_mode_same
)

if local_batch is None and max(global_num_tokens) > 0:
local_batch = self.get_idle_batch()
local_batch = get_idle_batch()

if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
local_batch.tbo_split_seq_index = (
local_tbo_split_seq_index if can_run_tbo else None
)

# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
if not disable_cuda_graph:
local_batch.can_run_dp_cuda_graph = can_cuda_graph

return local_batch, any(is_extend_in_batch)
Expand Down Expand Up @@ -1962,6 +2012,10 @@ def _import_static_state(model, static_params):
self_named_buffers[name][...] = tensor


def _is_all_same(x):
return all(value == x[0] for value in x)


def run_scheduler_process(
server_args: ServerArgs,
port_args: PortArgs,
Expand Down
163 changes: 161 additions & 2 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,20 @@

from __future__ import annotations

import dataclasses
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import numpy as np
import torch
import triton
import triton.language as tl

from sglang.srt import two_batch_overlap
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend

Expand Down Expand Up @@ -210,6 +215,7 @@ class ForwardBatch:
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
gathered_buffer: Optional[torch.Tensor] = None
tbo_split_seq_index: Optional[int] = None
can_run_dp_cuda_graph: bool = False

# Speculative decoding
Expand All @@ -223,6 +229,9 @@ class ForwardBatch:
# For Qwen2-VL
mrope_positions: torch.Tensor = None

tbo_parent_token_range: Optional[Tuple[int, int]] = None
tbo_children: Optional[List["ForwardBatch"]] = None

@classmethod
def init_new(
cls,
Expand Down Expand Up @@ -251,6 +260,7 @@ def init_new(
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
tbo_split_seq_index=batch.tbo_split_seq_index,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
Expand Down Expand Up @@ -332,6 +342,8 @@ def init_new(
if model_runner.server_args.lora_paths is not None:
model_runner.lora_manager.prepare_lora_batch(ret)

ret.prepare_tbo()

return ret

def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
Expand Down Expand Up @@ -443,6 +455,153 @@ def _compute_mrope_positions(
)
self.mrope_positions = self.mrope_positions.to(torch.int64)

def prepare_tbo(self):
if self.tbo_split_seq_index is None:
return

tbo_split_token_index = two_batch_overlap.compute_split_token_index(
split_seq_index=self.tbo_split_seq_index,
forward_mode=self.forward_mode,
extend_seq_lens=self.extend_seq_lens,
)

from sglang.srt.layers.attention.tbo_backend import TboAttnBackend

assert isinstance(self.attn_backend, TboAttnBackend)
attn_backend_child_a, attn_backend_child_b = self.attn_backend.children

child_a = self.filter_batch(
start_token_index=0,
end_token_index=tbo_split_token_index,
start_seq_index=0,
end_seq_index=self.tbo_split_seq_index,
output_attn_backend=attn_backend_child_a,
)
child_b = self.filter_batch(
start_token_index=tbo_split_token_index,
end_token_index=self.input_ids.shape[0],
start_seq_index=self.tbo_split_seq_index,
end_seq_index=self.batch_size,
output_attn_backend=attn_backend_child_b,
)

assert self.tbo_children is None
self.tbo_children = [child_a, child_b]

def filter_batch(
self,
*,
start_token_index: int,
end_token_index: int,
start_seq_index: int,
end_seq_index: int,
output_attn_backend: AttentionBackend,
):
num_tokens = self.input_ids.shape[0]
num_seqs = self.batch_size

output_dict = dict()

for key in [
"input_ids",
"positions",
"out_cache_loc",
]:
old_value = getattr(self, key)
assert (
old_value.shape[0] == num_tokens
), f"{key=} {old_value=} {num_tokens=} {self=}"
output_dict[key] = old_value[start_token_index:end_token_index]

for key in [
"req_pool_indices",
"seq_lens",
"seq_lens_cpu",
"extend_seq_lens",
"extend_prefix_lens",
"extend_start_loc",
"extend_prefix_lens_cpu",
"extend_seq_lens_cpu",
"extend_logprob_start_lens_cpu",
"lora_paths",
]:
old_value = getattr(self, key)
if old_value is None:
continue
assert (
len(old_value) == num_seqs
), f"{key=} {old_value=} {num_seqs=} {self=}"
output_dict[key] = old_value[start_seq_index:end_seq_index]

for key in [
"forward_mode",
"return_logprob",
"req_to_token_pool",
"token_to_kv_pool",
"can_run_dp_cuda_graph",
"spec_info",
"spec_algorithm",
"capture_hidden_mode",
"padded_static_len",
"mrope_positions", # only used by qwen2-vl, thus not care
]:
output_dict[key] = getattr(self, key)

assert (
_compute_extend_num_tokens(self.input_ids, self.forward_mode)
== self.extend_num_tokens
), f"{self=}"
extend_num_tokens = _compute_extend_num_tokens(
output_dict["input_ids"], output_dict["forward_mode"]
)

output_dict.update(
dict(
batch_size=end_seq_index - start_seq_index,
seq_lens_sum=output_dict["seq_lens"].sum().item(),
extend_num_tokens=extend_num_tokens,
attn_backend=output_attn_backend,
tbo_split_seq_index=None,
tbo_parent_token_range=(start_token_index, end_token_index),
tbo_children=None,
global_num_tokens_gpu=None,
global_num_tokens_cpu=None,
gathered_buffer=None,
global_num_tokens_for_logprob_gpu=None,
global_num_tokens_for_logprob_cpu=None,
sampling_info=None,
# For logits and logprobs post processing, thus we do not care
temp_scaled_logprobs=False,
temperature=None,
top_p_normalized_logprobs=False,
top_p=None,
mm_inputs=None,
)
)

errors = []
for field in dataclasses.fields(ForwardBatch):
if getattr(self, field.name) is not None and field.name not in output_dict:
errors.append(
f"Field {field.name} has value, but is not yet supported (value={getattr(self, field.name)} self={self})"
)
if len(errors) > 0:
raise Exception(f"{len(errors)} errors happen:\n" + "\n\n".join(errors))

return ForwardBatch(**output_dict)

@property
def can_run_tbo(self):
return self.tbo_split_seq_index is not None


def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
if forward_mode.is_extend():
return input_ids.shape[0]
elif forward_mode.is_decode():
return None
raise NotImplementedError


def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
Expand Down
Loading
Loading