@@ -68,9 +68,6 @@ def __init__(
68
68
self .num_q_heads = (
69
69
model_runner .model_config .num_attention_heads // get_attention_tp_size ()
70
70
)
71
- self .num_kv_heads = model_runner .model_config .get_num_kv_heads (
72
- get_attention_tp_size ()
73
- )
74
71
self .req_to_token = model_runner .req_to_token_pool .req_to_token
75
72
self .num_local_heads = (
76
73
model_runner .model_config .num_attention_heads // get_attention_tp_size ()
@@ -111,8 +108,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
111
108
)
112
109
mla_metadata , num_splits = get_mla_metadata (
113
110
forward_batch .seq_lens .to (torch .int32 ),
114
- Q_LEN * self .num_q_heads // self . num_kv_heads ,
115
- self . num_kv_heads ,
111
+ Q_LEN * self .num_q_heads ,
112
+ 1 ,
116
113
)
117
114
self .forward_metadata = FlashMLADecodeMetadata (
118
115
mla_metadata ,
@@ -141,8 +138,8 @@ def init_cuda_graph_state(
141
138
142
139
self .cuda_graph_mla_metadata , self .cuda_graph_num_splits = get_mla_metadata (
143
140
torch .ones (max_bs , dtype = torch .int32 , device = cuda_graph_kv_indices .device ),
144
- Q_LEN * self .num_q_heads // self . num_kv_heads ,
145
- self . num_kv_heads ,
141
+ Q_LEN * self .num_q_heads ,
142
+ 1 ,
146
143
)
147
144
self .cuda_graph_kv_indices = cuda_graph_kv_indices
148
145
@@ -171,8 +168,8 @@ def init_forward_metadata_capture_cuda_graph(
171
168
)
172
169
mla_metadata , num_splits = get_mla_metadata (
173
170
seq_lens .to (torch .int32 ),
174
- Q_LEN * self .num_q_heads // self . num_kv_heads ,
175
- self . num_kv_heads ,
171
+ Q_LEN * self .num_q_heads ,
172
+ 1 ,
176
173
)
177
174
self .cuda_graph_mla_metadata .copy_ (mla_metadata )
178
175
self .cuda_graph_num_splits [: bs + 1 ].copy_ (num_splits )
@@ -221,8 +218,8 @@ def init_forward_metadata_replay_cuda_graph(
221
218
)
222
219
mla_metadata , num_splits = get_mla_metadata (
223
220
seq_lens .to (torch .int32 ),
224
- Q_LEN * self .num_q_heads // self . num_kv_heads ,
225
- self . num_kv_heads ,
221
+ Q_LEN * self .num_q_heads ,
222
+ 1 ,
226
223
)
227
224
self .cuda_graph_mla_metadata .copy_ (mla_metadata )
228
225
self .cuda_graph_num_splits [: bs + 1 ].copy_ (num_splits )
0 commit comments