Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion cmd/controllers/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/webhook"

"github.com/aibrix/aibrix/pkg/cache"
"github.com/aibrix/aibrix/pkg/config"
"github.com/aibrix/aibrix/pkg/controller"
//+kubebuilder:scaffold:imports
)
Expand Down Expand Up @@ -105,6 +106,8 @@ func main() {
var leaderElectionResourceLock string
var leaderElectionId string
var controllers string
var enableRuntimeSidecar bool
var debugMode bool
flag.StringVar(&metricsAddr, "metrics-bind-address", ":8080", "The address the metric endpoint binds to.")
flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.")
flag.BoolVar(&enableLeaderElection, "leader-elect", false,
Expand All @@ -126,6 +129,10 @@ func main() {
flag.StringVar(&leaderElectionId, "leader-election-id", "aibrix-controller-manager",
"leader-election-id determines the name of the resource that leader election will use for holding the leader lock, Default is aibrix-controller-manager.")
flag.StringVar(&controllers, "controllers", "*", "Comma-separated list of controllers to enable or disable, default value is * which indicates all controllers should be started.")
flag.BoolVar(&enableRuntimeSidecar, "enable-runtime-sidecar", false,
"If set, Runtime management API will be enabled for the metrics, model adapter and model downloading interactions, control plane will not talk to engine directly anymore")
flag.BoolVar(&debugMode, "debug-mode", false,
"If set, control plane will talk to localhost nodePort for testing purpose")

// Initialize the klog
klog.InitFlags(flag.CommandLine)
Expand Down Expand Up @@ -164,6 +171,8 @@ func main() {
tlsOpts = append(tlsOpts, disableHTTP2)
}

runtimeConfig := config.NewRuntimeConfig(enableRuntimeSidecar, debugMode)

webhookServer := webhook.NewServer(webhook.Options{
TLSOpts: tlsOpts,
})
Expand Down Expand Up @@ -229,7 +238,7 @@ func main() {

// Kind controller registration is encapsulated inside the pkg/controller/controller.go
// So here we can use more clean registration flow and there's no need to change logics in future.
if err = controller.SetupWithManager(mgr); err != nil {
if err = controller.SetupWithManager(mgr, runtimeConfig); err != nil {
setupLog.Error(err, "unable to setup controller")
os.Exit(1)
}
Expand Down
1 change: 1 addition & 0 deletions config/manager/manager.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ spec:
- --leader-elect
- --health-probe-bind-address=:8081
- --metrics-bind-address=0
- --enable-runtime-sidecar
image: controller:latest
name: manager
securityContext:
Expand Down
30 changes: 30 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
Copyright 2024 The Aibrix Team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package config

type RuntimeConfig struct {
EnableRuntimeSidecar bool
DebugMode bool
}

// NewRuntimeConfig creates a new RuntimeConfig with specified settings.
func NewRuntimeConfig(enableRuntimeSidecar, debugMode bool) RuntimeConfig {
return RuntimeConfig{
EnableRuntimeSidecar: enableRuntimeSidecar,
DebugMode: debugMode,
}
}
7 changes: 4 additions & 3 deletions pkg/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package controller

import (
"github.com/aibrix/aibrix/pkg/config"
"github.com/aibrix/aibrix/pkg/controller/modeladapter"
"github.com/aibrix/aibrix/pkg/controller/modelrouter"
"github.com/aibrix/aibrix/pkg/controller/podautoscaler"
Expand All @@ -33,7 +34,7 @@ import (
// Reason: We have single controller-manager as well and use the controller-runtime libraries.
// Instead of registering every controller in the main.go, kruise's registration flow is much cleaner.

var controllerAddFuncs []func(manager.Manager) error
var controllerAddFuncs []func(manager.Manager, config.RuntimeConfig) error

func Initialize() {
if features.IsControllerEnabled(features.PodAutoscalerController) {
Expand All @@ -56,9 +57,9 @@ func Initialize() {
}

// SetupWithManager sets up the controller with the Manager.
func SetupWithManager(m manager.Manager) error {
func SetupWithManager(m manager.Manager, runtimeConfig config.RuntimeConfig) error {
for _, f := range controllerAddFuncs {
if err := f(m); err != nil {
if err := f(m, runtimeConfig); err != nil {
if kindMatchErr, ok := err.(*meta.NoKindMatchError); ok {
klog.InfoS("CRD is not installed, its controller will perform noops!", "CRD", kindMatchErr.GroupKind)
continue
Expand Down
63 changes: 34 additions & 29 deletions pkg/controller/modeladapter/modeladapter_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

modelv1alpha1 "github.com/aibrix/aibrix/api/model/v1alpha1"
"github.com/aibrix/aibrix/pkg/cache"
"github.com/aibrix/aibrix/pkg/config"
"github.com/aibrix/aibrix/pkg/controller/modeladapter/scheduling"
"github.com/aibrix/aibrix/pkg/utils"
corev1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -83,6 +84,18 @@ const (
ModelAdapterAvailable = "ModelAdapterAvailable"
// ModelAdapterUnavailable is added in a ModelAdapter when it doesn't have any pod hosting it.
ModelAdapterUnavailable = "ModelAdapterUnavailable"

// Inference Service path and ports
DefaultInferenceEnginePort = "8000"
DefaultDebugInferenceEnginePort = "30081"
DefaultRuntimeAPIPort = "8080"

ModelListPath = "/v1/models"
ModelListRuntimeAPIPath = "/v1/models"
LoadLoraAdapterPath = "/v1/load_lora_adapter"
LoadLoraRuntimeAPIPath = "/v1/lora_adapter/load"
UnloadLoraAdapterPath = "/v1/unload_lora_adapter"
UnloadLoraRuntimeAPIPath = "/v1/lora_adapter/unload"
)

var (
Expand All @@ -92,18 +105,25 @@ var (
defaultRequeueDuration = 3 * time.Second
)

type URLConfig struct {
BaseURL string
ListModelsURL string
LoadAdapterURL string
UnloadAdapterURL string
}

// Add creates a new ModelAdapter Controller and adds it to the Manager with default RBAC.
// The Manager will set fields on the Controller and Start it when the Manager is Started.
func Add(mgr manager.Manager) error {
r, err := newReconciler(mgr)
func Add(mgr manager.Manager, runtimeConfig config.RuntimeConfig) error {
r, err := newReconciler(mgr, runtimeConfig)
if err != nil {
return err
}
return add(mgr, r)
}

// newReconciler returns a new reconcile.Reconciler
func newReconciler(mgr manager.Manager) (reconcile.Reconciler, error) {
func newReconciler(mgr manager.Manager, runtimeConfig config.RuntimeConfig) (reconcile.Reconciler, error) {
cacher := mgr.GetCache()

podInformer, err := cacher.GetInformer(context.TODO(), &corev1.Pod{})
Expand Down Expand Up @@ -146,6 +166,7 @@ func newReconciler(mgr manager.Manager) (reconcile.Reconciler, error) {
EndpointSliceLister: endpointSliceLister,
Recorder: mgr.GetEventRecorderFor(controllerName),
scheduler: scheduler,
RuntimeConfig: runtimeConfig,
}
return reconciler, nil
}
Expand Down Expand Up @@ -227,6 +248,7 @@ type ModelAdapterReconciler struct {
ServiceLister corelisters.ServiceLister
// EndpointSliceLister is able to list/get services from a shared informer's cache store
EndpointSliceLister discoverylisters.EndpointSliceLister
RuntimeConfig config.RuntimeConfig
}

//+kubebuilder:rbac:groups=discovery.k8s.io,resources=endpointslices,verbs=get;list;watch;create;update;patch;delete
Expand Down Expand Up @@ -517,17 +539,10 @@ func (r *ModelAdapterReconciler) reconcileLoading(ctx context.Context, instance
return nil
}

// Define the key you want to check
key := "DEBUG_MODE"
value, exists := getEnvKey(key)
host := fmt.Sprintf("http://%s:8000", targetPod.Status.PodIP)
if exists && value == "on" {
// 30080 is the nodePort of the base model service.
host = fmt.Sprintf("http://%s:30081", "localhost")
}
urls := BuildURLs(targetPod.Status.PodIP, r.RuntimeConfig)

// Check if the model is already loaded
exists, err = r.modelAdapterExists(host, instance)
exists, err := r.modelAdapterExists(urls.ListModelsURL, instance)
if err != nil {
return err
}
Expand All @@ -537,7 +552,7 @@ func (r *ModelAdapterReconciler) reconcileLoading(ctx context.Context, instance
}

// Load the Model adapter
err = r.loadModelAdapter(host, instance)
err = r.loadModelAdapter(urls.LoadAdapterURL, instance)
if err != nil {
return err
}
Expand All @@ -546,10 +561,7 @@ func (r *ModelAdapterReconciler) reconcileLoading(ctx context.Context, instance
}

// Separate method to check if the model already exists
func (r *ModelAdapterReconciler) modelAdapterExists(host string, instance *modelv1alpha1.ModelAdapter) (bool, error) {
// TODO: /v1/models is the vllm entrypoints, let's support multiple engine in future
url := fmt.Sprintf("%s/v1/models", host)

func (r *ModelAdapterReconciler) modelAdapterExists(url string, instance *modelv1alpha1.ModelAdapter) (bool, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return false, err
Expand Down Expand Up @@ -599,10 +611,11 @@ func (r *ModelAdapterReconciler) modelAdapterExists(host string, instance *model
}

// Separate method to load the LoRA adapter
func (r *ModelAdapterReconciler) loadModelAdapter(host string, instance *modelv1alpha1.ModelAdapter) error {
func (r *ModelAdapterReconciler) loadModelAdapter(url string, instance *modelv1alpha1.ModelAdapter) error {
artifactURL := instance.Spec.ArtifactURL
if strings.HasPrefix(instance.Spec.ArtifactURL, "huggingface://") {
artifactURL, err := extractHuggingFacePath(instance.Spec.ArtifactURL)
var err error
artifactURL, err = extractHuggingFacePath(instance.Spec.ArtifactURL)
if err != nil {
// Handle error, e.g., log it and return
klog.ErrorS(err, "Invalid artifact URL", "artifactURL", artifactURL)
Expand All @@ -620,7 +633,6 @@ func (r *ModelAdapterReconciler) loadModelAdapter(host string, instance *modelv1
return err
}

url := fmt.Sprintf("%s/v1/load_lora_adapter", host)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payloadBytes))
if err != nil {
return err
Expand Down Expand Up @@ -682,15 +694,8 @@ func (r *ModelAdapterReconciler) unloadModelAdapter(instance *modelv1alpha1.Mode
return err
}

url := fmt.Sprintf("http://%s:%d/v1/unload_lora_adapter", targetPod.Status.PodIP, 8000)
key := "DEBUG_MODE"
value, exists := getEnvKey(key)
if exists && value == "on" {
// 30080 is the nodePort of the base model service.
url = "http://localhost:30081/v1/unload_lora_adapter"
}

req, err := http.NewRequest("POST", url, bytes.NewBuffer(payloadBytes))
urls := BuildURLs(targetPod.Status.PodIP, r.RuntimeConfig)
req, err := http.NewRequest("POST", urls.UnloadAdapterURL, bytes.NewBuffer(payloadBytes))
if err != nil {
return err
}
Expand Down
28 changes: 28 additions & 0 deletions pkg/controller/modeladapter/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"strings"

modelv1alpha1 "github.com/aibrix/aibrix/api/model/v1alpha1"
"github.com/aibrix/aibrix/pkg/config"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

Expand Down Expand Up @@ -141,3 +142,30 @@ func NewCondition(condType string, status metav1.ConditionStatus, reason, msg st
Message: msg,
}
}

func BuildURLs(podIP string, config config.RuntimeConfig) URLConfig {
var host string
if config.DebugMode {
host = fmt.Sprintf("http://%s:%s", "localhost", DefaultDebugInferenceEnginePort)
} else if config.EnableRuntimeSidecar {
host = fmt.Sprintf("http://%s:%s", podIP, DefaultRuntimeAPIPort)
} else {
host = fmt.Sprintf("http://%s:%s", podIP, DefaultInferenceEnginePort)
}

apiPath := ModelListPath
loadPath := LoadLoraAdapterPath
unloadPath := UnloadLoraAdapterPath
if config.EnableRuntimeSidecar {
apiPath = ModelListRuntimeAPIPath
loadPath = LoadLoraRuntimeAPIPath
unloadPath = UnloadLoraRuntimeAPIPath
}

return URLConfig{
BaseURL: host,
ListModelsURL: fmt.Sprintf("%s%s", host, apiPath),
LoadAdapterURL: fmt.Sprintf("%s%s", host, loadPath),
UnloadAdapterURL: fmt.Sprintf("%s%s", host, unloadPath),
}
}
72 changes: 72 additions & 0 deletions pkg/controller/modeladapter/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package modeladapter

import (
"fmt"
"os"
"testing"

Expand All @@ -25,6 +26,7 @@ import (
"k8s.io/utils/ptr"

modelv1alpha1 "github.com/aibrix/aibrix/api/model/v1alpha1"
"github.com/aibrix/aibrix/pkg/config"
)

// Test for validateModelAdapter function
Expand Down Expand Up @@ -236,3 +238,73 @@ func TestExtractHuggingFacePath(t *testing.T) {
})
}
}

func TestBuildURLs(t *testing.T) {
tests := []struct {
name string
podIP string
config config.RuntimeConfig
expectedURLs URLConfig
expectError bool
}{
{
name: "Debug mode enabled",
podIP: "192.168.1.1",
config: config.RuntimeConfig{
DebugMode: true,
EnableRuntimeSidecar: false,
},
expectedURLs: URLConfig{
BaseURL: fmt.Sprintf("http://%s:%s", "localhost", DefaultDebugInferenceEnginePort),
ListModelsURL: fmt.Sprintf("http://%s:%s%s", "localhost", DefaultDebugInferenceEnginePort, ModelListPath),
LoadAdapterURL: fmt.Sprintf("http://%s:%s%s", "localhost", DefaultDebugInferenceEnginePort, LoadLoraAdapterPath),
UnloadAdapterURL: fmt.Sprintf("http://%s:%s%s", "localhost", DefaultDebugInferenceEnginePort, UnloadLoraAdapterPath),
},
expectError: false,
},
{
name: "Runtime sidecar enabled",
podIP: "192.168.1.2",
config: config.RuntimeConfig{
DebugMode: false,
EnableRuntimeSidecar: true,
},
expectedURLs: URLConfig{
BaseURL: fmt.Sprintf("http://%s:%s", "192.168.1.2", DefaultRuntimeAPIPort),
ListModelsURL: fmt.Sprintf("http://%s:%s%s", "192.168.1.2", DefaultRuntimeAPIPort, ModelListRuntimeAPIPath),
LoadAdapterURL: fmt.Sprintf("http://%s:%s%s", "192.168.1.2", DefaultRuntimeAPIPort, LoadLoraRuntimeAPIPath),
UnloadAdapterURL: fmt.Sprintf("http://%s:%s%s", "192.168.1.2", DefaultRuntimeAPIPort, UnloadLoraRuntimeAPIPath),
},
expectError: false,
},
{
name: "Default mode",
podIP: "192.168.1.3",
config: config.RuntimeConfig{
DebugMode: false,
EnableRuntimeSidecar: false,
},
expectedURLs: URLConfig{
BaseURL: fmt.Sprintf("http://%s:%s", "192.168.1.3", DefaultInferenceEnginePort),
ListModelsURL: fmt.Sprintf("http://%s:%s%s", "192.168.1.3", DefaultInferenceEnginePort, ModelListPath),
LoadAdapterURL: fmt.Sprintf("http://%s:%s%s", "192.168.1.3", DefaultInferenceEnginePort, LoadLoraAdapterPath),
UnloadAdapterURL: fmt.Sprintf("http://%s:%s%s", "192.168.1.3", DefaultInferenceEnginePort, UnloadLoraAdapterPath),
},
expectError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
urls := BuildURLs(tt.podIP, tt.config)

if tt.expectError {
t.Fatalf("Expected error but got none")
} else {
if urls != tt.expectedURLs {
t.Errorf("Expected URLs %+v but got %+v", tt.expectedURLs, urls)
}
}
})
}
}
Loading
Loading