Skip to content

feat: fa3 custom ops for compatibility with PT Compile #1590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 184 additions & 73 deletions hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Copyright (c) 2023, Tri Dao.

from typing import Optional, Union
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn

# isort: off
# We need to import the CUDA kernels after importing torch
Expand All @@ -15,41 +14,46 @@
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x


def _flash_attn_forward(
q,
k,
v,
k_new,
v_new,
qv,
out,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_k_new,
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
page_table,
kv_batch_idx,
leftpad_k,
rotary_cos,
rotary_sin,
seqlens_rotary,
q_descale,
k_descale,
v_descale,
softmax_scale,
causal,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=1,
pack_gqa=None,
sm_margin=0):
def round_multiple(x, m):
return (x + m - 1) // m * m

@torch.library.custom_op("flash_attn::_hopper_flash_attn_forward", mutates_args=('out',), device_types="cuda")
def _hopper_flash_attn_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
k_new: Optional[torch.Tensor],
v_new: Optional[torch.Tensor],
qv: Optional[torch.Tensor],
out: Optional[torch.Tensor],
cu_seqlens_q: Optional[torch.Tensor],
cu_seqlens_k: Optional[torch.Tensor],
cu_seqlens_k_new: Optional[torch.Tensor],
seqused_q: Optional[torch.Tensor],
seqused_k: Optional[torch.Tensor],
max_seqlen_q: Optional[int],
max_seqlen_k: Optional[int],
page_table: Optional[torch.Tensor],
kv_batch_idx: Optional[torch.Tensor],
leftpad_k: Optional[torch.Tensor],
rotary_cos: Optional[torch.Tensor],
rotary_sin: Optional[torch.Tensor],
seqlens_rotary: Optional[torch.Tensor],
q_descale: Optional[torch.Tensor],
k_descale: Optional[torch.Tensor],
v_descale: Optional[torch.Tensor],
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
attention_chunk: int = 0,
softcap: float = 0.0,
rotary_interleaved: bool =True,
scheduler_metadata: Optional[torch.Tensor] = None,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
sm_margin: int = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
Expand Down Expand Up @@ -87,8 +91,8 @@ def _flash_attn_forward(
v_descale,
softmax_scale,
causal,
window_size[0],
window_size[1],
window_size_left,
window_size_right,
attention_chunk,
softcap,
rotary_interleaved,
Expand All @@ -99,30 +103,94 @@ def _flash_attn_forward(
)
return out, softmax_lse, *rest


def _flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
sequed_q,
sequed_k,
max_seqlen_q,
max_seqlen_k,
dq,
dk,
dv,
softmax_scale,
causal,
window_size=(-1, -1),
softcap=0.0,
deterministic=False,
sm_margin=0,
):
@torch.library.register_fake("flash_attn::_hopper_flash_attn_forward")
def _hopper_flash_attn_forward_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
k_new: Optional[torch.Tensor],
v_new: Optional[torch.Tensor],
qv: Optional[torch.Tensor],
out: Optional[torch.Tensor],
cu_seqlens_q: Optional[torch.Tensor],
cu_seqlens_k: Optional[torch.Tensor],
cu_seqlens_k_new: Optional[torch.Tensor],
seqused_q: Optional[torch.Tensor],
seqused_k: Optional[torch.Tensor],
max_seqlen_q: Optional[int],
max_seqlen_k: Optional[int],
page_table: Optional[torch.Tensor],
kv_batch_idx: Optional[torch.Tensor],
leftpad_k: Optional[torch.Tensor],
rotary_cos: Optional[torch.Tensor],
rotary_sin: Optional[torch.Tensor],
seqlens_rotary: Optional[torch.Tensor],
q_descale: Optional[torch.Tensor],
k_descale: Optional[torch.Tensor],
v_descale: Optional[torch.Tensor],
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
attention_chunk: int = 0,
softcap: float = 0.0,
rotary_interleaved: bool =True,
scheduler_metadata: Optional[torch.Tensor] = None,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
sm_margin: int = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]

out = torch.empty_like(q)

is_varlen = cu_seqlens_q is not None or cu_seqlens_k is not None or leftpad_k is not None
if is_varlen:
total_q, num_heads, _ = q.shape
softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout)
else:
batch_size, seqlen_q, num_heads, _ = q.shape
softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout)
if num_splits > 1:
head_size_v = v.shape[-1]
if is_varlen:
total_q, num_heads, _ = q.shape
out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device, layout=q.layout)
softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout)
else:
batch_size, seqlen_q, num_heads, _ = q.shape
out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device, layout=q.layout)
softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout)
else:
out_accum = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
softmax_lse_accum = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
return out, softmax_lse, out_accum, softmax_lse_accum

@torch.library.custom_op("flash_attn::_hopper_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
def _hopper_flash_attn_backward(
dout: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor],
cu_seqlens_k: Optional[torch.Tensor],
sequed_q: Optional[torch.Tensor],
sequed_k: Optional[torch.Tensor],
max_seqlen_q: Optional[int],
max_seqlen_k: Optional[int],
dq: Optional[torch.Tensor],
dk: Optional[torch.Tensor],
dv: Optional[torch.Tensor],
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
softcap: float =0.0,
deterministic: bool = False,
sm_margin: int = 0,
) -> None:
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
Expand All @@ -143,14 +211,50 @@ def _flash_attn_backward(
max_seqlen_k,
softmax_scale,
causal,
window_size[0],
window_size[1],
window_size_left,
window_size_right,
softcap,
deterministic,
sm_margin,
)
return dq, dk, dv, softmax_d

@torch.library.register_fake("flash_attn::_hopper_flash_attn_backward")
def _hopper_flash_attn_backward_fake(
dout: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor],
cu_seqlens_k: Optional[torch.Tensor],
sequed_q: Optional[torch.Tensor],
sequed_k: Optional[torch.Tensor],
max_seqlen_q: Optional[int],
max_seqlen_k: Optional[int],
dq: Optional[torch.Tensor],
dk: Optional[torch.Tensor],
dv: Optional[torch.Tensor],
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
softcap: float =0.0,
deterministic: bool = False,
sm_margin: int = 0,
) -> None:
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
if dq is None:
dq = torch.empty_like(q)
if dk is None:
dk = torch.empty_like(k)
if dv is None:
dv = torch.empty_like(v)
return None

# Forward Compatibility
_flash_attn_forward = torch.ops.flash_attn._hopper_flash_attn_forward
_flash_attn_backward = torch.ops.flash_attn._hopper_flash_attn_backward

class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -192,7 +296,8 @@ def forward(
q_descale, k_descale, v_descale,
softmax_scale,
causal=causal,
window_size=window_size,
window_size_left=window_size[0],
window_size_right=window_size[1],
attention_chunk=attention_chunk,
softcap=softcap,
)
Expand Down Expand Up @@ -237,7 +342,8 @@ def backward(ctx, dout, *args):
dv,
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.deterministic,
)
Expand Down Expand Up @@ -283,7 +389,8 @@ def forward(
q_descale, k_descale, v_descale,
softmax_scale,
causal=causal,
window_size=window_size,
window_size_left=window_size[0],
window_size_right=window_size[1],
attention_chunk=attention_chunk,
softcap=softcap,
num_splits=num_splits,
Expand Down Expand Up @@ -321,7 +428,8 @@ def backward(ctx, dout, *args):
dv,
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.deterministic,
ctx.sm_margin,
Expand Down Expand Up @@ -380,7 +488,8 @@ def forward(
q_descale, k_descale, v_descale,
softmax_scale,
causal=causal,
window_size=window_size,
window_size_left=window_size[0],
window_size_right=window_size[1],
attention_chunk=attention_chunk,
softcap=softcap,
num_splits=num_splits,
Expand Down Expand Up @@ -423,7 +532,8 @@ def backward(ctx, dout, *args):
dv,
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.deterministic,
ctx.sm_margin,
Expand Down Expand Up @@ -771,7 +881,8 @@ def flash_attn_with_kvcache(
q_descale, k_descale, v_descale,
softmax_scale,
causal=causal,
window_size=window_size,
window_size_left=window_size[0],
window_size_right=window_size[1],
attention_chunk=attention_chunk,
softcap=softcap,
rotary_interleaved=rotary_interleaved,
Expand Down