diff --git a/packages/typescript/ai-client/src/chat-client.ts b/packages/typescript/ai-client/src/chat-client.ts index 4b0e40f70..44e009a2e 100644 --- a/packages/typescript/ai-client/src/chat-client.ts +++ b/packages/typescript/ai-client/src/chat-client.ts @@ -32,6 +32,7 @@ export class ChatClient { private uniqueId: string private body: Record = {} private pendingMessageBody: Record | undefined = undefined + private context: unknown = undefined private isLoading = false private isSubscribed = false private error: Error | undefined = undefined @@ -81,6 +82,7 @@ export class ChatClient { constructor(options: ChatClientOptions) { this.uniqueId = options.id || this.generateUniqueId('chat') this.body = options.body || {} + this.context = options.context this.connection = normalizeConnectionAdapter(options.connection) this.events = new DefaultChatClientEventEmitter(this.uniqueId) @@ -202,7 +204,7 @@ export class ChatClient { // Create and track the execution promise const executionPromise = (async () => { try { - const output = await executeFunc(args.input) + const output = await executeFunc(args.input, { userContext: this.context }) await this.addToolResult({ toolCallId: args.toolCallId, tool: args.toolName, diff --git a/packages/typescript/ai-client/src/types.ts b/packages/typescript/ai-client/src/types.ts index b705ebbdf..cbaf539a3 100644 --- a/packages/typescript/ai-client/src/types.ts +++ b/packages/typescript/ai-client/src/types.ts @@ -185,6 +185,7 @@ export interface UIMessage = any> { export interface ChatClientOptions< TTools extends ReadonlyArray = any, + TContext = unknown, > { /** * Connection adapter for streaming. @@ -193,6 +194,27 @@ export interface ChatClientOptions< */ connection: ConnectionAdapter + /** + * Context object passed to client-side tool execute functions during execution. + * + * This is client-side only — it is NOT serialized or sent to the server. + * For server-side tool context, pass `context` directly to `chat()` on the server. + * + * Available as `context.userContext` inside tool execute functions. + * + * @example + * const client = new ChatClient({ + * context: { userId: '123', api }, + * tools: [myClientTool], + * }) + * + * // In tool definition: + * const myClientTool = toolDef.client(async (args, context) => { + * return context?.userContext?.api.fetch(args.id) + * }) + */ + context?: TContext + /** * Initial messages to populate the chat */ diff --git a/packages/typescript/ai/src/activities/chat/index.ts b/packages/typescript/ai/src/activities/chat/index.ts index c7ec866d6..7564d2859 100644 --- a/packages/typescript/ai/src/activities/chat/index.ts +++ b/packages/typescript/ai/src/activities/chat/index.ts @@ -712,6 +712,7 @@ class TextEngine< ) }, }, + this.params.context, ) // Consume the async generator, yielding custom events and collecting the return value @@ -856,6 +857,7 @@ class TextEngine< ) }, }, + this.params.context, ) // Consume the async generator, yielding custom events and collecting the return value diff --git a/packages/typescript/ai/src/activities/chat/tools/tool-calls.ts b/packages/typescript/ai/src/activities/chat/tools/tool-calls.ts index f6fe2060e..7348f4b9c 100644 --- a/packages/typescript/ai/src/activities/chat/tools/tool-calls.ts +++ b/packages/typescript/ai/src/activities/chat/tools/tool-calls.ts @@ -151,9 +151,11 @@ export class ToolCallManager { * Execute all tool calls and return tool result messages * Yields TOOL_CALL_END events for streaming * @param finishEvent - RUN_FINISHED event from the stream + * @param userContext - Optional user-provided context passed to tool execute functions */ async *executeTools( finishEvent: RunFinishedEvent, + userContext?: unknown, ): AsyncGenerator, void> { const toolCallsArray = this.getToolCalls() const toolResults: Array = [] @@ -191,7 +193,7 @@ export class ToolCallManager { } // Execute the tool - let result = await tool.execute(args) + let result = await tool.execute(args, userContext !== undefined ? { userContext } : undefined) // Validate output against outputSchema if provided (for Standard Schema compliant schemas) if ( @@ -495,6 +497,7 @@ export async function* executeToolCalls( value: Record, ) => CustomEvent, middlewareHooks?: ToolExecutionMiddlewareHooks, + userContext?: unknown, ): AsyncGenerator { const results: Array = [] const needsApproval: Array = [] @@ -572,6 +575,7 @@ export async function* executeToolCalls( const pendingEvents: Array = [] const context: ToolExecutionContext = { toolCallId: toolCall.id, + userContext, emitCustomEvent: (eventName: string, value: Record) => { if (createCustomEventChunk) { pendingEvents.push( diff --git a/packages/typescript/ai/src/activities/chat/tools/tool-definition.ts b/packages/typescript/ai/src/activities/chat/tools/tool-definition.ts index c12c61898..589abf275 100644 --- a/packages/typescript/ai/src/activities/chat/tools/tool-definition.ts +++ b/packages/typescript/ai/src/activities/chat/tools/tool-definition.ts @@ -25,7 +25,7 @@ export interface ClientTool< TInput extends SchemaInput = SchemaInput, TOutput extends SchemaInput = SchemaInput, TName extends string = string, -> { +> extends Tool { __toolSide: 'client' name: TName description: string @@ -36,6 +36,7 @@ export interface ClientTool< metadata?: Record execute?: ( args: InferSchemaType, + context?: ToolExecutionContext, ) => Promise> | InferSchemaType } @@ -125,6 +126,7 @@ export interface ToolDefinition< client: ( execute?: ( args: InferSchemaType, + context?: ToolExecutionContext, ) => Promise> | InferSchemaType, ) => ClientTool } @@ -203,19 +205,22 @@ export function toolDefinition< return { __toolSide: 'server', ...config, - execute, + execute: (args, context) => execute(args, context), } }, client( execute?: ( args: InferSchemaType, + context?: ToolExecutionContext, ) => Promise> | InferSchemaType, ): ClientTool { return { __toolSide: 'client', ...config, - execute, + execute: execute + ? (args: any, context: any) => execute(args, context) + : undefined, } }, } diff --git a/packages/typescript/ai/src/index.ts b/packages/typescript/ai/src/index.ts index 0d1f7aeef..a219353ee 100644 --- a/packages/typescript/ai/src/index.ts +++ b/packages/typescript/ai/src/index.ts @@ -87,6 +87,8 @@ export type { ErrorInfo, } from './activities/chat/middleware/index' +export type { ToolExecutionContext } from './types' + // All types export * from './types' diff --git a/packages/typescript/ai/src/types.ts b/packages/typescript/ai/src/types.ts index f5ddaee59..7c05249fb 100644 --- a/packages/typescript/ai/src/types.ts +++ b/packages/typescript/ai/src/types.ts @@ -346,7 +346,7 @@ export type ConstrainedModelMessage< /** * Context passed to tool execute functions, providing capabilities like - * emitting custom events during execution. + * emitting custom events during execution and access to user-provided context. */ export interface ToolExecutionContext { /** The ID of the tool call being executed */ @@ -361,15 +361,25 @@ export interface ToolExecutionContext { * @example * ```ts * const tool = toolDefinition({ ... }).server(async (args, context) => { - * context?.emitCustomEvent('progress', { step: 1, total: 3 }) + * context?.emitCustomEvent?.('progress', { step: 1, total: 3 }) * // ... do work ... - * context?.emitCustomEvent('progress', { step: 2, total: 3 }) + * context?.emitCustomEvent?.('progress', { step: 2, total: 3 }) * // ... do more work ... * return result * }) * ``` */ - emitCustomEvent: (eventName: string, value: Record) => void + emitCustomEvent?: (eventName: string, value: Record) => void + /** + * User-provided context passed from chat() options. + * Allows tools to access shared context (e.g. database connections, user ID, request metadata) + * without needing to capture them via closures. + * + * @example + * chat({ context: { db, userId }, ... }) + * // In tool: context?.userContext?.db.users.find(...) + */ + userContext?: unknown } /** @@ -480,10 +490,12 @@ export interface Tool< * Can return any value - will be automatically stringified if needed. * * @param args - The arguments parsed from the model's tool call (validated against inputSchema) + * @param context - SDK context providing toolCallId, emitCustomEvent, and userContext * @returns Result to send back to the model (validated against outputSchema if provided) * * @example - * execute: async (args) => { + * execute: async (args, context) => { + * const user = await context?.userContext?.db.users.find({ id: context.userContext.userId }); // Can access user context * const weather = await fetchWeather(args.location); * return weather; // Can return object or string * } @@ -682,6 +694,29 @@ export interface TextOptions< metadata?: Record modelOptions?: TProviderOptionsForModel request?: Request | RequestInit + /** + * Context object that is automatically passed to all tool execute functions. + * + * This allows tools to access shared context (like user ID, database connections, + * request metadata, etc.) without needing to capture them via closures. + * Works for both server and client tools. + * + * @example + * const stream = chat({ + * adapter: openai(), + * model: 'gpt-4o', + * messages, + * context: { userId: '123', db }, + * tools: [getUserData], + * }); + * + * // In tool definition: + * const getUserData = getUserDataDef.server(async (args, context) => { + * // context.userContext.userId and context.userContext.db are available + * return await context?.userContext?.db.users.find({ userId: context.userContext.userId }); + * }); + */ + context?: unknown /** * Schema for structured output. diff --git a/packages/typescript/ai/tests/custom-events-integration.test.ts b/packages/typescript/ai/tests/custom-events-integration.test.ts index 9fe31fb6c..31f65e1ca 100644 --- a/packages/typescript/ai/tests/custom-events-integration.test.ts +++ b/packages/typescript/ai/tests/custom-events-integration.test.ts @@ -1,7 +1,7 @@ import { describe, expect, it, vi } from 'vitest' +import { z } from 'zod' import { toolDefinition } from '../src/activities/chat/tools/tool-definition' import { StreamProcessor } from '../src/activities/chat/stream/processor' -import { z } from 'zod' describe('Custom Events Integration', () => { it('should emit custom events from tool execution context', async () => { @@ -16,7 +16,7 @@ describe('Custom Events Integration', () => { }), }).server(async (args, context) => { // Emit progress event - context?.emitCustomEvent('tool:progress', { + context?.emitCustomEvent?.('tool:progress', { tool: 'testTool', progress: 25, message: 'Starting processing...', @@ -26,14 +26,14 @@ describe('Custom Events Integration', () => { await new Promise((resolve) => setTimeout(resolve, 10)) // Emit another progress event - context?.emitCustomEvent('tool:progress', { + context?.emitCustomEvent?.('tool:progress', { tool: 'testTool', progress: 75, message: 'Almost done...', }) // Emit completion event - context?.emitCustomEvent('tool:complete', { + context?.emitCustomEvent?.('tool:complete', { tool: 'testTool', result: 'success', duration: 20, diff --git a/packages/typescript/ai/tests/tool-call-manager.test.ts b/packages/typescript/ai/tests/tool-call-manager.test.ts index 546fc7a95..a528db0dd 100644 --- a/packages/typescript/ai/tests/tool-call-manager.test.ts +++ b/packages/typescript/ai/tests/tool-call-manager.test.ts @@ -21,7 +21,7 @@ describe('ToolCallManager', () => { inputSchema: z.object({ location: z.string().optional(), }), - execute: vi.fn((args: any) => { + execute: vi.fn((args: any, _options?: any) => { return JSON.stringify({ temp: 72, location: args.location }) }), } @@ -138,8 +138,8 @@ describe('ToolCallManager', () => { expect(finalResult[0]?.role).toBe('tool') expect(finalResult[0]?.toolCallId).toBe('call_123') - // Tool execute should have been called - expect(mockWeatherTool.execute).toHaveBeenCalledWith({ location: 'Paris' }) + // Tool execute should have been called without context (none provided) + expect(mockWeatherTool.execute).toHaveBeenCalledWith({ location: 'Paris' }, undefined) }) it('should handle tool execution errors gracefully', async () => { @@ -240,15 +240,16 @@ describe('ToolCallManager', () => { }) it('should handle multiple tool calls in same iteration', async () => { - const calculateTool: Tool = { - name: 'calculate', - description: 'Calculate', - inputSchema: z.object({ - expression: z.string(), - }), - execute: vi.fn((args: any) => { - return JSON.stringify({ result: eval(args.expression) }) - }), +const calculateTool: Tool = { + name: 'calculate', + description: 'Calculate', + inputSchema: z.object({ + expression: z.string(), + }), + execute: vi.fn((args: any, _options?: any) => { + const results: Record = { '5+3': 8 } + return JSON.stringify({ result: results[args.expression] ?? 0 }) + }), } const manager = new ToolCallManager([mockWeatherTool, calculateTool]) diff --git a/packages/typescript/ai/tests/tool-definition.test.ts b/packages/typescript/ai/tests/tool-definition.test.ts index 642992c68..a53b39d04 100644 --- a/packages/typescript/ai/tests/tool-definition.test.ts +++ b/packages/typescript/ai/tests/tool-definition.test.ts @@ -1,4 +1,4 @@ -import { describe, it, expect, vi } from 'vitest' +import { describe, expect, it, vi } from 'vitest' import { z } from 'zod' import { toolDefinition } from '../src/activities/chat/tools/tool-definition' @@ -46,7 +46,7 @@ describe('toolDefinition', () => { }), }) - const executeFn = vi.fn(async (_args: { location: string }) => { + const executeFn = vi.fn((_args: { location: string }, _options?: unknown) => { return { temperature: 72, conditions: 'sunny', @@ -60,9 +60,9 @@ describe('toolDefinition', () => { expect(serverTool.execute).toBeDefined() if (serverTool.execute) { - const result = await serverTool.execute({ location: 'Paris' }) + const result = await serverTool.execute({ location: 'Paris' }, undefined) expect(result).toEqual({ temperature: 72, conditions: 'sunny' }) - expect(executeFn).toHaveBeenCalledWith({ location: 'Paris' }) + expect(executeFn).toHaveBeenCalledWith({ location: 'Paris' }, undefined) } }) @@ -79,7 +79,7 @@ describe('toolDefinition', () => { }), }) - const executeFn = vi.fn(async (_args: { key: string; value: string }) => { + const executeFn = vi.fn(async (_args: { key: string; value: string }, _options?: unknown) => { return { success: true } }) @@ -90,9 +90,9 @@ describe('toolDefinition', () => { expect(clientTool.execute).toBeDefined() if (clientTool.execute) { - const result = await clientTool.execute({ key: 'test', value: 'data' }) + const result = await clientTool.execute({ key: 'test', value: 'data' }, undefined) expect(result).toEqual({ success: true }) - expect(executeFn).toHaveBeenCalledWith({ key: 'test', value: 'data' }) + expect(executeFn).toHaveBeenCalledWith({ key: 'test', value: 'data' }, undefined) } }) @@ -207,7 +207,7 @@ describe('toolDefinition', () => { }) if (serverTool.execute) { - const result = serverTool.execute({ value: 5 }) + const result = serverTool.execute({ value: 5 }, undefined) expect(result).toEqual({ doubled: 10 }) } }) @@ -253,7 +253,7 @@ describe('toolDefinition', () => { orderId: '123', items: [], shipping: { address: '123 Main St', method: 'standard' }, - }) + }, undefined) expect(serverTool.__toolSide).toBe('server') }) @@ -286,4 +286,109 @@ describe('toolDefinition', () => { expect(tool.__toolSide).toBe('definition') expect(tool.inputSchema).toBeDefined() }) + + it('should pass context to server tool execute function', async () => { + const tool = toolDefinition({ + name: 'getContextValue', + description: 'Get a value from context', + inputSchema: z.object({ + key: z.string(), + }), + outputSchema: z.object({ + exists: z.boolean(), + value: z.string().optional(), + }), + }) + + const contextValue = 'test-value' + const context = { testData: contextValue } + + const serverTool = tool.server( + (_: unknown, ctx) => { + const userCtx = ctx?.userContext as typeof context | undefined + const exists = userCtx?.testData !== undefined + const value = exists ? userCtx.testData : undefined + return { exists, value } + } + ) + + if (serverTool.execute) { + const result = await serverTool.execute( + { key: 'testData' }, + { userContext: context, emitCustomEvent: () => {} } + ) + + expect(result.exists).toBe(true) + expect(result.value).toBe(contextValue) + } + }) + + it('should pass context to client tool execute function', async () => { + const tool = toolDefinition({ + name: 'getContextValue', + description: 'Get a value from context', + inputSchema: z.object({ + key: z.string(), + }), + outputSchema: z.object({ + exists: z.boolean(), + value: z.string().optional(), + }), + }) + + const contextValue = 'test-value' + const context = { testData: contextValue } + + const clientTool = tool.client( + (_: unknown, ctx) => { + const userCtx = ctx?.userContext as typeof context | undefined + const exists = userCtx?.testData !== undefined + const value = exists ? userCtx.testData : undefined + return { exists, value } + } + ) + + if (clientTool.execute) { + const result = await clientTool.execute( + { key: 'testData' }, + { userContext: context } + ) + + expect(result.exists).toBe(true) + expect(result.value).toBe(contextValue) + } + }) + + it('should handle missing context gracefully', async () => { + const tool = toolDefinition({ + name: 'getContextValue', + description: 'Get a value from context', + inputSchema: z.object({ + key: z.string(), + }), + outputSchema: z.object({ + exists: z.boolean(), + value: z.string().optional(), + }), + }) + + const serverTool = tool.server( + (_: unknown, ctx) => { + const userCtx = ctx?.userContext as { testData?: string } | undefined + const exists = userCtx?.testData !== undefined + const value = exists ? userCtx.testData : undefined + return { exists, value } + } + ) + + if (serverTool.execute) { + const result = await serverTool.execute( + { key: 'testData' }, + { userContext: {} } // Empty context + ) + + expect(result.exists).toBe(false) + expect(result.value).toBeUndefined() + } + }) }) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 37caa9ab7..8e9eea23a 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -9,7 +9,7 @@ overrides: patchedDependencies: '@changesets/assemble-release-plan@6.0.9': - hash: 1bc53c741da20baad9cdeb674c599225dcf9fe6aa7b0de16ff87e63f12a97b24 + hash: dyjlmgavov3d7e6popx3ho7cei path: patches/@changesets__assemble-release-plan@6.0.9.patch importers: @@ -9804,7 +9804,7 @@ snapshots: resolve-from: 5.0.0 semver: 7.7.4 - '@changesets/assemble-release-plan@6.0.9(patch_hash=1bc53c741da20baad9cdeb674c599225dcf9fe6aa7b0de16ff87e63f12a97b24)': + '@changesets/assemble-release-plan@6.0.9(patch_hash=dyjlmgavov3d7e6popx3ho7cei)': dependencies: '@changesets/errors': 0.2.0 '@changesets/get-dependents-graph': 2.1.3 @@ -9820,7 +9820,7 @@ snapshots: '@changesets/cli@2.30.0(@types/node@24.10.3)': dependencies: '@changesets/apply-release-plan': 7.1.0 - '@changesets/assemble-release-plan': 6.0.9(patch_hash=1bc53c741da20baad9cdeb674c599225dcf9fe6aa7b0de16ff87e63f12a97b24) + '@changesets/assemble-release-plan': 6.0.9(patch_hash=dyjlmgavov3d7e6popx3ho7cei) '@changesets/changelog-git': 0.2.1 '@changesets/config': 3.1.3 '@changesets/errors': 0.2.0 @@ -9879,7 +9879,7 @@ snapshots: '@changesets/get-release-plan@4.0.15': dependencies: - '@changesets/assemble-release-plan': 6.0.9(patch_hash=1bc53c741da20baad9cdeb674c599225dcf9fe6aa7b0de16ff87e63f12a97b24) + '@changesets/assemble-release-plan': 6.0.9(patch_hash=dyjlmgavov3d7e6popx3ho7cei) '@changesets/config': 3.1.3 '@changesets/pre': 2.0.2 '@changesets/read': 0.6.7