-
Notifications
You must be signed in to change notification settings - Fork 17
fix: avoid KeyError when cancelling requests that have not been processed #233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to vLLM support on Spyre.
Or this can be done with
Now you are good to go 🚀 |
@@ -603,15 +603,15 @@ def _update_states(self, scheduler_output): | |||
|
|||
# Continuous batching stuff | |||
for req_id in scheduler_output.finished_req_ids: | |||
if req_id in self.req_ids2blocks: | |||
# requests may be cancelled from the client side while in the queue | |||
if req_id in self.requests: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this not still need the test if req_id in self.req_ids2blocks
to not potentially cause a missing key error below (for freed_block in self.req_ids2blocks[req_id]
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, thats a good question.
If we look at this code in isolation, each del
on a map should be checked for presence of they key before the removal (eg. req_ids2left_pads
should be checked too). But the request is added to all maps during the course of _update_states
:
vllm-spyre/vllm_spyre/v1/worker/spyre_model_runner.py
Lines 650 to 688 in 97d03d6
self.req_ids2left_pads[ | |
request_data.req_id] = self.tkv - len(prompt_tokens) | |
input_token_list.append( | |
torch.tensor(prompt_tokens, | |
dtype=torch.long, | |
device=torch.device("cpu"))) | |
# filling block table and slot mapping | |
block_table_i = [] | |
slot_mapping_i = [] | |
for pos_i in range(block_padding): | |
if pos_i % self.BLOCK_SIZE == 0: | |
block_number = self.free_blocks.popleft() | |
block_table_i.append(block_number) | |
block_offset = pos_i % self.BLOCK_SIZE | |
slot = block_number * self.BLOCK_SIZE + block_offset | |
slot_mapping_i.append(slot) | |
self.req_ids2blocks[request_data.req_id] = deque(block_table_i) | |
slot_mapping.append(slot_mapping_i) | |
# Add new requests to the cached states. | |
req_id = request_data.req_id | |
sampling_params = request_data.sampling_params | |
if sampling_params.sampling_type == SamplingType.RANDOM_SEED: | |
generator = torch.Generator(device=self.device) | |
generator.manual_seed(sampling_params.seed) | |
else: | |
generator = None | |
req_state = CachedRequestState( | |
req_id=req_id, | |
prompt_token_ids=request_data.prompt_token_ids, | |
sampling_params=sampling_params, | |
generator=generator, | |
output_token_ids=[], | |
) | |
self.requests[req_id] = req_state | |
self.input_batch.add_request(req_state) | |
self.prefill_batch.add_request(req_state) |
So the idea is that if it is in self.requests
it is also in the other maps. But probably better to be safe and not make that assumption.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'm also trying to understand the problem here. Could it be that there's a race condition between 2 threads and one thread actually deletes the request from self.requests
while the other is trying to do the same but fails? In that case then Christian's comment above does make sense. Or maybe we can also think about using thread-safe mechanisms here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There shouldn't be any threading in vllm though, as far as I know. The only concurrency should be with async or multiprocessing 😕
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Elegant!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we've run into problems with request cancellation twice, I think we should have a test for it here. We can follow this pattern that's used a couple places in vllm to flex similar behavior with aborting requests in the engine: https://github.com/vllm-project/vllm/blob/e6aab5de2999187c6cf0206f2d63ab6d7a0b6964/tests/v1/engine/test_async_llm.py#L147
Shouldn't be too hard to whip up a similar async test with an AsyncLLM so that this is covered in the future when we keep making changes to the model runner
9567935
to
d278d6a
Compare
b4ac122
to
ad1b99a
Compare
I think we will have to wait for #162 to be merged in to get the tests to pass (they pass in my dev env if I include that PR). |
Is it the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, once test pass this is ready to be merged
tests/e2e/test_spyre_async_llm.py
Outdated
if cancel_after is not None and count >= cancel_after: | ||
return count, request_id | ||
|
||
await asyncio.sleep(0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is await asyncio.sleep(0.0
) enough or do we need await asyncio.sleep(x>0)
. background: arguments 0 and x>0 might not behave the same (see here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is interesting that the behavior is different! Thanks for the link.
The tests seem fine with 0.0
, but I'll add a small value just in case.
del self.req_ids2left_pads[req_id] | ||
|
||
del self.requests[req_id] | ||
logger.debug("Finishing request id: %s", req_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logger.debug("Finishing request id: %s", req_id) | |
if req_id in self.req_ids2blocks: | |
logger.debug("Freeing request id: %s", req_id) |
This debug statement was for specific for CB to have confirmation that the blocks were actually freed. I would suggest to keep it that way. IMO it should not be a general statement that a request has finished, this could probably (or is already) at some other place in the code, not in the CB specific part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removing my block since there are tests now!
…ssed Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
ad1b99a
to
a8cf848
Compare
Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
a8cf848
to
5566ed2
Compare
Signed-off-by: Travis Johnson <[email protected]>
6ebcef6
to
46c87e6
Compare
"""Test handling of cancelled requests""" | ||
|
||
if cb == 1 and backend != "eager": | ||
pytest.skip("CB requires eager") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can come back and fix this up to work on spyre
FIX #225