Skip to content

[RL] support update_weights_from_distributed with different group and multiple weights #7292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

zhuzilin
Copy link
Collaborator

@zhuzilin zhuzilin commented Jun 18, 2025

Motivation

This PR is attempting to support initializing multiple model update group and updating multiple weigthts concurrently.

Thank you for your time on reviewing this PR :)

Modifications

Checklist

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @zhuzilin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces significant improvements to the distributed weight update mechanism. It allows for the concurrent update of multiple model parameters and the management of distinct, named weight update groups. These changes provide greater flexibility and efficiency for dynamic model updates in distributed environments.

Highlights

  • Enhanced Distributed Weight Updates: The update_weights_from_distributed functionality has been significantly enhanced to support updating multiple model weights concurrently in a single operation. Previously, it handled only one weight at a time.
  • Support for Multiple Update Groups: The system now allows for the initialization and management of multiple distinct weight update process groups. This enables more flexible and isolated weight updates for different parts of a distributed system or different models.
  • API and Data Structure Updates: The API for weight updates (update_weights_from_distributed) and its corresponding request input data structure (UpdateWeightsFromDistributedReqInput) have been modified to accept lists of weight names, data types, and shapes, along with a specific group_name for the update operation.
  • Asynchronous Broadcast for Efficiency: The underlying implementation in model_runner.py now leverages asynchronous torch.distributed.broadcast operations for each weight, improving efficiency when updating multiple parameters simultaneously within a specified distributed group.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enhances the update_weights_from_distributed functionality by enabling the update of multiple model weights concurrently and allowing the specification of different process groups. This is achieved by modifying UpdateWeightsFromDistributedReqInput to accept lists for weight properties and updating ModelRunner to manage multiple process groups and handle batch weight updates efficiently using asynchronous broadcasts.

The changes are generally well-implemented. Key areas for improvement include:

  1. Type Hinting: Several new and modified parameters in method signatures lack type hints or have incorrect ones (e.g., names: str should be names: List[str]). Adding correct type hints will improve code clarity and help prevent errors.
  2. Documentation: Docstrings for the modified public methods (Engine.update_weights_from_distributed and ModelRunner.update_weights_from_distributed) should be updated to accurately describe the new parameters and their behavior, especially the handling of multiple weights and groups.
  3. Testing: The PR checklist indicates that unit tests have not been added. Given the distributed nature and critical functionality of weight updates, it's highly recommended to add comprehensive unit and integration tests. These tests should cover scenarios with single and multiple weights, different group names, and potential edge cases to ensure robustness.

Overall, the PR introduces valuable flexibility. Addressing the points above will further improve the code quality and maintainability.

@@ -417,12 +417,15 @@ def init_weights_update_group(
self.tokenizer_manager.init_weights_update_group(obj, None)
)

def update_weights_from_distributed(self, name: str, dtype, shape):
def update_weights_from_distributed(
self, names: str, dtypes, shapes, group_name: str

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hints for names, dtypes, and shapes in the method signature should be updated. Currently, names is incorrectly hinted as str, and dtypes and shapes are missing type hints.

Based on the usage and the definition in UpdateWeightsFromDistributedReqInput (in io_struct.py), these should be:

  • names: List[str]
  • dtypes: List[str]
  • shapes: List[List[int]]
from typing import List # Ensure List is imported if not already

# ...
def update_weights_from_distributed(
    self, names: List[str], dtypes: List[str], shapes: List[List[int]], group_name: str
):
# ...

Additionally, the docstring for this method (line 423) should be updated to describe the new parameters (names, dtypes, shapes, group_name) and their expected types and roles, especially since they now handle multiple weights. For example:

"""Update one or more weights from a distributed source using a specific process group.

Args:
    names: A list of names for the parameters to be updated.
    dtypes: A list of data types for the parameters.
    shapes: A list of shapes for the parameters.
    group_name: The name of the initialized process group to use for the broadcast.
"""
Suggested change
self, names: str, dtypes, shapes, group_name: str
self, names: List[str], dtypes: List[str], shapes: List[List[int]], group_name: str

@@ -720,7 +721,7 @@ def init_weights_update_group(
logger.error(message)
return False, message

def update_weights_from_distributed(self, name, dtype, shape):
def update_weights_from_distributed(self, names, dtypes, shapes, group_name):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameters names, dtypes, and shapes in the method signature are missing type hints. Based on their usage and the changes in io_struct.py, they should be:

  • names: List[str]
  • dtypes: List[str]
  • shapes: List[List[int]]
from typing import List # Ensure List is imported if not already

# ...
def update_weights_from_distributed(self, names: List[str], dtypes: List[str], shapes: List[List[int]], group_name: str):
# ...

The docstring for this method (lines 725-733) also needs to be updated to reflect these changes:

  • The description should mention updating one or more parameters.
  • Parameter name should be names (list of parameter names).
  • Parameter dtype should be dtypes (list of data types).
  • Parameter shape should be shapes (list of shapes).
  • Add documentation for the new group_name parameter.

Example updated docstring:

"""
Update one or more parameters in the model weights online
through the specified `_model_update_group` process group.

Args:
    names: A list of names for the parameters to be updated.
    dtypes: A list of data types for the parameters.
    shapes: A list of shapes for the parameters.
    group_name: The name of the process group to use for broadcasting weights.
"""
Suggested change
def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
def update_weights_from_distributed(self, names: List[str], dtypes: List[str], shapes: List[List[int]], group_name: str):

Copy link
Collaborator

@zhaochenyang20 zhaochenyang20 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, weight for unit test. And could add a test to small MOE model?

@zhuzilin zhuzilin changed the title [rl] support update_weights_from_distributed with different group and multiple weights [RL] support update_weights_from_distributed with different group and multiple weights Jun 21, 2025
@zhaochenyang20
Copy link
Collaborator

Could we add a unit test to MOE models?

@merrymercy
Copy link
Contributor

This looks good to me. We can merge it as long as it passes existing weight sync tests

@zhaochenyang20
Copy link
Collaborator

@zhuzilin Increase CI time 2 gpu to 35min

@zhuzilin zhuzilin force-pushed the feature/update_weights_from_distributed branch from 2f5e710 to c358da9 Compare July 2, 2025 07:53
@zhuzilin zhuzilin requested a review from zhaochenyang20 July 2, 2025 10:06
@zhyncs zhyncs merged commit 0626f67 into sgl-project:main Jul 3, 2025
95 of 108 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants