Skip to content

Commit 7d009c7

Browse files
q10avbokovoy
authored andcommitted
Fold ops registration code, pt 3 (pytorch#3641)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/717 Pull Request resolved: pytorch#3641 - Fold out ops registration code in SLL ops Reviewed By: sryap Differential Revision: D68911389 fbshipit-source-id: 49649316b92f064063dee7079fe0c83a39c850ee
1 parent 31a6a92 commit 7d009c7

File tree

3 files changed

+158
-169
lines changed

3 files changed

+158
-169
lines changed

fbgemm_gpu/fbgemm_gpu/sll/__init__.py

Lines changed: 86 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
# pyre-strict
99

10+
import torch
11+
1012
from fbgemm_gpu.sll.cpu_sll import ( # noqa F401
1113
cpu_array_jagged_bmm_jagged_out,
1214
cpu_dense_jagged_cat_jagged_out,
@@ -21,14 +23,14 @@
2123
cpu_jagged_jagged_bmm_jagged_out,
2224
cpu_jagged_self_substraction_jagged_out,
2325
cpu_jagged_softmax,
24-
meta_jagged_dense_elementwise_mul_jagged_out,
25-
meta_jagged_self_substraction_jagged_out,
2626
)
2727

2828
from fbgemm_gpu.sll.meta_sll import ( # noqa F401
2929
meta_array_jagged_bmm_jagged_out,
3030
meta_jagged2_softmax,
31+
meta_jagged_dense_elementwise_mul_jagged_out,
3132
meta_jagged_jagged_bmm_jagged_out,
33+
meta_jagged_self_substraction_jagged_out,
3234
)
3335

3436
from fbgemm_gpu.sll.triton_sll import ( # noqa F401
@@ -208,144 +210,131 @@
208210
"""
209211
)
210212

211-
# NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same function
212-
# however, this is not ideal because in the inference case, we don't need the autograd forward
213-
# to save the context because we don't need to do backward.
214-
lib.register(
215-
"sll_jagged_dense_bmm",
216-
{
217-
"CUDA": jagged_dense_bmm,
218-
"AutogradCUDA": jagged_dense_bmm,
213+
# NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same
214+
# function however, this is not ideal because in the inference case, we don't
215+
# need the autograd forward to save the context because we don't need to do
216+
# backward.
217+
218+
# pyre-ignore[5]
219+
sll_cpu_registrations = {
220+
"sll_jagged_dense_bmm": {
219221
"CPU": cpu_jagged_dense_bmm,
220222
"AutogradCPU": cpu_jagged_dense_bmm,
221223
},
222-
)
223-
224-
lib.register(
225-
"sll_jagged_jagged_bmm",
226-
{
227-
"CUDA": jagged_jagged_bmm,
228-
"AutogradCUDA": jagged_jagged_bmm,
224+
"sll_jagged_jagged_bmm": {
229225
"CPU": cpu_jagged_jagged_bmm,
230226
"AutogradCPU": cpu_jagged_jagged_bmm,
231227
},
232-
)
233-
234-
lib.register(
235-
"sll_dense_jagged_cat_jagged_out",
236-
{
237-
"CUDA": dense_jagged_cat_jagged_out,
228+
"sll_dense_jagged_cat_jagged_out": {
238229
"CPU": cpu_dense_jagged_cat_jagged_out,
239230
},
240-
)
241-
242-
lib.register(
243-
"sll_jagged_self_substraction_jagged_out",
244-
{
245-
"CUDA": triton_jagged_self_substraction_jagged_out,
231+
"sll_jagged_self_substraction_jagged_out": {
246232
"CPU": cpu_jagged_self_substraction_jagged_out,
247233
"Meta": meta_jagged_self_substraction_jagged_out,
248234
},
249-
)
250-
251-
lib.register(
252-
"sll_jagged2_to_padded_dense",
253-
{
254-
"CUDA": jagged2_to_padded_dense,
255-
"AutogradCUDA": jagged2_to_padded_dense,
235+
"sll_jagged2_to_padded_dense": {
256236
"CPU": cpu_jagged2_to_padded_dense,
257237
"AutogradCPU": cpu_jagged2_to_padded_dense,
258238
},
259-
)
260-
261-
lib.register(
262-
"sll_jagged_dense_elementwise_mul_jagged_out",
263-
{
264-
"CUDA": jagged_dense_elementwise_mul_jagged_out,
265-
"AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
239+
"sll_jagged_dense_elementwise_mul_jagged_out": {
266240
"CPU": cpu_jagged_dense_elementwise_mul_jagged_out,
267241
"AutogradCPU": cpu_jagged_dense_elementwise_mul_jagged_out,
268242
"Meta": meta_jagged_dense_elementwise_mul_jagged_out,
269243
},
270-
)
271-
272-
lib.register(
273-
"sll_jagged_softmax",
274-
{
275-
"CUDA": jagged_softmax,
276-
"AutogradCUDA": jagged_softmax,
244+
"sll_jagged_softmax": {
277245
"CPU": cpu_jagged_softmax,
278246
"AutogradCPU": cpu_jagged_softmax,
279247
},
280-
)
281-
282-
lib.register(
283-
"sll_jagged2_softmax",
284-
{
285-
"CUDA": jagged2_softmax,
286-
"AutogradCUDA": jagged2_softmax,
248+
"sll_jagged2_softmax": {
287249
"CPU": cpu_jagged2_softmax,
288250
"AutogradCPU": cpu_jagged2_softmax,
289251
"AutogradMeta": meta_jagged2_softmax,
290252
},
291-
)
292-
293-
lib.register(
294-
"sll_array_jagged_bmm_jagged_out",
295-
{
296-
"CUDA": array_jagged_bmm_jagged_out,
297-
"AutogradCUDA": array_jagged_bmm_jagged_out,
253+
"sll_array_jagged_bmm_jagged_out": {
298254
"CPU": cpu_array_jagged_bmm_jagged_out,
299255
"AutogradCPU": cpu_array_jagged_bmm_jagged_out,
300256
"AutogradMeta": meta_array_jagged_bmm_jagged_out,
301257
},
302-
)
303-
304-
lib.register(
305-
"sll_jagged_jagged_bmm_jagged_out",
306-
{
307-
"CUDA": jagged_jagged_bmm_jagged_out,
308-
"AutogradCUDA": jagged_jagged_bmm_jagged_out,
258+
"sll_jagged_jagged_bmm_jagged_out": {
309259
"CPU": cpu_jagged_jagged_bmm_jagged_out,
310260
"AutogradCPU": cpu_jagged_jagged_bmm_jagged_out,
311261
"AutogradMeta": meta_jagged_jagged_bmm_jagged_out,
312262
},
313-
)
314-
315-
lib.register(
316-
"sll_jagged_flash_attention_basic",
317-
{
318-
"CUDA": jagged_flash_attention_basic,
319-
"AutogradCUDA": jagged_flash_attention_basic,
263+
"sll_jagged_flash_attention_basic": {
320264
"CPU": cpu_jagged_flash_attention_basic,
321265
"AutogradCPU": cpu_jagged_flash_attention_basic,
322266
},
323-
)
324-
325-
lib.register(
326-
"sll_jagged_dense_elementwise_add",
327-
{
328-
"CUDA": jagged_dense_elementwise_add,
329-
"AutogradCUDA": jagged_dense_elementwise_add,
267+
"sll_jagged_dense_elementwise_add": {
330268
"CPU": cpu_jagged_dense_elementwise_add,
331269
"AutogradCPU": cpu_jagged_dense_elementwise_add,
332270
},
333-
)
334-
335-
lib.register(
336-
"sll_jagged_dense_flash_attention",
337-
{
338-
"CUDA": jagged_dense_flash_attention,
339-
"AutogradCUDA": jagged_dense_flash_attention,
271+
"sll_jagged_dense_flash_attention": {
340272
"CPU": cpu_jagged_dense_flash_attention,
341273
"AutogradCPU": cpu_jagged_dense_flash_attention,
342274
},
343-
)
275+
}
344276

345-
lib.register(
346-
"sll_multi_head_jagged_flash_attention",
347-
{
277+
# pyre-ignore[5]
278+
sll_gpu_registrations = {
279+
"sll_jagged_dense_bmm": {
280+
"CUDA": jagged_dense_bmm,
281+
"AutogradCUDA": jagged_dense_bmm,
282+
},
283+
"sll_jagged_jagged_bmm": {
284+
"CUDA": jagged_jagged_bmm,
285+
"AutogradCUDA": jagged_jagged_bmm,
286+
},
287+
"sll_dense_jagged_cat_jagged_out": {
288+
"CUDA": dense_jagged_cat_jagged_out,
289+
},
290+
"sll_jagged_self_substraction_jagged_out": {
291+
"CUDA": triton_jagged_self_substraction_jagged_out,
292+
},
293+
"sll_jagged2_to_padded_dense": {
294+
"CUDA": jagged2_to_padded_dense,
295+
"AutogradCUDA": jagged2_to_padded_dense,
296+
},
297+
"sll_jagged_dense_elementwise_mul_jagged_out": {
298+
"CUDA": jagged_dense_elementwise_mul_jagged_out,
299+
"AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
300+
},
301+
"sll_jagged_softmax": {
302+
"CUDA": jagged_softmax,
303+
"AutogradCUDA": jagged_softmax,
304+
},
305+
"sll_jagged2_softmax": {
306+
"CUDA": jagged2_softmax,
307+
"AutogradCUDA": jagged2_softmax,
308+
},
309+
"sll_array_jagged_bmm_jagged_out": {
310+
"CUDA": array_jagged_bmm_jagged_out,
311+
"AutogradCUDA": array_jagged_bmm_jagged_out,
312+
},
313+
"sll_jagged_jagged_bmm_jagged_out": {
314+
"CUDA": jagged_jagged_bmm_jagged_out,
315+
"AutogradCUDA": jagged_jagged_bmm_jagged_out,
316+
},
317+
"sll_jagged_flash_attention_basic": {
318+
"CUDA": jagged_flash_attention_basic,
319+
"AutogradCUDA": jagged_flash_attention_basic,
320+
},
321+
"sll_jagged_dense_elementwise_add": {
322+
"CUDA": jagged_dense_elementwise_add,
323+
"AutogradCUDA": jagged_dense_elementwise_add,
324+
},
325+
"sll_jagged_dense_flash_attention": {
326+
"CUDA": jagged_dense_flash_attention,
327+
"AutogradCUDA": jagged_dense_flash_attention,
328+
},
329+
"sll_multi_head_jagged_flash_attention": {
348330
"CUDA": multi_head_jagged_flash_attention,
349331
"AutogradCUDA": multi_head_jagged_flash_attention,
350332
},
351-
)
333+
}
334+
335+
for op_name, dispatches in sll_cpu_registrations.items():
336+
lib.register(op_name, dispatches)
337+
338+
if torch.cuda.is_available():
339+
for op_name, dispatches in sll_gpu_registrations.items():
340+
lib.register(op_name, dispatches)

fbgemm_gpu/fbgemm_gpu/sll/cpu_sll.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -213,19 +213,6 @@ def cpu_jagged_self_substraction_jagged_out(
213213
return jagged_B
214214

215215

216-
def meta_jagged_self_substraction_jagged_out(
217-
jagged_A: torch.Tensor,
218-
offsets_a: torch.Tensor,
219-
offsets_b: torch.Tensor,
220-
max_seq_len: int,
221-
) -> torch.Tensor:
222-
return torch.empty(
223-
[torch.library.get_ctx().new_dynamic_size()],
224-
dtype=jagged_A.dtype,
225-
device=jagged_A.device,
226-
)
227-
228-
229216
def cpu_jagged2_to_padded_dense(
230217
values: torch.Tensor,
231218
offsets: torch.Tensor,
@@ -352,65 +339,6 @@ def cpu_jagged_dense_elementwise_mul_jagged_out(
352339
)
353340

354341

355-
class MetaJaggedDenseElementwiseMul(torch.autograd.Function):
356-
@staticmethod
357-
# pyre-fixme
358-
def forward(
359-
ctx, # pyre-ignore [2]
360-
x: torch.Tensor,
361-
y: torch.Tensor,
362-
x_seq_lengths: torch.Tensor,
363-
x_offsets: torch.Tensor,
364-
max_seq_len: int,
365-
) -> torch.Tensor:
366-
ctx.max_seq_len = max_seq_len
367-
368-
ctx.save_for_backward(
369-
x,
370-
y,
371-
x_seq_lengths,
372-
x_offsets,
373-
)
374-
375-
total_L = x.size(0)
376-
jagged_C = torch.zeros((total_L), device=x.device, dtype=x.dtype)
377-
378-
return jagged_C
379-
380-
@staticmethod
381-
# pyre-fixme
382-
def backward(ctx, grad_output: torch.Tensor):
383-
(
384-
x,
385-
y,
386-
x_seq_lengths,
387-
x_offsets,
388-
) = ctx.saved_tensors
389-
390-
total_L = grad_output.size(0)
391-
jagged_C = torch.zeros(
392-
(total_L), device=grad_output.device, dtype=grad_output.dtype
393-
)
394-
395-
return jagged_C, None, None, None, None
396-
397-
398-
def meta_jagged_dense_elementwise_mul_jagged_out(
399-
x: torch.Tensor,
400-
y: torch.Tensor,
401-
x_seq_lengths: torch.Tensor,
402-
x_offsets: torch.Tensor,
403-
max_seq_len: int,
404-
) -> torch.Tensor:
405-
return MetaJaggedDenseElementwiseMul.apply(
406-
x,
407-
y,
408-
x_seq_lengths,
409-
x_offsets,
410-
max_seq_len,
411-
)
412-
413-
414342
class JaggedSoftmaxCPU(torch.autograd.Function):
415343
@staticmethod
416344
# pyre-fixme

0 commit comments

Comments
 (0)