diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index 30887e9f60dc..e8130741732d 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -11,7 +11,7 @@ import TabItem from '@theme/TabItem'; | Description | Vertex AI is a fully-managed AI development platform for building and using generative AI. | | Provider Route on LiteLLM | `vertex_ai/` | | Link to Provider Doc | [Vertex AI ↗](https://cloud.google.com/vertex-ai) | -| Base URL | [https://{vertex_location}-aiplatform.googleapis.com/](https://{vertex_location}-aiplatform.googleapis.com/) | +| Base URL | 1. Regional endpoints
[https://{vertex_location}-aiplatform.googleapis.com/](https://{vertex_location}-aiplatform.googleapis.com/)
2. Global endpoints (limited availability)
[https://aiplatform.googleapis.com/](https://{aiplatform.googleapis.com/)| | Supported Operations | [`/chat/completions`](#sample-usage), `/completions`, [`/embeddings`](#embedding-models), [`/audio/speech`](#text-to-speech-apis), [`/fine_tuning`](#fine-tuning-apis), [`/batches`](#batch-apis), [`/files`](#batch-apis), [`/images`](#image-generation-models) | @@ -832,7 +832,7 @@ OR You can set: - `vertex_credentials` (str) - can be a json string or filepath to your vertex ai service account.json -- `vertex_location` (str) - place where vertex model is deployed (us-central1, asia-southeast1, etc.) +- `vertex_location` (str) - place where vertex model is deployed (us-central1, asia-southeast1, etc.). Some models support the global location, please see [Vertex AI documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#supported_models) - `vertex_project` Optional[str] - use if vertex project different from the one in vertex_credentials as dynamic params for a `litellm.completion` call. diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index 477995a1578a..f96848c6d563 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -84,9 +84,15 @@ def _get_vertex_url( endpoint = "generateContent" if stream is True: endpoint = "streamGenerateContent" - url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse" + if vertex_location== "global": + url = f"https://aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/global/publishers/google/models/{model}:{endpoint}?alt=sse" + else: + url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse" else: - url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" + if vertex_location == "global": + url = f"https://aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/global/publishers/google/models/{model}:{endpoint}" + else: + url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" # if model is only numeric chars then it's a fine tuned gemini model # model = 4965075652664360960 diff --git a/tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py b/tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py index 90a7fb30e199..fca89650a53e 100644 --- a/tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py +++ b/tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py @@ -17,6 +17,7 @@ get_vertex_location_from_url, get_vertex_project_id_from_url, set_schema_property_ordering, + _get_vertex_url ) @@ -516,3 +517,45 @@ def test_vertex_ai_complex_response_schema(): assert "additionalProperties" not in type2 assert "additionalProperties" not in type3 assert "additionalProperties" not in type3_prop3_items + +@pytest.mark.parametrize( + "stream, expected_endpoint_suffix", + [ + (True, "streamGenerateContent?alt=sse"), + (False, "generateContent"), + ], +) +def test_get_vertex_url_global_region(stream, expected_endpoint_suffix): + """ + Test _get_vertex_url when vertex_location is 'global' for chat mode. + """ + mode = "chat" + model = "gemini-1.5-pro-preview-0409" + vertex_project = "test-g-project" + vertex_location = "global" + vertex_api_version = "v1" + + # Mock litellm.VertexGeminiConfig.get_model_for_vertex_ai_url to return model as is + # as we are not testing that part here, just the URL construction + with patch("litellm.VertexGeminiConfig.get_model_for_vertex_ai_url", side_effect=lambda model: model): + url, endpoint = _get_vertex_url( + mode=mode, + model=model, + stream=stream, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_api_version=vertex_api_version, + ) + + expected_url_base = f"https://aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/global/publishers/google/models/{model}" + + if stream: + expected_endpoint = "streamGenerateContent" + expected_url = f"{expected_url_base}:{expected_endpoint}?alt=sse" + else: + expected_endpoint = "generateContent" + expected_url = f"{expected_url_base}:{expected_endpoint}" + + + assert endpoint == expected_endpoint + assert url == expected_url