Skip to content

Commit 7d342a6

Browse files
committed
fix bw compatibility issues
Signed-off-by: Austin Liu <[email protected]>
1 parent c350396 commit 7d342a6

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@
3838

3939
logger = logging.getLogger(__name__)
4040
SUPPORTED_TRANSFORMER_VERSION = "4.46.1"
41+
FLEXATTENTION_SUPPORTED_TRANSFORMER_VERSION = "4.48.0"
4142
TRANSFORMER_DEPRECATION_WARNING = "Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191"
42-
FLEX_ATTENTION_NOT_SUPPORT_WARNING = "Not support flex attention for this model yet"
43+
FLEX_ATTENTION_NOT_SUPPORT_WARNING = "Flex attention is not supported."
4344

4445

4546
def _bind_method_to_module(module, method_name: str, new_method: Callable):
@@ -120,9 +121,12 @@ def apply_liger_kernel_to_llama(
120121

121122
if flex_attn:
122123
# Patching HuggingFace default attn_impl from `toch.sdpa` to liger's `llama_flex_attention_forward``
123-
modeling_llama.ALL_ATTENTION_FUNCTIONS.update(
124-
{"sdpa": llama_flex_attention_forward, "flex_attention": llama_flex_attention_forward}
125-
)
124+
if transformer_version >= version.parse(FLEXATTENTION_SUPPORTED_TRANSFORMER_VERSION):
125+
modeling_llama.ALL_ATTENTION_FUNCTIONS.update(
126+
{"sdpa": llama_flex_attention_forward, "flex_attention": llama_flex_attention_forward}
127+
)
128+
else:
129+
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)
126130

127131
if model is not None:
128132
# The model instance already exists, so we need to additionally patch the

0 commit comments

Comments
 (0)