|
13 | 13 | # ==============================================================================
|
14 | 14 | import json
|
15 | 15 | import logging
|
| 16 | +import random |
16 | 17 | from dataclasses import dataclass
|
17 | 18 | from pathlib import Path
|
18 | 19 | from typing import List, Optional
|
@@ -205,10 +206,10 @@ def _init_raw(
|
205 | 206 | logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
206 | 207 | logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
|
207 | 208 | logical_to_all_physical_map=logical_to_all_physical_map,
|
208 |
| - logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, |
209 | 209 | num_gpus=ep_size,
|
210 | 210 | num_physical_experts=num_physical_experts,
|
211 |
| - ep_rank=torch.distributed.get_rank(), |
| 211 | + # TODO improve when we have real EP rank |
| 212 | + ep_rank=torch.distributed.get_rank() % ep_size, |
212 | 213 | ),
|
213 | 214 | )
|
214 | 215 |
|
@@ -296,49 +297,82 @@ def _pad_nested_array(arr, pad_value):
|
296 | 297 | return padded
|
297 | 298 |
|
298 | 299 |
|
299 |
| -# TODO use more sophisticated approaches |
| 300 | +# TODO optimize performance (rewrite and/or run in separate process with overlap) |
300 | 301 | def compute_logical_to_rank_dispatch_physical_map(
|
301 | 302 | logical_to_all_physical_map: torch.Tensor,
|
302 |
| - logical_to_all_physical_map_num_valid: torch.Tensor, |
303 | 303 | num_gpus: int,
|
304 | 304 | num_physical_experts: int,
|
305 | 305 | ep_rank: int,
|
306 |
| - base_seed: int = 42, |
| 306 | + seed: int = 42, |
307 | 307 | ):
|
308 |
| - device = logical_to_all_physical_map.device |
| 308 | + r = random.Random(seed) |
309 | 309 |
|
310 | 310 | num_local_physical_experts = num_physical_experts // num_gpus
|
311 | 311 | num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
|
| 312 | + dtype = logical_to_all_physical_map.dtype |
312 | 313 |
|
313 |
| - g = torch.Generator(device=device) |
314 |
| - g.manual_seed(base_seed + ep_rank) |
315 |
| - |
316 |
| - output_shape = (num_layers, num_logical_experts) |
317 |
| - chosen_index = ( |
318 |
| - torch.randint( |
319 |
| - 0, 65536, output_shape, dtype=torch.int32, device=device, generator=g |
320 |
| - ) |
321 |
| - % logical_to_all_physical_map_num_valid |
| 314 | + logical_to_rank_dispatch_physical_map = torch.full( |
| 315 | + size=(num_gpus, num_layers, num_logical_experts), |
| 316 | + fill_value=-1, |
| 317 | + dtype=dtype, |
322 | 318 | )
|
323 |
| - logical_to_rank_dispatch_physical_map = torch.gather( |
324 |
| - logical_to_all_physical_map, dim=2, index=chosen_index.unsqueeze(-1) |
325 |
| - ).squeeze(-1) |
326 |
| - assert logical_to_rank_dispatch_physical_map.shape == output_shape |
327 |
| - |
328 |
| - for index in range(logical_to_all_physical_map_num_valid.max().item()): |
329 |
| - partial_logical_to_all_physical_map = logical_to_all_physical_map[:, :, index] |
330 |
| - is_valid = partial_logical_to_all_physical_map != -1 |
331 |
| - is_same_gpu = ( |
332 |
| - partial_logical_to_all_physical_map // num_local_physical_experts |
333 |
| - ) == ep_rank |
334 |
| - logical_to_rank_dispatch_physical_map = torch.where( |
335 |
| - is_valid & is_same_gpu, |
336 |
| - partial_logical_to_all_physical_map, |
337 |
| - logical_to_rank_dispatch_physical_map, |
338 |
| - ) |
| 319 | + |
| 320 | + for layer_id in range(num_layers): |
| 321 | + for logical_expert_id in range(num_logical_experts): |
| 322 | + candidate_physical_expert_ids = _logical_to_all_physical_raw( |
| 323 | + logical_to_all_physical_map, layer_id, logical_expert_id |
| 324 | + ) |
| 325 | + output_partial = logical_to_rank_dispatch_physical_map[ |
| 326 | + :, layer_id, logical_expert_id |
| 327 | + ] |
| 328 | + |
| 329 | + for gpu_id in range(num_gpus): |
| 330 | + same_gpu_physical_expert_ids = [ |
| 331 | + physical_expert_id |
| 332 | + for physical_expert_id in candidate_physical_expert_ids |
| 333 | + if _compute_gpu_id_of_physical_expert( |
| 334 | + physical_expert_id, num_local_physical_experts |
| 335 | + ) |
| 336 | + == gpu_id |
| 337 | + ] |
| 338 | + if len(same_gpu_physical_expert_ids) > 0: |
| 339 | + output_partial[gpu_id] = same_gpu_physical_expert_ids[0] |
| 340 | + |
| 341 | + num_remain = torch.sum(output_partial == -1).item() |
| 342 | + output_partial[output_partial == -1] = torch.tensor( |
| 343 | + _fair_choices(candidate_physical_expert_ids, k=num_remain, r=r), |
| 344 | + dtype=dtype, |
| 345 | + ) |
339 | 346 |
|
340 | 347 | assert torch.all(logical_to_rank_dispatch_physical_map != -1)
|
341 |
| - return logical_to_rank_dispatch_physical_map |
| 348 | + |
| 349 | + device = logical_to_all_physical_map.device |
| 350 | + return logical_to_rank_dispatch_physical_map[ep_rank, :, :].to(device) |
| 351 | + |
| 352 | + |
| 353 | +def _logical_to_all_physical_raw( |
| 354 | + logical_to_all_physical_map, layer_id: int, logical_expert_id: int |
| 355 | +) -> List[int]: |
| 356 | + return [ |
| 357 | + physical_expert_id |
| 358 | + for physical_expert_id in logical_to_all_physical_map[ |
| 359 | + layer_id, logical_expert_id |
| 360 | + ].tolist() |
| 361 | + if physical_expert_id != -1 |
| 362 | + ] |
| 363 | + |
| 364 | + |
| 365 | +def _compute_gpu_id_of_physical_expert( |
| 366 | + physical_expert_id: int, num_local_physical_experts: int |
| 367 | +) -> int: |
| 368 | + return physical_expert_id // num_local_physical_experts |
| 369 | + |
| 370 | + |
| 371 | +def _fair_choices(arr: List, k: int, r: random.Random) -> List: |
| 372 | + quotient, remainder = divmod(k, len(arr)) |
| 373 | + ans = arr * quotient + r.sample(arr, k=remainder) |
| 374 | + r.shuffle(ans) |
| 375 | + return ans |
342 | 376 |
|
343 | 377 |
|
344 | 378 | @dataclass
|
|
0 commit comments