Skip to content

[Core][V1][TPU] Enable structured decoding on TPU V1 #16499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 23, 2025
Merged
2 changes: 2 additions & 0 deletions .buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ docker run --privileged --net host --shm-size=16G -it \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \
&& echo TEST_9 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" \
&& echo TEST_10 \
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \


# TODO: This test fails because it uses RANDOM_SEED sampling
Expand Down
4 changes: 2 additions & 2 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ def validate_request(
) -> None:
"""Raises if this request is unsupported on this platform"""
if isinstance(params, SamplingParams):
if params.guided_decoding is not None:
if params.guided_decoding is not None and not envs.VLLM_USE_V1:
raise ValueError("Structured output is not supported on "
f"{cls.device_name}.")
f"{cls.device_name} V0.")
if params.sampling_type == SamplingType.RANDOM_SEED:
raise ValueError(
"Torch XLA does not support per-request seed.")
154 changes: 133 additions & 21 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,9 +688,17 @@ def execute_model(
)
hidden_states = self.select_hidden_states(hidden_states,
logits_indices)
logits = self.compute_logits(hidden_states)
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, padded_num_reqs, self.device)
selected_token_ids = self.sample_from_hidden(hidden_states,
if scheduler_output.grammar_bitmask is not None:
require_struct_decoding, grammar_bitmask_padded = \
self.prepare_structured_decoding_input(logits, scheduler_output)
logits = self.structured_decode(
require_struct_decoding.to(logits.device),
grammar_bitmask_padded.to(logits.device), logits,
torch.arange(0, 32).to(logits.device))
selected_token_ids = self.sample_from_logits(logits,
tpu_sampling_metadata)
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
Expand Down Expand Up @@ -862,7 +870,7 @@ def _precompile_backbone(self) -> None:
self._dummy_run(num_tokens)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("model backbone")

def _precompile_select_hidden_states(self) -> None:
Expand All @@ -886,19 +894,64 @@ def _precompile_select_hidden_states(self) -> None:
logger.info(" -- num_tokens: %d", num_tokens)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("select_hidden_states")

def _precompile_sample_from_hidden(self) -> None:
logger.info("Compiling sampling with different input shapes.")
def _precompile_compute_logits(self) -> None:
logger.info(
"Compiling select_hidden_states with different input shapes.")
start = time.perf_counter()
hsize = self.model_config.get_hidden_size()
for num_reqs in self.num_reqs_paddings:
dummy_hidden = torch.zeros((num_reqs, hsize),
device=self.device,
dtype=self._hidden_states_dtype)
# The first dimension of dummy_hidden cannot be mark_dynamic because
# some operations in the sampler require it to be static.
torch._dynamo.mark_dynamic(dummy_hidden, 0)
self.compute_logits(dummy_hidden)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("compute_logits")

def _precompile_structured_decoding(self) -> None:
logger.info(
"Compiling structured_decoding with different input shapes.")
Comment on lines +1072 to +1073
Copy link
Member

@russellb russellb Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to keep this and other logs in this method? They seem more like debug logs to me, so perhaps logger.debug if you want to keep them?

Copy link
Contributor Author

@Chenyaaang Chenyaaang Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually I'm fine with both, but since _get_token_paddings is also using logger.info, so I'd like to keep it as it is. In my understanding, it logs the preparation steps.

start = time.perf_counter()
vocab_size = self.model_config.get_vocab_size()
for num_reqs in self.num_reqs_paddings:
dummy_logits = torch.zeros((num_reqs, vocab_size),
device=self.device,
dtype=self._hidden_states_dtype)
dummy_require_struct_decoding = torch.zeros(
num_reqs, dtype=torch.bool, device=self.device).unsqueeze(1)
dummy_grammar_bitmask = torch.zeros(
(num_reqs, cdiv(vocab_size, 32)),
dtype=torch.int32,
device=self.device)
# The first dimension of the above 3 dummy tensors cannot be
# mark_dynamic because some operations in structured_decode require
# them to be static.
arange = torch.arange(0, 32).to(self.device)
self.structured_decode(dummy_require_struct_decoding,
dummy_grammar_bitmask, dummy_logits, arange)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("structured_decoding")

def _precompile_sample_from_logits(self) -> None:
logger.info(
"Compiling sample_from_logits with different input shapes.")
start = time.perf_counter()
vocab_size = self.model_config.get_vocab_size()
for num_reqs in self.num_reqs_paddings:
dummy_logits = torch.zeros((num_reqs, vocab_size),
device=self.device,
dtype=self._hidden_states_dtype)
# The first dimension of dummy_logits cannot be mark_dynamic
# because some operations in the sampler require it to be static.
for all_greedy in [False, True]:
generate_params_if_all_greedy = not all_greedy
sampling_metadata = (
Expand All @@ -909,12 +962,12 @@ def _precompile_sample_from_hidden(self) -> None:
generate_params_if_all_greedy,
))
sampling_metadata.all_greedy = all_greedy
self.sample_from_hidden(dummy_hidden, sampling_metadata)
self.sample_from_logits(dummy_logits, sampling_metadata)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("sampling")
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("sample_from_logits")

def capture_model(self) -> None:
"""
Expand All @@ -923,7 +976,9 @@ def capture_model(self) -> None:
# TODO: precompile encoder
self._precompile_backbone()
self._precompile_select_hidden_states()
self._precompile_sample_from_hidden()
self._precompile_compute_logits()
self._precompile_structured_decoding()
self._precompile_sample_from_logits()

def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Expand Down Expand Up @@ -980,29 +1035,86 @@ def select_hidden_states(self, hidden_states, indices_do_sample):
return hidden_states[indices_do_sample]

@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def sample_from_hidden(
self,
sample_hidden_states: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> torch.Tensor:
"""
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
"""
logits = self.model.compute_logits(sample_hidden_states, None)
def compute_logits(self,
sample_hidden_states: torch.Tensor) -> torch.Tensor:
return self.model.compute_logits(sample_hidden_states, None)

@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def sample_from_logits(
self, logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
if sampling_metadata.all_greedy:
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
else:
out_tokens = self.sampler(logits,
sampling_metadata).sampled_token_ids
return out_tokens

@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def structured_decode(self, require_struct_decoding: torch.Tensor,
grammar_bitmask: torch.Tensor, logits: torch.Tensor,
arange: torch.Tensor) -> torch.Tensor:
return torch.where(
require_struct_decoding,
self.apply_grammar_bitmask(logits, grammar_bitmask, arange),
logits)

def apply_grammar_bitmask(self, logits: torch.Tensor,
grammar_bitmask: torch.Tensor,
arange: torch.Tensor):
assert (logits.shape[0] == grammar_bitmask.shape[0])
vocab_size = logits.shape[1]
logits_cloned = logits.clone()
for i in range(logits.shape[0]):
unpacked_bitmask = (torch.bitwise_right_shift(
grammar_bitmask[i][:, None], arange[None, :]) & 1) == 0
unpacked_bitmask = unpacked_bitmask.reshape(-1)[:vocab_size]
logits_cloned[i] = logits_cloned[i].masked_fill(
unpacked_bitmask, -float("inf"))
return logits_cloned

def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)

def get_input_embeddings(self, *args, **kwargs):
return self.model.get_input_embeddings(*args, **kwargs)

def prepare_structured_decoding_input(
self, logits: torch.Tensor, scheduler_output: "SchedulerOutput"
) -> tuple[torch.Tensor, torch.Tensor]:
grammar_bitmask = scheduler_output.grammar_bitmask
assert grammar_bitmask is not None
num_reqs, vocab_size = logits.shape

# Pad grammar bitmask so that it won't cause recompilation.
grammar_bitmask_padded = torch.zeros((num_reqs, cdiv(vocab_size, 32)),
dtype=torch.int32,
device="cpu")

# We receive the structured output bitmask from the scheduler, but the
# indices of the requests in the batch may not match the indices of
# the bitmask since the scheduler doesn't know how the tpu runner is
# ordering the requests in the batch. We need to match the order of
# bitmask with the order of requests
struct_out_indices: list[int] = []
mask_indices: list[int] = []
for req_id in self.input_batch.req_ids:
mask_index = scheduler_output.structured_output_request_ids.get(
req_id)
if mask_index is None:
continue
batch_index = self.input_batch.req_id_to_index[req_id]
struct_out_indices.append(batch_index)
grammar_bitmask_padded[struct_out_indices] = torch.from_numpy(
grammar_bitmask[mask_indices])
# It's not guaranteed that all requests in this batch require
# structured output, so create a bool tensor to represent
# the requests that need structured output.
require_struct_out = torch.zeros(num_reqs, dtype=torch.bool)
struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long)
require_struct_out[struct_out_indices] = True
return require_struct_out.unsqueeze(1), grammar_bitmask_padded


def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
logger.info("Preparing request paddings:")
Expand Down