Skip to content

Commit 9894293

Browse files
q10facebook-github-bot
authored andcommitted
Re-organize SLL ops, pt 3 (pytorch#728)
Summary: Pull Request resolved: facebookresearch/FBGEMM#728 X-link: pytorch#3652 - Re-organize `jagged_dense_elementwise_add` Reviewed By: sryap Differential Revision: D68923208 fbshipit-source-id: 1c3f22d9588f11a664fb02843b457ea33bb40f9c
1 parent c60c986 commit 9894293

File tree

7 files changed

+64
-51
lines changed

7 files changed

+64
-51
lines changed

fbgemm_gpu/fbgemm_gpu/sll/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
jagged2_softmax,
4040
jagged2_to_padded_dense,
4141
jagged_dense_bmm,
42-
jagged_dense_elementwise_add,
4342
jagged_dense_elementwise_mul_jagged_out,
4443
jagged_flash_attention_basic,
4544
jagged_jagged_bmm,
@@ -316,10 +315,6 @@
316315
"CUDA": jagged_flash_attention_basic,
317316
"AutogradCUDA": jagged_flash_attention_basic,
318317
},
319-
"sll_jagged_dense_elementwise_add": {
320-
"CUDA": jagged_dense_elementwise_add,
321-
"AutogradCUDA": jagged_dense_elementwise_add,
322-
},
323318
}
324319

325320
for op_name, dispatches in sll_cpu_registrations.items():

fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,26 @@
88
# pyre-strict
99

1010

11-
from fbgemm_gpu.sll.triton.jagged_dense_flash_attention import ( # noqa F401
11+
from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_add import ( # noqa F401
12+
jagged_dense_elementwise_add,
13+
JaggedDenseAdd, # noqa F401
14+
)
15+
from fbgemm_gpu.sll.triton.triton_jagged_dense_flash_attention import ( # noqa F401
1216
jagged_dense_flash_attention,
1317
JaggedDenseFlashAttention, # noqa F401
1418
)
1519

16-
from fbgemm_gpu.sll.triton.multi_head_jagged_flash_attention import ( # noqa F401
20+
from fbgemm_gpu.sll.triton.triton_multi_head_jagged_flash_attention import ( # noqa F401
1721
multi_head_jagged_flash_attention,
1822
MultiHeadJaggedFlashAttention, # noqa F401
1923
)
2024

2125
# pyre-ignore[5]
2226
op_registrations = {
27+
"sll_jagged_dense_elementwise_add": {
28+
"CUDA": jagged_dense_elementwise_add,
29+
"AutogradCUDA": jagged_dense_elementwise_add,
30+
},
2331
"sll_jagged_dense_flash_attention": {
2432
"CUDA": jagged_dense_flash_attention,
2533
"AutogradCUDA": jagged_dense_flash_attention,
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
11+
from fbgemm_gpu.triton.jagged.triton_jagged_tensor_ops import (
12+
dense_to_jagged,
13+
jagged_to_dense,
14+
)
15+
16+
17+
class JaggedDenseAdd(torch.autograd.Function):
18+
@staticmethod
19+
# pyre-fixme
20+
def forward(
21+
ctx, x: torch.Tensor, x_offsets: torch.Tensor, y: torch.Tensor, max_seq_len: int
22+
):
23+
ctx.save_for_backward(x_offsets)
24+
ctx.max_seq_len = max_seq_len
25+
# TODO: what should be the correct behavior when jagged values has length > max seq len?
26+
# current behavior is to not truncate jagged values
27+
# similar for backward grad_output
28+
return dense_to_jagged(
29+
y, [x_offsets], operation_function="add", operation_jagged_values=x
30+
)[0]
31+
32+
@staticmethod
33+
# pyre-fixme
34+
def backward(ctx, grad_output: torch.Tensor):
35+
(offsets,) = ctx.saved_tensors
36+
grad_dense = jagged_to_dense(grad_output, [offsets], [ctx.max_seq_len])
37+
return grad_output, None, grad_dense, None
38+
39+
40+
def jagged_dense_elementwise_add(
41+
x: torch.Tensor,
42+
x_offsets: torch.Tensor,
43+
y: torch.Tensor,
44+
max_seq_len: int,
45+
use_fbgemm_kernel: bool = True,
46+
):
47+
if use_fbgemm_kernel:
48+
return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output(
49+
x, [x_offsets], y
50+
)[0]
51+
else:
52+
return JaggedDenseAdd.apply(x, x_offsets, y, max_seq_len)

fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@
1212
import triton
1313
import triton.language as tl
1414

15-
from fbgemm_gpu.triton.jagged.triton_jagged_tensor_ops import (
16-
dense_to_jagged,
17-
jagged_to_dense,
18-
)
19-
2015

2116
def set_block_size(N: int) -> int:
2217
if N > 64:
@@ -2591,41 +2586,3 @@ def jagged_flash_attention_basic(
25912586
)
25922587

25932588
return jagged_O
2594-
2595-
2596-
class JaggedDenseAdd(torch.autograd.Function):
2597-
@staticmethod
2598-
# pyre-fixme
2599-
def forward(
2600-
ctx, x: torch.Tensor, x_offsets: torch.Tensor, y: torch.Tensor, max_seq_len: int
2601-
):
2602-
ctx.save_for_backward(x_offsets)
2603-
ctx.max_seq_len = max_seq_len
2604-
# TODO: what should be the correct behavior when jagged values has length > max seq len?
2605-
# current behavior is to not truncate jagged values
2606-
# similar for backward grad_output
2607-
return dense_to_jagged(
2608-
y, [x_offsets], operation_function="add", operation_jagged_values=x
2609-
)[0]
2610-
2611-
@staticmethod
2612-
# pyre-fixme
2613-
def backward(ctx, grad_output: torch.Tensor):
2614-
(offsets,) = ctx.saved_tensors
2615-
grad_dense = jagged_to_dense(grad_output, [offsets], [ctx.max_seq_len])
2616-
return grad_output, None, grad_dense, None
2617-
2618-
2619-
def jagged_dense_elementwise_add(
2620-
x: torch.Tensor,
2621-
x_offsets: torch.Tensor,
2622-
y: torch.Tensor,
2623-
max_seq_len: int,
2624-
use_fbgemm_kernel: bool = True,
2625-
):
2626-
if use_fbgemm_kernel:
2627-
return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output(
2628-
x, [x_offsets], y
2629-
)[0]
2630-
else:
2631-
return JaggedDenseAdd.apply(x, x_offsets, y, max_seq_len)

fbgemm_gpu/test/sll/jagged_dense_elementwise_add_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99

1010
import unittest
1111

12+
import fbgemm_gpu.sll # noqa F401
1213
import hypothesis.strategies as st
1314
import torch
14-
from fbgemm_gpu.sll.triton_sll import jagged_dense_elementwise_add # noqa
15+
1516
from hypothesis import given, settings
1617

1718
from .common import open_source

0 commit comments

Comments
 (0)