Skip to content

Commit c60c986

Browse files
q10facebook-github-bot
authored andcommitted
Re-organize SLL ops, pt 2 (pytorch#719)
Summary: X-link: pytorch#3643 Pull Request resolved: facebookresearch/FBGEMM#719 - Re-organize `jagged_dense_flash_attention` Reviewed By: sryap Differential Revision: D68916405 fbshipit-source-id: 688f41ae4b96684697af8538611f9f6a800e7ff2
1 parent f02bb7d commit c60c986

File tree

4 files changed

+873
-853
lines changed

4 files changed

+873
-853
lines changed

fbgemm_gpu/fbgemm_gpu/sll/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
jagged_dense_bmm,
4242
jagged_dense_elementwise_add,
4343
jagged_dense_elementwise_mul_jagged_out,
44-
jagged_dense_flash_attention,
4544
jagged_flash_attention_basic,
4645
jagged_jagged_bmm,
4746
jagged_jagged_bmm_jagged_out,
@@ -321,10 +320,6 @@
321320
"CUDA": jagged_dense_elementwise_add,
322321
"AutogradCUDA": jagged_dense_elementwise_add,
323322
},
324-
"sll_jagged_dense_flash_attention": {
325-
"CUDA": jagged_dense_flash_attention,
326-
"AutogradCUDA": jagged_dense_flash_attention,
327-
},
328323
}
329324

330325
for op_name, dispatches in sll_cpu_registrations.items():

fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,22 @@
88
# pyre-strict
99

1010

11+
from fbgemm_gpu.sll.triton.jagged_dense_flash_attention import ( # noqa F401
12+
jagged_dense_flash_attention,
13+
JaggedDenseFlashAttention, # noqa F401
14+
)
15+
1116
from fbgemm_gpu.sll.triton.multi_head_jagged_flash_attention import ( # noqa F401
1217
multi_head_jagged_flash_attention,
13-
MultiHeadJaggedFlashAttention,
18+
MultiHeadJaggedFlashAttention, # noqa F401
1419
)
1520

21+
# pyre-ignore[5]
1622
op_registrations = {
23+
"sll_jagged_dense_flash_attention": {
24+
"CUDA": jagged_dense_flash_attention,
25+
"AutogradCUDA": jagged_dense_flash_attention,
26+
},
1727
"sll_multi_head_jagged_flash_attention": {
1828
"CUDA": multi_head_jagged_flash_attention,
1929
"AutogradCUDA": multi_head_jagged_flash_attention,

0 commit comments

Comments
 (0)