Skip to content

Commit cd8b7fd

Browse files
authored
fix connection inside tool execution (#712)
1 parent 2c15bd1 commit cd8b7fd

File tree

6 files changed

+239
-49
lines changed

6 files changed

+239
-49
lines changed

.changeset/breezy-poets-try.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"agents": patch
3+
---
4+
5+
fix connection inside tool execution

packages/agents/src/ai-chat-agent.ts

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import {
1616
type ConnectionContext,
1717
type WSMessage
1818
} from "./";
19+
import { agentContext } from "./context";
1920
import {
2021
MessageType,
2122
type IncomingMessage,
@@ -212,40 +213,52 @@ export class AIChatAgent<Env = unknown, State = unknown> extends Agent<
212213
const abortSignal = this._getAbortSignal(chatMessageId);
213214

214215
return this._tryCatchChat(async () => {
215-
const response = await this.onChatMessage(
216-
async (_finishResult) => {
217-
this._removeAbortController(chatMessageId);
218-
219-
this.observability?.emit(
220-
{
221-
displayMessage: "Chat message response",
222-
id: data.id,
223-
payload: {},
224-
timestamp: Date.now(),
225-
type: "message:response"
216+
// Wrap in agentContext.run() to propagate connection context to onChatMessage
217+
// This ensures getCurrentAgent() returns the connection inside tool execute functions
218+
return agentContext.run(
219+
{
220+
agent: this,
221+
connection,
222+
request: undefined,
223+
email: undefined
224+
},
225+
async () => {
226+
const response = await this.onChatMessage(
227+
async (_finishResult) => {
228+
this._removeAbortController(chatMessageId);
229+
230+
this.observability?.emit(
231+
{
232+
displayMessage: "Chat message response",
233+
id: data.id,
234+
payload: {},
235+
timestamp: Date.now(),
236+
type: "message:response"
237+
},
238+
this.ctx
239+
);
226240
},
227-
this.ctx
241+
abortSignal ? { abortSignal } : undefined
228242
);
229-
},
230-
abortSignal ? { abortSignal } : undefined
231-
);
232243

233-
if (response) {
234-
await this._reply(data.id, response);
235-
} else {
236-
console.warn(
237-
`[AIChatAgent] onChatMessage returned no response for chatMessageId: ${chatMessageId}`
238-
);
239-
this._broadcastChatMessage(
240-
{
241-
body: "No response was generated by the agent.",
242-
done: true,
243-
id: data.id,
244-
type: MessageType.CF_AGENT_USE_CHAT_RESPONSE
245-
},
246-
[connection.id]
247-
);
248-
}
244+
if (response) {
245+
await this._reply(data.id, response);
246+
} else {
247+
console.warn(
248+
`[AIChatAgent] onChatMessage returned no response for chatMessageId: ${chatMessageId}`
249+
);
250+
this._broadcastChatMessage(
251+
{
252+
body: "No response was generated by the agent.",
253+
done: true,
254+
id: data.id,
255+
type: MessageType.CF_AGENT_USE_CHAT_RESPONSE
256+
},
257+
[connection.id]
258+
);
259+
}
260+
}
261+
);
249262
});
250263
}
251264

packages/agents/src/context.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import { AsyncLocalStorage } from "node:async_hooks";
2+
import type { Connection } from "partyserver";
3+
4+
export type AgentEmail = {
5+
from: string;
6+
to: string;
7+
getRaw: () => Promise<Uint8Array>;
8+
headers: Headers;
9+
rawSize: number;
10+
setReject: (reason: string) => void;
11+
forward: (rcptTo: string, headers?: Headers) => Promise<void>;
12+
reply: (options: { from: string; to: string; raw: string }) => Promise<void>;
13+
};
14+
15+
export type AgentContextStore = {
16+
// Using unknown to avoid circular dependency with Agent
17+
agent: unknown;
18+
connection: Connection | undefined;
19+
request: Request | undefined;
20+
email: AgentEmail | undefined;
21+
};
22+
23+
export const agentContext = new AsyncLocalStorage<AgentContextStore>();

packages/agents/src/index.ts

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import type { env } from "cloudflare:workers";
2-
import { AsyncLocalStorage } from "node:async_hooks";
32
import type { Client } from "@modelcontextprotocol/sdk/client/index.js";
3+
import { agentContext, type AgentEmail } from "./context";
44
import type { SSEClientTransportOptions } from "@modelcontextprotocol/sdk/client/sse.js";
55

66
import type {
@@ -234,13 +234,6 @@ const STATE_WAS_CHANGED = "cf_state_was_changed";
234234

235235
const DEFAULT_STATE = {} as unknown;
236236

237-
const agentContext = new AsyncLocalStorage<{
238-
agent: Agent<unknown, unknown>;
239-
connection: Connection | undefined;
240-
request: Request | undefined;
241-
email: AgentEmail | undefined;
242-
}>();
243-
244237
export function getCurrentAgent<
245238
T extends Agent<unknown, unknown> = Agent<unknown, unknown>
246239
>(): {
@@ -1921,16 +1914,8 @@ export async function routeAgentEmail<Env>(
19211914
await agent._onEmail(serialisableEmail);
19221915
}
19231916

1924-
export type AgentEmail = {
1925-
from: string;
1926-
to: string;
1927-
getRaw: () => Promise<Uint8Array>;
1928-
headers: Headers;
1929-
rawSize: number;
1930-
setReject: (reason: string) => void;
1931-
forward: (rcptTo: string, headers?: Headers) => Promise<void>;
1932-
reply: (options: { from: string; to: string; raw: string }) => Promise<void>;
1933-
};
1917+
// AgentEmail is re-exported from ./context
1918+
export type { AgentEmail } from "./context";
19341919

19351920
export type EmailSendOptions = {
19361921
to: string;
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import { createExecutionContext, env } from "cloudflare:test";
2+
import { describe, it, expect } from "vitest";
3+
import worker, { type Env } from "./worker";
4+
import { MessageType } from "../ai-types";
5+
import type { UIMessage as ChatMessage } from "ai";
6+
7+
declare module "cloudflare:test" {
8+
interface ProvidedEnv extends Env {}
9+
}
10+
11+
async function connectChatWS(path: string) {
12+
const ctx = createExecutionContext();
13+
const req = new Request(`http://example.com${path}`, {
14+
headers: { Upgrade: "websocket" }
15+
});
16+
const res = await worker.fetch(req, env, ctx);
17+
expect(res.status).toBe(101);
18+
const ws = res.webSocket as WebSocket;
19+
expect(ws).toBeDefined();
20+
ws.accept();
21+
return { ws, ctx };
22+
}
23+
24+
describe("AIChatAgent Connection Context - Issue #711", () => {
25+
it("getCurrentAgent() should return connection in onChatMessage and nested async functions (tool execute)", async () => {
26+
const room = crypto.randomUUID();
27+
const { ws } = await connectChatWS(`/agents/test-chat-agent/${room}`);
28+
29+
// Get the agent stub to access captured context
30+
const agentStub = env.TestChatAgent.get(env.TestChatAgent.idFromName(room));
31+
32+
// Clear any previous captured context
33+
await agentStub.clearCapturedContext();
34+
35+
let resolvePromise: (value: boolean) => void;
36+
const donePromise = new Promise<boolean>((res) => {
37+
resolvePromise = res;
38+
});
39+
40+
const timeout = setTimeout(() => resolvePromise(false), 2000);
41+
42+
ws.addEventListener("message", (e: MessageEvent) => {
43+
const data = JSON.parse(e.data as string);
44+
if (data.type === MessageType.CF_AGENT_USE_CHAT_RESPONSE && data.done) {
45+
clearTimeout(timeout);
46+
resolvePromise(true);
47+
}
48+
});
49+
50+
const userMessage: ChatMessage = {
51+
id: "msg1",
52+
role: "user",
53+
parts: [{ type: "text", text: "Hello" }]
54+
};
55+
56+
// Send a chat message which will trigger onChatMessage
57+
ws.send(
58+
JSON.stringify({
59+
type: MessageType.CF_AGENT_USE_CHAT_REQUEST,
60+
id: "req1",
61+
init: {
62+
method: "POST",
63+
body: JSON.stringify({ messages: [userMessage] })
64+
}
65+
})
66+
);
67+
68+
const done = await donePromise;
69+
expect(done).toBe(true);
70+
71+
// Wait a bit to ensure context is captured
72+
await new Promise((resolve) => setTimeout(resolve, 100));
73+
74+
// Check the captured context from onChatMessage
75+
const capturedContext = await agentStub.getCapturedContext();
76+
77+
expect(capturedContext).not.toBeNull();
78+
// The agent should be available
79+
expect(capturedContext?.hasAgent).toBe(true);
80+
// The connection should be available - this is the bug being tested
81+
// Before the fix, this would be false
82+
expect(capturedContext?.hasConnection).toBe(true);
83+
// The connection ID should be defined
84+
expect(capturedContext?.connectionId).toBeDefined();
85+
86+
// Check the nested context
87+
// Tools called from onChatMessage couldn't access connection context
88+
const nestedContext = await agentStub.getNestedContext();
89+
90+
expect(nestedContext).not.toBeNull();
91+
// The agent should be available in nested async functions
92+
expect(nestedContext?.hasAgent).toBe(true);
93+
// The connection should ALSO be available in nested async functions (tool execute)
94+
// Before the fix, this would be false
95+
expect(nestedContext?.hasConnection).toBe(true);
96+
// The connection ID should match between onChatMessage and nested function
97+
expect(nestedContext?.connectionId).toBe(capturedContext?.connectionId);
98+
99+
ws.close();
100+
});
101+
});

packages/agents/src/tests/worker.ts

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { McpAgent } from "../mcp/index.ts";
1111
import {
1212
Agent,
1313
callable,
14+
getCurrentAgent,
1415
routeAgentRequest,
1516
type AgentEmail,
1617
type Connection,
@@ -571,14 +572,76 @@ export class TestOAuthAgent extends Agent<Env> {
571572

572573
export class TestChatAgent extends AIChatAgent<Env> {
573574
observability = undefined;
575+
// Store captured context for testing
576+
private _capturedContext: {
577+
hasAgent: boolean;
578+
hasConnection: boolean;
579+
connectionId: string | undefined;
580+
} | null = null;
581+
// Store context captured from nested async function (simulates tool execute)
582+
private _nestedContext: {
583+
hasAgent: boolean;
584+
hasConnection: boolean;
585+
connectionId: string | undefined;
586+
} | null = null;
574587

575588
async onChatMessage() {
589+
// Capture getCurrentAgent() context for testing
590+
const { agent, connection } = getCurrentAgent();
591+
this._capturedContext = {
592+
hasAgent: agent !== undefined,
593+
hasConnection: connection !== undefined,
594+
connectionId: connection?.id
595+
};
596+
597+
// Simulate what happens inside a tool's execute function:
598+
// It's a nested async function called from within onChatMessage
599+
await this._simulateToolExecute();
600+
576601
// Simple echo response for testing
577602
return new Response("Hello from chat agent!", {
578603
headers: { "Content-Type": "text/plain" }
579604
});
580605
}
581606

607+
// This simulates an AI SDK tool's execute function being called
608+
private async _simulateToolExecute(): Promise<void> {
609+
// Add a small delay to ensure we're in a new microtask (like real tool execution)
610+
await Promise.resolve();
611+
612+
// Capture context inside the "tool execute" function
613+
const { agent, connection } = getCurrentAgent();
614+
this._nestedContext = {
615+
hasAgent: agent !== undefined,
616+
hasConnection: connection !== undefined,
617+
connectionId: connection?.id
618+
};
619+
}
620+
621+
@callable()
622+
getCapturedContext(): {
623+
hasAgent: boolean;
624+
hasConnection: boolean;
625+
connectionId: string | undefined;
626+
} | null {
627+
return this._capturedContext;
628+
}
629+
630+
@callable()
631+
getNestedContext(): {
632+
hasAgent: boolean;
633+
hasConnection: boolean;
634+
connectionId: string | undefined;
635+
} | null {
636+
return this._nestedContext;
637+
}
638+
639+
@callable()
640+
clearCapturedContext(): void {
641+
this._capturedContext = null;
642+
this._nestedContext = null;
643+
}
644+
582645
@callable()
583646
getPersistedMessages(): ChatMessage[] {
584647
const rawMessages = (

0 commit comments

Comments
 (0)