@@ -925,7 +925,7 @@ def sageattn_qk_int8_pv_fp8_cuda_dsk_sm90(
925
925
torch .cuda .set_device (v .device )
926
926
927
927
_tensor_layout = 0 if tensor_layout == "NHD" else 1
928
- _is_causal = 1 if is_causal else 0
928
+ _is_caual = 1 if is_causal else 0
929
929
_qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
930
930
_return_lse = 1 if return_lse else 0
931
931
@@ -939,11 +939,11 @@ def sageattn_qk_int8_pv_fp8_cuda_dsk_sm90(
939
939
q = torch .nn .functional .pad (q , (0 , 128 - head_dim_og ))
940
940
k = torch .nn .functional .pad (k , (0 , 128 - head_dim_og ))
941
941
v = torch .nn .functional .pad (v , (0 , 128 - head_dim_og ))
942
- elif head_dim_og > 128 and head_dim_og < 256 :
943
- q = torch .nn .functional .pad (q , (0 , 256 - head_dim_og ))
944
- k = torch .nn .functional .pad (k , (0 , 256 - head_dim_og ))
942
+ elif head_dim_og > 128 and head_dim_og < 192 :
943
+ q = torch .nn .functional .pad (q , (0 , 192 - head_dim_og ))
944
+ k = torch .nn .functional .pad (k , (0 , 192 - head_dim_og ))
945
945
# v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
946
- elif head_dim_og > 256 :
946
+ elif head_dim_og > 192 :
947
947
raise ValueError (f"Unsupported head_dim: { head_dim_og } " )
948
948
949
949
assert q .stride (- 1 ) == 1 and k .stride (- 1 ) == 1 and v .stride (- 1 ) == 1 , "Last dim of qkv must be contiguous."
@@ -977,10 +977,10 @@ def sageattn_qk_int8_pv_fp8_cuda_dsk_sm90(
977
977
v = torch .cat ([v , torch .zeros (v .size (0 ), v_pad_len , v .size (2 ), v .size (3 ), dtype = v .dtype , device = v .device )], dim = 1 )
978
978
979
979
v_fp8 , v_scale , _ = per_channel_fp8 (v , tensor_layout = tensor_layout , smooth_v = False )
980
- q_int8_nope , q_int8_pe , _ = torch .split (q_int8 , [128 , 64 , 64 ], dim = - 1 )
981
- k_int8_nope , k_int8_pe , _ = torch .split (k_int8 , [128 , 64 , 64 ], dim = - 1 )
980
+ q_int8_nope , q_int8_pe = torch .split (q_int8 , [128 , 64 ], dim = - 1 )
981
+ k_int8_nope , k_int8_pe = torch .split (k_int8 , [128 , 64 ], dim = - 1 )
982
982
983
- lse = _qattn_sm90 .qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90 (q_int8_nope , k_int8_nope , q_int8_pe , k_int8_pe , v_fp8 , o , q_scale , k_scale , v_scale , _tensor_layout , _is_causal , _qk_quant_gran , sm_scale , _return_lse )
983
+ lse = _qattn_sm90 .qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90 (q_int8_nope , k_int8_nope , q_int8_pe , k_int8_pe , v_fp8 , o , q_scale , k_scale , v_scale , _tensor_layout , _is_caual , _qk_quant_gran , sm_scale , _return_lse )
984
984
985
985
head_dim_og = v .shape [- 1 ]
986
986
o = o [..., :head_dim_og ]
0 commit comments