|
35 | 35 | NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
|
36 | 36 |
|
37 | 37 |
|
38 |
| -# From Mooncake backend. |
39 | 38 | def group_concurrent_contiguous(
|
40 | 39 | src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
41 | 40 | ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
42 |
| - src_groups = [] |
43 |
| - dst_groups = [] |
44 |
| - current_src = [src_indices[0]] |
45 |
| - current_dst = [dst_indices[0]] |
46 |
| - |
47 |
| - for i in range(1, len(src_indices)): |
48 |
| - src_contiguous = src_indices[i] == src_indices[i - 1] + 1 |
49 |
| - dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1 |
50 |
| - if src_contiguous and dst_contiguous: |
51 |
| - current_src.append(src_indices[i]) |
52 |
| - current_dst.append(dst_indices[i]) |
53 |
| - else: |
54 |
| - src_groups.append(current_src) |
55 |
| - dst_groups.append(current_dst) |
56 |
| - current_src = [src_indices[i]] |
57 |
| - current_dst = [dst_indices[i]] |
| 41 | + """Vectorised NumPy implementation.""" |
| 42 | + if src_indices.size == 0: |
| 43 | + return [], [] |
| 44 | + |
| 45 | + brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1 |
| 46 | + src_groups = np.split(src_indices, brk) |
| 47 | + dst_groups = np.split(dst_indices, brk) |
58 | 48 |
|
59 |
| - src_groups.append(current_src) |
60 |
| - dst_groups.append(current_dst) |
| 49 | + src_groups = [g.tolist() for g in src_groups] |
| 50 | + dst_groups = [g.tolist() for g in dst_groups] |
61 | 51 |
|
62 | 52 | return src_groups, dst_groups
|
63 | 53 |
|
|
0 commit comments