Skip to content

Commit 37ff286

Browse files
committed
Mock out all of the google auth methods by moving them to helpers
1 parent bda1f89 commit 37ff286

File tree

4 files changed

+47
-23
lines changed

4 files changed

+47
-23
lines changed

litellm/llms/vertex_ai/vertex_llm_base.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,7 @@ def get_vertex_region(self, vertex_region: Optional[str]) -> str:
4040
def load_auth(
4141
self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str]
4242
) -> Tuple[Any, str]:
43-
import google.auth as google_auth
44-
from google.auth import identity_pool
45-
4643
if credentials is not None:
47-
import google.oauth2.service_account
48-
import google.oauth2.credentials as google_oauth_credentials
49-
5044
if isinstance(credentials, str):
5145
verbose_logger.debug(
5246
"Vertex: Loading vertex credentials from %s", credentials
@@ -78,27 +72,25 @@ def load_auth(
7872

7973
# Check if the JSON object contains Workload Identity Federation configuration
8074
if "type" in json_obj and json_obj["type"] == "external_account":
81-
creds = identity_pool.Credentials.from_info(json_obj)
75+
creds = self._credentials_from_identity_pool(json_obj)
8276
# Check if the JSON object contains Authorized User configuration (via gcloud auth application-default login)
8377
elif "type" in json_obj and json_obj["type"] == "authorized_user":
84-
creds = google_oauth_credentials.Credentials.from_authorized_user_info(
78+
creds = self._credentials_from_authorized_user(
8579
json_obj,
8680
scopes=["https://www.googleapis.com/auth/cloud-platform"],
8781
)
8882
if project_id is None:
8983
project_id = creds.quota_project_id # authorized user credentials don't have a project_id, only quota_project_id
9084
else:
91-
creds = (
92-
google.oauth2.service_account.Credentials.from_service_account_info(
93-
json_obj,
94-
scopes=["https://www.googleapis.com/auth/cloud-platform"],
95-
)
85+
creds = self._credentials_from_service_account(
86+
json_obj,
87+
scopes=["https://www.googleapis.com/auth/cloud-platform"],
9688
)
9789

9890
if project_id is None:
9991
project_id = getattr(creds, "project_id", None)
10092
else:
101-
creds, creds_project_id = google_auth.default(
93+
creds, creds_project_id = self._credentials_from_default_auth(
10294
quota_project_id=project_id,
10395
scopes=["https://www.googleapis.com/auth/cloud-platform"],
10496
)
@@ -116,6 +108,24 @@ def load_auth(
116108
)
117109

118110
return creds, project_id
111+
112+
# Google Auth Helpers -- extracted for mocking purposes in tests
113+
def _credentials_from_identity_pool(self, json_obj):
114+
from google.auth import identity_pool
115+
return identity_pool.Credentials.from_info(json_obj)
116+
117+
def _credentials_from_authorized_user(self, json_obj, scopes):
118+
import google.oauth2.credentials
119+
return google.oauth2.credentials.Credentials.from_authorized_user_info(json_obj, scopes=scopes)
120+
121+
def _credentials_from_service_account(self, json_obj, scopes):
122+
import google.oauth2.service_account
123+
return google.oauth2.service_account.Credentials.from_service_account_info(json_obj, scopes=scopes)
124+
125+
def _credentials_from_default_auth(self, quota_project_id, scopes):
126+
import google.auth as google_auth
127+
return google_auth.default(quota_project_id=quota_project_id, scopes=scopes)
128+
119129

120130
def refresh_auth(self, credentials: Any) -> None:
121131
from google.auth.transport.requests import (

poetry.lock

Lines changed: 6 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ redisvl = {version = "^0.4.1", optional = true, markers = "python_version >= '3.
5858
mcp = {version = "1.5.0", optional = true, python = ">=3.10"}
5959
litellm-proxy-extras = {version = "0.1.21", optional = true}
6060
rich = {version = "13.7.1", optional = true}
61-
google-auth = {version = "^2.40.1", optional = true}
6261
litellm-enterprise = {version = "0.1.4", optional = true}
6362

6463
[tool.poetry.extras]

tests/litellm/llms/vertex_ai/test_vertex_llm_base.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,13 @@ async def test_gemini_credentials(self, is_async):
175175
assert token == ""
176176
assert project == ""
177177

178-
179178
@pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"])
180179
@pytest.mark.asyncio
181180
async def test_authorized_user_credentials(self, is_async):
182181
vertex_base = VertexBase()
183182

183+
quota_project_id = "test-project"
184+
184185
credentials = {
185186
"account": "",
186187
"client_id": "fake-client-id",
@@ -191,7 +192,15 @@ async def test_authorized_user_credentials(self, is_async):
191192
"universe_domain": "googleapis.com"
192193
}
193194

194-
with patch.object(vertex_base, "refresh_auth") as mock_refresh:
195+
mock_creds = MagicMock()
196+
mock_creds.token = "token-1"
197+
mock_creds.expired = False
198+
mock_creds.quota_project_id = quota_project_id
199+
200+
201+
with patch.object(
202+
vertex_base, "_credentials_from_authorized_user", return_value=mock_creds
203+
) as mock_credentials_from_authorized_user, patch.object(vertex_base, "refresh_auth") as mock_refresh:
195204
def mock_refresh_impl(creds):
196205
creds.token = "refreshed-token"
197206

@@ -207,20 +216,21 @@ def mock_refresh_impl(creds):
207216
credentials=credentials, project_id=None, custom_llm_provider="vertex_ai"
208217
)
209218

210-
219+
assert mock_credentials_from_authorized_user.called
211220
assert token == "refreshed-token"
212-
assert project == "test-project"
221+
assert project == quota_project_id
213222

214223
# 2. Test that authorized_user-style credentials are correctly handled and uses passed in project_id
224+
not_quota_project_id = "new-project"
215225
if is_async:
216226
token, project = await vertex_base._ensure_access_token_async(
217-
credentials=credentials, project_id="new-project", custom_llm_provider="vertex_ai"
227+
credentials=credentials, project_id=not_quota_project_id, custom_llm_provider="vertex_ai"
218228
)
219229
else:
220230
token, project = vertex_base._ensure_access_token(
221-
credentials=credentials, project_id="new-project", custom_llm_provider="vertex_ai"
231+
credentials=credentials, project_id=not_quota_project_id, custom_llm_provider="vertex_ai"
222232
)
223233

224234

225235
assert token == "refreshed-token"
226-
assert project == "new-project"
236+
assert project == not_quota_project_id

0 commit comments

Comments
 (0)