diff --git a/pkg/providers/v1/describe_instance_batch.go b/pkg/providers/v1/describe_instance_batch.go index 3001827e0d..9059c0fcc6 100644 --- a/pkg/providers/v1/describe_instance_batch.go +++ b/pkg/providers/v1/describe_instance_batch.go @@ -57,6 +57,9 @@ func (b *describeInstanceBatcher) DescribeInstances(ctx context.Context, input * return nil, fmt.Errorf("expected to receive a single instance only, found %d", len(input.InstanceIds)) } result := b.batcher.Add(ctx, input) + if result.Output == nil { + return nil, result.Err + } return []*ec2types.Instance{result.Output}, result.Err } @@ -74,7 +77,8 @@ func describeInstanceHasher(ctx context.Context, input *ec2.DescribeInstancesInp func execDescribeInstanceBatch(ec2api iface.EC2) batcher.BatchExecutor[ec2.DescribeInstancesInput, ec2types.Instance] { return func(ctx context.Context, inputs []*ec2.DescribeInstancesInput) []batcher.Result[ec2types.Instance] { results := make([]batcher.Result[ec2types.Instance], len(inputs)) - firstInput := inputs[0] + + firstInput := *inputs[0] // aggregate instanceIDs into 1 input for _, input := range inputs[1:] { firstInput.InstanceIds = append(firstInput.InstanceIds, input.InstanceIds...) @@ -92,7 +96,7 @@ func execDescribeInstanceBatch(ec2api iface.EC2) batcher.BatchExecutor[ec2.Descr go func(input *ec2.DescribeInstancesInput) { defer wg.Done() out, err := ec2api.DescribeInstances(ctx, input) - if err != nil { + if err != nil || len(out) == 0 { results[idx] = batcher.Result[ec2types.Instance]{Output: nil, Err: err} return } @@ -104,7 +108,11 @@ func execDescribeInstanceBatch(ec2api iface.EC2) batcher.BatchExecutor[ec2.Descr instanceIDToOutputMap := map[string]ec2types.Instance{} lo.ForEach(output, func(o ec2types.Instance, _ int) { instanceIDToOutputMap[lo.FromPtr(o.InstanceId)] = o }) for idx, input := range inputs { - o := instanceIDToOutputMap[input.InstanceIds[0]] + o, ok := instanceIDToOutputMap[input.InstanceIds[0]] + if !ok { + results[idx] = batcher.Result[ec2types.Instance]{Output: nil} + continue + } results[idx] = batcher.Result[ec2types.Instance]{Output: &o} } } diff --git a/pkg/providers/v1/instances_v2_test.go b/pkg/providers/v1/instances_v2_test.go index 6a98a433aa..eaf500b4da 100644 --- a/pkg/providers/v1/instances_v2_test.go +++ b/pkg/providers/v1/instances_v2_test.go @@ -370,6 +370,142 @@ func TestDescribeInstanceBatching(t *testing.T) { mockedEC2API.AssertNumberOfCalls(t, "DescribeInstances", 1) } +// TestDescribeInstanceBatchingWithInstanceDoesntExist will test where one of the instance doesn't exist +func TestDescribeInstanceBatchingWithInstanceDoesntExist(t *testing.T) { + mockedEC2API := newMockedEC2API() + batcher := newdescribeInstanceBatcher(context.Background(), &awsSdkEC2{ec2: mockedEC2API}) + + mockedEC2API.On("DescribeInstances", mock.Anything).Return(&ec2.DescribeInstancesOutput{ + Reservations: []ec2types.Reservation{ + { + Instances: []ec2types.Instance{ + { + InstanceId: aws.String("Test-1"), + }, + { + InstanceId: aws.String("Test-2"), + }, + { + InstanceId: aws.String("Test-3"), + }, + }, + }, + }, + }, nil) + + type result struct { + input string + output []*ec2types.Instance + err error + isInstanceNotFound bool + } + + // Add extra space to channel so that we can ensure there were only 3 responses + resCh := make(chan result, 5) + helper := func(wg *sync.WaitGroup, input string, isInstanceNotFound bool) { + defer wg.Done() + res, err := batcher.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{InstanceIds: []string{input}}) + resCh <- result{input: input, output: res, err: err, isInstanceNotFound: isInstanceNotFound} + } + + wg := sync.WaitGroup{} + wg.Add(3) + go helper(&wg, "Test-1", false) + go helper(&wg, "Test-2", false) + go helper(&wg, "Test-4", true) + wg.Wait() + close(resCh) + + assert.Len(t, resCh, 3) + for res := range resCh { + assert.NoError(t, res.err) + if !res.isInstanceNotFound { + assert.Len(t, res.output, 1) + assert.Equal(t, res.input, *res.output[0].InstanceId) + } else { + assert.Len(t, res.output, 0) + } + } + + mockedEC2API.AssertNumberOfCalls(t, "DescribeInstances", 1) +} + +// TestDescribeInstanceBatchingWithBatchedRequestFail will test where batched request fails but individual request succeeds +func TestDescribeInstanceBatchingWithBatchedRequestFail(t *testing.T) { + mockedEC2API := newMockedEC2API() + batcher := newdescribeInstanceBatcher(context.Background(), &awsSdkEC2{ec2: mockedEC2API}) + var nilResp *ec2.DescribeInstancesOutput + mockedEC2API.On("DescribeInstances", &ec2.DescribeInstancesInput{ + InstanceIds: []string{"Test-1", "Test-2"}, + }).Return(nilResp, errors.New("instances not found")) + mockedEC2API.On("DescribeInstances", &ec2.DescribeInstancesInput{ + InstanceIds: []string{"Test-2", "Test-1"}, + }).Return(nilResp, errors.New("instances not found")) + mockedEC2API.On("DescribeInstances", + &ec2.DescribeInstancesInput{ + InstanceIds: []string{"Test-1"}, + }, + ).Return( + &ec2.DescribeInstancesOutput{ + Reservations: []ec2types.Reservation{ + { + Instances: []ec2types.Instance{ + { + InstanceId: aws.String("Test-1"), + }, + }, + }, + }, + }, + nil, + ) + mockedEC2API.On("DescribeInstances", + &ec2.DescribeInstancesInput{ + InstanceIds: []string{"Test-2"}, + }, + ).Return( + &ec2.DescribeInstancesOutput{ + Reservations: []ec2types.Reservation{}, + }, + nil, + ) + + type result struct { + input string + output []*ec2types.Instance + err error + isInstanceNotFound bool + } + + // Add extra space to channel so that we can ensure there were only 3 responses + resCh := make(chan result, 5) + helper := func(wg *sync.WaitGroup, input string, isInstanceNotFound bool) { + defer wg.Done() + res, err := batcher.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{InstanceIds: []string{input}}) + resCh <- result{input: input, output: res, err: err, isInstanceNotFound: isInstanceNotFound} + } + + wg := sync.WaitGroup{} + wg.Add(2) + go helper(&wg, "Test-1", false) + go helper(&wg, "Test-2", true) + wg.Wait() + close(resCh) + + assert.Len(t, resCh, 2) + for res := range resCh { + assert.NoError(t, res.err) + if !res.isInstanceNotFound { + assert.Len(t, res.output, 1) + assert.Equal(t, res.input, *res.output[0].InstanceId) + } else { + assert.Len(t, res.output, 0) + } + } + + mockedEC2API.AssertNumberOfCalls(t, "DescribeInstances", 3) +} + func getCloudWithMockedDescribeInstances(instanceExists bool, instanceState ec2types.InstanceStateName, instanceID string) *Cloud { mockedEC2API := newMockedEC2API() c := &Cloud{ec2: &awsSdkEC2{ec2: mockedEC2API}, describeInstanceBatcher: newdescribeInstanceBatcher(context.Background(), &awsSdkEC2{ec2: mockedEC2API})}