Skip to content

Commit 9bdf4b1

Browse files
committed
improve handling and error if fa3 requested but not installeD
1 parent d6f64a3 commit 9bdf4b1

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/axolotl/utils/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,14 @@ def apply_patches(self) -> None:
636636
if torch.cuda.get_device_capability() >= (9, 0):
637637
# FA3 is only available on Hopper GPUs and newer
638638
use_fa3 = True
639+
if not importlib.util.find_spec("flash_attn_interface"):
640+
use_fa3 = False
641+
if use_fa3 and not importlib.util.find_spec("flash_attn_interface"):
642+
# this can happen when use_flash_attention_3 is explicity set to True
643+
# and flash_attn_interface is not installed
644+
raise ModuleNotFoundError(
645+
"Please install the flash_attn_interface library to use Flash Attention 3.x"
646+
)
639647
if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None:
640648
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
641649
from flash_attn_interface import (

0 commit comments

Comments
 (0)