Skip to content

Commit eed25a6

Browse files
DefTruthlk-chen
authored andcommitted
[Kernel] Remove redundant Exp calculations (vllm-project#16123)
Signed-off-by: DefTruth <[email protected]>
1 parent 74d126a commit eed25a6

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

vllm/attention/ops/triton_merge_attn_states.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ def merge_attn_states_kernel(
6666
max_lse = tl.maximum(p_lse, s_lse)
6767
p_lse = p_lse - max_lse
6868
s_lse = s_lse - max_lse
69-
out_se = (tl.exp(p_lse) + tl.exp(s_lse))
69+
# Will reuse precomputed Exp values for scale factor computation.
70+
p_se = tl.exp(p_lse)
71+
s_se = tl.exp(s_lse)
72+
out_se = (p_se + s_se)
7073

7174
if OUTPUT_LSE:
7275
out_lse = tl.log(out_se) + max_lse
@@ -84,8 +87,8 @@ def merge_attn_states_kernel(
8487
# NOTE(woosuk): Be careful with the numerical stability.
8588
# We should compute the scale first, and then multiply it with the output.
8689
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
87-
p_scale = tl.exp(p_lse) / out_se
88-
s_scale = tl.exp(s_lse) / out_se
90+
p_scale = p_se / out_se
91+
s_scale = s_se / out_se
8992
out = p_out * p_scale + s_out * s_scale
9093
tl.store(output + token_idx * num_heads * HEAD_SIZE +
9194
head_idx * HEAD_SIZE + head_arange,

0 commit comments

Comments
 (0)