Skip to content

Commit 754830b

Browse files
l1cacheDellgzy19990617
authored andcommitted
Revert "finish building"
This reverts commit b1ce0a4.
1 parent b1ce0a4 commit 754830b

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

csrc/dispatch_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
} else if (head_dim == 128) { \
4141
constexpr int HEAD_DIM = 128; \
4242
__VA_ARGS__ \
43-
} else if (head_dim == 256) { \
44-
constexpr int HEAD_DIM = 256; \
43+
} else if (head_dim == 192) { \
44+
constexpr int HEAD_DIM = 192; \
4545
__VA_ARGS__ \
4646
} else { \
4747
std::ostringstream err_msg; \

patch_test/test_dsk.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def precision_cmp_torch(t1: torch.Tensor, t2: torch.Tensor):
6060
torch.cuda.synchronize()
6161

6262
sim, l1, max_diff = precision_cmp_torch(o_torch_fa2.transpose(2, 1), o_sa)
63-
print(f"Sim and Diff of Sage Attn & torch SDPA: {sim}, {max_diff}")
63+
print(f"Sim and Diff of Sage Attn: {sim}, {max_diff}")
6464

6565
sim, l1, max_diff = precision_cmp_torch(o_torch_fa2, o_torch_sdpa)
66-
print(f"Sim and Diff of Flash Attn & torch SDPA: {sim}, {max_diff}")
66+
print(f"Sim and Diff of Flash Attn: {sim}, {max_diff}")

sageattention/core.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,7 @@ def sageattn_qk_int8_pv_fp8_cuda_dsk_sm90(
925925
torch.cuda.set_device(v.device)
926926

927927
_tensor_layout = 0 if tensor_layout == "NHD" else 1
928-
_is_causal = 1 if is_causal else 0
928+
_is_caual = 1 if is_causal else 0
929929
_qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
930930
_return_lse = 1 if return_lse else 0
931931

@@ -939,11 +939,11 @@ def sageattn_qk_int8_pv_fp8_cuda_dsk_sm90(
939939
q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
940940
k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
941941
v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
942-
elif head_dim_og > 128 and head_dim_og < 256:
943-
q = torch.nn.functional.pad(q, (0, 256 - head_dim_og))
944-
k = torch.nn.functional.pad(k, (0, 256 - head_dim_og))
942+
elif head_dim_og > 128 and head_dim_og < 192:
943+
q = torch.nn.functional.pad(q, (0, 192 - head_dim_og))
944+
k = torch.nn.functional.pad(k, (0, 192 - head_dim_og))
945945
# v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
946-
elif head_dim_og > 256:
946+
elif head_dim_og > 192:
947947
raise ValueError(f"Unsupported head_dim: {head_dim_og}")
948948

949949
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
@@ -977,10 +977,10 @@ def sageattn_qk_int8_pv_fp8_cuda_dsk_sm90(
977977
v = torch.cat([v, torch.zeros(v.size(0), v_pad_len, v.size(2), v.size(3), dtype=v.dtype, device=v.device)], dim=1)
978978

979979
v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False)
980-
q_int8_nope, q_int8_pe, _ = torch.split(q_int8, [128, 64, 64], dim=-1)
981-
k_int8_nope, k_int8_pe, _ = torch.split(k_int8, [128, 64, 64], dim=-1)
980+
q_int8_nope, q_int8_pe = torch.split(q_int8, [128, 64], dim=-1)
981+
k_int8_nope, k_int8_pe = torch.split(k_int8, [128, 64], dim=-1)
982982

983-
lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90(q_int8_nope, k_int8_nope, q_int8_pe, k_int8_pe, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_causal, _qk_quant_gran, sm_scale, _return_lse)
983+
lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90(q_int8_nope, k_int8_nope, q_int8_pe, k_int8_pe, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
984984

985985
head_dim_og = v.shape[-1]
986986
o = o[..., :head_dim_og]

0 commit comments

Comments
 (0)