@@ -1466,14 +1466,36 @@ def process_batch_result(
1466
1466
self .send_to_tokenizer .send_pyobj (HealthCheckOutput ())
1467
1467
1468
1468
def prepare_dp_attn_batch (self , local_batch : ScheduleBatch ):
1469
+ return self .prepare_dp_attn_batch_raw (
1470
+ local_batch ,
1471
+ dp_size = self .server_args .dp_size ,
1472
+ attn_tp_size = self .attn_tp_size ,
1473
+ tp_cpu_group = self .tp_cpu_group ,
1474
+ get_idle_batch = self .get_idle_batch ,
1475
+ disable_cuda_graph = self .server_args .disable_cuda_graph ,
1476
+ spec_algorithm = self .spec_algorithm ,
1477
+ speculative_num_draft_tokens = self .server_args .speculative_num_draft_tokens ,
1478
+ )
1479
+
1480
+ @staticmethod
1481
+ def prepare_dp_attn_batch_raw (
1482
+ local_batch : ScheduleBatch ,
1483
+ dp_size ,
1484
+ attn_tp_size : int ,
1485
+ tp_cpu_group ,
1486
+ get_idle_batch ,
1487
+ disable_cuda_graph : bool ,
1488
+ spec_algorithm ,
1489
+ speculative_num_draft_tokens ,
1490
+ ):
1469
1491
# Check if other DP workers have running batches
1470
1492
if local_batch is None :
1471
1493
num_tokens = 0
1472
1494
global_num_tokens_for_logprob = 0
1473
1495
elif local_batch .forward_mode .is_decode ():
1474
1496
num_tokens = local_batch .batch_size ()
1475
- if not self . spec_algorithm .is_none () and self . spec_algorithm .is_eagle ():
1476
- num_tokens = num_tokens * self . server_args . speculative_num_draft_tokens
1497
+ if not spec_algorithm .is_none () and spec_algorithm .is_eagle ():
1498
+ num_tokens = num_tokens * speculative_num_draft_tokens
1477
1499
global_num_tokens_for_logprob = num_tokens
1478
1500
else :
1479
1501
num_tokens = local_batch .extend_num_tokens
@@ -1492,7 +1514,7 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1492
1514
else :
1493
1515
can_cuda_graph = 0
1494
1516
1495
- if not self . spec_algorithm .is_none ():
1517
+ if not spec_algorithm .is_none ():
1496
1518
# TODO(sang): Support cuda graph when idle batch is there.
1497
1519
if local_batch is None or local_batch .forward_mode .is_idle ():
1498
1520
can_cuda_graph = 0
@@ -1510,28 +1532,28 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1510
1532
dtype = torch .int64 ,
1511
1533
)
1512
1534
global_info = torch .empty (
1513
- (self . server_args . dp_size , self . attn_tp_size , 4 ),
1535
+ (dp_size , attn_tp_size , 4 ),
1514
1536
dtype = torch .int64 ,
1515
1537
)
1516
1538
torch .distributed .all_gather_into_tensor (
1517
1539
global_info .flatten (),
1518
1540
local_info ,
1519
- group = self . tp_cpu_group ,
1541
+ group = tp_cpu_group ,
1520
1542
)
1521
1543
global_num_tokens = global_info [:, 0 , 0 ].tolist ()
1522
1544
can_cuda_graph = min (global_info [:, 0 , 1 ].tolist ())
1523
1545
global_num_tokens_for_logprob = global_info [:, 0 , 2 ].tolist ()
1524
1546
is_extend_in_batch = global_info [:, 0 , 3 ].tolist ()
1525
1547
1526
1548
if local_batch is None and max (global_num_tokens ) > 0 :
1527
- local_batch = self . get_idle_batch ()
1549
+ local_batch = get_idle_batch ()
1528
1550
1529
1551
if local_batch is not None :
1530
1552
local_batch .global_num_tokens = global_num_tokens
1531
1553
local_batch .global_num_tokens_for_logprob = global_num_tokens_for_logprob
1532
1554
1533
1555
# Check forward mode for cuda graph
1534
- if not self . server_args . disable_cuda_graph :
1556
+ if not disable_cuda_graph :
1535
1557
local_batch .can_run_dp_cuda_graph = can_cuda_graph
1536
1558
1537
1559
return local_batch , any (is_extend_in_batch )
0 commit comments