-
-
Notifications
You must be signed in to change notification settings - Fork 8.5k
[V1][Frontend] Improve Shutdown And Logs #11737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 85 commits
eb16239
8549fdd
77801cd
1bbc3a4
8eca864
b8c77b3
ce9b8ef
3a760a7
3024da0
5af8189
3cb21bb
7c97308
ea6824a
b278065
c004bd4
2556bc4
db0b9e6
f722589
de75cc4
ba5ca87
4f6b68a
949d425
f67398b
b3d2994
34a997a
32cf91b
c73801c
1188845
706782c
1cc0915
8db0eee
2fc8af6
de39af1
732ba64
4372094
b9144a3
d90e122
2bbac31
c40542a
46734eb
f0baffb
8a7f18e
a662940
4ee6390
3e23ee2
45456f9
6128b1a
de24559
7adf26e
bf92854
8dae5c6
6b4fe88
efe85ee
6195795
0b25586
0b77b79
61f3dd7
fbf19ad
d25ce5c
23342d7
ebdf8f9
6a37020
2ed3349
f9ef3d8
95c249f
030c671
1bdb212
25412a0
7cf0647
352da94
a69e040
8dddc20
7b48b87
7400852
80317a0
ca37960
2d41499
4a39d39
43360f0
4d0f44f
218d095
c395634
042c486
b5a7b6f
dab77cf
f36305d
c99567e
a9219b0
a010281
4a733c9
64dcb24
3971d92
adebbe3
f23bc25
188d929
ae1dc32
33a7926
c2afedc
4648d85
4d5d280
1422551
59e2e29
4e6ca2d
89a5461
f60c8b5
9aed319
be1a23d
7d85fc5
7a3a5c2
9f672d8
79c4e19
5b332a9
07824d5
781dfcc
74d8e8f
d66844f
953db41
c4a7606
62f2c3e
2ee74b6
f229a86
86263dc
7b78cde
f824c15
40b0e15
7dc02fa
f1bce10
038aa31
d014a6b
c9941da
72740ca
9983d30
93c2001
1a76f36
5bde29d
1a0a217
e64c7c9
26005b0
1abcac3
1a4b6a0
863aa08
766338e
be9d356
92916a8
f02185d
95a45ba
cb70c37
6215c00
775e0c3
b309b45
3524115
de51ec1
a0536c4
76494dc
b5d8702
b067f8d
e94c89e
29912d5
6de94aa
060ecd9
4228bb4
b5acee3
6c540c3
da8c253
27d7d82
444a446
e33000e
b1977ac
0d0071a
4ce2771
e8672e8
7cf6b6f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,122 @@ | ||||||||||
# SPDX-License-Identifier: Apache-2.0 | ||||||||||
"""Test that we handle an Error in model forward and shutdown.""" | ||||||||||
|
||||||||||
import asyncio | ||||||||||
|
||||||||||
import pytest | ||||||||||
|
||||||||||
from tests.utils import wait_for_gpu_memory_to_clear | ||||||||||
from vllm import LLM, SamplingParams | ||||||||||
from vllm.distributed import get_tensor_model_parallel_rank | ||||||||||
from vllm.engine.arg_utils import AsyncEngineArgs | ||||||||||
from vllm.model_executor.models.llama import LlamaForCausalLM | ||||||||||
from vllm.utils import cuda_device_count_stateless | ||||||||||
from vllm.v1.engine.async_llm import AsyncLLM | ||||||||||
from vllm.v1.engine.exceptions import EngineDeadError | ||||||||||
|
||||||||||
|
||||||||||
def evil_forward(self, *args, **kwargs): | ||||||||||
"""Evil forward method that raise an exception after 10 calls.""" | ||||||||||
NUMBER_OF_GOOD_PASSES = 10 | ||||||||||
|
||||||||||
if not hasattr(self, "num_calls"): | ||||||||||
self.num_calls = 0 | ||||||||||
|
||||||||||
if (self.num_calls == NUMBER_OF_GOOD_PASSES | ||||||||||
and get_tensor_model_parallel_rank() == 0): | ||||||||||
raise Exception("Simulated illegal memory access on Rank 0!") | ||||||||||
self.num_calls += 1 | ||||||||||
|
||||||||||
return self.model(*args, **kwargs, intermediate_tensors=None) | ||||||||||
|
||||||||||
|
||||||||||
@pytest.mark.asyncio | ||||||||||
@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) | ||||||||||
async def test_async_llm_model_error(monkeypatch, tensor_parallel_size): | ||||||||||
|
||||||||||
if cuda_device_count_stateless() < tensor_parallel_size: | ||||||||||
pytest.skip(reason="Not enough CUDA devices") | ||||||||||
|
||||||||||
with monkeypatch.context() as m: | ||||||||||
m.setenv("VLLM_USE_V1", "1") | ||||||||||
|
||||||||||
# Monkeypatch an error in the model. | ||||||||||
monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you created a From reading the docs, it sounds like both will result in the same behavior, though I would find
Suggested change
|
||||||||||
|
||||||||||
engine_args = AsyncEngineArgs( | ||||||||||
model="meta-llama/Llama-3.2-1B", | ||||||||||
enforce_eager=True, | ||||||||||
tensor_parallel_size=tensor_parallel_size) | ||||||||||
async_llm = AsyncLLM.from_engine_args(engine_args) | ||||||||||
|
||||||||||
async def generate(request_id: str): | ||||||||||
generator = async_llm.generate("Hello my name is", | ||||||||||
request_id=request_id, | ||||||||||
sampling_params=SamplingParams()) | ||||||||||
try: | ||||||||||
async for _ in generator: | ||||||||||
pass | ||||||||||
except Exception as e: | ||||||||||
return e | ||||||||||
|
||||||||||
NUM_REQS = 3 | ||||||||||
tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)] | ||||||||||
outputs = await asyncio.gather(*tasks) | ||||||||||
|
||||||||||
# Every request should get an EngineDeadError. | ||||||||||
for output in outputs: | ||||||||||
assert isinstance(output, EngineDeadError) | ||||||||||
|
||||||||||
# AsyncLLM should be errored. | ||||||||||
assert async_llm.errored | ||||||||||
|
||||||||||
# We should not be able to make another request. | ||||||||||
with pytest.raises(EngineDeadError): | ||||||||||
async for _ in async_llm.generate( | ||||||||||
"Hello my name is", | ||||||||||
request_id="abc", | ||||||||||
sampling_params=SamplingParams()): | ||||||||||
raise Exception("We should not get here.") | ||||||||||
|
||||||||||
# Confirm all the processes are cleaned up. | ||||||||||
wait_for_gpu_memory_to_clear( | ||||||||||
devices=list(range(tensor_parallel_size)), | ||||||||||
threshold_bytes=2 * 2**30, | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This number kind of looks like magic. It would be great to get it put in a constant with a comment explaining it somewhere. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do have a |
||||||||||
timeout_s=60, | ||||||||||
) | ||||||||||
|
||||||||||
# NOTE: shutdown is handled by the API Server if an exception | ||||||||||
# occurs, so it is expected that we would need to call this. | ||||||||||
async_llm.shutdown() | ||||||||||
|
||||||||||
|
||||||||||
@pytest.mark.parametrize("enable_multiprocessing", [True, False]) | ||||||||||
@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) | ||||||||||
def test_llm_model_error(monkeypatch, tensor_parallel_size, | ||||||||||
enable_multiprocessing): | ||||||||||
|
||||||||||
if cuda_device_count_stateless() < tensor_parallel_size: | ||||||||||
pytest.skip(reason="Not enough CUDA devices") | ||||||||||
|
||||||||||
with monkeypatch.context() as m: | ||||||||||
m.setenv("VLLM_USE_V1", "1") | ||||||||||
|
||||||||||
MP_VALUE = "1" if enable_multiprocessing else "0" | ||||||||||
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) | ||||||||||
|
||||||||||
# Monkeypatch an error in the model. | ||||||||||
m.setattr(LlamaForCausalLM, "forward", evil_forward) | ||||||||||
|
||||||||||
llm = LLM(model="meta-llama/Llama-3.2-1B", | ||||||||||
enforce_eager=True, | ||||||||||
tensor_parallel_size=tensor_parallel_size) | ||||||||||
|
||||||||||
with pytest.raises(EngineDeadError): | ||||||||||
llm.generate("Hello my name is Robert and I") | ||||||||||
|
||||||||||
# Confirm all the processes are cleaned up. | ||||||||||
wait_for_gpu_memory_to_clear( | ||||||||||
devices=list(range(tensor_parallel_size)), | ||||||||||
threshold_bytes=2 * 2**30, | ||||||||||
timeout_s=60, | ||||||||||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Test error handling in Processor. Should not impact other reqs.""" | ||
|
||
import asyncio | ||
|
||
import pytest | ||
|
||
from vllm import SamplingParams | ||
from vllm.engine.arg_utils import AsyncEngineArgs | ||
from vllm.inputs.data import TokensPrompt | ||
from vllm.sampling_params import RequestOutputKind | ||
from vllm.v1.engine.async_llm import AsyncLLM | ||
from vllm.v1.engine.exceptions import EngineGenerateError | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_async_llm_processor_error(monkeypatch): | ||
|
||
with monkeypatch.context() as m: | ||
m.setenv("VLLM_USE_V1", "1") | ||
|
||
engine_args = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B", | ||
enforce_eager=True) | ||
async_llm = AsyncLLM.from_engine_args(engine_args) | ||
|
||
async def generate(request_id: str): | ||
# [] is not allowed and will raise a ValueError in Processor. | ||
generator = async_llm.generate(TokensPrompt([]), | ||
request_id=request_id, | ||
sampling_params=SamplingParams()) | ||
try: | ||
async for _ in generator: | ||
pass | ||
except Exception as e: | ||
return e | ||
|
||
NUM_REQS = 3 | ||
tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)] | ||
outputs = await asyncio.gather(*tasks) | ||
|
||
# Every request should have get an EngineGenerateError. | ||
for output in outputs: | ||
with pytest.raises(EngineGenerateError): | ||
raise output | ||
|
||
# AsyncLLM should be errored. | ||
assert not async_llm.errored | ||
|
||
# This should be no problem. | ||
EXPECTED_TOKENS = 5 | ||
outputs = [] | ||
async for out in async_llm.generate( | ||
"Hello my name is", | ||
request_id="abc", | ||
sampling_params=SamplingParams( | ||
max_tokens=EXPECTED_TOKENS, | ||
output_kind=RequestOutputKind.DELTA)): | ||
outputs.append(out) | ||
|
||
generated_tokens = [] | ||
for out in outputs: | ||
generated_tokens.extend(out.outputs[0].token_ids) | ||
assert len(generated_tokens) == EXPECTED_TOKENS | ||
|
||
async_llm.shutdown() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Test that we handle a startup Error and shutdown.""" | ||
|
||
import pytest | ||
|
||
from tests.utils import wait_for_gpu_memory_to_clear | ||
from vllm import LLM | ||
from vllm.distributed import get_tensor_model_parallel_rank | ||
from vllm.engine.arg_utils import AsyncEngineArgs | ||
from vllm.model_executor.models.llama import LlamaForCausalLM | ||
from vllm.utils import cuda_device_count_stateless | ||
from vllm.v1.engine.async_llm import AsyncLLM | ||
|
||
|
||
def evil_forward(self, *args, **kwargs): | ||
"""Evil forward method that raise an exception.""" | ||
|
||
if get_tensor_model_parallel_rank() == 0: | ||
raise Exception("Simulated Error in startup!") | ||
|
||
return self.model(*args, **kwargs, intermediate_tensors=None) | ||
|
||
|
||
MODELS = [ | ||
"meta-llama/Llama-3.2-1B", # Raises on first fwd pass. | ||
"mistralai/Mixtral-8x22B-Instruct-v0.1" # Causes OOM. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this going to download the model and get a real OOM? From a quick look, it doesn't look like this is used elsewhere, so that'd be a net-new model to download during tests? If so, that doesn't seem worth the cost, especially given how unreliable HF has been in CI lately. Maybe I'm misunderstanding, though! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. I do think that it is important to flex both the cases here since there is a subtle difference:
I will instead do a monkeypatch to raise an error on load_weights for case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. monkeypatch sounds good if the error encountered is clear enough! |
||
] | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) | ||
def test_async_llm_startup_error(monkeypatch, model, tensor_parallel_size): | ||
|
||
if cuda_device_count_stateless() < tensor_parallel_size: | ||
pytest.skip(reason="Not enough CUDA devices") | ||
|
||
with monkeypatch.context() as m: | ||
m.setenv("VLLM_USE_V1", "1") | ||
|
||
# Monkeypatch an error in the model. | ||
monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward) | ||
|
||
engine_args = AsyncEngineArgs( | ||
model=model, | ||
enforce_eager=True, | ||
tensor_parallel_size=tensor_parallel_size) | ||
|
||
# Confirm we get an exception. | ||
with pytest.raises(Exception, match="initialization failed"): | ||
_ = AsyncLLM.from_engine_args(engine_args) | ||
|
||
# Confirm all the processes are cleaned up. | ||
wait_for_gpu_memory_to_clear( | ||
devices=list(range(tensor_parallel_size)), | ||
threshold_bytes=2 * 2**30, | ||
timeout_s=60, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) | ||
@pytest.mark.parametrize("enable_multiprocessing", [True, False]) | ||
def test_llm_startup_error(monkeypatch, model, tensor_parallel_size, | ||
enable_multiprocessing): | ||
|
||
if cuda_device_count_stateless() < tensor_parallel_size: | ||
pytest.skip(reason="Not enough CUDA devices") | ||
|
||
with monkeypatch.context() as m: | ||
m.setenv("VLLM_USE_V1", "1") | ||
|
||
MP_VALUE = "1" if enable_multiprocessing else "0" | ||
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) | ||
|
||
# Monkeypatch an error in the model. | ||
monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward) | ||
|
||
with pytest.raises(Exception, match="initialization failed"): | ||
_ = LLM(model="meta-llama/Llama-3.2-1B", | ||
enforce_eager=True, | ||
tensor_parallel_size=tensor_parallel_size) | ||
|
||
# Confirm all the processes are cleaned up. | ||
wait_for_gpu_memory_to_clear( | ||
devices=list(range(tensor_parallel_size)), | ||
threshold_bytes=2 * 2**30, | ||
timeout_s=60, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just sort of a side note, but it seems like updating these commands would be really easy to miss when adding new tests in a new directory.