Skip to content

feat: two batch overlap #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
128 changes: 75 additions & 53 deletions python/sglang/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache

import logging

logger = logging.getLogger(__name__)


@dataclass
class FlashAttentionMetadata:
Expand Down Expand Up @@ -300,8 +304,8 @@ def __init__(
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
self.page_size = model_runner.page_size
self.use_mla = (
model_runner.model_config.attention_arch == AttentionArch.MLA
) and (not global_server_args_dict["disable_mla"])
model_runner.model_config.attention_arch == AttentionArch.MLA
) and (not global_server_args_dict["disable_mla"])
self.skip_prefill = skip_prefill

self.topk = topk
Expand Down Expand Up @@ -344,8 +348,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
else:
# Normal Decode
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
Expand All @@ -357,8 +361,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
elif forward_batch.forward_mode.is_target_verify():
metadata.cache_seqlens_int32 = (
forward_batch.seq_lens + self.speculative_num_draft_tokens
Expand All @@ -380,8 +384,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]

elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
Expand All @@ -390,8 +394,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]

if (
any(forward_batch.extend_prefix_lens_cpu)
Expand Down Expand Up @@ -421,8 +425,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
# Make sure effective_chunk_size is divisible by page_size
effective_chunk_size = (
effective_chunk_size // self.page_size
) * self.page_size
effective_chunk_size // self.page_size
) * self.page_size
if effective_chunk_size < self.page_size:
effective_chunk_size = self.page_size

Expand Down Expand Up @@ -464,16 +468,16 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
)
metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
]
forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
]

# Currently only support forward_batch.encoder_lens.numel() == 1
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices,
metadata.encoder_max_seq_len_k : (
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
),
]
forward_batch.req_pool_indices,
metadata.encoder_max_seq_len_k: (
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
),
]

# Convert the page table to a strided format which is needed by FA3 API
if self.page_size > 1:
Expand Down Expand Up @@ -642,7 +646,7 @@ def forward_extend(
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_rope = kv_cache[:, :, layer.v_head_dim :]
k_rope = kv_cache[:, :, layer.v_head_dim:]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
Expand All @@ -655,7 +659,7 @@ def forward_extend(
)
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
q_rope = q_all[:, :, layer.v_head_dim:]
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
Expand Down Expand Up @@ -770,7 +774,7 @@ def forward_decode(
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_rope = kv_cache[:, :, layer.v_head_dim :]
k_rope = kv_cache[:, :, layer.v_head_dim:]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
Expand All @@ -784,7 +788,24 @@ def forward_decode(

q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
q_rope = q_all[:, :, layer.v_head_dim:]

# #logger.debug(f"layer[{layer.layer_id}], mode[{forward_batch.forward_mode}], flash_attn_with_kvcache args:\n"
# f"q.shape: {q_rope.shape}\n"
# f"k_cache.shape: {k_rope_cache.shape}\n"
# f"v_cache.shape: {c_kv_cache.shape}\n"
# f"qv.shape: {q_nope.shape}\n"
# f"page_table: {metadata.page_table.shape}\n"
# f"cache_seqlens: {metadata.cache_seqlens_int32.shape}\n"
# f"cu_seqlens_q: {metadata.cu_seqlens_q.shape}\n"
# f"cu_seqlens_k: {metadata.cu_seqlens_k.shape}\n"
# f"softmax_scale: {layer.scaling}\n"
# f"softcap: {layer.logit_cap}\n"
# f"k_descale: {k_descale}\n"
# f"v_descale: {v_descale}")

# import time
# time.sleep(0.5)

o = flash_attn_with_kvcache(
q=q_rope,
Expand All @@ -802,6 +823,7 @@ def forward_decode(
k_descale=k_descale,
v_descale=v_descale,
)
##logger.debug(f"layer[{layer.layer_id}], mode[{forward_batch.forward_mode}], flash_attn_with_kvcache with o: {o.shape}")
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

def init_cuda_graph_state(self, max_bs: int):
Expand Down Expand Up @@ -889,23 +911,23 @@ def init_forward_metadata_capture_cuda_graph(
if spec_info is not None:
# Draft Decode
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
"cache_seqlens"
][:bs]
"cache_seqlens"
][:bs]
metadata.max_seq_len_k = seq_lens.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
: bs + 1
]
: bs + 1
]
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = self.decode_cuda_graph_metadata[
"page_table_draft_decode"
][req_pool_indices, :]
"page_table_draft_decode"
][req_pool_indices, :]
else:
# Normal Decode
# Get sequence information
Expand All @@ -919,17 +941,17 @@ def init_forward_metadata_capture_cuda_graph(
metadata.max_seq_len_k = seq_lens.max().item()
# Precompute page table
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
req_pool_indices, :
]
req_pool_indices, :
]
# Precompute cumulative sequence lengths
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
self.decode_cuda_graph_metadata[bs] = metadata
elif forward_mode.is_target_verify():
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
:bs
]
:bs
]
metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
)
Expand All @@ -948,27 +970,27 @@ def init_forward_metadata_capture_cuda_graph(
)

metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
: (bs + 1)
]
: (bs + 1)
]

metadata.page_table = self.target_verify_metadata["page_table"][
req_pool_indices, :
]
req_pool_indices, :
]

self.target_verify_metadata[bs] = metadata

if encoder_lens is not None:
encoder_bs = encoder_lens.numel()
metadata.encoder_lens_int32 = self.encoder_metadata["encoder_lens_int32"][
:encoder_bs
]
:encoder_bs
]
metadata.encoder_cu_seqlens_k = self.encoder_metadata[
"encoder_cu_seqlens_k"
][: (encoder_bs + 1)]
"encoder_cu_seqlens_k"
][: (encoder_bs + 1)]

metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
req_pool_indices, :
]
req_pool_indices, :
]

self.forward_metadata = metadata

Expand Down Expand Up @@ -1010,8 +1032,8 @@ def init_forward_metadata_replay_cuda_graph(
)

page_table = self.req_to_token[
req_pool_indices, : metadata.max_seq_len_k
]
req_pool_indices, : metadata.max_seq_len_k
]

metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
else:
Expand All @@ -1025,12 +1047,12 @@ def init_forward_metadata_replay_cuda_graph(
)

max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
None, :
None, :
],
]
page_indices //= self.page_size
Expand Down Expand Up @@ -1074,11 +1096,11 @@ def init_forward_metadata_replay_cuda_graph(

# Update the regular page table
page_table = self.req_to_token[
req_pool_indices,
metadata.encoder_max_seq_len_k : (
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
),
]
req_pool_indices,
metadata.encoder_max_seq_len_k: (
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
),
]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)

self.forward_metadata = metadata
Expand Down
40 changes: 32 additions & 8 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
BLOCK_SIZE=512,
)

seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
seg_indptr_cur_rank = seg_indptr[self.start_expert_id: self.end_expert_id + 2]
weight_indices_cur_rank = torch.arange(
0,
self.num_experts_per_partition,
Expand Down Expand Up @@ -431,7 +431,7 @@ def weight_loader(
elif shard_id == "w1":
param.data[expert_id][: self.intermediate_size, :] = loaded_weight
elif shard_id == "w3":
param.data[expert_id][self.intermediate_size :, :] = loaded_weight
param.data[expert_id][self.intermediate_size:, :] = loaded_weight
else:
raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}")

Expand Down Expand Up @@ -463,11 +463,11 @@ def _load_fp8_scale(
block_n, block_k = self.block_shape[0], self.block_shape[1]
if shard_id == "w1":
param_data[expert_id][
: (self.intermediate_size + block_n - 1) // block_n, :
: (self.intermediate_size + block_n - 1) // block_n, :
] = loaded_weight
elif shard_id == "w3":
param_data[expert_id][
(self.intermediate_size + block_n - 1) // block_n :, :
(self.intermediate_size + block_n - 1) // block_n:, :
] = loaded_weight
else: # w2
param_data[expert_id] = loaded_weight
Expand Down Expand Up @@ -844,12 +844,15 @@ def forward(
masked_m: torch.Tensor,
expected_m: int,
forward_mode: ForwardMode,
output_tensor: Optional[torch.Tensor] = None,
):
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
# logger.debug(
# f"DeepEPMoE forward: deepep_mode: {self.deepep_mode}, forward_mode: {forward_mode}, resolved_deepep_mode: {resolved_deepep_mode}")
if resolved_deepep_mode == DeepEPMode.normal:
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
elif resolved_deepep_mode == DeepEPMode.low_latency:
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m, output_tensor)
else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")

Expand Down Expand Up @@ -971,6 +974,7 @@ def forward_deepgemm_masked(
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
masked_m: torch.Tensor,
expected_m: int,
output_tensor: Optional[torch.Tensor] = None,
):
assert self.quant_method is not None
assert self.activation == "silu"
Expand Down Expand Up @@ -1023,11 +1027,31 @@ def forward_deepgemm_masked(
down_input,
get_col_major_tma_aligned_tensor(down_input_scale),
)
down_output = torch.empty(
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
)
# if get_tensor_model_parallel_rank()==0:
# logger.debug(f"[forward_deepgemm_masked],{get_gpu_memory_info(down_input.device)}")
# logger.debug(f"[forward_deepgemm_masked],memory_summary: {torch.cuda.memory_summary(device=0)}")
if output_tensor is None:
down_output = torch.empty(
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
)
else:
down_output = output_tensor
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
)

# del tensor to avoid out of memory
# del gateup_output
# del down_input
# del down_input_scale
# del down_input_fp8
# del hidden_states_fp8

# if get_tensor_model_parallel_rank()==0:
# logger.debug(f"[forward_deepgemm_masked after del tensor],memory_summary: {torch.cuda.memory_summary(device=0)}")
# logger.debug(f"[forward_deepgemm_masked after del tensor],{get_gpu_memory_info(down_output.device)}")
return down_output


def get_gpu_memory_info(device) -> str:
return f"device[{device}],cached:{torch.cuda.memory_reserved(device=device)},[[[allocated:{torch.cuda.memory_allocated(device=device)},]]]free:{torch.cuda.memory_reserved(device=device) - torch.cuda.memory_allocated(device=device)}"
Loading