Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ async def get_task(
params = MessageToDict(request)
if 'id' in params:
del params['id'] # id is part of the URL path
if 'tenant' in params:
del params['tenant']

response_data = await self._execute_request(
'GET',
Expand All @@ -127,12 +129,16 @@ async def list_tasks(
context: ClientCallContext | None = None,
) -> ListTasksResponse:
"""Retrieves tasks for an agent."""
params = MessageToDict(request)
if 'tenant' in params:
del params['tenant']

response_data = await self._execute_request(
'GET',
'/tasks',
request.tenant,
context=context,
params=MessageToDict(request),
params=params,
)
response: ListTasksResponse = ParseDict(
response_data, ListTasksResponse()
Expand Down Expand Up @@ -185,8 +191,10 @@ async def get_task_push_notification_config(
params = MessageToDict(request)
if 'id' in params:
del params['id']
if 'task_id' in params:
del params['task_id']
if 'taskId' in params:
del params['taskId']
if 'tenant' in params:
del params['tenant']

response_data = await self._execute_request(
'GET',
Expand All @@ -208,8 +216,10 @@ async def list_task_push_notification_configs(
) -> ListTaskPushNotificationConfigsResponse:
"""Lists push notification configurations for a specific task."""
params = MessageToDict(request)
if 'task_id' in params:
del params['task_id']
if 'taskId' in params:
del params['taskId']
if 'tenant' in params:
del params['tenant']

response_data = await self._execute_request(
'GET',
Expand All @@ -233,8 +243,10 @@ async def delete_task_push_notification_config(
params = MessageToDict(request)
if 'id' in params:
del params['id']
if 'task_id' in params:
del params['task_id']
if 'taskId' in params:
del params['taskId']
if 'tenant' in params:
del params['tenant']

await self._execute_request(
'DELETE',
Expand Down
30 changes: 9 additions & 21 deletions src/a2a/server/request_handlers/rest_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
MessageToDict,
MessageToJson,
Parse,
ParseDict,
)


Expand All @@ -27,7 +26,6 @@
AgentCard,
CancelTaskRequest,
GetTaskPushNotificationConfigRequest,
GetTaskRequest,
SubscribeToTaskRequest,
)
from a2a.utils import proto_utils
Expand Down Expand Up @@ -220,12 +218,11 @@ async def set_push_notification(
(due to the `@validate` decorator), A2AError if processing error is
found.
"""
task_id = request.path_params['id']
body = await request.body()
params = a2a_pb2.TaskPushNotificationConfig()
Parse(body, params)
# Set the parent to the task resource name format
params.task_id = task_id
params.task_id = request.path_params['id']
config = (
await self.request_handler.on_create_task_push_notification_config(
params, context
Expand All @@ -247,10 +244,9 @@ async def on_get_task(
Returns:
A `Task` object containing the Task.
"""
task_id = request.path_params['id']
history_length_str = request.query_params.get('historyLength')
history_length = int(history_length_str) if history_length_str else None
params = GetTaskRequest(id=task_id, history_length=history_length)
params = a2a_pb2.GetTaskRequest()
proto_utils.parse_params(request.query_params, params)
params.id = request.path_params['id']
task = await self.request_handler.on_get_task(params, context)
if task:
return MessageToDict(task)
Expand Down Expand Up @@ -295,12 +291,8 @@ async def list_tasks(
A list of `dict` representing the `Task` objects.
"""
params = a2a_pb2.ListTasksRequest()
# Parse query params, keeping arrays/repeated fields in mind if there are any
# Using a simple ParseDict for now, might need more robust query param parsing
# if the request structure contains nested or repeated elements
ParseDict(
dict(request.query_params), params, ignore_unknown_fields=True
)
proto_utils.parse_params(request.query_params, params)

result = await self.request_handler.on_list_tasks(params, context)
return MessageToDict(result)

Expand All @@ -318,13 +310,9 @@ async def list_push_notifications(
Returns:
A list of `dict` representing the `TaskPushNotificationConfig` objects.
"""
task_id = request.path_params['id']
params = a2a_pb2.ListTaskPushNotificationConfigsRequest(task_id=task_id)

# Parse query params, keeping arrays/repeated fields in mind if there are any
ParseDict(
dict(request.query_params), params, ignore_unknown_fields=True
)
params = a2a_pb2.ListTaskPushNotificationConfigsRequest()
proto_utils.parse_params(request.query_params, params)
params.task_id = request.path_params['id']

result = (
await self.request_handler.on_list_task_push_notification_configs(
Expand Down
60 changes: 59 additions & 1 deletion src/a2a/utils/proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,19 @@
This module provides helper functions for common proto type operations.
"""

from typing import Any
from typing import TYPE_CHECKING, Any

from google.protobuf.json_format import ParseDict
from google.protobuf.message import Message as ProtobufMessage


if TYPE_CHECKING:
from starlette.datastructures import QueryParams
else:
try:
from starlette.datastructures import QueryParams
except ImportError:
QueryParams = Any

from a2a.types.a2a_pb2 import (
Message,
Expand Down Expand Up @@ -131,3 +143,49 @@ def parse_string_integers_in_dict(value: Any, max_safe_digits: int = 15) -> Any:
if stripped_value.isdigit() and len(stripped_value) > max_safe_digits:
return int(value)
return value


def parse_params(params: QueryParams, message: ProtobufMessage) -> None:
"""Converts REST query parameters back into a Protobuf message.

Handles A2A-specific pre-processing before calling ParseDict:
- Booleans: 'true'/'false' -> True/False
- Repeated: Supports BOTH repeated keys and comma-separated values.
- Others: Handles string->enum/timestamp/number conversion via ParseDict.

See Also:
https://a2a-protocol.org/latest/specification/#115-query-parameter-naming-for-request-parameters
"""
descriptor = message.DESCRIPTOR
fields = {f.camelcase_name: f for f in descriptor.fields}
processed: dict[str, Any] = {}

keys = params.keys()

for k in keys:
if k not in fields:
continue

field = fields[k]
v_list = params.getlist(k)

if field.label == field.LABEL_REPEATED:
accumulated: list[Any] = []
for v in v_list:
if not v:
continue
if isinstance(v, str):
accumulated.extend([x for x in v.split(',') if x])
else:
accumulated.append(v)
processed[k] = accumulated
else:
# For non-repeated fields, the last one wins.
raw_val = v_list[-1]
if raw_val is not None:
parsed_val: Any = raw_val
if field.type == field.TYPE_BOOL and isinstance(raw_val, str):
parsed_val = raw_val.lower() == 'true'
processed[k] = parsed_val

ParseDict(processed, message, ignore_unknown_fields=True)
48 changes: 45 additions & 3 deletions tests/client/transports/test_rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from google.protobuf import json_format
from google.protobuf.timestamp_pb2 import Timestamp
from httpx_sse import EventSource, ServerSentEvent

from a2a.client import create_text_message_object
Expand All @@ -16,16 +17,16 @@
AgentCard,
AgentInterface,
CancelTaskRequest,
TaskPushNotificationConfig,
DeleteTaskPushNotificationConfigRequest,
GetExtendedAgentCardRequest,
GetTaskPushNotificationConfigRequest,
GetTaskRequest,
ListTaskPushNotificationConfigsRequest,
ListTasksRequest,
Message,
SendMessageRequest,
SubscribeToTaskRequest,
TaskPushNotificationConfig,
TaskState,
)
from a2a.utils.constants import TransportProtocol
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP
Expand Down Expand Up @@ -175,6 +176,47 @@ async def test_send_message_with_timeout_context(
assert 'timeout' in kwargs
assert kwargs['timeout'] == httpx.Timeout(10.0)

@pytest.mark.asyncio
async def test_url_serialization(
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
):
"""Test that query parameters are correctly serialized to the URL."""
client = RestTransport(
httpx_client=mock_httpx_client,
agent_card=mock_agent_card,
url='http://agent.example.com/api',
)

timestamp = Timestamp()
timestamp.FromJsonString('2024-03-09T16:00:00Z')

request = ListTasksRequest(
tenant='my-tenant',
status=TaskState.TASK_STATE_WORKING,
include_artifacts=True,
status_timestamp_after=timestamp,
)

# Use real build_request to get actual URL serialization
mock_httpx_client.build_request.side_effect = (
httpx.AsyncClient().build_request
)
mock_httpx_client.send.return_value = AsyncMock(
spec=httpx.Response, status_code=200, json=lambda: {'tasks': []}
)

await client.list_tasks(request=request)

mock_httpx_client.send.assert_called_once()
sent_request = mock_httpx_client.send.call_args[0][0]

# Check decoded query parameters for spec compliance
params = sent_request.url.params
assert params['status'] == 'TASK_STATE_WORKING'
assert params['includeArtifacts'] == 'true'
assert params['statusTimestampAfter'] == '2024-03-09T16:00:00Z'
assert 'tenant' not in params


class TestRestTransportExtensions:
@pytest.mark.asyncio
Expand Down Expand Up @@ -616,7 +658,7 @@ async def test_rest_get_task_prepend_empty_tenant(

# 3. Verify the URL
args, _ = mock_httpx_client.build_request.call_args
assert args[1] == f'http://agent.example.com/api/tasks/task-123'
assert args[1] == 'http://agent.example.com/api/tasks/task-123'

@pytest.mark.parametrize(
'method_name, request_obj, expected_path',
Expand Down
Loading
Loading