Skip to content

Commit 4376962

Browse files
authored
Support model registration flow using aibrix runtime api (#580)
* Introduce RuntimeConfig to all controllers * Refactor the logic to construct URLs based on different envs * Leverage runtime api to manage lora load & unload * Fix several bugs 1. hugginface protocol shadow assignment bug 2. wrong runtime port 3. wrong host used in buildurls 4. can not forward entire headers due to content length mismatch * Format files * Address code review feedback --------- Signed-off-by: Jiaxin Shan <[email protected]>
1 parent 1417a45 commit 4376962

File tree

16 files changed

+314
-77
lines changed

16 files changed

+314
-77
lines changed

cmd/controllers/main.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ import (
4646
"sigs.k8s.io/controller-runtime/pkg/webhook"
4747

4848
"github.com/aibrix/aibrix/pkg/cache"
49+
"github.com/aibrix/aibrix/pkg/config"
4950
"github.com/aibrix/aibrix/pkg/controller"
5051
//+kubebuilder:scaffold:imports
5152
)
@@ -105,6 +106,8 @@ func main() {
105106
var leaderElectionResourceLock string
106107
var leaderElectionId string
107108
var controllers string
109+
var enableRuntimeSidecar bool
110+
var debugMode bool
108111
flag.StringVar(&metricsAddr, "metrics-bind-address", ":8080", "The address the metric endpoint binds to.")
109112
flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.")
110113
flag.BoolVar(&enableLeaderElection, "leader-elect", false,
@@ -126,6 +129,10 @@ func main() {
126129
flag.StringVar(&leaderElectionId, "leader-election-id", "aibrix-controller-manager",
127130
"leader-election-id determines the name of the resource that leader election will use for holding the leader lock, Default is aibrix-controller-manager.")
128131
flag.StringVar(&controllers, "controllers", "*", "Comma-separated list of controllers to enable or disable, default value is * which indicates all controllers should be started.")
132+
flag.BoolVar(&enableRuntimeSidecar, "enable-runtime-sidecar", false,
133+
"If set, Runtime management API will be enabled for the metrics, model adapter and model downloading interactions, control plane will not talk to engine directly anymore")
134+
flag.BoolVar(&debugMode, "debug-mode", false,
135+
"If set, control plane will talk to localhost nodePort for testing purpose")
129136

130137
// Initialize the klog
131138
klog.InitFlags(flag.CommandLine)
@@ -164,6 +171,8 @@ func main() {
164171
tlsOpts = append(tlsOpts, disableHTTP2)
165172
}
166173

174+
runtimeConfig := config.NewRuntimeConfig(enableRuntimeSidecar, debugMode)
175+
167176
webhookServer := webhook.NewServer(webhook.Options{
168177
TLSOpts: tlsOpts,
169178
})
@@ -229,7 +238,7 @@ func main() {
229238

230239
// Kind controller registration is encapsulated inside the pkg/controller/controller.go
231240
// So here we can use more clean registration flow and there's no need to change logics in future.
232-
if err = controller.SetupWithManager(mgr); err != nil {
241+
if err = controller.SetupWithManager(mgr, runtimeConfig); err != nil {
233242
setupLog.Error(err, "unable to setup controller")
234243
os.Exit(1)
235244
}

config/manager/manager.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ spec:
3030
- --leader-elect
3131
- --health-probe-bind-address=:8081
3232
- --metrics-bind-address=0
33+
- --enable-runtime-sidecar
3334
image: controller:latest
3435
name: manager
3536
securityContext:

pkg/config/config.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
Copyright 2024 The Aibrix Team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package config
18+
19+
type RuntimeConfig struct {
20+
EnableRuntimeSidecar bool
21+
DebugMode bool
22+
}
23+
24+
// NewRuntimeConfig creates a new RuntimeConfig with specified settings.
25+
func NewRuntimeConfig(enableRuntimeSidecar, debugMode bool) RuntimeConfig {
26+
return RuntimeConfig{
27+
EnableRuntimeSidecar: enableRuntimeSidecar,
28+
DebugMode: debugMode,
29+
}
30+
}

pkg/controller/controller.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package controller
1818

1919
import (
20+
"github.com/aibrix/aibrix/pkg/config"
2021
"github.com/aibrix/aibrix/pkg/controller/modeladapter"
2122
"github.com/aibrix/aibrix/pkg/controller/modelrouter"
2223
"github.com/aibrix/aibrix/pkg/controller/podautoscaler"
@@ -33,7 +34,7 @@ import (
3334
// Reason: We have single controller-manager as well and use the controller-runtime libraries.
3435
// Instead of registering every controller in the main.go, kruise's registration flow is much cleaner.
3536

36-
var controllerAddFuncs []func(manager.Manager) error
37+
var controllerAddFuncs []func(manager.Manager, config.RuntimeConfig) error
3738

3839
func Initialize() {
3940
if features.IsControllerEnabled(features.PodAutoscalerController) {
@@ -56,9 +57,9 @@ func Initialize() {
5657
}
5758

5859
// SetupWithManager sets up the controller with the Manager.
59-
func SetupWithManager(m manager.Manager) error {
60+
func SetupWithManager(m manager.Manager, runtimeConfig config.RuntimeConfig) error {
6061
for _, f := range controllerAddFuncs {
61-
if err := f(m); err != nil {
62+
if err := f(m, runtimeConfig); err != nil {
6263
if kindMatchErr, ok := err.(*meta.NoKindMatchError); ok {
6364
klog.InfoS("CRD is not installed, its controller will perform noops!", "CRD", kindMatchErr.GroupKind)
6465
continue

pkg/controller/modeladapter/modeladapter_controller.go

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929

3030
modelv1alpha1 "github.com/aibrix/aibrix/api/model/v1alpha1"
3131
"github.com/aibrix/aibrix/pkg/cache"
32+
"github.com/aibrix/aibrix/pkg/config"
3233
"github.com/aibrix/aibrix/pkg/controller/modeladapter/scheduling"
3334
"github.com/aibrix/aibrix/pkg/utils"
3435
corev1 "k8s.io/api/core/v1"
@@ -83,6 +84,18 @@ const (
8384
ModelAdapterAvailable = "ModelAdapterAvailable"
8485
// ModelAdapterUnavailable is added in a ModelAdapter when it doesn't have any pod hosting it.
8586
ModelAdapterUnavailable = "ModelAdapterUnavailable"
87+
88+
// Inference Service path and ports
89+
DefaultInferenceEnginePort = "8000"
90+
DefaultDebugInferenceEnginePort = "30081"
91+
DefaultRuntimeAPIPort = "8080"
92+
93+
ModelListPath = "/v1/models"
94+
ModelListRuntimeAPIPath = "/v1/models"
95+
LoadLoraAdapterPath = "/v1/load_lora_adapter"
96+
LoadLoraRuntimeAPIPath = "/v1/lora_adapter/load"
97+
UnloadLoraAdapterPath = "/v1/unload_lora_adapter"
98+
UnloadLoraRuntimeAPIPath = "/v1/lora_adapter/unload"
8699
)
87100

88101
var (
@@ -92,18 +105,25 @@ var (
92105
defaultRequeueDuration = 3 * time.Second
93106
)
94107

108+
type URLConfig struct {
109+
BaseURL string
110+
ListModelsURL string
111+
LoadAdapterURL string
112+
UnloadAdapterURL string
113+
}
114+
95115
// Add creates a new ModelAdapter Controller and adds it to the Manager with default RBAC.
96116
// The Manager will set fields on the Controller and Start it when the Manager is Started.
97-
func Add(mgr manager.Manager) error {
98-
r, err := newReconciler(mgr)
117+
func Add(mgr manager.Manager, runtimeConfig config.RuntimeConfig) error {
118+
r, err := newReconciler(mgr, runtimeConfig)
99119
if err != nil {
100120
return err
101121
}
102122
return add(mgr, r)
103123
}
104124

105125
// newReconciler returns a new reconcile.Reconciler
106-
func newReconciler(mgr manager.Manager) (reconcile.Reconciler, error) {
126+
func newReconciler(mgr manager.Manager, runtimeConfig config.RuntimeConfig) (reconcile.Reconciler, error) {
107127
cacher := mgr.GetCache()
108128

109129
podInformer, err := cacher.GetInformer(context.TODO(), &corev1.Pod{})
@@ -146,6 +166,7 @@ func newReconciler(mgr manager.Manager) (reconcile.Reconciler, error) {
146166
EndpointSliceLister: endpointSliceLister,
147167
Recorder: mgr.GetEventRecorderFor(controllerName),
148168
scheduler: scheduler,
169+
RuntimeConfig: runtimeConfig,
149170
}
150171
return reconciler, nil
151172
}
@@ -227,6 +248,7 @@ type ModelAdapterReconciler struct {
227248
ServiceLister corelisters.ServiceLister
228249
// EndpointSliceLister is able to list/get services from a shared informer's cache store
229250
EndpointSliceLister discoverylisters.EndpointSliceLister
251+
RuntimeConfig config.RuntimeConfig
230252
}
231253

232254
//+kubebuilder:rbac:groups=discovery.k8s.io,resources=endpointslices,verbs=get;list;watch;create;update;patch;delete
@@ -517,17 +539,10 @@ func (r *ModelAdapterReconciler) reconcileLoading(ctx context.Context, instance
517539
return nil
518540
}
519541

520-
// Define the key you want to check
521-
key := "DEBUG_MODE"
522-
value, exists := getEnvKey(key)
523-
host := fmt.Sprintf("http://%s:8000", targetPod.Status.PodIP)
524-
if exists && value == "on" {
525-
// 30080 is the nodePort of the base model service.
526-
host = fmt.Sprintf("http://%s:30081", "localhost")
527-
}
542+
urls := BuildURLs(targetPod.Status.PodIP, r.RuntimeConfig)
528543

529544
// Check if the model is already loaded
530-
exists, err = r.modelAdapterExists(host, instance)
545+
exists, err := r.modelAdapterExists(urls.ListModelsURL, instance)
531546
if err != nil {
532547
return err
533548
}
@@ -537,7 +552,7 @@ func (r *ModelAdapterReconciler) reconcileLoading(ctx context.Context, instance
537552
}
538553

539554
// Load the Model adapter
540-
err = r.loadModelAdapter(host, instance)
555+
err = r.loadModelAdapter(urls.LoadAdapterURL, instance)
541556
if err != nil {
542557
return err
543558
}
@@ -546,10 +561,7 @@ func (r *ModelAdapterReconciler) reconcileLoading(ctx context.Context, instance
546561
}
547562

548563
// Separate method to check if the model already exists
549-
func (r *ModelAdapterReconciler) modelAdapterExists(host string, instance *modelv1alpha1.ModelAdapter) (bool, error) {
550-
// TODO: /v1/models is the vllm entrypoints, let's support multiple engine in future
551-
url := fmt.Sprintf("%s/v1/models", host)
552-
564+
func (r *ModelAdapterReconciler) modelAdapterExists(url string, instance *modelv1alpha1.ModelAdapter) (bool, error) {
553565
req, err := http.NewRequest("GET", url, nil)
554566
if err != nil {
555567
return false, err
@@ -599,10 +611,11 @@ func (r *ModelAdapterReconciler) modelAdapterExists(host string, instance *model
599611
}
600612

601613
// Separate method to load the LoRA adapter
602-
func (r *ModelAdapterReconciler) loadModelAdapter(host string, instance *modelv1alpha1.ModelAdapter) error {
614+
func (r *ModelAdapterReconciler) loadModelAdapter(url string, instance *modelv1alpha1.ModelAdapter) error {
603615
artifactURL := instance.Spec.ArtifactURL
604616
if strings.HasPrefix(instance.Spec.ArtifactURL, "huggingface://") {
605-
artifactURL, err := extractHuggingFacePath(instance.Spec.ArtifactURL)
617+
var err error
618+
artifactURL, err = extractHuggingFacePath(instance.Spec.ArtifactURL)
606619
if err != nil {
607620
// Handle error, e.g., log it and return
608621
klog.ErrorS(err, "Invalid artifact URL", "artifactURL", artifactURL)
@@ -620,7 +633,6 @@ func (r *ModelAdapterReconciler) loadModelAdapter(host string, instance *modelv1
620633
return err
621634
}
622635

623-
url := fmt.Sprintf("%s/v1/load_lora_adapter", host)
624636
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payloadBytes))
625637
if err != nil {
626638
return err
@@ -682,15 +694,8 @@ func (r *ModelAdapterReconciler) unloadModelAdapter(instance *modelv1alpha1.Mode
682694
return err
683695
}
684696

685-
url := fmt.Sprintf("http://%s:%d/v1/unload_lora_adapter", targetPod.Status.PodIP, 8000)
686-
key := "DEBUG_MODE"
687-
value, exists := getEnvKey(key)
688-
if exists && value == "on" {
689-
// 30080 is the nodePort of the base model service.
690-
url = "http://localhost:30081/v1/unload_lora_adapter"
691-
}
692-
693-
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payloadBytes))
697+
urls := BuildURLs(targetPod.Status.PodIP, r.RuntimeConfig)
698+
req, err := http.NewRequest("POST", urls.UnloadAdapterURL, bytes.NewBuffer(payloadBytes))
694699
if err != nil {
695700
return err
696701
}

pkg/controller/modeladapter/utils.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"strings"
2525

2626
modelv1alpha1 "github.com/aibrix/aibrix/api/model/v1alpha1"
27+
"github.com/aibrix/aibrix/pkg/config"
2728
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2829
)
2930

@@ -141,3 +142,30 @@ func NewCondition(condType string, status metav1.ConditionStatus, reason, msg st
141142
Message: msg,
142143
}
143144
}
145+
146+
func BuildURLs(podIP string, config config.RuntimeConfig) URLConfig {
147+
var host string
148+
if config.DebugMode {
149+
host = fmt.Sprintf("http://%s:%s", "localhost", DefaultDebugInferenceEnginePort)
150+
} else if config.EnableRuntimeSidecar {
151+
host = fmt.Sprintf("http://%s:%s", podIP, DefaultRuntimeAPIPort)
152+
} else {
153+
host = fmt.Sprintf("http://%s:%s", podIP, DefaultInferenceEnginePort)
154+
}
155+
156+
apiPath := ModelListPath
157+
loadPath := LoadLoraAdapterPath
158+
unloadPath := UnloadLoraAdapterPath
159+
if config.EnableRuntimeSidecar {
160+
apiPath = ModelListRuntimeAPIPath
161+
loadPath = LoadLoraRuntimeAPIPath
162+
unloadPath = UnloadLoraRuntimeAPIPath
163+
}
164+
165+
return URLConfig{
166+
BaseURL: host,
167+
ListModelsURL: fmt.Sprintf("%s%s", host, apiPath),
168+
LoadAdapterURL: fmt.Sprintf("%s%s", host, loadPath),
169+
UnloadAdapterURL: fmt.Sprintf("%s%s", host, unloadPath),
170+
}
171+
}

pkg/controller/modeladapter/utils_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package modeladapter
1818

1919
import (
20+
"fmt"
2021
"os"
2122
"testing"
2223

@@ -25,6 +26,7 @@ import (
2526
"k8s.io/utils/ptr"
2627

2728
modelv1alpha1 "github.com/aibrix/aibrix/api/model/v1alpha1"
29+
"github.com/aibrix/aibrix/pkg/config"
2830
)
2931

3032
// Test for validateModelAdapter function
@@ -236,3 +238,73 @@ func TestExtractHuggingFacePath(t *testing.T) {
236238
})
237239
}
238240
}
241+
242+
func TestBuildURLs(t *testing.T) {
243+
tests := []struct {
244+
name string
245+
podIP string
246+
config config.RuntimeConfig
247+
expectedURLs URLConfig
248+
expectError bool
249+
}{
250+
{
251+
name: "Debug mode enabled",
252+
podIP: "192.168.1.1",
253+
config: config.RuntimeConfig{
254+
DebugMode: true,
255+
EnableRuntimeSidecar: false,
256+
},
257+
expectedURLs: URLConfig{
258+
BaseURL: fmt.Sprintf("http://%s:%s", "localhost", DefaultDebugInferenceEnginePort),
259+
ListModelsURL: fmt.Sprintf("http://%s:%s%s", "localhost", DefaultDebugInferenceEnginePort, ModelListPath),
260+
LoadAdapterURL: fmt.Sprintf("http://%s:%s%s", "localhost", DefaultDebugInferenceEnginePort, LoadLoraAdapterPath),
261+
UnloadAdapterURL: fmt.Sprintf("http://%s:%s%s", "localhost", DefaultDebugInferenceEnginePort, UnloadLoraAdapterPath),
262+
},
263+
expectError: false,
264+
},
265+
{
266+
name: "Runtime sidecar enabled",
267+
podIP: "192.168.1.2",
268+
config: config.RuntimeConfig{
269+
DebugMode: false,
270+
EnableRuntimeSidecar: true,
271+
},
272+
expectedURLs: URLConfig{
273+
BaseURL: fmt.Sprintf("http://%s:%s", "192.168.1.2", DefaultRuntimeAPIPort),
274+
ListModelsURL: fmt.Sprintf("http://%s:%s%s", "192.168.1.2", DefaultRuntimeAPIPort, ModelListRuntimeAPIPath),
275+
LoadAdapterURL: fmt.Sprintf("http://%s:%s%s", "192.168.1.2", DefaultRuntimeAPIPort, LoadLoraRuntimeAPIPath),
276+
UnloadAdapterURL: fmt.Sprintf("http://%s:%s%s", "192.168.1.2", DefaultRuntimeAPIPort, UnloadLoraRuntimeAPIPath),
277+
},
278+
expectError: false,
279+
},
280+
{
281+
name: "Default mode",
282+
podIP: "192.168.1.3",
283+
config: config.RuntimeConfig{
284+
DebugMode: false,
285+
EnableRuntimeSidecar: false,
286+
},
287+
expectedURLs: URLConfig{
288+
BaseURL: fmt.Sprintf("http://%s:%s", "192.168.1.3", DefaultInferenceEnginePort),
289+
ListModelsURL: fmt.Sprintf("http://%s:%s%s", "192.168.1.3", DefaultInferenceEnginePort, ModelListPath),
290+
LoadAdapterURL: fmt.Sprintf("http://%s:%s%s", "192.168.1.3", DefaultInferenceEnginePort, LoadLoraAdapterPath),
291+
UnloadAdapterURL: fmt.Sprintf("http://%s:%s%s", "192.168.1.3", DefaultInferenceEnginePort, UnloadLoraAdapterPath),
292+
},
293+
expectError: false,
294+
},
295+
}
296+
297+
for _, tt := range tests {
298+
t.Run(tt.name, func(t *testing.T) {
299+
urls := BuildURLs(tt.podIP, tt.config)
300+
301+
if tt.expectError {
302+
t.Fatalf("Expected error but got none")
303+
} else {
304+
if urls != tt.expectedURLs {
305+
t.Errorf("Expected URLs %+v but got %+v", tt.expectedURLs, urls)
306+
}
307+
}
308+
})
309+
}
310+
}

0 commit comments

Comments
 (0)