diff --git a/pkg/api/authz/authz.go b/pkg/api/authz/authz.go index bbcf6df2d8..c5f93cd813 100644 --- a/pkg/api/authz/authz.go +++ b/pkg/api/authz/authz.go @@ -338,6 +338,7 @@ func defaultRules(devMode bool, registryNoAuth bool) []rule { f = (*fake)(nil) ) + // Set up the static rules for group := range staticRules { rule := rule{ group: group, @@ -349,22 +350,38 @@ func defaultRules(devMode bool, registryNoAuth bool) []rule { rules = append(rules, rule) } - var registryRule rule + // Set up the MCP Registry API rules + var ( + registryRule rule + acrRegistryRule rule + ) + if registryNoAuth { registryRule = rule{ group: anyGroup, mux: http.NewServeMux(), } + acrRegistryRule = rule{ + group: anyGroup, + mux: http.NewServeMux(), + } } else { registryRule = rule{ group: types.GroupBasic, mux: http.NewServeMux(), } + acrRegistryRule = rule{ + group: types.GroupBasic, + mux: http.NewServeMux(), + } } + registryRule.mux.Handle("GET /v0.1", f) registryRule.mux.Handle("GET /v0.1/", f) - rules = append(rules, registryRule) + acrRegistryRule.mux.Handle("GET /acr-registry/", f) + rules = append(rules, registryRule, acrRegistryRule) + // Set up the Dev Mode rules if devMode { for group := range devModeRules { rule := rule{ diff --git a/pkg/api/handlers/registry/acr_handler.go b/pkg/api/handlers/registry/acr_handler.go new file mode 100644 index 0000000000..1b46b586e2 --- /dev/null +++ b/pkg/api/handlers/registry/acr_handler.go @@ -0,0 +1,631 @@ +package registry + +import ( + "fmt" + "net/http" + "strings" + + "github.com/obot-platform/obot/apiclient/types" + "github.com/obot-platform/obot/pkg/accesscontrolrule" + "github.com/obot-platform/obot/pkg/api" + "github.com/obot-platform/obot/pkg/api/handlers" + v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1" + "github.com/obot-platform/obot/pkg/system" + "k8s.io/apimachinery/pkg/fields" + kclient "sigs.k8s.io/controller-runtime/pkg/client" +) + +// ACRHandler handles the per-ACR registry API endpoints +type ACRHandler struct { + acrHelper *accesscontrolrule.Helper + serverURL string + registryNoAuth bool + mimeFetcher *mimeFetcher +} + +func NewACRHandler(acrHelper *accesscontrolrule.Helper, serverURL string, registryNoAuth bool) *ACRHandler { + return &ACRHandler{ + acrHelper: acrHelper, + serverURL: serverURL, + registryNoAuth: registryNoAuth, + mimeFetcher: newMimeFetcher(), + } +} + +// ListServers handles GET /mcp-registry/{acr_id}/v0.1/servers +func (h *ACRHandler) ListServers(req api.Context) error { + acrID := req.PathValue("acr_id") + if acrID == "" { + return h.notFoundError("access control rule ID is required") + } + + // Fetch the ACR + var acr v1.AccessControlRule + if err := req.Get(&acr, acrID); err != nil { + return h.notFoundError("access control rule not found") + } + + // Check authorization + if !h.isAuthorized(req, acr) { + return h.notFoundError("access control rule not found") + } + + // Parse query parameters + cursor := req.URL.Query().Get("cursor") + limit := parseLimit(req.URL.Query().Get("limit")) + search := req.URL.Query().Get("search") + + reverseDNS, err := ReverseDNSFromURL(h.serverURL) + if err != nil { + return fmt.Errorf("failed to generate reverse DNS: %w", err) + } + + // Collect servers based on ACR resources + servers, err := h.collectServersFromACR(req, acr, reverseDNS) + if err != nil { + return err + } + + // Apply search filter if provided + if search != "" { + servers = filterServersBySearch(servers, search) + } + + // Apply pagination + response := paginateServers(servers, cursor, limit) + + return req.Write(response) +} + +// isAuthorized checks if the request is authorized to access this ACR's registry +func (h *ACRHandler) isAuthorized(req api.Context, acr v1.AccessControlRule) bool { + if h.registryNoAuth { + // When auth is OFF, only allow if ACR has wildcard subject + return h.hasWildcardSubject(acr) + } + + // When auth is ON, check if user is targeted by ACR subjects + return h.userMatchesSubjects(req, acr) +} + +// hasWildcardSubject checks if the ACR targets all users via wildcard +func (h *ACRHandler) hasWildcardSubject(acr v1.AccessControlRule) bool { + for _, subject := range acr.Spec.Manifest.Subjects { + if subject.Type == types.SubjectTypeSelector && subject.ID == "*" { + return true + } + } + return false +} + +// userMatchesSubjects checks if the current user matches any of the ACR's subjects +func (h *ACRHandler) userMatchesSubjects(req api.Context, acr v1.AccessControlRule) bool { + userID := req.User.GetUID() + groups := authGroupSet(req.User) + + for _, subject := range acr.Spec.Manifest.Subjects { + switch subject.Type { + case types.SubjectTypeUser: + if subject.ID == userID { + return true + } + case types.SubjectTypeGroup: + if _, ok := groups[subject.ID]; ok { + return true + } + case types.SubjectTypeSelector: + if subject.ID == "*" { + return true + } + } + } + return false +} + +// collectServersFromACR collects servers/entries based on ACR resources +func (h *ACRHandler) collectServersFromACR(req api.Context, acr v1.AccessControlRule, reverseDNS string) ([]types.RegistryServerResponse, error) { + var result []types.RegistryServerResponse + userID := req.User.GetUID() + + // Track catalog entries that have been "overridden" by deployed servers + addedCatalogEntries := make(map[string]bool) + + // Determine scope: catalog or workspace + catalogID := acr.Spec.MCPCatalogID + workspaceID := acr.Spec.PowerUserWorkspaceID + + // If auth is ON, first check for user's deployed servers from entries in this ACR + if !h.registryNoAuth { + deployedServers, err := h.collectDeployedServersFromACREntries(req, acr, reverseDNS, userID) + if err != nil { + return nil, err + } + result = append(result, deployedServers...) + + // Mark which catalog entries are now "overridden" + for _, server := range deployedServers { + // The server name in registry format contains the original entry name + // We track by the MCPServerCatalogEntryName from the server spec + addedCatalogEntries[server.Server.Name] = true + } + } + + // Process each resource in the ACR + for _, resource := range acr.Spec.Manifest.Resources { + switch resource.Type { + case types.ResourceTypeMCPServerCatalogEntry: + // Skip if user has already deployed from this entry + if addedCatalogEntries[resource.ID] { + continue + } + server, err := h.fetchCatalogEntry(req, resource.ID, catalogID, workspaceID, reverseDNS) + if err != nil { + // Skip entries that can't be fetched + continue + } + result = append(result, server) + + case types.ResourceTypeMCPServer: + server, err := h.fetchMCPServer(req, resource.ID, catalogID, workspaceID, reverseDNS, userID) + if err != nil { + // Skip servers that can't be fetched + continue + } + result = append(result, server) + + case types.ResourceTypeSelector: + if resource.ID == "*" { + // Wildcard: include all servers/entries in the scope + if catalogID != "" { + servers, err := h.collectAllFromCatalog(req, catalogID, reverseDNS, userID, addedCatalogEntries) + if err != nil { + return nil, err + } + result = append(result, servers...) + } else if workspaceID != "" { + servers, err := h.collectAllFromWorkspace(req, workspaceID, reverseDNS, userID, addedCatalogEntries) + if err != nil { + return nil, err + } + result = append(result, servers...) + } + } + } + } + + return result, nil +} + +// collectDeployedServersFromACREntries finds user's deployed servers that came from entries in this ACR +func (h *ACRHandler) collectDeployedServersFromACREntries(req api.Context, acr v1.AccessControlRule, reverseDNS, userID string) ([]types.RegistryServerResponse, error) { + var result []types.RegistryServerResponse + + // Build set of catalog entry names in this ACR + entryNames := make(map[string]bool) + hasWildcard := false + for _, resource := range acr.Spec.Manifest.Resources { + if resource.Type == types.ResourceTypeMCPServerCatalogEntry { + entryNames[resource.ID] = true + } else if resource.Type == types.ResourceTypeSelector && resource.ID == "*" { + hasWildcard = true + } + } + + // If no entries and no wildcard, nothing to check + if len(entryNames) == 0 && !hasWildcard { + return result, nil + } + + // List user's personal servers + var serverList v1.MCPServerList + if err := req.Storage.List(req.Context(), &serverList, &kclient.ListOptions{ + Namespace: system.DefaultNamespace, + FieldSelector: fields.SelectorFromSet(map[string]string{ + "spec.userID": userID, + "spec.mcpCatalogID": "", + "spec.powerUserWorkspaceID": "", + }), + }); err != nil { + return nil, fmt.Errorf("failed to list personal servers: %w", err) + } + + for _, server := range serverList.Items { + // Skip templates and components + if server.Spec.Template || server.Spec.CompositeName != "" { + continue + } + + // Check if this server was deployed from an entry in this ACR + entryName := server.Spec.MCPServerCatalogEntryName + if entryName == "" { + continue + } + + // Fetch the catalog entry to verify its scope matches the ACR's scope + var entry v1.MCPServerCatalogEntry + if err := req.Get(&entry, entryName); err != nil { + // Entry not found, skip this server + continue + } + + // Verify the entry's scope matches the ACR's scope + if acr.Spec.MCPCatalogID != "" { + // ACR is catalog-scoped, entry must be in the same catalog + if entry.Spec.MCPCatalogName != acr.Spec.MCPCatalogID { + continue + } + } else if acr.Spec.PowerUserWorkspaceID != "" { + // ACR is workspace-scoped, entry must be in the same workspace + if entry.Spec.PowerUserWorkspaceID != acr.Spec.PowerUserWorkspaceID { + continue + } + } else { + // ACR has no scope, this shouldn't happen but skip to be safe + continue + } + + // Check if entry is in ACR (directly or via wildcard) + inACR := entryNames[entryName] || hasWildcard + if !inACR { + continue + } + + // Get slug for this server + slug, err := handlers.SlugForMCPServer(req.Context(), req.Storage, server, userID, "", "") + if err != nil { + continue + } + + // Get credentials + credEnv, _ := h.getCredentialsForServer(req, server, userID, "", "") + + converted, err := ConvertMCPServerToRegistry(req.Context(), server, credEnv, h.serverURL, slug, reverseDNS, userID, h.mimeFetcher) + if err != nil { + continue + } + result = append(result, converted) + } + + return result, nil +} + +// fetchCatalogEntry fetches a single catalog entry and converts to registry format +func (h *ACRHandler) fetchCatalogEntry(req api.Context, entryName, catalogID, workspaceID, reverseDNS string) (types.RegistryServerResponse, error) { + var entry v1.MCPServerCatalogEntry + if err := req.Get(&entry, entryName); err != nil { + return types.RegistryServerResponse{}, fmt.Errorf("entry not found") + } + + // Verify scope matches + if catalogID != "" && entry.Spec.MCPCatalogName != catalogID { + return types.RegistryServerResponse{}, fmt.Errorf("entry not in catalog") + } + if workspaceID != "" && entry.Spec.PowerUserWorkspaceID != workspaceID { + return types.RegistryServerResponse{}, fmt.Errorf("entry not in workspace") + } + + return ConvertMCPServerCatalogEntryToRegistry(req.Context(), entry, h.serverURL, reverseDNS, h.mimeFetcher) +} + +// fetchMCPServer fetches a single MCP server and converts to registry format +func (h *ACRHandler) fetchMCPServer(req api.Context, serverName, catalogID, workspaceID, reverseDNS, userID string) (types.RegistryServerResponse, error) { + var server v1.MCPServer + if err := req.Get(&server, serverName); err != nil { + return types.RegistryServerResponse{}, fmt.Errorf("server not found") + } + + // Skip templates and components + if server.Spec.Template || server.Spec.CompositeName != "" { + return types.RegistryServerResponse{}, fmt.Errorf("server is template or component") + } + + // Verify scope matches + if catalogID != "" && server.Spec.MCPCatalogID != catalogID { + return types.RegistryServerResponse{}, fmt.Errorf("server not in catalog") + } + if workspaceID != "" && server.Spec.PowerUserWorkspaceID != workspaceID { + return types.RegistryServerResponse{}, fmt.Errorf("server not in workspace") + } + + // Get slug + var ( + slug string + err error + ) + if catalogID != "" { + slug, err = handlers.SlugForMCPServer(req.Context(), req.Storage, server, "", catalogID, "") + } else if workspaceID != "" { + slug, err = handlers.SlugForMCPServer(req.Context(), req.Storage, server, "", "", workspaceID) + } else { + return types.RegistryServerResponse{}, fmt.Errorf("no scope for server") + } + if err != nil { + return types.RegistryServerResponse{}, fmt.Errorf("failed to generate slug") + } + + // Get credentials + credEnv, _ := h.getCredentialsForServer(req, server, "", catalogID, workspaceID) + + return ConvertMCPServerToRegistry(req.Context(), server, credEnv, h.serverURL, slug, reverseDNS, userID, h.mimeFetcher) +} + +// collectAllFromCatalog collects all entries and servers from a catalog +func (h *ACRHandler) collectAllFromCatalog(req api.Context, catalogID, reverseDNS, userID string, exclude map[string]bool) ([]types.RegistryServerResponse, error) { + var result []types.RegistryServerResponse + + // List catalog entries + var entryList v1.MCPServerCatalogEntryList + if err := req.Storage.List(req.Context(), &entryList, &kclient.ListOptions{ + Namespace: system.DefaultNamespace, + FieldSelector: fields.SelectorFromSet(map[string]string{ + "spec.mcpCatalogName": catalogID, + }), + }); err != nil { + return nil, fmt.Errorf("failed to list catalog entries: %w", err) + } + + for _, entry := range entryList.Items { + if exclude[entry.Name] { + continue + } + converted, err := ConvertMCPServerCatalogEntryToRegistry(req.Context(), entry, h.serverURL, reverseDNS, h.mimeFetcher) + if err != nil { + continue + } + result = append(result, converted) + } + + // List servers in catalog + var serverList v1.MCPServerList + if err := req.Storage.List(req.Context(), &serverList, &kclient.ListOptions{ + Namespace: system.DefaultNamespace, + FieldSelector: fields.SelectorFromSet(map[string]string{ + "spec.mcpCatalogID": catalogID, + }), + }); err != nil { + return nil, fmt.Errorf("failed to list catalog servers: %w", err) + } + + for _, server := range serverList.Items { + if server.Spec.Template || server.Spec.CompositeName != "" { + continue + } + slug, err := handlers.SlugForMCPServer(req.Context(), req.Storage, server, "", catalogID, "") + if err != nil { + continue + } + credEnv, _ := h.getCredentialsForServer(req, server, "", catalogID, "") + converted, err := ConvertMCPServerToRegistry(req.Context(), server, credEnv, h.serverURL, slug, reverseDNS, userID, h.mimeFetcher) + if err != nil { + continue + } + result = append(result, converted) + } + + return result, nil +} + +// collectAllFromWorkspace collects all entries and servers from a workspace +func (h *ACRHandler) collectAllFromWorkspace(req api.Context, workspaceID, reverseDNS, userID string, exclude map[string]bool) ([]types.RegistryServerResponse, error) { + var result []types.RegistryServerResponse + + // List workspace entries + var entryList v1.MCPServerCatalogEntryList + if err := req.Storage.List(req.Context(), &entryList, &kclient.ListOptions{ + Namespace: system.DefaultNamespace, + FieldSelector: fields.SelectorFromSet(map[string]string{ + "spec.powerUserWorkspaceID": workspaceID, + }), + }); err != nil { + return nil, fmt.Errorf("failed to list workspace entries: %w", err) + } + + for _, entry := range entryList.Items { + if exclude[entry.Name] { + continue + } + converted, err := ConvertMCPServerCatalogEntryToRegistry(req.Context(), entry, h.serverURL, reverseDNS, h.mimeFetcher) + if err != nil { + continue + } + result = append(result, converted) + } + + // List workspace servers + var serverList v1.MCPServerList + if err := req.Storage.List(req.Context(), &serverList, &kclient.ListOptions{ + Namespace: system.DefaultNamespace, + FieldSelector: fields.SelectorFromSet(map[string]string{ + "spec.powerUserWorkspaceID": workspaceID, + }), + }); err != nil { + return nil, fmt.Errorf("failed to list workspace servers: %w", err) + } + + for _, server := range serverList.Items { + if server.Spec.Template || server.Spec.CompositeName != "" { + continue + } + slug, err := handlers.SlugForMCPServer(req.Context(), req.Storage, server, "", "", workspaceID) + if err != nil { + continue + } + credEnv, _ := h.getCredentialsForServer(req, server, "", "", workspaceID) + converted, err := ConvertMCPServerToRegistry(req.Context(), server, credEnv, h.serverURL, slug, reverseDNS, userID, h.mimeFetcher) + if err != nil { + continue + } + result = append(result, converted) + } + + return result, nil +} + +// getCredentialsForServer retrieves credentials for a server +func (h *ACRHandler) getCredentialsForServer(req api.Context, server v1.MCPServer, userID, catalogID, workspaceID string) (map[string]string, error) { + var ctx string + if catalogID != "" { + ctx = fmt.Sprintf("%s-%s", catalogID, server.Name) + } else if workspaceID != "" { + ctx = fmt.Sprintf("%s-%s", workspaceID, server.Name) + } else if userID != "" { + ctx = fmt.Sprintf("%s-%s", userID, server.Name) + } else { + ctx = fmt.Sprintf("%s-%s", server.Spec.UserID, server.Name) + } + + revealed, err := req.GPTClient.RevealCredential(req.Context(), []string{ctx}, server.Name) + if err != nil { + return make(map[string]string), nil + } + return revealed.Env, nil +} + +// ListServerVersions handles GET /mcp-registry/{acr_id}/v0.1/servers/{serverName}/versions +func (h *ACRHandler) ListServerVersions(req api.Context) error { + acrID := req.PathValue("acr_id") + serverName := req.PathValue("serverName") + + if acrID == "" { + return h.notFoundError("access control rule ID is required") + } + if serverName == "" { + return h.notFoundError("serverName is required") + } + + // Fetch the ACR + var acr v1.AccessControlRule + if err := req.Get(&acr, acrID); err != nil { + return h.notFoundError("access control rule not found") + } + + // Check authorization + if !h.isAuthorized(req, acr) { + return h.notFoundError("access control rule not found") + } + + // Parse reverse DNS and actual server name + parts := strings.SplitN(serverName, "/", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return h.notFoundError("Invalid server name format. Expected: reverseDNS/serverName") + } + reverseDNS, actualServerName := parts[0], parts[1] + + // Find the server within this ACR's scope + server, err := h.findServerInACR(req, acr, actualServerName, reverseDNS) + if err != nil { + return h.notFoundError("Server not found") + } + + // Return as a ServerList with single item + response := types.RegistryServerList{ + Servers: []types.RegistryServerResponse{server}, + Metadata: &types.RegistryServerListMetadata{ + Count: 1, + }, + } + + return req.Write(response) +} + +// GetServerVersion handles GET /mcp-registry/{acr_id}/v0.1/servers/{serverName}/versions/{version} +func (h *ACRHandler) GetServerVersion(req api.Context) error { + acrID := req.PathValue("acr_id") + serverName := req.PathValue("serverName") + version := req.PathValue("version") + + if acrID == "" { + return h.notFoundError("access control rule ID is required") + } + if serverName == "" { + return h.notFoundError("serverName is required") + } + if version == "" { + return h.notFoundError("version is required") + } + + // Only support "latest" version + if version != "latest" { + return h.notFoundError("Version not found. Only 'latest' is supported.") + } + + // Fetch the ACR + var acr v1.AccessControlRule + if err := req.Get(&acr, acrID); err != nil { + return h.notFoundError("access control rule not found") + } + + // Check authorization + if !h.isAuthorized(req, acr) { + return h.notFoundError("access control rule not found") + } + + // Parse reverse DNS and actual server name + parts := strings.SplitN(serverName, "/", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return h.notFoundError("Invalid server name format. Expected: reverseDNS/serverName") + } + reverseDNS, actualServerName := parts[0], parts[1] + + // Find the server within this ACR's scope + server, err := h.findServerInACR(req, acr, actualServerName, reverseDNS) + if err != nil { + return h.notFoundError("Server not found") + } + + return req.Write(server) +} + +// findServerInACR finds a specific server within the ACR's resource scope +func (h *ACRHandler) findServerInACR(req api.Context, acr v1.AccessControlRule, serverName, reverseDNS string) (types.RegistryServerResponse, error) { + userID := req.User.GetUID() + catalogID := acr.Spec.MCPCatalogID + workspaceID := acr.Spec.PowerUserWorkspaceID + + // Check if this server is in the ACR's resources + if !h.serverInACRResources(acr, serverName) { + return types.RegistryServerResponse{}, fmt.Errorf("server not in ACR") + } + + // Determine if this is an MCPServer or MCPServerCatalogEntry + if system.IsMCPServerID(serverName) { + return h.fetchMCPServer(req, serverName, catalogID, workspaceID, reverseDNS, userID) + } + return h.fetchCatalogEntry(req, serverName, catalogID, workspaceID, reverseDNS) +} + +// serverInACRResources checks if a server/entry is included in the ACR's resources +func (h *ACRHandler) serverInACRResources(acr v1.AccessControlRule, serverName string) bool { + for _, resource := range acr.Spec.Manifest.Resources { + switch resource.Type { + case types.ResourceTypeMCPServer, types.ResourceTypeMCPServerCatalogEntry: + if resource.ID == serverName { + return true + } + case types.ResourceTypeSelector: + if resource.ID == "*" { + return true + } + } + } + return false +} + +// notFoundError returns a standard 404 error +func (h *ACRHandler) notFoundError(detail string) error { + return &types.ErrHTTP{ + Code: http.StatusNotFound, + Message: fmt.Sprintf(`{"title":"Not Found","status":404,"detail":"%s"}`, detail), + } +} + +// authGroupSet extracts auth groups from user info (reuse from helper.go) +func authGroupSet(user interface{ GetExtra() map[string][]string }) map[string]struct{} { + extra := user.GetExtra() + groups := extra["auth_provider_groups"] + set := make(map[string]struct{}, len(groups)) + for _, group := range groups { + set[group] = struct{}{} + } + return set +} diff --git a/pkg/api/handlers/registry/acr_handler_test.go b/pkg/api/handlers/registry/acr_handler_test.go new file mode 100644 index 0000000000..025760e3e8 --- /dev/null +++ b/pkg/api/handlers/registry/acr_handler_test.go @@ -0,0 +1,588 @@ +package registry + +import ( + "testing" + + "github.com/obot-platform/obot/apiclient/types" + "github.com/obot-platform/obot/pkg/api" + v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1" + "k8s.io/apiserver/pkg/authentication/user" +) + +// mockUser implements user.Info for testing +type mockUser struct { + uid string + groups []string + extra map[string][]string +} + +func (m *mockUser) GetName() string { return m.uid } +func (m *mockUser) GetUID() string { return m.uid } +func (m *mockUser) GetGroups() []string { return m.groups } +func (m *mockUser) GetExtra() map[string][]string { return m.extra } +func (m *mockUser) GetAuthID() string { return "" } +func (m *mockUser) GetAuthentications() []string { return nil } +func (m *mockUser) GetPersistentIdentity() (string, bool) { return "", false } + +func TestHasWildcardSubject(t *testing.T) { + handler := &ACRHandler{} + + tests := []struct { + name string + acr v1.AccessControlRule + expected bool + }{ + { + name: "has wildcard selector", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{ + {Type: types.SubjectTypeSelector, ID: "*"}, + }, + }, + }, + }, + expected: true, + }, + { + name: "has wildcard among other subjects", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{ + {Type: types.SubjectTypeUser, ID: "user-123"}, + {Type: types.SubjectTypeSelector, ID: "*"}, + }, + }, + }, + }, + expected: true, + }, + { + name: "no wildcard selector", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{ + {Type: types.SubjectTypeUser, ID: "user-123"}, + {Type: types.SubjectTypeGroup, ID: "group-456"}, + }, + }, + }, + }, + expected: false, + }, + { + name: "empty subjects", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{}, + }, + }, + }, + expected: false, + }, + { + name: "selector with non-wildcard ID", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{ + {Type: types.SubjectTypeSelector, ID: "some-other-selector"}, + }, + }, + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := handler.hasWildcardSubject(tt.acr) + if result != tt.expected { + t.Errorf("hasWildcardSubject() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestUserMatchesSubjects(t *testing.T) { + handler := &ACRHandler{} + + tests := []struct { + name string + acr v1.AccessControlRule + user user.Info + expected bool + }{ + { + name: "user ID matches", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{ + {Type: types.SubjectTypeUser, ID: "user-123"}, + }, + }, + }, + }, + user: &mockUser{uid: "user-123"}, + expected: true, + }, + { + name: "user group matches", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{ + {Type: types.SubjectTypeGroup, ID: "developers"}, + }, + }, + }, + }, + user: &mockUser{ + uid: "user-456", + extra: map[string][]string{ + "auth_provider_groups": {"developers", "admins"}, + }, + }, + expected: true, + }, + { + name: "wildcard selector matches", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{ + {Type: types.SubjectTypeSelector, ID: "*"}, + }, + }, + }, + }, + user: &mockUser{uid: "any-user"}, + expected: true, + }, + { + name: "no match", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{ + {Type: types.SubjectTypeUser, ID: "user-123"}, + {Type: types.SubjectTypeGroup, ID: "admins"}, + }, + }, + }, + }, + user: &mockUser{ + uid: "user-456", + extra: map[string][]string{ + "auth_provider_groups": {"developers"}, + }, + }, + expected: false, + }, + { + name: "multiple subjects, one matches", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{ + {Type: types.SubjectTypeUser, ID: "user-123"}, + {Type: types.SubjectTypeUser, ID: "user-456"}, + {Type: types.SubjectTypeGroup, ID: "admins"}, + }, + }, + }, + }, + user: &mockUser{uid: "user-456"}, + expected: true, + }, + { + name: "empty subjects", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{}, + }, + }, + }, + user: &mockUser{uid: "user-123"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a minimal api.Context with just the user + ctx := api.Context{ + User: tt.user, + } + result := handler.userMatchesSubjects(ctx, tt.acr) + if result != tt.expected { + t.Errorf("userMatchesSubjects() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestIsAuthorized(t *testing.T) { + tests := []struct { + name string + registryNoAuth bool + acr v1.AccessControlRule + user user.Info + expected bool + }{ + { + name: "auth OFF - wildcard subject", + registryNoAuth: true, + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{ + {Type: types.SubjectTypeSelector, ID: "*"}, + }, + }, + }, + }, + user: &mockUser{uid: "any-user"}, + expected: true, + }, + { + name: "auth OFF - no wildcard subject", + registryNoAuth: true, + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{ + {Type: types.SubjectTypeUser, ID: "user-123"}, + }, + }, + }, + }, + user: &mockUser{uid: "user-123"}, + expected: false, + }, + { + name: "auth ON - user matches", + registryNoAuth: false, + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{ + {Type: types.SubjectTypeUser, ID: "user-123"}, + }, + }, + }, + }, + user: &mockUser{uid: "user-123"}, + expected: true, + }, + { + name: "auth ON - user doesn't match", + registryNoAuth: false, + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{ + {Type: types.SubjectTypeUser, ID: "user-123"}, + }, + }, + }, + }, + user: &mockUser{uid: "user-456"}, + expected: false, + }, + { + name: "auth ON - group matches", + registryNoAuth: false, + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{ + {Type: types.SubjectTypeGroup, ID: "developers"}, + }, + }, + }, + }, + user: &mockUser{ + uid: "user-123", + extra: map[string][]string{ + "auth_provider_groups": {"developers"}, + }, + }, + expected: true, + }, + { + name: "auth ON - wildcard matches", + registryNoAuth: false, + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Subjects: []types.Subject{ + {Type: types.SubjectTypeSelector, ID: "*"}, + }, + }, + }, + }, + user: &mockUser{uid: "any-user"}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := &ACRHandler{ + registryNoAuth: tt.registryNoAuth, + } + ctx := api.Context{ + User: tt.user, + } + result := handler.isAuthorized(ctx, tt.acr) + if result != tt.expected { + t.Errorf("isAuthorized() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestServerInACRResources(t *testing.T) { + handler := &ACRHandler{} + + tests := []struct { + name string + acr v1.AccessControlRule + serverName string + expected bool + }{ + { + name: "server explicitly listed", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Resources: []types.Resource{ + {Type: types.ResourceTypeMCPServer, ID: "mcp-server-abc123"}, + }, + }, + }, + }, + serverName: "mcp-server-abc123", + expected: true, + }, + { + name: "catalog entry explicitly listed", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Resources: []types.Resource{ + {Type: types.ResourceTypeMCPServerCatalogEntry, ID: "filesystem"}, + }, + }, + }, + }, + serverName: "filesystem", + expected: true, + }, + { + name: "wildcard selector includes all", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Resources: []types.Resource{ + {Type: types.ResourceTypeSelector, ID: "*"}, + }, + }, + }, + }, + serverName: "any-server", + expected: true, + }, + { + name: "server not in list", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Resources: []types.Resource{ + {Type: types.ResourceTypeMCPServer, ID: "mcp-server-abc123"}, + {Type: types.ResourceTypeMCPServerCatalogEntry, ID: "filesystem"}, + }, + }, + }, + }, + serverName: "other-server", + expected: false, + }, + { + name: "empty resources", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Resources: []types.Resource{}, + }, + }, + }, + serverName: "any-server", + expected: false, + }, + { + name: "multiple resources, one matches", + acr: v1.AccessControlRule{ + Spec: v1.AccessControlRuleSpec{ + Manifest: types.AccessControlRuleManifest{ + Resources: []types.Resource{ + {Type: types.ResourceTypeMCPServer, ID: "server-1"}, + {Type: types.ResourceTypeMCPServer, ID: "server-2"}, + {Type: types.ResourceTypeMCPServerCatalogEntry, ID: "filesystem"}, + }, + }, + }, + }, + serverName: "server-2", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := handler.serverInACRResources(tt.acr, tt.serverName) + if result != tt.expected { + t.Errorf("serverInACRResources() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestAuthGroupSet(t *testing.T) { + tests := []struct { + name string + user user.Info + expected map[string]struct{} + }{ + { + name: "single group", + user: &mockUser{ + extra: map[string][]string{ + "auth_provider_groups": {"developers"}, + }, + }, + expected: map[string]struct{}{ + "developers": {}, + }, + }, + { + name: "multiple groups", + user: &mockUser{ + extra: map[string][]string{ + "auth_provider_groups": {"developers", "admins", "users"}, + }, + }, + expected: map[string]struct{}{ + "developers": {}, + "admins": {}, + "users": {}, + }, + }, + { + name: "no groups", + user: &mockUser{ + extra: map[string][]string{}, + }, + expected: map[string]struct{}{}, + }, + { + name: "nil extra", + user: &mockUser{ + extra: nil, + }, + expected: map[string]struct{}{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := authGroupSet(tt.user) + if len(result) != len(tt.expected) { + t.Errorf("authGroupSet() length = %v, want %v", len(result), len(tt.expected)) + return + } + for group := range tt.expected { + if _, ok := result[group]; !ok { + t.Errorf("authGroupSet() missing group %v", group) + } + } + }) + } +} + +func TestNotFoundError(t *testing.T) { + handler := &ACRHandler{} + + tests := []struct { + name string + detail string + wantCode int + wantMsg string + }{ + { + name: "simple message", + detail: "server not found", + wantCode: 404, + wantMsg: `{"title":"Not Found","status":404,"detail":"server not found"}`, + }, + { + name: "empty message", + detail: "", + wantCode: 404, + wantMsg: `{"title":"Not Found","status":404,"detail":""}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := handler.notFoundError(tt.detail) + if err == nil { + t.Fatal("notFoundError() returned nil") + } + + httpErr, ok := err.(*types.ErrHTTP) + if !ok { + t.Fatalf("notFoundError() did not return *types.ErrHTTP") + } + + if httpErr.Code != tt.wantCode { + t.Errorf("notFoundError() code = %v, want %v", httpErr.Code, tt.wantCode) + } + + if httpErr.Message != tt.wantMsg { + t.Errorf("notFoundError() message = %v, want %v", httpErr.Message, tt.wantMsg) + } + }) + } +} + +func TestNewACRHandler(t *testing.T) { + serverURL := "https://obot.example.com" + registryNoAuth := false + + handler := NewACRHandler(nil, serverURL, registryNoAuth) + + if handler == nil { + t.Fatal("NewACRHandler() returned nil") + } + + if handler.serverURL != serverURL { + t.Errorf("NewACRHandler() serverURL = %v, want %v", handler.serverURL, serverURL) + } + + if handler.registryNoAuth != registryNoAuth { + t.Errorf("NewACRHandler() registryNoAuth = %v, want %v", handler.registryNoAuth, registryNoAuth) + } + + if handler.mimeFetcher == nil { + t.Error("NewACRHandler() mimeFetcher is nil") + } +} diff --git a/pkg/api/handlers/wellknown/handler.go b/pkg/api/handlers/wellknown/handler.go index 25eea642d4..711baae1a4 100644 --- a/pkg/api/handlers/wellknown/handler.go +++ b/pkg/api/handlers/wellknown/handler.go @@ -25,6 +25,7 @@ func SetupHandlers(baseURL string, config services.OAuthAuthorizationServerConfi mux.HandleFunc("GET /.well-known/oauth-authorization-server/mcp-connect/{mcp_id}", h.oauthAuthorization) mux.HandleFunc("GET /.well-known/oauth-protected-resource/v0.1/servers", h.registryOAuthProtectedResource) + mux.HandleFunc("GET /.well-known/oauth-protected-resource/mcp-registry/{acr_id}/v0.1/servers", h.acrRegistryOAuthProtectedResource) // These will allow clients that don't follow the WWW-Authenticate header to connect to the MCP gateway. // Such clients won't be able to do the second-level OAuth, but will be able to connect to all MCP servers diff --git a/pkg/api/handlers/wellknown/oauth.go b/pkg/api/handlers/wellknown/oauth.go index f7bbbd314b..f57d5d2ad7 100644 --- a/pkg/api/handlers/wellknown/oauth.go +++ b/pkg/api/handlers/wellknown/oauth.go @@ -57,3 +57,28 @@ func (h *handler) registryOAuthProtectedResource(req api.Context) error { "bearer_methods_supported": ["header"] }`, h.baseURL)) } + +func (h *handler) acrRegistryOAuthProtectedResource(req api.Context) error { + // Return 404 if registry is in no-auth mode + if h.registryNoAuth { + return &types.ErrHTTP{ + Code: http.StatusNotFound, + Message: "Registry OAuth is not available when registry authentication is disabled", + } + } + + acrID := req.PathValue("acr_id") + if acrID == "" { + return &types.ErrHTTP{ + Code: http.StatusNotFound, + Message: "ACR ID is required", + } + } + + // Return the same OAuth metadata as root registry, but scoped to the ACR path + return req.Write(fmt.Sprintf(`{ + "resource": "%s/mcp-registry/%s/v0.1/servers", + "authorization_servers": ["%[1]s"], + "bearer_methods_supported": ["header"] +}`, h.baseURL, acrID)) +} diff --git a/pkg/api/router/router.go b/pkg/api/router/router.go index b3d481402a..9de09b8453 100644 --- a/pkg/api/router/router.go +++ b/pkg/api/router/router.go @@ -77,6 +77,7 @@ func Router(ctx context.Context, services *services.Services) (http.Handler, err userDefaultRoleSettings := handlers.NewUserDefaultRoleSettingHandler() setupHandler := setup.NewHandler(services.ServerURL) registryHandler := registry.NewHandler(services.AccessControlRuleHelper, services.ServerURL, services.RegistryNoAuth) + acrRegistryHandler := registry.NewACRHandler(services.AccessControlRuleHelper, services.ServerURL, services.RegistryNoAuth) // Version mux.HandleFunc("GET /api/version", version.GetVersion) @@ -593,6 +594,11 @@ func Router(ctx context.Context, services *services.Services) (http.Handler, err mux.HandleFunc("GET /v0.1/servers/{serverName}/versions", registryHandler.ListServerVersions) mux.HandleFunc("GET /v0.1/servers/{serverName}/versions/{version}", registryHandler.GetServerVersion) + // Per-ACR Registry API + mux.HandleFunc("GET /mcp-registry/{acr_id}/v0.1/servers", acrRegistryHandler.ListServers) + mux.HandleFunc("GET /mcp-registry/{acr_id}/v0.1/servers/{serverName}/versions", acrRegistryHandler.ListServerVersions) + mux.HandleFunc("GET /mcp-registry/{acr_id}/v0.1/servers/{serverName}/versions/{version}", acrRegistryHandler.GetServerVersion) + // MCP Audit Logs mux.HandleFunc("GET /api/mcp-audit-logs", mcpAuditLogs.ListAuditLogs) mux.HandleFunc("POST /api/mcp-audit-logs", mcpAuditLogs.SubmitAuditLogs) diff --git a/pkg/api/server/server.go b/pkg/api/server/server.go index 525a9fb32d..273b54c76c 100644 --- a/pkg/api/server/server.go +++ b/pkg/api/server/server.go @@ -166,8 +166,18 @@ func (s *Server) Wrap(f api.HandlerFunc) http.HandlerFunc { } // Only set WWW-Authenticate if not in no-auth mode - if strings.HasPrefix(req.URL.Path, "/v0.1") && !s.registryNoAuth { - rw.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="MCP Registry", resource_metadata="%s/.well-known/oauth-protected-resource/v0.1/servers"`, strings.TrimSuffix(s.baseURL, "/api"))) + if !s.registryNoAuth { + if strings.HasPrefix(req.URL.Path, "/v0.1") { + rw.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="MCP Registry", resource_metadata="%s/.well-known/oauth-protected-resource/v0.1/servers"`, strings.TrimSuffix(s.baseURL, "/api"))) + } else if strings.HasPrefix(req.URL.Path, "/mcp-registry/") { + // Per-ACR registry path - extract ACR ID and set WWW-Authenticate header + // Path format: /mcp-registry/{acr_id}/v0.1/... + pathParts := strings.SplitN(strings.TrimPrefix(req.URL.Path, "/mcp-registry/"), "/", 2) + if len(pathParts) > 0 && pathParts[0] != "" { + acrID := pathParts[0] + rw.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="MCP Registry", resource_metadata="%s/.well-known/oauth-protected-resource/mcp-registry/%s/v0.1/servers"`, strings.TrimSuffix(s.baseURL, "/api"), acrID)) + } + } } if authenticated {