diff --git a/balancer/endpointsharding/endpointsharding.go b/balancer/endpointsharding/endpointsharding.go index 421c4fecc999..cc606f4dae4e 100644 --- a/balancer/endpointsharding/endpointsharding.go +++ b/balancer/endpointsharding/endpointsharding.go @@ -73,7 +73,7 @@ func NewBalancer(cc balancer.ClientConn, opts balancer.BuildOptions, childBuilde esOpts: esOpts, childBuilder: childBuilder, } - es.children.Store(resolver.NewEndpointMap()) + es.children.Store(resolver.NewEndpointMap[*balancerWrapper]()) return es } @@ -90,7 +90,7 @@ type endpointSharding struct { // calls into a child. To avoid deadlocks, do not acquire childMu while // holding mu. childMu sync.Mutex - children atomic.Pointer[resolver.EndpointMap] // endpoint -> *balancerWrapper + children atomic.Pointer[resolver.EndpointMap[*balancerWrapper]] // inhibitChildUpdates is set during UpdateClientConnState/ResolverError // calls (calls to children will each produce an update, only want one @@ -122,7 +122,7 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState var ret error children := es.children.Load() - newChildren := resolver.NewEndpointMap() + newChildren := resolver.NewEndpointMap[*balancerWrapper]() // Update/Create new children. for _, endpoint := range state.ResolverState.Endpoints { @@ -131,9 +131,8 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState // update. continue } - var childBalancer *balancerWrapper - if val, ok := children.Get(endpoint); ok { - childBalancer = val.(*balancerWrapper) + childBalancer, ok := children.Get(endpoint) + if ok { // Endpoint attributes may have changed, update the stored endpoint. es.mu.Lock() childBalancer.childState.Endpoint = endpoint @@ -166,7 +165,7 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState for _, e := range children.Keys() { child, _ := children.Get(e) if _, ok := newChildren.Get(e); !ok { - child.(*balancerWrapper).closeLocked() + child.closeLocked() } } es.children.Store(newChildren) @@ -189,7 +188,7 @@ func (es *endpointSharding) ResolverError(err error) { }() children := es.children.Load() for _, child := range children.Values() { - child.(*balancerWrapper).resolverErrorLocked(err) + child.resolverErrorLocked(err) } } @@ -202,7 +201,7 @@ func (es *endpointSharding) Close() { defer es.childMu.Unlock() children := es.children.Load() for _, child := range children.Values() { - child.(*balancerWrapper).closeLocked() + child.closeLocked() } } @@ -222,8 +221,7 @@ func (es *endpointSharding) updateState() { childStates := make([]ChildState, 0, children.Len()) for _, child := range children.Values() { - bw := child.(*balancerWrapper) - childState := bw.childState + childState := child.childState childStates = append(childStates, childState) childPicker := childState.State.Picker switch childState.State.ConnectivityState { diff --git a/balancer/leastrequest/leastrequest.go b/balancer/leastrequest/leastrequest.go index d25f9178b9d2..dd46dfa8faa4 100644 --- a/balancer/leastrequest/leastrequest.go +++ b/balancer/leastrequest/leastrequest.go @@ -88,7 +88,7 @@ func (bb) Name() string { func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { b := &leastRequestBalancer{ ClientConn: cc, - endpointRPCCounts: resolver.NewEndpointMap(), + endpointRPCCounts: resolver.NewEndpointMap[*atomic.Int32](), } b.child = endpointsharding.NewBalancer(b, bOpts, balancer.Get(pickfirstleaf.Name).Build, endpointsharding.Options{}) b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[%p] ", b)) @@ -110,7 +110,7 @@ type leastRequestBalancer struct { choiceCount uint32 // endpointRPCCounts holds RPC counts to keep track for subsequent picker // updates. - endpointRPCCounts *resolver.EndpointMap // endpoint -> *atomic.Int32 + endpointRPCCounts *resolver.EndpointMap[*atomic.Int32] } func (lrb *leastRequestBalancer) Close() { @@ -164,7 +164,7 @@ func (lrb *leastRequestBalancer) UpdateState(state balancer.State) { } // Reconcile endpoints. - newEndpoints := resolver.NewEndpointMap() // endpoint -> nil + newEndpoints := resolver.NewEndpointMap[any]() for _, child := range readyEndpoints { newEndpoints.Set(child.Endpoint, nil) } @@ -179,13 +179,11 @@ func (lrb *leastRequestBalancer) UpdateState(state balancer.State) { // Copy refs to counters into picker. endpointStates := make([]endpointState, 0, len(readyEndpoints)) for _, child := range readyEndpoints { - var counter *atomic.Int32 - if val, ok := lrb.endpointRPCCounts.Get(child.Endpoint); !ok { + counter, ok := lrb.endpointRPCCounts.Get(child.Endpoint) + if !ok { // Create new counts if needed. counter = new(atomic.Int32) lrb.endpointRPCCounts.Set(child.Endpoint, counter) - } else { - counter = val.(*atomic.Int32) } endpointStates = append(endpointStates, endpointState{ picker: child.State.Picker, diff --git a/balancer/weightedroundrobin/balancer.go b/balancer/weightedroundrobin/balancer.go index acc86198766b..0ee707601198 100644 --- a/balancer/weightedroundrobin/balancer.go +++ b/balancer/weightedroundrobin/balancer.go @@ -105,7 +105,7 @@ func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Ba target: bOpts.Target.String(), metricsRecorder: cc.MetricsRecorder(), addressWeights: resolver.NewAddressMapV2[*endpointWeight](), - endpointToWeight: resolver.NewEndpointMap(), + endpointToWeight: resolver.NewEndpointMap[*endpointWeight](), scToWeight: make(map[balancer.SubConn]*endpointWeight), } @@ -155,17 +155,15 @@ func (bb) Name() string { // // Caller must hold b.mu. func (b *wrrBalancer) updateEndpointsLocked(endpoints []resolver.Endpoint) { - endpointSet := resolver.NewEndpointMap() + endpointSet := resolver.NewEndpointMap[*endpointWeight]() addressSet := resolver.NewAddressMapV2[*endpointWeight]() for _, endpoint := range endpoints { endpointSet.Set(endpoint, nil) for _, addr := range endpoint.Addresses { addressSet.Set(addr, nil) } - var ew *endpointWeight - if ewi, ok := b.endpointToWeight.Get(endpoint); ok { - ew = ewi.(*endpointWeight) - } else { + ew, ok := b.endpointToWeight.Get(endpoint) + if !ok { ew = &endpointWeight{ logger: b.logger, connectivityState: connectivity.Connecting, @@ -215,7 +213,7 @@ type wrrBalancer struct { locality string stopPicker *grpcsync.Event addressWeights *resolver.AddressMapV2[*endpointWeight] - endpointToWeight *resolver.EndpointMap // endpoint -> endpointWeight + endpointToWeight *resolver.EndpointMap[*endpointWeight] scToWeight map[balancer.SubConn]*endpointWeight } @@ -260,13 +258,12 @@ func (b *wrrBalancer) UpdateState(state balancer.State) { for _, childState := range childStates { if childState.State.ConnectivityState == connectivity.Ready { - ewv, ok := b.endpointToWeight.Get(childState.Endpoint) + ew, ok := b.endpointToWeight.Get(childState.Endpoint) if !ok { // Should never happen, simply continue and ignore this endpoint // for READY pickers. continue } - ew := ewv.(*endpointWeight) readyPickersWeight = append(readyPickersWeight, pickerWeightedEndpoint{ picker: childState.State.Picker, weightedEndpoint: ew, @@ -398,8 +395,7 @@ func (b *wrrBalancer) Close() { b.mu.Unlock() // Ensure any lingering OOB watchers are stopped. - for _, ewv := range b.endpointToWeight.Values() { - ew := ewv.(*endpointWeight) + for _, ew := range b.endpointToWeight.Values() { if ew.stopORCAListener != nil { ew.stopORCAListener() } diff --git a/resolver/map.go b/resolver/map.go index 4e8d8f4481d8..c3c15ac96f13 100644 --- a/resolver/map.go +++ b/resolver/map.go @@ -162,21 +162,21 @@ type endpointMapKey string // unordered set of address strings within an endpoint. This map is not thread // safe, thus it is unsafe to access concurrently. Must be created via // NewEndpointMap; do not construct directly. -type EndpointMap struct { - endpoints map[endpointMapKey]endpointData +type EndpointMap[T any] struct { + endpoints map[endpointMapKey]endpointData[T] } -type endpointData struct { +type endpointData[T any] struct { // decodedKey stores the original key to avoid decoding when iterating on // EndpointMap keys. decodedKey Endpoint - value any + value T } // NewEndpointMap creates a new EndpointMap. -func NewEndpointMap() *EndpointMap { - return &EndpointMap{ - endpoints: make(map[endpointMapKey]endpointData), +func NewEndpointMap[T any]() *EndpointMap[T] { + return &EndpointMap[T]{ + endpoints: make(map[endpointMapKey]endpointData[T]), } } @@ -196,25 +196,25 @@ func encodeEndpoint(e Endpoint) endpointMapKey { } // Get returns the value for the address in the map, if present. -func (em *EndpointMap) Get(e Endpoint) (value any, ok bool) { +func (em *EndpointMap[T]) Get(e Endpoint) (value T, ok bool) { val, found := em.endpoints[encodeEndpoint(e)] if found { return val.value, true } - return nil, false + return value, false } // Set updates or adds the value to the address in the map. -func (em *EndpointMap) Set(e Endpoint, value any) { +func (em *EndpointMap[T]) Set(e Endpoint, value T) { en := encodeEndpoint(e) - em.endpoints[en] = endpointData{ + em.endpoints[en] = endpointData[T]{ decodedKey: Endpoint{Addresses: e.Addresses}, value: value, } } // Len returns the number of entries in the map. -func (em *EndpointMap) Len() int { +func (em *EndpointMap[T]) Len() int { return len(em.endpoints) } @@ -223,7 +223,7 @@ func (em *EndpointMap) Len() int { // the unordered set of addresses. Thus, endpoint information returned is not // the full endpoint data (drops duplicated addresses and attributes) but can be // used for EndpointMap accesses. -func (em *EndpointMap) Keys() []Endpoint { +func (em *EndpointMap[T]) Keys() []Endpoint { ret := make([]Endpoint, 0, len(em.endpoints)) for _, en := range em.endpoints { ret = append(ret, en.decodedKey) @@ -232,8 +232,8 @@ func (em *EndpointMap) Keys() []Endpoint { } // Values returns a slice of all current map values. -func (em *EndpointMap) Values() []any { - ret := make([]any, 0, len(em.endpoints)) +func (em *EndpointMap[T]) Values() []T { + ret := make([]T, 0, len(em.endpoints)) for _, val := range em.endpoints { ret = append(ret, val.value) } @@ -241,7 +241,7 @@ func (em *EndpointMap) Values() []any { } // Delete removes the specified endpoint from the map. -func (em *EndpointMap) Delete(e Endpoint) { +func (em *EndpointMap[T]) Delete(e Endpoint) { en := encodeEndpoint(e) delete(em.endpoints, en) } diff --git a/resolver/map_test.go b/resolver/map_test.go index 37b817b462fa..33526839d228 100644 --- a/resolver/map_test.go +++ b/resolver/map_test.go @@ -72,11 +72,11 @@ func (s) TestAddressMap_Length(t *testing.T) { } func (s) TestAddressMap_Get(t *testing.T) { - addrMap := NewAddressMapV2[any]() + addrMap := NewAddressMapV2[int]() addrMap.Set(addr1, 1) - if got, ok := addrMap.Get(addr2); ok || got != nil { - t.Fatalf("addrMap.Get(addr1) = %v, %v; want nil, false", got, ok) + if got, ok := addrMap.Get(addr2); ok || got != 0 { + t.Fatalf("addrMap.Get(addr1) = %v, %v; want 0, false", got, ok) } addrMap.Set(addr2, 2) @@ -85,25 +85,25 @@ func (s) TestAddressMap_Get(t *testing.T) { addrMap.Set(addr5, 5) addrMap.Set(addr6, 6) addrMap.Set(addr7, 7) // aliases addr1 - if got, ok := addrMap.Get(addr1); !ok || got.(int) != 7 { + if got, ok := addrMap.Get(addr1); !ok || got != 7 { t.Fatalf("addrMap.Get(addr1) = %v, %v; want %v, true", got, ok, 7) } - if got, ok := addrMap.Get(addr2); !ok || got.(int) != 2 { + if got, ok := addrMap.Get(addr2); !ok || got != 2 { t.Fatalf("addrMap.Get(addr2) = %v, %v; want %v, true", got, ok, 2) } - if got, ok := addrMap.Get(addr3); !ok || got.(int) != 3 { + if got, ok := addrMap.Get(addr3); !ok || got != 3 { t.Fatalf("addrMap.Get(addr3) = %v, %v; want %v, true", got, ok, 3) } - if got, ok := addrMap.Get(addr4); !ok || got.(int) != 4 { + if got, ok := addrMap.Get(addr4); !ok || got != 4 { t.Fatalf("addrMap.Get(addr4) = %v, %v; want %v, true", got, ok, 4) } - if got, ok := addrMap.Get(addr5); !ok || got.(int) != 5 { + if got, ok := addrMap.Get(addr5); !ok || got != 5 { t.Fatalf("addrMap.Get(addr5) = %v, %v; want %v, true", got, ok, 5) } - if got, ok := addrMap.Get(addr6); !ok || got.(int) != 6 { + if got, ok := addrMap.Get(addr6); !ok || got != 6 { t.Fatalf("addrMap.Get(addr6) = %v, %v; want %v, true", got, ok, 6) } - if got, ok := addrMap.Get(addr7); !ok || got.(int) != 7 { + if got, ok := addrMap.Get(addr7); !ok || got != 7 { t.Fatalf("addrMap.Get(addr7) = %v, %v; want %v, true", got, ok, 7) } } @@ -132,7 +132,7 @@ func (s) TestAddressMap_Delete(t *testing.T) { } func (s) TestAddressMap_Keys(t *testing.T) { - addrMap := NewAddressMapV2[any]() + addrMap := NewAddressMapV2[int]() addrMap.Set(addr1, 1) addrMap.Set(addr2, 2) addrMap.Set(addr3, 3) @@ -153,7 +153,7 @@ func (s) TestAddressMap_Keys(t *testing.T) { } func (s) TestAddressMap_Values(t *testing.T) { - addrMap := NewAddressMapV2[any]() + addrMap := NewAddressMapV2[int]() addrMap.Set(addr1, 1) addrMap.Set(addr2, 2) addrMap.Set(addr3, 3) @@ -163,10 +163,7 @@ func (s) TestAddressMap_Values(t *testing.T) { addrMap.Set(addr7, 7) // aliases addr1 want := []int{2, 3, 4, 5, 6, 7} - var got []int - for _, v := range addrMap.Values() { - got = append(got, v.(int)) - } + got := addrMap.Values() sort.Ints(got) if diff := cmp.Diff(want, got); diff != "" { t.Fatalf("addrMap.Values returned unexpected elements (-want, +got):\n%v", diff) @@ -174,7 +171,7 @@ func (s) TestAddressMap_Values(t *testing.T) { } func (s) TestEndpointMap_Length(t *testing.T) { - em := NewEndpointMap() + em := NewEndpointMap[struct{}]() // Should be empty at creation time. if got := em.Len(); got != 0 { t.Fatalf("em.Len() = %v; want 0", got) @@ -196,7 +193,7 @@ func (s) TestEndpointMap_Length(t *testing.T) { } func (s) TestEndpointMap_Get(t *testing.T) { - em := NewEndpointMap() + em := NewEndpointMap[int]() em.Set(endpoint1, 1) // The second endpoint endpoint21 should override. em.Set(endpoint12, 1) @@ -207,28 +204,28 @@ func (s) TestEndpointMap_Get(t *testing.T) { em.Set(endpoint6, 6) em.Set(endpoint7, 7) - if got, ok := em.Get(endpoint1); !ok || got.(int) != 1 { + if got, ok := em.Get(endpoint1); !ok || got != 1 { t.Fatalf("em.Get(endpoint1) = %v, %v; want %v, true", got, ok, 1) } - if got, ok := em.Get(endpoint12); !ok || got.(int) != 2 { + if got, ok := em.Get(endpoint12); !ok || got != 2 { t.Fatalf("em.Get(endpoint12) = %v, %v; want %v, true", got, ok, 2) } - if got, ok := em.Get(endpoint21); !ok || got.(int) != 2 { + if got, ok := em.Get(endpoint21); !ok || got != 2 { t.Fatalf("em.Get(endpoint21) = %v, %v; want %v, true", got, ok, 2) } - if got, ok := em.Get(endpoint3); !ok || got.(int) != 3 { + if got, ok := em.Get(endpoint3); !ok || got != 3 { t.Fatalf("em.Get(endpoint1) = %v, %v; want %v, true", got, ok, 3) } - if got, ok := em.Get(endpoint4); !ok || got.(int) != 4 { + if got, ok := em.Get(endpoint4); !ok || got != 4 { t.Fatalf("em.Get(endpoint1) = %v, %v; want %v, true", got, ok, 4) } - if got, ok := em.Get(endpoint5); !ok || got.(int) != 5 { + if got, ok := em.Get(endpoint5); !ok || got != 5 { t.Fatalf("em.Get(endpoint1) = %v, %v; want %v, true", got, ok, 5) } - if got, ok := em.Get(endpoint6); !ok || got.(int) != 6 { + if got, ok := em.Get(endpoint6); !ok || got != 6 { t.Fatalf("em.Get(endpoint1) = %v, %v; want %v, true", got, ok, 6) } - if got, ok := em.Get(endpoint7); !ok || got.(int) != 7 { + if got, ok := em.Get(endpoint7); !ok || got != 7 { t.Fatalf("em.Get(endpoint1) = %v, %v; want %v, true", got, ok, 7) } if _, ok := em.Get(endpoint123); ok { @@ -237,7 +234,7 @@ func (s) TestEndpointMap_Get(t *testing.T) { } func (s) TestEndpointMap_Delete(t *testing.T) { - em := NewEndpointMap() + em := NewEndpointMap[struct{}]() // Initial state of system: [1, 2, 3, 12] em.Set(endpoint1, struct{}{}) em.Set(endpoint2, struct{}{}) @@ -267,7 +264,7 @@ func (s) TestEndpointMap_Delete(t *testing.T) { } func (s) TestEndpointMap_Values(t *testing.T) { - em := NewEndpointMap() + em := NewEndpointMap[int]() em.Set(endpoint1, 1) // The second endpoint endpoint21 should override. em.Set(endpoint12, 1) @@ -278,10 +275,7 @@ func (s) TestEndpointMap_Values(t *testing.T) { em.Set(endpoint6, 6) em.Set(endpoint7, 7) want := []int{1, 2, 3, 4, 5, 6, 7} - var got []int - for _, v := range em.Values() { - got = append(got, v.(int)) - } + got := em.Values() sort.Ints(got) if diff := cmp.Diff(want, got); diff != "" { t.Fatalf("em.Values() returned unexpected elements (-want, +got):\n%v", diff) @@ -292,7 +286,7 @@ func (s) TestEndpointMap_Values(t *testing.T) { // faster than O(n). This test doesn't run O(n) operations including listing // keys and values. func BenchmarkEndpointMap(b *testing.B) { - em := NewEndpointMap() + em := NewEndpointMap[any]() for i := range b.N { em.Set(Endpoint{ Addresses: []Address{{Addr: fmt.Sprintf("%d.%d.%d.%d", i, i, i, i)}}, diff --git a/xds/internal/balancer/outlierdetection/balancer.go b/xds/internal/balancer/outlierdetection/balancer.go index 0d60ab2e86f4..bc5cea40cc1d 100644 --- a/xds/internal/balancer/outlierdetection/balancer.go +++ b/xds/internal/balancer/outlierdetection/balancer.go @@ -68,7 +68,7 @@ func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Ba scUpdateCh: buffer.NewUnbounded(), pickerUpdateCh: buffer.NewUnbounded(), channelzParent: bOpts.ChannelzParent, - endpoints: resolver.NewEndpointMap(), + endpoints: resolver.NewEndpointMap[*endpointInfo](), } b.logger = prefixLogger(b) b.logger.Infof("Created") @@ -196,7 +196,7 @@ type outlierDetectionBalancer struct { // (within the context of a single goroutine). mu sync.Mutex // endpoints stores pointers to endpointInfo objects for each endpoint. - endpoints *resolver.EndpointMap // endpoint -> endpointInfo + endpoints *resolver.EndpointMap[*endpointInfo] // addrs stores pointers to endpointInfo objects for each address. Addresses // belonging to the same endpoint point to the same object. addrs map[string]*endpointInfo @@ -229,8 +229,7 @@ func (b *outlierDetectionBalancer) onIntervalConfig() { var interval time.Duration if b.timerStartTime.IsZero() { b.timerStartTime = time.Now() - for _, val := range b.endpoints.Values() { - epInfo := val.(*endpointInfo) + for _, epInfo := range b.endpoints.Values() { epInfo.callCounter.clear() } interval = time.Duration(b.cfg.Interval) @@ -253,8 +252,7 @@ func (b *outlierDetectionBalancer) onNoopConfig() { // do the following:" // "Unset the timer start timestamp." b.timerStartTime = time.Time{} - for _, val := range b.endpoints.Values() { - epInfo := val.(*endpointInfo) + for _, epInfo := range b.endpoints.Values() { // "Uneject all currently ejected endpoints." if !epInfo.latestEjectionTimestamp.IsZero() { b.unejectEndpoint(epInfo) @@ -298,7 +296,7 @@ func (b *outlierDetectionBalancer) UpdateClientConnState(s balancer.ClientConnSt b.updateUnconditionally = false b.cfg = lbCfg - newEndpoints := resolver.NewEndpointMap() + newEndpoints := resolver.NewEndpointMap[bool]() for _, ep := range s.ResolverState.Endpoints { newEndpoints.Set(ep, true) if _, ok := b.endpoints.Get(ep); !ok { @@ -315,8 +313,7 @@ func (b *outlierDetectionBalancer) UpdateClientConnState(s balancer.ClientConnSt // populate the addrs map. b.addrs = map[string]*endpointInfo{} for _, ep := range s.ResolverState.Endpoints { - val, _ := b.endpoints.Get(ep) - epInfo := val.(*endpointInfo) + epInfo, _ := b.endpoints.Get(ep) for _, addr := range ep.Addresses { if _, ok := b.addrs[addr.Addr]; ok { b.logger.Errorf("Endpoints contain duplicate address %q", addr.Addr) @@ -705,8 +702,7 @@ func (b *outlierDetectionBalancer) intervalTimerAlgorithm() { defer b.mu.Unlock() b.timerStartTime = time.Now() - for _, val := range b.endpoints.Values() { - epInfo := val.(*endpointInfo) + for _, epInfo := range b.endpoints.Values() { epInfo.callCounter.swap() } @@ -718,8 +714,7 @@ func (b *outlierDetectionBalancer) intervalTimerAlgorithm() { b.failurePercentageAlgorithm() } - for _, val := range b.endpoints.Values() { - epInfo := val.(*endpointInfo) + for _, epInfo := range b.endpoints.Values() { if epInfo.latestEjectionTimestamp.IsZero() && epInfo.ejectionTimeMultiplier > 0 { epInfo.ejectionTimeMultiplier-- continue @@ -751,8 +746,7 @@ func (b *outlierDetectionBalancer) intervalTimerAlgorithm() { // Caller must hold b.mu. func (b *outlierDetectionBalancer) endpointsWithAtLeastRequestVolume(requestVolume uint32) []*endpointInfo { var endpoints []*endpointInfo - for _, val := range b.endpoints.Values() { - epInfo := val.(*endpointInfo) + for _, epInfo := range b.endpoints.Values() { bucket1 := epInfo.callCounter.inactiveBucket rv := bucket1.numSuccesses + bucket1.numFailures if rv >= requestVolume { diff --git a/xds/internal/balancer/ringhash/ring.go b/xds/internal/balancer/ringhash/ring.go index c2e556bb1662..978facf14333 100644 --- a/xds/internal/balancer/ringhash/ring.go +++ b/xds/internal/balancer/ringhash/ring.go @@ -68,7 +68,7 @@ type ringEntry struct { // and first item with hash >= given hash will be returned. // // Must be called with a non-empty endpoints map. -func newRing(endpoints *resolver.EndpointMap, minRingSize, maxRingSize uint64, logger *grpclog.PrefixLogger) *ring { +func newRing(endpoints *resolver.EndpointMap[*endpointState], minRingSize, maxRingSize uint64, logger *grpclog.PrefixLogger) *ring { if logger.V(2) { logger.Infof("newRing: number of endpoints is %d, minRingSize is %d, maxRingSize is %d", endpoints.Len(), minRingSize, maxRingSize) } @@ -136,18 +136,17 @@ func newRing(endpoints *resolver.EndpointMap, minRingSize, maxRingSize uint64, l // The endpoints are sorted in ascending order to ensure consistent results. // // Must be called with a non-empty endpoints map. -func normalizeWeights(endpoints *resolver.EndpointMap) ([]endpointInfo, float64) { +func normalizeWeights(endpoints *resolver.EndpointMap[*endpointState]) ([]endpointInfo, float64) { var weightSum uint32 // Since attributes are explicitly ignored in the EndpointMap key, we need // to iterate over the values to get the weights. endpointVals := endpoints.Values() - for _, a := range endpointVals { - weightSum += a.(*endpointState).weight + for _, epState := range endpointVals { + weightSum += epState.weight } ret := make([]endpointInfo, 0, endpoints.Len()) min := 1.0 - for _, a := range endpointVals { - epState := a.(*endpointState) + for _, epState := range endpointVals { // (*endpointState).weight is set to 1 if the weight attribute is not // found on the endpoint. And since this function is guaranteed to be // called with a non-empty endpoints map, weightSum is guaranteed to be diff --git a/xds/internal/balancer/ringhash/ring_test.go b/xds/internal/balancer/ringhash/ring_test.go index 108955b9727a..1d28bccc4bd8 100644 --- a/xds/internal/balancer/ringhash/ring_test.go +++ b/xds/internal/balancer/ringhash/ring_test.go @@ -30,7 +30,7 @@ import ( ) var testEndpoints []resolver.Endpoint -var testEndpointStateMap *resolver.EndpointMap +var testEndpointStateMap *resolver.EndpointMap[*endpointState] func init() { testEndpoints = []resolver.Endpoint{ @@ -38,7 +38,7 @@ func init() { testEndpoint("b", 3), testEndpoint("c", 4), } - testEndpointStateMap = resolver.NewEndpointMap() + testEndpointStateMap = resolver.NewEndpointMap[*endpointState]() testEndpointStateMap.Set(testEndpoints[0], &endpointState{firstAddr: "a", weight: 3}) testEndpointStateMap.Set(testEndpoints[1], &endpointState{firstAddr: "b", weight: 3}) testEndpointStateMap.Set(testEndpoints[2], &endpointState{firstAddr: "c", weight: 4}) diff --git a/xds/internal/balancer/ringhash/ringhash.go b/xds/internal/balancer/ringhash/ringhash.go index 216e16c33a99..26623378d4b9 100644 --- a/xds/internal/balancer/ringhash/ringhash.go +++ b/xds/internal/balancer/ringhash/ringhash.go @@ -55,7 +55,7 @@ type bb struct{} func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { b := &ringhashBalancer{ ClientConn: cc, - endpointStates: resolver.NewEndpointMap(), + endpointStates: resolver.NewEndpointMap[*endpointState](), } esOpts := endpointsharding.Options{DisableAutoReconnect: true} b.child = endpointsharding.NewBalancer(b, opts, lazyPickFirstBuilder, esOpts) @@ -86,7 +86,7 @@ type ringhashBalancer struct { config *LBConfig inhibitChildUpdates bool shouldRegenerateRing bool - endpointStates *resolver.EndpointMap // Map from endpoint -> *endpointState + endpointStates *resolver.EndpointMap[*endpointState] // ring is always in sync with endpoints. When endpoints change, a new ring // is generated. Note that address weights updates also regenerates the @@ -108,14 +108,15 @@ func (b *ringhashBalancer) UpdateState(state balancer.State) { defer b.mu.Unlock() childStates := endpointsharding.ChildStatesFromPicker(state.Picker) // endpointsSet is the set converted from endpoints, used for quick lookup. - endpointsSet := resolver.NewEndpointMap() + endpointsSet := resolver.NewEndpointMap[bool]() for _, childState := range childStates { endpoint := childState.Endpoint endpointsSet.Set(endpoint, true) newWeight := getWeightAttribute(endpoint) - if val, ok := b.endpointStates.Get(endpoint); !ok { - es := &endpointState{ + es, ok := b.endpointStates.Get(endpoint) + if !ok { + es = &endpointState{ balancer: childState.Balancer, weight: newWeight, firstAddr: endpoint.Addresses[0].Addr, @@ -128,7 +129,6 @@ func (b *ringhashBalancer) UpdateState(state balancer.State) { // object for it. If the weight or the first address of the endpoint // has changed, update the endpoint state map with the new weight. // This will be used when a new ring is created. - es := val.(*endpointState) if oldWeight := es.weight; oldWeight != newWeight { b.shouldRegenerateRing = true es.weight = newWeight @@ -240,8 +240,8 @@ func (b *ringhashBalancer) updatePickerLocked() { // ensure `ExitIdle` is called on the same child, preventing unnecessary // connections. var endpointStates = make([]*endpointState, b.endpointStates.Len()) - for i, val := range b.endpointStates.Values() { - endpointStates[i] = val.(*endpointState) + for i, s := range b.endpointStates.Values() { + endpointStates[i] = s } sort.Slice(endpointStates, func(i, j int) bool { return endpointStates[i].firstAddr < endpointStates[j].firstAddr @@ -300,8 +300,7 @@ func (b *ringhashBalancer) ExitIdle() { // re-generated every time an endpoint state is updated. func (b *ringhashBalancer) newPickerLocked() *picker { states := make(map[string]balancer.State) - for _, val := range b.endpointStates.Values() { - epState := val.(*endpointState) + for _, epState := range b.endpointStates.Values() { states[epState.firstAddr] = epState.state } return &picker{ring: b.ring, logger: b.logger, endpointStates: states} @@ -324,8 +323,7 @@ func (b *ringhashBalancer) newPickerLocked() *picker { // failure to failover to the lower priority. func (b *ringhashBalancer) aggregatedStateLocked() connectivity.State { var nums [5]int - for _, val := range b.endpointStates.Values() { - es := val.(*endpointState) + for _, es := range b.endpointStates.Values() { nums[es.state.ConnectivityState]++ } diff --git a/xds/internal/balancer/ringhash/ringhash_test.go b/xds/internal/balancer/ringhash/ringhash_test.go index 27bae77a4213..5ee45018ca9e 100644 --- a/xds/internal/balancer/ringhash/ringhash_test.go +++ b/xds/internal/balancer/ringhash/ringhash_test.go @@ -656,7 +656,7 @@ func (s) TestAggregatedConnectivityState(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - bal := &ringhashBalancer{endpointStates: resolver.NewEndpointMap()} + bal := &ringhashBalancer{endpointStates: resolver.NewEndpointMap[*endpointState]()} for i, cs := range tt.endpointStates { es := &endpointState{ state: balancer.State{ConnectivityState: cs},