Skip to content

Commit 2ed96c7

Browse files
authored
fix flashmla bug (#5272)
1 parent 2aa3f5e commit 2ed96c7

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

python/sglang/srt/layers/attention/flashmla_backend.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,6 @@ def __init__(
6868
self.num_q_heads = (
6969
model_runner.model_config.num_attention_heads // get_attention_tp_size()
7070
)
71-
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
72-
get_attention_tp_size()
73-
)
7471
self.req_to_token = model_runner.req_to_token_pool.req_to_token
7572
self.num_local_heads = (
7673
model_runner.model_config.num_attention_heads // get_attention_tp_size()
@@ -111,8 +108,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
111108
)
112109
mla_metadata, num_splits = get_mla_metadata(
113110
forward_batch.seq_lens.to(torch.int32),
114-
Q_LEN * self.num_q_heads // self.num_kv_heads,
115-
self.num_kv_heads,
111+
Q_LEN * self.num_q_heads,
112+
1,
116113
)
117114
self.forward_metadata = FlashMLADecodeMetadata(
118115
mla_metadata,
@@ -141,8 +138,8 @@ def init_cuda_graph_state(
141138

142139
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
143140
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
144-
Q_LEN * self.num_q_heads // self.num_kv_heads,
145-
self.num_kv_heads,
141+
Q_LEN * self.num_q_heads,
142+
1,
146143
)
147144
self.cuda_graph_kv_indices = cuda_graph_kv_indices
148145

@@ -171,8 +168,8 @@ def init_forward_metadata_capture_cuda_graph(
171168
)
172169
mla_metadata, num_splits = get_mla_metadata(
173170
seq_lens.to(torch.int32),
174-
Q_LEN * self.num_q_heads // self.num_kv_heads,
175-
self.num_kv_heads,
171+
Q_LEN * self.num_q_heads,
172+
1,
176173
)
177174
self.cuda_graph_mla_metadata.copy_(mla_metadata)
178175
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -221,8 +218,8 @@ def init_forward_metadata_replay_cuda_graph(
221218
)
222219
mla_metadata, num_splits = get_mla_metadata(
223220
seq_lens.to(torch.int32),
224-
Q_LEN * self.num_q_heads // self.num_kv_heads,
225-
self.num_kv_heads,
221+
Q_LEN * self.num_q_heads,
222+
1,
226223
)
227224
self.cuda_graph_mla_metadata.copy_(mla_metadata)
228225
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)

0 commit comments

Comments
 (0)