From 09ec08d42d96ceb5d578e6111da4dad73d3aac56 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 15 Apr 2025 09:21:30 -0700 Subject: [PATCH 1/2] [MLA] Simplification to batch P/D reordering I noticed that we're unnecessarily re-creating the sampling metadata twice when reordering the batch requests into prefill and decode groups for MLA. This moves the reorder op from the start of the _prepare_inputs() method to the end of the _update_stats() method (which is called right before). Signed-off-by: Nick Hill --- vllm/v1/attention/backends/flash_attn.py | 4 ++-- vllm/v1/attention/backends/mla/common.py | 16 +++++----------- vllm/v1/worker/gpu_model_runner.py | 14 ++++++-------- 3 files changed, 13 insertions(+), 21 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index b4c7708daab..9ac10dbb4af 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -280,8 +280,8 @@ def __init__(self, runner: "GPUModelRunner"): self.runner = runner def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - return False + scheduler_output: "SchedulerOutput") -> None: + pass def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int): diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 8c7179ba0a8..acaa058cfa0 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -378,7 +378,7 @@ def __init__(self, self.page_size = self.runner.block_size def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: + scheduler_output: "SchedulerOutput") -> None: # We now want to reorder the batch so that the "decode" requests are and # the front and the "prefill" requests are at the using the least amount # swaps possible. (NOTE for now we loosely use "decode" to mean requests @@ -415,20 +415,16 @@ def reorder_batch(self, input_batch: "InputBatch", # the above loop num_decodes = len(decodes) num_prefills = len(prefills) - first_prefill = 0 - modified_batch = False for i in range(1, min(num_decodes, num_prefills) + 1): # If the decode is at the "back" of the batch, i, we can swap it # with the prefill closest to the front of the batch - if decodes[num_decodes - i] >= num_decodes: - input_batch.swap_states(prefills[first_prefill], - decodes[num_decodes - i]) - first_prefill += 1 - modified_batch = True - else: + decode_idx = decodes[num_decodes - i] + if decode_idx < num_decodes: break + input_batch.swap_states(prefills[i - 1], decode_idx) + # Save for next `build` call # TODO(lucas): this is a bit of a hack, we should probably have a # better way of doing this @@ -437,8 +433,6 @@ def reorder_batch(self, input_batch: "InputBatch", self._num_decode_tokens = num_decode_tokens self._num_prefill_tokens = num_prefill_tokens - return modified_batch - def _build_decode(self, input_positions: torch.Tensor, block_table: torch.Tensor, seq_lens: torch.Tensor): return MLACommonDecodeMetadata( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c3d84ab3773..60c5666b3d1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -459,6 +459,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.condense(removed_req_indices) if batch_changed: + # Some attention backends (namely MLA) may want to separate + # requests based on if the attention computation will be + # compute-bound or memory-bound. This gives them a hook to do that. + self.attn_metadata_builder.reorder_batch(self.input_batch, + scheduler_output) + self.input_batch.refresh_sampling_metadata() def _prepare_inputs( @@ -471,14 +477,6 @@ def _prepare_inputs( num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - # Some attention backends (namely MLA) may want to separate requests - # based on if the attention computation will be compute-bound or - # memory-bound. This gives them a hook to do that. - modified_batch = self.attn_metadata_builder.reorder_batch( - self.input_batch, scheduler_output) - if modified_batch: - self.input_batch.refresh_sampling_metadata() - # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit(num_reqs) From db9f9a67ae5eaa94fe944b35f0905c73033f6557 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 15 Apr 2025 10:14:00 -0700 Subject: [PATCH 2/2] Always call reorder Signed-off-by: Nick Hill --- vllm/v1/attention/backends/flash_attn.py | 4 ++-- vllm/v1/attention/backends/mla/common.py | 6 +++++- vllm/v1/worker/gpu_model_runner.py | 12 ++++++------ 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 9ac10dbb4af..b4c7708daab 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -280,8 +280,8 @@ def __init__(self, runner: "GPUModelRunner"): self.runner = runner def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> None: - pass + scheduler_output: "SchedulerOutput") -> bool: + return False def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int): diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index acaa058cfa0..b77e9525219 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -378,7 +378,7 @@ def __init__(self, self.page_size = self.runner.block_size def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> None: + scheduler_output: "SchedulerOutput") -> bool: # We now want to reorder the batch so that the "decode" requests are and # the front and the "prefill" requests are at the using the least amount # swaps possible. (NOTE for now we loosely use "decode" to mean requests @@ -415,6 +415,7 @@ def reorder_batch(self, input_batch: "InputBatch", # the above loop num_decodes = len(decodes) num_prefills = len(prefills) + modified_batch = False for i in range(1, min(num_decodes, num_prefills) + 1): # If the decode is at the "back" of the batch, i, we can swap it @@ -424,6 +425,7 @@ def reorder_batch(self, input_batch: "InputBatch", break input_batch.swap_states(prefills[i - 1], decode_idx) + modified_batch = True # Save for next `build` call # TODO(lucas): this is a bit of a hack, we should probably have a @@ -433,6 +435,8 @@ def reorder_batch(self, input_batch: "InputBatch", self._num_decode_tokens = num_decode_tokens self._num_prefill_tokens = num_prefill_tokens + return modified_batch + def _build_decode(self, input_positions: torch.Tensor, block_table: torch.Tensor, seq_lens: torch.Tensor): return MLACommonDecodeMetadata( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 60c5666b3d1..f3f87213258 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -458,13 +458,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if removed_req_indices: self.input_batch.condense(removed_req_indices) - if batch_changed: - # Some attention backends (namely MLA) may want to separate - # requests based on if the attention computation will be - # compute-bound or memory-bound. This gives them a hook to do that. - self.attn_metadata_builder.reorder_batch(self.input_batch, - scheduler_output) + # Some attention backends (namely MLA) may want to separate requests + # based on if the attention computation will be compute-bound or + # memory-bound. This gives them a hook to do that. + batch_reordered = self.attn_metadata_builder.reorder_batch( + self.input_batch, scheduler_output) + if batch_changed or batch_reordered: self.input_batch.refresh_sampling_metadata() def _prepare_inputs(