Open
Description
Motivation
When switching between inference and training frameworks, weight conversion is required. The current implementation typically involves resharding operations tailored to specific framework pairs, which performs well when the number of frameworks is limited. However, as framework choices increase, the cost of maintaining resharding between more frameworks rises significantly. Therefore, a decoupled approach is needed, implementing unified resharding functions on both the inference and training sides. These functions can be managed through a unified AutoManager
function with __enter__()
and __exit__()
methods, handling data in the format of (iter[tuple(str, torch.nn.module)])
.
Design:
class AutoShardingManager(BaseShardingManager):
def __init__(self, actor, rollout):
self.actor = actor
self.rollout = rollout
def __enter__(self):
generator = self.actor.gen_generator()
self.rollout.wakeup_weight()
self.rollout.upload_weight(generator)
self.rollout.wakeup_kv()
def __exit__(self, exc_type, exc_value, traceback):
self.rollout.sleep()
Feasibility Analysis:
- The vLLM side has already implemented a generator-based
(iter[tuple(str, torch.nn.module)])
, which can be directly utilized. - For sglang, I submitted a PR over the weekend, and the team reached a consensus to align with vLLM by adopting the
(iter[tuple(str, torch.nn.module)])
input format. - In verl, the generator approach for FSDP is already widely used.
- Megatron’s current implementation also employs a generator-based weight update mechanism.