Skip to content

Commit 6a930ef

Browse files
authored
fix callback validation in mcp client (#696)
* fix: callback urls were being matched by multiple servers * feat: server url in state param * fix: validate nonces * remove legacy setting a client id manually * rename function to validateState * feat: add multiple servers to playground * fix some ui * fix: add check state and consume state rather than validate * more tests * changeset * add cleanup of expired keys * fix: changeset * fix: readme to remove custom authProvider wtffffff hahaha
1 parent 7360c51 commit 6a930ef

File tree

11 files changed

+1300
-580
lines changed

11 files changed

+1300
-580
lines changed

.changeset/nine-bottles-heal.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
"agents": patch
3+
---
4+
5+
Enables connecting to multiple MCP servers simultaneously and hardens OAuth state handling against replay/DoS attacks.
6+
7+
**Note:** Inflight OAuth flows that were initiated on a previous version will not complete after upgrading, as the state parameter format has changed. Users will need to restart the authentication flow.

examples/mcp-client/README.md

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,34 +22,17 @@ Then, follow the steps below to setup the client:
2222

2323
Tap "O + enter" to open the front end. It should list out all the tools, prompts, and resources available for each server added.
2424

25-
## Transport Configuration
25+
## Usage
2626

27-
The MCP client defaults to HTTP Streamable transport for better performance. You can specify transport type explicitly:
27+
The recommended way to add MCP servers is via `Agent.addMcpServer()`:
2828

2929
```typescript
30-
// Using MCPClientManager directly
31-
const mcpClient = new MCPClientManager("my-app", "1.0.0");
32-
33-
// HTTP Streamable transport (default, recommended)
34-
await mcpClient.connect(serverUrl, {
35-
transport: {
36-
type: "streamable-http",
37-
authProvider: myAuthProvider
38-
}
39-
});
40-
41-
// SSE transport (legacy compatibility)
42-
await mcpClient.connect(serverUrl, {
43-
transport: {
44-
type: "sse",
45-
authProvider: myAuthProvider
46-
}
47-
});
48-
49-
// Or using Agent.addMcpServer() (as shown in the example)
5030
export class MyAgent extends Agent<Env, never> {
5131
async addServer(name: string, url: string, callbackHost: string) {
32+
// Uses HTTP Streamable transport by default
5233
await this.addMcpServer(name, url, callbackHost);
5334
}
5435
}
5536
```
37+
38+
The MCP client handles OAuth authentication automatically using the built-in `DurableObjectOAuthClientProvider`.

packages/agents/src/mcp/client.ts

Lines changed: 57 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -635,46 +635,81 @@ export class MCPClientManager {
635635
}
636636
}
637637

638+
private extractServerIdFromState(state: string | null): string | null {
639+
if (!state) return null;
640+
const parts = state.split(".");
641+
return parts.length === 2 ? parts[1] : null;
642+
}
643+
638644
isCallbackRequest(req: Request): boolean {
639645
if (req.method !== "GET") {
640646
return false;
641647
}
642648

643-
// Quick heuristic check: most callback URLs contain "/callback"
644-
// This avoids DB queries for obviously non-callback requests
645649
if (!req.url.includes("/callback")) {
646650
return false;
647651
}
648652

649-
// Check database for matching callback URL
653+
const url = new URL(req.url);
654+
const state = url.searchParams.get("state");
655+
const serverId = this.extractServerIdFromState(state);
656+
if (!serverId) {
657+
return false;
658+
}
659+
650660
const servers = this.getServersFromStorage();
651-
return servers.some(
652-
(server) => server.callback_url && req.url.startsWith(server.callback_url)
653-
);
661+
return servers.some((server) => server.id === serverId);
654662
}
655663

656664
async handleCallbackRequest(req: Request) {
657665
const url = new URL(req.url);
666+
const code = url.searchParams.get("code");
667+
const state = url.searchParams.get("state");
668+
const error = url.searchParams.get("error");
669+
const errorDescription = url.searchParams.get("error_description");
670+
671+
if (!state) {
672+
throw new Error("Unauthorized: no state provided");
673+
}
674+
675+
const serverId = this.extractServerIdFromState(state);
676+
677+
if (!serverId) {
678+
throw new Error(
679+
"No serverId found in state parameter. Expected format: {nonce}.{serverId}"
680+
);
681+
}
658682

659-
// Find the matching server from database
660683
const servers = this.getServersFromStorage();
661-
const matchingServer = servers.find((server: MCPServerRow) => {
662-
return server.callback_url && req.url.startsWith(server.callback_url);
663-
});
684+
const serverExists = servers.some((server) => server.id === serverId);
664685

665-
if (!matchingServer) {
686+
if (!serverExists) {
666687
throw new Error(
667-
`No callback URI match found for the request url: ${req.url}. Was the request matched with \`isCallbackRequest()\`?`
688+
`No server found with id "${serverId}". Was the request matched with \`isCallbackRequest()\`?`
668689
);
669690
}
670691

671-
const serverId = matchingServer.id;
672-
const code = url.searchParams.get("code");
673-
const state = url.searchParams.get("state");
674-
const error = url.searchParams.get("error");
675-
const errorDescription = url.searchParams.get("error_description");
692+
if (this.mcpConnections[serverId] === undefined) {
693+
throw new Error(`Could not find serverId: ${serverId}`);
694+
}
695+
696+
const conn = this.mcpConnections[serverId];
697+
if (!conn.options.transport.authProvider) {
698+
throw new Error(
699+
"Trying to finalize authentication for a server connection without an authProvider"
700+
);
701+
}
702+
703+
const authProvider = conn.options.transport.authProvider;
704+
authProvider.serverId = serverId;
705+
706+
// Two-phase state validation: check first (non-destructive), consume later
707+
// This prevents DoS attacks where attacker consumes valid state before legitimate callback
708+
const stateValidation = await authProvider.checkState(state);
709+
if (!stateValidation.valid) {
710+
throw new Error(`Invalid state: ${stateValidation.error}`);
711+
}
676712

677-
// Handle OAuth error responses from the provider
678713
if (error) {
679714
return {
680715
serverId,
@@ -686,25 +721,14 @@ export class MCPClientManager {
686721
if (!code) {
687722
throw new Error("Unauthorized: no code provided");
688723
}
689-
if (!state) {
690-
throw new Error("Unauthorized: no state provided");
691-
}
692724

693-
if (this.mcpConnections[serverId] === undefined) {
694-
throw new Error(`Could not find serverId: ${serverId}`);
695-
}
696-
697-
// If connection is already ready/connected, this is likely a duplicate callback
698725
if (
699726
this.mcpConnections[serverId].connectionState ===
700727
MCPConnectionState.READY ||
701728
this.mcpConnections[serverId].connectionState ===
702729
MCPConnectionState.CONNECTED
703730
) {
704-
// make sure auth_url is cleared
705731
this.clearServerAuthUrl(serverId);
706-
707-
// Already authenticated and ready, treat as success
708732
return {
709733
serverId,
710734
authSuccess: true
@@ -720,32 +744,20 @@ export class MCPClientManager {
720744
);
721745
}
722746

723-
const conn = this.mcpConnections[serverId];
724-
if (!conn.options.transport.authProvider) {
725-
throw new Error(
726-
"Trying to finalize authentication for a server connection without an authProvider"
727-
);
728-
}
729-
730-
// Get clientId from auth provider (stored during redirectToAuthorization) or fallback to state for backward compatibility
731-
const clientId = conn.options.transport.authProvider.clientId || state;
732-
733-
// Set the OAuth credentials
734-
conn.options.transport.authProvider.clientId = clientId;
735-
conn.options.transport.authProvider.serverId = serverId;
736-
737747
try {
748+
await authProvider.consumeState(state);
738749
await conn.completeAuthorization(code);
750+
await authProvider.deleteCodeVerifier();
739751
this.clearServerAuthUrl(serverId);
740752
this._onServerStateChanged.fire();
741753

742754
return {
743755
serverId,
744756
authSuccess: true
745757
};
746-
} catch (error) {
758+
} catch (authError) {
747759
const errorMessage =
748-
error instanceof Error ? error.message : String(error);
760+
authError instanceof Error ? authError.message : String(authError);
749761

750762
this._onServerStateChanged.fire();
751763

packages/agents/src/mcp/do-oauth-client-provider.ts

Lines changed: 100 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,25 @@ import type {
77
} from "@modelcontextprotocol/sdk/shared/auth.js";
88
import { nanoid } from "nanoid";
99

10+
const STATE_EXPIRATION_MS = 10 * 60 * 1000; // 10 minutes
11+
12+
interface StoredState {
13+
nonce: string;
14+
serverId: string;
15+
createdAt: number;
16+
}
17+
1018
// A slight extension to the standard OAuthClientProvider interface because `redirectToAuthorization` doesn't give us the interface we need
1119
// This allows us to track authentication for a specific server and associated dynamic client registration
1220
export interface AgentsOAuthProvider extends OAuthClientProvider {
1321
authUrl: string | undefined;
1422
clientId: string | undefined;
1523
serverId: string | undefined;
24+
checkState(
25+
state: string
26+
): Promise<{ valid: boolean; serverId?: string; error?: string }>;
27+
consumeState(state: string): Promise<void>;
28+
deleteCodeVerifier(): Promise<void>;
1629
}
1730

1831
export class DurableObjectOAuthClientProvider implements AgentsOAuthProvider {
@@ -48,7 +61,7 @@ export class DurableObjectOAuthClientProvider implements AgentsOAuthProvider {
4861
}
4962

5063
get redirectUrl() {
51-
return `${this.baseRedirectUrl}/${this.serverId}`;
64+
return this.baseRedirectUrl;
5265
}
5366

5467
get clientId() {
@@ -124,17 +137,92 @@ export class DurableObjectOAuthClientProvider implements AgentsOAuthProvider {
124137
return this._authUrl_;
125138
}
126139

127-
/**
128-
* Because this operates on the server side (but we need browser auth), we send this url back to the user
129-
* and require user interact to initiate the redirect flow
130-
*/
140+
stateKey(nonce: string) {
141+
return `/${this.clientName}/${this.serverId}/state/${nonce}`;
142+
}
143+
144+
async state(): Promise<string> {
145+
const nonce = nanoid();
146+
const state = `${nonce}.${this.serverId}`;
147+
const storedState: StoredState = {
148+
nonce,
149+
serverId: this.serverId,
150+
createdAt: Date.now()
151+
};
152+
await this.storage.put(this.stateKey(nonce), storedState);
153+
return state;
154+
}
155+
156+
async checkState(
157+
state: string
158+
): Promise<{ valid: boolean; serverId?: string; error?: string }> {
159+
const parts = state.split(".");
160+
if (parts.length !== 2) {
161+
return { valid: false, error: "Invalid state format" };
162+
}
163+
164+
const [nonce, serverId] = parts;
165+
const key = this.stateKey(nonce);
166+
const storedState = await this.storage.get<StoredState>(key);
167+
168+
if (!storedState) {
169+
return { valid: false, error: "State not found or already used" };
170+
}
171+
172+
if (storedState.serverId !== serverId) {
173+
await this.storage.delete(key);
174+
return { valid: false, error: "State serverId mismatch" };
175+
}
176+
177+
const age = Date.now() - storedState.createdAt;
178+
if (age > STATE_EXPIRATION_MS) {
179+
await this.storage.delete(key);
180+
return { valid: false, error: "State expired" };
181+
}
182+
183+
return { valid: true, serverId };
184+
}
185+
186+
async consumeState(state: string): Promise<void> {
187+
const parts = state.split(".");
188+
if (parts.length !== 2) {
189+
// This should never happen since checkState validates format first.
190+
// Log for debugging but don't throw - state consumption is best-effort.
191+
console.warn(
192+
`[OAuth] consumeState called with invalid state format: ${state.substring(0, 20)}...`
193+
);
194+
return;
195+
}
196+
const [nonce] = parts;
197+
await this.storage.delete(this.stateKey(nonce));
198+
}
199+
131200
async redirectToAuthorization(authUrl: URL): Promise<void> {
132-
// Generate secure random token for state parameter
133-
const stateToken = nanoid();
134-
authUrl.searchParams.set("state", stateToken);
135201
this._authUrl_ = authUrl.toString();
136202
}
137203

204+
async invalidateCredentials(
205+
scope: "all" | "client" | "tokens" | "verifier"
206+
): Promise<void> {
207+
if (!this._clientId_) return;
208+
209+
const deleteKeys: string[] = [];
210+
211+
if (scope === "all" || scope === "client") {
212+
deleteKeys.push(this.clientInfoKey(this.clientId));
213+
}
214+
if (scope === "all" || scope === "tokens") {
215+
deleteKeys.push(this.tokenKey(this.clientId));
216+
}
217+
if (scope === "all" || scope === "verifier") {
218+
deleteKeys.push(this.codeVerifierKey(this.clientId));
219+
}
220+
221+
if (deleteKeys.length > 0) {
222+
await this.storage.delete(deleteKeys);
223+
}
224+
}
225+
138226
codeVerifierKey(clientId: string) {
139227
return `${this.keyPrefix(clientId)}/code_verifier`;
140228
}
@@ -160,4 +248,8 @@ export class DurableObjectOAuthClientProvider implements AgentsOAuthProvider {
160248
}
161249
return codeVerifier;
162250
}
251+
252+
async deleteCodeVerifier(): Promise<void> {
253+
await this.storage.delete(this.codeVerifierKey(this.clientId));
254+
}
163255
}

0 commit comments

Comments
 (0)