Skip to content

Commit 99d388a

Browse files
robertgshaw2-redhatrshaw@neuralmagic.comDarkLight1337russellbafeldman-nm
authored andcommitted
[V1][Frontend] Improve Shutdown And Logs (vllm-project#11737)
Signed-off-by: [email protected] <[email protected]> Signed-off-by: Andrew Feldman <[email protected]> Signed-off-by: Nick Hill <[email protected]> Co-authored-by: [email protected] <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Russell Bryant <[email protected]> Co-authored-by: Andrew Feldman <[email protected]> Co-authored-by: afeldman-nm <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent f546ba3 commit 99d388a

File tree

16 files changed

+1047
-363
lines changed

16 files changed

+1047
-363
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,7 @@ steps:
552552
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
553553
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
554554
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
555+
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
555556

556557
- label: Plugin Tests (2 GPUs) # 40min
557558
working_dir: "/vllm-workspace/tests"

tests/v1/shutdown/test_delete.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Test that we handle a startup Error and shutdown."""
3+
4+
import pytest
5+
6+
from tests.utils import wait_for_gpu_memory_to_clear
7+
from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES,
8+
SHUTDOWN_TEST_TIMEOUT_SEC)
9+
from vllm import LLM, SamplingParams
10+
from vllm.engine.arg_utils import AsyncEngineArgs
11+
from vllm.sampling_params import RequestOutputKind
12+
from vllm.utils import cuda_device_count_stateless
13+
from vllm.v1.engine.async_llm import AsyncLLM
14+
15+
MODELS = ["meta-llama/Llama-3.2-1B"]
16+
17+
18+
@pytest.mark.asyncio
19+
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
20+
@pytest.mark.parametrize("model", MODELS)
21+
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
22+
@pytest.mark.parametrize("send_one_request", [False, True])
23+
async def test_async_llm_delete(model: str, tensor_parallel_size: int,
24+
send_one_request: bool) -> None:
25+
"""Test that AsyncLLM frees GPU memory upon deletion.
26+
AsyncLLM always uses an MP client.
27+
28+
Args:
29+
model: model under test
30+
tensor_parallel_size: degree of tensor parallelism
31+
send_one_request: send one request to engine before deleting
32+
"""
33+
if cuda_device_count_stateless() < tensor_parallel_size:
34+
pytest.skip(reason="Not enough CUDA devices")
35+
36+
engine_args = AsyncEngineArgs(model=model,
37+
enforce_eager=True,
38+
tensor_parallel_size=tensor_parallel_size)
39+
40+
# Instantiate AsyncLLM; make request to complete any deferred
41+
# initialization; then delete instance
42+
async_llm = AsyncLLM.from_engine_args(engine_args)
43+
if send_one_request:
44+
async for _ in async_llm.generate(
45+
"Hello my name is",
46+
request_id="abc",
47+
sampling_params=SamplingParams(
48+
max_tokens=1, output_kind=RequestOutputKind.DELTA)):
49+
pass
50+
del async_llm
51+
52+
# Confirm all the processes are cleaned up.
53+
wait_for_gpu_memory_to_clear(
54+
devices=list(range(tensor_parallel_size)),
55+
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
56+
)
57+
58+
59+
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
60+
@pytest.mark.parametrize("model", MODELS)
61+
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
62+
@pytest.mark.parametrize("enable_multiprocessing", [True])
63+
@pytest.mark.parametrize("send_one_request", [False, True])
64+
def test_llm_delete(monkeypatch, model: str, tensor_parallel_size: int,
65+
enable_multiprocessing: bool,
66+
send_one_request: bool) -> None:
67+
"""Test that LLM frees GPU memory upon deletion.
68+
TODO(andy) - LLM without multiprocessing.
69+
70+
Args:
71+
model: model under test
72+
tensor_parallel_size: degree of tensor parallelism
73+
enable_multiprocessing: enable workers in separate process(es)
74+
send_one_request: send one request to engine before deleting
75+
"""
76+
if cuda_device_count_stateless() < tensor_parallel_size:
77+
pytest.skip(reason="Not enough CUDA devices")
78+
79+
with monkeypatch.context() as m:
80+
MP_VALUE = "1" if enable_multiprocessing else "0"
81+
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)
82+
83+
# Instantiate LLM; make request to complete any deferred
84+
# initialization; then delete instance
85+
llm = LLM(model=model,
86+
enforce_eager=True,
87+
tensor_parallel_size=tensor_parallel_size)
88+
if send_one_request:
89+
llm.generate("Hello my name is",
90+
sampling_params=SamplingParams(max_tokens=1))
91+
del llm
92+
93+
# Confirm all the processes are cleaned up.
94+
wait_for_gpu_memory_to_clear(
95+
devices=list(range(tensor_parallel_size)),
96+
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
97+
)
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Test that we handle an Error in model forward and shutdown."""
3+
4+
import asyncio
5+
6+
import pytest
7+
8+
from tests.utils import wait_for_gpu_memory_to_clear
9+
from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES,
10+
SHUTDOWN_TEST_TIMEOUT_SEC)
11+
from vllm import LLM, AsyncEngineArgs, SamplingParams
12+
from vllm.distributed import get_tensor_model_parallel_rank
13+
from vllm.model_executor.models.llama import LlamaForCausalLM
14+
from vllm.utils import cuda_device_count_stateless
15+
from vllm.v1.engine.async_llm import AsyncLLM
16+
from vllm.v1.engine.exceptions import EngineDeadError
17+
18+
MODELS = ["meta-llama/Llama-3.2-1B"]
19+
20+
21+
def evil_forward(self, *args, **kwargs):
22+
"""Evil forward method that raise an exception after 10 calls."""
23+
NUMBER_OF_GOOD_PASSES = 10
24+
25+
if not hasattr(self, "num_calls"):
26+
self.num_calls = 0
27+
28+
if (self.num_calls == NUMBER_OF_GOOD_PASSES
29+
and get_tensor_model_parallel_rank() == 0):
30+
raise Exception("Simulated illegal memory access on Rank 0!")
31+
self.num_calls += 1
32+
33+
return self.model(*args, **kwargs)
34+
35+
36+
@pytest.mark.asyncio
37+
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
38+
@pytest.mark.parametrize("model", MODELS)
39+
async def test_async_llm_model_error(monkeypatch, tensor_parallel_size: int,
40+
model: str) -> None:
41+
"""Test that AsyncLLM propagates a forward pass error and frees memory.
42+
43+
AsyncLLM always uses an MP client.
44+
"""
45+
if cuda_device_count_stateless() < tensor_parallel_size:
46+
pytest.skip(reason="Not enough CUDA devices")
47+
48+
# Monkeypatch an error in the model.
49+
monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward)
50+
51+
engine_args = AsyncEngineArgs(model=model,
52+
enforce_eager=True,
53+
tensor_parallel_size=tensor_parallel_size)
54+
async_llm = AsyncLLM.from_engine_args(engine_args)
55+
56+
async def generate(request_id: str):
57+
generator = async_llm.generate("Hello my name is",
58+
request_id=request_id,
59+
sampling_params=SamplingParams())
60+
try:
61+
async for _ in generator:
62+
pass
63+
except Exception as e:
64+
return e
65+
66+
NUM_REQS = 3
67+
tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)]
68+
outputs = await asyncio.gather(*tasks)
69+
70+
# Every request should get an EngineDeadError.
71+
for output in outputs:
72+
assert isinstance(output, EngineDeadError)
73+
74+
# AsyncLLM should be errored.
75+
assert async_llm.errored
76+
77+
# We should not be able to make another request.
78+
with pytest.raises(EngineDeadError):
79+
async for _ in async_llm.generate("Hello my name is",
80+
request_id="abc",
81+
sampling_params=SamplingParams()):
82+
raise Exception("We should not get here.")
83+
84+
# Confirm all the processes are cleaned up.
85+
wait_for_gpu_memory_to_clear(
86+
devices=list(range(tensor_parallel_size)),
87+
threshold_bytes=2 * 2**30,
88+
timeout_s=60,
89+
)
90+
91+
# NOTE: shutdown is handled by the API Server if an exception
92+
# occurs, so it is expected that we would need to call this.
93+
async_llm.shutdown()
94+
95+
96+
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
97+
@pytest.mark.parametrize("enable_multiprocessing", [True])
98+
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
99+
@pytest.mark.parametrize("model", MODELS)
100+
def test_llm_model_error(monkeypatch, tensor_parallel_size: int,
101+
enable_multiprocessing: bool, model: str) -> None:
102+
"""Test that LLM propagates a forward pass error and frees memory.
103+
TODO(andy) - LLM without multiprocessing; LLM with multiprocessing
104+
and >1 rank
105+
"""
106+
if cuda_device_count_stateless() < tensor_parallel_size:
107+
pytest.skip(reason="Not enough CUDA devices")
108+
109+
with monkeypatch.context() as m:
110+
111+
MP_VALUE = "1" if enable_multiprocessing else "0"
112+
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)
113+
114+
# Monkeypatch an error in the model.
115+
m.setattr(LlamaForCausalLM, "forward", evil_forward)
116+
117+
llm = LLM(model=model,
118+
enforce_eager=True,
119+
tensor_parallel_size=tensor_parallel_size)
120+
121+
with pytest.raises(
122+
EngineDeadError if enable_multiprocessing else Exception):
123+
llm.generate("Hello my name is Robert and I")
124+
125+
# Confirm all the processes are cleaned up.
126+
wait_for_gpu_memory_to_clear(
127+
devices=list(range(tensor_parallel_size)),
128+
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
129+
)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Test error handling in Processor. Should not impact other reqs."""
3+
4+
import asyncio
5+
6+
import pytest
7+
8+
from tests.v1.shutdown.utils import SHUTDOWN_TEST_TIMEOUT_SEC
9+
from vllm import SamplingParams
10+
from vllm.engine.arg_utils import AsyncEngineArgs
11+
from vllm.inputs.data import TokensPrompt
12+
from vllm.sampling_params import RequestOutputKind
13+
from vllm.v1.engine.async_llm import AsyncLLM
14+
from vllm.v1.engine.exceptions import EngineGenerateError
15+
16+
MODELS = ["meta-llama/Llama-3.2-1B"]
17+
18+
19+
@pytest.mark.asyncio
20+
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
21+
@pytest.mark.parametrize("model", MODELS)
22+
async def test_async_llm_processor_error(model: str) -> None:
23+
"""Test that AsyncLLM propagates a processor error.
24+
Test empty tokens prompt (failure) and non-empty prompt (no failure.)
25+
AsyncLLM always uses an MP client.
26+
"""
27+
engine_args = AsyncEngineArgs(model=model, enforce_eager=True)
28+
async_llm = AsyncLLM.from_engine_args(engine_args)
29+
30+
async def generate(request_id: str):
31+
# [] is not allowed and will raise a ValueError in Processor.
32+
generator = async_llm.generate(TokensPrompt([]),
33+
request_id=request_id,
34+
sampling_params=SamplingParams())
35+
try:
36+
async for _ in generator:
37+
pass
38+
except Exception as e:
39+
return e
40+
41+
NUM_REQS = 3
42+
tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)]
43+
outputs = await asyncio.gather(*tasks)
44+
45+
# Every request should have get an EngineGenerateError.
46+
for output in outputs:
47+
with pytest.raises(EngineGenerateError):
48+
raise output
49+
50+
# AsyncLLM should be errored.
51+
assert not async_llm.errored
52+
53+
# This should be no problem.
54+
EXPECTED_TOKENS = 5
55+
outputs = []
56+
async for out in async_llm.generate(
57+
"Hello my name is",
58+
request_id="abc",
59+
sampling_params=SamplingParams(
60+
max_tokens=EXPECTED_TOKENS,
61+
output_kind=RequestOutputKind.DELTA)):
62+
outputs.append(out)
63+
64+
generated_tokens = []
65+
for out in outputs:
66+
generated_tokens.extend(out.outputs[0].token_ids)
67+
assert len(generated_tokens) == EXPECTED_TOKENS
68+
69+
async_llm.shutdown()

0 commit comments

Comments
 (0)