diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go new file mode 100644 index 00000000..2c2e1a13 --- /dev/null +++ b/pkg/github/notifications.go @@ -0,0 +1,243 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "time" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// getNotifications creates a tool to list notifications for the current user. +func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_notifications", + mcp.WithDescription(t("TOOL_GET_NOTIFICATIONS_DESCRIPTION", "Get notifications for the authenticated GitHub user")), + mcp.WithBoolean("all", + mcp.Description("If true, show notifications marked as read. Default: false"), + ), + mcp.WithBoolean("participating", + mcp.Description("If true, only shows notifications in which the user is directly participating or mentioned. Default: false"), + ), + mcp.WithString("since", + mcp.Description("Only show notifications updated after the given time (ISO 8601 format)"), + ), + mcp.WithString("before", + mcp.Description("Only show notifications updated before the given time (ISO 8601 format)"), + ), + mcp.WithNumber("per_page", + mcp.Description("Results per page (max 100). Default: 30"), + ), + mcp.WithNumber("page", + mcp.Description("Page number of the results to fetch. Default: 1"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + // Extract optional parameters with defaults + all, err := OptionalBoolParamWithDefault(request, "all", false) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + participating, err := OptionalBoolParamWithDefault(request, "participating", false) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + since, err := OptionalStringParamWithDefault(request, "since", "") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + before, err := OptionalStringParam(request, "before") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + perPage, err := OptionalIntParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + page, err := OptionalIntParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Build options + opts := &github.NotificationListOptions{ + All: all, + Participating: participating, + ListOptions: github.ListOptions{ + Page: page, + PerPage: perPage, + }, + } + + // Parse time parameters if provided + if since != "" { + sinceTime, err := time.Parse(time.RFC3339, since) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid since time format, should be RFC3339/ISO8601: %v", err)), nil + } + opts.Since = sinceTime + } + + if before != "" { + beforeTime, err := time.Parse(time.RFC3339, before) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid before time format, should be RFC3339/ISO8601: %v", err)), nil + } + opts.Before = beforeTime + } + + // Call GitHub API + notifications, resp, err := client.Activity.ListNotifications(ctx, opts) + if err != nil { + return nil, fmt.Errorf("failed to get notifications: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get notifications: %s", string(body))), nil + } + + // Marshal response to JSON + r, err := json.Marshal(notifications) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// ManageNotifications creates a tool to manage notifications (mark as read, mark all as read, or mark as done). +func ManageNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("manage_notifications", + mcp.WithDescription(t("TOOL_MANAGE_NOTIFICATIONS_DESCRIPTION", "Manage notifications (mark as read, mark all as read, or mark as done)")), + mcp.WithString("action", + mcp.Required(), + mcp.Description("The action to perform: 'mark_read', 'mark_all_read', or 'mark_done'"), + ), + mcp.WithString("threadID", + mcp.Description("The ID of the notification thread (required for 'mark_read' and 'mark_done')"), + ), + mcp.WithString("lastReadAt", + mcp.Description("Describes the last point that notifications were checked (optional, for 'mark_all_read'). Default: Now"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + action, err := requiredParam[string](request, "action") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + switch action { + case "mark_read": + threadID, err := requiredParam[string](request, "threadID") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + resp, err := client.Activity.MarkThreadRead(ctx, threadID) + if err != nil { + return nil, fmt.Errorf("failed to mark notification as read: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil + } + + return mcp.NewToolResultText("Notification marked as read"), nil + + case "mark_done": + threadIDStr, err := requiredParam[string](request, "threadID") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + threadID, err := strconv.ParseInt(threadIDStr, 10, 64) + if err != nil { + return mcp.NewToolResultError("Invalid threadID: must be a numeric value"), nil + } + + resp, err := client.Activity.MarkThreadDone(ctx, threadID) + if err != nil { + return nil, fmt.Errorf("failed to mark notification as done: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil + } + + return mcp.NewToolResultText("Notification marked as done"), nil + + case "mark_all_read": + lastReadAt, err := OptionalStringParam(request, "lastReadAt") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + var markReadOptions github.Timestamp + if lastReadAt != "" { + lastReadTime, err := time.Parse(time.RFC3339, lastReadAt) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil + } + markReadOptions = github.Timestamp{ + Time: lastReadTime, + } + } + + resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions) + if err != nil { + return nil, fmt.Errorf("failed to mark all notifications as read: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil + } + + return mcp.NewToolResultText("All notifications marked as read"), nil + + default: + return mcp.NewToolResultError("Invalid action: must be 'mark_read', 'mark_all_read', or 'mark_done'"), nil + } + } +} diff --git a/pkg/github/server.go b/pkg/github/server.go index e4c24171..c51e4732 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -143,6 +143,47 @@ func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, e return v, nil } +// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalParam, but it also takes a default value. +func OptionalBoolParamWithDefault(r mcp.CallToolRequest, p string, d bool) (bool, error) { + v, err := OptionalParam[bool](r, p) + if err != nil { + return false, err + } + if !v { + return d, nil + } + return v, nil +} + +// OptionalStringParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func OptionalStringParam(r mcp.CallToolRequest, p string) (string, error) { + v, err := OptionalParam[string](r, p) + if err != nil { + return "", err + } + if v == "" { + return "", nil + } + return v, nil +} + +// OptionalStringParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalParam, but it also takes a default value. +func OptionalStringParamWithDefault(r mcp.CallToolRequest, p string, d string) (string, error) { + v, err := OptionalParam[string](r, p) + if err != nil { + return "", err + } + if v == "" { + return d, nil + } + return v, nil +} + // OptionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request. // It does the following checks: // 1. Checks if the parameter is present in the request, if not, it returns its zero-value diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 35dabaef..4d4889a8 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -78,6 +78,15 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, toolsets.NewServerTool(GetSecretScanningAlert(getClient, t)), toolsets.NewServerTool(ListSecretScanningAlerts(getClient, t)), ) + + notifications := toolsets.NewToolset("notifications", "GitHub Notifications related tools"). + AddReadTools( + toolsets.NewServerTool(GetNotifications(getClient, t)), + ). + AddWriteTools( + toolsets.NewServerTool(ManageNotifications(getClient, t)), + ) + // Keep experiments alive so the system doesn't error out when it's always enabled experiments := toolsets.NewToolset("experiments", "Experimental features that are not considered stable yet") @@ -88,6 +97,7 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, tsg.AddToolset(pullRequests) tsg.AddToolset(codeSecurity) tsg.AddToolset(secretProtection) + tsg.AddToolset(notifications) tsg.AddToolset(experiments) // Enable the requested features