Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions pkg/providers/v1/describe_instance_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ func (b *describeInstanceBatcher) DescribeInstances(ctx context.Context, input *
return nil, fmt.Errorf("expected to receive a single instance only, found %d", len(input.InstanceIds))
}
result := b.batcher.Add(ctx, input)
if result.Output == nil {
return nil, result.Err
}
return []*ec2types.Instance{result.Output}, result.Err
}

Expand All @@ -74,7 +77,8 @@ func describeInstanceHasher(ctx context.Context, input *ec2.DescribeInstancesInp
func execDescribeInstanceBatch(ec2api iface.EC2) batcher.BatchExecutor[ec2.DescribeInstancesInput, ec2types.Instance] {
return func(ctx context.Context, inputs []*ec2.DescribeInstancesInput) []batcher.Result[ec2types.Instance] {
results := make([]batcher.Result[ec2types.Instance], len(inputs))
firstInput := inputs[0]

firstInput := *inputs[0]
// aggregate instanceIDs into 1 input
for _, input := range inputs[1:] {
firstInput.InstanceIds = append(firstInput.InstanceIds, input.InstanceIds...)
Expand All @@ -92,7 +96,7 @@ func execDescribeInstanceBatch(ec2api iface.EC2) batcher.BatchExecutor[ec2.Descr
go func(input *ec2.DescribeInstancesInput) {
defer wg.Done()
out, err := ec2api.DescribeInstances(ctx, input)
if err != nil {
if err != nil || len(out) == 0 {
results[idx] = batcher.Result[ec2types.Instance]{Output: nil, Err: err}
return
}
Expand All @@ -104,7 +108,11 @@ func execDescribeInstanceBatch(ec2api iface.EC2) batcher.BatchExecutor[ec2.Descr
instanceIDToOutputMap := map[string]ec2types.Instance{}
lo.ForEach(output, func(o ec2types.Instance, _ int) { instanceIDToOutputMap[lo.FromPtr(o.InstanceId)] = o })
for idx, input := range inputs {
o := instanceIDToOutputMap[input.InstanceIds[0]]
o, ok := instanceIDToOutputMap[input.InstanceIds[0]]
if !ok {
results[idx] = batcher.Result[ec2types.Instance]{Output: nil}
continue
}
results[idx] = batcher.Result[ec2types.Instance]{Output: &o}
}
}
Expand Down
136 changes: 136 additions & 0 deletions pkg/providers/v1/instances_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,142 @@ func TestDescribeInstanceBatching(t *testing.T) {
mockedEC2API.AssertNumberOfCalls(t, "DescribeInstances", 1)
}

// TestDescribeInstanceBatchingWithInstanceDoesntExist will test where one of the instance doesn't exist
func TestDescribeInstanceBatchingWithInstanceDoesntExist(t *testing.T) {
mockedEC2API := newMockedEC2API()
batcher := newdescribeInstanceBatcher(context.Background(), &awsSdkEC2{ec2: mockedEC2API})

mockedEC2API.On("DescribeInstances", mock.Anything).Return(&ec2.DescribeInstancesOutput{
Reservations: []ec2types.Reservation{
{
Instances: []ec2types.Instance{
{
InstanceId: aws.String("Test-1"),
},
{
InstanceId: aws.String("Test-2"),
},
{
InstanceId: aws.String("Test-3"),
},
},
},
},
}, nil)

type result struct {
input string
output []*ec2types.Instance
err error
isInstanceNotFound bool
}

// Add extra space to channel so that we can ensure there were only 3 responses
resCh := make(chan result, 5)
helper := func(wg *sync.WaitGroup, input string, isInstanceNotFound bool) {
defer wg.Done()
res, err := batcher.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{InstanceIds: []string{input}})
resCh <- result{input: input, output: res, err: err, isInstanceNotFound: isInstanceNotFound}
}

wg := sync.WaitGroup{}
wg.Add(3)
go helper(&wg, "Test-1", false)
go helper(&wg, "Test-2", false)
go helper(&wg, "Test-4", true)
wg.Wait()
close(resCh)

assert.Len(t, resCh, 3)
for res := range resCh {
assert.NoError(t, res.err)
if !res.isInstanceNotFound {
assert.Len(t, res.output, 1)
assert.Equal(t, res.input, *res.output[0].InstanceId)
} else {
assert.Len(t, res.output, 0)
}
}

mockedEC2API.AssertNumberOfCalls(t, "DescribeInstances", 1)
}

// TestDescribeInstanceBatchingWithBatchedRequestFail will test where batched request fails but individual request succeeds
func TestDescribeInstanceBatchingWithBatchedRequestFail(t *testing.T) {
mockedEC2API := newMockedEC2API()
batcher := newdescribeInstanceBatcher(context.Background(), &awsSdkEC2{ec2: mockedEC2API})
var nilResp *ec2.DescribeInstancesOutput
mockedEC2API.On("DescribeInstances", &ec2.DescribeInstancesInput{
InstanceIds: []string{"Test-1", "Test-2"},
}).Return(nilResp, errors.New("instances not found"))
mockedEC2API.On("DescribeInstances", &ec2.DescribeInstancesInput{
InstanceIds: []string{"Test-2", "Test-1"},
}).Return(nilResp, errors.New("instances not found"))
mockedEC2API.On("DescribeInstances",
&ec2.DescribeInstancesInput{
InstanceIds: []string{"Test-1"},
},
).Return(
&ec2.DescribeInstancesOutput{
Reservations: []ec2types.Reservation{
{
Instances: []ec2types.Instance{
{
InstanceId: aws.String("Test-1"),
},
},
},
},
},
nil,
)
mockedEC2API.On("DescribeInstances",
&ec2.DescribeInstancesInput{
InstanceIds: []string{"Test-2"},
},
).Return(
&ec2.DescribeInstancesOutput{
Reservations: []ec2types.Reservation{},
},
nil,
)

type result struct {
input string
output []*ec2types.Instance
err error
isInstanceNotFound bool
}

// Add extra space to channel so that we can ensure there were only 3 responses
resCh := make(chan result, 5)
helper := func(wg *sync.WaitGroup, input string, isInstanceNotFound bool) {
defer wg.Done()
res, err := batcher.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{InstanceIds: []string{input}})
resCh <- result{input: input, output: res, err: err, isInstanceNotFound: isInstanceNotFound}
}

wg := sync.WaitGroup{}
wg.Add(2)
go helper(&wg, "Test-1", false)
go helper(&wg, "Test-2", true)
wg.Wait()
close(resCh)

assert.Len(t, resCh, 2)
for res := range resCh {
assert.NoError(t, res.err)
if !res.isInstanceNotFound {
assert.Len(t, res.output, 1)
assert.Equal(t, res.input, *res.output[0].InstanceId)
} else {
assert.Len(t, res.output, 0)
}
}

mockedEC2API.AssertNumberOfCalls(t, "DescribeInstances", 3)
}

func getCloudWithMockedDescribeInstances(instanceExists bool, instanceState ec2types.InstanceStateName, instanceID string) *Cloud {
mockedEC2API := newMockedEC2API()
c := &Cloud{ec2: &awsSdkEC2{ec2: mockedEC2API}, describeInstanceBatcher: newdescribeInstanceBatcher(context.Background(), &awsSdkEC2{ec2: mockedEC2API})}
Expand Down
Loading