-
-
Notifications
You must be signed in to change notification settings - Fork 8.5k
[Core][V1][TPU] Enable structured decoding on TPU V1 #16499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
ceb6eb3
bc6dddb
2a8c28b
8a6827e
809073a
dc169ed
8b9ea03
fe043e4
24a3f9f
d882f6e
f152c0a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -688,9 +688,17 @@ def execute_model( | |
) | ||
hidden_states = self.select_hidden_states(hidden_states, | ||
logits_indices) | ||
logits = self.compute_logits(hidden_states) | ||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ | ||
from_input_batch(self.input_batch, padded_num_reqs, self.device) | ||
selected_token_ids = self.sample_from_hidden(hidden_states, | ||
if scheduler_output.grammar_bitmask is not None: | ||
require_struct_decoding, grammar_bitmask_padded = \ | ||
self.prepare_structured_decoding_input(logits, scheduler_output) | ||
logits = self.structured_decode( | ||
require_struct_decoding.to(logits.device), | ||
grammar_bitmask_padded.to(logits.device), logits, | ||
torch.arange(0, 32).to(logits.device)) | ||
selected_token_ids = self.sample_from_logits(logits, | ||
tpu_sampling_metadata) | ||
# Remove padding on cpu and keep dynamic op outside of xla graph. | ||
selected_token_ids = selected_token_ids.cpu()[:num_reqs] | ||
|
@@ -862,7 +870,7 @@ def _precompile_backbone(self) -> None: | |
self._dummy_run(num_tokens) | ||
xm.wait_device_ops() | ||
end = time.perf_counter() | ||
logger.info("Compilation finished in in %.2f [secs].", end - start) | ||
logger.info("Compilation finished in %.2f [secs].", end - start) | ||
self._update_num_xla_graphs("model backbone") | ||
|
||
def _precompile_select_hidden_states(self) -> None: | ||
|
@@ -886,19 +894,64 @@ def _precompile_select_hidden_states(self) -> None: | |
logger.info(" -- num_tokens: %d", num_tokens) | ||
xm.wait_device_ops() | ||
end = time.perf_counter() | ||
logger.info("Compilation finished in in %.2f [secs].", end - start) | ||
logger.info("Compilation finished in %.2f [secs].", end - start) | ||
self._update_num_xla_graphs("select_hidden_states") | ||
|
||
def _precompile_sample_from_hidden(self) -> None: | ||
logger.info("Compiling sampling with different input shapes.") | ||
def _precompile_compute_logits(self) -> None: | ||
logger.info( | ||
"Compiling select_hidden_states with different input shapes.") | ||
start = time.perf_counter() | ||
hsize = self.model_config.get_hidden_size() | ||
for num_reqs in self.num_reqs_paddings: | ||
dummy_hidden = torch.zeros((num_reqs, hsize), | ||
device=self.device, | ||
dtype=self._hidden_states_dtype) | ||
# The first dimension of dummy_hidden cannot be mark_dynamic because | ||
# some operations in the sampler require it to be static. | ||
torch._dynamo.mark_dynamic(dummy_hidden, 0) | ||
self.compute_logits(dummy_hidden) | ||
logger.info(" -- num_seqs: %d", num_reqs) | ||
xm.wait_device_ops() | ||
end = time.perf_counter() | ||
logger.info("Compilation finished in %.2f [secs].", end - start) | ||
self._update_num_xla_graphs("compute_logits") | ||
|
||
def _precompile_structured_decoding(self) -> None: | ||
logger.info( | ||
"Compiling structured_decoding with different input shapes.") | ||
Comment on lines
+1072
to
+1073
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you want to keep this and other logs in this method? They seem more like debug logs to me, so perhaps There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually I'm fine with both, but since |
||
start = time.perf_counter() | ||
vocab_size = self.model_config.get_vocab_size() | ||
Chenyaaang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for num_reqs in self.num_reqs_paddings: | ||
dummy_logits = torch.zeros((num_reqs, vocab_size), | ||
device=self.device, | ||
dtype=self._hidden_states_dtype) | ||
dummy_require_struct_decoding = torch.zeros( | ||
num_reqs, dtype=torch.bool, device=self.device).unsqueeze(1) | ||
Chenyaaang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dummy_grammar_bitmask = torch.zeros( | ||
(num_reqs, cdiv(vocab_size, 32)), | ||
dtype=torch.int32, | ||
device=self.device) | ||
# The first dimension of the above 3 dummy tensors cannot be | ||
# mark_dynamic because some operations in structured_decode require | ||
# them to be static. | ||
arange = torch.arange(0, 32).to(self.device) | ||
self.structured_decode(dummy_require_struct_decoding, | ||
dummy_grammar_bitmask, dummy_logits, arange) | ||
logger.info(" -- num_seqs: %d", num_reqs) | ||
xm.wait_device_ops() | ||
end = time.perf_counter() | ||
logger.info("Compilation finished in %.2f [secs].", end - start) | ||
self._update_num_xla_graphs("structured_decoding") | ||
|
||
def _precompile_sample_from_logits(self) -> None: | ||
logger.info( | ||
"Compiling sample_from_logits with different input shapes.") | ||
start = time.perf_counter() | ||
vocab_size = self.model_config.get_vocab_size() | ||
for num_reqs in self.num_reqs_paddings: | ||
dummy_logits = torch.zeros((num_reqs, vocab_size), | ||
device=self.device, | ||
dtype=self._hidden_states_dtype) | ||
# The first dimension of dummy_logits cannot be mark_dynamic | ||
# because some operations in the sampler require it to be static. | ||
for all_greedy in [False, True]: | ||
generate_params_if_all_greedy = not all_greedy | ||
sampling_metadata = ( | ||
|
@@ -909,12 +962,12 @@ def _precompile_sample_from_hidden(self) -> None: | |
generate_params_if_all_greedy, | ||
)) | ||
sampling_metadata.all_greedy = all_greedy | ||
self.sample_from_hidden(dummy_hidden, sampling_metadata) | ||
self.sample_from_logits(dummy_logits, sampling_metadata) | ||
logger.info(" -- num_seqs: %d", num_reqs) | ||
xm.wait_device_ops() | ||
end = time.perf_counter() | ||
logger.info("Compilation finished in in %.2f [secs].", end - start) | ||
self._update_num_xla_graphs("sampling") | ||
logger.info("Compilation finished in %.2f [secs].", end - start) | ||
self._update_num_xla_graphs("sample_from_logits") | ||
|
||
def capture_model(self) -> None: | ||
""" | ||
|
@@ -923,7 +976,9 @@ def capture_model(self) -> None: | |
# TODO: precompile encoder | ||
self._precompile_backbone() | ||
self._precompile_select_hidden_states() | ||
self._precompile_sample_from_hidden() | ||
self._precompile_compute_logits() | ||
self._precompile_structured_decoding() | ||
self._precompile_sample_from_logits() | ||
|
||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: | ||
""" | ||
|
@@ -980,29 +1035,86 @@ def select_hidden_states(self, hidden_states, indices_do_sample): | |
return hidden_states[indices_do_sample] | ||
|
||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False) | ||
def sample_from_hidden( | ||
self, | ||
sample_hidden_states: torch.Tensor, | ||
sampling_metadata: TPUSupportedSamplingMetadata, | ||
) -> torch.Tensor: | ||
""" | ||
Sample with xla-friendly function. This function is to be traced | ||
separately from `forward` for lighter compilation overhead. | ||
""" | ||
logits = self.model.compute_logits(sample_hidden_states, None) | ||
def compute_logits(self, | ||
sample_hidden_states: torch.Tensor) -> torch.Tensor: | ||
return self.model.compute_logits(sample_hidden_states, None) | ||
|
||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False) | ||
def sample_from_logits( | ||
self, logits: torch.Tensor, | ||
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor: | ||
if sampling_metadata.all_greedy: | ||
out_tokens = torch.argmax(logits, dim=-1, keepdim=True) | ||
else: | ||
out_tokens = self.sampler(logits, | ||
sampling_metadata).sampled_token_ids | ||
return out_tokens | ||
|
||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False) | ||
def structured_decode(self, require_struct_decoding: torch.Tensor, | ||
grammar_bitmask: torch.Tensor, logits: torch.Tensor, | ||
arange: torch.Tensor) -> torch.Tensor: | ||
return torch.where( | ||
require_struct_decoding, | ||
self.apply_grammar_bitmask(logits, grammar_bitmask, arange), | ||
logits) | ||
|
||
def apply_grammar_bitmask(self, logits: torch.Tensor, | ||
grammar_bitmask: torch.Tensor, | ||
arange: torch.Tensor): | ||
assert (logits.shape[0] == grammar_bitmask.shape[0]) | ||
vocab_size = logits.shape[1] | ||
logits_cloned = logits.clone() | ||
Chenyaaang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for i in range(logits.shape[0]): | ||
unpacked_bitmask = (torch.bitwise_right_shift( | ||
grammar_bitmask[i][:, None], arange[None, :]) & 1) == 0 | ||
unpacked_bitmask = unpacked_bitmask.reshape(-1)[:vocab_size] | ||
logits_cloned[i] = logits_cloned[i].masked_fill( | ||
unpacked_bitmask, -float("inf")) | ||
return logits_cloned | ||
|
||
def get_multimodal_embeddings(self, *args, **kwargs): | ||
return self.model.get_multimodal_embeddings(*args, **kwargs) | ||
|
||
def get_input_embeddings(self, *args, **kwargs): | ||
return self.model.get_input_embeddings(*args, **kwargs) | ||
|
||
def prepare_structured_decoding_input( | ||
self, logits: torch.Tensor, scheduler_output: "SchedulerOutput" | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
grammar_bitmask = scheduler_output.grammar_bitmask | ||
assert grammar_bitmask is not None | ||
num_reqs, vocab_size = logits.shape | ||
|
||
# Pad grammar bitmask so that it won't cause recompilation. | ||
grammar_bitmask_padded = torch.zeros((num_reqs, cdiv(vocab_size, 32)), | ||
Chenyaaang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dtype=torch.int32, | ||
device="cpu") | ||
|
||
# We receive the structured output bitmask from the scheduler, but the | ||
# indices of the requests in the batch may not match the indices of | ||
# the bitmask since the scheduler doesn't know how the tpu runner is | ||
# ordering the requests in the batch. We need to match the order of | ||
# bitmask with the order of requests | ||
struct_out_indices: list[int] = [] | ||
mask_indices: list[int] = [] | ||
for req_id in self.input_batch.req_ids: | ||
mask_index = scheduler_output.structured_output_request_ids.get( | ||
req_id) | ||
if mask_index is None: | ||
continue | ||
batch_index = self.input_batch.req_id_to_index[req_id] | ||
struct_out_indices.append(batch_index) | ||
grammar_bitmask_padded[struct_out_indices] = torch.from_numpy( | ||
grammar_bitmask[mask_indices]) | ||
# It's not guaranteed that all requests in this batch require | ||
# structured output, so create a bool tensor to represent | ||
# the requests that need structured output. | ||
require_struct_out = torch.zeros(num_reqs, dtype=torch.bool) | ||
struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long) | ||
require_struct_out[struct_out_indices] = True | ||
return require_struct_out.unsqueeze(1), grammar_bitmask_padded | ||
|
||
|
||
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: | ||
logger.info("Preparing request paddings:") | ||
|
Uh oh!
There was an error while loading. Please reload this page.