Skip to content

Commit 9ddc9fe

Browse files
fzyzcjyDiweiSun
authored andcommitted
Let bench_one_batch support enable_dp_attention (sgl-project#4058)
1 parent 314884e commit 9ddc9fe

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

python/sglang/bench_one_batch.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from sglang.srt.entrypoints.engine import _set_envs_and_config
6161
from sglang.srt.hf_transformers_utils import get_tokenizer
6262
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
63+
from sglang.srt.managers.scheduler import Scheduler
6364
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
6465
from sglang.srt.model_executor.model_runner import ModelRunner
6566
from sglang.srt.sampling.sampling_params import SamplingParams
@@ -184,6 +185,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
184185
req.prefix_indices = []
185186
req.fill_ids = req.origin_input_ids
186187
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
188+
req.logprob_start_len = len(req.origin_input_ids) - 1
187189
reqs.append(req)
188190

189191
return input_ids, reqs
@@ -199,6 +201,7 @@ def prepare_extend_inputs_for_correctness_test(
199201
i, : bench_args.cut_len
200202
]
201203
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
204+
req.logprob_start_len = len(req.origin_input_ids) - 1
202205
return reqs
203206

204207

@@ -220,6 +223,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
220223
req.prefix_indices = []
221224
req.fill_ids = req.origin_input_ids
222225
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
226+
req.logprob_start_len = len(req.origin_input_ids) - 1
223227
reqs.append(req)
224228

225229
return reqs
@@ -238,6 +242,7 @@ def extend(reqs, model_runner):
238242
enable_custom_logit_processor=False,
239243
)
240244
batch.prepare_for_extend()
245+
_maybe_prepare_dp_attn_batch(batch, model_runner)
241246
model_worker_batch = batch.get_model_worker_batch()
242247
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
243248
logits_output = model_runner.forward(forward_batch)
@@ -249,13 +254,28 @@ def extend(reqs, model_runner):
249254
def decode(input_token_ids, batch, model_runner):
250255
batch.output_ids = input_token_ids
251256
batch.prepare_for_decode()
257+
_maybe_prepare_dp_attn_batch(batch, model_runner)
252258
model_worker_batch = batch.get_model_worker_batch()
253259
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
254260
logits_output = model_runner.forward(forward_batch)
255261
next_token_ids = model_runner.sample(logits_output, forward_batch)
256262
return next_token_ids, logits_output.next_token_logits
257263

258264

265+
def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
266+
if model_runner.server_args.enable_dp_attention:
267+
Scheduler.prepare_dp_attn_batch_raw(
268+
batch,
269+
dp_size=model_runner.server_args.dp_size,
270+
attn_tp_size=1,
271+
tp_cpu_group=model_runner.tp_group.cpu_group,
272+
get_idle_batch=None,
273+
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
274+
spec_algorithm=SpeculativeAlgorithm.NONE,
275+
speculative_num_draft_tokens=None,
276+
)
277+
278+
259279
def correctness_test(
260280
server_args,
261281
port_args,

python/sglang/srt/managers/scheduler.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,14 +1466,36 @@ def process_batch_result(
14661466
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
14671467

14681468
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1469+
return self.prepare_dp_attn_batch_raw(
1470+
local_batch,
1471+
dp_size=self.server_args.dp_size,
1472+
attn_tp_size=self.attn_tp_size,
1473+
tp_cpu_group=self.tp_cpu_group,
1474+
get_idle_batch=self.get_idle_batch,
1475+
disable_cuda_graph=self.server_args.disable_cuda_graph,
1476+
spec_algorithm=self.spec_algorithm,
1477+
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1478+
)
1479+
1480+
@staticmethod
1481+
def prepare_dp_attn_batch_raw(
1482+
local_batch: ScheduleBatch,
1483+
dp_size,
1484+
attn_tp_size: int,
1485+
tp_cpu_group,
1486+
get_idle_batch,
1487+
disable_cuda_graph: bool,
1488+
spec_algorithm,
1489+
speculative_num_draft_tokens,
1490+
):
14691491
# Check if other DP workers have running batches
14701492
if local_batch is None:
14711493
num_tokens = 0
14721494
global_num_tokens_for_logprob = 0
14731495
elif local_batch.forward_mode.is_decode():
14741496
num_tokens = local_batch.batch_size()
1475-
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
1476-
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
1497+
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
1498+
num_tokens = num_tokens * speculative_num_draft_tokens
14771499
global_num_tokens_for_logprob = num_tokens
14781500
else:
14791501
num_tokens = local_batch.extend_num_tokens
@@ -1492,7 +1514,7 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
14921514
else:
14931515
can_cuda_graph = 0
14941516

1495-
if not self.spec_algorithm.is_none():
1517+
if not spec_algorithm.is_none():
14961518
# TODO(sang): Support cuda graph when idle batch is there.
14971519
if local_batch is None or local_batch.forward_mode.is_idle():
14981520
can_cuda_graph = 0
@@ -1510,28 +1532,28 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
15101532
dtype=torch.int64,
15111533
)
15121534
global_info = torch.empty(
1513-
(self.server_args.dp_size, self.attn_tp_size, 4),
1535+
(dp_size, attn_tp_size, 4),
15141536
dtype=torch.int64,
15151537
)
15161538
torch.distributed.all_gather_into_tensor(
15171539
global_info.flatten(),
15181540
local_info,
1519-
group=self.tp_cpu_group,
1541+
group=tp_cpu_group,
15201542
)
15211543
global_num_tokens = global_info[:, 0, 0].tolist()
15221544
can_cuda_graph = min(global_info[:, 0, 1].tolist())
15231545
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
15241546
is_extend_in_batch = global_info[:, 0, 3].tolist()
15251547

15261548
if local_batch is None and max(global_num_tokens) > 0:
1527-
local_batch = self.get_idle_batch()
1549+
local_batch = get_idle_batch()
15281550

15291551
if local_batch is not None:
15301552
local_batch.global_num_tokens = global_num_tokens
15311553
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
15321554

15331555
# Check forward mode for cuda graph
1534-
if not self.server_args.disable_cuda_graph:
1556+
if not disable_cuda_graph:
15351557
local_batch.can_run_dp_cuda_graph = can_cuda_graph
15361558

15371559
return local_batch, any(is_extend_in_batch)

0 commit comments

Comments
 (0)