Skip to content

Commit f3c6ec2

Browse files
KuuCiVincent Chen
andauthored
Bump FA2 to 2.7.4.post1 (#1728)
Co-authored-by: Vincent Chen <[email protected]>
1 parent a0ae025 commit f3c6ec2

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

llmfoundry/models/layers/blocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def apply_ffn(
231231
if not self.use_pad_tok_in_ffn and attention_mask is not None:
232232
assert unpad_input is not None
233233
attention_mask = self.slice_attention_mask(attention_mask, seq_len)
234-
m, indices, _, _ = unpad_input(m, attention_mask)
234+
m, indices, *_ = unpad_input(m, attention_mask)
235235
n = self.ffn(m)
236236
if not self.use_pad_tok_in_ffn and attention_mask is not None:
237237
assert pad_input is not None

llmfoundry/models/mpt/modeling_mpt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,15 +289,15 @@ def gen_flash_attn_padding_info(
289289
query_padding_mask = attention_mask_in_length
290290
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
291291

292-
_, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
292+
_, indices_q, cu_seqlens_q, max_seqlen_q, *_ = unpadding_function(
293293
torch.empty(bsz, S, 1, device=device),
294294
query_padding_mask,
295295
)
296-
_, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(
296+
_, indices_k, cu_seqlens_k, max_seqlen_k, *_ = unpadding_function(
297297
torch.empty(bsz, past_key_len + S, 1, device=device),
298298
key_padding_mask,
299299
)
300-
_, indices_v, _, _ = unpadding_function(
300+
_, indices_v, *_ = unpadding_function(
301301
torch.empty(bsz, past_key_len + S, 1, device=device),
302302
key_padding_mask,
303303
)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@
104104

105105
# Flash 2 group kept for backwards compatibility
106106
extra_deps['gpu-flash2'] = [
107-
'flash-attn==2.6.3',
107+
'flash-attn==2.7.4.post1',
108108
]
109109

110110
extra_deps['gpu'] = copy.deepcopy(extra_deps['gpu-flash2'])

0 commit comments

Comments
 (0)