Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 11 additions & 95 deletions internal/server/routed.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ import (
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"

"github.com/github/gh-aw-mcpg/internal/httputil"
"github.com/github/gh-aw-mcpg/internal/logger"
"github.com/github/gh-aw-mcpg/internal/syncutil"
sdk "github.com/modelcontextprotocol/go-sdk/mcp"
)

Expand All @@ -30,86 +29,11 @@ func rejectIfShutdown(unifiedServer *UnifiedServer, next http.Handler, logNamesp
})
}

// filteredServerCacheMaxSize is the maximum number of entries the filteredServerCache
// will hold. When the cache is full, the least-recently-used entry is evicted to make room.
// filteredServerCacheMaxSize is the maximum number of entries the filtered
// server cache will hold. When the cache is full, the least-recently-used entry
// is evicted to make room.
const filteredServerCacheMaxSize = 1000

// filteredServerCache caches filtered server instances per (backend, session) key.
// Entries are evicted after the configured TTL to prevent unbounded memory growth
// in long-running deployments with many sessions. A max-size cap provides an additional
// safety guard against an unbounded number of unique sessions.
type filteredServerCache struct {
servers map[string]*filteredServerEntry
ttl time.Duration
maxSize int
mu sync.RWMutex
}

type filteredServerEntry struct {
server *sdk.Server
lastUsed time.Time
}

// newFilteredServerCache creates a new server cache with the given entry TTL.
func newFilteredServerCache(ttl time.Duration) *filteredServerCache {
logRouted.Printf("[CACHE] Creating filtered server cache: ttl=%s, maxSize=%d", ttl, filteredServerCacheMaxSize)
return &filteredServerCache{
servers: make(map[string]*filteredServerEntry),
ttl: ttl,
maxSize: filteredServerCacheMaxSize,
}
}

// getOrCreate returns a cached server or creates a new one.
// Expired entries are lazily evicted on each call. When the cache has reached its
// maximum size, the least-recently-used entry is evicted to make room.
func (c *filteredServerCache) getOrCreate(backendID, sessionID string, creator func() *sdk.Server) *sdk.Server {
key := fmt.Sprintf("%s/%s", backendID, sessionID)
now := time.Now()

c.mu.Lock()
defer c.mu.Unlock()

// Lazy eviction of expired entries
for k, entry := range c.servers {
if now.Sub(entry.lastUsed) > c.ttl {
logRouted.Printf("[CACHE] Evicting expired server: key=%s (idle %s)", truncateCacheKeyForLog(k), now.Sub(entry.lastUsed).Round(time.Second))
delete(c.servers, k)
}
}

if entry, ok := c.servers[key]; ok {
entry.lastUsed = now
logRouted.Printf("[CACHE] Cache hit: key=%s", truncateCacheKeyForLog(key))
return entry.server
}

// When at capacity after TTL eviction, evict the least-recently-used entry
// to bound memory growth reliably. This may interrupt an active session for
// the evicted (backend, session) pair, but is preferable to unbounded growth.
if len(c.servers) >= c.maxSize {
lruKey := ""
var lruTime time.Time
first := true
for k, entry := range c.servers {
if first || entry.lastUsed.Before(lruTime) {
lruKey = k
lruTime = entry.lastUsed
first = false
}
}
if lruKey != "" {
logRouted.Printf("[CACHE] Max size reached (%d), evicting LRU entry: key=%s (idle %s)", c.maxSize, truncateCacheKeyForLog(lruKey), now.Sub(lruTime).Round(time.Second))
delete(c.servers, lruKey)
}
}

logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, truncateSessionID(sessionID))
server := creator()
c.servers[key] = &filteredServerEntry{server: server, lastUsed: now}
return server
}

// CreateHTTPServerForRoutedMode creates an HTTP server for routed mode
// In routed mode, each backend is accessible at /mcp/<server>
// Multiple routes from the same Authorization header share a session
Expand All @@ -129,7 +53,8 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap
// TTL matches the SDK SessionTimeout so cache entries expire with sessions.
// Long-running agentic workflows (e.g. >30 min GitHub Actions jobs) need this
// to be at least as long as the workflow to avoid spurious "session not found" errors.
serverCache := newFilteredServerCache(sessionTimeout)
logRouted.Printf("[CACHE] Creating filtered server cache: ttl=%s, maxSize=%d", sessionTimeout, filteredServerCacheMaxSize)
serverCache := syncutil.NewTTLCache[string, *sdk.Server](sessionTimeout, filteredServerCacheMaxSize)

// Create a proxy for each backend server
for _, serverID := range allBackends {
Expand All @@ -143,10 +68,12 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap
return nil
}

// Return a cached filtered proxy server for this backend and session
// This ensures the same server instance is reused for all requests in a session
// Return a cached filtered proxy server for this backend and session.
// This ensures the same server instance is reused for all requests in a session.
sessionID := SessionIDFromContext(r.Context())
return serverCache.getOrCreate(backendID, sessionID, func() *sdk.Server {
cacheKey := fmt.Sprintf("%s/%s", backendID, sessionID)
return serverCache.GetOrCreate(cacheKey, func() *sdk.Server {
logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, truncateSessionID(sessionID))
return createFilteredServer(unifiedServer, backendID)
})
}, buildDefaultHandlerConfig(unifiedServer, sessionTimeout, defaultHandlerConfigOptions{
Expand Down Expand Up @@ -204,14 +131,3 @@ func createFilteredServer(unifiedServer *UnifiedServer, backendID string) *sdk.S

return server
}

// truncateCacheKeyForLog returns a log-safe version of a cache key of the form
// "backendID/sessionID" by truncating the session ID portion.
func truncateCacheKeyForLog(key string) string {
backendID, sessionID, found := strings.Cut(key, "/")
if !found {
return key
}

return fmt.Sprintf("%s/%s", backendID, truncateSessionID(sessionID))
}
81 changes: 0 additions & 81 deletions internal/server/routed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,87 +557,6 @@ func TestCreateFilteredServer_EdgeCases(t *testing.T) {
})
}

// TestFilteredServerCache_MaxSize verifies that when the cache is at capacity, the
// least-recently-used entry is evicted to make room for a new entry.
func TestFilteredServerCache_MaxSize(t *testing.T) {
assert := assert.New(t)

ttl := time.Hour
cache := newFilteredServerCache(ttl)
cache.maxSize = 3 // Use a small max for the test

callCount := 0
creator := func() *sdk.Server {
callCount++
return sdk.NewServer(&sdk.Implementation{Name: "test", Version: "1.0"}, &sdk.ServerOptions{})
}

// Fill the cache to max capacity
s1 := cache.getOrCreate("backend", "session1", creator)
s2 := cache.getOrCreate("backend", "session2", creator)
s3 := cache.getOrCreate("backend", "session3", creator)
assert.Equal(3, callCount, "Should have created 3 servers")
assert.NotNil(s1)
assert.NotNil(s2)
assert.NotNil(s3)
assert.Len(cache.servers, 3, "Cache should have 3 entries")

// Manually set lastUsed to ensure deterministic LRU ordering:
// session1 is least recently used, session3 is most recently used.
now := time.Now()
cache.servers["backend/session1"].lastUsed = now.Add(-3 * time.Millisecond)
cache.servers["backend/session2"].lastUsed = now.Add(-2 * time.Millisecond)
cache.servers["backend/session3"].lastUsed = now.Add(-1 * time.Millisecond)

// Adding a fourth entry should evict the LRU entry (session1) to stay within maxSize
s4 := cache.getOrCreate("backend", "session4", creator)
assert.Equal(4, callCount, "Should have created a 4th server")
assert.NotNil(s4)
assert.Len(cache.servers, 3, "Cache should maintain maxSize by evicting the LRU entry")

// session1 (LRU) should have been evicted
_, session1Exists := cache.servers["backend/session1"]
assert.False(session1Exists, "session1 (LRU) should have been evicted to make room")

// session2, session3, session4 should still be present
_, session2Exists := cache.servers["backend/session2"]
assert.True(session2Exists, "session2 should still be cached")
_, session3Exists := cache.servers["backend/session3"]
assert.True(session3Exists, "session3 should still be cached")
_, session4Exists := cache.servers["backend/session4"]
assert.True(session4Exists, "session4 should be cached")
}

// TestFilteredServerCache_TTLEviction verifies that expired entries are evicted.
func TestFilteredServerCache_TTLEviction(t *testing.T) {
assert := assert.New(t)

ttl := 100 * time.Millisecond
cache := newFilteredServerCache(ttl)

callCount := 0
creator := func() *sdk.Server {
callCount++
return sdk.NewServer(&sdk.Implementation{Name: "test", Version: "1.0"}, &sdk.ServerOptions{})
}

// Add an entry
cache.getOrCreate("backend", "session1", creator)
assert.Equal(1, callCount)
assert.Len(cache.servers, 1)

// Wait for TTL to expire (use generous margin to avoid CI flakiness)
time.Sleep(200 * time.Millisecond)

// Next call should evict the expired entry and create a new one
cache.getOrCreate("backend", "session2", creator)
assert.Equal(2, callCount, "Should have created a new server after TTL eviction")

// session1 should have been evicted during the lazy eviction scan
_, session1Exists := cache.servers["backend/session1"]
assert.False(session1Exists, "Expired session1 should have been evicted")
}

// TestRegisterToolWithoutValidation verifies that tools are registered on the server
// and that the wrapped handler forwards calls correctly via in-memory transport.
func TestRegisterToolWithoutValidation(t *testing.T) {
Expand Down
11 changes: 11 additions & 0 deletions internal/server/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ func truncateSessionID(sessionID string) string {
return strutil.Truncate(sessionID, 8)
}

// truncateCacheKeyForLog returns a log-safe version of a cache key of the form
// "backendID/sessionID" by truncating the session ID portion.
func truncateCacheKeyForLog(key string) string {
backendID, sessionID, found := strings.Cut(key, "/")
if !found {
return key
}

return fmt.Sprintf("%s/%s", backendID, truncateSessionID(sessionID))
}

// extractSessionIDFromRequest extracts the session ID from X-Agent-ID and
// Authorization headers. Returns "" if neither header is present or valid.
func extractSessionIDFromRequest(r *http.Request) string {
Expand Down
104 changes: 104 additions & 0 deletions internal/syncutil/ttl_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package syncutil

import (
"sync"
"time"
)

// ttlEntry holds a cached value together with its last-used timestamp for LRU
// tracking.
type ttlEntry[V any] struct {
value V
lastUsed time.Time
}

// TTLCache is a thread-safe generic get-or-create cache that combines lazy TTL
// eviction with an LRU size cap.
//
// On every [TTLCache.GetOrCreate] call, all entries whose idle time exceeds the
// configured TTL are evicted first. If the cache is still at capacity after TTL
// eviction, the least-recently-used entry is evicted to make room.
type TTLCache[K comparable, V any] struct {
mu sync.Mutex
entries map[K]*ttlEntry[V]
ttl time.Duration
maxSize int
nowFn func() time.Time
}

// NewTTLCache creates a new TTLCache with the given entry TTL and maximum size.
// Entries idle longer than ttl are evicted lazily; when the cache reaches
// maxSize the least-recently-used entry is evicted on the next GetOrCreate call.
func NewTTLCache[K comparable, V any](ttl time.Duration, maxSize int) *TTLCache[K, V] {
return newTTLCacheWithClock[K, V](ttl, maxSize, time.Now)
}

// newTTLCacheWithClock creates a TTLCache with an injectable clock function.
// Intended for use in unit tests only.
func newTTLCacheWithClock[K comparable, V any](ttl time.Duration, maxSize int, nowFn func() time.Time) *TTLCache[K, V] {
return &TTLCache[K, V]{
entries: make(map[K]*ttlEntry[V]),
ttl: ttl,
maxSize: maxSize,
nowFn: nowFn,
}
}

// GetOrCreate returns the cached value for key. If the key is not present, or
// has been evicted, create is called to produce a new value which is then
// stored and returned.
//
// On each call, all expired entries are lazily evicted before the lookup. If
// the cache has reached its capacity after TTL eviction, the LRU entry is
// removed to make room.
Comment thread
Copilot marked this conversation as resolved.
//
// create is called while the cache lock is held.
func (c *TTLCache[K, V]) GetOrCreate(key K, create func() V) V {
if c.maxSize <= 0 {
return create()
}
now := c.nowFn()
c.mu.Lock()
defer c.mu.Unlock()
// Lazy eviction of expired entries.
if c.ttl > 0 {
for k, entry := range c.entries {
if now.Sub(entry.lastUsed) > c.ttl {
delete(c.entries, k)
}
}
}

if entry, ok := c.entries[key]; ok {
entry.lastUsed = now
return entry.value
}

// LRU eviction when still at capacity after TTL sweep.
if len(c.entries) >= c.maxSize {
var lruKey K
var lruTime time.Time
first := true
for k, entry := range c.entries {
if first || entry.lastUsed.Before(lruTime) {
lruKey = k
lruTime = entry.lastUsed
first = false
}
}
if !first {
delete(c.entries, lruKey)
}
}

v := create()
c.entries[key] = &ttlEntry[V]{value: v, lastUsed: now}
return v
}

// Len returns the current number of entries held in the cache.
func (c *TTLCache[K, V]) Len() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.entries)
}
Loading
Loading