@@ -40,13 +40,7 @@ def get_vertex_region(self, vertex_region: Optional[str]) -> str:
40
40
def load_auth (
41
41
self , credentials : Optional [VERTEX_CREDENTIALS_TYPES ], project_id : Optional [str ]
42
42
) -> Tuple [Any , str ]:
43
- import google .auth as google_auth
44
- from google .auth import identity_pool
45
-
46
43
if credentials is not None :
47
- import google .oauth2 .service_account
48
- import google .oauth2 .credentials as google_oauth_credentials
49
-
50
44
if isinstance (credentials , str ):
51
45
verbose_logger .debug (
52
46
"Vertex: Loading vertex credentials from %s" , credentials
@@ -78,27 +72,25 @@ def load_auth(
78
72
79
73
# Check if the JSON object contains Workload Identity Federation configuration
80
74
if "type" in json_obj and json_obj ["type" ] == "external_account" :
81
- creds = identity_pool . Credentials . from_info (json_obj )
75
+ creds = self . _credentials_from_identity_pool (json_obj )
82
76
# Check if the JSON object contains Authorized User configuration (via gcloud auth application-default login)
83
77
elif "type" in json_obj and json_obj ["type" ] == "authorized_user" :
84
- creds = google_oauth_credentials . Credentials . from_authorized_user_info (
78
+ creds = self . _credentials_from_authorized_user (
85
79
json_obj ,
86
80
scopes = ["https://www.googleapis.com/auth/cloud-platform" ],
87
81
)
88
82
if project_id is None :
89
83
project_id = creds .quota_project_id # authorized user credentials don't have a project_id, only quota_project_id
90
84
else :
91
- creds = (
92
- google .oauth2 .service_account .Credentials .from_service_account_info (
93
- json_obj ,
94
- scopes = ["https://www.googleapis.com/auth/cloud-platform" ],
95
- )
85
+ creds = self ._credentials_from_service_account (
86
+ json_obj ,
87
+ scopes = ["https://www.googleapis.com/auth/cloud-platform" ],
96
88
)
97
89
98
90
if project_id is None :
99
91
project_id = getattr (creds , "project_id" , None )
100
92
else :
101
- creds , creds_project_id = google_auth . default (
93
+ creds , creds_project_id = self . _credentials_from_default_auth (
102
94
quota_project_id = project_id ,
103
95
scopes = ["https://www.googleapis.com/auth/cloud-platform" ],
104
96
)
@@ -116,6 +108,24 @@ def load_auth(
116
108
)
117
109
118
110
return creds , project_id
111
+
112
+ # Google Auth Helpers -- extracted for mocking purposes in tests
113
+ def _credentials_from_identity_pool (self , json_obj ):
114
+ from google .auth import identity_pool
115
+ return identity_pool .Credentials .from_info (json_obj )
116
+
117
+ def _credentials_from_authorized_user (self , json_obj , scopes ):
118
+ import google .oauth2 .credentials
119
+ return google .oauth2 .credentials .Credentials .from_authorized_user_info (json_obj , scopes = scopes )
120
+
121
+ def _credentials_from_service_account (self , json_obj , scopes ):
122
+ import google .oauth2 .service_account
123
+ return google .oauth2 .service_account .Credentials .from_service_account_info (json_obj , scopes = scopes )
124
+
125
+ def _credentials_from_default_auth (self , quota_project_id , scopes ):
126
+ import google .auth as google_auth
127
+ return google_auth .default (quota_project_id = quota_project_id , scopes = scopes )
128
+
119
129
120
130
def refresh_auth (self , credentials : Any ) -> None :
121
131
from google .auth .transport .requests import (
0 commit comments