Skip to content

Commit f373926

Browse files
q10avbokovoy
authored andcommitted
Re-organize SLL ops, pt 1 (pytorch#3642)
Summary: Pull Request resolved: pytorch#3642 X-link: https://github.com/facebookresearch/FBGEMM/pull/718 - Re-organize SLL ops, pt 1 Reviewed By: sryap Differential Revision: D68915217 fbshipit-source-id: 6208ef53c1740c5dbf89534cc2301c10242b82ea
1 parent 4739458 commit f373926

File tree

6 files changed

+795
-743
lines changed

6 files changed

+795
-743
lines changed

fbgemm_gpu/fbgemm_gpu/sll/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
jagged_jagged_bmm,
4747
jagged_jagged_bmm_jagged_out,
4848
jagged_softmax,
49-
multi_head_jagged_flash_attention,
5049
triton_jagged_self_substraction_jagged_out,
5150
)
5251

@@ -326,15 +325,16 @@
326325
"CUDA": jagged_dense_flash_attention,
327326
"AutogradCUDA": jagged_dense_flash_attention,
328327
},
329-
"sll_multi_head_jagged_flash_attention": {
330-
"CUDA": multi_head_jagged_flash_attention,
331-
"AutogradCUDA": multi_head_jagged_flash_attention,
332-
},
333328
}
334329

335330
for op_name, dispatches in sll_cpu_registrations.items():
336331
lib.register(op_name, dispatches)
337332

338333
if torch.cuda.is_available():
334+
from fbgemm_gpu.sll.triton import op_registrations
335+
336+
for op_name, dispatches in op_registrations.items():
337+
lib.register(op_name, dispatches)
338+
339339
for op_name, dispatches in sll_gpu_registrations.items():
340340
lib.register(op_name, dispatches)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
11+
from fbgemm_gpu.sll.triton.multi_head_jagged_flash_attention import ( # noqa F401
12+
multi_head_jagged_flash_attention,
13+
MultiHeadJaggedFlashAttention,
14+
)
15+
16+
op_registrations = {
17+
"sll_multi_head_jagged_flash_attention": {
18+
"CUDA": multi_head_jagged_flash_attention,
19+
"AutogradCUDA": multi_head_jagged_flash_attention,
20+
},
21+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
11+
12+
def expect_contiguous(x: torch.Tensor) -> torch.Tensor:
13+
if not x.is_contiguous():
14+
return x.contiguous()
15+
else:
16+
return x

0 commit comments

Comments
 (0)