-
Notifications
You must be signed in to change notification settings - Fork 501
[Feature] kv_transfer/kv_connector: Add aibrix_pd_reuse_connector to support PD + reuse #1852
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @dczhu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new SHFSConnector for shared file system L2 cache and a AIBrixPDReuseConnector to support prefiller-decoder separation with KV cache reuse. The overall implementation is good, but there are several areas for improvement. I've identified a bug in directory creation within the SHFSConnector, significant performance issues in the mget/mput implementations, and several instances of redundant or unclear code. My review comments provide specific suggestions to address these points, aiming to improve correctness, performance, and maintainability.
| def open(self) -> Status: | ||
| """Open a connection by ensuring the root directory exists.""" | ||
| try: | ||
| ensure_dir_exist(str(self.root_path)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ensure_dir_exist function creates the parent directory of the given path, not the path itself. Since self.root_path is the directory that needs to be created, you should use os.makedirs(self.root_path, exist_ok=True) instead. This is a bug as the method does not behave as its docstring suggests.
| ensure_dir_exist(str(self.root_path)) | |
| os.makedirs(self.root_path, exist_ok=True) |
| statuses = [] | ||
|
|
||
| for i, (key, mr) in enumerate(zip(keys, mrs)): | ||
| status = await self.get(key, mr) | ||
| statuses.append(status) | ||
| if not status.is_ok() and not status.is_not_found(): | ||
| logger.error(f"SHFS mget[{i}] failed: {status}") | ||
|
|
||
| return statuses |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of mget iterates and awaits get calls sequentially. This negates the benefit of an async mget method, which should perform operations in parallel. Using asyncio.gather will execute the file reads concurrently, significantly improving performance for multiple keys.
| statuses = [] | |
| for i, (key, mr) in enumerate(zip(keys, mrs)): | |
| status = await self.get(key, mr) | |
| statuses.append(status) | |
| if not status.is_ok() and not status.is_not_found(): | |
| logger.error(f"SHFS mget[{i}] failed: {status}") | |
| return statuses | |
| import asyncio | |
| tasks = [self.get(key, mr) for key, mr in zip(keys, mrs)] | |
| statuses = await asyncio.gather(*tasks) | |
| for i, status in enumerate(statuses): | |
| if not status.is_ok() and not status.is_not_found(): | |
| logger.error(f"SHFS mget[{i}] failed: {status}") | |
| return list(statuses) |
| statuses = [] | ||
|
|
||
| for i, (key, mr) in enumerate(zip(keys, mrs)): | ||
| status = await self.put(key, mr) | ||
| statuses.append(status) | ||
| if not status.is_ok(): | ||
| logger.error(f"SHFS mput[{i}] failed: {status}") | ||
|
|
||
| return statuses |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to mget, the mput method is implemented as a sequential loop. This is inefficient. To leverage the asynchronicity and improve performance, you should use asyncio.gather to perform the file write operations concurrently.
| statuses = [] | |
| for i, (key, mr) in enumerate(zip(keys, mrs)): | |
| status = await self.put(key, mr) | |
| statuses.append(status) | |
| if not status.is_ok(): | |
| logger.error(f"SHFS mput[{i}] failed: {status}") | |
| return statuses | |
| import asyncio | |
| tasks = [self.put(key, mr) for key, mr in zip(keys, mrs)] | |
| statuses = await asyncio.gather(*tasks) | |
| for i, status in enumerate(statuses): | |
| if not status.is_ok(): | |
| logger.error(f"SHFS mput[{i}] failed: {status}") | |
| return list(statuses) |
| AIBRIX_KV_CACHE_OL_SHFS_ROOT: str = os.path.expanduser( | ||
| os.path.join(os.path.expanduser("~"), ".kv_cache_ol", "shfs") | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The outer os.path.expanduser is redundant here. os.path.join(os.path.expanduser("~"), ...) already produces an absolute path, so wrapping it in another os.path.expanduser has no effect. Removing it will make the code cleaner.
| AIBRIX_KV_CACHE_OL_SHFS_ROOT: str = os.path.expanduser( | |
| os.path.join(os.path.expanduser("~"), ".kv_cache_ol", "shfs") | |
| ) | |
| AIBRIX_KV_CACHE_OL_SHFS_ROOT: str = os.path.join( | |
| os.path.expanduser("~"), ".kv_cache_ol", "shfs" | |
| ) |
| root = envs.AIBRIX_KV_CACHE_OL_SHFS_ROOT | ||
|
|
||
| # Create full path: root/conn_id | ||
| full_path = os.path.join(os.path.expanduser(root), conn_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def get_batches( | ||
| self, | ||
| keys: Sequence[bytes], | ||
| mrs: Sequence[MemoryRegion | Sequence[MemoryRegion]], | ||
| batch_size: int, | ||
| ) -> Sequence[ | ||
| Sequence[tuple[bytes, MemoryRegion | Sequence[MemoryRegion]]] | ||
| ]: | ||
| """Get batches for mput/mget operations.""" | ||
| batches = [] | ||
| current_batch = [] | ||
|
|
||
| for key, mr in zip(keys, mrs): | ||
| current_batch.append((key, mr)) | ||
| if len(current_batch) >= batch_size: | ||
| batches.append(current_batch) | ||
| current_batch = [] | ||
|
|
||
| if current_batch: | ||
| batches.append(current_batch) | ||
|
|
||
| return batches |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| + import os | ||
| + from aibrix_kvcache import envs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| + if is_prefiller_with_remote_decode: | ||
| + # Prefiller: If entire KV cache exists in SHFS, skip loading (decoder will load it) | ||
| + if exists_status.is_ok() and num_existing_tokens >= aligned_query_len: | ||
| + return 0 # Skip loading, decoder will handle it | ||
| + elif is_decoder_with_remote_prefill: | ||
| + # Decoder: Always try to load from SHFS if exists (this is the main path for PD separation) | ||
| + if exists_status.is_ok() and num_existing_tokens >= aligned_query_len: | ||
| + # Continue to acquire (will load from SHFS) | ||
| + pass | ||
| + else: | ||
| + # KV cache reuse only (no PD separation): Use threshold to avoid loading very small chunks | ||
| + threshold = max( | ||
| + OFFLOADING_CONNECTOR_SKIP_THRESHOLD * self.engine_block_ntokens, | ||
| + self.cache_block_ntokens, | ||
| + ) | ||
| + if aligned_query_len < threshold: | ||
| + return 0 | ||
| + # For kvcache reuse, proceed to load if exists | ||
| + if exists_status.is_ok() and num_existing_tokens >= aligned_query_len: | ||
| + # KV cache reuse, entire cache exists | ||
| + pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The conditional logic here can be simplified. The if statements that only contain a pass statement are redundant because the code would fall through to the loading logic anyway. Removing them and restructuring the conditions will make the code more readable and easier to maintain.
if is_prefiller_with_remote_decode:
# Prefiller: If entire KV cache exists in SHFS, skip loading (decoder will load it)
if exists_status.is_ok() and num_existing_tokens >= aligned_query_len:
return 0 # Skip loading, decoder will handle it
elif not is_decoder_with_remote_prefill:
# KV cache reuse only (no PD separation): Use threshold to avoid loading very small chunks
threshold = max(
OFFLOADING_CONNECTOR_SKIP_THRESHOLD * self.engine_block_ntokens,
self.cache_block_ntokens,
)
if aligned_query_len < threshold:
return 0
ccdeb16 to
e78cfae
Compare
|
@dczhu there are some format issues, please check the errors of the failed workflow action and help fix them |
| + self._scheduler_meta = AIBrixPDReuseConnectorMetadata({}) | ||
| + | ||
| + # Track requests that need PD transfer | ||
| + self._reqs_need_send: dict[str, float] = {} # req_id -> expiration_time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems unused?
| + assert config.kv_transfer_config.engine_id is not None | ||
| + self.engine_id = config.kv_transfer_config.engine_id | ||
| + | ||
| + self.side_channel_host = getattr(vllm.envs, 'VLLM_NIXL_SIDE_CHANNEL_HOST', '127.0.0.1') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in this kvcache-oriented architecture, since P and D will not talk directly, do we still need the side channel as nixl do?
| + if not params: | ||
| + return | ||
| + | ||
| + # Handle PD separation: update metadata if needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PD separation -> PD disaggregation
| + from aibrix_kvcache import envs | ||
| + l2_backend = os.getenv("AIBRIX_KV_CACHE_OL_L2_CACHE_BACKEND", "").strip().upper() | ||
| + | ||
| + needs_async_load = l2_backend not in ["SHFS", ""] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right now the decoder will also rely on start_load_kv_before_update to load kvcache, which is not async no matter which L2 backend we are using, so let's remove L2 backend check here for now.
253804d to
5f24073
Compare
|
Thanks @DwyaneShi for the detailed review! I updated the connector file and also the test image (which is reflected in the test yaml files). |
Signed-off-by: Dengcheng Zhu <[email protected]>
… + reuse Signed-off-by: Dengcheng Zhu <[email protected]>
Signed-off-by: Dengcheng Zhu <[email protected]>
5f24073 to
95e4967
Compare
|
Regarding prefiller's partial block at the tail, I'll need to add support for st/ld, probably in a separate feature enhancement PR. @DwyaneShi |
Pull Request Description
Related Issues
Resolves: #[Insert issue number(s)]
Important: Before submitting, please complete the description above and review the checklist below.
Contribution Guidelines (Expand for Details)
We appreciate your contribution to aibrix! To ensure a smooth review process and maintain high code quality, please adhere to the following guidelines:
Pull Request Title Format
Your PR title should start with one of these prefixes to indicate the nature of the change:
[Bug]: Corrections to existing functionality[CI]: Changes to build process or CI pipeline[Docs]: Updates or additions to documentation[API]: Modifications to aibrix's API or interface[CLI]: Changes or additions to the Command Line Interface[Misc]: For changes not covered above (use sparingly)Note: For changes spanning multiple categories, use multiple prefixes in order of importance.
Submission Checklist
By submitting this PR, you confirm that you've read these guidelines and your changes align with the project's contribution standards.