Skip to content

support MTP for Furion PD-disaggregation #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: feature/pd-disaggregation
Choose a base branch
from
35 changes: 29 additions & 6 deletions python/sglang/srt/managers/kv_transfer_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ def __init__(self,
attn_tp_cpu_group: torch.distributed.ProcessGroup = None,
device: str = "cpu"):
self.kv_buffer = {}
self.spec_info_buffer = {}
self.layer_num = layer_num
self.tp_rank = tp_rank
self.attn_tp_cpu_group = attn_tp_cpu_group
self.role = server_args.kv_transfer_config.role
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator

self.attn_tp_rank, self.attn_tp_size, _ = (
compute_dp_attention_world_info(
server_args.enable_dp_attention,
Expand Down Expand Up @@ -166,11 +166,24 @@ def set_kv_buffer(self, req: Req) -> int:
[self.token_to_kv_pool_allocator.get_kvcache().get_key_buffer(i)[kv_indices]
for i in range(self.layer_num)]
)
kv_cache = safetensors_save({"tensor": flatten.to(self.device)})
self.kv_buffer[req.rid] = kv_cache
return len(kv_cache)
if req.speculative_algorithm is not None:
kv_cache_and_spec_info = safetensors_save({"kv_cache": flatten.to(self.device),
"top_k": req.top_k.to(self.device) if req.top_k is not None else None,
"top_k_index":req.top_k_index.to(self.device) if req.top_k_index is not None else None,
"hidden_states":req.hidden_states_spec.to(self.device) if req.hidden_states_spec is not None else None,
"verified_id":req.verified_id.to(self.device) if req.verified_id is not None else None},)
logger.debug(f" kv_transfer_agent send top_k {req.top_k.shape if req.top_k is not None else 0} \n"
f"top_k_index {req.top_k_index.shape if req.top_k_index is not None else 0} \n"
f"hidden_states {req.hidden_states_spec.shape if req.hidden_states_spec is not None else None} \n"
f"verified_id {req.verified_id.shape if req.verified_id is not None else None}")
self.kv_buffer[req.rid] = kv_cache_and_spec_info
return len(kv_cache_and_spec_info)
else:
kv_cache = safetensors_save({"kv_cache": flatten.to(self.device)})
self.kv_buffer[req.rid] = kv_cache
return len(kv_cache)

def get_kv_buffer(self, req: Req) -> torch.Tensor:
def get_kv_buffer(self, req: Req) -> Union[torch.Tensor, dict]:
if self.attn_tp_rank == 0:
dst_ptr = self._allocate_transfer_kv_buffer(req.kv_cache_length)
# fetch kv buffer util transfer done
Expand All @@ -193,7 +206,17 @@ def get_kv_buffer(self, req: Req) -> torch.Tensor:
# free buffer
self._free_transfer_kv_buffer(dst_ptr, req.kv_cache_length)
# load to device
loaded_tensor = safetensors_load(kv_cache)["tensor"]
loaded_data = safetensors_load(kv_cache)
if len(loaded_data) > 1: # Has speculative info
loaded_tensor = {
"kv_cache": loaded_data["kv_cache"],
"top_k": loaded_data["top_k"],
"top_k_index": loaded_data["top_k_index"],
"hidden_states": loaded_data["hidden_states"],
"verified_id": loaded_data["verified_id"]
}
else:
loaded_tensor = loaded_data["kv_cache"]
else:
loaded_tensor = None
if self.attn_tp_size > 1:
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,13 @@ def __init__(
self.kv_cache_length = kv_cache_length
self.kv_cache_restored = False

# speculative decoding
self.top_k : torch.Tensor = None
self.top_k_index : torch.Tensor = None
self.hidden_states_spec : torch.Tensor = None
self.verified_id : torch.Tensor = None
self.speculative_algorithm = None

# Sampling info
if isinstance(sampling_params.custom_params, dict):
sampling_params = copy.copy(sampling_params)
Expand Down Expand Up @@ -845,6 +852,7 @@ def recover_for_decode(self, origin_output_ids: List[List[int]]):
extend_input_logprob_token_ids = []

pt = 0
logger.info(f"recover for decode origin_output_ids {len(origin_output_ids)} reqs {len(reqs)}")
for i, req in enumerate(reqs):
req.req_pool_idx = req_pool_indices[i]
req.output_ids = origin_output_ids[i]
Expand Down
57 changes: 50 additions & 7 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@
suppress_other_loggers,
)
from sglang.utils import TypeBasedDispatcher, get_exception_traceback

from sglang.srt.speculative.eagle_utils import EagleDraftInput

logger = logging.getLogger(__name__)

# Test retract decode for debugging purposes
Expand Down Expand Up @@ -903,7 +906,7 @@ def log_prefill_stats(
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.cache_hit_rate = cache_hit_rate
self.metrics_collector.log_stats(self.stats)

def log_recover_stats(
self,
adder: PrefillAdder,
Expand Down Expand Up @@ -1053,7 +1056,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
if self.server_args.kv_transfer_config is not None:
if self.last_batch:
self.running_batch = self.last_batch

if self.server_args.kv_transfer_config.role == "decode":
ret = self.recover_new_prefilled_batch()
if not self.running_batch.is_empty():
Expand All @@ -1067,12 +1070,12 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
ret = None
else:
ret = self.get_new_batch_prefill()

if self.server_args.enable_dp_attention:
ret, _ = self.prepare_dp_attn_batch(ret)

return ret

# Merge the prefill batch into the running batch
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.chunked_req:
Expand Down Expand Up @@ -1115,13 +1118,13 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
ret, _ = self.prepare_dp_attn_batch(ret)

return ret

def recover_new_prefilled_batch(self) -> Optional[ScheduleBatch]:
if (
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
) and self.chunked_req is None:
return None

running_bs = len(self.running_batch.reqs) if self.running_batch else 0
if running_bs >= self.max_running_requests:
self.running_batch.batch_is_full = True
Expand Down Expand Up @@ -1216,12 +1219,42 @@ def recover_new_prefilled_batch(self) -> Optional[ScheduleBatch]:
new_batch.recover_for_decode(origin_output_ids)
# Recover kv cache from kv_transfer_agent
pt = 0
top_k = None
top_k_index = None
hidden_states = None
verified_id = None
for i in range(new_batch.batch_size()):
req = new_batch.reqs[i]
if req.kv_cache_restored:
pt += new_batch.extend_lens[i]
continue
flattened_kv_buffer = self.kv_transfer_agent.get_kv_buffer(req).to(self.device)
flattened_buffer = self.kv_transfer_agent.get_kv_buffer(req)
if new_batch.spec_algorithm is not None:
assert isinstance(flattened_buffer,dict)
flattened_kv_buffer = flattened_buffer["kv_cache"].to(self.device)
flattened_topk_buffer = flattened_buffer["top_k"].to(self.device)
flattened_topk_index_buffer = flattened_buffer["top_k_index"].to(self.device)
flattened_hidden_states_buffer = flattened_buffer["hidden_states"].to(self.device)
flattened_verified_id_buffer = flattened_buffer["verified_id"].to(self.device)
if top_k is None or top_k_index is None or hidden_states is None:
top_k = torch.zeros((new_batch.batch_size(),) + tuple(flattened_topk_buffer.shape))
top_k_index = torch.zeros((new_batch.batch_size(),) + tuple(flattened_topk_index_buffer.shape))
hidden_states = torch.zeros((new_batch.batch_size(),) + tuple(flattened_hidden_states_buffer.shape))
verified_id = torch.zeros((new_batch.batch_size(),))
req.top_k = flattened_topk_buffer
req.top_k_index = flattened_topk_index_buffer
req.hidden_states = flattened_hidden_states_buffer
req.verified_id = flattened_verified_id_buffer
top_k[i] = flattened_topk_buffer
top_k_index[i] = flattened_topk_index_buffer
hidden_states[i] = flattened_hidden_states_buffer
verified_id[i] = flattened_verified_id_buffer
logger.debug(f"{self.tp_rank} recover_new_prefilled_batch spec info top k: {req.top_k.shape} {req.top_k.device},\n"+
f" top_k_index: {req.top_k_index.shape} {req.top_k_index.device},\n " +
f" hidden_states: {req.hidden_states.shape} {req.hidden_states.device} \n"
f" verified_id: {req.verified_id.shape} {req.verified_id.device}")
else:
flattened_kv_buffer = flattened_buffer.to(self.device)
layer_kv_buffers = torch.unbind(flattened_kv_buffer, dim=0)
kv_cache_pool = self.token_to_kv_pool_allocator.get_kvcache()
for layer_id, layer_kv_buffer in enumerate(layer_kv_buffers):
Expand All @@ -1234,6 +1267,16 @@ def recover_new_prefilled_batch(self) -> Optional[ScheduleBatch]:
)
req.kv_cache_restored = True
pt += new_batch.extend_lens[i]
draft_input = EagleDraftInput()
draft_input.hidden_states = hidden_states.to(self.device)
draft_input.topk_p = top_k.to(self.device)
draft_input.topk_index = top_k_index.to(self.device)
draft_input.verified_id = verified_id.to(self.device)
logging.debug(f"\n\ndraft_input spec info top k {draft_input.topk_p.shape} {draft_input.topk_p.device} \n"
f" topk_index spec info topk_index {draft_input.topk_index.shape} {draft_input.topk_index.device}\n"
f" hidden_states: {draft_input.hidden_states.shape} {draft_input.hidden_states.device} \n"
f" verified_id: {draft_input.verified_id.shape} {draft_input.verified_id.device}")
new_batch.spec_info = draft_input
return new_batch

def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,7 @@ def forward(
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:

hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)

if self.dp_size != 1:
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/speculative/eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,18 @@ def forward_draft_extend(
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info)
# restore spec info for pd disaggregation
logger.debug(f"\n\nspec info batch.spec_info.verified_id.shape {batch.spec_info.verified_id.shape} forward_batch.spec_info.verified_id.shape {forward_batch.spec_info.verified_id.shape} \n\n")
for i,req in enumerate(batch.reqs):
req.speculative_algorithm = batch.spec_algorithm
req.top_k = batch.spec_info.topk_p[i]
req.top_k_index = batch.spec_info.topk_index[i]
req.hidden_states_spec = batch.spec_info.hidden_states[i]
req.verified_id = batch.spec_info.verified_id[i:i+1]
logger.debug(f"\n\nforward_draft_extend top_k {req.top_k.shape if req.top_k is not None else 0} \n"
f"top_k_index {req.top_k_index.shape if req.top_k_index is not None else 0} \n"
f"hidden_states {req.hidden_states_spec.shape if req.hidden_states_spec is not None else None} \n"
f"verified_id {req.verified_id.shape if req.verified_id is not None else None} \n\n")

def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
# Backup fileds that will be modified in-place
Expand Down
Loading