@@ -43,6 +43,7 @@ const (
4343 SGLangBootstrapPort int64 = 8998
4444 SGLangBootstrapPortIdentifier string = "model.aibrix.ai/sglang-bootstrap-port"
4545 LLMEngineIdentifier string = constants .ModelLabelEngine
46+ PDRoleSetIdentifier string = "roleset-name"
4647 PDRoleIdentifier string = "role-name"
4748 RoleReplicaIndex string = "stormservice.orchestration.aibrix.ai/role-replica-index"
4849 PodGroupIndex string = "stormservice.orchestration.aibrix.ai/pod-group-index"
@@ -90,12 +91,18 @@ func (r pdRouter) Route(ctx *types.RoutingContext, readyPodList types.PodList) (
9091 return "" , err
9192 }
9293
93- if err = r .doPrefillRequest (ctx , prefillPods , getLLMEngine (prefillPods [0 ], LLMEngineIdentifier , VLLMEngine )); err != nil {
94+ prefillPod , err := r .doPrefillRequest (ctx , prefillPods , getLLMEngine (prefillPods [0 ], LLMEngineIdentifier , VLLMEngine ))
95+ if err != nil {
9496 klog .ErrorS (err , "prefill request failed" , "request_id" , ctx .RequestID )
9597 return "" , err
9698 }
9799
98- decodePod := r .selectDecodePod (decodePods )
100+ decodePod := r .selectDecodePod (prefillPod , decodePods )
101+ if decodePod == nil {
102+ return "" , fmt .Errorf ("decode pod not found" )
103+ }
104+
105+ klog .InfoS ("P/D" , "prefill_pod" , prefillPod .Name , "decode_pod" , decodePod .Name )
99106
100107 ctx .SetTargetPod (decodePod )
101108 return ctx .TargetAddress (), nil
@@ -148,15 +155,30 @@ func (r *pdRouter) evaluatePrefixCache(ctx *types.RoutingContext, prefillPods []
148155 return prefillPod , prefixHashes , err
149156}
150157
151- func (r * pdRouter ) selectDecodePod (decodePods []* v1.Pod ) * v1.Pod {
152- decodePod , _ := utils .SelectRandomPod (decodePods , rand .Intn )
158+ func (r * pdRouter ) selectDecodePod (prefillPod * v1.Pod , decodePods []* v1.Pod ) * v1.Pod {
159+ prefillRoleSet , ok := prefillPod .Labels [PDRoleSetIdentifier ]
160+ if ! ok {
161+ return nil
162+ }
163+
164+ filteredDecodePods := []* v1.Pod {}
165+ for _ , pod := range decodePods {
166+ if podRoleSet , exists := pod .Labels [PDRoleSetIdentifier ]; exists && podRoleSet == prefillRoleSet {
167+ filteredDecodePods = append (filteredDecodePods , pod )
168+ }
169+ }
170+ if len (filteredDecodePods ) == 0 {
171+ return nil
172+ }
173+
174+ decodePod , _ := utils .SelectRandomPod (filteredDecodePods , rand .Intn )
153175 return decodePod
154176}
155177
156- func (r * pdRouter ) doPrefillRequest (routingCtx * types.RoutingContext , prefillPods []* v1.Pod , llmEngine string ) error {
178+ func (r * pdRouter ) doPrefillRequest (routingCtx * types.RoutingContext , prefillPods []* v1.Pod , llmEngine string ) ( * v1. Pod , error ) {
157179 prefillPod , prefixHashes , err := r .evaluatePrefixCache (routingCtx , prefillPods )
158180 if err != nil {
159- return err
181+ return nil , err
160182 }
161183 defer func () {
162184 if len (prefixHashes ) > 0 {
@@ -167,7 +189,7 @@ func (r *pdRouter) doPrefillRequest(routingCtx *types.RoutingContext, prefillPod
167189 // Prepare prefill request payload
168190 payload , err := r .preparePrefillPayload (routingCtx , prefillPod , llmEngine )
169191 if err != nil {
170- return fmt .Errorf ("failed to prepare prefill payload: %w" , err )
192+ return nil , fmt .Errorf ("failed to prepare prefill payload: %w" , err )
171193 }
172194
173195 // Execute HTTP request
@@ -183,20 +205,32 @@ func (r *pdRouter) doPrefillRequest(routingCtx *types.RoutingContext, prefillPod
183205
184206 if llmEngine == SGLangEngine {
185207 go func () {
186- if err := r .executeHTTPRequest (apiURL , routingCtx , payload ); err != nil {
208+ if _ , err := r .executeHTTPRequest (apiURL , routingCtx , payload ); err != nil {
187209 klog .ErrorS (err , "prefill request for sglang failed" , "request_id" , routingCtx .RequestID )
188210 return
189211 }
190212 klog .InfoS ("prefill_request_complete" , "request_id" , routingCtx .RequestID )
191213 }()
214+ } else if llmEngine == VLLMEngine {
215+ responseData , err := r .executeHTTPRequest (apiURL , routingCtx , payload )
216+ if err != nil {
217+ return nil , fmt .Errorf ("failed to execute prefill request: %w" , err )
218+ }
219+
220+ // Update routing context with KV transfer params from prefill response for vLLM
221+ if err := r .updateRoutingContextWithKVTransferParams (routingCtx , responseData , prefillPod ); err != nil {
222+ return nil , fmt .Errorf ("failed to update routing context with KV transfer params: %w" , err )
223+ }
224+
225+ klog .InfoS ("prefill_request_complete" , "request_id" , routingCtx .RequestID , "prefill_pod_ip" , prefillPod .Status .PodIP )
192226 } else {
193- if err := r .executeHTTPRequest (apiURL , routingCtx , payload ); err != nil {
194- return fmt .Errorf ("failed to execute prefill request: %w" , err )
227+ if _ , err := r .executeHTTPRequest (apiURL , routingCtx , payload ); err != nil {
228+ return nil , fmt .Errorf ("failed to execute prefill request: %w" , err )
195229 }
196- klog .InfoS ("prefill_request_complete" , "request_id" , routingCtx .RequestID )
230+ klog .InfoS ("prefill_request_complete" , "request_id" , routingCtx .RequestID , "prefill_pod_ip" , prefillPod . Status . PodIP )
197231 }
198232
199- return nil
233+ return prefillPod , nil
200234}
201235
202236func (r * pdRouter ) preparePrefillPayload (routingCtx * types.RoutingContext , pod * v1.Pod , llmEngine string ) ([]byte , error ) {
@@ -221,6 +255,18 @@ func (r *pdRouter) preparePrefillPayload(routingCtx *types.RoutingContext, pod *
221255 routingCtx .ReqBody = bodyCopy
222256 }
223257
258+ // Add nixl-specific kv_transfer_params for vLLM prefill requests only
259+ if llmEngine == VLLMEngine {
260+ completionRequest ["kv_transfer_params" ] = map [string ]any {
261+ "do_remote_decode" : true ,
262+ "do_remote_prefill" : false ,
263+ "remote_engine_id" : nil ,
264+ "remote_block_ids" : nil ,
265+ "remote_host" : nil ,
266+ "remote_port" : nil ,
267+ }
268+ }
269+
224270 // Set prefill-specific parameters
225271 completionRequest ["max_tokens" ] = 1
226272 completionRequest ["max_completion_tokens" ] = 1
@@ -230,36 +276,83 @@ func (r *pdRouter) preparePrefillPayload(routingCtx *types.RoutingContext, pod *
230276 return json .Marshal (completionRequest )
231277}
232278
233- func (r * pdRouter ) executeHTTPRequest (url string , routingCtx * types.RoutingContext , payload []byte ) error {
279+ func (r * pdRouter ) executeHTTPRequest (url string , routingCtx * types.RoutingContext , payload []byte ) ( map [ string ] any , error ) {
234280 // Create request with context
235281 req , err := http .NewRequest ("POST" , url , bytes .NewBuffer (payload ))
236282 if err != nil {
237- return fmt .Errorf ("failed to create http prefill request: %w" , err )
283+ return nil , fmt .Errorf ("failed to create http prefill request: %w" , err )
238284 }
239285
240286 // Set headers
241287 for key , value := range routingCtx .ReqHeaders {
242288 req .Header .Set (key , value )
243289 }
244290 req .Header .Set ("content-type" , "application/json" )
245- req .Header .Set ("content-length " , strconv . Itoa ( len ( payload )) )
291+ req .Header .Set ("X-Request-Id " , routingCtx . RequestID )
246292
247- // Execute with timeout
248- client := & http.Client {Timeout : time .Duration (prefillRequestTimeout ) * time .Second }
293+ client := & http.Client {
294+ Timeout : time .Duration (prefillRequestTimeout ) * time .Second ,
295+ }
249296 resp , err := client .Do (req )
250297 if err != nil {
251- return fmt .Errorf ("failed to execute http prefill request: %w" , err )
298+ return nil , fmt .Errorf ("failed to execute http prefill request: %w" , err )
252299 }
253300 defer func () {
254301 _ = resp .Body .Close ()
255302 }()
256303
304+ // Read response body
305+ body , err := io .ReadAll (resp .Body )
306+ if err != nil {
307+ return nil , fmt .Errorf ("failed to read prefill response body: %w" , err )
308+ }
309+
257310 // Check response status
258311 if resp .StatusCode != http .StatusOK {
259- body , _ := io .ReadAll (resp .Body )
260- return fmt .Errorf ("http prefill request failed with status %d: %s" , resp .StatusCode , string (body ))
312+ return nil , fmt .Errorf ("http prefill request failed with status %d: %s" , resp .StatusCode , string (body ))
313+ }
314+
315+ // Parse response JSON
316+ var responseData map [string ]any
317+ if err := json .Unmarshal (body , & responseData ); err != nil {
318+ return nil , fmt .Errorf ("failed to unmarshal prefill response: %w" , err )
319+ }
320+
321+ return responseData , nil
322+ }
323+
324+ func (r * pdRouter ) updateRoutingContextWithKVTransferParams (routingCtx * types.RoutingContext , responseData map [string ]any , prefillPod * v1.Pod ) error {
325+ // Extract kv_transfer_params from prefill response
326+ kvTransferParams , exists := responseData ["kv_transfer_params" ]
327+ if ! exists {
328+ klog .InfoS ("no kv_transfer_params in prefill response" , "request_id" , routingCtx .RequestID )
329+ return nil
261330 }
262331
332+ // Parse the original request body
333+ var originalRequest map [string ]any
334+ if err := json .Unmarshal (routingCtx .ReqBody , & originalRequest ); err != nil {
335+ return fmt .Errorf ("failed to unmarshal original request body: %w" , err )
336+ }
337+
338+ // Update request body with KV transfer params from prefill response
339+ originalRequest ["kv_transfer_params" ] = kvTransferParams
340+
341+ // Add prefill host information following the Python pattern
342+ if kvTransferParamsMap , ok := kvTransferParams .(map [string ]any ); ok {
343+ kvTransferParamsMap ["remote_host" ] = prefillPod .Status .PodIP
344+ }
345+
346+ // Marshal the updated request body
347+ updatedReqBody , err := json .Marshal (originalRequest )
348+ if err != nil {
349+ return fmt .Errorf ("failed to marshal updated request body: %w" , err )
350+ }
351+
352+ // Update routing context with new request body
353+ routingCtx .ReqBody = updatedReqBody
354+
355+ klog .InfoS ("updated routing context with kv_transfer_params" , "request_id" , routingCtx .RequestID , "prefill_host" , prefillPod .Status .PodIP )
263356 return nil
264357}
265358
0 commit comments