Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import asyncio
import base64
import contextvars
import json
import logging
import re
Expand Down Expand Up @@ -61,7 +60,6 @@ class MCPSpecificApproval(TypedDict, total=False):

_MCP_REMOTE_NAME_KEY = "_mcp_remote_name"
_MCP_NORMALIZED_NAME_KEY = "_mcp_normalized_name"
_mcp_call_headers: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar("_mcp_call_headers")
MCP_DEFAULT_TIMEOUT = 30
MCP_DEFAULT_SSE_READ_TIMEOUT = 60 * 5

Expand Down Expand Up @@ -1760,6 +1758,8 @@ def __init__(
self.terminate_on_close = terminate_on_close
self._httpx_client: AsyncClient | None = http_client
self._header_provider = header_provider
self._active_call_headers: dict[str, str] = {}
self._header_provider_call_lock = asyncio.Lock()

def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
"""Get an MCP streamable HTTP client.
Expand All @@ -1784,8 +1784,7 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
async def _inject_headers(request: Request) -> None: # noqa: RUF029
if _url_origin(request.url) != target_origin:
return
headers = _mcp_call_headers.get({})
for key, value in headers.items():
for key, value in self._active_call_headers.items():
request.headers[key] = value

self._inject_headers_hook = _inject_headers # type: ignore[attr-defined]
Expand All @@ -1802,8 +1801,8 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:

When a ``header_provider`` was supplied at construction time, the runtime
*kwargs* (originating from ``FunctionInvocationContext.kwargs``) are passed
to the provider. The returned headers are attached to every HTTP request
made during this tool call via a ``contextvars.ContextVar``.
to the provider. The returned headers are attached to every HTTP request
made during this tool call.

Args:
tool_name: The name of the tool to call.
Expand All @@ -1815,12 +1814,13 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
A list of Content items representing the tool output.
"""
if self._header_provider is not None:
headers = self._header_provider(kwargs)
token = _mcp_call_headers.set(headers)
try:
return await super().call_tool(tool_name, **kwargs)
finally:
_mcp_call_headers.reset(token)
async with self._header_provider_call_lock:
previous_headers = self._active_call_headers
self._active_call_headers = dict(self._header_provider(kwargs))
try:
return await super().call_tool(tool_name, **kwargs)
finally:
self._active_call_headers = previous_headers
return await super().call_tool(tool_name, **kwargs)


Expand Down
120 changes: 60 additions & 60 deletions python/packages/core/tests/core/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4639,19 +4639,13 @@ def provider(kwargs):
server.session.call_tool.assert_called_once()


async def test_mcp_streamable_http_tool_header_provider_sets_contextvar():
"""Test that call_tool sets the contextvar with headers from header_provider."""
from agent_framework._mcp import _mcp_call_headers

async def test_mcp_streamable_http_tool_header_provider_sets_active_headers():
"""Test that call_tool exposes headers from header_provider during the call."""
observed_headers: list[dict[str, str]] = []
original_call_tool = MCPTool.call_tool

async def spy_call_tool(self, tool_name, **kwargs):
# Capture the contextvar value during the super call
try:
observed_headers.append(_mcp_call_headers.get())
except LookupError:
observed_headers.append({})
observed_headers.append(dict(self._active_call_headers))
return await original_call_tool(self, tool_name, **kwargs)

class _TestServer(MCPStreamableHTTPTool):
Expand Down Expand Up @@ -4692,10 +4686,8 @@ def get_mcp_client(self):
assert observed_headers[0] == {"X-Auth": "bearer-xyz"}


async def test_mcp_streamable_http_tool_header_provider_contextvar_reset_after_call():
"""Test that the contextvar is properly reset after call_tool completes."""
from agent_framework._mcp import _mcp_call_headers

async def test_mcp_streamable_http_tool_header_provider_active_headers_reset_after_call():
"""Test that active headers are reset after call_tool completes."""
class _TestServer(MCPStreamableHTTPTool):
async def connect(self):
self.session = Mock(spec=ClientSession)
Expand Down Expand Up @@ -4728,9 +4720,7 @@ def get_mcp_client(self):
await server.load_tools()
await server.call_tool("greet", name="Alice", token="secret")

# After call_tool, the contextvar should be unset (reset to no value)
with pytest.raises(LookupError):
_mcp_call_headers.get()
assert server._active_call_headers == {}


async def test_mcp_streamable_http_tool_without_header_provider():
Expand Down Expand Up @@ -4773,10 +4763,10 @@ def get_mcp_client(self):


async def test_mcp_streamable_http_tool_header_provider_with_httpx_event_hook():
"""Test that the httpx event hook injects headers from the contextvar."""
"""Test that the httpx event hook injects the tool's active headers."""
import httpx

from agent_framework._mcp import MCP_DEFAULT_SSE_READ_TIMEOUT, MCP_DEFAULT_TIMEOUT, _mcp_call_headers
from agent_framework._mcp import MCP_DEFAULT_SSE_READ_TIMEOUT, MCP_DEFAULT_TIMEOUT

tool = MCPStreamableHTTPTool(
name="test",
Expand All @@ -4797,14 +4787,10 @@ async def test_mcp_streamable_http_tool_header_provider_with_httpx_event_hook():
hooks = tool._httpx_client.event_hooks.get("request", [])
assert len(hooks) == 1, "Expected one request event hook"

# Simulate what happens during a call_tool: contextvar is set
token = _mcp_call_headers.set({"X-Custom": "test-value"})
try:
request = httpx.Request("POST", "http://example.com/mcp")
await hooks[0](request)
assert request.headers.get("X-Custom") == "test-value"
finally:
_mcp_call_headers.reset(token)
tool._active_call_headers = {"X-Custom": "test-value"}
request = httpx.Request("POST", "http://example.com/mcp")
await hooks[0](request)
assert request.headers.get("X-Custom") == "test-value"
finally:
# Ensure any created httpx client is properly closed
if getattr(tool, "_httpx_client", None) is not None:
Expand All @@ -4815,8 +4801,6 @@ async def test_mcp_streamable_http_tool_header_provider_skips_cross_origin_redir
"""The request hook must not re-add caller headers after a cross-origin redirect."""
import httpx

from agent_framework._mcp import _mcp_call_headers

tool = MCPStreamableHTTPTool(
name="test",
url="http://example.com/mcp",
Expand All @@ -4831,17 +4815,45 @@ async def test_mcp_streamable_http_tool_header_provider_skips_cross_origin_redir
hooks = tool._httpx_client.event_hooks.get("request", [])
assert len(hooks) == 1

token = _mcp_call_headers.set({"Authorization": "Bearer secret"})
try:
same_origin = httpx.Request("POST", "http://example.com/redirected")
await hooks[0](same_origin)
assert same_origin.headers.get("Authorization") == "Bearer secret"

cross_origin = httpx.Request("POST", "http://attacker.example/capture")
await hooks[0](cross_origin)
assert "Authorization" not in cross_origin.headers
finally:
_mcp_call_headers.reset(token)
tool._active_call_headers = {"Authorization": "Bearer secret"}

same_origin = httpx.Request("POST", "http://example.com/redirected")
await hooks[0](same_origin)
assert same_origin.headers.get("Authorization") == "Bearer secret"

cross_origin = httpx.Request("POST", "http://attacker.example/capture")
await hooks[0](cross_origin)
assert "Authorization" not in cross_origin.headers
finally:
if getattr(tool, "_httpx_client", None) is not None:
await tool._httpx_client.aclose()


async def test_mcp_streamable_http_tool_header_provider_hook_reads_headers_from_transport_task():
"""Test that request hooks can read updated headers from another task."""
import httpx

tool = MCPStreamableHTTPTool(
name="test",
url="http://example.com/mcp",
header_provider=lambda kw: {"X-Custom": kw.get("custom", "")},
)

try:
with patch("agent_framework._mcp.streamable_http_client"):
tool.get_mcp_client()

assert tool._httpx_client is not None
hooks = tool._httpx_client.event_hooks.get("request", [])
assert len(hooks) == 1

async def run_hook_in_transport_task() -> str | None:
request = httpx.Request("POST", "http://example.com/mcp")
await hooks[0](request)
return request.headers.get("X-Custom")

tool._active_call_headers = {"X-Custom": "test-value"}
assert await asyncio.create_task(run_hook_in_transport_task()) == "test-value"
finally:
if getattr(tool, "_httpx_client", None) is not None:
await tool._httpx_client.aclose()
Expand All @@ -4851,8 +4863,6 @@ async def test_mcp_streamable_http_tool_header_provider_with_user_httpx_client()
"""Test that header_provider works when the user provides their own httpx client."""
import httpx

from agent_framework._mcp import _mcp_call_headers

user_client = httpx.AsyncClient(headers={"X-Base": "static"})

tool = MCPStreamableHTTPTool(
Expand All @@ -4870,14 +4880,10 @@ async def test_mcp_streamable_http_tool_header_provider_with_user_httpx_client()
hooks = user_client.event_hooks.get("request", [])
assert len(hooks) == 1

# Verify the hook injects headers
token = _mcp_call_headers.set({"X-Dynamic": "per-request"})
try:
request = httpx.Request("POST", "http://example.com/mcp")
await hooks[0](request)
assert request.headers.get("X-Dynamic") == "per-request"
finally:
_mcp_call_headers.reset(token)
tool._active_call_headers = {"X-Dynamic": "per-request"}
request = httpx.Request("POST", "http://example.com/mcp")
await hooks[0](request)
assert request.headers.get("X-Dynamic") == "per-request"

await user_client.aclose()

Expand All @@ -4888,19 +4894,12 @@ async def test_mcp_streamable_http_tool_header_provider_via_invoke_with_context(
This exercises the full pipeline: FunctionInvocationContext.kwargs -> FunctionTool.invoke
-> MCPStreamableHTTPTool.call_tool -> header_provider.
"""
from agent_framework._mcp import _mcp_call_headers

observed_headers: list[dict[str, str]] = []
original_call_tool = MCPStreamableHTTPTool.call_tool
original_call_tool = MCPTool.call_tool

async def spy_call_tool(self, tool_name, **kwargs):
# Capture the contextvar value set by call_tool before delegating
result = await original_call_tool(self, tool_name, **kwargs)
try:
observed_headers.append(_mcp_call_headers.get())
except LookupError:
observed_headers.append({})
return result
observed_headers.append(dict(self._active_call_headers))
return await original_call_tool(self, tool_name, **kwargs)

class _TestServer(MCPStreamableHTTPTool):
async def connect(self):
Expand Down Expand Up @@ -4951,7 +4950,7 @@ def provider(kwargs):
kwargs={"some_token": "my-secret"},
)

with patch.object(MCPStreamableHTTPTool, "call_tool", spy_call_tool):
with patch.object(MCPTool, "call_tool", spy_call_tool):
result = await func.invoke(arguments={"name": "Alice"}, context=context)

# Verify the invoke produced a result
Expand All @@ -4961,6 +4960,7 @@ def provider(kwargs):
# Verify header_provider was called with the runtime kwargs
assert len(provider_received) == 1
assert provider_received[0]["some_token"] == "my-secret"
assert observed_headers == [{"X-Some-Token": "my-secret"}]

# Verify session.call_tool was called with the tool arguments (not the runtime kwargs)
server.session.call_tool.assert_called_once()
Expand Down