Skip to content

Commit a626bc0

Browse files
committed
remove nil ptr dereferences
1 parent e955e76 commit a626bc0

5 files changed

Lines changed: 65 additions & 13 deletions

File tree

pkg/providers/v1/aws.go

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,7 +1004,7 @@ func (c *Cloud) GetZoneByProviderID(ctx context.Context, providerID string) (clo
10041004
}
10051005

10061006
instance, err := c.getInstanceByID(ctx, string(instanceID))
1007-
if err != nil {
1007+
if err != nil || instance == nil {
10081008
return cloudprovider.Zone{}, err
10091009
}
10101010
return c.getInstanceZone(instance), nil
@@ -1078,7 +1078,7 @@ func (c *Cloud) buildSelfAWSInstance(ctx context.Context) (*awsInstance, error)
10781078
defer instanceIDMetadata.Content.Close()
10791079

10801080
instance, err := c.getInstanceByID(ctx, string(instanceIDBytes))
1081-
if err != nil {
1081+
if err != nil || instance == nil {
10821082
return nil, fmt.Errorf("error finding instance %s: %q", string(instanceIDBytes), err)
10831083
}
10841084
return newAWSInstance(c.ec2, instance), nil
@@ -3146,7 +3146,7 @@ func (c *Cloud) UpdateLoadBalancer(ctx context.Context, clusterName string, serv
31463146
// Returns the instance with the specified ID
31473147
func (c *Cloud) getInstanceByID(ctx context.Context, instanceID string) (*ec2types.Instance, error) {
31483148
instances, err := c.getInstancesByIDs(ctx, []string{instanceID})
3149-
if err != nil {
3149+
if err != nil || instances == nil {
31503150
return nil, err
31513151
}
31523152

@@ -3347,13 +3347,6 @@ func (c *Cloud) getFullInstance(ctx context.Context, nodeName types.NodeName) (*
33473347
return awsInstance, instance, err
33483348
}
33493349

3350-
// extract private ip address from node name
3351-
func nodeNameToIPAddress(nodeName string) string {
3352-
nodeName = strings.TrimPrefix(nodeName, privateDNSNamePrefix)
3353-
nodeName = strings.Split(nodeName, ".")[0]
3354-
return strings.ReplaceAll(nodeName, "-", ".")
3355-
}
3356-
33573350
func (c *Cloud) nodeNameToInstanceID(nodeName types.NodeName) (InstanceID, error) {
33583351
if strings.HasPrefix(string(nodeName), rbnNamePrefix) {
33593352
// depending on if you use a RHEL (e.g. AL2) or Debian (e.g. standard Ubuntu) based distribution, the

pkg/providers/v1/aws_fakes.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import (
3232
"github.com/aws/aws-sdk-go-v2/service/ec2"
3333
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
3434
elb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing"
35+
elbtypes "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing/types"
3536
elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
3637
"github.com/aws/aws-sdk-go-v2/service/kms"
3738
"k8s.io/klog/v2"
@@ -622,7 +623,13 @@ func (e *FakeELB) SetLoadBalancerPoliciesOfListener(ctx context.Context, input *
622623
// DescribeLoadBalancerPolicies is not implemented but is required for
623624
// interface conformance
624625
func (e *FakeELB) DescribeLoadBalancerPolicies(ctx context.Context, input *elb.DescribeLoadBalancerPoliciesInput, opts ...func(*elb.Options)) (*elb.DescribeLoadBalancerPoliciesOutput, error) {
625-
panic("Not implemented")
626+
if aws.ToString(input.LoadBalancerName) == "" {
627+
return nil, &elbtypes.LoadBalancerAttributeNotFoundException{}
628+
}
629+
if len(input.PolicyNames) == 0 || input.PolicyNames[0] == "k8s-SSLNegotiationPolicy-" {
630+
return nil, &elbtypes.PolicyNotFoundException{}
631+
}
632+
return &elb.DescribeLoadBalancerPoliciesOutput{}, nil
626633
}
627634

628635
// DescribeLoadBalancerAttributes is not implemented but is required for

pkg/providers/v1/aws_loadbalancer.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLB(ctx context.Context, lbName s
816816
{
817817
for sgID := range desiredSGIDs.Difference(sets.StringKeySet(clusterSGs)) {
818818
sg, err := c.findSecurityGroup(ctx, sgID)
819-
if err != nil {
819+
if err != nil || sg == nil {
820820
return fmt.Errorf("error finding instance group: %q", err)
821821
}
822822
clusterSGs[sgID] = sg
@@ -1512,6 +1512,7 @@ func (c *Cloud) ensureSSLNegotiationPolicy(ctx context.Context, loadBalancer *el
15121512
if !errors.As(err, &notFoundErr) {
15131513
return fmt.Errorf("error describing security policies on load balancer: %q", err)
15141514
}
1515+
return nil
15151516
}
15161517

15171518
if len(result.PolicyDescriptions) > 0 {

pkg/providers/v1/aws_loadbalancer_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,3 +1073,50 @@ func TestCloud_computeTargetGroupExpectedTargets(t *testing.T) {
10731073
})
10741074
}
10751075
}
1076+
1077+
// Make sure that errors returned by DescribeLoadBalancerPolicies are
1078+
// handled gracefully, and don't progress further into the function
1079+
func TestEnsureSSLNegotiationPolicyErrorHandling(t *testing.T) {
1080+
awsServices := NewFakeAWSServices(TestClusterID)
1081+
c, err := newAWSCloud(config.CloudConfig{}, awsServices)
1082+
if err != nil {
1083+
t.Errorf("Error building aws cloud: %v", err)
1084+
return
1085+
}
1086+
1087+
tests := []struct {
1088+
name string
1089+
loadBalancer *elbtypes.LoadBalancerDescription
1090+
policyName string
1091+
expectError bool
1092+
}{
1093+
{
1094+
name: "Expect LoadBalancerAttributeNotFoundException, error",
1095+
loadBalancer: &elbtypes.LoadBalancerDescription{
1096+
LoadBalancerName: aws.String(""),
1097+
},
1098+
policyName: "",
1099+
expectError: true,
1100+
},
1101+
{
1102+
name: "Expect PolicyNotFoundException, nil error",
1103+
loadBalancer: &elbtypes.LoadBalancerDescription{
1104+
LoadBalancerName: aws.String("test-lb"),
1105+
},
1106+
policyName: "",
1107+
expectError: false,
1108+
},
1109+
}
1110+
1111+
for _, test := range tests {
1112+
t.Run(test.name, func(t *testing.T) {
1113+
err := c.ensureSSLNegotiationPolicy(context.TODO(), test.loadBalancer, test.policyName)
1114+
if test.expectError && err == nil {
1115+
t.Errorf("Expected error but got none")
1116+
}
1117+
if !test.expectError && err != nil {
1118+
t.Errorf("Expected no error but got: %v", err)
1119+
}
1120+
})
1121+
}
1122+
}

pkg/providers/v1/instances_v2.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ package aws
2222

2323
import (
2424
"context"
25+
"errors"
2526
"fmt"
2627
"k8s.io/cloud-provider-aws/pkg/providers/v1/variant"
2728
"strconv"
@@ -33,6 +34,9 @@ import (
3334
)
3435

3536
func (c *Cloud) getProviderID(ctx context.Context, node *v1.Node) (string, error) {
37+
if node == nil {
38+
return "", errors.New("error getting provider id, node is nil")
39+
}
3640
if node.Spec.ProviderID != "" {
3741
return node.Spec.ProviderID, nil
3842
}
@@ -147,7 +151,7 @@ func (c *Cloud) InstanceMetadata(ctx context.Context, node *v1.Node) (*cloudprov
147151
}
148152
} else {
149153
instance, err := c.getInstanceByID(ctx, string(instanceID))
150-
if err != nil {
154+
if err != nil || instance == nil {
151155
return nil, fmt.Errorf("failed to get instance by ID %s: %w", instanceID, err)
152156
}
153157

0 commit comments

Comments
 (0)