Skip to content

Commit 54f0c1a

Browse files
committed
revert
1 parent 3dcd436 commit 54f0c1a

File tree

1 file changed

+54
-57
lines changed

1 file changed

+54
-57
lines changed

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 54 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -82,58 +82,59 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
8282
batch_size = len(seqlens_in_batch)
8383
device = seqlens_in_batch.device
8484

85-
if forward_batch.forward_mode == ForwardMode.DECODE and self.skip_prefill:
86-
metadata.cu_seqlens_q = torch.arange(
87-
0, batch_size * self.topk + 1, dtype=torch.int32, device=device
88-
)
89-
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
90-
metadata.cache_seqlens_int32 = (
91-
(seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32)
92-
)
93-
metadata.cu_seqlens_k = torch.nn.functional.pad(
94-
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
95-
(1, 0),
96-
)
97-
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
98-
self.step_id + 1
99-
)
100-
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
101-
forward_batch.req_pool_indices, : metadata.max_seq_len_k
102-
] # (bsz, max_seq_len)
103-
metadata.page_table = metadata.page_table.repeat_interleave(
104-
self.topk, dim=0
105-
)
106-
cache_loc = forward_batch.out_cache_loc.view(
107-
self.speculative_num_steps, -1
108-
).T
109-
110-
# page table indices to update
111-
# [bsz, topk]
112-
row_indices = torch.arange(
113-
batch_size * self.topk, device=device, dtype=torch.int32
114-
).view(batch_size, self.topk)
115-
# [max_seq_len : max_seq_len + step_id + 1]
116-
col_indices = torch.arange(
117-
forward_batch.seq_lens_cpu.max().item(),
118-
metadata.max_seq_len_k,
119-
device=device,
120-
dtype=torch.int32,
121-
)
122-
# mask for all valid page table indices
123-
valid_mask = (col_indices.view(1, -1) >= seqlens_in_batch.view(-1, 1)) & (
124-
col_indices.view(1, -1) < seq_lens_with_decode.view(-1, 1)
125-
)
126-
127-
# cache indices to read
128-
cache_indices = torch.arange(
129-
self.step_id + 1, device=device, dtype=torch.int32
130-
)
131-
132-
metadata.page_table[row_indices, col_indices] = torch.where(
133-
valid_mask,
134-
cache_loc[row_indices, cache_indices].to(torch.int32),
135-
metadata.page_table[row_indices, col_indices],
136-
)
85+
if forward_batch.forward_mode == ForwardMode.DECODE:
86+
if self.skip_prefill:
87+
metadata.cu_seqlens_q = torch.arange(
88+
0, batch_size * self.topk + 1, dtype=torch.int32, device=device
89+
)
90+
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
91+
metadata.cache_seqlens_int32 = (
92+
(seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32)
93+
)
94+
metadata.cu_seqlens_k = torch.nn.functional.pad(
95+
torch.cumsum(
96+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
97+
),
98+
(1, 0),
99+
)
100+
# .repeat_interleave(self.topk) # tensor([7, 7, 7, 8, 8, 8])
101+
# .repeat(self.topk) # tensor([7, 8, 7, 8, 7, 8])
102+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
103+
self.step_id + 1
104+
)
105+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
106+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
107+
] # (bsz, max_seq_len)
108+
metadata.page_table = metadata.page_table.repeat_interleave(
109+
self.topk, dim=0
110+
)
111+
cache_loc = forward_batch.out_cache_loc.view(
112+
self.speculative_num_steps, -1
113+
).T
114+
115+
for idx, single_seq_len in enumerate(seq_lens_with_decode):
116+
real_bsz_start_idx = idx * self.topk
117+
real_bsz_end_idx = (idx + 1) * self.topk
118+
metadata.page_table[
119+
real_bsz_start_idx:real_bsz_end_idx,
120+
(single_seq_len - (self.step_id + 1)) : single_seq_len,
121+
] = cache_loc[
122+
real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
123+
]
124+
else:
125+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
126+
metadata.cu_seqlens_k = torch.nn.functional.pad(
127+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
128+
)
129+
# Precompute maximum sequence length
130+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
131+
# Precompute page table
132+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
133+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
134+
]
135+
metadata.cu_seqlens_q = torch.arange(
136+
0, batch_size + 1, dtype=torch.int32, device=device
137+
)
137138
elif forward_batch.forward_mode == ForwardMode.TARGET_VERIFY:
138139
draft_token_num = forward_batch.spec_info.draft_token_num
139140

@@ -182,11 +183,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
182183
forward_batch.req_pool_indices, : metadata.max_seq_len_k
183184
]
184185
# Precompute cumulative sequence lengths
185-
if forward_batch.forward_mode == ForwardMode.DECODE:
186-
metadata.cu_seqlens_q = torch.arange(
187-
0, batch_size + 1, dtype=torch.int32, device=device
188-
)
189-
elif (
186+
if (
190187
any(forward_batch.extend_prefix_lens_cpu)
191188
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
192189
):

0 commit comments

Comments
 (0)