File tree Expand file tree Collapse file tree 1 file changed +8
-0
lines changed Expand file tree Collapse file tree 1 file changed +8
-0
lines changed Original file line number Diff line number Diff line change @@ -636,6 +636,14 @@ def apply_patches(self) -> None:
636
636
if torch .cuda .get_device_capability () >= (9 , 0 ):
637
637
# FA3 is only available on Hopper GPUs and newer
638
638
use_fa3 = True
639
+ if not importlib .util .find_spec ("flash_attn_interface" ):
640
+ use_fa3 = False
641
+ if use_fa3 and not importlib .util .find_spec ("flash_attn_interface" ):
642
+ # this can happen when use_flash_attention_3 is explicity set to True
643
+ # and flash_attn_interface is not installed
644
+ raise ModuleNotFoundError (
645
+ "Please install the flash_attn_interface library to use Flash Attention 3.x"
646
+ )
639
647
if use_fa3 and importlib .util .find_spec ("flash_attn_interface" ) is not None :
640
648
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
641
649
from flash_attn_interface import (
You can’t perform that action at this time.
0 commit comments