Skip to content

Bump FA2 to 2.7.4.post1 #1728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Mar 12, 2025
Merged
21 changes: 21 additions & 0 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,27 @@ def __init__(
additional_eval_metrics: Optional[list] = None,
should_save_peft_only: bool = True,
):
is_mpt_model = pretrained_model_name_or_path.startswith('mosaicml/mpt')
# Check for attn_impl in config_overrides
is_flash_requested = use_flash_attention_2
if config_overrides and 'attn_config' in config_overrides:
attn_config = config_overrides.get('attn_config', {})
if isinstance(attn_config,
dict) and attn_config.get('attn_impl') == 'flash':
is_flash_requested = True

if is_mpt_model and is_flash_requested:
import importlib.metadata

from packaging import version
flash_version = importlib.metadata.version('flash-attn')
if version.parse(flash_version) > version.parse('2.6.9'):
raise ValueError(
f'Flash Attention version {flash_version} (>2.6) is not supported with MPT models. '
+
'Please use Flash Attention version 2.6 or earlier, or use a different attention implementation.',
)

super().__init__(
pretrained_model_name_or_path,
tokenizer=tokenizer,
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def apply_ffn(
if not self.use_pad_tok_in_ffn and attention_mask is not None:
assert unpad_input is not None
attention_mask = self.slice_attention_mask(attention_mask, seq_len)
m, indices, _, _ = unpad_input(m, attention_mask)
m, indices, *_ = unpad_input(m, attention_mask)
n = self.ffn(m)
if not self.use_pad_tok_in_ffn and attention_mask is not None:
assert pad_input is not None
Expand Down
6 changes: 3 additions & 3 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,15 +289,15 @@ def gen_flash_attn_padding_info(
query_padding_mask = attention_mask_in_length
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences

_, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
_, indices_q, cu_seqlens_q, max_seqlen_q, *_ = unpadding_function(
torch.empty(bsz, S, 1, device=device),
query_padding_mask,
)
_, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(
_, indices_k, cu_seqlens_k, max_seqlen_k, *_ = unpadding_function(
torch.empty(bsz, past_key_len + S, 1, device=device),
key_padding_mask,
)
_, indices_v, _, _ = unpadding_function(
_, indices_v, *_ = unpadding_function(
torch.empty(bsz, past_key_len + S, 1, device=device),
key_padding_mask,
)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@

# Flash 2 group kept for backwards compatibility
extra_deps['gpu-flash2'] = [
'flash-attn==2.6.3',
'flash-attn==2.7.4.post1',
]

extra_deps['gpu'] = copy.deepcopy(extra_deps['gpu-flash2'])
Expand Down
Loading