-
Notifications
You must be signed in to change notification settings - Fork 501
[Fix] Fix vLLM NIXL-based P/D samples #1425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| import os | ||
| import random | ||
| import time | ||
| import uuid | ||
| from contextlib import asynccontextmanager | ||
|
|
||
| import httpx | ||
|
|
@@ -99,20 +100,46 @@ def parse_args(): | |
| return args | ||
|
|
||
|
|
||
| async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, req_data: dict): | ||
| req_data = req_data.copy() | ||
| req_data["max_tokens"] = 1 | ||
| if "max_completion_tokens" in req_data: | ||
| req_data["max_completion_tokens"] = 1 | ||
|
|
||
| headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} | ||
| response = await client.post(endpoint, json=req_data, headers=headers) | ||
| async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, req_id: str, req_data: dict): | ||
| req_data_copy = req_data.copy() | ||
| # nixl-specific kv_transfer_params for prefillers | ||
| req_data_copy['kv_transfer_params'] = { | ||
| "do_remote_decode": True, | ||
| "do_remote_prefill": False, | ||
| "remote_engine_id": None, | ||
| "remote_block_ids": None, | ||
| "remote_host": None, | ||
| "remote_port": None | ||
| } | ||
| # disable streaming for prefillers | ||
| req_data_copy["stream"] = False | ||
| if "stream_options" in req_data_copy: | ||
| del req_data_copy["stream_options"] | ||
| req_data_copy["max_tokens"] = 1 | ||
| if "max_completion_tokens" in req_data_copy: | ||
| req_data_copy["max_completion_tokens"] = 1 | ||
|
|
||
| headers = { | ||
| "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", | ||
| "X-Request-Id": req_id | ||
| } | ||
| response = await client.post(endpoint, json=req_data_copy, headers=headers) | ||
| response.raise_for_status() | ||
| # extract nixl-specific kv_transfer_params returned from prefillers and | ||
| # attach to the req_data for decode clients | ||
| response_json = response.json() | ||
| kv_transfer_params = response_json.get('kv_transfer_params', {}) | ||
| if kv_transfer_params: | ||
| req_data["kv_transfer_params"] = kv_transfer_params | ||
| req_data["kv_transfer_params"]["remote_host"] = client.base_url.host | ||
| return response | ||
|
|
||
|
|
||
| async def stream_service_response(client: httpx.AsyncClient, endpoint: str, req_data: dict): | ||
| headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} | ||
| async def stream_service_response(client: httpx.AsyncClient, endpoint: str, req_id: str, req_data: dict): | ||
| headers = { | ||
| "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", | ||
| "X-Request-Id": req_id | ||
| } | ||
| async with client.stream("POST", endpoint, json=req_data, headers=headers) as response: | ||
| response.raise_for_status() | ||
| async for chunk in response.aiter_bytes(): | ||
|
|
@@ -141,15 +168,16 @@ async def handle_completions(request: Request): | |
| st = time.time() | ||
|
|
||
| try: | ||
| req_id = str(uuid.uuid4()) | ||
| req_data = await request.json() | ||
| prefill_client, decode_client = select_random_clients() | ||
|
|
||
| await send_request_to_service(prefill_client, "/completions", req_data) | ||
| await send_request_to_service(prefill_client, "/completions", req_id, req_data) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To accompany the suggested change in kv_transfer_params = await send_request_to_service(prefill_client, "/completions", req_id, req_data)
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params |
||
| et = time.time() | ||
| stats_calculator.add(et - st) | ||
|
|
||
| async def generate_stream(): | ||
| async for chunk in stream_service_response(decode_client, "/completions", req_data): | ||
| async for chunk in stream_service_response(decode_client, "/completions", req_id, req_data): | ||
| yield chunk | ||
|
|
||
| return StreamingResponse(generate_stream(), media_type="text/event-stream") | ||
|
|
@@ -169,15 +197,16 @@ async def handle_chat_completions(request: Request): | |
| st = time.time() | ||
|
|
||
| try: | ||
| req_id = str(uuid.uuid4()) | ||
| req_data = await request.json() | ||
| prefill_client, decode_client = select_random_clients() | ||
|
|
||
| await send_request_to_service(prefill_client, "/chat/completions", req_data) | ||
| await send_request_to_service(prefill_client, "/chat/completions", req_id, req_data) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To accompany the suggested change in kv_transfer_params = await send_request_to_service(prefill_client, "/chat/completions", req_id, req_data)
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params |
||
| et = time.time() | ||
| stats_calculator.add(et - st) | ||
|
|
||
| async def generate_stream(): | ||
| async for chunk in stream_service_response(decode_client, "/chat/completions", req_data): | ||
| async for chunk in stream_service_response(decode_client, "/chat/completions", req_id, req_data): | ||
| yield chunk | ||
|
|
||
| return StreamingResponse(generate_stream(), media_type="text/event-stream") | ||
|
|
@@ -195,4 +224,4 @@ async def generate_stream(): | |
| global_args = parse_args() | ||
|
|
||
| import uvicorn | ||
| uvicorn.run(app, host=global_args.host, port=global_args.port) | ||
| uvicorn.run(app, host=global_args.host, port=global_args.port) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function modifies the
req_datadictionary in-place, which is a side effect that can make the code harder to reason about. A cleaner approach is to return thekv_transfer_paramsand let the caller updatereq_data. This makes the data flow explicit. Since theresponseobject is not used by the callers, the function's return value can be changed to facilitate this refactoring.