-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[RL] support update_weights_from_distributed with different group and multiple weights #7292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RL] support update_weights_from_distributed with different group and multiple weights #7292
Conversation
… multiple tensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @zhuzilin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces significant improvements to the distributed weight update mechanism. It allows for the concurrent update of multiple model parameters and the management of distinct, named weight update groups. These changes provide greater flexibility and efficiency for dynamic model updates in distributed environments.
Highlights
- Enhanced Distributed Weight Updates: The
update_weights_from_distributed
functionality has been significantly enhanced to support updating multiple model weights concurrently in a single operation. Previously, it handled only one weight at a time. - Support for Multiple Update Groups: The system now allows for the initialization and management of multiple distinct weight update process groups. This enables more flexible and isolated weight updates for different parts of a distributed system or different models.
- API and Data Structure Updates: The API for weight updates (
update_weights_from_distributed
) and its corresponding request input data structure (UpdateWeightsFromDistributedReqInput
) have been modified to accept lists of weight names, data types, and shapes, along with a specificgroup_name
for the update operation. - Asynchronous Broadcast for Efficiency: The underlying implementation in
model_runner.py
now leverages asynchronoustorch.distributed.broadcast
operations for each weight, improving efficiency when updating multiple parameters simultaneously within a specified distributed group.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request enhances the update_weights_from_distributed
functionality by enabling the update of multiple model weights concurrently and allowing the specification of different process groups. This is achieved by modifying UpdateWeightsFromDistributedReqInput
to accept lists for weight properties and updating ModelRunner
to manage multiple process groups and handle batch weight updates efficiently using asynchronous broadcasts.
The changes are generally well-implemented. Key areas for improvement include:
- Type Hinting: Several new and modified parameters in method signatures lack type hints or have incorrect ones (e.g.,
names: str
should benames: List[str]
). Adding correct type hints will improve code clarity and help prevent errors. - Documentation: Docstrings for the modified public methods (
Engine.update_weights_from_distributed
andModelRunner.update_weights_from_distributed
) should be updated to accurately describe the new parameters and their behavior, especially the handling of multiple weights and groups. - Testing: The PR checklist indicates that unit tests have not been added. Given the distributed nature and critical functionality of weight updates, it's highly recommended to add comprehensive unit and integration tests. These tests should cover scenarios with single and multiple weights, different group names, and potential edge cases to ensure robustness.
Overall, the PR introduces valuable flexibility. Addressing the points above will further improve the code quality and maintainability.
@@ -417,12 +417,15 @@ def init_weights_update_group( | |||
self.tokenizer_manager.init_weights_update_group(obj, None) | |||
) | |||
|
|||
def update_weights_from_distributed(self, name: str, dtype, shape): | |||
def update_weights_from_distributed( | |||
self, names: str, dtypes, shapes, group_name: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hints for names
, dtypes
, and shapes
in the method signature should be updated. Currently, names
is incorrectly hinted as str
, and dtypes
and shapes
are missing type hints.
Based on the usage and the definition in UpdateWeightsFromDistributedReqInput
(in io_struct.py
), these should be:
names: List[str]
dtypes: List[str]
shapes: List[List[int]]
from typing import List # Ensure List is imported if not already
# ...
def update_weights_from_distributed(
self, names: List[str], dtypes: List[str], shapes: List[List[int]], group_name: str
):
# ...
Additionally, the docstring for this method (line 423) should be updated to describe the new parameters (names
, dtypes
, shapes
, group_name
) and their expected types and roles, especially since they now handle multiple weights. For example:
"""Update one or more weights from a distributed source using a specific process group.
Args:
names: A list of names for the parameters to be updated.
dtypes: A list of data types for the parameters.
shapes: A list of shapes for the parameters.
group_name: The name of the initialized process group to use for the broadcast.
"""
self, names: str, dtypes, shapes, group_name: str | |
self, names: List[str], dtypes: List[str], shapes: List[List[int]], group_name: str |
@@ -720,7 +721,7 @@ def init_weights_update_group( | |||
logger.error(message) | |||
return False, message | |||
|
|||
def update_weights_from_distributed(self, name, dtype, shape): | |||
def update_weights_from_distributed(self, names, dtypes, shapes, group_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameters names
, dtypes
, and shapes
in the method signature are missing type hints. Based on their usage and the changes in io_struct.py
, they should be:
names: List[str]
dtypes: List[str]
shapes: List[List[int]]
from typing import List # Ensure List is imported if not already
# ...
def update_weights_from_distributed(self, names: List[str], dtypes: List[str], shapes: List[List[int]], group_name: str):
# ...
The docstring for this method (lines 725-733) also needs to be updated to reflect these changes:
- The description should mention updating one or more parameters.
- Parameter
name
should benames
(list of parameter names). - Parameter
dtype
should bedtypes
(list of data types). - Parameter
shape
should beshapes
(list of shapes). - Add documentation for the new
group_name
parameter.
Example updated docstring:
"""
Update one or more parameters in the model weights online
through the specified `_model_update_group` process group.
Args:
names: A list of names for the parameters to be updated.
dtypes: A list of data types for the parameters.
shapes: A list of shapes for the parameters.
group_name: The name of the process group to use for broadcasting weights.
"""
def update_weights_from_distributed(self, names, dtypes, shapes, group_name): | |
def update_weights_from_distributed(self, names: List[str], dtypes: List[str], shapes: List[List[int]], group_name: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, weight for unit test. And could add a test to small MOE model?
Could we add a unit test to MOE models? |
This looks good to me. We can merge it as long as it passes existing weight sync tests |
@zhuzilin Increase CI time 2 gpu to 35min |
2f5e710
to
c358da9
Compare
Motivation
This PR is attempting to support initializing multiple model update group and updating multiple weigthts concurrently.
Thank you for your time on reviewing this PR :)
Modifications
Checklist