diff --git a/pkg/providers/v1/aws_loadbalancer.go b/pkg/providers/v1/aws_loadbalancer.go index f9dd598450..62f68690ad 100644 --- a/pkg/providers/v1/aws_loadbalancer.go +++ b/pkg/providers/v1/aws_loadbalancer.go @@ -859,18 +859,7 @@ func (c *Cloud) ensureTargetGroup(ctx context.Context, targetGroup *elbv2types.T func (c *Cloud) ensureTargetGroupTargets(ctx context.Context, tgARN string, expectedTargets []*elbv2types.TargetDescription, actualTargets []*elbv2types.TargetDescription) error { targetsToRegister, targetsToDeregister := c.diffTargetGroupTargets(expectedTargets, actualTargets) - if len(targetsToRegister) > 0 { - targetsToRegisterChunks := c.chunkTargetDescriptions(targetsToRegister, defaultRegisterTargetsChunkSize) - for _, targetsChunk := range targetsToRegisterChunks { - req := &elbv2.RegisterTargetsInput{ - TargetGroupArn: aws.String(tgARN), - Targets: targetsChunk, - } - if _, err := c.elbv2.RegisterTargets(ctx, req); err != nil { - return fmt.Errorf("error trying to register targets in target group: %q", err) - } - } - } + // deregister targets prior to registering to allow instance replacements when the LB is at max instance capacity if len(targetsToDeregister) > 0 { targetsToDeregisterChunks := c.chunkTargetDescriptions(targetsToDeregister, defaultDeregisterTargetsChunkSize) for _, targetsChunk := range targetsToDeregisterChunks { @@ -883,6 +872,18 @@ func (c *Cloud) ensureTargetGroupTargets(ctx context.Context, tgARN string, expe } } } + if len(targetsToRegister) > 0 { + targetsToRegisterChunks := c.chunkTargetDescriptions(targetsToRegister, defaultRegisterTargetsChunkSize) + for _, targetsChunk := range targetsToRegisterChunks { + req := &elbv2.RegisterTargetsInput{ + TargetGroupArn: aws.String(tgARN), + Targets: targetsChunk, + } + if _, err := c.elbv2.RegisterTargets(ctx, req); err != nil { + return fmt.Errorf("error trying to register targets in target group: %q", err) + } + } + } return nil } diff --git a/pkg/providers/v1/aws_loadbalancer_test.go b/pkg/providers/v1/aws_loadbalancer_test.go index 39037205e8..1b19e93d33 100644 --- a/pkg/providers/v1/aws_loadbalancer_test.go +++ b/pkg/providers/v1/aws_loadbalancer_test.go @@ -1921,3 +1921,157 @@ func TestCloud_reconcileTargetGroupsAttributes(t *testing.T) { }) } } + +// Test-specific mock for ELB v2 client that embeds MockedFakeELBV2 +type mockELBV2ClientForEnsureTargetGroupTargets struct { + *MockedFakeELBV2 + + MaxTargets int + NumTargets int +} + +func (m *mockELBV2ClientForEnsureTargetGroupTargets) RegisterTargets(ctx context.Context, input *elbv2.RegisterTargetsInput, optFns ...func(*elbv2.Options)) (*elbv2.RegisterTargetsOutput, error) { + m.NumTargets += len(input.Targets) + + if m.NumTargets > m.MaxTargets { + return nil, fmt.Errorf("TooManyTargets") + } + + return m.MockedFakeELBV2.RegisterTargets(ctx, input, optFns...) +} + +func (m *mockELBV2ClientForEnsureTargetGroupTargets) DeregisterTargets(ctx context.Context, input *elbv2.DeregisterTargetsInput, optFns ...func(*elbv2.Options)) (*elbv2.DeregisterTargetsOutput, error) { + m.NumTargets -= len(input.Targets) + + return m.MockedFakeELBV2.DeregisterTargets(ctx, input, optFns...) +} + +func TestCloud_ensureTargetGroupTargets(t *testing.T) { + testTargetGroupArn := "arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/test-tg/1234567890123456" + + tests := []struct { + name string + maxTargets int + expectedTargets []*elbv2types.TargetDescription + actualTargets []*elbv2types.TargetDescription + expectedError string + description string + }{ + { + name: "target replacement at max target limit should not fail", + maxTargets: 4, + expectedTargets: []*elbv2types.TargetDescription{ + { + Id: aws.String("i-abcdefg1"), + Port: aws.Int32(8080), + }, + { + Id: aws.String("i-abcdefg2"), + Port: aws.Int32(8080), + }, + { + Id: aws.String("i-abcdefg3"), + Port: aws.Int32(8080), + }, + { + Id: aws.String("i-abcdefg4"), + Port: aws.Int32(8080), + }, + }, + actualTargets: []*elbv2types.TargetDescription{ + { + Id: aws.String("i-replacement"), + Port: aws.Int32(8080), + }, + { + Id: aws.String("i-abcdefg2"), + Port: aws.Int32(8080), + }, + { + Id: aws.String("i-abcdefg3"), + Port: aws.Int32(8080), + }, + { + Id: aws.String("i-abcdefg4"), + Port: aws.Int32(8080), + }, + }, + description: "Function should succeed when replacing an instance for a LB at max capacity", + }, + { + name: "exceeding max target limit should fail", + maxTargets: 4, + expectedTargets: []*elbv2types.TargetDescription{ + { + Id: aws.String("i-abcdefg1"), + Port: aws.Int32(8080), + }, + { + Id: aws.String("i-abcdefg2"), + Port: aws.Int32(8080), + }, + { + Id: aws.String("i-abcdefg3"), + Port: aws.Int32(8080), + }, + { + Id: aws.String("i-abcdefg4"), + Port: aws.Int32(8080), + }, + { + Id: aws.String("i-abcdefg5"), + Port: aws.Int32(8080), + }, + }, + actualTargets: []*elbv2types.TargetDescription{ + { + Id: aws.String("i-replacement"), + Port: aws.Int32(8080), + }, + { + Id: aws.String("i-abcdefg2"), + Port: aws.Int32(8080), + }, + { + Id: aws.String("i-abcdefg3"), + Port: aws.Int32(8080), + }, + { + Id: aws.String("i-abcdefg4"), + Port: aws.Int32(8080), + }, + }, + expectedError: "TooManyTargets", + description: "Function should fail when adding an instance to a LB at max capacity", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockClient := &mockELBV2ClientForEnsureTargetGroupTargets{ + MockedFakeELBV2: &MockedFakeELBV2{ + LoadBalancers: []*elbv2types.LoadBalancer{}, + TargetGroups: []*elbv2types.TargetGroup{}, + Listeners: []*elbv2types.Listener{}, + LoadBalancerAttributes: make(map[string]map[string]string), + Tags: make(map[string][]elbv2types.Tag), + RegisteredInstances: make(map[string][]string), + }, + MaxTargets: tt.maxTargets, + NumTargets: len(tt.actualTargets), + } + c := &Cloud{ + elbv2: mockClient, + } + + err := c.ensureTargetGroupTargets(context.TODO(), testTargetGroupArn, tt.expectedTargets, tt.actualTargets) + + if len(tt.expectedError) > 0 { + assert.Error(t, err, "Expected error for test case: %s", tt.description) + assert.Contains(t, err.Error(), tt.expectedError, "Error message should contain expected text for test case: %s", tt.description) + } else { + assert.NoError(t, err, "Expected no error for test case: %s", tt.description) + } + }) + } +}