Skip to content

Commit c27f2a4

Browse files
njhillMu Huai
authored andcommitted
[MLA] Simplification to batch P/D reordering (vllm-project#16673)
Signed-off-by: Nick Hill <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent 3d87be7 commit c27f2a4

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -415,20 +415,18 @@ def reorder_batch(self, input_batch: "InputBatch",
415415
# the above loop
416416
num_decodes = len(decodes)
417417
num_prefills = len(prefills)
418-
first_prefill = 0
419418
modified_batch = False
420419

421420
for i in range(1, min(num_decodes, num_prefills) + 1):
422421
# If the decode is at the "back" of the batch, i, we can swap it
423422
# with the prefill closest to the front of the batch
424-
if decodes[num_decodes - i] >= num_decodes:
425-
input_batch.swap_states(prefills[first_prefill],
426-
decodes[num_decodes - i])
427-
first_prefill += 1
428-
modified_batch = True
429-
else:
423+
decode_idx = decodes[num_decodes - i]
424+
if decode_idx < num_decodes:
430425
break
431426

427+
input_batch.swap_states(prefills[i - 1], decode_idx)
428+
modified_batch = True
429+
432430
# Save for next `build` call
433431
# TODO(lucas): this is a bit of a hack, we should probably have a
434432
# better way of doing this

vllm/v1/worker/gpu_model_runner.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
458458
if removed_req_indices:
459459
self.input_batch.condense(removed_req_indices)
460460

461-
if batch_changed:
461+
# Some attention backends (namely MLA) may want to separate requests
462+
# based on if the attention computation will be compute-bound or
463+
# memory-bound. This gives them a hook to do that.
464+
batch_reordered = self.attn_metadata_builder.reorder_batch(
465+
self.input_batch, scheduler_output)
466+
467+
if batch_changed or batch_reordered:
462468
self.input_batch.refresh_sampling_metadata()
463469

464470
def _prepare_inputs(
@@ -471,14 +477,6 @@ def _prepare_inputs(
471477
num_reqs = self.input_batch.num_reqs
472478
assert num_reqs > 0
473479

474-
# Some attention backends (namely MLA) may want to separate requests
475-
# based on if the attention computation will be compute-bound or
476-
# memory-bound. This gives them a hook to do that.
477-
modified_batch = self.attn_metadata_builder.reorder_batch(
478-
self.input_batch, scheduler_output)
479-
if modified_batch:
480-
self.input_batch.refresh_sampling_metadata()
481-
482480
# OPTIMIZATION: Start copying the block table first.
483481
# This way, we can overlap the copy with the following CPU operations.
484482
self.input_batch.block_table.commit(num_reqs)

0 commit comments

Comments
 (0)