Skip to content

Commit 2c3b71d

Browse files
authored
Improve EPLB logical to physical dispatch map (sgl-project#6727)
1 parent 51cdd81 commit 2c3b71d

File tree

1 file changed

+66
-32
lines changed

1 file changed

+66
-32
lines changed

python/sglang/srt/managers/expert_location.py

Lines changed: 66 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# ==============================================================================
1414
import json
1515
import logging
16+
import random
1617
from dataclasses import dataclass
1718
from pathlib import Path
1819
from typing import List, Optional
@@ -205,10 +206,10 @@ def _init_raw(
205206
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
206207
logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
207208
logical_to_all_physical_map=logical_to_all_physical_map,
208-
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
209209
num_gpus=ep_size,
210210
num_physical_experts=num_physical_experts,
211-
ep_rank=torch.distributed.get_rank(),
211+
# TODO improve when we have real EP rank
212+
ep_rank=torch.distributed.get_rank() % ep_size,
212213
),
213214
)
214215

@@ -296,49 +297,82 @@ def _pad_nested_array(arr, pad_value):
296297
return padded
297298

298299

299-
# TODO use more sophisticated approaches
300+
# TODO optimize performance (rewrite and/or run in separate process with overlap)
300301
def compute_logical_to_rank_dispatch_physical_map(
301302
logical_to_all_physical_map: torch.Tensor,
302-
logical_to_all_physical_map_num_valid: torch.Tensor,
303303
num_gpus: int,
304304
num_physical_experts: int,
305305
ep_rank: int,
306-
base_seed: int = 42,
306+
seed: int = 42,
307307
):
308-
device = logical_to_all_physical_map.device
308+
r = random.Random(seed)
309309

310310
num_local_physical_experts = num_physical_experts // num_gpus
311311
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
312+
dtype = logical_to_all_physical_map.dtype
312313

313-
g = torch.Generator(device=device)
314-
g.manual_seed(base_seed + ep_rank)
315-
316-
output_shape = (num_layers, num_logical_experts)
317-
chosen_index = (
318-
torch.randint(
319-
0, 65536, output_shape, dtype=torch.int32, device=device, generator=g
320-
)
321-
% logical_to_all_physical_map_num_valid
314+
logical_to_rank_dispatch_physical_map = torch.full(
315+
size=(num_gpus, num_layers, num_logical_experts),
316+
fill_value=-1,
317+
dtype=dtype,
322318
)
323-
logical_to_rank_dispatch_physical_map = torch.gather(
324-
logical_to_all_physical_map, dim=2, index=chosen_index.unsqueeze(-1)
325-
).squeeze(-1)
326-
assert logical_to_rank_dispatch_physical_map.shape == output_shape
327-
328-
for index in range(logical_to_all_physical_map_num_valid.max().item()):
329-
partial_logical_to_all_physical_map = logical_to_all_physical_map[:, :, index]
330-
is_valid = partial_logical_to_all_physical_map != -1
331-
is_same_gpu = (
332-
partial_logical_to_all_physical_map // num_local_physical_experts
333-
) == ep_rank
334-
logical_to_rank_dispatch_physical_map = torch.where(
335-
is_valid & is_same_gpu,
336-
partial_logical_to_all_physical_map,
337-
logical_to_rank_dispatch_physical_map,
338-
)
319+
320+
for layer_id in range(num_layers):
321+
for logical_expert_id in range(num_logical_experts):
322+
candidate_physical_expert_ids = _logical_to_all_physical_raw(
323+
logical_to_all_physical_map, layer_id, logical_expert_id
324+
)
325+
output_partial = logical_to_rank_dispatch_physical_map[
326+
:, layer_id, logical_expert_id
327+
]
328+
329+
for gpu_id in range(num_gpus):
330+
same_gpu_physical_expert_ids = [
331+
physical_expert_id
332+
for physical_expert_id in candidate_physical_expert_ids
333+
if _compute_gpu_id_of_physical_expert(
334+
physical_expert_id, num_local_physical_experts
335+
)
336+
== gpu_id
337+
]
338+
if len(same_gpu_physical_expert_ids) > 0:
339+
output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
340+
341+
num_remain = torch.sum(output_partial == -1).item()
342+
output_partial[output_partial == -1] = torch.tensor(
343+
_fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
344+
dtype=dtype,
345+
)
339346

340347
assert torch.all(logical_to_rank_dispatch_physical_map != -1)
341-
return logical_to_rank_dispatch_physical_map
348+
349+
device = logical_to_all_physical_map.device
350+
return logical_to_rank_dispatch_physical_map[ep_rank, :, :].to(device)
351+
352+
353+
def _logical_to_all_physical_raw(
354+
logical_to_all_physical_map, layer_id: int, logical_expert_id: int
355+
) -> List[int]:
356+
return [
357+
physical_expert_id
358+
for physical_expert_id in logical_to_all_physical_map[
359+
layer_id, logical_expert_id
360+
].tolist()
361+
if physical_expert_id != -1
362+
]
363+
364+
365+
def _compute_gpu_id_of_physical_expert(
366+
physical_expert_id: int, num_local_physical_experts: int
367+
) -> int:
368+
return physical_expert_id // num_local_physical_experts
369+
370+
371+
def _fair_choices(arr: List, k: int, r: random.Random) -> List:
372+
quotient, remainder = divmod(k, len(arr))
373+
ans = arr * quotient + r.sample(arr, k=remainder)
374+
r.shuffle(ans)
375+
return ans
342376

343377

344378
@dataclass

0 commit comments

Comments
 (0)