@@ -37,6 +37,9 @@ def __init__(self, model_runner: ModelRunner):
37
37
(max_bs + 1 ,), dtype = torch .int32 , device = model_runner .device
38
38
)
39
39
self .req_to_token = model_runner .req_to_token_pool .req_to_token
40
+ self .qo_indptr = torch .zeros (
41
+ (max_bs + 1 ,), dtype = torch .int32 , device = model_runner .device
42
+ )
40
43
41
44
self .num_head = (
42
45
model_runner .model_config .num_attention_heads // get_attention_tp_size ()
@@ -54,6 +57,9 @@ def __init__(self, model_runner: ModelRunner):
54
57
def init_forward_metadata (self , forward_batch : ForwardBatch ):
55
58
"""Init auxiliary variables for triton attention backend."""
56
59
60
+ bs = forward_batch .batch_size
61
+ kv_indptr = self .kv_indptr
62
+
57
63
if forward_batch .forward_mode .is_decode ():
58
64
attn_logits = torch .empty (
59
65
(
@@ -68,31 +74,59 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
68
74
69
75
max_extend_len = None
70
76
71
- kv_indptr = self .kv_indptr
72
- bs = len (forward_batch .req_pool_indices )
73
77
kv_indptr [1 : bs + 1 ] = torch .cumsum (forward_batch .seq_lens , dim = 0 )
74
78
kv_indptr = kv_indptr [: bs + 1 ]
75
79
kv_indices = torch .empty (
76
- forward_batch .seq_lens_sum , dtype = torch .int32 , device = "cuda"
80
+ forward_batch .seq_lens_sum , dtype = torch .int32 , device = self . device
77
81
)
78
82
create_flashinfer_kv_indices_triton [(bs ,)](
79
- forward_batch . req_to_token_pool .req_to_token ,
83
+ self .req_to_token ,
80
84
forward_batch .req_pool_indices ,
81
85
forward_batch .seq_lens ,
82
86
kv_indptr ,
83
87
None ,
84
88
kv_indices ,
85
- forward_batch . req_to_token_pool .req_to_token .stride (0 ),
89
+ self .req_to_token .stride (0 ),
86
90
)
87
91
92
+ qo_indptr = None
93
+ custom_mask = None
88
94
else :
95
+ kv_indptr [1 : bs + 1 ] = torch .cumsum (
96
+ forward_batch .extend_prefix_lens , dim = 0
97
+ )
98
+ kv_indptr = kv_indptr [: bs + 1 ]
99
+ kv_indices = torch .empty (
100
+ forward_batch .extend_prefix_lens .sum ().item (),
101
+ dtype = torch .int32 ,
102
+ device = self .device ,
103
+ )
104
+ create_flashinfer_kv_indices_triton [(bs ,)](
105
+ self .req_to_token ,
106
+ forward_batch .req_pool_indices ,
107
+ forward_batch .extend_prefix_lens ,
108
+ kv_indptr ,
109
+ None ,
110
+ kv_indices ,
111
+ self .req_to_token .stride (0 ),
112
+ )
113
+
114
+ qo_indptr = self .qo_indptr
115
+ qo_indptr [1 : bs + 1 ] = torch .cumsum (forward_batch .extend_seq_lens , dim = 0 )
116
+ qo_indptr = qo_indptr [: bs + 1 ]
117
+ custom_mask = None
118
+
89
119
attn_logits = None
90
120
max_extend_len = torch .max (forward_batch .extend_seq_lens ).item ()
91
121
92
- kv_indptr = None
93
- kv_indices = None
94
-
95
- self .forward_metadata = attn_logits , max_extend_len , kv_indptr , kv_indices
122
+ self .forward_metadata = (
123
+ attn_logits ,
124
+ max_extend_len ,
125
+ kv_indptr ,
126
+ kv_indices ,
127
+ qo_indptr ,
128
+ custom_mask ,
129
+ )
96
130
97
131
def init_cuda_graph_state (self , max_bs : int ):
98
132
self .cuda_graph_max_total_num_tokens = max_bs * self .cuda_graph_max_seq_len
@@ -144,6 +178,8 @@ def init_forward_metadata_capture_cuda_graph(
144
178
None ,
145
179
kv_indptr ,
146
180
kv_indices ,
181
+ None ,
182
+ None ,
147
183
)
148
184
149
185
def init_forward_metadata_replay_cuda_graph (
@@ -197,19 +233,19 @@ def forward_extend(
197
233
layer , forward_batch .out_cache_loc , k , v
198
234
)
199
235
200
- _ , max_extend_len , _ , _ = self .forward_metadata
236
+ _ , max_extend_len , kv_indptr , kv_indices , qo_indptr , custom_mask = (
237
+ self .forward_metadata
238
+ )
201
239
self .extend_attention_fwd (
202
240
q .view (- 1 , layer .tp_q_head_num , layer .qk_head_dim ),
203
241
k .contiguous (),
204
242
v .contiguous (),
205
243
o .view (- 1 , layer .tp_q_head_num , layer .v_head_dim ),
206
244
forward_batch .token_to_kv_pool .get_key_buffer (layer .layer_id ),
207
245
forward_batch .token_to_kv_pool .get_value_buffer (layer .layer_id ),
208
- forward_batch .req_to_token_pool .req_to_token ,
209
- forward_batch .req_pool_indices ,
210
- forward_batch .seq_lens ,
211
- forward_batch .extend_seq_lens ,
212
- forward_batch .extend_start_loc ,
246
+ qo_indptr ,
247
+ kv_indptr ,
248
+ kv_indices ,
213
249
max_extend_len ,
214
250
layer .scaling ,
215
251
layer .logit_cap ,
@@ -235,7 +271,7 @@ def forward_decode(
235
271
else :
236
272
o = torch .empty_like (q )
237
273
238
- attn_logits , _ , kv_indptr , kv_indices = self .forward_metadata
274
+ attn_logits , _ , kv_indptr , kv_indices , _ , _ = self .forward_metadata
239
275
240
276
if save_kv_cache :
241
277
forward_batch .token_to_kv_pool .set_kv_buffer (
0 commit comments