-
Notifications
You must be signed in to change notification settings - Fork 2.3k
SGLang + Verl #3852
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SGLang + Verl #3852
Conversation
class VerlEngine: | ||
def __init__( | ||
self, | ||
device_mesh_cpu: DeviceMesh, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This device mesh has only one dimension. Can we use ProcessGroup
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am personally OK for whatever API here, but the original feature request #2736 seems to pass in a 1D DeviceMesh so my default is to align with that.
EDIT: Btw quickly searched but ProcessGroup does not seem to have API like device_mesh_cpu.mesh[0]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ProcessGroup does not seem to have API like device_mesh_cpu.mesh[0].
Can we use dist.get_global_rank(group, 0)
or dist.get_process_group_ranks(group)[0]
?
I feel that the SGLang community is more familiar with ProcessGroup
. It would be great if we can keep such consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable. Just now realized another tiny issue: It seems the DTensor weights is in FSDP DeviceMesh, so if we want to utilize DTensor.redistribute
to SGLang mesh, we may need to have a DevcieMesh object. (Currently I do full_tensor() following Verl VLLM, and redistribute changing mesh is a not-yet-done in torch, so this is not in code, and wait for profiling to avoid premature optimizations)
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank)) | ||
for name, tensor in named_tensors | ||
] | ||
# TODO should we name it "direct" or "megatron"? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
based on its implementation, I recommend "direct".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
P.S. #2736 named it "megatron", while I feel "direct" may be a bit more suitable, thus I leave the question here.
@@ -269,6 +212,79 @@ def __exit__(self, exc_type, exc_value, traceback): | |||
self.model_proc.terminate() | |||
self.in_queue = self.out_queue = None | |||
|
|||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this refactor necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(see below)
@@ -408,6 +374,84 @@ def __exit__(self, exc_type, exc_value, traceback): | |||
self.engine.shutdown() | |||
del self.engine | |||
|
|||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this refactor necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(see below)
python/sglang/test/runners.py
Outdated
mem_fraction_static=mem_fraction_static, | ||
trust_remote_code=False, | ||
mem_fraction_static=0.65, | ||
trust_remote_code=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change necessary? Many other tests use this code. It would be better to keep the original version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For changes in test/runners.py:
Firstly, it is both OK for me to refactor (to avoid code duplication) or to copy (to avoid changing existing code), though I personally slightly prefer refactoring, thus I commented # TODO Ask: is it ok to refactor test code like this
in the code. Indeed zhaochenyang20 above seems to say LGTM.
Secondly, it is refactored because, in test_verl_engine.py, I made some comparison tests to ensure HuggingFace outputs are the same as SGLang outputs. The test_verl_engine.py roughly mimics adhoc_verl_torchrun.py, which is a minimal modification from guangming's Verl integration test script. This is quite similar to how comparison tests are done in test_generation_models.py, thus common logic are extracted.
For trust_remote_code
, IIRC it is because some model (maybe THUDM/glm-4-9b-chat
?) requires this. I copied the list of models in test_generation_models.py and put it in test_verl_engine.py and test them, and this model comes from there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I leave the decision to @zhaochenyang20 as he is more knowledgeable about this refractor's potential impact.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, I don't think we need to change the trust_remote_code
. THUDM/glm-4-9b-chat
is not a widely used LLM. If we need to change the parameter to this model, we'd better delete this model in test cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, DRY (Don't Repeat Yourself). I agree with tom's refactor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied it from test_generation_models.py (the ALL_OTHER_MODELS
section) - not sure whether we are allowed to delete an existing test.
python/sglang/test/runners.py
Outdated
@@ -130,7 +130,7 @@ def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): | |||
self.base_model = AutoModelForCausalLM.from_pretrained( | |||
model_path, | |||
torch_dtype=torch_dtype, | |||
trust_remote_code=False, | |||
trust_remote_code=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change necessary? Many other tests use this code. It would be better to keep the original version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(see above)
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype) | ||
self.tokenizer = get_tokenizer( | ||
model_path, torch_dtype=torch.dtype, trust_remote_code=True | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change necessary? Many other tests use this code. It would be better to keep the original version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(see above)
python/sglang/test/runners.py
Outdated
mem_fraction_static=mem_fraction_static, | ||
trust_remote_code=False, | ||
mem_fraction_static=0.65, | ||
trust_remote_code=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, DRY (Don't Repeat Yourself). I agree with tom's refactor.
dist.gather_object( | ||
obj=serialized_tensor, | ||
object_gather_list=gathered_serialized_tensors, | ||
dst=self._device_mesh_cpu.mesh.tolist()[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: (I am not familiar with device mesh so this might be a stupid question.) Does the self._device_mesh_cpu.mesh.tolist()[0]
return the global_rank for local_rank=0? Will it be more clear if we use group_dst=0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it is getting first global rank in group. About group_dst
, I am a bit confused - it seems gather_object
does not provide this API parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, it is introduced in pytorch 2.6. Maybe we can just stick to dst
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I approve this PR. Avoiding full_tensor
can be explored in future PRs.
Great. All set. |
Great!! |
Motivation
Still WIP, mark as "ready for review" just to check CI.Ready for review
Modifications
Checklist