Skip to content

Commit bb022c0

Browse files
fzyzcjyShenAo1111
authored andcommitted
[Feature] SPMD for SGLang + Verl (sgl-project#3852)
1 parent a5c774a commit bb022c0

File tree

19 files changed

+718
-132
lines changed

19 files changed

+718
-132
lines changed

.github/workflows/pr-test.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,12 @@ jobs:
149149
cd test/srt
150150
python3 test_update_weights_from_distributed.py
151151
152+
- name: Test VerlEngine
153+
timeout-minutes: 10
154+
run: |
155+
cd test/srt
156+
python3 test_verl_engine.py
157+
152158
- name: Test expert parallelism (EP=2)
153159
timeout-minutes: 10
154160
run: |
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import datetime
2+
import os
3+
import sys
4+
5+
from torch.distributed.device_mesh import init_device_mesh
6+
7+
from sglang.srt.entrypoints.verl_engine import VerlEngine
8+
9+
10+
def run():
11+
"""
12+
Example command:
13+
```
14+
torchrun --nproc_per_node=8 offline_batch_inference_torchrun.py
15+
```
16+
"""
17+
18+
local_rank = int(os.environ["LOCAL_RANK"])
19+
rank = int(os.environ["RANK"])
20+
world_size = int(os.environ["WORLD_SIZE"])
21+
22+
def _log(text):
23+
t = datetime.datetime.now().strftime("%H:%M:%S")
24+
print(f"[{t}] [rank={rank}] {text}")
25+
26+
_log(
27+
f'start {local_rank=} {rank=} {world_size=} {sys.argv=} {os.environ.get("CUDA_VISIBLE_DEVICES")}'
28+
)
29+
30+
tp_size = 4
31+
dp_size = 2
32+
assert world_size == tp_size * dp_size
33+
34+
device_mesh_kwargs = dict(
35+
mesh_shape=(tp_size, dp_size, 1), mesh_dim_names=["tp", "dp", "pp"]
36+
)
37+
device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)
38+
_log(f"{device_mesh_cpu=}")
39+
40+
tp_rank = device_mesh_cpu.get_local_rank("tp")
41+
dp_rank = device_mesh_cpu.get_local_rank("dp")
42+
_log(f"{tp_rank=} {tp_size=} ; {dp_rank=} {dp_size=}")
43+
44+
model_name, mem_fraction_static = "meta-llama/Llama-3.2-1B-Instruct", 0.1
45+
# model_name, mem_fraction_static = "meta-llama/Llama-3.1-70B-Instruct", 0.9 # test large models
46+
# model_name, mem_fraction_static = "deepseek-ai/DeepSeek-V2-Lite", 0.8
47+
48+
for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
49+
if k in os.environ:
50+
del os.environ[k]
51+
52+
fragment = VerlEngine(
53+
model_path=model_name,
54+
mem_fraction_static=mem_fraction_static,
55+
device_mesh_cpu=device_mesh_cpu["tp"],
56+
base_gpu_id=dp_rank,
57+
gpu_id_step=dp_size,
58+
port=30000,
59+
# for DeepSeek-V2-Lite + DP Attention
60+
# enable_dp_attention=True, port=30000 + dp_rank * 100,
61+
)
62+
_log(f"{fragment=}")
63+
64+
prompt_all = [
65+
["1+1=2, 1+2=3, 1+3=4, 1+4=", "9-1=8, 8-1=7, 7-1="],
66+
["2*1=2, 2*2=4, 2*3=", "8/2=4, 6/2="],
67+
]
68+
prompt = prompt_all[dp_rank]
69+
70+
output = fragment.generate(
71+
prompt=prompt,
72+
sampling_params=dict(max_new_tokens=16, temperature=0.0),
73+
)
74+
_log(f"{prompt=} {output=}")
75+
76+
fragment.shutdown()
77+
_log(f"End script")
78+
79+
80+
if __name__ == "__main__":
81+
run()

python/sglang/srt/entrypoints/engine.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,18 @@ def update_weights_from_distributed(self, name: str, dtype, shape):
271271
self.tokenizer_manager.update_weights_from_distributed(obj, None)
272272
)
273273

274-
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
275-
"""Update weights from distributed source."""
274+
def update_weights_from_tensor(
275+
self,
276+
named_tensors: List[Tuple[str, torch.Tensor]],
277+
load_format: Optional[str] = None,
278+
flush_cache: bool = True,
279+
):
280+
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
281+
to avoid duplicated operations such as clearing cache."""
276282
obj = UpdateWeightsFromTensorReqInput(
277-
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors)
283+
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors),
284+
load_format=load_format,
285+
flush_cache=flush_cache,
278286
)
279287
loop = asyncio.get_event_loop()
280288
return loop.run_until_complete(
@@ -384,7 +392,10 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
384392
)
385393
for tp_rank in tp_rank_range:
386394
reader, writer = mp.Pipe(duplex=False)
387-
gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
395+
gpu_id = (
396+
server_args.base_gpu_id
397+
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
398+
)
388399
proc = mp.Process(
389400
target=run_scheduler_process,
390401
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright 2023-2024 SGLang Team
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ==============================================================================
14+
import os
15+
from typing import Dict, List, Optional, Tuple, Union
16+
17+
import torch
18+
import torch.distributed as dist
19+
from torch.distributed.tensor import DeviceMesh, DTensor
20+
21+
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
22+
from sglang.srt.server import Engine
23+
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
24+
25+
26+
class VerlEngine:
27+
def __init__(
28+
self,
29+
device_mesh_cpu: DeviceMesh,
30+
nnodes: int = 1,
31+
**kwargs,
32+
):
33+
self._device_mesh_cpu = device_mesh_cpu
34+
self._tp_rank = device_mesh_cpu.get_local_rank()
35+
self._tp_size = device_mesh_cpu.size()
36+
tp_size_per_node = self._tp_size // nnodes
37+
node_rank = self._tp_rank // tp_size_per_node
38+
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
39+
40+
if first_rank_in_node:
41+
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
42+
self._engine = Engine(
43+
**kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes
44+
)
45+
else:
46+
self._engine = None
47+
48+
dist.barrier(group=self._device_mesh_cpu.get_group())
49+
50+
def generate(
51+
self,
52+
# The input prompt. It can be a single prompt or a batch of prompts.
53+
prompt: Optional[Union[List[str], str]] = None,
54+
sampling_params: Optional[Union[List[Dict], Dict]] = None,
55+
# The token ids for text; one can either specify text or input_ids.
56+
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
57+
# The image input. It can be a file name, a url, or base64 encoded string.
58+
# See also python/sglang/srt/utils.py:load_image.
59+
image_data: Optional[Union[List[str], str]] = None,
60+
return_logprob: Optional[Union[List[bool], bool]] = False,
61+
logprob_start_len: Optional[Union[List[int], int]] = None,
62+
top_logprobs_num: Optional[Union[List[int], int]] = None,
63+
lora_path: Optional[List[Optional[str]]] = None,
64+
custom_logit_processor: Optional[Union[List[str], str]] = None,
65+
) -> Dict:
66+
"""
67+
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
68+
Please refer to `GenerateReqInput` for the documentation.
69+
"""
70+
if self._tp_rank == 0:
71+
output = self._engine.generate(
72+
prompt=prompt,
73+
sampling_params=sampling_params,
74+
input_ids=input_ids,
75+
image_data=image_data,
76+
return_logprob=return_logprob,
77+
logprob_start_len=logprob_start_len,
78+
top_logprobs_num=top_logprobs_num,
79+
lora_path=lora_path,
80+
custom_logit_processor=custom_logit_processor,
81+
)
82+
else:
83+
output = None
84+
85+
# Most naive implementation, can extract tensor and send via gloo if too slow
86+
[output] = broadcast_pyobj(
87+
data=[output],
88+
rank=self._tp_rank,
89+
dist_group=self._device_mesh_cpu.get_group(),
90+
src=self._device_mesh_cpu.mesh[0].item(),
91+
)
92+
93+
return output
94+
95+
def update_weights_from_tensor(
96+
self,
97+
named_tensors: List[Tuple[str, torch.Tensor]],
98+
load_format: Optional[str] = None,
99+
):
100+
# Most naive implementation, can optimize a lot if it is bottleneck
101+
for tensor_index, (name, tensor) in enumerate(named_tensors):
102+
serialized_tensor = MultiprocessingSerializer.serialize(
103+
_preprocess_tensor_for_update_weights(tensor)
104+
)
105+
106+
if self._tp_rank == 0:
107+
gathered_serialized_tensors = [None for _ in range(self._tp_size)]
108+
else:
109+
gathered_serialized_tensors = None
110+
dist.gather_object(
111+
obj=serialized_tensor,
112+
object_gather_list=gathered_serialized_tensors,
113+
dst=self._device_mesh_cpu.mesh.tolist()[0],
114+
group=self._device_mesh_cpu.get_group(),
115+
)
116+
117+
if self._tp_rank == 0:
118+
self._engine.update_weights_from_tensor(
119+
named_tensors=[
120+
(
121+
name,
122+
LocalSerializedTensor(values=gathered_serialized_tensors),
123+
)
124+
],
125+
load_format=load_format,
126+
flush_cache=tensor_index == len(named_tensors) - 1,
127+
)
128+
129+
def release_memory_occupation(self):
130+
if self._tp_rank == 0:
131+
self._engine.release_memory_occupation()
132+
133+
def resume_memory_occupation(self):
134+
if self._tp_rank == 0:
135+
self._engine.resume_memory_occupation()
136+
137+
def shutdown(self):
138+
if self._engine is not None:
139+
self._engine.shutdown()
140+
141+
142+
def _preprocess_tensor_for_update_weights(tensor: torch.Tensor):
143+
if isinstance(tensor, DTensor):
144+
return tensor.full_tensor()
145+
return tensor

python/sglang/srt/managers/data_parallel_controller.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def launch_dp_schedulers(self, server_args, port_args):
121121
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
122122
)
123123
threads.append(thread)
124-
base_gpu_id += server_args.tp_size
124+
base_gpu_id += server_args.tp_size * server_args.gpu_id_step
125125

126126
# Free all sockets before starting the threads to launch TP workers
127127
for sock in sockets:
@@ -177,7 +177,11 @@ def launch_tensor_parallel_group(
177177
rank_port_args.nccl_port = port_args.nccl_port
178178

179179
reader, writer = mp.Pipe(duplex=False)
180-
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
180+
gpu_id = (
181+
server_args.base_gpu_id
182+
+ base_gpu_id
183+
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
184+
)
181185
proc = mp.Process(
182186
target=run_scheduler_process,
183187
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),

python/sglang/srt/managers/io_struct.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,8 @@ class UpdateWeightsFromDistributedReqOutput:
449449
@dataclass
450450
class UpdateWeightsFromTensorReqInput:
451451
serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor]
452+
load_format: Optional[str]
453+
flush_cache: bool
452454

453455

454456
@dataclass

python/sglang/srt/managers/scheduler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,8 +1760,9 @@ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
17601760
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
17611761
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
17621762
if success:
1763-
flash_cache_success = self.flush_cache()
1764-
assert flash_cache_success, "Cache flush failed after updating weights"
1763+
if recv_req.flush_cache:
1764+
flash_cache_success = self.flush_cache()
1765+
assert flash_cache_success, "Cache flush failed after updating weights"
17651766
else:
17661767
logger.error(message)
17671768
return UpdateWeightsFromTensorReqOutput(success, message)

python/sglang/srt/managers/tp_worker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,10 @@ def update_weights_from_distributed(
205205

206206
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
207207
success, message = self.model_runner.update_weights_from_tensor(
208-
MultiprocessingSerializer.deserialize(recv_req.serialized_named_tensors)
208+
named_tensors=MultiprocessingSerializer.deserialize(
209+
recv_req.serialized_named_tensors
210+
),
211+
load_format=recv_req.load_format,
209212
)
210213
return success, message
211214

0 commit comments

Comments
 (0)