Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 133 additions & 26 deletions cmd/kvcache-watcher/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,29 @@ import (
"github.com/redis/go-redis/v9"
)

const DefaultKVCacheServerPort = 9600
const KVCacheLabelKeyIdentifier = "kvcache.orchestration.aibrix.ai/name"
const KVCacheLabelKeyRole = "kvcache.orchestration.aibrix.ai/role"
const KVCacheLabelValueRoleCache = "cache"
const KVCacheLabelValueRoleMetadata = "metadata"
const KVCacheLabelValueRoleKVWatcher = "kvwatcher"
const RedisNodeMemberKey = "hpkv_nodes"
const HPKVRedisNodeMemberKey = "hpkv_nodes"
const InfiniStoreRedisNodeMemberKey = "kvcache_nodes"

const networkStatusAnnotation = "k8s.volcengine.com/network-status"

var (
kvCacheServerRDMAPort = utils.LoadEnvInt("AIBRIX_KVCACHE_RDMA_PORT", 18512)
totalSlots = utils.LoadEnvInt("AIBRIX_KVCACHE_TOTAL_SLOTS", 4096)
virtualNodeCount = utils.LoadEnvInt("AIBRIX_KVCACHE_VIRTUAL_NODE_COUNT", 100)

config *rest.Config
clientset *kubernetes.Clientset
)

var (
kvCacheBackend string
kvCacheWatchNS string
kvCacheWatchClusterId string
kvCacheServerRDMAPort int
kvCacheServerAdminPort int
consistentHashingTotalSlots int
consistentHashingVirtualNodeCount int
)

type SlotRange struct {
Start int `json:"start"`
End int `json:"end"`
Expand Down Expand Up @@ -145,16 +149,50 @@ func (v VirtualNode) Index() int {
return -1
}

type KVCacheBackend interface {
GetRedisKey() string
ExtractIP(ctx context.Context, pod *corev1.Pod) (string, error)
}

type HPKVBackend struct{}

func (b HPKVBackend) GetRedisKey() string {
return HPKVRedisNodeMemberKey
}

func (b HPKVBackend) ExtractIP(ctx context.Context, pod *corev1.Pod) (string, error) {
return GetRDMAIP(ctx, pod)
}

type InfiniStoreBackend struct{}

func (b InfiniStoreBackend) GetRedisKey() string {
return InfiniStoreRedisNodeMemberKey
}

func (b InfiniStoreBackend) ExtractIP(ctx context.Context, pod *corev1.Pod) (string, error) {
return pod.Status.PodIP, nil
}

func NewKVCacheBackend(backend string) KVCacheBackend {
switch backend {
case "infinistore":
return InfiniStoreBackend{}
case "hpkv":
fallthrough
default:
return HPKVBackend{}
}
}

func main() {
ctx := context.Background()
parseFlags()

// read environment variables from env
namespace := utils.LoadEnv("WATCH_KVCACHE_NAMESPACE", "default")
kvClusterId := os.Getenv("WATCH_KVCACHE_CLUSTER") // e.g., "kvcache.aibrix.ai=llama4"
redisAddr := os.Getenv("REDIS_ADDR")
redisPass := os.Getenv("REDIS_PASSWORD")
redisDatabase := utils.LoadEnvInt("REDIS_DATABASE", 0)

rdb := redis.NewClient(&redis.Options{
Addr: redisAddr,
Password: redisPass,
Expand Down Expand Up @@ -189,10 +227,10 @@ func main() {

// Create informer factory
factory := informers.NewSharedInformerFactoryWithOptions(clientset, 15*time.Second,
informers.WithNamespace(namespace),
informers.WithNamespace(kvCacheWatchNS),
informers.WithTweakListOptions(func(opts *metav1.ListOptions) {
if kvClusterId != "" {
kvClusterLabel := fmt.Sprintf("%s=%s", KVCacheLabelKeyIdentifier, kvClusterId)
if kvCacheWatchClusterId != "" {
kvClusterLabel := fmt.Sprintf("%s=%s", KVCacheLabelKeyIdentifier, kvCacheWatchClusterId)
kvClusterRoleLabel := fmt.Sprintf("%s=%s", KVCacheLabelKeyRole, KVCacheLabelValueRoleCache)
opts.LabelSelector = fmt.Sprintf("%s,%s", kvClusterLabel, kvClusterRoleLabel)
}
Expand All @@ -205,13 +243,13 @@ func main() {
podInformer := factory.Core().V1().Pods().Informer()
_, err = podInformer.AddEventHandler(&cache.ResourceEventHandlerFuncs{
AddFunc: func(obj interface{}) {
queue.Add(kvClusterId)
queue.Add(kvCacheWatchClusterId)
},
UpdateFunc: func(oldObj, newObj interface{}) {
queue.Add(kvClusterId)
queue.Add(kvCacheWatchClusterId)
},
DeleteFunc: func(obj interface{}) {
queue.Add(kvClusterId)
queue.Add(kvCacheWatchClusterId)
},
})
if err != nil {
Expand All @@ -225,6 +263,8 @@ func main() {
factory.Start(stopCh)
factory.WaitForCacheSync(stopCh)

backendImpl := NewKVCacheBackend(kvCacheBackend)

// Start queue worker in goroutine
go func() {
for {
Expand All @@ -236,7 +276,7 @@ func main() {
func(key string) {
defer queue.Done(key)

if err := syncPods(ctx, rdb, podInformer, key); err != nil {
if err := syncPods(ctx, rdb, podInformer, key, backendImpl); err != nil {
klog.Errorf("syncPods failed for %s: %v, retrying...", key, err)
queue.AddRateLimited(key)
} else {
Expand All @@ -249,7 +289,71 @@ func main() {
<-stopCh
}

func syncPods(ctx context.Context, rdb *redis.Client, informer cache.SharedIndexInformer, kvClusterId string) error {
func parseFlags() {
flag.StringVar(
&kvCacheBackend,
"kvcache-backend",
"hpkv",
"KV backend implementation to use. Supported: 'hpkv', 'infinistore'.",
)

flag.StringVar(
&kvCacheWatchNS,
"kvcache-watch-namespace",
utils.LoadEnv("AIBRIX_KVCACHE_WATCH_NAMESPACE", "default"),
"Kubernetes namespace to watch for KVCache pods.",
)

flag.StringVar(
&kvCacheWatchClusterId,
"kvcache-watch-cluster-id",
os.Getenv("AIBRIX_KVCACHE_WATCH_CLUSTER"),
"Value of the 'kvcache.orchestration.aibrix.ai/name' label to identify the KV cache cluster.",
)

flag.IntVar(
&kvCacheServerRDMAPort,
"kvcache-server-rdma-port",
utils.LoadEnvInt("AIBRIX_KVCACHE_RDMA_PORT", 18512),
"RDMA service port used by the KVCache data servers.",
)

flag.IntVar(
&kvCacheServerAdminPort,
"kvcache-server-admin-port",
utils.LoadEnvInt("AIBRIX_KVCACHE_ADMIN_PORT", 9100),
"Admin port used for control kv cache server.",
)

flag.IntVar(
&consistentHashingTotalSlots,
"consistent-hashing-total-slots",
4096,
"Total number of slots in the consistent hashing ring.",
)

flag.IntVar(
&consistentHashingVirtualNodeCount,
"consistent-hashing-virtual-node-count",
100,
"Number of virtual nodes per physical KVCache pod for consistent hashing.",
)

flag.Parse()

klog.Infof("=== Parsed Flags ===")
flag.VisitAll(func(f *flag.Flag) {
klog.Infof("%s: %s", f.Name, f.Value)
})
}

func syncPods(
ctx context.Context,
rdb *redis.Client,
informer cache.SharedIndexInformer,
kvClusterId string,
kvb KVCacheBackend,
) error {
pods := informer.GetStore().List()
klog.Infof("%d pods Found in kvcache cluster %s", len(pods), kvClusterId)

Expand All @@ -275,35 +379,38 @@ func syncPods(ctx context.Context, rdb *redis.Client, informer cache.SharedIndex
}
}

nodeSlots := calculateSlotDistribution(validPods, totalSlots, virtualNodeCount)
nodeSlots := calculateSlotDistribution(validPods, consistentHashingTotalSlots, consistentHashingVirtualNodeCount)
currentNodes := make([]NodeInfo, 0)
for _, pod := range validPods {
rdmaIP, err := GetRDMAIP(ctx, &pod)
ip, err := kvb.ExtractIP(ctx, &pod)
if err != nil {
klog.ErrorS(err, "Failed to get RDMA IP for pod", "pod", pod.Name)
continue
}

currentNodes = append(currentNodes, NodeInfo{
Name: pod.Name,
Addr: rdmaIP,
Addr: ip,
Port: kvCacheServerRDMAPort,
Slots: mergeSlots(nodeSlots[pod.Name], totalSlots),
Slots: mergeSlots(nodeSlots[pod.Name], consistentHashingTotalSlots),
})
}

redisKey := kvb.GetRedisKey()

// get existing nodes
val, err := rdb.Get(ctx, RedisNodeMemberKey).Result()
val, err := rdb.Get(ctx, redisKey).Result()
if err != nil && !errors.Is(err, redis.Nil) {
return fmt.Errorf("failed to get existing data from redis %v", err)
}
existingClusterNodes := ClusterNodes{}
_ = json.Unmarshal([]byte(val), &existingClusterNodes)
klog.Infof("redis get result: key %s, value %s", RedisNodeMemberKey, val)
klog.Infof("redis get result: key %s, value %s", redisKey, val)

needUpdate := !isNodeListEqual(currentNodes, existingClusterNodes.Nodes)
if !needUpdate {
klog.Infof("Node list unchanged, skipping update, current version: %d", existingClusterNodes.Version)
return nil
}

newVersion := int64(1)
Expand All @@ -323,7 +430,7 @@ func syncPods(ctx context.Context, rdb *redis.Client, informer cache.SharedIndex

// write to redis using pipeline
pipe := rdb.TxPipeline()
pipe.Set(ctx, RedisNodeMemberKey, jsonData, 0)
pipe.Set(ctx, redisKey, jsonData, 0)
if _, err := pipe.Exec(ctx); err != nil {
return fmt.Errorf("redis transaction failed: %v", err)
}
Expand Down
44 changes: 44 additions & 0 deletions pkg/constants/kvcache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
Copyright 2025 The Aibrix Team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package constants

const (
KVCacheLabelKeyIdentifier = "kvcache.orchestration.aibrix.ai/name"
KVCacheLabelKeyRole = "kvcache.orchestration.aibrix.ai/role"
KVCacheLabelKeyMetadataIndex = "kvcache.orchestration.aibrix.ai/etcd-index"
KVCacheLabelKeyBackend = "kvcache.orchestration.aibrix.ai/backend"

KVCacheAnnotationNodeAffinityKey = "kvcache.orchestration.aibrix.ai/node-affinity-key"
KVCacheAnnotationNodeAffinityGPUType = "kvcache.orchestration.aibrix.ai/node-affinity-gpu-type"
KVCacheAnnotationPodAffinityKey = "kvcache.orchestration.aibrix.ai/pod-affinity-workload"
KVCacheAnnotationPodAntiAffinity = "kvcache.orchestration.aibrix.ai/pod-anti-affinity"

KVCacheAnnotationNodeAffinityDefaultKey = "machine.cluster.vke.volcengine.com/gpu-name"

// This config will be deprecated in future, users should specify kvcache backend directly.
KVCacheAnnotationMode = "kvcache.orchestration.aibrix.ai/mode"
KVCacheAnnotationContainerRegistry = "kvcache.orchestration.aibrix.ai/container-registry"

KVCacheLabelValueRoleCache = "cache"
KVCacheLabelValueRoleMetadata = "metadata"
KVCacheLabelValueRoleKVWatcher = "kvwatcher"

KVCacheBackendVineyard = "vineyard"
KVCacheBackendHPKV = "hpkv"
KVCacheBackendInfinistore = "infinistore"
KVCacheBackendDefault = KVCacheBackendVineyard
)
Loading
Loading