Skip to content

Commit 8bd7ead

Browse files
Jeffwanvarungup90googs1025DwyaneShiHaiyang Shi
authored
Cherry picks #1409 #1412 #1425 #1436 #1429 #1427 #1442 #1441 to release-0.4 branch (#1468)
* Select PD workers in same roleset (#1409) * Select PD workers in same roleset * nit * update ut --------- Signed-off-by: Varun Gupta <[email protected]> * [Bug] fix webhook config output when using make manifests (#1412) fix webhook config output when using make manifests Signed-off-by: googs1025 <[email protected]> * [Fix] Fix vLLM NIXL-based P/D samples (#1425) Signed-off-by: Haiyang Shi <[email protected]> Co-authored-by: Haiyang Shi <[email protected]> * [Fix] Disable GGA in NIXL samples (#1436) [Fix] Fix NIXL samples Explicitly set UCX_TLS to let UCX not use GGA (GPU Direct) transport Signed-off-by: Haiyang Shi <[email protected]> Co-authored-by: Haiyang Shi <[email protected]> * Fix P/D disaggregation router to follow Nixl kv_transfer_params (#1429) - Add kv_transfer_params configuration to prefill requests and decode requests Signed-off-by: Jiaxin Shan <[email protected]> * [Bug] Corrected naming convention for AIBRIX_MODEL_GPU_PROFILE_CACHING_FLAG (#1427) Corrected naming convention for AIBRIX_MODEL_GPU_PROFILE_CACHING_FLAG Signed-off-by: Jonathon Shea <[email protected]> * [Bug] stormservice's headless service not set ownerRef (#1442) * fix: stormservice's headless service not set ownerRef Signed-off-by: dajun.cui <[email protected]> * fix: patch ut test for service sync Signed-off-by: dajun.cui <[email protected]> --------- Signed-off-by: dajun.cui <[email protected]> * [Bug] stormservice's headless service need set PublishNotReadyAddresses (#1441) * fix: stormservice's headless service need set PublishNotReadyAddresses Signed-off-by: dajun.cui <[email protected]> * fix: isServiceEqual check PublishNotReadyAddresses Signed-off-by: dajun.cui <[email protected]> --------- Signed-off-by: dajun.cui <[email protected]> --------- Signed-off-by: Varun Gupta <[email protected]> Signed-off-by: googs1025 <[email protected]> Signed-off-by: Haiyang Shi <[email protected]> Signed-off-by: Jiaxin Shan <[email protected]> Signed-off-by: Jonathon Shea <[email protected]> Signed-off-by: dajun.cui <[email protected]> Co-authored-by: Varun Gupta <[email protected]> Co-authored-by: CYJiang <[email protected]> Co-authored-by: Haiyang Shi <[email protected]> Co-authored-by: Haiyang Shi <[email protected]> Co-authored-by: Jonathon Shea <[email protected]> Co-authored-by: cuidajun <[email protected]>
1 parent d5bb4d9 commit 8bd7ead

File tree

17 files changed

+589
-59
lines changed

17 files changed

+589
-59
lines changed

config/webhook/manifests.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,37 +56,37 @@ webhooks:
5656
service:
5757
name: webhook-service
5858
namespace: system
59-
path: /validate-model-aibrix-ai-v1alpha1-modeladapter
59+
path: /validate-orchestration-aibrix-ai-v1alpha1-kvcache
6060
failurePolicy: Fail
61-
name: vmodeladapter.kb.io
61+
name: vkvcache-v1alpha1.kb.io
6262
rules:
6363
- apiGroups:
64-
- model.aibrix.ai
64+
- orchestration.aibrix.ai
6565
apiVersions:
6666
- v1alpha1
6767
operations:
6868
- CREATE
6969
- UPDATE
7070
resources:
71-
- modeladapters
71+
- kvcaches
7272
sideEffects: None
7373
- admissionReviewVersions:
7474
- v1
7575
clientConfig:
7676
service:
7777
name: webhook-service
7878
namespace: system
79-
path: /validate-orchestration-aibrix-ai-v1alpha1-kvcache
79+
path: /validate-model-aibrix-ai-v1alpha1-modeladapter
8080
failurePolicy: Fail
81-
name: vkvcache-v1alpha1.kb.io
81+
name: vmodeladapter.kb.io
8282
rules:
8383
- apiGroups:
84-
- orchestration.aibrix.ai
84+
- model.aibrix.ai
8585
apiVersions:
8686
- v1alpha1
8787
operations:
8888
- CREATE
8989
- UPDATE
9090
resources:
91-
- kvcaches
91+
- modeladapters
9292
sideEffects: None

pkg/cache/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ Kubernetes informers for watching:
9494

9595
**Performance:**
9696
- `AIBRIX_POD_METRIC_REFRESH_INTERVAL_MS`: Metric refresh interval
97-
- `AIBRIX_Model_GPU_PROFILE_CACHING_FLAG`: Enable GPU profile caching
97+
- `AIBRIX_MODEL_GPU_PROFILE_CACHING_FLAG`: Enable GPU profile caching
9898

9999
## Usage Example
100100

pkg/cache/model_gpu_profile.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ const defaultModelGPUProfileRefreshInterval = 10 * time.Second
4747
var enableModelGPUProfileCaching = getModelGPUProfileCachingFlag()
4848

4949
func getModelGPUProfileCachingFlag() bool {
50-
value := utils.LoadEnv("AIBRIX_Model_GPU_PROFILE_CACHING_FLAG", "true")
50+
value := utils.LoadEnv("AIBRIX_MODEL_GPU_PROFILE_CACHING_FLAG", "true")
5151
boolVal, err := strconv.ParseBool(value)
5252
if err != nil || !boolVal {
5353
return false

pkg/controller/stormservice/sync.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,15 @@ func (r *StormServiceReconciler) syncHeadlessService(ctx context.Context, servic
7676
Name: service.Name,
7777
Namespace: service.Namespace,
7878
Labels: service.Labels,
79+
OwnerReferences: []metav1.OwnerReference{
80+
*metav1.NewControllerRef(service, orchestrationv1alpha1.SchemeGroupVersion.WithKind(orchestrationv1alpha1.StormServiceKind)),
81+
},
7982
},
8083
Spec: corev1.ServiceSpec{
81-
Type: corev1.ServiceTypeClusterIP,
82-
ClusterIP: corev1.ClusterIPNone,
83-
Selector: map[string]string{constants.StormServiceNameLabelKey: service.Name},
84+
Type: corev1.ServiceTypeClusterIP,
85+
ClusterIP: corev1.ClusterIPNone,
86+
Selector: map[string]string{constants.StormServiceNameLabelKey: service.Name},
87+
PublishNotReadyAddresses: true,
8488
},
8589
}
8690

pkg/controller/stormservice/sync_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,28 @@ func TestSyncHeadlessService(t *testing.T) {
173173
},
174174
wantError: false,
175175
},
176+
{
177+
name: "service already exists with PublishNotReadyAddresses false",
178+
stormService: &orchestrationv1alpha1.StormService{
179+
ObjectMeta: metav1.ObjectMeta{
180+
Name: "test-storm",
181+
Namespace: "default",
182+
},
183+
},
184+
existingService: &corev1.Service{
185+
ObjectMeta: metav1.ObjectMeta{
186+
Name: "test-storm",
187+
Namespace: "default",
188+
},
189+
Spec: corev1.ServiceSpec{
190+
Type: corev1.ServiceTypeClusterIP,
191+
ClusterIP: corev1.ClusterIPNone,
192+
Selector: map[string]string{constants.StormServiceNameLabelKey: "test-storm"},
193+
PublishNotReadyAddresses: false, // should be updated to true
194+
},
195+
},
196+
wantError: false,
197+
},
176198
}
177199

178200
for _, tt := range tests {
@@ -216,6 +238,17 @@ func TestSyncHeadlessService(t *testing.T) {
216238
t.Errorf("Expected ClusterIP to be None, got %s", service.Spec.ClusterIP)
217239
}
218240

241+
if tt.existingService == nil {
242+
if len(service.OwnerReferences) == 0 {
243+
t.Error("Expected service to have an owner reference")
244+
} else {
245+
ownerRef := service.OwnerReferences[0]
246+
if ownerRef.Kind != orchestrationv1alpha1.StormServiceKind || ownerRef.UID != service.UID {
247+
t.Errorf("Expected owner reference to be %s %s, got %s %s", orchestrationv1alpha1.StormServiceKind, service.UID, ownerRef.Kind, ownerRef.UID)
248+
}
249+
}
250+
}
251+
219252
expectedSelector := map[string]string{constants.StormServiceNameLabelKey: tt.stormService.Name}
220253
if !reflect.DeepEqual(service.Spec.Selector, expectedSelector) {
221254
t.Errorf("Expected selector %v, got %v", expectedSelector, service.Spec.Selector)
@@ -224,6 +257,10 @@ func TestSyncHeadlessService(t *testing.T) {
224257
if service.Spec.Type != corev1.ServiceTypeClusterIP {
225258
t.Errorf("Expected service type ClusterIP, got %v", service.Spec.Type)
226259
}
260+
261+
if service.Spec.PublishNotReadyAddresses != true {
262+
t.Errorf("Expected PublishNotReadyAddresses to be true, got %v", service.Spec.PublishNotReadyAddresses)
263+
}
227264
})
228265
}
229266
}

pkg/controller/stormservice/utils.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,5 +237,6 @@ func sortRoleSetByRevision(roleSets []*orchestrationv1alpha1.RoleSet, updatedRev
237237
func isServiceEqual(a, b *corev1.Service) bool {
238238
return a.Spec.Type == b.Spec.Type &&
239239
apiequality.Semantic.DeepEqual(a.Spec.Selector, b.Spec.Selector) &&
240-
a.Spec.ClusterIP == b.Spec.ClusterIP
240+
a.Spec.ClusterIP == b.Spec.ClusterIP &&
241+
a.Spec.PublishNotReadyAddresses == b.Spec.PublishNotReadyAddresses
241242
}

pkg/plugins/gateway/algorithms/pd_disaggregation.go

Lines changed: 113 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ const (
4343
SGLangBootstrapPort int64 = 8998
4444
SGLangBootstrapPortIdentifier string = "model.aibrix.ai/sglang-bootstrap-port"
4545
LLMEngineIdentifier string = constants.ModelLabelEngine
46+
PDRoleSetIdentifier string = "roleset-name"
4647
PDRoleIdentifier string = "role-name"
4748
RoleReplicaIndex string = "stormservice.orchestration.aibrix.ai/role-replica-index"
4849
PodGroupIndex string = "stormservice.orchestration.aibrix.ai/pod-group-index"
@@ -90,12 +91,18 @@ func (r pdRouter) Route(ctx *types.RoutingContext, readyPodList types.PodList) (
9091
return "", err
9192
}
9293

93-
if err = r.doPrefillRequest(ctx, prefillPods, getLLMEngine(prefillPods[0], LLMEngineIdentifier, VLLMEngine)); err != nil {
94+
prefillPod, err := r.doPrefillRequest(ctx, prefillPods, getLLMEngine(prefillPods[0], LLMEngineIdentifier, VLLMEngine))
95+
if err != nil {
9496
klog.ErrorS(err, "prefill request failed", "request_id", ctx.RequestID)
9597
return "", err
9698
}
9799

98-
decodePod := r.selectDecodePod(decodePods)
100+
decodePod := r.selectDecodePod(prefillPod, decodePods)
101+
if decodePod == nil {
102+
return "", fmt.Errorf("decode pod not found")
103+
}
104+
105+
klog.InfoS("P/D", "prefill_pod", prefillPod.Name, "decode_pod", decodePod.Name)
99106

100107
ctx.SetTargetPod(decodePod)
101108
return ctx.TargetAddress(), nil
@@ -148,15 +155,30 @@ func (r *pdRouter) evaluatePrefixCache(ctx *types.RoutingContext, prefillPods []
148155
return prefillPod, prefixHashes, err
149156
}
150157

151-
func (r *pdRouter) selectDecodePod(decodePods []*v1.Pod) *v1.Pod {
152-
decodePod, _ := utils.SelectRandomPod(decodePods, rand.Intn)
158+
func (r *pdRouter) selectDecodePod(prefillPod *v1.Pod, decodePods []*v1.Pod) *v1.Pod {
159+
prefillRoleSet, ok := prefillPod.Labels[PDRoleSetIdentifier]
160+
if !ok {
161+
return nil
162+
}
163+
164+
filteredDecodePods := []*v1.Pod{}
165+
for _, pod := range decodePods {
166+
if podRoleSet, exists := pod.Labels[PDRoleSetIdentifier]; exists && podRoleSet == prefillRoleSet {
167+
filteredDecodePods = append(filteredDecodePods, pod)
168+
}
169+
}
170+
if len(filteredDecodePods) == 0 {
171+
return nil
172+
}
173+
174+
decodePod, _ := utils.SelectRandomPod(filteredDecodePods, rand.Intn)
153175
return decodePod
154176
}
155177

156-
func (r *pdRouter) doPrefillRequest(routingCtx *types.RoutingContext, prefillPods []*v1.Pod, llmEngine string) error {
178+
func (r *pdRouter) doPrefillRequest(routingCtx *types.RoutingContext, prefillPods []*v1.Pod, llmEngine string) (*v1.Pod, error) {
157179
prefillPod, prefixHashes, err := r.evaluatePrefixCache(routingCtx, prefillPods)
158180
if err != nil {
159-
return err
181+
return nil, err
160182
}
161183
defer func() {
162184
if len(prefixHashes) > 0 {
@@ -167,7 +189,7 @@ func (r *pdRouter) doPrefillRequest(routingCtx *types.RoutingContext, prefillPod
167189
// Prepare prefill request payload
168190
payload, err := r.preparePrefillPayload(routingCtx, prefillPod, llmEngine)
169191
if err != nil {
170-
return fmt.Errorf("failed to prepare prefill payload: %w", err)
192+
return nil, fmt.Errorf("failed to prepare prefill payload: %w", err)
171193
}
172194

173195
// Execute HTTP request
@@ -183,20 +205,32 @@ func (r *pdRouter) doPrefillRequest(routingCtx *types.RoutingContext, prefillPod
183205

184206
if llmEngine == SGLangEngine {
185207
go func() {
186-
if err := r.executeHTTPRequest(apiURL, routingCtx, payload); err != nil {
208+
if _, err := r.executeHTTPRequest(apiURL, routingCtx, payload); err != nil {
187209
klog.ErrorS(err, "prefill request for sglang failed", "request_id", routingCtx.RequestID)
188210
return
189211
}
190212
klog.InfoS("prefill_request_complete", "request_id", routingCtx.RequestID)
191213
}()
214+
} else if llmEngine == VLLMEngine {
215+
responseData, err := r.executeHTTPRequest(apiURL, routingCtx, payload)
216+
if err != nil {
217+
return nil, fmt.Errorf("failed to execute prefill request: %w", err)
218+
}
219+
220+
// Update routing context with KV transfer params from prefill response for vLLM
221+
if err := r.updateRoutingContextWithKVTransferParams(routingCtx, responseData, prefillPod); err != nil {
222+
return nil, fmt.Errorf("failed to update routing context with KV transfer params: %w", err)
223+
}
224+
225+
klog.InfoS("prefill_request_complete", "request_id", routingCtx.RequestID, "prefill_pod_ip", prefillPod.Status.PodIP)
192226
} else {
193-
if err := r.executeHTTPRequest(apiURL, routingCtx, payload); err != nil {
194-
return fmt.Errorf("failed to execute prefill request: %w", err)
227+
if _, err := r.executeHTTPRequest(apiURL, routingCtx, payload); err != nil {
228+
return nil, fmt.Errorf("failed to execute prefill request: %w", err)
195229
}
196-
klog.InfoS("prefill_request_complete", "request_id", routingCtx.RequestID)
230+
klog.InfoS("prefill_request_complete", "request_id", routingCtx.RequestID, "prefill_pod_ip", prefillPod.Status.PodIP)
197231
}
198232

199-
return nil
233+
return prefillPod, nil
200234
}
201235

202236
func (r *pdRouter) preparePrefillPayload(routingCtx *types.RoutingContext, pod *v1.Pod, llmEngine string) ([]byte, error) {
@@ -221,6 +255,18 @@ func (r *pdRouter) preparePrefillPayload(routingCtx *types.RoutingContext, pod *
221255
routingCtx.ReqBody = bodyCopy
222256
}
223257

258+
// Add nixl-specific kv_transfer_params for vLLM prefill requests only
259+
if llmEngine == VLLMEngine {
260+
completionRequest["kv_transfer_params"] = map[string]any{
261+
"do_remote_decode": true,
262+
"do_remote_prefill": false,
263+
"remote_engine_id": nil,
264+
"remote_block_ids": nil,
265+
"remote_host": nil,
266+
"remote_port": nil,
267+
}
268+
}
269+
224270
// Set prefill-specific parameters
225271
completionRequest["max_tokens"] = 1
226272
completionRequest["max_completion_tokens"] = 1
@@ -230,36 +276,83 @@ func (r *pdRouter) preparePrefillPayload(routingCtx *types.RoutingContext, pod *
230276
return json.Marshal(completionRequest)
231277
}
232278

233-
func (r *pdRouter) executeHTTPRequest(url string, routingCtx *types.RoutingContext, payload []byte) error {
279+
func (r *pdRouter) executeHTTPRequest(url string, routingCtx *types.RoutingContext, payload []byte) (map[string]any, error) {
234280
// Create request with context
235281
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
236282
if err != nil {
237-
return fmt.Errorf("failed to create http prefill request: %w", err)
283+
return nil, fmt.Errorf("failed to create http prefill request: %w", err)
238284
}
239285

240286
// Set headers
241287
for key, value := range routingCtx.ReqHeaders {
242288
req.Header.Set(key, value)
243289
}
244290
req.Header.Set("content-type", "application/json")
245-
req.Header.Set("content-length", strconv.Itoa(len(payload)))
291+
req.Header.Set("X-Request-Id", routingCtx.RequestID)
246292

247-
// Execute with timeout
248-
client := &http.Client{Timeout: time.Duration(prefillRequestTimeout) * time.Second}
293+
client := &http.Client{
294+
Timeout: time.Duration(prefillRequestTimeout) * time.Second,
295+
}
249296
resp, err := client.Do(req)
250297
if err != nil {
251-
return fmt.Errorf("failed to execute http prefill request: %w", err)
298+
return nil, fmt.Errorf("failed to execute http prefill request: %w", err)
252299
}
253300
defer func() {
254301
_ = resp.Body.Close()
255302
}()
256303

304+
// Read response body
305+
body, err := io.ReadAll(resp.Body)
306+
if err != nil {
307+
return nil, fmt.Errorf("failed to read prefill response body: %w", err)
308+
}
309+
257310
// Check response status
258311
if resp.StatusCode != http.StatusOK {
259-
body, _ := io.ReadAll(resp.Body)
260-
return fmt.Errorf("http prefill request failed with status %d: %s", resp.StatusCode, string(body))
312+
return nil, fmt.Errorf("http prefill request failed with status %d: %s", resp.StatusCode, string(body))
313+
}
314+
315+
// Parse response JSON
316+
var responseData map[string]any
317+
if err := json.Unmarshal(body, &responseData); err != nil {
318+
return nil, fmt.Errorf("failed to unmarshal prefill response: %w", err)
319+
}
320+
321+
return responseData, nil
322+
}
323+
324+
func (r *pdRouter) updateRoutingContextWithKVTransferParams(routingCtx *types.RoutingContext, responseData map[string]any, prefillPod *v1.Pod) error {
325+
// Extract kv_transfer_params from prefill response
326+
kvTransferParams, exists := responseData["kv_transfer_params"]
327+
if !exists {
328+
klog.InfoS("no kv_transfer_params in prefill response", "request_id", routingCtx.RequestID)
329+
return nil
261330
}
262331

332+
// Parse the original request body
333+
var originalRequest map[string]any
334+
if err := json.Unmarshal(routingCtx.ReqBody, &originalRequest); err != nil {
335+
return fmt.Errorf("failed to unmarshal original request body: %w", err)
336+
}
337+
338+
// Update request body with KV transfer params from prefill response
339+
originalRequest["kv_transfer_params"] = kvTransferParams
340+
341+
// Add prefill host information following the Python pattern
342+
if kvTransferParamsMap, ok := kvTransferParams.(map[string]any); ok {
343+
kvTransferParamsMap["remote_host"] = prefillPod.Status.PodIP
344+
}
345+
346+
// Marshal the updated request body
347+
updatedReqBody, err := json.Marshal(originalRequest)
348+
if err != nil {
349+
return fmt.Errorf("failed to marshal updated request body: %w", err)
350+
}
351+
352+
// Update routing context with new request body
353+
routingCtx.ReqBody = updatedReqBody
354+
355+
klog.InfoS("updated routing context with kv_transfer_params", "request_id", routingCtx.RequestID, "prefill_host", prefillPod.Status.PodIP)
263356
return nil
264357
}
265358

0 commit comments

Comments
 (0)