Skip to content

[MLA] Simplification to batch P/D reordering #16673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ def __init__(self, runner: "GPUModelRunner"):
self.runner = runner

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False
scheduler_output: "SchedulerOutput") -> None:
pass

def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
Expand Down
16 changes: 5 additions & 11 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def __init__(self,
self.page_size = self.runner.block_size

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
scheduler_output: "SchedulerOutput") -> None:
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
Expand Down Expand Up @@ -415,20 +415,16 @@ def reorder_batch(self, input_batch: "InputBatch",
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
first_prefill = 0
modified_batch = False

for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
if decodes[num_decodes - i] >= num_decodes:
input_batch.swap_states(prefills[first_prefill],
decodes[num_decodes - i])
first_prefill += 1
modified_batch = True
else:
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break

input_batch.swap_states(prefills[i - 1], decode_idx)

# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
Expand All @@ -437,8 +433,6 @@ def reorder_batch(self, input_batch: "InputBatch",
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens

return modified_batch

def _build_decode(self, input_positions: torch.Tensor,
block_table: torch.Tensor, seq_lens: torch.Tensor):
return MLACommonDecodeMetadata(
Expand Down
14 changes: 6 additions & 8 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
self.input_batch.condense(removed_req_indices)

if batch_changed:
# Some attention backends (namely MLA) may want to separate
# requests based on if the attention computation will be
# compute-bound or memory-bound. This gives them a hook to do that.
self.attn_metadata_builder.reorder_batch(self.input_batch,
scheduler_output)

self.input_batch.refresh_sampling_metadata()

def _prepare_inputs(
Expand All @@ -471,14 +477,6 @@ def _prepare_inputs(
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0

# Some attention backends (namely MLA) may want to separate requests
# based on if the attention computation will be compute-bound or
# memory-bound. This gives them a hook to do that.
modified_batch = self.attn_metadata_builder.reorder_batch(
self.input_batch, scheduler_output)
if modified_batch:
self.input_batch.refresh_sampling_metadata()

# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit(num_reqs)
Expand Down