diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index c88cf33d1b..00f85a6b4b 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -231,7 +231,7 @@ def apply_ffn( if not self.use_pad_tok_in_ffn and attention_mask is not None: assert unpad_input is not None attention_mask = self.slice_attention_mask(attention_mask, seq_len) - m, indices, _, _ = unpad_input(m, attention_mask) + m, indices, *_ = unpad_input(m, attention_mask) n = self.ffn(m) if not self.use_pad_tok_in_ffn and attention_mask is not None: assert pad_input is not None diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 1e9e043f84..c2dc4d390c 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -289,15 +289,15 @@ def gen_flash_attn_padding_info( query_padding_mask = attention_mask_in_length unpadding_function = bert_padding.unpad_input_for_concatenated_sequences - _, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( + _, indices_q, cu_seqlens_q, max_seqlen_q, *_ = unpadding_function( torch.empty(bsz, S, 1, device=device), query_padding_mask, ) - _, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function( + _, indices_k, cu_seqlens_k, max_seqlen_k, *_ = unpadding_function( torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask, ) - _, indices_v, _, _ = unpadding_function( + _, indices_v, *_ = unpadding_function( torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask, ) diff --git a/setup.py b/setup.py index 2578403c96..3621d4e5d6 100644 --- a/setup.py +++ b/setup.py @@ -104,7 +104,7 @@ # Flash 2 group kept for backwards compatibility extra_deps['gpu-flash2'] = [ - 'flash-attn==2.6.3', + 'flash-attn==2.7.4.post1', ] extra_deps['gpu'] = copy.deepcopy(extra_deps['gpu-flash2'])