@@ -82,58 +82,59 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
82
82
batch_size = len (seqlens_in_batch )
83
83
device = seqlens_in_batch .device
84
84
85
- if forward_batch .forward_mode == ForwardMode .DECODE and self .skip_prefill :
86
- metadata .cu_seqlens_q = torch .arange (
87
- 0 , batch_size * self .topk + 1 , dtype = torch .int32 , device = device
88
- )
89
- seq_lens_with_decode = seqlens_in_batch + (self .step_id + 1 )
90
- metadata .cache_seqlens_int32 = (
91
- (seq_lens_with_decode ).repeat_interleave (self .topk ).to (torch .int32 )
92
- )
93
- metadata .cu_seqlens_k = torch .nn .functional .pad (
94
- torch .cumsum (metadata .cache_seqlens_int32 , dim = 0 , dtype = torch .int32 ),
95
- (1 , 0 ),
96
- )
97
- metadata .max_seq_len_k = forward_batch .seq_lens_cpu .max ().item () + (
98
- self .step_id + 1
99
- )
100
- metadata .page_table = forward_batch .req_to_token_pool .req_to_token [
101
- forward_batch .req_pool_indices , : metadata .max_seq_len_k
102
- ] # (bsz, max_seq_len)
103
- metadata .page_table = metadata .page_table .repeat_interleave (
104
- self .topk , dim = 0
105
- )
106
- cache_loc = forward_batch .out_cache_loc .view (
107
- self .speculative_num_steps , - 1
108
- ).T
109
-
110
- # page table indices to update
111
- # [bsz, topk]
112
- row_indices = torch .arange (
113
- batch_size * self .topk , device = device , dtype = torch .int32
114
- ).view (batch_size , self .topk )
115
- # [max_seq_len : max_seq_len + step_id + 1]
116
- col_indices = torch .arange (
117
- forward_batch .seq_lens_cpu .max ().item (),
118
- metadata .max_seq_len_k ,
119
- device = device ,
120
- dtype = torch .int32 ,
121
- )
122
- # mask for all valid page table indices
123
- valid_mask = (col_indices .view (1 , - 1 ) >= seqlens_in_batch .view (- 1 , 1 )) & (
124
- col_indices .view (1 , - 1 ) < seq_lens_with_decode .view (- 1 , 1 )
125
- )
126
-
127
- # cache indices to read
128
- cache_indices = torch .arange (
129
- self .step_id + 1 , device = device , dtype = torch .int32
130
- )
131
-
132
- metadata .page_table [row_indices , col_indices ] = torch .where (
133
- valid_mask ,
134
- cache_loc [row_indices , cache_indices ].to (torch .int32 ),
135
- metadata .page_table [row_indices , col_indices ],
136
- )
85
+ if forward_batch .forward_mode == ForwardMode .DECODE :
86
+ if self .skip_prefill :
87
+ metadata .cu_seqlens_q = torch .arange (
88
+ 0 , batch_size * self .topk + 1 , dtype = torch .int32 , device = device
89
+ )
90
+ seq_lens_with_decode = seqlens_in_batch + (self .step_id + 1 )
91
+ metadata .cache_seqlens_int32 = (
92
+ (seq_lens_with_decode ).repeat_interleave (self .topk ).to (torch .int32 )
93
+ )
94
+ metadata .cu_seqlens_k = torch .nn .functional .pad (
95
+ torch .cumsum (
96
+ metadata .cache_seqlens_int32 , dim = 0 , dtype = torch .int32
97
+ ),
98
+ (1 , 0 ),
99
+ )
100
+ # .repeat_interleave(self.topk) # tensor([7, 7, 7, 8, 8, 8])
101
+ # .repeat(self.topk) # tensor([7, 8, 7, 8, 7, 8])
102
+ metadata .max_seq_len_k = forward_batch .seq_lens_cpu .max ().item () + (
103
+ self .step_id + 1
104
+ )
105
+ metadata .page_table = forward_batch .req_to_token_pool .req_to_token [
106
+ forward_batch .req_pool_indices , : metadata .max_seq_len_k
107
+ ] # (bsz, max_seq_len)
108
+ metadata .page_table = metadata .page_table .repeat_interleave (
109
+ self .topk , dim = 0
110
+ )
111
+ cache_loc = forward_batch .out_cache_loc .view (
112
+ self .speculative_num_steps , - 1
113
+ ).T
114
+
115
+ for idx , single_seq_len in enumerate (seq_lens_with_decode ):
116
+ real_bsz_start_idx = idx * self .topk
117
+ real_bsz_end_idx = (idx + 1 ) * self .topk
118
+ metadata .page_table [
119
+ real_bsz_start_idx :real_bsz_end_idx ,
120
+ (single_seq_len - (self .step_id + 1 )) : single_seq_len ,
121
+ ] = cache_loc [
122
+ real_bsz_start_idx :real_bsz_end_idx , : (self .step_id + 1 )
123
+ ]
124
+ else :
125
+ metadata .cache_seqlens_int32 = seqlens_in_batch .to (torch .int32 )
126
+ metadata .cu_seqlens_k = torch .nn .functional .pad (
127
+ torch .cumsum (seqlens_in_batch , dim = 0 , dtype = torch .int32 ), (1 , 0 )
128
+ )
129
+ # Precompute maximum sequence length
130
+ metadata .max_seq_len_k = forward_batch .seq_lens_cpu .max ().item ()
131
+ # Precompute page table
132
+ metadata .page_table = forward_batch .req_to_token_pool .req_to_token [
133
+ forward_batch .req_pool_indices , : metadata .max_seq_len_k
134
+ ]
135
+ metadata .cu_seqlens_q = torch .arange (
136
+ 0 , batch_size + 1 , dtype = torch .int32 , device = device
137
+ )
137
138
elif forward_batch .forward_mode == ForwardMode .TARGET_VERIFY :
138
139
draft_token_num = forward_batch .spec_info .draft_token_num
139
140
@@ -182,11 +183,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
182
183
forward_batch .req_pool_indices , : metadata .max_seq_len_k
183
184
]
184
185
# Precompute cumulative sequence lengths
185
- if forward_batch .forward_mode == ForwardMode .DECODE :
186
- metadata .cu_seqlens_q = torch .arange (
187
- 0 , batch_size + 1 , dtype = torch .int32 , device = device
188
- )
189
- elif (
186
+ if (
190
187
any (forward_batch .extend_prefix_lens_cpu )
191
188
or forward_batch .forward_mode == ForwardMode .DRAFT_EXTEND
192
189
):
0 commit comments