Skip to content

Commit 4c8a912

Browse files
lesebclaude
andauthored
feat(auth): add upstream header auth provider for gateway deployments (#5563)
# What does this PR do? Adds a new `upstream_header` authentication provider that extracts user identity from headers injected by an upstream gateway (Authorino, Istio, or any reverse proxy that handles auth). This enables Llama Stack to participate in gateway-fronted deployments without redundant token validation, while still scoping all storage queries to the authenticated user via the existing `AuthorizedSqlStore` pipeline. The `AuthProvider` base class gains a `requires_http_bearer` property (default `True`) so the `AuthenticationMiddleware` can skip Bearer token extraction for providers that read identity from other sources. The `authenticated_client_id` falls back to the principal when no token is present, preserving quota tracking. Configuration: ```yaml server: auth: provider_config: type: upstream_header principal_header: x-auth-user-id # required attributes_header: x-auth-attributes # optional, JSON dict ``` ## Test Plan 17 unit tests covering provider behavior and middleware integration: ```bash uv run pytest tests/unit/server/test_auth_upstream_header.py -v ``` ``` tests/unit/server/test_auth_upstream_header.py::test_valid_upstream_header_auth PASSED tests/unit/server/test_auth_upstream_header.py::test_valid_upstream_header_auth_principal_only PASSED tests/unit/server/test_auth_upstream_header.py::test_missing_principal_header PASSED tests/unit/server/test_auth_upstream_header.py::test_invalid_attributes_json PASSED tests/unit/server/test_auth_upstream_header.py::test_attributes_not_object PASSED tests/unit/server/test_auth_upstream_header.py::test_no_bearer_token_required PASSED tests/unit/server/test_auth_upstream_header.py::test_bearer_token_ignored PASSED tests/unit/server/test_auth_upstream_header.py::test_no_attributes_header_configured PASSED tests/unit/server/test_auth_upstream_header.py::test_case_insensitive_headers PASSED tests/unit/server/test_auth_upstream_header.py::test_attributes_string_values_normalized PASSED tests/unit/server/test_auth_upstream_header.py::test_error_message_includes_header_name PASSED tests/unit/server/test_auth_upstream_header.py::test_authenticated_client_id_uses_principal PASSED tests/unit/server/test_auth_upstream_header.py::test_provider_requires_http_bearer_false PASSED tests/unit/server/test_auth_upstream_header.py::test_provider_validate_token_extracts_principal PASSED tests/unit/server/test_auth_upstream_header.py::test_provider_validate_token_extracts_attributes PASSED tests/unit/server/test_auth_upstream_header.py::test_provider_validate_token_missing_principal PASSED tests/unit/server/test_auth_upstream_header.py::test_provider_validate_token_none_scope PASSED ``` All 35 existing auth tests continue to pass. Signed-off-by: Sébastien Han <seb@redhat.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 043394d commit 4c8a912

4 files changed

Lines changed: 394 additions & 10 deletions

File tree

src/llama_stack/core/datatypes.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ class AuthProviderType(StrEnum):
213213
GITHUB_TOKEN = "github_token"
214214
CUSTOM = "custom"
215215
KUBERNETES = "kubernetes"
216+
UPSTREAM_HEADER = "upstream_header"
216217

217218

218219
class OAuth2TokenAuthConfig(BaseModel):
@@ -320,8 +321,30 @@ def validate_claims_mapping(cls, v):
320321
return v
321322

322323

324+
class UpstreamHeaderAuthConfig(BaseModel):
325+
"""Configuration for upstream header authentication.
326+
327+
Used when an upstream gateway (Authorino, Istio, or any reverse proxy) handles
328+
authentication and injects user identity into request headers. Llama Stack trusts
329+
these headers and extracts the principal and optional attributes from them.
330+
"""
331+
332+
type: Literal[AuthProviderType.UPSTREAM_HEADER] = AuthProviderType.UPSTREAM_HEADER
333+
principal_header: str = Field(
334+
description="HTTP header containing the authenticated user's identity (e.g. x-auth-user-id)",
335+
)
336+
attributes_header: str | None = Field(
337+
default=None,
338+
description="HTTP header containing JSON-encoded user attributes for access control (e.g. x-auth-attributes)",
339+
)
340+
341+
323342
AuthProviderConfig = Annotated[
324-
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig | KubernetesAuthProviderConfig,
343+
OAuth2TokenAuthConfig
344+
| GitHubTokenAuthConfig
345+
| CustomAuthConfig
346+
| KubernetesAuthProviderConfig
347+
| UpstreamHeaderAuthConfig,
325348
Field(discriminator="type"),
326349
]
327350

src/llama_stack/core/server/auth.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,20 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> Any:
120120
return await self.app(scope, receive, send)
121121

122122
# Handle authentication
123-
headers = dict(scope.get("headers", []))
124-
auth_header = headers.get(b"authorization", b"").decode()
123+
if self.auth_provider.requires_http_bearer:
124+
headers = dict(scope.get("headers", []))
125+
auth_header = headers.get(b"authorization", b"").decode()
125126

126-
if not auth_header:
127-
error_msg = self.auth_provider.get_auth_error_message(scope)
128-
return await self._send_auth_error(send, error_msg)
127+
if not auth_header:
128+
error_msg = self.auth_provider.get_auth_error_message(scope)
129+
return await self._send_auth_error(send, error_msg)
129130

130-
if not auth_header.startswith("Bearer "):
131-
return await self._send_auth_error(send, "Invalid Authorization header format")
131+
if not auth_header.startswith("Bearer "):
132+
return await self._send_auth_error(send, "Invalid Authorization header format")
132133

133-
token = auth_header.split("Bearer ", 1)[1]
134+
token = auth_header.split("Bearer ", 1)[1]
135+
else:
136+
token = ""
134137

135138
# Validate token and get access attributes
136139
try:
@@ -147,7 +150,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> Any:
147150

148151
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
149152
# can identify the requester and enforce per-client rate limits.
150-
scope["authenticated_client_id"] = token
153+
scope["authenticated_client_id"] = token or validation_result.principal
151154

152155
# Store attributes in request scope
153156
scope["principal"] = validation_result.principal

src/llama_stack/core/server/auth_providers.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
GitHubTokenAuthConfig,
2121
KubernetesAuthProviderConfig,
2222
OAuth2TokenAuthConfig,
23+
UpstreamHeaderAuthConfig,
2324
User,
2425
)
2526
from llama_stack.log import get_logger
@@ -60,6 +61,15 @@ class AuthRequest(BaseModel):
6061
class AuthProvider(ABC):
6162
"""Abstract base class for authentication providers."""
6263

64+
@property
65+
def requires_http_bearer(self) -> bool:
66+
"""Whether this provider requires a Bearer token from the Authorization header.
67+
68+
Providers that extract identity from other sources (e.g. gateway-injected
69+
headers) should override this to return False.
70+
"""
71+
return True
72+
6373
@abstractmethod
6474
async def validate_token(self, token: str, scope: Scope | None = None) -> User:
6575
"""Validate a token and return access attributes."""
@@ -502,6 +512,70 @@ async def close(self) -> None:
502512
pass
503513

504514

515+
class UpstreamHeaderAuthProvider(AuthProvider):
516+
"""Authentication provider that extracts identity from upstream gateway headers.
517+
518+
Used when an upstream gateway (Authorino, Istio, or any reverse proxy) handles
519+
authentication and injects user identity into request headers. This provider
520+
trusts the headers and performs no token validation or outbound calls.
521+
"""
522+
523+
def __init__(self, config: UpstreamHeaderAuthConfig) -> None:
524+
self.config = config
525+
526+
@property
527+
def requires_http_bearer(self) -> bool:
528+
return False
529+
530+
async def validate_token(self, token: str, scope: Scope | None = None) -> User:
531+
if scope is None:
532+
raise ValueError("Missing required authentication header: " + self.config.principal_header)
533+
534+
headers = dict(scope.get("headers", []))
535+
536+
# HTTP headers are case-insensitive; ASGI stores them as lowercase bytes
537+
principal_key = self.config.principal_header.lower().encode()
538+
principal_value = headers.get(principal_key)
539+
540+
if not principal_value:
541+
raise ValueError("Missing required authentication header: " + self.config.principal_header)
542+
543+
principal = principal_value.decode()
544+
545+
attributes: dict[str, list[str]] | None = None
546+
if self.config.attributes_header:
547+
attributes_key = self.config.attributes_header.lower().encode()
548+
attributes_value = headers.get(attributes_key)
549+
if attributes_value:
550+
import json
551+
552+
try:
553+
parsed = json.loads(attributes_value.decode())
554+
except (json.JSONDecodeError, UnicodeDecodeError) as e:
555+
raise ValueError("Failed to parse authentication attributes header: invalid JSON") from e
556+
557+
if not isinstance(parsed, dict):
558+
raise ValueError("Failed to parse authentication attributes header: expected JSON object")
559+
560+
# Normalize values to list[str] to match the User.attributes type
561+
attributes = {}
562+
for k, v in parsed.items():
563+
if isinstance(v, list):
564+
attributes[k] = [str(item) for item in v]
565+
elif isinstance(v, str):
566+
attributes[k] = [v]
567+
else:
568+
attributes[k] = [str(v)]
569+
570+
return User(principal=principal, attributes=attributes)
571+
572+
async def close(self) -> None:
573+
pass
574+
575+
def get_auth_error_message(self, scope: Scope | None = None) -> str:
576+
return f"Authentication required. Upstream gateway must set the {self.config.principal_header} header"
577+
578+
505579
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
506580
"""Factory function to create the appropriate auth provider."""
507581
provider_config = config.provider_config
@@ -514,5 +588,7 @@ def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
514588
return GitHubTokenAuthProvider(provider_config)
515589
elif isinstance(provider_config, KubernetesAuthProviderConfig):
516590
return KubernetesAuthProvider(provider_config)
591+
elif isinstance(provider_config, UpstreamHeaderAuthConfig):
592+
return UpstreamHeaderAuthProvider(provider_config)
517593
else:
518594
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")

0 commit comments

Comments
 (0)