Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions examples/mcp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ import (
"google.golang.org/adk/tool/mcptoolset"
)

// This example demonstrates 2 ways to use MCP tools with ADK:
// To select between two, set AGENT_MODE="local" or "github" ("local" is default).
// This example demonstrates 3 ways to use MCP tools with ADK:
// To select between them, set AGENT_MODE="local", "github", or "custom-headers" ("local" is default).
//
// 1. in-memory MCP server:
// - define golang function (in this case -- GetWeather)
Expand All @@ -45,6 +45,12 @@ import (
// 2. GitHub's remote MCP server (https://github.com/github/github-mcp-server):
// - create http.Client with authenticated transport. In this case it's oauth2 transport with GitHub personal access token.
// - use `export GITHUB_PAT=...` to set GitHub personal access token.
//
// 3. MCP server with custom headers using HeaderProvider:
// - demonstrates how HeaderProvider injects context-aware headers for a real HTTP transport
// - uses the same GitHub MCP endpoint but leaves HTTPClient nil for brevity
// - headers are generated per tool call (e.g., auth token + user/session IDs)
// - use `export GITHUB_PAT=...` and optionally `CUSTOM_API_KEY=...`

type Input struct {
City string `json:"city" jsonschema:"city name"`
Expand Down Expand Up @@ -96,16 +102,44 @@ func main() {
log.Fatalf("Failed to create model: %v", err)
}

var transport mcp.Transport
if strings.ToLower(os.Getenv("AGENT_MODE")) == "github" {
transport = githubMCPTransport(ctx)
} else {
transport = localMCPTransport(ctx)
var mcpToolSet tool.Toolset
agentMode := strings.ToLower(os.Getenv("AGENT_MODE"))
if agentMode == "" {
agentMode = "local"
}

switch agentMode {
case "github":
mcpToolSet, err = mcptoolset.New(mcptoolset.Config{
Transport: githubMCPTransport(ctx),
})
case "custom-headers":
headerProvider := func(ctx agent.ReadonlyContext) map[string]string {
headers := make(map[string]string)

if pat := os.Getenv("GITHUB_PAT"); pat != "" {
headers["Authorization"] = "Bearer " + pat
}

if userID := ctx.UserID(); userID != "" {
headers["X-User-ID"] = userID
}

return headers
}

mcpToolSet, err = mcptoolset.New(mcptoolset.Config{
Transport: &mcp.StreamableClientTransport{
Endpoint: "https://api.githubcopilot.com/mcp/",
},
HeaderProvider: headerProvider,
})
default:
mcpToolSet, err = mcptoolset.New(mcptoolset.Config{
Transport: localMCPTransport(ctx),
})
}

mcpToolSet, err := mcptoolset.New(mcptoolset.Config{
Transport: transport,
})
if err != nil {
log.Fatalf("Failed to create MCP tool set: %v", err)
}
Expand Down
230 changes: 230 additions & 0 deletions tool/mcptoolset/session_manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package mcptoolset

import (
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"net/http"
"sort"
"strings"
"sync"
"time"

"github.com/modelcontextprotocol/go-sdk/mcp"
)

const (
defaultSessionKey = "default"
defaultPingTimeout = 2 * time.Second
)

// sessionManager manages MCP client sessions with header-based pooling
type sessionManager struct {
client *mcp.Client
transport mcp.Transport

mu sync.RWMutex
sessions map[string]*sessionEntry
}

type sessionEntry struct {
session *mcp.ClientSession
headers map[string]string
}

// newSessionManager creates a new session manager
func newSessionManager(client *mcp.Client, transport mcp.Transport) *sessionManager {
return &sessionManager{
client: client,
transport: transport,
sessions: make(map[string]*sessionEntry),
}
}

// headersAffectSession returns true only for HTTP-based transports where
// headers are actually used by the connection.
func (sm *sessionManager) headersAffectSession() bool {
switch sm.transport.(type) {
case *mcp.SSEClientTransport, *mcp.StreamableClientTransport:
return true
default:
return false
}
}

// generateSessionKey creates a hash-based key from headers
func (sm *sessionManager) generateSessionKey(headers map[string]string) string {
// For non-HTTP transports (e.g., stdio, in-memory), headers don't apply,
// so we always pool into the same session.
if !sm.headersAffectSession() {
return defaultSessionKey
}
if len(headers) == 0 {
return defaultSessionKey
}

keys := make([]string, 0, len(headers))
for k := range headers {
keys = append(keys, k)
}
sort.Strings(keys)

var pairs []string
for _, k := range keys {
pairs = append(pairs, fmt.Sprintf("%q:%q", k, headers[k]))
}
jsonStr := "{" + strings.Join(pairs, ",") + "}"

h := md5.Sum([]byte(jsonStr))
return hex.EncodeToString(h[:])
}

// GetSession returns a session for the given headers, creating if necessary
func (sm *sessionManager) GetSession(ctx context.Context, headers map[string]string) (*mcp.ClientSession, error) {
key := sm.generateSessionKey(headers)

sm.mu.RLock()
entry, ok := sm.sessions[key]
sm.mu.RUnlock()

if ok && sm.isSessionValid(ctx, entry.session) {
return entry.session, nil
}

sm.mu.Lock()
defer sm.mu.Unlock()

if entry, ok := sm.sessions[key]; ok && sm.isSessionValid(ctx, entry.session) {
return entry.session, nil
}

wrappedTransport := sm.wrapTransportWithHeaders(headers)

session, err := sm.client.Connect(ctx, wrappedTransport, nil)
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}

sm.sessions[key] = &sessionEntry{
session: session,
headers: headers,
}

return session, nil
}

// isSessionValid checks if a session is still usable
func (sm *sessionManager) isSessionValid(ctx context.Context, session *mcp.ClientSession) bool {
if session == nil {
return false
}

pingCtx := ctx
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
pingCtx, cancel = context.WithTimeout(ctx, defaultPingTimeout)
defer cancel()
}

if err := session.Ping(pingCtx, nil); err != nil {
return false
}
return true
}

// wrapTransportWithHeaders creates a transport that injects headers
func (sm *sessionManager) wrapTransportWithHeaders(headers map[string]string) mcp.Transport {
switch t := sm.transport.(type) {

case *mcp.SSEClientTransport:
return &mcp.SSEClientTransport{
Endpoint: t.Endpoint,
HTTPClient: wrapHTTPClient(t.HTTPClient, headers),
}

case *mcp.StreamableClientTransport:
return &mcp.StreamableClientTransport{
Endpoint: t.Endpoint,
HTTPClient: wrapHTTPClient(t.HTTPClient, headers),
}

default:
return sm.transport
}
}

func wrapHTTPClient(httpClient *http.Client, headers map[string]string) *http.Client {
if httpClient == nil {
httpClient = &http.Client{}
}

return &http.Client{
Transport: &headerTransport{
Base: httpClient.Transport,
Headers: headers,
},
CheckRedirect: httpClient.CheckRedirect,
Jar: httpClient.Jar,
Timeout: httpClient.Timeout,
}
}

// Close closes all sessions
func (sm *sessionManager) Close() error {
sm.mu.Lock()
defer sm.mu.Unlock()

var errs []error
for _, entry := range sm.sessions {
if err := entry.session.Close(); err != nil {
errs = append(errs, err)
}
}

sm.sessions = make(map[string]*sessionEntry)

if len(errs) > 0 {
return fmt.Errorf("errors closing sessions: %v", errs)
}
return nil
}

type headerTransport struct {
Base http.RoundTripper
Headers map[string]string
}

// RoundTrip adds the configured headers to the request.
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if len(t.Headers) == 0 {
return t.base().RoundTrip(req)
}

req2 := req.Clone(req.Context())
for key, value := range t.Headers {
req2.Header.Set(key, value)
}
return t.base().RoundTrip(req2)
}

func (t *headerTransport) base() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}
Loading