@@ -29,6 +29,8 @@ import (
2929 "github.com/aws/aws-sdk-go-v2/service/ecr/types"
3030 "github.com/aws/aws-sdk-go-v2/service/ecrpublic"
3131 publictypes "github.com/aws/aws-sdk-go-v2/service/ecrpublic/types"
32+ "github.com/aws/aws-sdk-go-v2/service/sts"
33+ ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
3234 "github.com/stretchr/testify/mock"
3335 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
3436 v1 "k8s.io/kubelet/pkg/apis/credentialprovider/v1"
@@ -40,6 +42,15 @@ type MockedECR struct {
4042
4143func (m * MockedECR ) GetAuthorizationToken (ctx context.Context , params * ecr.GetAuthorizationTokenInput , optFns ... func (* ecr.Options )) (* ecr.GetAuthorizationTokenOutput , error ) {
4244 args := m .Called (ctx , params )
45+
46+ opts := ecr.Options {}
47+ for _ , fn := range optFns {
48+ fn (& opts )
49+ }
50+ if opts .Credentials != nil {
51+ opts .Credentials .Retrieve (ctx )
52+ }
53+
4354 if args .Get (1 ) != nil {
4455 return args .Get (0 ).(* ecr.GetAuthorizationTokenOutput ), args .Get (1 ).(error )
4556 }
@@ -59,6 +70,18 @@ func (m *MockedECRPublic) GetAuthorizationToken(ctx context.Context, params *ecr
5970 return args .Get (0 ).(* ecrpublic.GetAuthorizationTokenOutput ), nil
6071}
6172
73+ type MockedSTS struct {
74+ mock.Mock
75+ }
76+
77+ func (m * MockedSTS ) AssumeRoleWithWebIdentity (ctx context.Context , params * sts.AssumeRoleWithWebIdentityInput , optFns ... func (* sts.Options )) (* sts.AssumeRoleWithWebIdentityOutput , error ) {
78+ args := m .Called (ctx , params )
79+ if args .Get (1 ) != nil {
80+ return args .Get (0 ).(* sts.AssumeRoleWithWebIdentityOutput ), args .Get (1 ).(error )
81+ }
82+ return args .Get (0 ).(* sts.AssumeRoleWithWebIdentityOutput ), nil
83+ }
84+
6285func generatePrivateGetAuthorizationTokenOutput (user string , password string , proxy string , expiration * time.Time ) * ecr.GetAuthorizationTokenOutput {
6386 creds := []byte (fmt .Sprintf ("%s:%s" , user , password ))
6487 data := types.AuthorizationData {
@@ -159,7 +182,7 @@ func Test_GetCredentials_Private(t *testing.T) {
159182 p := & ecrPlugin {ecr : & mockECR }
160183 mockECR .On ("GetAuthorizationToken" , mock .Anything , mock .Anything ).Return (testcase .getAuthorizationTokenOutput , testcase .getAuthorizationTokenError )
161184
162- creds , err := p .GetCredentials (context .TODO (), testcase .image , testcase .args )
185+ creds , err := p .GetCredentials (context .TODO (), & v1. CredentialProviderRequest { Image : testcase .image } , testcase .args )
163186
164187 if testcase .expectedError != nil && (testcase .expectedError .Error () != err .Error ()) {
165188 t .Fatalf ("expected %s, got %s" , testcase .expectedError .Error (), err .Error ())
@@ -182,7 +205,101 @@ func Test_GetCredentials_Private(t *testing.T) {
182205 }
183206}
184207
185- func generatePublicGetAuthorizationTokenOutput (user string , password string , proxy string , expiration * time.Time ) * ecrpublic.GetAuthorizationTokenOutput {
208+ func Test_GetCredentials_PrivateForServiceAccount (t * testing.T ) {
209+ testcases := []struct {
210+ name string
211+ request * v1.CredentialProviderRequest
212+ args []string
213+ expectedAssumeArn string
214+ getAuthorizationTokenOutput * ecr.GetAuthorizationTokenOutput
215+ getAuthorizationTokenError error
216+ assumeRoleWithWebIdentityOutput * sts.AssumeRoleWithWebIdentityOutput
217+ assumeRoleWithWebIdentityError error
218+ response * v1.CredentialProviderResponse
219+ expectedError error
220+ }{
221+ {
222+ name : "success" ,
223+ request : & v1.CredentialProviderRequest {Image : "123456789123.dkr.ecr.us-west-2.amazonaws.com" , ServiceAccountToken : "DEADBEEF=" , ServiceAccountAnnotations : map [string ]string {"eks.amazonaws.com/ecr-role-arn" : "arn:expected" }},
224+ expectedAssumeArn : "arn:expected" ,
225+ getAuthorizationTokenOutput : generatePrivateGetAuthorizationTokenOutput ("user" , "pass" , "" , nil ),
226+ assumeRoleWithWebIdentityOutput : & sts.AssumeRoleWithWebIdentityOutput {
227+ Credentials : & ststypes.Credentials {
228+ AccessKeyId : aws .String ("access-key-id" ),
229+ SecretAccessKey : aws .String ("secret-access-key" ),
230+ SessionToken : aws .String ("session-token" ),
231+ },
232+ },
233+ response : generateResponse ("123456789123.dkr.ecr.us-west-2.amazonaws.com" , "user" , "pass" ),
234+ },
235+ {
236+ name : "no arn provided" ,
237+ request : & v1.CredentialProviderRequest {Image : "123456789123.dkr.ecr.us-west-2.amazonaws.com" , ServiceAccountToken : "DEADBEEF=" },
238+ expectedAssumeArn : "arn:expected" ,
239+ getAuthorizationTokenOutput : generatePrivateGetAuthorizationTokenOutput ("user" , "pass" , "" , nil ),
240+ assumeRoleWithWebIdentityOutput : & sts.AssumeRoleWithWebIdentityOutput {
241+ Credentials : & ststypes.Credentials {
242+ AccessKeyId : aws .String ("access-key-id" ),
243+ SecretAccessKey : aws .String ("secret-access-key" ),
244+ SessionToken : aws .String ("session-token" ),
245+ },
246+ },
247+ response : generateResponse ("123456789123.dkr.ecr.us-west-2.amazonaws.com" , "user" , "pass" ),
248+ expectedError : errors .New ("no arn provided, cannot assume role using ServiceAccountToken" ),
249+ },
250+ {
251+ name : "assume error" ,
252+ request : & v1.CredentialProviderRequest {Image : "123456789123.dkr.ecr.us-west-2.amazonaws.com" , ServiceAccountToken : "DEADBEEF=" , ServiceAccountAnnotations : map [string ]string {"eks.amazonaws.com/ecr-role-arn" : "arn:expected" }},
253+ expectedAssumeArn : "arn:expected" ,
254+ getAuthorizationTokenOutput : generatePrivateGetAuthorizationTokenOutput ("user" , "pass" , "" , nil ),
255+ assumeRoleWithWebIdentityError : errors .New ("injected error" ),
256+ response : generateResponse ("123456789123.dkr.ecr.us-west-2.amazonaws.com" , "user" , "pass" ),
257+ expectedError : errors .New ("injected error" ),
258+ },
259+ }
260+ for _ , testcase := range testcases {
261+ t .Run (testcase .name , func (t * testing.T ) {
262+ mockECR := MockedECR {}
263+ mockSTS := MockedSTS {}
264+ p := & ecrPlugin {ecr : & mockECR , sts : & mockSTS }
265+ mockECR .On ("GetAuthorizationToken" , mock .Anything , mock .Anything ).Return (testcase .getAuthorizationTokenOutput , testcase .getAuthorizationTokenError )
266+
267+ expectedInput := sts.AssumeRoleWithWebIdentityInput {
268+ RoleArn : aws .String (testcase .expectedAssumeArn ),
269+ RoleSessionName : aws .String ("ecr-credential-provider" ),
270+ WebIdentityToken : aws .String (testcase .request .ServiceAccountToken ),
271+ }
272+ mockSTS .On ("AssumeRoleWithWebIdentity" , mock .Anything , & expectedInput ).Return (testcase .assumeRoleWithWebIdentityOutput , testcase .assumeRoleWithWebIdentityError )
273+ creds , err := p .GetCredentials (context .TODO (), testcase .request , testcase .args )
274+ if err != nil {
275+ if testcase .expectedError == nil {
276+ t .Fatalf ("got unexpected error %s" , err .Error ())
277+
278+ }
279+
280+ if testcase .expectedError .Error () != err .Error () {
281+ t .Fatalf ("expected %s, got %s" , testcase .expectedError .Error (), err .Error ())
282+ }
283+ }
284+
285+ if testcase .expectedError == nil {
286+ if creds .CacheKeyType != testcase .response .CacheKeyType {
287+ t .Fatalf ("Unexpected CacheKeyType. Expected: %s, got: %s" , testcase .response .CacheKeyType , creds .CacheKeyType )
288+ }
289+
290+ if creds .Auth [testcase .request .Image ] != testcase .response .Auth [testcase .request .Image ] {
291+ t .Fatalf ("Unexpected Auth. Expected: %s, got: %s" , testcase .response .Auth [testcase .request .Image ], creds .Auth [testcase .request .Image ])
292+ }
293+
294+ if creds .CacheDuration .Duration != testcase .response .CacheDuration .Duration {
295+ t .Fatalf ("Unexpected CacheDuration. Expected: %s, got: %s" , testcase .response .CacheDuration .Duration , creds .CacheDuration .Duration )
296+ }
297+ }
298+ })
299+ }
300+ }
301+
302+ func generatePublicGetAuthorizationTokenOutput (user string , password string , expiration * time.Time ) * ecrpublic.GetAuthorizationTokenOutput {
186303 creds := []byte (fmt .Sprintf ("%s:%s" , user , password ))
187304 data := & publictypes.AuthorizationData {
188305 AuthorizationToken : aws .String (base64 .StdEncoding .EncodeToString (creds )),
@@ -207,9 +324,16 @@ func Test_GetCredentials_Public(t *testing.T) {
207324 {
208325 name : "success" ,
209326 image : "public.ecr.aws" ,
210- getAuthorizationTokenOutput : generatePublicGetAuthorizationTokenOutput ("user" , "pass" , "" , nil ),
327+ getAuthorizationTokenOutput : generatePublicGetAuthorizationTokenOutput ("user" , "pass" , nil ),
211328 response : generateResponse ("public.ecr.aws" , "user" , "pass" ),
212329 },
330+ {
331+ name : "empty image" ,
332+ image : "" ,
333+ getAuthorizationTokenOutput : & ecrpublic.GetAuthorizationTokenOutput {},
334+ getAuthorizationTokenError : nil ,
335+ expectedError : errors .New ("image in plugin request was empty" ),
336+ },
213337 {
214338 name : "empty authorization data" ,
215339 image : "public.ecr.aws" ,
@@ -257,7 +381,7 @@ func Test_GetCredentials_Public(t *testing.T) {
257381 p := & ecrPlugin {ecrPublic : & mockECRPublic }
258382 mockECRPublic .On ("GetAuthorizationToken" , mock .Anything , mock .Anything ).Return (testcase .getAuthorizationTokenOutput , testcase .getAuthorizationTokenError )
259383
260- creds , err := p .GetCredentials (context .TODO (), testcase .image , testcase .args )
384+ creds , err := p .GetCredentials (context .TODO (), & v1. CredentialProviderRequest { Image : testcase .image } , testcase .args )
261385
262386 if testcase .expectedError != nil && (testcase .expectedError .Error () != err .Error ()) {
263387 t .Fatalf ("expected %s, got %s" , testcase .expectedError .Error (), err .Error ())
0 commit comments