File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -66,7 +66,10 @@ def merge_attn_states_kernel(
66
66
max_lse = tl .maximum (p_lse , s_lse )
67
67
p_lse = p_lse - max_lse
68
68
s_lse = s_lse - max_lse
69
- out_se = (tl .exp (p_lse ) + tl .exp (s_lse ))
69
+ # Will reuse precomputed Exp values for scale factor computation.
70
+ p_se = tl .exp (p_lse )
71
+ s_se = tl .exp (s_lse )
72
+ out_se = (p_se + s_se )
70
73
71
74
if OUTPUT_LSE :
72
75
out_lse = tl .log (out_se ) + max_lse
@@ -84,8 +87,8 @@ def merge_attn_states_kernel(
84
87
# NOTE(woosuk): Be careful with the numerical stability.
85
88
# We should compute the scale first, and then multiply it with the output.
86
89
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
87
- p_scale = tl . exp ( p_lse ) / out_se
88
- s_scale = tl . exp ( s_lse ) / out_se
90
+ p_scale = p_se / out_se
91
+ s_scale = s_se / out_se
89
92
out = p_out * p_scale + s_out * s_scale
90
93
tl .store (output + token_idx * num_heads * HEAD_SIZE +
91
94
head_idx * HEAD_SIZE + head_arange ,
You can’t perform that action at this time.
0 commit comments