diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index 55b58b564..f36ab40ff 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -107,7 +107,7 @@ async def handle_sse(request): ) async def handle_messages(request): - await sse.handle_post_message(request.scope, request.receive, request._send) + return await sse.handle_post_message(request.scope, request.receive) starlette_app = Starlette( debug=True, diff --git a/examples/servers/simple-resource/mcp_simple_resource/server.py b/examples/servers/simple-resource/mcp_simple_resource/server.py index 520c8871b..988c87194 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/server.py +++ b/examples/servers/simple-resource/mcp_simple_resource/server.py @@ -64,7 +64,7 @@ async def handle_sse(request): ) async def handle_messages(request): - await sse.handle_post_message(request.scope, request.receive, request._send) + return await sse.handle_post_message(request.scope, request.receive) starlette_app = Starlette( debug=True, diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index a0f5b7b76..b50179bce 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -78,7 +78,7 @@ async def handle_sse(request): ) async def handle_messages(request): - await sse.handle_post_message(request.scope, request.receive, request._send) + return await sse.handle_post_message(request.scope, request.receive) starlette_app = Starlette( debug=True, diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 65369a785..8d5bce636 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -24,7 +24,7 @@ async def handle_sse(request): ) async def handle_messages(request): - await sse.handle_post_message(request.scope, request.receive, request._send) + return await sse.handle_post_message(request.scope, request.receive) # Create and run Starlette app starlette_app = Starlette(routes=routes) @@ -133,31 +133,26 @@ async def sse_writer(): logger.debug("Yielding read and write streams") yield (read_stream, write_stream) - async def handle_post_message( - self, scope: Scope, receive: Receive, send: Send - ) -> None: + async def handle_post_message(self, scope: Scope, receive: Receive) -> Response: logger.debug("Handling POST message") request = Request(scope, receive) session_id_param = request.query_params.get("session_id") if session_id_param is None: logger.warning("Received request without session_id") - response = Response("session_id is required", status_code=400) - return await response(scope, receive, send) + return Response("session_id is required", status_code=400) try: session_id = UUID(hex=session_id_param) logger.debug(f"Parsed session ID: {session_id}") except ValueError: logger.warning(f"Received invalid session ID: {session_id_param}") - response = Response("Invalid session ID", status_code=400) - return await response(scope, receive, send) + return Response("Invalid session ID", status_code=400) writer = self._read_stream_writers.get(session_id) if not writer: logger.warning(f"Could not find session for ID: {session_id}") - response = Response("Could not find session", status_code=404) - return await response(scope, receive, send) + return Response("Could not find session", status_code=404) json = await request.json() logger.debug(f"Received JSON: {json}") @@ -167,12 +162,9 @@ async def handle_post_message( logger.debug(f"Validated client message: {message}") except ValidationError as err: logger.error(f"Failed to parse message: {err}") - response = Response("Could not parse message", status_code=400) - await response(scope, receive, send) await writer.send(err) - return + return Response("Could not parse message", status_code=400) logger.debug(f"Sending message to writer: {message}") - response = Response("Accepted", status_code=202) - await response(scope, receive, send) await writer.send(message) + return Response("Accepted", status_code=202) diff --git a/uv.lock b/uv.lock index b3fb27d81..5f99abe32 100644 --- a/uv.lock +++ b/uv.lock @@ -171,7 +171,7 @@ wheels = [ [[package]] name = "mcp" -version = "1.0.0.dev0" +version = "1.0.1.dev0" source = { editable = "." } dependencies = [ { name = "anyio" },