Skip to content

Commit 30ca18f

Browse files
yuan-luoluoyuan.luo
andauthored
Refactor group_concurrent_contiguous in NIXL (#6214)
Co-authored-by: luoyuan.luo <[email protected]>
1 parent 0388691 commit 30ca18f

File tree

1 file changed

+9
-19
lines changed
  • python/sglang/srt/disaggregation/nixl

1 file changed

+9
-19
lines changed

python/sglang/srt/disaggregation/nixl/conn.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,29 +35,19 @@
3535
NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
3636

3737

38-
# From Mooncake backend.
3938
def group_concurrent_contiguous(
4039
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
4140
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
42-
src_groups = []
43-
dst_groups = []
44-
current_src = [src_indices[0]]
45-
current_dst = [dst_indices[0]]
46-
47-
for i in range(1, len(src_indices)):
48-
src_contiguous = src_indices[i] == src_indices[i - 1] + 1
49-
dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
50-
if src_contiguous and dst_contiguous:
51-
current_src.append(src_indices[i])
52-
current_dst.append(dst_indices[i])
53-
else:
54-
src_groups.append(current_src)
55-
dst_groups.append(current_dst)
56-
current_src = [src_indices[i]]
57-
current_dst = [dst_indices[i]]
41+
"""Vectorised NumPy implementation."""
42+
if src_indices.size == 0:
43+
return [], []
44+
45+
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
46+
src_groups = np.split(src_indices, brk)
47+
dst_groups = np.split(dst_indices, brk)
5848

59-
src_groups.append(current_src)
60-
dst_groups.append(current_dst)
49+
src_groups = [g.tolist() for g in src_groups]
50+
dst_groups = [g.tolist() for g in dst_groups]
6151

6252
return src_groups, dst_groups
6353

0 commit comments

Comments
 (0)