File tree Expand file tree Collapse file tree 4 files changed +873
-853
lines changed
fbgemm_gpu/fbgemm_gpu/sll Expand file tree Collapse file tree 4 files changed +873
-853
lines changed Original file line number Diff line number Diff line change 41
41
jagged_dense_bmm ,
42
42
jagged_dense_elementwise_add ,
43
43
jagged_dense_elementwise_mul_jagged_out ,
44
- jagged_dense_flash_attention ,
45
44
jagged_flash_attention_basic ,
46
45
jagged_jagged_bmm ,
47
46
jagged_jagged_bmm_jagged_out ,
321
320
"CUDA" : jagged_dense_elementwise_add ,
322
321
"AutogradCUDA" : jagged_dense_elementwise_add ,
323
322
},
324
- "sll_jagged_dense_flash_attention" : {
325
- "CUDA" : jagged_dense_flash_attention ,
326
- "AutogradCUDA" : jagged_dense_flash_attention ,
327
- },
328
323
}
329
324
330
325
for op_name , dispatches in sll_cpu_registrations .items ():
Original file line number Diff line number Diff line change 8
8
# pyre-strict
9
9
10
10
11
+ from fbgemm_gpu .sll .triton .jagged_dense_flash_attention import ( # noqa F401
12
+ jagged_dense_flash_attention ,
13
+ JaggedDenseFlashAttention , # noqa F401
14
+ )
15
+
11
16
from fbgemm_gpu .sll .triton .multi_head_jagged_flash_attention import ( # noqa F401
12
17
multi_head_jagged_flash_attention ,
13
- MultiHeadJaggedFlashAttention ,
18
+ MultiHeadJaggedFlashAttention , # noqa F401
14
19
)
15
20
21
+ # pyre-ignore[5]
16
22
op_registrations = {
23
+ "sll_jagged_dense_flash_attention" : {
24
+ "CUDA" : jagged_dense_flash_attention ,
25
+ "AutogradCUDA" : jagged_dense_flash_attention ,
26
+ },
17
27
"sll_multi_head_jagged_flash_attention" : {
18
28
"CUDA" : multi_head_jagged_flash_attention ,
19
29
"AutogradCUDA" : multi_head_jagged_flash_attention ,
You can’t perform that action at this time.
0 commit comments