Skip to content

Commit 1137413

Browse files
authored
Merge pull request #1155 from fletcherw/fletcherw/session_token
Support ServiceAccountToken in ecr-credential-provider
2 parents b47d2cf + cecce90 commit 1137413

5 files changed

Lines changed: 292 additions & 27 deletions

File tree

cmd/ecr-credential-provider/main.go

Lines changed: 98 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"github.com/aws/aws-sdk-go-v2/config"
3232
"github.com/aws/aws-sdk-go-v2/service/ecr"
3333
"github.com/aws/aws-sdk-go-v2/service/ecrpublic"
34+
"github.com/aws/aws-sdk-go-v2/service/sts"
3435
"github.com/spf13/cobra"
3536

3637
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@@ -54,9 +55,15 @@ type ECRPublic interface {
5455
GetAuthorizationToken(ctx context.Context, params *ecrpublic.GetAuthorizationTokenInput, optFns ...func(*ecrpublic.Options)) (*ecrpublic.GetAuthorizationTokenOutput, error)
5556
}
5657

58+
// STS abstracts the calls we make to aws-sdk for testing purposes
59+
type STS interface {
60+
AssumeRoleWithWebIdentity(context.Context, *sts.AssumeRoleWithWebIdentityInput, ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error)
61+
}
62+
5763
type ecrPlugin struct {
5864
ecr ECR
5965
ecrPublic ECRPublic
66+
sts STS
6067
}
6168

6269
func defaultECRProvider(ctx context.Context, region string) (ECR, error) {
@@ -91,12 +98,30 @@ func publicECRProvider(ctx context.Context) (ECRPublic, error) {
9198
return ecrpublic.NewFromConfig(cfg), nil
9299
}
93100

101+
func stsProvider(ctx context.Context, region string) (STS, error) {
102+
var cfg aws.Config
103+
var err error
104+
if region != "" {
105+
cfg, err = config.LoadDefaultConfig(ctx,
106+
config.WithRegion(region),
107+
)
108+
} else {
109+
klog.Warningf("No region found in the image reference, the default region will be used. Please refer to AWS SDK documentation for configuration purpose.")
110+
cfg, err = config.LoadDefaultConfig(ctx)
111+
}
112+
113+
if err != nil {
114+
return nil, err
115+
}
116+
return sts.NewFromConfig(cfg), nil
117+
}
118+
94119
type credsData struct {
95120
authToken *string
96121
expiresAt *time.Time
97122
}
98123

99-
func (e *ecrPlugin) getPublicCredsData(ctx context.Context) (*credsData, error) {
124+
func (e *ecrPlugin) getPublicCredsData(ctx context.Context, optFns ...func(*ecrpublic.Options)) (*credsData, error) {
100125
klog.Infof("Getting creds for public registry")
101126
var err error
102127

@@ -107,7 +132,7 @@ func (e *ecrPlugin) getPublicCredsData(ctx context.Context) (*credsData, error)
107132
return nil, err
108133
}
109134

110-
output, err := e.ecrPublic.GetAuthorizationToken(ctx, &ecrpublic.GetAuthorizationTokenInput{})
135+
output, err := e.ecrPublic.GetAuthorizationToken(ctx, &ecrpublic.GetAuthorizationTokenInput{}, optFns...)
111136
if err != nil {
112137
return nil, err
113138
}
@@ -126,7 +151,7 @@ func (e *ecrPlugin) getPublicCredsData(ctx context.Context) (*credsData, error)
126151
}, nil
127152
}
128153

129-
func (e *ecrPlugin) getPrivateCredsData(ctx context.Context, imageHost string, image string) (*credsData, error) {
154+
func (e *ecrPlugin) getPrivateCredsData(ctx context.Context, imageHost string, image string, optFns ...func(*ecr.Options)) (*credsData, error) {
130155
klog.Infof("Getting creds for private image %s", image)
131156
var err error
132157

@@ -137,7 +162,8 @@ func (e *ecrPlugin) getPrivateCredsData(ctx context.Context, imageHost string, i
137162
return nil, err
138163
}
139164
}
140-
output, err := e.ecr.GetAuthorizationToken(ctx, &ecr.GetAuthorizationTokenInput{})
165+
166+
output, err := e.ecr.GetAuthorizationToken(ctx, &ecr.GetAuthorizationTokenInput{}, optFns...)
141167
if err != nil {
142168
return nil, err
143169
}
@@ -153,19 +179,83 @@ func (e *ecrPlugin) getPrivateCredsData(ctx context.Context, imageHost string, i
153179
}, nil
154180
}
155181

156-
func (e *ecrPlugin) GetCredentials(ctx context.Context, image string, args []string) (*v1.CredentialProviderResponse, error) {
182+
func (e *ecrPlugin) buildCredentialsProvider(ctx context.Context, request *v1.CredentialProviderRequest, imageHost string) (aws.CredentialsProvider, error) {
183+
var err error
184+
185+
arn, ok := request.ServiceAccountAnnotations["eks.amazonaws.com/ecr-role-arn"]
186+
if !ok {
187+
arn = os.Getenv("AWS_ECR_ROLE_ARN")
188+
}
189+
if arn == "" {
190+
return nil, errors.New("no arn provided, cannot assume role using ServiceAccountToken")
191+
}
192+
193+
if e.sts == nil {
194+
region := ""
195+
if imageHost != ecrPublicHost {
196+
region = parseRegionFromECRPrivateHost(imageHost)
197+
}
198+
e.sts, err = stsProvider(ctx, region)
199+
}
200+
if err != nil {
201+
return nil, err
202+
}
203+
204+
return aws.CredentialsProviderFunc(func(ctx context.Context) (aws.Credentials, error) {
205+
assumeOutput, err := e.sts.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityInput{
206+
RoleArn: aws.String(arn),
207+
RoleSessionName: aws.String("ecr-credential-provider"),
208+
WebIdentityToken: aws.String(request.ServiceAccountToken),
209+
})
210+
if err != nil {
211+
return aws.Credentials{}, fmt.Errorf("failed to assume role: %w", err)
212+
}
213+
return aws.Credentials{
214+
AccessKeyID: *assumeOutput.Credentials.AccessKeyId,
215+
SecretAccessKey: *assumeOutput.Credentials.SecretAccessKey,
216+
SessionToken: *assumeOutput.Credentials.SessionToken,
217+
}, nil
218+
}),
219+
nil
220+
}
221+
222+
func (e *ecrPlugin) GetCredentials(ctx context.Context, request *v1.CredentialProviderRequest, args []string) (*v1.CredentialProviderResponse, error) {
157223
var creds *credsData
158224
var err error
159225

160-
imageHost, err := parseHostFromImageReference(image)
226+
if request.Image == "" {
227+
return nil, errors.New("image in plugin request was empty")
228+
}
229+
230+
imageHost, err := parseHostFromImageReference(request.Image)
161231
if err != nil {
162232
return nil, err
163233
}
164234

235+
var credentialsProvider aws.CredentialsProvider = nil
236+
if request.ServiceAccountToken != "" {
237+
credentialsProvider, err = e.buildCredentialsProvider(ctx, request, imageHost)
238+
if err != nil {
239+
return nil, err
240+
}
241+
}
242+
165243
if imageHost == ecrPublicHost {
166-
creds, err = e.getPublicCredsData(ctx)
244+
var optFns = []func(*ecrpublic.Options){}
245+
if credentialsProvider != nil {
246+
optFns = append(optFns, func(o *ecrpublic.Options) {
247+
o.Credentials = credentialsProvider
248+
})
249+
}
250+
creds, err = e.getPublicCredsData(ctx, optFns...)
167251
} else {
168-
creds, err = e.getPrivateCredsData(ctx, imageHost, image)
252+
var optFns = []func(*ecr.Options){}
253+
if credentialsProvider != nil {
254+
optFns = append(optFns, func(o *ecr.Options) {
255+
o.Credentials = credentialsProvider
256+
})
257+
}
258+
creds, err = e.getPrivateCredsData(ctx, imageHost, request.Image, optFns...)
169259
}
170260

171261
if err != nil {

cmd/ecr-credential-provider/main_test.go

Lines changed: 128 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import (
2929
"github.com/aws/aws-sdk-go-v2/service/ecr/types"
3030
"github.com/aws/aws-sdk-go-v2/service/ecrpublic"
3131
publictypes "github.com/aws/aws-sdk-go-v2/service/ecrpublic/types"
32+
"github.com/aws/aws-sdk-go-v2/service/sts"
33+
ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
3234
"github.com/stretchr/testify/mock"
3335
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
3436
v1 "k8s.io/kubelet/pkg/apis/credentialprovider/v1"
@@ -40,6 +42,15 @@ type MockedECR struct {
4042

4143
func (m *MockedECR) GetAuthorizationToken(ctx context.Context, params *ecr.GetAuthorizationTokenInput, optFns ...func(*ecr.Options)) (*ecr.GetAuthorizationTokenOutput, error) {
4244
args := m.Called(ctx, params)
45+
46+
opts := ecr.Options{}
47+
for _, fn := range optFns {
48+
fn(&opts)
49+
}
50+
if opts.Credentials != nil {
51+
opts.Credentials.Retrieve(ctx)
52+
}
53+
4354
if args.Get(1) != nil {
4455
return args.Get(0).(*ecr.GetAuthorizationTokenOutput), args.Get(1).(error)
4556
}
@@ -59,6 +70,18 @@ func (m *MockedECRPublic) GetAuthorizationToken(ctx context.Context, params *ecr
5970
return args.Get(0).(*ecrpublic.GetAuthorizationTokenOutput), nil
6071
}
6172

73+
type MockedSTS struct {
74+
mock.Mock
75+
}
76+
77+
func (m *MockedSTS) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) {
78+
args := m.Called(ctx, params)
79+
if args.Get(1) != nil {
80+
return args.Get(0).(*sts.AssumeRoleWithWebIdentityOutput), args.Get(1).(error)
81+
}
82+
return args.Get(0).(*sts.AssumeRoleWithWebIdentityOutput), nil
83+
}
84+
6285
func generatePrivateGetAuthorizationTokenOutput(user string, password string, proxy string, expiration *time.Time) *ecr.GetAuthorizationTokenOutput {
6386
creds := []byte(fmt.Sprintf("%s:%s", user, password))
6487
data := types.AuthorizationData{
@@ -159,7 +182,7 @@ func Test_GetCredentials_Private(t *testing.T) {
159182
p := &ecrPlugin{ecr: &mockECR}
160183
mockECR.On("GetAuthorizationToken", mock.Anything, mock.Anything).Return(testcase.getAuthorizationTokenOutput, testcase.getAuthorizationTokenError)
161184

162-
creds, err := p.GetCredentials(context.TODO(), testcase.image, testcase.args)
185+
creds, err := p.GetCredentials(context.TODO(), &v1.CredentialProviderRequest{Image: testcase.image}, testcase.args)
163186

164187
if testcase.expectedError != nil && (testcase.expectedError.Error() != err.Error()) {
165188
t.Fatalf("expected %s, got %s", testcase.expectedError.Error(), err.Error())
@@ -182,7 +205,101 @@ func Test_GetCredentials_Private(t *testing.T) {
182205
}
183206
}
184207

185-
func generatePublicGetAuthorizationTokenOutput(user string, password string, proxy string, expiration *time.Time) *ecrpublic.GetAuthorizationTokenOutput {
208+
func Test_GetCredentials_PrivateForServiceAccount(t *testing.T) {
209+
testcases := []struct {
210+
name string
211+
request *v1.CredentialProviderRequest
212+
args []string
213+
expectedAssumeArn string
214+
getAuthorizationTokenOutput *ecr.GetAuthorizationTokenOutput
215+
getAuthorizationTokenError error
216+
assumeRoleWithWebIdentityOutput *sts.AssumeRoleWithWebIdentityOutput
217+
assumeRoleWithWebIdentityError error
218+
response *v1.CredentialProviderResponse
219+
expectedError error
220+
}{
221+
{
222+
name: "success",
223+
request: &v1.CredentialProviderRequest{Image: "123456789123.dkr.ecr.us-west-2.amazonaws.com", ServiceAccountToken: "DEADBEEF=", ServiceAccountAnnotations: map[string]string{"eks.amazonaws.com/ecr-role-arn": "arn:expected"}},
224+
expectedAssumeArn: "arn:expected",
225+
getAuthorizationTokenOutput: generatePrivateGetAuthorizationTokenOutput("user", "pass", "", nil),
226+
assumeRoleWithWebIdentityOutput: &sts.AssumeRoleWithWebIdentityOutput{
227+
Credentials: &ststypes.Credentials{
228+
AccessKeyId: aws.String("access-key-id"),
229+
SecretAccessKey: aws.String("secret-access-key"),
230+
SessionToken: aws.String("session-token"),
231+
},
232+
},
233+
response: generateResponse("123456789123.dkr.ecr.us-west-2.amazonaws.com", "user", "pass"),
234+
},
235+
{
236+
name: "no arn provided",
237+
request: &v1.CredentialProviderRequest{Image: "123456789123.dkr.ecr.us-west-2.amazonaws.com", ServiceAccountToken: "DEADBEEF="},
238+
expectedAssumeArn: "arn:expected",
239+
getAuthorizationTokenOutput: generatePrivateGetAuthorizationTokenOutput("user", "pass", "", nil),
240+
assumeRoleWithWebIdentityOutput: &sts.AssumeRoleWithWebIdentityOutput{
241+
Credentials: &ststypes.Credentials{
242+
AccessKeyId: aws.String("access-key-id"),
243+
SecretAccessKey: aws.String("secret-access-key"),
244+
SessionToken: aws.String("session-token"),
245+
},
246+
},
247+
response: generateResponse("123456789123.dkr.ecr.us-west-2.amazonaws.com", "user", "pass"),
248+
expectedError: errors.New("no arn provided, cannot assume role using ServiceAccountToken"),
249+
},
250+
{
251+
name: "assume error",
252+
request: &v1.CredentialProviderRequest{Image: "123456789123.dkr.ecr.us-west-2.amazonaws.com", ServiceAccountToken: "DEADBEEF=", ServiceAccountAnnotations: map[string]string{"eks.amazonaws.com/ecr-role-arn": "arn:expected"}},
253+
expectedAssumeArn: "arn:expected",
254+
getAuthorizationTokenOutput: generatePrivateGetAuthorizationTokenOutput("user", "pass", "", nil),
255+
assumeRoleWithWebIdentityError: errors.New("injected error"),
256+
response: generateResponse("123456789123.dkr.ecr.us-west-2.amazonaws.com", "user", "pass"),
257+
expectedError: errors.New("injected error"),
258+
},
259+
}
260+
for _, testcase := range testcases {
261+
t.Run(testcase.name, func(t *testing.T) {
262+
mockECR := MockedECR{}
263+
mockSTS := MockedSTS{}
264+
p := &ecrPlugin{ecr: &mockECR, sts: &mockSTS}
265+
mockECR.On("GetAuthorizationToken", mock.Anything, mock.Anything).Return(testcase.getAuthorizationTokenOutput, testcase.getAuthorizationTokenError)
266+
267+
expectedInput := sts.AssumeRoleWithWebIdentityInput{
268+
RoleArn: aws.String(testcase.expectedAssumeArn),
269+
RoleSessionName: aws.String("ecr-credential-provider"),
270+
WebIdentityToken: aws.String(testcase.request.ServiceAccountToken),
271+
}
272+
mockSTS.On("AssumeRoleWithWebIdentity", mock.Anything, &expectedInput).Return(testcase.assumeRoleWithWebIdentityOutput, testcase.assumeRoleWithWebIdentityError)
273+
creds, err := p.GetCredentials(context.TODO(), testcase.request, testcase.args)
274+
if err != nil {
275+
if testcase.expectedError == nil {
276+
t.Fatalf("got unexpected error %s", err.Error())
277+
278+
}
279+
280+
if testcase.expectedError.Error() != err.Error() {
281+
t.Fatalf("expected %s, got %s", testcase.expectedError.Error(), err.Error())
282+
}
283+
}
284+
285+
if testcase.expectedError == nil {
286+
if creds.CacheKeyType != testcase.response.CacheKeyType {
287+
t.Fatalf("Unexpected CacheKeyType. Expected: %s, got: %s", testcase.response.CacheKeyType, creds.CacheKeyType)
288+
}
289+
290+
if creds.Auth[testcase.request.Image] != testcase.response.Auth[testcase.request.Image] {
291+
t.Fatalf("Unexpected Auth. Expected: %s, got: %s", testcase.response.Auth[testcase.request.Image], creds.Auth[testcase.request.Image])
292+
}
293+
294+
if creds.CacheDuration.Duration != testcase.response.CacheDuration.Duration {
295+
t.Fatalf("Unexpected CacheDuration. Expected: %s, got: %s", testcase.response.CacheDuration.Duration, creds.CacheDuration.Duration)
296+
}
297+
}
298+
})
299+
}
300+
}
301+
302+
func generatePublicGetAuthorizationTokenOutput(user string, password string, expiration *time.Time) *ecrpublic.GetAuthorizationTokenOutput {
186303
creds := []byte(fmt.Sprintf("%s:%s", user, password))
187304
data := &publictypes.AuthorizationData{
188305
AuthorizationToken: aws.String(base64.StdEncoding.EncodeToString(creds)),
@@ -207,9 +324,16 @@ func Test_GetCredentials_Public(t *testing.T) {
207324
{
208325
name: "success",
209326
image: "public.ecr.aws",
210-
getAuthorizationTokenOutput: generatePublicGetAuthorizationTokenOutput("user", "pass", "", nil),
327+
getAuthorizationTokenOutput: generatePublicGetAuthorizationTokenOutput("user", "pass", nil),
211328
response: generateResponse("public.ecr.aws", "user", "pass"),
212329
},
330+
{
331+
name: "empty image",
332+
image: "",
333+
getAuthorizationTokenOutput: &ecrpublic.GetAuthorizationTokenOutput{},
334+
getAuthorizationTokenError: nil,
335+
expectedError: errors.New("image in plugin request was empty"),
336+
},
213337
{
214338
name: "empty authorization data",
215339
image: "public.ecr.aws",
@@ -257,7 +381,7 @@ func Test_GetCredentials_Public(t *testing.T) {
257381
p := &ecrPlugin{ecrPublic: &mockECRPublic}
258382
mockECRPublic.On("GetAuthorizationToken", mock.Anything, mock.Anything).Return(testcase.getAuthorizationTokenOutput, testcase.getAuthorizationTokenError)
259383

260-
creds, err := p.GetCredentials(context.TODO(), testcase.image, testcase.args)
384+
creds, err := p.GetCredentials(context.TODO(), &v1.CredentialProviderRequest{Image: testcase.image}, testcase.args)
261385

262386
if testcase.expectedError != nil && (testcase.expectedError.Error() != err.Error()) {
263387
t.Fatalf("expected %s, got %s", testcase.expectedError.Error(), err.Error())

cmd/ecr-credential-provider/plugin.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import (
2828
"k8s.io/apimachinery/pkg/runtime/serializer"
2929
"k8s.io/apimachinery/pkg/runtime/serializer/json"
3030
"k8s.io/kubelet/pkg/apis/credentialprovider/install"
31-
"k8s.io/kubelet/pkg/apis/credentialprovider/v1"
31+
v1 "k8s.io/kubelet/pkg/apis/credentialprovider/v1"
3232
)
3333

3434
var (
@@ -43,7 +43,7 @@ func init() {
4343
// CredentialProvider is an interface implemented by the kubelet credential provider plugin to fetch
4444
// the username/password based on the provided image name.
4545
type CredentialProvider interface {
46-
GetCredentials(ctx context.Context, image string, args []string) (response *v1.CredentialProviderResponse, err error)
46+
GetCredentials(ctx context.Context, request *v1.CredentialProviderRequest, args []string) (response *v1.CredentialProviderResponse, err error)
4747
}
4848

4949
// ExecPlugin implements the exec-based plugin for fetching credentials that is invoked by the kubelet.
@@ -85,11 +85,7 @@ func (e *ExecPlugin) runPlugin(ctx context.Context, r io.Reader, w io.Writer, ar
8585
return err
8686
}
8787

88-
if request.Image == "" {
89-
return errors.New("image in plugin request was empty")
90-
}
91-
92-
response, err := e.plugin.GetCredentials(ctx, request.Image, args)
88+
response, err := e.plugin.GetCredentials(ctx, request, args)
9389
if err != nil {
9490
return err
9591
}

0 commit comments

Comments
 (0)