diff --git a/docs/my-website/docs/mcp.md b/docs/my-website/docs/mcp.md index 17bf4460458c..3a4de87beccf 100644 --- a/docs/my-website/docs/mcp.md +++ b/docs/my-website/docs/mcp.md @@ -265,6 +265,182 @@ if __name__ == "__main__": +## Using your MCP with client side credentials + +Use this if you want to pass a client side authentication token to LiteLLM to then pass to your MCP to auth to your MCP. + +You can specify your MCP auth token using the header `x-mcp-auth`. LiteLLM will forward this token to your MCP server for authentication. + + + + +#### Connect via OpenAI Responses API with MCP Auth + +Use the OpenAI Responses API and include the `x-mcp-auth` header for your MCP server authentication: + +```bash title="cURL Example with MCP Auth" showLineNumbers +curl --location 'https://api.openai.com/v1/responses' \ +--header 'Content-Type: application/json' \ +--header "Authorization: Bearer $OPENAI_API_KEY" \ +--data '{ + "model": "gpt-4o", + "tools": [ + { + "type": "mcp", + "server_label": "litellm", + "server_url": "/mcp", + "require_approval": "never", + "headers": { + "x-litellm-api-key": "Bearer YOUR_LITELLM_API_KEY", + "x-mcp-auth": YOUR_MCP_AUTH_TOKEN + } + } + ], + "input": "Run available tools", + "tool_choice": "required" +}' +``` + + + + + +#### Connect via LiteLLM Proxy Responses API with MCP Auth + +Use this when calling LiteLLM Proxy for LLM API requests to `/v1/responses` endpoint with MCP authentication: + +```bash title="cURL Example with MCP Auth" showLineNumbers +curl --location '/v1/responses' \ +--header 'Content-Type: application/json' \ +--header "Authorization: Bearer $LITELLM_API_KEY" \ +--data '{ + "model": "gpt-4o", + "tools": [ + { + "type": "mcp", + "server_label": "litellm", + "server_url": "/mcp", + "require_approval": "never", + "headers": { + "x-litellm-api-key": "Bearer YOUR_LITELLM_API_KEY", + "x-mcp-auth": "YOUR_MCP_AUTH_TOKEN" + } + } + ], + "input": "Run available tools", + "tool_choice": "required" +}' +``` + + + + + +#### Connect via Cursor IDE with MCP Auth + +Use tools directly from Cursor IDE with LiteLLM MCP and include your MCP authentication token: + +**Setup Instructions:** + +1. **Open Cursor Settings**: Use `⇧+⌘+J` (Mac) or `Ctrl+Shift+J` (Windows/Linux) +2. **Navigate to MCP Tools**: Go to the "MCP Tools" tab and click "New MCP Server" +3. **Add Configuration**: Copy and paste the JSON configuration below, then save with `Cmd+S` or `Ctrl+S` + +```json title="Cursor MCP Configuration with Auth" showLineNumbers +{ + "mcpServers": { + "LiteLLM": { + "url": "/mcp", + "headers": { + "x-litellm-api-key": "Bearer $LITELLM_API_KEY", + "x-mcp-auth": "$MCP_AUTH_TOKEN" + } + } + } +} +``` + + + + + +#### Connect via Streamable HTTP Transport with MCP Auth + +Connect to LiteLLM MCP using HTTP transport with MCP authentication: + +**Server URL:** +```text showLineNumbers +/mcp +``` + +**Headers:** +```text showLineNumbers +x-litellm-api-key: Bearer YOUR_LITELLM_API_KEY +x-mcp-auth: Bearer YOUR_MCP_AUTH_TOKEN +``` + +This URL can be used with any MCP client that supports HTTP transport. The `x-mcp-auth` header will be forwarded to your MCP server for authentication. + + + + + +#### Connect via Python FastMCP Client with MCP Auth + +Use the Python FastMCP client to connect to your LiteLLM MCP server with MCP authentication: + +```python title="Python FastMCP Example with MCP Auth" showLineNumbers +import asyncio +import json + +from fastmcp import Client +from fastmcp.client.transports import StreamableHttpTransport + +# Create the transport with your LiteLLM MCP server URL and auth headers +server_url = "/mcp" +transport = StreamableHttpTransport( + server_url, + headers={ + "x-litellm-api-key": "Bearer YOUR_LITELLM_API_KEY", + "x-mcp-auth": "Bearer YOUR_MCP_AUTH_TOKEN" + } +) + +# Initialize the client with the transport +client = Client(transport=transport) + + +async def main(): + # Connection is established here + print("Connecting to LiteLLM MCP server with authentication...") + async with client: + print(f"Client connected: {client.is_connected()}") + + # Make MCP calls within the context + print("Fetching available tools...") + tools = await client.list_tools() + + print(f"Available tools: {json.dumps([t.name for t in tools], indent=2)}") + + # Example: Call a tool (replace 'tool_name' with an actual tool name) + if tools: + tool_name = tools[0].name + print(f"Calling tool: {tool_name}") + + # Call the tool with appropriate arguments + result = await client.call_tool(tool_name, arguments={}) + print(f"Tool result: {result}") + + +# Run the example +if __name__ == "__main__": + asyncio.run(main()) +``` + + + + + ## ✨ MCP Permission Management LiteLLM supports managing permissions for MCP Servers by Keys, Teams, Organizations (entities) on LiteLLM. When a MCP client attempts to list tools, LiteLLM will only return the tools the entity has permissions to access. diff --git a/litellm/experimental_mcp_client/client.py b/litellm/experimental_mcp_client/client.py index e69de29bb2d1..af2cb171dad9 100644 --- a/litellm/experimental_mcp_client/client.py +++ b/litellm/experimental_mcp_client/client.py @@ -0,0 +1,164 @@ +""" +LiteLLM Proxy uses this MCP Client to connnect to other MCP servers. +""" +import base64 +from datetime import timedelta +from typing import List, Optional + +from mcp import ClientSession +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client +from mcp.types import CallToolRequestParams as MCPCallToolRequestParams +from mcp.types import CallToolResult as MCPCallToolResult +from mcp.types import Tool as MCPTool + +from litellm.types.mcp import MCPAuth, MCPAuthType, MCPTransport, MCPTransportType + + +def to_basic_auth(auth_value: str) -> str: + """Convert auth value to Basic Auth format.""" + return base64.b64encode(auth_value.encode("utf-8")).decode() + + +class MCPClient: + """ + MCP Client supporting: + SSE and HTTP transports + Authentication via Bearer token, Basic Auth, or API Key + Tool calling with error handling and result parsing + """ + + def __init__( + self, + server_url: str, + transport_type: MCPTransportType = MCPTransport.http, + auth_type: MCPAuthType = None, + auth_value: Optional[str] = None, + timeout: float = 60.0, + ): + self.server_url: str = server_url + self.transport_type: MCPTransport = transport_type + self.auth_type: MCPAuthType = auth_type + self.timeout: float = timeout + self._mcp_auth_value: Optional[str] = None + self._session: Optional[ClientSession] = None + self._context = None + self._transport_ctx = None + self._transport = None + self._session_ctx = None + + # handle the basic auth value if provided + if auth_value: + self.update_auth_value(auth_value) + + async def __aenter__(self): + """ + Enable async context manager support. + Initializes the transport and session. + """ + await self.connect() + return self + + async def connect(self): + """Initialize the transport and session.""" + if self._session: + return # Already connected + + headers = self._get_auth_headers() + + if self.transport_type == MCPTransport.sse: + self._transport_ctx = sse_client( + url=self.server_url, + timeout=self.timeout, + headers=headers, + ) + self._transport = await self._transport_ctx.__aenter__() + self._session_ctx = ClientSession(self._transport[0], self._transport[1]) + self._session = await self._session_ctx.__aenter__() + await self._session.initialize() + else: + self._transport_ctx = streamablehttp_client( + url=self.server_url, + timeout=timedelta(seconds=self.timeout), + headers=headers, + ) + self._transport = await self._transport_ctx.__aenter__() + self._session_ctx = ClientSession(self._transport[0], self._transport[1]) + self._session = await self._session_ctx.__aenter__() + await self._session.initialize() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Cleanup when exiting context manager.""" + if self._session: + await self._session_ctx.__aexit__(exc_type, exc_val, exc_tb) # type: ignore + if self._transport_ctx: + await self._transport_ctx.__aexit__(exc_type, exc_val, exc_tb) + + async def disconnect(self): + """Clean up session and connections.""" + if self._session: + try: + # Ensure session is properly closed + await self._session.close() # type: ignore + except Exception: + pass + self._session = None + + if self._context: + try: + await self._context.__aexit__(None, None, None) # type: ignore + except Exception: + pass + self._context = None + + def update_auth_value(self, mcp_auth_value: str): + """ + Set the authentication header for the MCP client. + """ + if self.auth_type == MCPAuth.basic: + # Assuming mcp_auth_value is in format "username:password", convert it when updating + mcp_auth_value = to_basic_auth(mcp_auth_value) + self._mcp_auth_value = mcp_auth_value + + def _get_auth_headers(self) -> dict: + """Generate authentication headers based on auth type.""" + if not self._mcp_auth_value: + return {} + + if self.auth_type == MCPAuth.bearer_token: + return {"Authorization": f"Bearer {self._mcp_auth_value}"} + elif self.auth_type == MCPAuth.basic: + return {"Authorization": f"Basic {self._mcp_auth_value}"} + elif self.auth_type == MCPAuth.api_key: + return {"X-API-Key": self._mcp_auth_value} + return {} + + async def list_tools(self) -> List[MCPTool]: + """List available tools from the server.""" + if not self._session: + await self.connect() + if self._session is None: + raise ValueError("Session is not initialized") + + result = await self._session.list_tools() + return result.tools + + async def call_tool( + self, call_tool_request_params: MCPCallToolRequestParams + ) -> MCPCallToolResult: + """ + Call an MCP Tool. + """ + if not self._session: + await self.connect() + + if self._session is None: + raise ValueError("Session is not initialized") + + tool_result = await self._session.call_tool( + name=call_tool_request_params.name, + arguments=call_tool_request_params.arguments, + ) + return tool_result + + diff --git a/litellm/proxy/_experimental/mcp_server/auth/litellm_auth_handler.py b/litellm/proxy/_experimental/mcp_server/auth/litellm_auth_handler.py index 655b9ad0d869..b04fc3a0a49e 100644 --- a/litellm/proxy/_experimental/mcp_server/auth/litellm_auth_handler.py +++ b/litellm/proxy/_experimental/mcp_server/auth/litellm_auth_handler.py @@ -1,3 +1,5 @@ +from typing import Optional + from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser from litellm.proxy._types import UserAPIKeyAuth @@ -8,5 +10,6 @@ class LiteLLMAuthenticatedUser(AuthenticatedUser): Wrapper class to make UserAPIKeyAuth compatible with MCP's AuthenticatedUser """ - def __init__(self, user_api_key_auth: UserAPIKeyAuth): + def __init__(self, user_api_key_auth: UserAPIKeyAuth, mcp_auth_header: Optional[str] = None): self.user_api_key_auth = user_api_key_auth + self.mcp_auth_header = mcp_auth_header diff --git a/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py b/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py index c92aedfe9872..177fbafa5e31 100644 --- a/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py +++ b/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py @@ -1,11 +1,11 @@ -from typing import List, Optional +from typing import List, Optional, Tuple from starlette.datastructures import Headers from starlette.requests import Request from starlette.types import Scope from litellm._logging import verbose_logger -from litellm.proxy._types import LiteLLM_TeamTable, UserAPIKeyAuth +from litellm.proxy._types import LiteLLM_TeamTable, SpecialHeaders, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth @@ -16,11 +16,14 @@ class UserAPIKeyAuthMCP: Utilizes the main `user_api_key_auth` function to validate the request """ - LITELLM_API_KEY_HEADER_NAME_PRIMARY = "x-litellm-api-key" - LITELLM_API_KEY_HEADER_NAME_SECONDARY = "Authorization" + LITELLM_API_KEY_HEADER_NAME_PRIMARY = SpecialHeaders.custom_litellm_api_key.value + LITELLM_API_KEY_HEADER_NAME_SECONDARY = SpecialHeaders.openai_authorization.value + + # This is the header to use if you want LiteLLM to use this header for authenticating to the MCP server + LITELLM_MCP_AUTH_HEADER_NAME = SpecialHeaders.mcp_auth.value @staticmethod - async def user_api_key_auth_mcp(scope: Scope) -> UserAPIKeyAuth: + async def user_api_key_auth_mcp(scope: Scope) -> Tuple[UserAPIKeyAuth, Optional[str]]: """ Validate and extract headers from the ASGI scope for MCP requests. @@ -29,6 +32,7 @@ async def user_api_key_auth_mcp(scope: Scope) -> UserAPIKeyAuth: Returns: UserAPIKeyAuth containing validated authentication information + mcp_auth_header: Optional[str] MCP auth header to be passed to the MCP server Raises: HTTPException: If headers are invalid or missing required headers @@ -37,6 +41,7 @@ async def user_api_key_auth_mcp(scope: Scope) -> UserAPIKeyAuth: litellm_api_key = ( UserAPIKeyAuthMCP.get_litellm_api_key_from_headers(headers) or "" ) + mcp_auth_header = headers.get(UserAPIKeyAuthMCP.LITELLM_MCP_AUTH_HEADER_NAME) # Create a proper Request object with mock body method to avoid ASGI receive channel issues request = Request(scope=scope) @@ -52,7 +57,7 @@ async def mock_body(): api_key=litellm_api_key, request=request ) - return validated_user_api_key_auth + return validated_user_api_key_auth, mcp_auth_header @staticmethod def get_litellm_api_key_from_headers(headers: Headers) -> Optional[str]: diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index dc2701af8013..d32a17791453 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -11,12 +11,12 @@ import json from typing import Any, Dict, List, Optional, cast -from mcp import ClientSession -from mcp.client.sse import sse_client +from mcp.types import CallToolRequestParams as MCPCallToolRequestParams from mcp.types import CallToolResult from mcp.types import Tool as MCPTool from litellm._logging import verbose_logger +from litellm.experimental_mcp_client.client import MCPClient from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import ( UserAPIKeyAuthMCP, ) @@ -29,12 +29,6 @@ MCPTransportType, UserAPIKeyAuth, ) - -try: - from mcp.client.streamable_http import streamablehttp_client -except ImportError: - streamablehttp_client = None # type: ignore - from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer @@ -164,7 +158,9 @@ async def get_allowed_mcp_servers( return list(self.get_registry().keys()) async def list_tools( - self, user_api_key_auth: Optional[UserAPIKeyAuth] = None + self, + user_api_key_auth: Optional[UserAPIKeyAuth] = None, + mcp_auth_header: Optional[str] = None, ) -> List[MCPTool]: """ List all tools available across all MCP Servers. @@ -183,7 +179,10 @@ async def list_tools( verbose_logger.warning(f"MCP Server {server_id} not found") continue try: - tools = await self._get_tools_from_server(server) + tools = await self._get_tools_from_server( + server=server, + mcp_auth_header=mcp_auth_header, + ) list_tools_result.extend(tools) except Exception as e: verbose_logger.exception( @@ -192,7 +191,30 @@ async def list_tools( return list_tools_result - async def _get_tools_from_server(self, server: MCPServer) -> List[MCPTool]: + ######################################################### + # Methods that call the upstream MCP servers + ######################################################### + def _create_mcp_client(self, server: MCPServer, mcp_auth_header: Optional[str] = None) -> MCPClient: + """ + Create an MCPClient instance for the given server. + + Args: + server (MCPServer): The server configuration + mcp_auth_header: MCP auth header to be passed to the MCP server. This is optional and will be used if provided. + + Returns: + MCPClient: Configured MCP client instance + """ + transport = server.transport or MCPTransport.sse + return MCPClient( + server_url=server.url, + transport_type=transport, + auth_type=server.auth_type, + auth_value=mcp_auth_header or server.authentication_token, + timeout=60.0, + ) + + async def _get_tools_from_server(self, server: MCPServer, mcp_auth_header: Optional[str] = None) -> List[MCPTool]: """ Helper method to get tools from a single MCP server. @@ -203,57 +225,51 @@ async def _get_tools_from_server(self, server: MCPServer) -> List[MCPTool]: List[MCPTool]: List of tools available on the server """ verbose_logger.debug(f"Connecting to url: {server.url}") - verbose_logger.info("_get_tools_from_server...") - # send transport to connect to the server - if server.transport is None or server.transport == MCPTransport.sse: - async with sse_client(url=server.url) as (read, write): - async with ClientSession(read, write) as session: - await session.initialize() - - tools_result = await session.list_tools() - verbose_logger.debug(f"Tools from {server.name}: {tools_result}") - - # Update tool to server mapping - for tool in tools_result.tools: - self.tool_name_to_mcp_server_name_mapping[tool.name] = ( - server.name - ) - - return tools_result.tools - elif server.transport == MCPTransport.http: - if streamablehttp_client is None: - verbose_logger.error( - "streamablehttp_client not available - install mcp with HTTP support" - ) - raise ValueError( - "streamablehttp_client not available - please run `pip install mcp -U`" - ) - verbose_logger.debug(f"Using HTTP streamable transport for {server.url}") - async with streamablehttp_client( - url=server.url, - ) as (read_stream, write_stream, get_session_id): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - if get_session_id is not None: - session_id = get_session_id() - if session_id: - verbose_logger.debug(f"HTTP session ID: {session_id}") - - tools_result = await session.list_tools() - verbose_logger.debug(f"Tools from {server.name}: {tools_result}") - - # Update tool to server mapping - for tool in tools_result.tools: - self.tool_name_to_mcp_server_name_mapping[tool.name] = ( - server.name - ) - - return tools_result.tools - else: - verbose_logger.warning(f"Unsupported transport type: {server.transport}") - return [] + + client = self._create_mcp_client( + server=server, + mcp_auth_header=mcp_auth_header, + ) + async with client: + tools = await client.list_tools() + verbose_logger.debug(f"Tools from {server.name}: {tools}") + + # Update tool to server mapping + for tool in tools: + self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name + + return tools + + async def call_tool( + self, + name: str, + arguments: Dict[str, Any], + user_api_key_auth: Optional[UserAPIKeyAuth] = None, + mcp_auth_header: Optional[str] = None, + ) -> CallToolResult: + """ + Call a tool with the given name and arguments + """ + mcp_server = self._get_mcp_server_from_tool_name(name) + if mcp_server is None: + raise ValueError(f"Tool {name} not found") + + client = self._create_mcp_client( + server=mcp_server, + mcp_auth_header=mcp_auth_header, + ) + async with client: + call_tool_params = MCPCallToolRequestParams( + name=name, + arguments=arguments, + ) + return await client.call_tool(call_tool_params) + + ######################################################### + # End of Methods that call the upstream MCP servers + ######################################################### + def initialize_tool_name_to_mcp_server_name_mapping(self): """ @@ -278,46 +294,6 @@ async def _initialize_tool_name_to_mcp_server_name_mapping(self): for tool in tools: self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name - async def call_tool(self, name: str, arguments: Dict[str, Any]): - """ - Call a tool with the given name and arguments - """ - mcp_server = self._get_mcp_server_from_tool_name(name) - if mcp_server is None: - raise ValueError(f"Tool {name} not found") - elif mcp_server.transport is None or mcp_server.transport == MCPTransport.sse: - async with sse_client(url=mcp_server.url) as (read, write): - async with ClientSession(read, write) as session: - await session.initialize() - return await session.call_tool(name, arguments) - elif mcp_server.transport == MCPTransport.http: - if streamablehttp_client is None: - verbose_logger.error( - "streamablehttp_client not available - install mcp with HTTP support" - ) - raise ValueError( - "streamablehttp_client not available - please run `pip install mcp -U`" - ) - verbose_logger.debug( - f"Using HTTP streamable transport for tool call: {name}" - ) - async with streamablehttp_client( - url=mcp_server.url, - ) as (read_stream, write_stream, get_session_id): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - if get_session_id is not None: - session_id = get_session_id() - if session_id: - verbose_logger.debug( - f"HTTP session ID for tool call: {session_id}" - ) - - return await session.call_tool(name, arguments) - else: - return CallToolResult(content=[], isError=True) - def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPServer]: """ Get the MCP Server from the tool name diff --git a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py index 2b983bd7e8c0..9094be6f42e4 100644 --- a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py @@ -70,7 +70,9 @@ async def list_tool_rest_api( if server_id and server.server_id != server_id: continue try: - tools = await global_mcp_server_manager._get_tools_from_server(server) + tools = await global_mcp_server_manager._get_tools_from_server( + server=server, + ) for tool in tools: list_tools_result.append( ListMCPToolsRestAPIResponseObject( diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 9dcb516ef5ba..83f922a223b3 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -4,7 +4,7 @@ import asyncio import contextlib -from typing import Any, AsyncIterator, Dict, List, Optional, Union +from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union from fastapi import FastAPI, HTTPException from pydantic import ConfigDict @@ -166,11 +166,14 @@ async def list_tools() -> list[MCPTool]: List all available tools """ # Get user authentication from context variable - user_api_key_auth = get_auth_context() + user_api_key_auth, mcp_auth_header = get_auth_context() verbose_logger.debug( f"MCP list_tools - User API Key Auth from context: {user_api_key_auth}" ) - return await _list_mcp_tools(user_api_key_auth) + return await _list_mcp_tools( + user_api_key_auth=user_api_key_auth, + mcp_auth_header=mcp_auth_header, + ) @server.call_tool() async def mcp_server_tool_call( @@ -190,9 +193,15 @@ async def mcp_server_tool_call( HTTPException: If tool not found or arguments missing """ # Validate arguments + user_api_key_auth, mcp_auth_header = get_auth_context() + verbose_logger.debug( + f"MCP mcp_server_tool_call - User API Key Auth from context: {user_api_key_auth}" + ) response = await call_mcp_tool( name=name, arguments=arguments, + user_api_key_auth=user_api_key_auth, + mcp_auth_header=mcp_auth_header, ) return response @@ -206,6 +215,7 @@ async def mcp_server_tool_call( async def _list_mcp_tools( user_api_key_auth: Optional[UserAPIKeyAuth] = None, + mcp_auth_header: Optional[str] = None, ) -> List[MCPTool]: """ List all available tools @@ -229,6 +239,7 @@ async def _list_mcp_tools( tools_from_mcp_servers: List[MCPTool] = ( await global_mcp_server_manager.list_tools( user_api_key_auth=user_api_key_auth, + mcp_auth_header=mcp_auth_header, ) ) verbose_logger.debug("TOOLS FROM MCP SERVERS: %s", tools_from_mcp_servers) @@ -238,7 +249,11 @@ async def _list_mcp_tools( @client async def call_mcp_tool( - name: str, arguments: Optional[Dict[str, Any]] = None, **kwargs: Any + name: str, + arguments: Optional[Dict[str, Any]] = None, + user_api_key_auth: Optional[UserAPIKeyAuth] = None, + mcp_auth_header: Optional[str] = None, + **kwargs: Any ) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: """ Call a specific tool with the provided arguments @@ -270,7 +285,12 @@ async def call_mcp_tool( # Try managed server tool first if name in global_mcp_server_manager.tool_name_to_mcp_server_name_mapping: - return await _handle_managed_mcp_tool(name, arguments) + return await _handle_managed_mcp_tool( + name=name, + arguments=arguments, + user_api_key_auth=user_api_key_auth, + mcp_auth_header=mcp_auth_header, + ) # Fall back to local tool registry return await _handle_local_mcp_tool(name, arguments) @@ -295,12 +315,17 @@ def _get_standard_logging_mcp_tool_call( ) async def _handle_managed_mcp_tool( - name: str, arguments: Dict[str, Any] + name: str, + arguments: Dict[str, Any], + user_api_key_auth: Optional[UserAPIKeyAuth] = None, + mcp_auth_header: Optional[str] = None, ) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: """Handle tool execution for managed server tools""" call_tool_result = await global_mcp_server_manager.call_tool( name=name, arguments=arguments, + user_api_key_auth=user_api_key_auth, + mcp_auth_header=mcp_auth_header, ) verbose_logger.debug("CALL TOOL RESULT: %s", call_tool_result) return call_tool_result.content @@ -325,11 +350,14 @@ async def handle_streamable_http_mcp( """Handle MCP requests through StreamableHTTP.""" try: # Validate headers and log request info - user_api_key_auth: UserAPIKeyAuth = ( + user_api_key_auth, mcp_auth_header = ( await UserAPIKeyAuthMCP.user_api_key_auth_mcp(scope) ) # Set the auth context variable for easy access in MCP functions - set_auth_context(user_api_key_auth) + set_auth_context( + user_api_key_auth=user_api_key_auth, + mcp_auth_header=mcp_auth_header, + ) # Ensure session managers are initialized if not _SESSION_MANAGERS_INITIALIZED: @@ -346,11 +374,14 @@ async def handle_sse_mcp(scope: Scope, receive: Receive, send: Send) -> None: """Handle MCP requests through SSE.""" try: # Validate headers and log request info - user_api_key_auth: UserAPIKeyAuth = ( + user_api_key_auth, mcp_auth_header = ( await UserAPIKeyAuthMCP.user_api_key_auth_mcp(scope) ) # Set the auth context variable for easy access in MCP functions - set_auth_context(user_api_key_auth) + set_auth_context( + user_api_key_auth=user_api_key_auth, + mcp_auth_header=mcp_auth_header, + ) # Ensure session managers are initialized if not _SESSION_MANAGERS_INITIALIZED: @@ -390,17 +421,31 @@ def get_mcp_server_enabled() -> Dict[str, bool]: ############ Auth Context Functions #################### ######################################################## - def set_auth_context(user_api_key_auth: UserAPIKeyAuth) -> None: - """Set the UserAPIKeyAuth in the auth context variable.""" - auth_user = LiteLLMAuthenticatedUser(user_api_key_auth) + def set_auth_context(user_api_key_auth: UserAPIKeyAuth, mcp_auth_header: Optional[str] = None) -> None: + """ + Set the UserAPIKeyAuth in the auth context variable. + + Args: + user_api_key_auth: UserAPIKeyAuth object + mcp_auth_header: MCP auth header to be passed to the MCP server + """ + auth_user = LiteLLMAuthenticatedUser( + user_api_key_auth=user_api_key_auth, + mcp_auth_header=mcp_auth_header, + ) auth_context_var.set(auth_user) - def get_auth_context() -> Optional[UserAPIKeyAuth]: - """Get the UserAPIKeyAuth from the auth context variable.""" + def get_auth_context() -> Tuple[Optional[UserAPIKeyAuth], Optional[str]]: + """ + Get the UserAPIKeyAuth from the auth context variable. + + Returns: + Tuple[Optional[UserAPIKeyAuth], Optional[str]]: UserAPIKeyAuth object and MCP auth header + """ auth_user = auth_context_var.get() if auth_user and isinstance(auth_user, LiteLLMAuthenticatedUser): - return auth_user.user_api_key_auth - return None + return auth_user.user_api_key_auth, auth_user.mcp_auth_header + return None, None ######################################################## ############ End of Auth Context Functions ############# diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index f0c122690cc8..e214da35d8ba 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -17,6 +17,13 @@ from litellm.types.integrations.slack_alerting import AlertType from litellm.types.llms.openai import AllMessageValues, OpenAIFileObject +from litellm.types.mcp import ( + MCPAuthType, + MCPSpecVersion, + MCPSpecVersionType, + MCPTransport, + MCPTransportType, +) from litellm.types.router import RouterErrors, UpdateRouterConfig from litellm.types.utils import ( CallTypes, @@ -830,32 +837,6 @@ class SpecialMCPServerName(str, enum.Enum): all_team_servers = "all-team-mcpservers" all_proxy_servers = "all-proxy-mcpservers" - -class MCPTransport(str, enum.Enum): - sse = "sse" - http = "http" - - -class MCPSpecVersion(str, enum.Enum): - nov_2024 = "2024-11-05" - mar_2025 = "2025-03-26" - - -class MCPAuth(str, enum.Enum): - none = "none" - api_key = "api_key" - bearer_token = "bearer_token" - basic = "basic" - - -# MCP Literals -MCPTransportType = Literal[MCPTransport.sse, MCPTransport.http] -MCPSpecVersionType = Literal[MCPSpecVersion.nov_2024, MCPSpecVersion.mar_2025] -MCPAuthType = Optional[ - Literal[MCPAuth.none, MCPAuth.api_key, MCPAuth.bearer_token, MCPAuth.basic] -] - - # MCP Proxy Request Types class NewMCPServerRequest(LiteLLMPydanticObjectBase): server_id: Optional[str] = None @@ -2703,6 +2684,7 @@ class SpecialHeaders(enum.Enum): google_ai_studio_authorization = "x-goog-api-key" azure_apim_authorization = "Ocp-Apim-Subscription-Key" custom_litellm_api_key = "x-litellm-api-key" + mcp_auth = "x-mcp-auth" class LitellmDataForBackendLLMCall(TypedDict, total=False): diff --git a/litellm/types/mcp.py b/litellm/types/mcp.py new file mode 100644 index 000000000000..988b0ef8e8af --- /dev/null +++ b/litellm/types/mcp.py @@ -0,0 +1,29 @@ +import enum +from typing import Literal, Optional + +from pydantic import BaseModel, ConfigDict +from typing_extensions import TypedDict + + +class MCPTransport(str, enum.Enum): + sse = "sse" + http = "http" + + +class MCPSpecVersion(str, enum.Enum): + nov_2024 = "2024-11-05" + mar_2025 = "2025-03-26" + +class MCPAuth(str, enum.Enum): + none = "none" + api_key = "api_key" + bearer_token = "bearer_token" + basic = "basic" + + +# MCP Literals +MCPTransportType = Literal[MCPTransport.sse, MCPTransport.http] +MCPSpecVersionType = Literal[MCPSpecVersion.nov_2024, MCPSpecVersion.mar_2025] +MCPAuthType = Optional[ + Literal[MCPAuth.none, MCPAuth.api_key, MCPAuth.bearer_token, MCPAuth.basic] +] diff --git a/litellm/types/mcp_server/mcp_server_manager.py b/litellm/types/mcp_server/mcp_server_manager.py index 9edbc80c649c..2a7c5e97366b 100644 --- a/litellm/types/mcp_server/mcp_server_manager.py +++ b/litellm/types/mcp_server/mcp_server_manager.py @@ -16,9 +16,9 @@ class MCPServer(BaseModel): server_id: str name: str url: str - # TODO: alter the types to be the Literal explicit transport: MCPTransportType spec_version: MCPSpecVersionType auth_type: Optional[MCPAuthType] = None + authentication_token: Optional[str] = None mcp_info: Optional[MCPInfo] = None model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/tests/mcp_tests/test_mcp_client_unit.py b/tests/mcp_tests/test_mcp_client_unit.py new file mode 100644 index 000000000000..e0efea89b3b4 --- /dev/null +++ b/tests/mcp_tests/test_mcp_client_unit.py @@ -0,0 +1,169 @@ +""" +Unit tests for the MCPClient class - critical functionality only. +""" +import base64 +import os +import sys +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +# Add the project root to the path +sys.path.insert(0, os.path.abspath("../../..")) + +from litellm.experimental_mcp_client.client import MCPClient +from litellm.types.mcp import MCPAuth, MCPTransport +from mcp.types import Tool as MCPTool, CallToolResult as MCPCallToolResult + + +class TestMCPClientUnitTests: + """Unit tests for MCPClient functionality.""" + + def test_init_with_auth(self): + """Test initialization with authentication.""" + client = MCPClient( + server_url="http://example.com", + transport_type=MCPTransport.sse, + auth_type=MCPAuth.bearer_token, + auth_value="test_token", + timeout=30.0 + ) + assert client.server_url == "http://example.com" + assert client.transport_type == MCPTransport.sse + assert client.auth_type == MCPAuth.bearer_token + assert client.timeout == 30.0 + assert client._mcp_auth_value == "test_token" + + def test_get_auth_headers(self): + """Test authentication header generation for different auth types.""" + # Bearer token + client = MCPClient( + "http://example.com", + auth_type=MCPAuth.bearer_token, + auth_value="test_token" + ) + headers = client._get_auth_headers() + assert headers == {"Authorization": "Bearer test_token"} + + # Basic auth + client = MCPClient( + "http://example.com", + auth_type=MCPAuth.basic, + auth_value="user:pass" + ) + expected_encoded = base64.b64encode("user:pass".encode("utf-8")).decode() + headers = client._get_auth_headers() + assert headers == {"Authorization": f"Basic {expected_encoded}"} + + # API key + client = MCPClient( + "http://example.com", + auth_type=MCPAuth.api_key, + auth_value="api_key_123" + ) + headers = client._get_auth_headers() + assert headers == {"X-API-Key": "api_key_123"} + + @pytest.mark.asyncio + @patch('litellm.experimental_mcp_client.client.streamablehttp_client') + @patch('litellm.experimental_mcp_client.client.ClientSession') + async def test_connect(self, mock_session_class, mock_transport): + """Test connecting to MCP server with authentication.""" + # Setup mocks + mock_transport_ctx = AsyncMock() + mock_transport.return_value = mock_transport_ctx + mock_transport_instance = MagicMock() + mock_transport_ctx.__aenter__ = AsyncMock(return_value=mock_transport_instance) + + mock_session_ctx = AsyncMock() + mock_session_class.return_value = mock_session_ctx + mock_session_instance = AsyncMock() + mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session_instance) + + client = MCPClient( + "http://example.com", + auth_type=MCPAuth.bearer_token, + auth_value="test_token" + ) + await client.connect() + + # Verify transport was created with auth headers + call_args = mock_transport.call_args + assert call_args[1]['headers'] == {"Authorization": "Bearer test_token"} + + # Verify session was initialized + mock_session_instance.initialize.assert_called_once() + assert client._session == mock_session_instance + + @pytest.mark.asyncio + @patch('litellm.experimental_mcp_client.client.streamablehttp_client') + @patch('litellm.experimental_mcp_client.client.ClientSession') + async def test_list_tools(self, mock_session_class, mock_transport): + """Test listing tools from the server.""" + # Setup mocks + mock_transport_ctx = AsyncMock() + mock_transport.return_value = mock_transport_ctx + mock_transport_instance = MagicMock() + mock_transport_ctx.__aenter__ = AsyncMock(return_value=mock_transport_instance) + + mock_session_ctx = AsyncMock() + mock_session_class.return_value = mock_session_ctx + mock_session_instance = AsyncMock() + mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session_instance) + + mock_tools = [ + MCPTool( + name="test_tool", + description="Test tool", + inputSchema={ + "type": "object", + "properties": {"arg1": {"type": "string"}}, + "required": ["arg1"] + } + ) + ] + mock_result = MagicMock() + mock_result.tools = mock_tools + mock_session_instance.list_tools.return_value = mock_result + + client = MCPClient("http://example.com") + result = await client.list_tools() + + assert result == mock_tools + mock_session_instance.initialize.assert_called_once() + mock_session_instance.list_tools.assert_called_once() + + @pytest.mark.asyncio + @patch('litellm.experimental_mcp_client.client.streamablehttp_client') + @patch('litellm.experimental_mcp_client.client.ClientSession') + async def test_call_tool(self, mock_session_class, mock_transport): + """Test calling a tool.""" + from mcp.types import CallToolRequestParams + + # Setup mocks + mock_transport_ctx = AsyncMock() + mock_transport.return_value = mock_transport_ctx + mock_transport_instance = MagicMock() + mock_transport_ctx.__aenter__ = AsyncMock(return_value=mock_transport_instance) + + mock_session_ctx = AsyncMock() + mock_session_class.return_value = mock_session_ctx + mock_session_instance = AsyncMock() + mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session_instance) + + mock_result = MCPCallToolResult(content=[]) + mock_session_instance.call_tool.return_value = mock_result + + client = MCPClient("http://example.com") + params = CallToolRequestParams(name="test_tool", arguments={"arg1": "value1"}) + result = await client.call_tool(params) + + assert result == mock_result + mock_session_instance.initialize.assert_called_once() + mock_session_instance.call_tool.assert_called_once_with( + name="test_tool", + arguments={"arg1": "value1"} + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/tests/mcp_tests/test_mcp_server.py b/tests/mcp_tests/test_mcp_server.py index 3c009c54f852..3115ab80ef57 100644 --- a/tests/mcp_tests/test_mcp_server.py +++ b/tests/mcp_tests/test_mcp_server.py @@ -101,26 +101,17 @@ async def test_mcp_http_transport_list_tools_mock(): ) ] - # Mock the session and its methods - mock_session = AsyncMock() - mock_session.initialize = AsyncMock() - mock_session.list_tools = AsyncMock(return_value=ListToolsResult(tools=mock_tools)) - - # Create an async context manager mock for streamablehttp_client - @asynccontextmanager - async def mock_streamablehttp_client(url): - read_stream = AsyncMock() - write_stream = AsyncMock() - get_session_id = MagicMock(return_value="test-session-123") - yield (read_stream, write_stream, get_session_id) - - # Create an async context manager mock for ClientSession - @asynccontextmanager - async def mock_client_session(read_stream, write_stream): - yield mock_session - - with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.streamablehttp_client', mock_streamablehttp_client), \ - patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.ClientSession', mock_client_session): + # Create a mock MCPClient that returns our test tools + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=mock_tools) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + # Mock the MCPClient constructor to return our mock + def mock_client_constructor(*args, **kwargs): + return mock_client + + with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor): # Load server config with HTTP transport test_manager.load_servers_from_config({ @@ -139,9 +130,9 @@ async def mock_client_session(read_stream, write_stream): assert tools[0].name == "gmail_send_email" assert tools[1].name == "calendar_create_event" - # Verify session methods were called - mock_session.initialize.assert_called_once() - mock_session.list_tools.assert_called_once() + # Verify client methods were called + mock_client.__aenter__.assert_called() + mock_client.list_tools.assert_called_once() # Verify tool mapping was updated assert test_manager.tool_name_to_mcp_server_name_mapping["gmail_send_email"] == "test_http_server" @@ -166,26 +157,17 @@ async def test_mcp_http_transport_call_tool_mock(): isError=False ) - # Mock the session and its methods - mock_session = AsyncMock() - mock_session.initialize = AsyncMock() - mock_session.call_tool = AsyncMock(return_value=mock_result) - - # Create an async context manager mock for streamablehttp_client - @asynccontextmanager - async def mock_streamablehttp_client(url): - read_stream = AsyncMock() - write_stream = AsyncMock() - get_session_id = MagicMock(return_value="test-session-456") - yield (read_stream, write_stream, get_session_id) - - # Create an async context manager mock for ClientSession - @asynccontextmanager - async def mock_client_session(read_stream, write_stream): - yield mock_session - - with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.streamablehttp_client', mock_streamablehttp_client), \ - patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.ClientSession', mock_client_session): + # Create a mock MCPClient that returns our test result + mock_client = AsyncMock() + mock_client.call_tool = AsyncMock(return_value=mock_result) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + # Mock the MCPClient constructor to return our mock + def mock_client_constructor(*args, **kwargs): + return mock_client + + with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor): # Load server config with HTTP transport test_manager.load_servers_from_config({ @@ -216,16 +198,9 @@ async def mock_client_session(read_stream, write_stream): assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Email sent successfully to test@example.com" - # Verify session methods were called - mock_session.initialize.assert_called_once() - mock_session.call_tool.assert_called_once_with( - "gmail_send_email", - { - "to": "test@example.com", - "subject": "Test Subject", - "body": "Test email body" - } - ) + # Verify client methods were called + mock_client.__aenter__.assert_called() + mock_client.call_tool.assert_called_once() @pytest.mark.asyncio @@ -246,26 +221,17 @@ async def test_mcp_http_transport_call_tool_error_mock(): isError=True ) - # Mock the session and its methods - mock_session = AsyncMock() - mock_session.initialize = AsyncMock() - mock_session.call_tool = AsyncMock(return_value=mock_error_result) - - # Create an async context manager mock for streamablehttp_client - @asynccontextmanager - async def mock_streamablehttp_client(url): - read_stream = AsyncMock() - write_stream = AsyncMock() - get_session_id = MagicMock(return_value="test-session-789") - yield (read_stream, write_stream, get_session_id) - - # Create an async context manager mock for ClientSession - @asynccontextmanager - async def mock_client_session(read_stream, write_stream): - yield mock_session - - with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.streamablehttp_client', mock_streamablehttp_client), \ - patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.ClientSession', mock_client_session): + # Create a mock MCPClient that returns our test error result + mock_client = AsyncMock() + mock_client.call_tool = AsyncMock(return_value=mock_error_result) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + # Mock the MCPClient constructor to return our mock + def mock_client_constructor(*args, **kwargs): + return mock_client + + with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient', mock_client_constructor): # Load server config with HTTP transport test_manager.load_servers_from_config({ @@ -292,9 +258,9 @@ async def mock_client_session(read_stream, write_stream): assert isinstance(result.content[0], TextContent) assert "Error: Invalid email address" in result.content[0].text - # Verify session methods were called - mock_session.initialize.assert_called_once() - mock_session.call_tool.assert_called_once() + # Verify client methods were called + mock_client.__aenter__.assert_called() + mock_client.call_tool.assert_called_once() @pytest.mark.asyncio diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py b/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py index f0c5c63a1406..8458f9a45ace 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py @@ -116,17 +116,19 @@ async def test_get_allowed_mcp_servers_for_key( mock_find_unique.assert_not_called() @pytest.mark.parametrize( - "headers,expected_api_key", + "headers,expected_api_key,expected_mcp_auth_header", [ # Test case 1: x-litellm-api-key header present ( [(b"x-litellm-api-key", b"test-api-key-123")], "test-api-key-123", + None, ), # Test case 2: Authorization header present (fallback) ( [(b"authorization", b"Bearer test-auth-token")], "Bearer test-auth-token", + None, ), # Test case 3: Both headers present (primary should win) ( @@ -135,22 +137,40 @@ async def test_get_allowed_mcp_servers_for_key( (b"authorization", b"Bearer fallback-token"), ], "primary-key", + None, ), # Test case 4: Case insensitive headers ( [(b"X-LITELLM-API-KEY", b"case-insensitive-key")], "case-insensitive-key", + None, ), # Test case 5: No relevant headers ( [(b"content-type", b"application/json")], "", + None, ), # Test case 6: Empty headers - ([], ""), + ([], "", None), + # Test case 7: MCP auth header present + ( + [ + (b"x-litellm-api-key", b"test-api-key-123"), + (b"x-mcp-auth", b"mcp-auth-token"), + ], + "test-api-key-123", + "mcp-auth-token", + ), + # Test case 8: Only MCP auth header present (no API key) + ( + [(b"x-mcp-auth", b"mcp-auth-token")], + "", + "mcp-auth-token", + ), ], ) - async def test_user_api_key_auth_mcp(self, headers, expected_api_key): + async def test_user_api_key_auth_mcp(self, headers, expected_api_key, expected_mcp_auth_header): """Test user_api_key_auth_mcp method with various header scenarios""" # Create ASGI scope with headers @@ -174,10 +194,11 @@ async def test_user_api_key_auth_mcp(self, headers, expected_api_key): mock_user_api_key_auth.return_value = mock_auth_result # Call the method - result = await UserAPIKeyAuthMCP.user_api_key_auth_mcp(scope) + auth_result, mcp_auth_header = await UserAPIKeyAuthMCP.user_api_key_auth_mcp(scope) - # Assert the result - assert result == mock_auth_result + # Assert the results + assert auth_result == mock_auth_result + assert mcp_auth_header == expected_mcp_auth_header # Verify user_api_key_auth was called with correct parameters mock_user_api_key_auth.assert_called_once() diff --git a/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py index b34bf4ae82c5..ea5e8d9f3d6b 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py @@ -18,11 +18,11 @@ from litellm.proxy._types import ( LiteLLM_MCPServerTable, LitellmUserRoles, - MCPAuth, MCPSpecVersion, MCPTransport, UserAPIKeyAuth, ) +from litellm.types.mcp import MCPAuth from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer