@@ -1043,8 +1043,8 @@ def __init__(
1043
1043
self .q_proj = q_proj
1044
1044
self .kv_b_proj = kv_b_proj
1045
1045
self .o_proj = o_proj
1046
- self .triton_fa_func = triton_attention
1047
1046
1047
+ self .triton_fa_func = triton_attention
1048
1048
# Handle the differences between the flash_attn_varlen from flash_attn
1049
1049
# and the one from vllm_flash_attn. The former is used on RoCM and the
1050
1050
# latter has an additional parameter to control FA2 vs FA3
@@ -1055,6 +1055,70 @@ def __init__(
1055
1055
functools .partial (flash_attn_varlen_func ,
1056
1056
fa_version = self .vllm_flash_attn_version )
1057
1057
1058
+ # For MLA the v head dim is smaller than qk head dim so we pad out
1059
+ # v with 0s to match the qk head dim for attention backends that do
1060
+ # not support different headdims
1061
+ # We don't need to pad V if we are on a hopper system with FA3
1062
+ self ._pad_v = self .vllm_flash_attn_version is None or not (
1063
+ self .vllm_flash_attn_version == 3
1064
+ and current_platform .get_device_capability ()[0 ] == 9 )
1065
+
1066
+ def _flash_attn_varlen_diff_headdims (self , q , k , v , softmax_scale ,
1067
+ return_softmax_lse , ** kwargs ):
1068
+ maybe_padded_v = v
1069
+ if self ._pad_v :
1070
+ maybe_padded_v = torch .nn .functional .pad (
1071
+ v , [0 , q .shape [- 1 ] - v .shape [- 1 ]], value = 0 )
1072
+
1073
+ if is_hip and envs .VLLM_USE_TRITON_FLASH_ATTN \
1074
+ and not return_softmax_lse :
1075
+ attn_out = self .triton_fa_func (
1076
+ q ,
1077
+ k ,
1078
+ maybe_padded_v ,
1079
+ ** kwargs ,
1080
+ )
1081
+ if is_vllm_fa :
1082
+ attn_out = self .flash_attn_varlen_func (
1083
+ q = q ,
1084
+ k = k ,
1085
+ v = maybe_padded_v ,
1086
+ return_softmax_lse = return_softmax_lse ,
1087
+ softmax_scale = softmax_scale ,
1088
+ ** kwargs ,
1089
+ )
1090
+ else :
1091
+ # Use return_attn_probs instead of return_softmax_lse for RoCM
1092
+ attn_out = self .flash_attn_varlen_func (
1093
+ q = q ,
1094
+ k = k ,
1095
+ v = maybe_padded_v ,
1096
+ return_attn_probs = return_softmax_lse ,
1097
+ softmax_scale = softmax_scale ,
1098
+ ** kwargs ,
1099
+ )
1100
+
1101
+ # Unpack the output if there is multiple results,
1102
+ # triton always returns (output, softmax_lse),
1103
+ # vllm_flash_attn returns (output, softmax_lse) when
1104
+ # `return_softmax_lse = True`
1105
+ # flash_attn (RoCM) returns (output, softmax_lse, ...) when
1106
+ # `return_attn_probs = True`
1107
+ rest = None
1108
+ if isinstance (attn_out , tuple ):
1109
+ attn_out , * rest = attn_out
1110
+
1111
+ # unpad if necessary
1112
+ if self ._pad_v :
1113
+ attn_out = attn_out [..., :v .shape [- 1 ]]
1114
+
1115
+ # Remain consistent with old `flash_attn_varlen_func` where there
1116
+ # is only one output tensor if `return_softmax_lse` is False.
1117
+ if return_softmax_lse :
1118
+ assert rest is not None
1119
+ return attn_out , rest [0 ]
1120
+ return attn_out
1121
+
1058
1122
def _v_up_proj_and_o_proj (self , x ):
1059
1123
# Convert from (B, N, L) to (N, B, L)
1060
1124
x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
@@ -1176,40 +1240,19 @@ def _compute_prefill_context(
1176
1240
k = torch .cat ((k_nope , k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))),
1177
1241
dim = - 1 )
1178
1242
1179
- # For MLA the v head dim is smaller than qk head dim so we pad
1180
- # out v with 0s to match the qk head dim
1181
- v_padded = torch .nn .functional .pad (v ,
1182
- [0 , q .shape [- 1 ] - v .shape [- 1 ]],
1183
- value = 0 )
1184
-
1185
- if is_vllm_fa :
1186
- attn_output , attn_softmax_lse = self .flash_attn_varlen_func (
1187
- q = q ,
1188
- k = k ,
1189
- v = v_padded ,
1190
- cu_seqlens_q = prefill_metadata .query_start_loc ,
1191
- cu_seqlens_k = prefill_metadata .context_chunk_cu_seq_lens [i ],
1192
- max_seqlen_q = prefill_metadata .max_query_len ,
1193
- max_seqlen_k = prefill_metadata .
1194
- context_chunk_max_seq_lens [i ],
1195
- softmax_scale = self .scale ,
1196
- causal = False , # Context is unmasked
1197
- return_softmax_lse = True ,
1198
- )
1199
- else :
1200
- attn_output , attn_softmax_lse , _ = self .flash_attn_varlen_func (
1201
- q = q ,
1202
- k = k ,
1203
- v = v_padded ,
1204
- cu_seqlens_q = prefill_metadata .query_start_loc ,
1205
- cu_seqlens_k = prefill_metadata .context_chunk_cu_seq_lens [i ],
1206
- max_seqlen_q = prefill_metadata .max_query_len ,
1207
- max_seqlen_k = prefill_metadata .
1208
- context_chunk_max_seq_lens [i ],
1209
- softmax_scale = self .scale ,
1210
- causal = False , # Context is unmasked
1211
- return_attn_probs = True ,
1212
- )
1243
+ attn_output , attn_softmax_lse = \
1244
+ self ._flash_attn_varlen_diff_headdims (
1245
+ q = q ,
1246
+ k = k ,
1247
+ v = v ,
1248
+ cu_seqlens_q = prefill_metadata .query_start_loc ,
1249
+ cu_seqlens_k = prefill_metadata .context_chunk_cu_seq_lens [i ],
1250
+ max_seqlen_q = prefill_metadata .max_query_len ,
1251
+ max_seqlen_k = prefill_metadata .context_chunk_max_seq_lens [i ],
1252
+ softmax_scale = self .scale ,
1253
+ causal = False , # Context is unmasked
1254
+ return_softmax_lse = True ,
1255
+ )
1213
1256
1214
1257
if output is None :
1215
1258
output = attn_output
@@ -1252,58 +1295,22 @@ def _forward_prefill(
1252
1295
1253
1296
k = torch .cat ((k_nope , k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))), dim = - 1 )
1254
1297
1255
- # For MLA the v head dim is smaller than qk head dim so we pad out
1256
- # v with 0s to match the qk head dim
1257
- v_padded = torch .nn .functional .pad (v , [0 , q .shape [- 1 ] - v .shape [- 1 ]],
1258
- value = 0 )
1259
-
1260
- if is_hip and envs .VLLM_USE_TRITON_FLASH_ATTN and not has_context :
1261
- output = self .triton_fa_func (
1262
- q ,
1263
- k ,
1264
- v_padded ,
1265
- None ,
1266
- prefill_metadata .query_start_loc ,
1267
- prefill_metadata .query_start_loc ,
1268
- prefill_metadata .max_prefill_seq_len ,
1269
- prefill_metadata .max_prefill_seq_len ,
1270
- True , # causal
1271
- self .scale ,
1272
- None , # attn_mask is None unless applying ALiBi mask
1273
- )
1274
- ## triton flash attention always return 2 objects
1275
- if not has_context :
1276
- output = output [0 ]
1277
- elif is_vllm_fa :
1278
- output = self .flash_attn_varlen_func (
1279
- q = q ,
1280
- k = k ,
1281
- v = v_padded ,
1282
- cu_seqlens_q = prefill_metadata .query_start_loc ,
1283
- cu_seqlens_k = prefill_metadata .query_start_loc ,
1284
- max_seqlen_q = prefill_metadata .max_prefill_seq_len ,
1285
- max_seqlen_k = prefill_metadata .max_prefill_seq_len ,
1286
- softmax_scale = self .scale ,
1287
- causal = True ,
1288
- return_softmax_lse = has_context ,
1289
- )
1290
- else :
1291
- output = self .flash_attn_varlen_func (
1292
- q = q ,
1293
- k = k ,
1294
- v = v_padded ,
1295
- cu_seqlens_q = prefill_metadata .query_start_loc ,
1296
- cu_seqlens_k = prefill_metadata .query_start_loc ,
1297
- max_seqlen_q = prefill_metadata .max_prefill_seq_len ,
1298
- max_seqlen_k = prefill_metadata .max_prefill_seq_len ,
1299
- softmax_scale = self .scale ,
1300
- causal = True ,
1301
- return_attn_probs = has_context ,
1302
- )
1298
+ output = self ._flash_attn_varlen_diff_headdims (
1299
+ q = q ,
1300
+ k = k ,
1301
+ v = v ,
1302
+ cu_seqlens_q = prefill_metadata .query_start_loc ,
1303
+ cu_seqlens_k = prefill_metadata .query_start_loc ,
1304
+ max_seqlen_q = prefill_metadata .max_prefill_seq_len ,
1305
+ max_seqlen_k = prefill_metadata .max_prefill_seq_len ,
1306
+ softmax_scale = self .scale ,
1307
+ causal = True ,
1308
+ return_softmax_lse = has_context ,
1309
+ )
1303
1310
1304
1311
if has_context :
1305
1312
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
1306
- suffix_output , suffix_lse , * rest = output
1313
+ suffix_output , suffix_lse = output
1307
1314
context_output , context_lse = self ._compute_prefill_context ( \
1308
1315
q , kv_c_and_k_pe_cache , attn_metadata )
1309
1316
@@ -1316,12 +1323,7 @@ def _forward_prefill(
1316
1323
suffix_lse = suffix_lse ,
1317
1324
)
1318
1325
1319
- # slice by `:v.shape[-1]` in order to remove v headdim padding
1320
- output = output \
1321
- .view (- 1 , self .num_heads , q .shape [- 1 ])[..., :v .shape [- 1 ]]\
1322
- .reshape (- 1 , self .num_heads * v .shape [- 1 ])
1323
-
1324
- return self .o_proj (output )[0 ]
1326
+ return self .o_proj (output .flatten (start_dim = - 2 ))[0 ]
1325
1327
1326
1328
@abstractmethod
1327
1329
def _forward_decode (
0 commit comments