Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion packages/typescript/ai-client/src/chat-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export class ChatClient {
private uniqueId: string
private body: Record<string, any> = {}
private pendingMessageBody: Record<string, any> | undefined = undefined
private context: unknown = undefined
private isLoading = false
private isSubscribed = false
private error: Error | undefined = undefined
Expand Down Expand Up @@ -81,6 +82,7 @@ export class ChatClient {
constructor(options: ChatClientOptions) {
this.uniqueId = options.id || this.generateUniqueId('chat')
this.body = options.body || {}
this.context = options.context
this.connection = normalizeConnectionAdapter(options.connection)
this.events = new DefaultChatClientEventEmitter(this.uniqueId)

Expand Down Expand Up @@ -202,7 +204,7 @@ export class ChatClient {
// Create and track the execution promise
const executionPromise = (async () => {
try {
const output = await executeFunc(args.input)
const output = await executeFunc(args.input, { userContext: this.context })
await this.addToolResult({
toolCallId: args.toolCallId,
tool: args.toolName,
Expand Down
22 changes: 22 additions & 0 deletions packages/typescript/ai-client/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ export interface UIMessage<TTools extends ReadonlyArray<AnyClientTool> = any> {

export interface ChatClientOptions<
TTools extends ReadonlyArray<AnyClientTool> = any,
TContext = unknown,
> {
/**
* Connection adapter for streaming.
Expand All @@ -193,6 +194,27 @@ export interface ChatClientOptions<
*/
connection: ConnectionAdapter

/**
* Context object passed to client-side tool execute functions during execution.
*
* This is client-side only — it is NOT serialized or sent to the server.
* For server-side tool context, pass `context` directly to `chat()` on the server.
*
* Available as `context.userContext` inside tool execute functions.
*
* @example
* const client = new ChatClient({
* context: { userId: '123', api },
* tools: [myClientTool],
* })
*
* // In tool definition:
* const myClientTool = toolDef.client(async (args, context) => {
* return context?.userContext?.api.fetch(args.id)
* })
*/
context?: TContext

/**
* Initial messages to populate the chat
*/
Expand Down
2 changes: 2 additions & 0 deletions packages/typescript/ai/src/activities/chat/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ class TextEngine<
)
},
},
this.params.context,
)

// Consume the async generator, yielding custom events and collecting the return value
Expand Down Expand Up @@ -856,6 +857,7 @@ class TextEngine<
)
},
},
this.params.context,
)

// Consume the async generator, yielding custom events and collecting the return value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,11 @@ export class ToolCallManager {
* Execute all tool calls and return tool result messages
* Yields TOOL_CALL_END events for streaming
* @param finishEvent - RUN_FINISHED event from the stream
* @param userContext - Optional user-provided context passed to tool execute functions
*/
async *executeTools(
finishEvent: RunFinishedEvent,
userContext?: unknown,
): AsyncGenerator<ToolCallEndEvent, Array<ModelMessage>, void> {
const toolCallsArray = this.getToolCalls()
const toolResults: Array<ModelMessage> = []
Expand Down Expand Up @@ -191,7 +193,7 @@ export class ToolCallManager {
}

// Execute the tool
let result = await tool.execute(args)
let result = await tool.execute(args, userContext !== undefined ? { userContext } : undefined)

// Validate output against outputSchema if provided (for Standard Schema compliant schemas)
if (
Expand Down Expand Up @@ -495,6 +497,7 @@ export async function* executeToolCalls(
value: Record<string, any>,
) => CustomEvent,
middlewareHooks?: ToolExecutionMiddlewareHooks,
userContext?: unknown,
): AsyncGenerator<CustomEvent, ExecuteToolCallsResult, void> {
const results: Array<ToolResult> = []
const needsApproval: Array<ApprovalRequest> = []
Expand Down Expand Up @@ -572,6 +575,7 @@ export async function* executeToolCalls(
const pendingEvents: Array<CustomEvent> = []
const context: ToolExecutionContext = {
toolCallId: toolCall.id,
userContext,
emitCustomEvent: (eventName: string, value: Record<string, any>) => {
if (createCustomEventChunk) {
pendingEvents.push(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export interface ClientTool<
TInput extends SchemaInput = SchemaInput,
TOutput extends SchemaInput = SchemaInput,
TName extends string = string,
> {
> extends Tool<TInput, TOutput, TName> {
__toolSide: 'client'
name: TName
description: string
Expand All @@ -36,6 +36,7 @@ export interface ClientTool<
metadata?: Record<string, unknown>
execute?: (
args: InferSchemaType<TInput>,
context?: ToolExecutionContext,
) => Promise<InferSchemaType<TOutput>> | InferSchemaType<TOutput>
}

Expand Down Expand Up @@ -125,6 +126,7 @@ export interface ToolDefinition<
client: (
execute?: (
args: InferSchemaType<TInput>,
context?: ToolExecutionContext,
) => Promise<InferSchemaType<TOutput>> | InferSchemaType<TOutput>,
) => ClientTool<TInput, TOutput, TName>
}
Expand Down Expand Up @@ -203,19 +205,22 @@ export function toolDefinition<
return {
__toolSide: 'server',
...config,
execute,
execute: (args, context) => execute(args, context),
}
},

client(
execute?: (
args: InferSchemaType<TInput>,
context?: ToolExecutionContext,
) => Promise<InferSchemaType<TOutput>> | InferSchemaType<TOutput>,
): ClientTool<TInput, TOutput, TName> {
return {
__toolSide: 'client',
...config,
execute,
execute: execute
? (args: any, context: any) => execute(args, context)
: undefined,
}
},
}
Expand Down
2 changes: 2 additions & 0 deletions packages/typescript/ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ export type {
ErrorInfo,
} from './activities/chat/middleware/index'

export type { ToolExecutionContext } from './types'

// All types
export * from './types'

Expand Down
45 changes: 40 additions & 5 deletions packages/typescript/ai/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ export type ConstrainedModelMessage<

/**
* Context passed to tool execute functions, providing capabilities like
* emitting custom events during execution.
* emitting custom events during execution and access to user-provided context.
*/
export interface ToolExecutionContext {
/** The ID of the tool call being executed */
Expand All @@ -361,15 +361,25 @@ export interface ToolExecutionContext {
* @example
* ```ts
* const tool = toolDefinition({ ... }).server(async (args, context) => {
* context?.emitCustomEvent('progress', { step: 1, total: 3 })
* context?.emitCustomEvent?.('progress', { step: 1, total: 3 })
* // ... do work ...
* context?.emitCustomEvent('progress', { step: 2, total: 3 })
* context?.emitCustomEvent?.('progress', { step: 2, total: 3 })
* // ... do more work ...
* return result
* })
* ```
*/
emitCustomEvent: (eventName: string, value: Record<string, any>) => void
emitCustomEvent?: (eventName: string, value: Record<string, any>) => void
/**
* User-provided context passed from chat() options.
* Allows tools to access shared context (e.g. database connections, user ID, request metadata)
* without needing to capture them via closures.
*
* @example
* chat({ context: { db, userId }, ... })
* // In tool: context?.userContext?.db.users.find(...)
*/
userContext?: unknown
}

/**
Expand Down Expand Up @@ -480,10 +490,12 @@ export interface Tool<
* Can return any value - will be automatically stringified if needed.
*
* @param args - The arguments parsed from the model's tool call (validated against inputSchema)
* @param context - SDK context providing toolCallId, emitCustomEvent, and userContext
* @returns Result to send back to the model (validated against outputSchema if provided)
*
* @example
* execute: async (args) => {
* execute: async (args, context) => {
* const user = await context?.userContext?.db.users.find({ id: context.userContext.userId }); // Can access user context
* const weather = await fetchWeather(args.location);
* return weather; // Can return object or string
* }
Expand Down Expand Up @@ -682,6 +694,29 @@ export interface TextOptions<
metadata?: Record<string, any>
modelOptions?: TProviderOptionsForModel
request?: Request | RequestInit
/**
* Context object that is automatically passed to all tool execute functions.
*
* This allows tools to access shared context (like user ID, database connections,
* request metadata, etc.) without needing to capture them via closures.
* Works for both server and client tools.
*
* @example
* const stream = chat({
* adapter: openai(),
* model: 'gpt-4o',
* messages,
* context: { userId: '123', db },
* tools: [getUserData],
* });
*
* // In tool definition:
* const getUserData = getUserDataDef.server(async (args, context) => {
* // context.userContext.userId and context.userContext.db are available
* return await context?.userContext?.db.users.find({ userId: context.userContext.userId });
* });
*/
context?: unknown

/**
* Schema for structured output.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { describe, expect, it, vi } from 'vitest'
import { z } from 'zod'
import { toolDefinition } from '../src/activities/chat/tools/tool-definition'
import { StreamProcessor } from '../src/activities/chat/stream/processor'
import { z } from 'zod'

describe('Custom Events Integration', () => {
it('should emit custom events from tool execution context', async () => {
Expand All @@ -16,7 +16,7 @@ describe('Custom Events Integration', () => {
}),
}).server(async (args, context) => {
// Emit progress event
context?.emitCustomEvent('tool:progress', {
context?.emitCustomEvent?.('tool:progress', {
tool: 'testTool',
progress: 25,
message: 'Starting processing...',
Expand All @@ -26,14 +26,14 @@ describe('Custom Events Integration', () => {
await new Promise((resolve) => setTimeout(resolve, 10))

// Emit another progress event
context?.emitCustomEvent('tool:progress', {
context?.emitCustomEvent?.('tool:progress', {
tool: 'testTool',
progress: 75,
message: 'Almost done...',
})

// Emit completion event
context?.emitCustomEvent('tool:complete', {
context?.emitCustomEvent?.('tool:complete', {
tool: 'testTool',
result: 'success',
duration: 20,
Expand Down
25 changes: 13 additions & 12 deletions packages/typescript/ai/tests/tool-call-manager.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ describe('ToolCallManager', () => {
inputSchema: z.object({
location: z.string().optional(),
}),
execute: vi.fn((args: any) => {
execute: vi.fn((args: any, _options?: any) => {
return JSON.stringify({ temp: 72, location: args.location })
}),
}
Expand Down Expand Up @@ -138,8 +138,8 @@ describe('ToolCallManager', () => {
expect(finalResult[0]?.role).toBe('tool')
expect(finalResult[0]?.toolCallId).toBe('call_123')

// Tool execute should have been called
expect(mockWeatherTool.execute).toHaveBeenCalledWith({ location: 'Paris' })
// Tool execute should have been called without context (none provided)
expect(mockWeatherTool.execute).toHaveBeenCalledWith({ location: 'Paris' }, undefined)
})

it('should handle tool execution errors gracefully', async () => {
Expand Down Expand Up @@ -240,15 +240,16 @@ describe('ToolCallManager', () => {
})

it('should handle multiple tool calls in same iteration', async () => {
const calculateTool: Tool = {
name: 'calculate',
description: 'Calculate',
inputSchema: z.object({
expression: z.string(),
}),
execute: vi.fn((args: any) => {
return JSON.stringify({ result: eval(args.expression) })
}),
const calculateTool: Tool = {
name: 'calculate',
description: 'Calculate',
inputSchema: z.object({
expression: z.string(),
}),
execute: vi.fn((args: any, _options?: any) => {
const results: Record<string, number> = { '5+3': 8 }
return JSON.stringify({ result: results[args.expression] ?? 0 })
}),
}

const manager = new ToolCallManager([mockWeatherTool, calculateTool])
Expand Down
Loading