10
10
import uuid
11
11
from collections import defaultdict
12
12
from functools import cache
13
- from typing import Dict , List , Optional , Tuple , Union
13
+ from typing import Dict , List , Optional , Set , Tuple , TypeAlias , Union
14
14
15
15
import numpy as np
16
16
import numpy .typing as npt
32
32
33
33
logger = logging .getLogger (__name__ )
34
34
35
+ NixlEngineInfo : TypeAlias = Dict [str , Union [str , int ]]
36
+
35
37
36
38
@dataclasses .dataclass
37
39
class TransferInfo :
@@ -45,19 +47,36 @@ class TransferInfo:
45
47
dst_aux_index : int
46
48
dst_gpu_id : int
47
49
50
+ def is_dummy (self ):
51
+ return self .endpoint == ""
52
+
48
53
@classmethod
49
54
def from_zmq (cls , msg : List [bytes ]):
50
- return cls (
51
- room = int (msg [0 ].decode ("ascii" )),
52
- endpoint = msg [1 ].decode ("ascii" ),
53
- dst_port = int (msg [2 ].decode ("ascii" )),
54
- agent_metadata = msg [3 ],
55
- dst_kv_ptrs = list (struct .unpack (f"{ len (msg [4 ])// 8 } Q" , msg [4 ])),
56
- dst_kv_indices = np .frombuffer (msg [5 ], dtype = np .int64 ),
57
- dst_aux_ptrs = list (struct .unpack (f"{ len (msg [6 ])// 8 } Q" , msg [6 ])),
58
- dst_aux_index = int (msg [7 ].decode ("ascii" )),
59
- dst_gpu_id = int (msg [8 ].decode ("ascii" )),
60
- )
55
+ if len (msg ) == 1 :
56
+ # dummy msg
57
+ return cls (
58
+ room = int (msg [0 ].decode ("ascii" )),
59
+ endpoint = "" ,
60
+ dst_port = 0 ,
61
+ agent_metadata = b"" ,
62
+ dst_kv_ptrs = [],
63
+ dst_kv_indices = np .array ([], dtype = np .int64 ),
64
+ dst_aux_ptrs = [],
65
+ dst_aux_index = 0 ,
66
+ dst_gpu_id = 0 ,
67
+ )
68
+ else :
69
+ return cls (
70
+ room = int (msg [0 ].decode ("ascii" )),
71
+ endpoint = msg [1 ].decode ("ascii" ),
72
+ dst_port = int (msg [2 ].decode ("ascii" )),
73
+ agent_metadata = msg [3 ],
74
+ dst_kv_ptrs = list (struct .unpack (f"{ len (msg [4 ])// 8 } Q" , msg [4 ])),
75
+ dst_kv_indices = np .frombuffer (msg [5 ], dtype = np .int64 ),
76
+ dst_aux_ptrs = list (struct .unpack (f"{ len (msg [6 ])// 8 } Q" , msg [6 ])),
77
+ dst_aux_index = int (msg [7 ].decode ("ascii" )),
78
+ dst_gpu_id = int (msg [8 ].decode ("ascii" )),
79
+ )
61
80
62
81
63
82
@dataclasses .dataclass
@@ -98,6 +117,19 @@ def __init__(
98
117
# for p/d multi node infer
99
118
self .bootstrap_port = server_args .disaggregation_bootstrap_port
100
119
self .dist_init_addr = server_args .dist_init_addr
120
+ self .tp_size = server_args .tp_size
121
+
122
+ self .tp_rank = args .engine_rank
123
+ self .enable_dp_attention = server_args .enable_dp_attention
124
+ if self .enable_dp_attention :
125
+ assert (
126
+ server_args .dp_size > 1
127
+ ), "If dp_attention is enabled, dp size must be greater than 1 in disaggregation mode."
128
+ self .dp_size = server_args .dp_size
129
+ self .tp_size_of_dp = server_args .tp_size // server_args .dp_size
130
+ self .attn_tp_rank = args .engine_rank % self .tp_size_of_dp
131
+ self .dp_rank = args .engine_rank // self .tp_size_of_dp
132
+
101
133
self .rank_port = None
102
134
self .server_socket = zmq .Context ().socket (zmq .PULL )
103
135
self .register_buffer_to_engine ()
@@ -110,7 +142,10 @@ def __init__(
110
142
self ._start_bootstrap_thread ()
111
143
self ._register_to_bootstrap ()
112
144
elif self .disaggregation_mode == DisaggregationMode .DECODE :
113
- self .connection_pool : Dict [str , Dict [str , Union [str , int ]]] = {}
145
+ # bootstrap key -> (engine_rank - >real source remote, engine_rank -> dummy remote)
146
+ self .prefill_peer_infos : Dict [
147
+ str , Tuple [Dict [int , NixlEngineInfo ], Dict [int , NixlEngineInfo ]]
148
+ ] = {}
114
149
self .transfer_statuses : Dict [int , TransferStatus ] = defaultdict (
115
150
TransferStatus
116
151
)
@@ -180,7 +215,7 @@ def send_kvcache(
180
215
src_descs ,
181
216
dst_descs ,
182
217
peer_name ,
183
- notif .encode ("ascii" ),
218
+ notif .encode ("ascii" ), # type: ignore
184
219
)
185
220
if not xfer_handle :
186
221
raise Exception ("KVSender failed to create transfer" )
@@ -213,7 +248,7 @@ def send_aux(
213
248
src_descs ,
214
249
dst_descs ,
215
250
peer_name ,
216
- notif .encode ("ascii" ),
251
+ notif .encode ("ascii" ), # type: ignore
217
252
)
218
253
if not xfer_handle :
219
254
raise Exception ("KVSender failed to create transfer" )
@@ -240,6 +275,9 @@ def add_transfer_request(
240
275
req = self .transfer_infos [bootstrap_room ]
241
276
assert bootstrap_room == req .room
242
277
278
+ if req .is_dummy ():
279
+ return []
280
+
243
281
peer_name = self ._add_remote (bootstrap_room , req .agent_metadata )
244
282
chunked_dst_kv_indice = req .dst_kv_indices [index_slice ]
245
283
assert len (chunked_dst_kv_indice ) == len (kv_indices )
@@ -256,6 +294,7 @@ def add_transfer_request(
256
294
handles = [kv_xfer_handle ]
257
295
# Only the last chunk we need to send the aux data.
258
296
if is_last :
297
+ assert aux_index is not None
259
298
aux_xfer_handle = self .send_aux (
260
299
peer_name ,
261
300
aux_index ,
@@ -372,14 +411,13 @@ def send(
372
411
373
412
def poll (self ) -> KVPoll :
374
413
if not self .has_sent :
375
- return KVPoll .WaitingForInput
376
-
414
+ return KVPoll .WaitingForInput # type: ignore
377
415
states = [self .kv_mgr .agent .check_xfer_state (x ) for x in self .xfer_handles ]
378
416
if all ([x == "DONE" for x in states ]):
379
- return KVPoll .Success
417
+ return KVPoll .Success # type: ignore
380
418
if any ([x == "ERR" for x in states ]):
381
419
raise Exception ("KVSender transfer encountered an error." )
382
- return KVPoll .WaitingForInput
420
+ return KVPoll .WaitingForInput # type: ignore
383
421
384
422
def failure_exception (self ):
385
423
raise Exception ("Fake KVSender Exception" )
@@ -401,7 +439,7 @@ def __init__(
401
439
# NOTE: key distinguished by bootstrap_addr and engine_rank
402
440
bootstrap_key = f"{ self .bootstrap_addr } _{ self .kv_mgr .kv_args .engine_rank } "
403
441
404
- if bootstrap_key not in self .kv_mgr .connection_pool :
442
+ if bootstrap_key not in self .kv_mgr .prefill_peer_infos :
405
443
self .bootstrap_info = self ._get_bootstrap_info_from_server (
406
444
self .kv_mgr .kv_args .engine_rank
407
445
)
@@ -410,25 +448,76 @@ def __init__(
410
448
f"Could not fetch bootstrap info for engine rank: { self .kv_mgr .kv_args .engine_rank } "
411
449
)
412
450
else :
413
- self .kv_mgr .connection_pool [bootstrap_key ] = self .bootstrap_info
451
+ self .kv_mgr .prefill_peer_infos [bootstrap_key ] = self .bootstrap_info
414
452
else :
415
- self .bootstrap_info = self .kv_mgr .connection_pool [bootstrap_key ]
416
-
453
+ self .bootstrap_info = self .kv_mgr .prefill_peer_infos [bootstrap_key ]
417
454
assert self .bootstrap_info is not None
418
455
419
- def _get_bootstrap_info_from_server (self , engine_rank ):
456
+ # return: (real source remotes, others dummy remotes)
457
+ def _get_bootstrap_info_from_server (
458
+ self , engine_rank
459
+ ) -> Optional [Tuple [Dict [int , NixlEngineInfo ], Dict [int , NixlEngineInfo ]]]:
420
460
"""Fetch the bootstrap info from the bootstrap server."""
421
461
try :
422
- url = f"http://{ self .bootstrap_addr } /route?engine_rank={ engine_rank } "
423
- response = requests .get (url )
424
- if response .status_code == 200 :
462
+ if self .kv_mgr .enable_dp_attention :
463
+ url = f"http://{ self .bootstrap_addr } /route"
464
+ response = requests .get (url )
465
+ if response .status_code != 200 :
466
+ logger .error (
467
+ f"Failed to get prefill server info: { response .status_code } , { response .text } "
468
+ )
469
+ return None
470
+
425
471
bootstrap_info = response .json ()
426
- return bootstrap_info
427
- else :
428
- logger .error (
429
- f"Failed to get prefill server info: { response .status_code } , { response .text } "
472
+ assert isinstance (bootstrap_info , dict )
473
+ bootstrap_info = {int (k ): v for k , v in bootstrap_info .items ()}
474
+
475
+ # split out who need to send to this rank.
476
+ # currently for dpsk mla model, those ranks share the same latent cache.
477
+ # pick one as the real source
478
+
479
+ prefill_tp_size = len (bootstrap_info .keys ())
480
+
481
+ assert (
482
+ prefill_tp_size >= self .kv_mgr .tp_size_of_dp
483
+ ), f"Only support Prefill TP size >= Decode TP size of DP, now we have { prefill_tp_size } vs { self .kv_mgr .tp_size_of_dp } "
484
+
485
+ num_remote_tp_rank_we_managed = (
486
+ prefill_tp_size // self .kv_mgr .tp_size_of_dp
430
487
)
431
- return None
488
+
489
+ # We handle [num * self.attn_tp_rank, num * self.attn_tp_rank + num)
490
+ remote_tp_ranks = list (range (0 , prefill_tp_size ))
491
+ # split it into tp_size_of_dp parts and get our part
492
+ remote_tp_ranks_grouped = [
493
+ remote_tp_ranks [i : i + num_remote_tp_rank_we_managed ]
494
+ for i in range (0 , prefill_tp_size , self .kv_mgr .tp_size_of_dp )
495
+ ]
496
+ managed_ranks = remote_tp_ranks_grouped [self .kv_mgr .attn_tp_rank ]
497
+ picked_rank = managed_ranks [0 ]
498
+
499
+ assert len (managed_ranks ) == num_remote_tp_rank_we_managed
500
+
501
+ logger .debug (
502
+ f"Rank { self .kv_mgr .kv_args .engine_rank } managed { managed_ranks } , picked { picked_rank } as real source"
503
+ )
504
+
505
+ return {picked_rank : bootstrap_info [picked_rank ]}, {
506
+ rk : bootstrap_info [rk ]
507
+ for rk in bootstrap_info .keys ()
508
+ if rk in managed_ranks and rk != picked_rank
509
+ }
510
+ else :
511
+ url = f"http://{ self .bootstrap_addr } /route?engine_rank={ engine_rank } "
512
+ response = requests .get (url )
513
+ if response .status_code == 200 :
514
+ bootstrap_info = response .json ()
515
+ return {engine_rank : bootstrap_info }, {}
516
+ else :
517
+ logger .error (
518
+ f"Failed to get prefill server info: { response .status_code } , { response .text } "
519
+ )
520
+ return None
432
521
except Exception as e :
433
522
logger .error (f"Error fetching prefill info from bootstrap: { e } " )
434
523
return None
@@ -440,11 +529,20 @@ def _connect(self, endpoint: str):
440
529
return socket
441
530
442
531
def init (self , kv_indices : npt .NDArray [np .int64 ], aux_index : Optional [int ] = None ):
443
- self .prefill_server_url = (
444
- f"{ self .bootstrap_info ['rank_ip' ]} :{ self .bootstrap_info ['rank_port' ]} "
445
- )
532
+
533
+ assert self .bootstrap_info is not None
534
+
535
+ sources = self .bootstrap_info [0 ]
536
+ dummy = self .bootstrap_info [1 ]
537
+
538
+ assert len (sources ) == 1 , "Only support one source now"
539
+
540
+ remote_rank = list (self .bootstrap_info [0 ].keys ())[0 ]
541
+
542
+ self .prefill_server_url = f"{ self .bootstrap_info [0 ][remote_rank ]['rank_ip' ]} :{ self .bootstrap_info [0 ][remote_rank ]['rank_port' ]} "
543
+
446
544
logger .debug (
447
- f"Fetched bootstrap info: { self . bootstrap_info } for engine rank: { self .kv_mgr .kv_args .engine_rank } "
545
+ f"Fetched bootstrap info for engine rank: { self .kv_mgr .kv_args .engine_rank } , source: { self . bootstrap_info [ 0 ]. keys () } , dummy: { self . bootstrap_info [ 1 ]. keys () } "
448
546
)
449
547
450
548
packed_kv_data_ptrs = b"" .join (
@@ -466,17 +564,26 @@ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = Non
466
564
str (self .kv_mgr .kv_args .gpu_id ).encode ("ascii" ),
467
565
]
468
566
)
567
+
568
+ for dummy_rank , dummy_info in dummy .items ():
569
+ dummy_url = f"{ dummy_info ['rank_ip' ]} :{ dummy_info ['rank_port' ]} "
570
+ self ._connect ("tcp://" + dummy_url ).send_multipart (
571
+ [
572
+ str (self .bootstrap_room ).encode ("ascii" ),
573
+ ]
574
+ )
575
+
469
576
self .started_transfer = True
470
577
471
578
def poll (self ) -> KVPoll :
472
579
if not self .started_transfer :
473
- return KVPoll .WaitingForInput
580
+ return KVPoll .WaitingForInput # type: ignore
474
581
475
582
self .kv_mgr .update_transfer_status ()
476
583
477
- if self .kv_mgr .check_transfer_done (self .bootstrap_room ):
478
- return KVPoll .Success
479
- return KVPoll .WaitingForInput
584
+ if self .kv_mgr .check_transfer_done (self .bootstrap_room ): # type: ignore
585
+ return KVPoll .Success # type: ignore
586
+ return KVPoll .WaitingForInput # type: ignore
480
587
481
588
def failure_exception (self ):
482
589
raise Exception ("Fake KVReceiver Exception" )
@@ -564,13 +671,13 @@ async def _handle_route_put(self, request: web.Request):
564
671
engine_rank = int (data ["engine_rank" ])
565
672
agent_name = data ["agent_name" ]
566
673
567
- # Add lock to make sure thread-safe
568
674
if role == "Prefill" :
569
- self .prefill_port_table [engine_rank ] = {
570
- "rank_ip" : rank_ip ,
571
- "rank_port" : rank_port ,
572
- "agent_name" : agent_name ,
573
- }
675
+ async with self .lock :
676
+ self .prefill_port_table [engine_rank ] = {
677
+ "rank_ip" : rank_ip ,
678
+ "rank_port" : rank_port ,
679
+ "agent_name" : agent_name ,
680
+ }
574
681
logger .info (
575
682
f"Registered Prefill boostrap: { engine_rank } with rank_ip: { rank_ip } and rank_port: { rank_port } and name: { agent_name } "
576
683
)
@@ -580,7 +687,13 @@ async def _handle_route_put(self, request: web.Request):
580
687
async def _handle_route_get (self , request : web .Request ):
581
688
engine_rank = request .query .get ("engine_rank" )
582
689
if not engine_rank :
583
- return web .Response (text = "Missing rank" , status = 400 )
690
+ logger .debug (
691
+ f"No engine_rank specified, return all { len (self .prefill_port_table )} engine infos as a dict"
692
+ )
693
+ # Return a dict of all engine_rank
694
+ async with self .lock :
695
+ bootstrap_info = self .prefill_port_table
696
+ return web .json_response (bootstrap_info , status = 200 )
584
697
585
698
# Find corresponding prefill info
586
699
async with self .lock :
0 commit comments