Skip to content

Commit 0626f67

Browse files
authored
[RL] support update_weights_from_distributed with different group and multiple weights (#7292)
1 parent 09e699b commit 0626f67

File tree

6 files changed

+73
-38
lines changed

6 files changed

+73
-38
lines changed

python/sglang/srt/entrypoints/engine.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -418,12 +418,21 @@ def init_weights_update_group(
418418
self.tokenizer_manager.init_weights_update_group(obj, None)
419419
)
420420

421-
def update_weights_from_distributed(self, name: str, dtype, shape):
421+
def update_weights_from_distributed(
422+
self,
423+
names: list[str],
424+
dtypes: list[str],
425+
shapes: list[list[int]],
426+
group_name: str = "weight_update_group",
427+
flush_cache: bool = True,
428+
):
422429
"""Update weights from distributed source."""
423430
obj = UpdateWeightsFromDistributedReqInput(
424-
name=name,
425-
dtype=dtype,
426-
shape=shape,
431+
names=names,
432+
dtypes=dtypes,
433+
shapes=shapes,
434+
group_name=group_name,
435+
flush_cache=flush_cache,
427436
)
428437
loop = asyncio.get_event_loop()
429438
return loop.run_until_complete(

python/sglang/srt/managers/io_struct.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -752,9 +752,13 @@ class UpdateWeightFromDiskReqOutput:
752752

753753
@dataclass
754754
class UpdateWeightsFromDistributedReqInput:
755-
name: str
756-
dtype: str
757-
shape: List[int]
755+
names: List[str]
756+
dtypes: List[str]
757+
shapes: List[List[int]]
758+
# The group name
759+
group_name: str = "weight_update_group"
760+
# Whether to flush the cache after updating weights
761+
flush_cache: bool = True
758762

759763

760764
@dataclass

python/sglang/srt/managers/scheduler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,8 +2303,9 @@ def update_weights_from_distributed(
23032303
"""Update the online model parameter."""
23042304
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
23052305
if success:
2306-
flush_cache_success = self.flush_cache()
2307-
assert flush_cache_success, "Cache flush failed after updating weights"
2306+
if recv_req.flush_cache:
2307+
flush_cache_success = self.flush_cache()
2308+
assert flush_cache_success, "Cache flush failed after updating weights"
23082309
else:
23092310
logger.error(message)
23102311
return UpdateWeightsFromDistributedReqOutput(success, message)

python/sglang/srt/managers/tp_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def update_weights_from_distributed(
259259
self, recv_req: UpdateWeightsFromDistributedReqInput
260260
):
261261
success, message = self.model_runner.update_weights_from_distributed(
262-
recv_req.name, recv_req.dtype, recv_req.shape
262+
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
263263
)
264264
return success, message
265265

python/sglang/srt/model_executor/model_runner.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def __init__(
225225
self.support_pp = (
226226
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
227227
)
228+
self._model_update_group = {}
228229

229230
def initialize(self, min_per_gpu_memory: float):
230231
server_args = self.server_args
@@ -744,7 +745,7 @@ def init_weights_update_group(
744745
)
745746

746747
try:
747-
self._model_update_group = init_custom_process_group(
748+
self._model_update_group[group_name] = init_custom_process_group(
748749
backend=backend,
749750
init_method=f"tcp://{master_address}:{master_port}",
750751
world_size=world_size,
@@ -757,7 +758,7 @@ def init_weights_update_group(
757758
logger.error(message)
758759
return False, message
759760

760-
def update_weights_from_distributed(self, name, dtype, shape):
761+
def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
761762
"""
762763
Update specific parameter in the model weights online
763764
through `_model_update_group` process group.
@@ -767,19 +768,34 @@ def update_weights_from_distributed(self, name, dtype, shape):
767768
dtype: the data type of the parameter to be updated.
768769
shape: the shape of the parameter to be updated.
769770
"""
770-
target_dtype = (
771-
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
772-
)
773771

774-
assert (
775-
self._model_update_group is not None
776-
), "model update group must be initialized"
772+
assert group_name in self._model_update_group, (
773+
f"Group {group_name} not in {list(self._model_update_group.keys())}. "
774+
"Please call `init_weights_update_group` first."
775+
)
777776

778777
try:
779-
weights = torch.empty(shape, dtype=target_dtype, device=self.device)
780-
torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
781-
self.model.load_weights([(name, weights)])
782-
return True, f"Succeeded to update parameter {name} online."
778+
weights = []
779+
handles = []
780+
for name, dtype, shape in zip(names, dtypes, shapes):
781+
target_dtype = (
782+
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
783+
)
784+
weight = torch.empty(shape, dtype=target_dtype, device=self.device)
785+
handles.append(
786+
torch.distributed.broadcast(
787+
weight,
788+
src=0,
789+
group=self._model_update_group[group_name],
790+
async_op=True,
791+
)
792+
)
793+
weights.append((name, weight))
794+
for handle in handles:
795+
handle.wait()
796+
797+
self.model.load_weights(weights)
798+
return True, f"Succeeded to update parameter online."
783799

784800
except Exception as e:
785801
error_msg = (

test/srt/test_update_weights_from_distributed.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -294,22 +294,27 @@ def init_process_sgl(
294294
update_parameters.remove("lm_head.weight")
295295

296296
# Get weights from the training engine and update the inference engine.
297-
for parameter_name in update_parameters:
298-
if backend == "Engine":
299-
engine.update_weights_from_distributed(
300-
parameter_name,
301-
dtype=torch.bfloat16,
302-
shape=state_dict_key_to_shape[parameter_name],
303-
)
304-
else:
305-
requests.post(
306-
f"{url}/update_weights_from_distributed",
307-
json={
308-
"name": parameter_name,
309-
"dtype": "bfloat16",
310-
"shape": state_dict_key_to_shape[parameter_name],
311-
},
312-
)
297+
names = [parameter_name for parameter_name in update_parameters]
298+
dtypes = [torch.bfloat16 if backend == "Engine" else "bfloat16"] * len(names)
299+
shapes = [state_dict_key_to_shape[parameter_name] for parameter_name in names]
300+
301+
if backend == "Engine":
302+
engine.update_weights_from_distributed(
303+
names,
304+
dtypes=dtypes,
305+
shapes=shapes,
306+
group_name="test_parameter_update_group",
307+
)
308+
else:
309+
requests.post(
310+
f"{url}/update_weights_from_distributed",
311+
json={
312+
"names": names,
313+
"dtypes": dtypes,
314+
"shapes": shapes,
315+
"group_name": "test_parameter_update_group",
316+
},
317+
)
313318
torch.cuda.synchronize()
314319
time_end_update = time.perf_counter()
315320

0 commit comments

Comments
 (0)