-
-
Notifications
You must be signed in to change notification settings - Fork 8.5k
[Feature] support sequence parallelism using compilation pass #16155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 7 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
9f4dd67
add reduce scatter op and register all gather
cascade812 09caae6
replace all reduce with reduce scatter and all gather
cascade812 84f4360
match first embedding
cascade812 165216d
update embedding replace pattern
cascade812 4318d65
compile graph only for specific shapes
cascade812 abd2953
clean code
cascade812 ca7fcb1
add test and rename
cascade812 ffb2e24
address comments
cascade812 4695110
update
cascade812 662e698
pass in dtype and device
cascade812 f60a871
Merge branch 'main' into sp_pass
tlrmchlsmth 9a72e10
enable rms_norm automatically if enable_sequence_parallelism=True
cascade812 552857c
add test for sq pass
cascade812 629e942
fix failed tests
cascade812 1a60865
fix failed tests
cascade812 0736045
fix failed tests
cascade812 534af36
address comments
cascade812 c16a197
minor fix
cascade812 82527a1
update test
cascade812 5b12ce5
test FixFunctionalizationPass with SequenceParallelismPass
cascade812 230ee3c
remove redundant code
cascade812 57d684d
Merge remote-tracking branch 'origin' into sp_pass
cascade812 8dc0422
remove the singleton pattern to support two LLM instances.
cascade812 b251ad5
nit
cascade812 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,296 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
""" | ||
WARNING: This test runs in both single-node (4 GPUs) and multi-node | ||
(2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is | ||
important to set the distributed backend to "mp" to avoid Ray scheduling | ||
all workers in a node other than the head node, which can cause the test | ||
to fail. | ||
""" | ||
import json | ||
import os | ||
from dataclasses import dataclass | ||
from typing import Literal, NamedTuple, Optional | ||
|
||
import pytest | ||
|
||
from vllm.config import TaskOption | ||
from vllm.logger import init_logger | ||
|
||
from ..models.registry import HF_EXAMPLE_MODELS | ||
from ..utils import compare_two_settings, create_new_process_for_each_test | ||
|
||
logger = init_logger("test_sequence_parallel") | ||
|
||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" | ||
|
||
|
||
class ParallelSetup(NamedTuple): | ||
tp_size: int | ||
sp_enabled: bool | ||
eager_mode: bool | ||
chunked_prefill: bool | ||
|
||
|
||
class SPTestOptions(NamedTuple): | ||
multi_node_only: bool | ||
load_format: Optional[str] = None | ||
|
||
|
||
@dataclass | ||
class SPTestSettings: | ||
parallel_setups: list[ParallelSetup] | ||
# NOTE: the length of distributed_backends and | ||
# vllm_major_versions should be the same, and they | ||
# are first zipped together to iterate over all | ||
# test settings. | ||
distributed_backends: list[str] | ||
# vllm major version: "0" for V0, "1" for V1 | ||
vllm_major_versions: list[str] | ||
task: TaskOption | ||
test_options: SPTestOptions | ||
|
||
def __post_init__(self): | ||
if len(self.distributed_backends) != len(self.vllm_major_versions): | ||
raise ValueError( | ||
f"Length mismatch: distributed_backends " | ||
f"({len(self.distributed_backends)}) != " | ||
f"vllm_major_versions ({len(self.vllm_major_versions)})") | ||
|
||
@staticmethod | ||
def detailed( | ||
*, | ||
tp_base: int = 2, | ||
multi_node_only: bool = False, | ||
task: TaskOption = "auto", | ||
load_format: Optional[str] = None, | ||
): | ||
return SPTestSettings( | ||
parallel_setups=[ | ||
ParallelSetup(tp_size=tp_base, | ||
sp_enabled=True, | ||
eager_mode=False, | ||
chunked_prefill=False), | ||
ParallelSetup(tp_size=tp_base, | ||
sp_enabled=True, | ||
eager_mode=False, | ||
chunked_prefill=True), | ||
ParallelSetup(tp_size=tp_base, | ||
sp_enabled=True, | ||
eager_mode=True, | ||
chunked_prefill=False), | ||
ParallelSetup(tp_size=tp_base, | ||
sp_enabled=True, | ||
eager_mode=True, | ||
chunked_prefill=True) | ||
], | ||
distributed_backends=["mp", "ray"], | ||
vllm_major_versions=["1", "1"], | ||
task=task, | ||
test_options=SPTestOptions(multi_node_only=multi_node_only, | ||
load_format=load_format), | ||
) | ||
|
||
@staticmethod | ||
def fast( | ||
*, | ||
tp_base: int = 2, | ||
task: TaskOption = "auto", | ||
multi_node_only: bool = False, | ||
load_format: Optional[str] = None, | ||
): | ||
return SPTestSettings( | ||
parallel_setups=[ | ||
ParallelSetup(tp_size=tp_base, | ||
sp_enabled=True, | ||
eager_mode=False, | ||
chunked_prefill=False), | ||
], | ||
distributed_backends=["mp", "ray"], | ||
vllm_major_versions=["1", "1"], | ||
task=task, | ||
test_options=SPTestOptions(multi_node_only=multi_node_only, | ||
load_format=load_format), | ||
) | ||
|
||
def iter_params(self, model_id: str): | ||
opts = self.test_options | ||
|
||
for parallel_setup in self.parallel_setups: | ||
for backend, vllm_major_version in zip(self.distributed_backends, | ||
self.vllm_major_versions): | ||
yield (model_id, parallel_setup, backend, vllm_major_version, | ||
self.task, opts) | ||
|
||
|
||
def _compare_sp( | ||
model_id: str, | ||
parallel_setup: ParallelSetup, | ||
distributed_backend: str, | ||
vllm_major_version: str, | ||
task: TaskOption, | ||
test_options: SPTestOptions, | ||
num_gpus_available: int, | ||
*, | ||
method: Literal["generate", "encode"], | ||
is_multimodal: bool, | ||
): | ||
( | ||
tp_size, | ||
sp_enabled, | ||
eager_mode, | ||
chunked_prefill, | ||
) = parallel_setup | ||
|
||
multi_node_only, load_format = test_options | ||
|
||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) | ||
model_info.check_transformers_version(on_fail="skip") | ||
|
||
trust_remote_code = model_info.trust_remote_code | ||
tokenizer_mode = model_info.tokenizer_mode | ||
hf_overrides = model_info.hf_overrides | ||
|
||
if load_format == "dummy": | ||
# Avoid OOM | ||
text_overrides = { | ||
"num_hidden_layers": 4, | ||
"hidden_size": 512, | ||
"intermediate_size": 800, | ||
"num_attention_heads": 4, | ||
"num_key_value_heads": 1, | ||
} | ||
|
||
if is_multimodal: | ||
hf_overrides.update({"text_config": text_overrides}) | ||
else: | ||
hf_overrides.update(text_overrides) | ||
else: | ||
model_info.check_available_online(on_fail="skip") | ||
|
||
pp_size = 1 | ||
if num_gpus_available < tp_size * pp_size: | ||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") | ||
if VLLM_MULTI_NODE and distributed_backend == "mp": | ||
pytest.skip("Skipping multi-node pipeline parallel test for " | ||
"multiprocessing distributed backend") | ||
if multi_node_only and not VLLM_MULTI_NODE: | ||
pytest.skip("Not in multi-node setting") | ||
|
||
common_args = [ | ||
# use half precision for speed and memory savings in CI environment | ||
"--dtype", | ||
"float16", | ||
"--max-model-len", | ||
"2048", | ||
"--max-num-seqs", | ||
"8", | ||
] | ||
if chunked_prefill: | ||
common_args.append("--enable-chunked-prefill") | ||
if eager_mode: | ||
common_args.append("--enforce-eager") | ||
if task != "auto": | ||
common_args.extend(["--task", task]) | ||
if trust_remote_code: | ||
common_args.append("--trust-remote-code") | ||
if tokenizer_mode: | ||
common_args.extend(["--tokenizer-mode", tokenizer_mode]) | ||
if load_format: | ||
common_args.extend(["--load-format", load_format]) | ||
if hf_overrides: | ||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) | ||
|
||
compilation_config = { | ||
'level': 3, | ||
'custom_ops': ["+rms_norm"], | ||
'compile_sizes': [4, 8], | ||
'splitting_ops': [], | ||
'pass_config': { | ||
'enable_sequence_parallism': sp_enabled, | ||
'enable_noop': True, | ||
'enable_fusion': True, | ||
}, | ||
} | ||
|
||
tp_sp_env = tp_env = { | ||
"VLLM_USE_V1": vllm_major_version, | ||
} | ||
|
||
tp_sp_args = [ | ||
*common_args, | ||
"--tensor-parallel-size", | ||
str(tp_size), | ||
"--distributed-executor-backend", | ||
distributed_backend, | ||
"--compilation_config", | ||
str(compilation_config), | ||
] | ||
|
||
tp_env = { | ||
"VLLM_USE_V1": vllm_major_version, | ||
} | ||
tp_args = [ | ||
*common_args, | ||
"--tensor-parallel-size", | ||
str(tp_size), | ||
"--distributed-executor-backend", | ||
"mp", | ||
] | ||
|
||
try: | ||
compare_two_settings(model_id, | ||
tp_sp_args, | ||
tp_args, | ||
tp_sp_env, | ||
tp_env, | ||
method=method) | ||
except Exception: | ||
testing_ray_compiled_graph = tp_sp_env is not None | ||
if testing_ray_compiled_graph and vllm_major_version == "0": | ||
# Ray Compiled Graph tests are flaky for V0, | ||
# so we don't want to fail the test | ||
logger.exception("Ray Compiled Graph tests failed") | ||
else: | ||
raise | ||
|
||
|
||
SP_TEXT_GENERATION_MODELS = { | ||
# [Decoder-only] | ||
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.detailed(), | ||
} | ||
|
||
SP_TEST_MODELS = [ | ||
# TODO support other models | ||
# [LANGUAGE GENERATION] | ||
"meta-llama/Llama-3.2-1B-Instruct", | ||
] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", | ||
"task", "test_options"), | ||
[ | ||
params for model_id, settings in SP_TEXT_GENERATION_MODELS.items() | ||
for params in settings.iter_params(model_id) | ||
if model_id in SP_TEST_MODELS | ||
], | ||
) | ||
@create_new_process_for_each_test() | ||
def test_tp_sp_generation( | ||
model_id: str, | ||
parallel_setup: ParallelSetup, | ||
distributed_backend: str, | ||
vllm_major_version: str, | ||
task: TaskOption, | ||
test_options: SPTestOptions, | ||
num_gpus_available, | ||
): | ||
_compare_sp(model_id, | ||
parallel_setup, | ||
distributed_backend, | ||
vllm_major_version, | ||
task, | ||
test_options, | ||
num_gpus_available, | ||
method="generate", | ||
is_multimodal=False) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.