Skip to content

[PDDisaggreagtion] Async migration #3610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions lmdeploy/pytorch/disagg/backend/dlslime.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import json
import os
from typing import Dict, List

from dlslime import Assignment as DLSlimeAssignment
Expand All @@ -14,6 +16,22 @@

logger = get_logger('lmdeploy')

LMDEPLOY_USE_ASYNC_MIGRATION = os.environ.get('LMDEPLOY_USE_ASYNC_MIGRATION', None)


async def read_batch_coroutine(endpoint: RDMAEndpoint, batch: List[DLSlimeAssignment]):
loop = asyncio.get_running_loop()
future = loop.create_future()

def _completion_handler(status: int):
loop.call_soon_threadsafe(future.set_result, status)

endpoint.read_batch_with_callback(
batch,
_completion_handler,
)
await future


class DLSlimeMigrationManagement:

Expand Down Expand Up @@ -45,15 +63,7 @@ def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage
def connect(self, connect_request: DistServeConnectionRequest):
self.endpoint[connect_request.protocol].connect(json.loads(connect_request.remote_endpoint_info))

def p2p_migrate(self, assignment: MigrationAssignment, async_op=False):
MAX_NUM_READ_BATCH = 4096

def split(batch: List[DLSlimeAssignment]):
batch_split = []
for i in range(0, len(batch), MAX_NUM_READ_BATCH):
batch_split.append(batch[i:i + MAX_NUM_READ_BATCH])
return batch_split

async def p2p_migrate(self, assignment: MigrationAssignment, async_op=False):
batch = [
DLSlimeAssignment(
mr_key=assign.mr_key,
Expand All @@ -62,9 +72,21 @@ def split(batch: List[DLSlimeAssignment]):
length=assign.length,
) for assign in assignment.batch
]
batch_splited = split(batch)
for b_split in batch_splited:
self.endpoint[assignment.protocol].read_batch(b_split)

if not LMDEPLOY_USE_ASYNC_MIGRATION:
MAX_NUM_READ_BATCH = 4096

def split(batch: List[DLSlimeAssignment]):
batch_split = []
for i in range(0, len(batch), MAX_NUM_READ_BATCH):
batch_split.append(batch[i:i + MAX_NUM_READ_BATCH])
return batch_split

batch_splited = split(batch)
for b_split in batch_splited:
self.endpoint[assignment.protocol].read_batch(b_split)
else:
await read_batch_coroutine(self.endpoint[assignment.protocol], batch)


@MIGRATION_BACKENDS.register_module(MigrationBackend.DLSlime.name)
Expand All @@ -85,8 +107,8 @@ def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol):
def p2p_connect(self, conn_req: DistServeConnectionRequest):
self.links[conn_req.remote_engine_id].connect(conn_req)

def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False):
self.links[assignment.remote_engine_id].p2p_migrate(assignment, async_op=async_op)
async def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False):
await self.links[assignment.remote_engine_id].p2p_migrate(assignment, async_op=async_op)

def store(self, assignment: MigrationAssignment, async_op: bool = False):
raise NotImplementedError
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/engine/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def p2p_initialize(self, migration_init_request: DistServeInitRequest):
def p2p_connect(self, migration_conn_request: DistServeConnectionRequest):
self.migration_backend_impl.p2p_connect(migration_conn_request[self.tp_rank])

def migrate(self, migration_execution_inputs: MigrationExecutionBatch):
async def migrate(self, migration_execution_inputs: MigrationExecutionBatch):

def get_assignment_len():
head_dim = self.model_config.get_head_size()
Expand Down Expand Up @@ -369,7 +369,7 @@ def get_assignment_batch(mr_key, block_ids, assignment_len, layer_stride, remote
assignment_batch.extend(
get_assignment_batch(str(i), blocks_to_migration, assignment_len, layer_stride,
remote_layer_stride))
self.migration_backend_impl.p2p_migrate(
await self.migration_backend_impl.p2p_migrate(
MigrationAssignment(
protocol=migration_execution_inputs.protocol,
remote_engine_id=remote_engine_id,
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,6 @@ def p2p_connect(self, conn_request: List[DistServeConnectionRequest]):
return self.model_agent.cache_engine.p2p_connect(conn_request)

async def migrate(self, inputs: MigrationExecutionBatch):
return self.model_agent.cache_engine.migrate(inputs)
return await self.model_agent.cache_engine.migrate(inputs)

""" PD Disaggregation API End """
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/executor/uni_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,6 @@ def p2p_connect(self, conn_request: List[DistServeConnectionRequest]):

async def migrate(self, batch: MigrationExecutionBatch):
"""KV Cache Migration."""
return self.model_agent.cache_engine.migrate(batch)
return await self.model_agent.cache_engine.migrate(batch)

""" PD Disaggregation API End """