Skip to content

Commit 603d269

Browse files
chanhChanh Nguyen
authored andcommitted
[Core] Speed up decode by remove synchronizing operation in sampler (vllm-project#16436)
Signed-off-by: Chanh Nguyen <[email protected]> Co-authored-by: Chanh Nguyen <[email protected]>
1 parent 9340e1b commit 603d269

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

vllm/model_executor/layers/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,15 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
4747
output_tokens_tensor, vocab_size, num_seqs)
4848
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
4949
1, vocab_size)
50-
logits[logits > 0] /= torch.where(prompt_mask | output_mask,
51-
repetition_penalties, 1.0)[logits > 0]
52-
logits[logits <= 0] *= torch.where(prompt_mask | output_mask,
53-
repetition_penalties, 1.0)[logits <= 0]
50+
51+
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
52+
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
53+
1.0)
54+
55+
# If logits are positive, divide by penalty, otherwise multiply by penalty.
56+
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
57+
logits *= scaling
58+
5459
# We follow the definition in OpenAI API.
5560
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
5661
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts

0 commit comments

Comments
 (0)