diff --git a/src/mcp/server/__init__.py b/src/mcp/server/__init__.py index a0dd033d6..7cdd2667b 100644 --- a/src/mcp/server/__init__.py +++ b/src/mcp/server/__init__.py @@ -68,7 +68,7 @@ async def main(): import logging import warnings from collections.abc import Awaitable, Callable -from typing import Any, Sequence +from typing import Any, Optional, Sequence from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl @@ -379,6 +379,20 @@ async def handler(req: types.ProgressNotification): return decorator + def cancellation_notification(self): + def decorator( + func: Callable[[Optional[int], Optional[str]], Awaitable[None]], + ): + logger.debug("Registering handler for ProgressNotification") + + async def handler(req: types.CancellationNotification): + await func(req.params.requestId, req.params.reason) + + self.notification_handlers[types.CancellationNotification] = handler + return func + + return decorator + def completion(self): """Provides completions for prompts and resource templates""" diff --git a/src/mcp/types.py b/src/mcp/types.py index a2b897403..6eda16d76 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1,4 +1,4 @@ -from typing import Any, Generic, Literal, TypeVar +from typing import Any, Generic, Literal, Optional, TypeVar from pydantic import BaseModel, ConfigDict, FileUrl, RootModel from pydantic.networks import AnyUrl @@ -325,6 +325,21 @@ class ProgressNotificationParams(NotificationParams): model_config = ConfigDict(extra="allow") +class CancelledParams(BaseModel): + requestId: Optional[int] = None + reason: Optional[str] = "" + + +class CancellationNotification(Notification): + """ + An out-of-band notification used to inform the receiver of a progress update for a + long-running request. + """ + + method: Literal["cancelled"] + params: CancelledParams + + class ProgressNotification(Notification): """ An out-of-band notification used to inform the receiver of a progress update for a @@ -997,7 +1012,10 @@ class ClientRequest( class ClientNotification( RootModel[ - ProgressNotification | InitializedNotification | RootsListChangedNotification + ProgressNotification + | InitializedNotification + | RootsListChangedNotification + | CancellationNotification ] ): pass @@ -1019,6 +1037,7 @@ class ServerNotification( | ResourceListChangedNotification | ToolListChangedNotification | PromptListChangedNotification + | CancellationNotification ] ): pass