Skip to content

[Feature] support sequence parallelism using compilation pass #16155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Apr 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ steps:
commands:
- pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py

- label: PyTorch Fullgraph Smoke Test # 9min
source_file_dependencies:
Expand Down Expand Up @@ -583,6 +584,8 @@ steps:
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
# test sequence parallel
- pytest -v -s distributed/test_sequence_parallel.py
# this test fails consistently.
# TODO: investigate and fix
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
Expand Down
14 changes: 8 additions & 6 deletions tests/compile/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig
from vllm.config import CompilationConfig, VllmConfig

from .backend import TestBackend

Expand Down Expand Up @@ -49,13 +49,15 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
do_fusion: bool):
torch.set_default_device("cuda")

config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
enable_noop=True)
noop_pass = NoOpEliminationPass(config)
fusion_pass = FusionPass.instance(config)
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(pass_config= \
CompilationConfig.PassConfig(enable_fusion=do_fusion,
enable_noop=True))
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)

passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
func_pass = FixFunctionalizationPass(config)
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)

Expand Down
9 changes: 5 additions & 4 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,

vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
vllm_config.compilation_config.pass_config = \
CompilationConfig.PassConfig(enable_fusion=True,
enable_noop=True)
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
config = CompilationConfig.PassConfig(enable_fusion=True,
enable_noop=True)
noop_pass = NoOpEliminationPass(config)
fusion_pass = FusionPass.instance(config)
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)

backend = TestBackend(noop_pass, fusion_pass)
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)
Expand Down
9 changes: 5 additions & 4 deletions tests/compile/test_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.compilation.pass_manager import PostGradPassManager
from vllm.config import CompilationConfig
from vllm.config import VllmConfig


# dummy custom pass that doesn't inherit
Expand All @@ -16,7 +16,7 @@ def simple_callable(graph: torch.fx.Graph):

# Should fail to add directly to the pass manager
def test_bad_callable():
config = CompilationConfig().pass_config
config = VllmConfig()

pass_manager = PostGradPassManager()
pass_manager.configure(config)
Expand All @@ -43,7 +43,7 @@ def __call__(self, graph: torch.fx.graph.Graph) -> None:
],
)
def test_pass_manager_uuid(callable):
config = CompilationConfig().pass_config
config = VllmConfig()

pass_manager = PostGradPassManager()
pass_manager.configure(config)
Expand All @@ -64,7 +64,8 @@ def test_pass_manager_uuid(callable):

# UUID should be different due to config change
config2 = copy.deepcopy(config)
config2.enable_fusion = not config2.enable_fusion
config2.compilation_config.pass_config.enable_fusion = not \
config2.compilation_config.pass_config.enable_fusion
pass_manager3 = PostGradPassManager()
pass_manager3.configure(config2)
pass_manager3.add(callable)
Expand Down
190 changes: 190 additions & 0 deletions tests/compile/test_sequence_parallelism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

import vllm.envs as envs
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
find_specified_fn,
find_specified_fn_maybe, is_func)
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables

from ..utils import multi_gpu_test
from .backend import TestBackend

OPS_IN_MODEL_BEFORE = [
torch.ops.vllm.all_reduce.default,
]

OPS_IN_MODEL_AFTER = [
torch.ops.vllm.reduce_scatter.default,
torch.ops.vllm.all_gather.default,
]

OPS_IN_MODEL = [torch.ops._C.fused_add_rms_norm.default]

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]


class TestModel(torch.nn.Module):

def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size)))
self.norm = RMSNorm(hidden_size, 1e-05)
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)

def forward(self, hidden_states, residual):
"""
Forward pass implementing the operations in the FX graph

Args:
hidden_states: Input tensor
residual: Residual tensor from previous layer

Returns:
Tuple containing the output tensor
"""
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)

#matrix multiplication
permute = self.gate_proj.permute(1, 0)
mm = torch.mm(view, permute)

# Tensor parallel all-reduce
all_reduce = tensor_model_parallel_all_reduce(mm)

# layer normalization
norm_output, residual_output = self.norm(all_reduce, residual)

return norm_output, residual_output


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_sequence_parallelism_pass(batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
num_processes = 2

def run_torch_spawn(fn, nprocs):
# need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda
torch.multiprocessing.spawn(fn,
args=(num_processes, batch_size, seq_len,
hidden_size, dtype),
nprocs=nprocs)

run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)


def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
batch_size: int, seq_len: int,
hidden_size: int,
dtype: torch.dtype):
current_platform.seed_everything(0)

device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)

update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})

# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)

# configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
pass_config=CompilationConfig.PassConfig(
enable_sequence_parallelism=True, ), )
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))

# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config.model_config = ModelConfig(model=model,
task="auto",
tokenizer=model,
tokenizer_mode="auto",
trust_remote_code=True,
dtype=dtype,
seed=42)

sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
backend_no_func = TestBackend(sequence_parallelism_pass)
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(sequence_parallelism_pass, func_pass)

model = TestModel(hidden_size, hidden_size * 2)
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)

compiled_model_no_func = torch.compile(model, backend=backend_no_func)
compiled_model_no_func(hidden_states, residual)
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)

# Check substitution worked
pre_nodes = backend_no_func.graph_pre_pass.nodes
post_nodes = backend_no_func.graph_post_pass.nodes

# In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not
for op in OPS_IN_MODEL_BEFORE:
find_specified_fn(pre_nodes, op)
for op in OPS_IN_MODEL_AFTER:
assert find_specified_fn_maybe(pre_nodes, op) is None

# In post-nodes, reduce scatter and all gather should be there,
# all reduce should not
for op in OPS_IN_MODEL_AFTER:
find_specified_fn(post_nodes, op)
for op in OPS_IN_MODEL_BEFORE:
assert find_specified_fn_maybe(post_nodes, op) is None

# check if the functionalization pass is applied
for op in OPS_IN_MODEL:
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
op) is None # noqa: E501

# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in OPS_IN_MODEL:
if is_func(node, op):
found[op] = True
assert all(found[op] for op in OPS_IN_MODEL)
31 changes: 30 additions & 1 deletion tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)

from ..utils import init_test_distributed_environment, multi_process_parallel

Expand Down Expand Up @@ -47,6 +48,34 @@ def all_reduce_test_worker(
torch.testing.assert_close(t, expected)


@ray.remote(num_gpus=1, max_calls=1)
def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int,
pp_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)

num_elements = 8
all_tensors = [
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
(r + 1) for r in range(tp_size)
]

index = rank % tp_size
partition_size = num_elements // tp_size
all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
expected = all_reduce[index * partition_size:(index + 1) * partition_size]
t = all_tensors[index]
t = tensor_model_parallel_reduce_scatter(t, 0)
torch.testing.assert_close(t, expected)


@ray.remote(num_gpus=1, max_calls=1)
def all_gather_test_worker(
monkeypatch: pytest.MonkeyPatch,
Expand Down
Loading