|
7 | 7 |
|
8 | 8 | # pyre-strict
|
9 | 9 |
|
| 10 | +import torch |
| 11 | + |
10 | 12 | from fbgemm_gpu.sll.cpu_sll import ( # noqa F401
|
11 | 13 | cpu_array_jagged_bmm_jagged_out,
|
12 | 14 | cpu_dense_jagged_cat_jagged_out,
|
|
21 | 23 | cpu_jagged_jagged_bmm_jagged_out,
|
22 | 24 | cpu_jagged_self_substraction_jagged_out,
|
23 | 25 | cpu_jagged_softmax,
|
24 |
| - meta_jagged_dense_elementwise_mul_jagged_out, |
25 |
| - meta_jagged_self_substraction_jagged_out, |
26 | 26 | )
|
27 | 27 |
|
28 | 28 | from fbgemm_gpu.sll.meta_sll import ( # noqa F401
|
29 | 29 | meta_array_jagged_bmm_jagged_out,
|
30 | 30 | meta_jagged2_softmax,
|
| 31 | + meta_jagged_dense_elementwise_mul_jagged_out, |
31 | 32 | meta_jagged_jagged_bmm_jagged_out,
|
| 33 | + meta_jagged_self_substraction_jagged_out, |
32 | 34 | )
|
33 | 35 |
|
34 | 36 | from fbgemm_gpu.sll.triton_sll import ( # noqa F401
|
|
208 | 210 | """
|
209 | 211 | )
|
210 | 212 |
|
211 |
| -# NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same function |
212 |
| -# however, this is not ideal because in the inference case, we don't need the autograd forward |
213 |
| -# to save the context because we don't need to do backward. |
214 |
| -lib.register( |
215 |
| - "sll_jagged_dense_bmm", |
216 |
| - { |
217 |
| - "CUDA": jagged_dense_bmm, |
218 |
| - "AutogradCUDA": jagged_dense_bmm, |
| 213 | +# NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same |
| 214 | +# function however, this is not ideal because in the inference case, we don't |
| 215 | +# need the autograd forward to save the context because we don't need to do |
| 216 | +# backward. |
| 217 | + |
| 218 | +# pyre-ignore[5] |
| 219 | +sll_cpu_registrations = { |
| 220 | + "sll_jagged_dense_bmm": { |
219 | 221 | "CPU": cpu_jagged_dense_bmm,
|
220 | 222 | "AutogradCPU": cpu_jagged_dense_bmm,
|
221 | 223 | },
|
222 |
| -) |
223 |
| - |
224 |
| -lib.register( |
225 |
| - "sll_jagged_jagged_bmm", |
226 |
| - { |
227 |
| - "CUDA": jagged_jagged_bmm, |
228 |
| - "AutogradCUDA": jagged_jagged_bmm, |
| 224 | + "sll_jagged_jagged_bmm": { |
229 | 225 | "CPU": cpu_jagged_jagged_bmm,
|
230 | 226 | "AutogradCPU": cpu_jagged_jagged_bmm,
|
231 | 227 | },
|
232 |
| -) |
233 |
| - |
234 |
| -lib.register( |
235 |
| - "sll_dense_jagged_cat_jagged_out", |
236 |
| - { |
237 |
| - "CUDA": dense_jagged_cat_jagged_out, |
| 228 | + "sll_dense_jagged_cat_jagged_out": { |
238 | 229 | "CPU": cpu_dense_jagged_cat_jagged_out,
|
239 | 230 | },
|
240 |
| -) |
241 |
| - |
242 |
| -lib.register( |
243 |
| - "sll_jagged_self_substraction_jagged_out", |
244 |
| - { |
245 |
| - "CUDA": triton_jagged_self_substraction_jagged_out, |
| 231 | + "sll_jagged_self_substraction_jagged_out": { |
246 | 232 | "CPU": cpu_jagged_self_substraction_jagged_out,
|
247 | 233 | "Meta": meta_jagged_self_substraction_jagged_out,
|
248 | 234 | },
|
249 |
| -) |
250 |
| - |
251 |
| -lib.register( |
252 |
| - "sll_jagged2_to_padded_dense", |
253 |
| - { |
254 |
| - "CUDA": jagged2_to_padded_dense, |
255 |
| - "AutogradCUDA": jagged2_to_padded_dense, |
| 235 | + "sll_jagged2_to_padded_dense": { |
256 | 236 | "CPU": cpu_jagged2_to_padded_dense,
|
257 | 237 | "AutogradCPU": cpu_jagged2_to_padded_dense,
|
258 | 238 | },
|
259 |
| -) |
260 |
| - |
261 |
| -lib.register( |
262 |
| - "sll_jagged_dense_elementwise_mul_jagged_out", |
263 |
| - { |
264 |
| - "CUDA": jagged_dense_elementwise_mul_jagged_out, |
265 |
| - "AutogradCUDA": jagged_dense_elementwise_mul_jagged_out, |
| 239 | + "sll_jagged_dense_elementwise_mul_jagged_out": { |
266 | 240 | "CPU": cpu_jagged_dense_elementwise_mul_jagged_out,
|
267 | 241 | "AutogradCPU": cpu_jagged_dense_elementwise_mul_jagged_out,
|
268 | 242 | "Meta": meta_jagged_dense_elementwise_mul_jagged_out,
|
269 | 243 | },
|
270 |
| -) |
271 |
| - |
272 |
| -lib.register( |
273 |
| - "sll_jagged_softmax", |
274 |
| - { |
275 |
| - "CUDA": jagged_softmax, |
276 |
| - "AutogradCUDA": jagged_softmax, |
| 244 | + "sll_jagged_softmax": { |
277 | 245 | "CPU": cpu_jagged_softmax,
|
278 | 246 | "AutogradCPU": cpu_jagged_softmax,
|
279 | 247 | },
|
280 |
| -) |
281 |
| - |
282 |
| -lib.register( |
283 |
| - "sll_jagged2_softmax", |
284 |
| - { |
285 |
| - "CUDA": jagged2_softmax, |
286 |
| - "AutogradCUDA": jagged2_softmax, |
| 248 | + "sll_jagged2_softmax": { |
287 | 249 | "CPU": cpu_jagged2_softmax,
|
288 | 250 | "AutogradCPU": cpu_jagged2_softmax,
|
289 | 251 | "AutogradMeta": meta_jagged2_softmax,
|
290 | 252 | },
|
291 |
| -) |
292 |
| - |
293 |
| -lib.register( |
294 |
| - "sll_array_jagged_bmm_jagged_out", |
295 |
| - { |
296 |
| - "CUDA": array_jagged_bmm_jagged_out, |
297 |
| - "AutogradCUDA": array_jagged_bmm_jagged_out, |
| 253 | + "sll_array_jagged_bmm_jagged_out": { |
298 | 254 | "CPU": cpu_array_jagged_bmm_jagged_out,
|
299 | 255 | "AutogradCPU": cpu_array_jagged_bmm_jagged_out,
|
300 | 256 | "AutogradMeta": meta_array_jagged_bmm_jagged_out,
|
301 | 257 | },
|
302 |
| -) |
303 |
| - |
304 |
| -lib.register( |
305 |
| - "sll_jagged_jagged_bmm_jagged_out", |
306 |
| - { |
307 |
| - "CUDA": jagged_jagged_bmm_jagged_out, |
308 |
| - "AutogradCUDA": jagged_jagged_bmm_jagged_out, |
| 258 | + "sll_jagged_jagged_bmm_jagged_out": { |
309 | 259 | "CPU": cpu_jagged_jagged_bmm_jagged_out,
|
310 | 260 | "AutogradCPU": cpu_jagged_jagged_bmm_jagged_out,
|
311 | 261 | "AutogradMeta": meta_jagged_jagged_bmm_jagged_out,
|
312 | 262 | },
|
313 |
| -) |
314 |
| - |
315 |
| -lib.register( |
316 |
| - "sll_jagged_flash_attention_basic", |
317 |
| - { |
318 |
| - "CUDA": jagged_flash_attention_basic, |
319 |
| - "AutogradCUDA": jagged_flash_attention_basic, |
| 263 | + "sll_jagged_flash_attention_basic": { |
320 | 264 | "CPU": cpu_jagged_flash_attention_basic,
|
321 | 265 | "AutogradCPU": cpu_jagged_flash_attention_basic,
|
322 | 266 | },
|
323 |
| -) |
324 |
| - |
325 |
| -lib.register( |
326 |
| - "sll_jagged_dense_elementwise_add", |
327 |
| - { |
328 |
| - "CUDA": jagged_dense_elementwise_add, |
329 |
| - "AutogradCUDA": jagged_dense_elementwise_add, |
| 267 | + "sll_jagged_dense_elementwise_add": { |
330 | 268 | "CPU": cpu_jagged_dense_elementwise_add,
|
331 | 269 | "AutogradCPU": cpu_jagged_dense_elementwise_add,
|
332 | 270 | },
|
333 |
| -) |
334 |
| - |
335 |
| -lib.register( |
336 |
| - "sll_jagged_dense_flash_attention", |
337 |
| - { |
338 |
| - "CUDA": jagged_dense_flash_attention, |
339 |
| - "AutogradCUDA": jagged_dense_flash_attention, |
| 271 | + "sll_jagged_dense_flash_attention": { |
340 | 272 | "CPU": cpu_jagged_dense_flash_attention,
|
341 | 273 | "AutogradCPU": cpu_jagged_dense_flash_attention,
|
342 | 274 | },
|
343 |
| -) |
| 275 | +} |
344 | 276 |
|
345 |
| -lib.register( |
346 |
| - "sll_multi_head_jagged_flash_attention", |
347 |
| - { |
| 277 | +# pyre-ignore[5] |
| 278 | +sll_gpu_registrations = { |
| 279 | + "sll_jagged_dense_bmm": { |
| 280 | + "CUDA": jagged_dense_bmm, |
| 281 | + "AutogradCUDA": jagged_dense_bmm, |
| 282 | + }, |
| 283 | + "sll_jagged_jagged_bmm": { |
| 284 | + "CUDA": jagged_jagged_bmm, |
| 285 | + "AutogradCUDA": jagged_jagged_bmm, |
| 286 | + }, |
| 287 | + "sll_dense_jagged_cat_jagged_out": { |
| 288 | + "CUDA": dense_jagged_cat_jagged_out, |
| 289 | + }, |
| 290 | + "sll_jagged_self_substraction_jagged_out": { |
| 291 | + "CUDA": triton_jagged_self_substraction_jagged_out, |
| 292 | + }, |
| 293 | + "sll_jagged2_to_padded_dense": { |
| 294 | + "CUDA": jagged2_to_padded_dense, |
| 295 | + "AutogradCUDA": jagged2_to_padded_dense, |
| 296 | + }, |
| 297 | + "sll_jagged_dense_elementwise_mul_jagged_out": { |
| 298 | + "CUDA": jagged_dense_elementwise_mul_jagged_out, |
| 299 | + "AutogradCUDA": jagged_dense_elementwise_mul_jagged_out, |
| 300 | + }, |
| 301 | + "sll_jagged_softmax": { |
| 302 | + "CUDA": jagged_softmax, |
| 303 | + "AutogradCUDA": jagged_softmax, |
| 304 | + }, |
| 305 | + "sll_jagged2_softmax": { |
| 306 | + "CUDA": jagged2_softmax, |
| 307 | + "AutogradCUDA": jagged2_softmax, |
| 308 | + }, |
| 309 | + "sll_array_jagged_bmm_jagged_out": { |
| 310 | + "CUDA": array_jagged_bmm_jagged_out, |
| 311 | + "AutogradCUDA": array_jagged_bmm_jagged_out, |
| 312 | + }, |
| 313 | + "sll_jagged_jagged_bmm_jagged_out": { |
| 314 | + "CUDA": jagged_jagged_bmm_jagged_out, |
| 315 | + "AutogradCUDA": jagged_jagged_bmm_jagged_out, |
| 316 | + }, |
| 317 | + "sll_jagged_flash_attention_basic": { |
| 318 | + "CUDA": jagged_flash_attention_basic, |
| 319 | + "AutogradCUDA": jagged_flash_attention_basic, |
| 320 | + }, |
| 321 | + "sll_jagged_dense_elementwise_add": { |
| 322 | + "CUDA": jagged_dense_elementwise_add, |
| 323 | + "AutogradCUDA": jagged_dense_elementwise_add, |
| 324 | + }, |
| 325 | + "sll_jagged_dense_flash_attention": { |
| 326 | + "CUDA": jagged_dense_flash_attention, |
| 327 | + "AutogradCUDA": jagged_dense_flash_attention, |
| 328 | + }, |
| 329 | + "sll_multi_head_jagged_flash_attention": { |
348 | 330 | "CUDA": multi_head_jagged_flash_attention,
|
349 | 331 | "AutogradCUDA": multi_head_jagged_flash_attention,
|
350 | 332 | },
|
351 |
| -) |
| 333 | +} |
| 334 | + |
| 335 | +for op_name, dispatches in sll_cpu_registrations.items(): |
| 336 | + lib.register(op_name, dispatches) |
| 337 | + |
| 338 | +if torch.cuda.is_available(): |
| 339 | + for op_name, dispatches in sll_gpu_registrations.items(): |
| 340 | + lib.register(op_name, dispatches) |
0 commit comments