Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,21 @@ async def on_cancel_task(

consumer = EventConsumer(queue)
result = await result_aggregator.consume_all(consumer)
if isinstance(result, Task):
return result
if not isinstance(result, Task):
raise ServerError(
error=InternalError(
message='Agent did not return valid response for cancel'
)
)

raise ServerError(
error=InternalError(
message='Agent did not return valid response for cancel'
if result.status.state != TaskState.canceled:
raise ServerError(
error=TaskNotCancelableError(
message=f'Task cannot be canceled - current state: {result.status.state}'
)
)
)

return result

async def _run_event_stream(
self, request: RequestContext, queue: EventQueue
Expand Down
50 changes: 50 additions & 0 deletions tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,56 @@ async def test_on_cancel_task_cancels_running_agent():
mock_agent_executor.cancel.assert_awaited_once()


@pytest.mark.asyncio
async def test_on_cancel_task_completes_during_cancellation():
"""Test on_cancel_task fails to cancel a task due to concurrent task completion."""
task_id = 'running_agent_task_to_cancel'
sample_task = create_sample_task(task_id=task_id)
mock_task_store = AsyncMock(spec=TaskStore)
mock_task_store.get.return_value = sample_task

mock_queue_manager = AsyncMock(spec=QueueManager)
mock_event_queue = AsyncMock(spec=EventQueue)
mock_queue_manager.tap.return_value = mock_event_queue

mock_agent_executor = AsyncMock(spec=AgentExecutor)

# Mock ResultAggregator
mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator)
mock_result_aggregator_instance.consume_all.return_value = (
create_sample_task(task_id=task_id, status_state=TaskState.completed)
)

request_handler = DefaultRequestHandler(
agent_executor=mock_agent_executor,
task_store=mock_task_store,
queue_manager=mock_queue_manager,
)

# Simulate a running agent task
mock_producer_task = AsyncMock(spec=asyncio.Task)
request_handler._running_agents[task_id] = mock_producer_task

from a2a.utils.errors import (
ServerError, # Local import
TaskNotCancelableError, # Local import
)

with patch(
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
return_value=mock_result_aggregator_instance,
):
params = TaskIdParams(id=task_id)
with pytest.raises(ServerError) as exc_info:
await request_handler.on_cancel_task(
params, create_server_call_context()
)

mock_producer_task.cancel.assert_called_once()
mock_agent_executor.cancel.assert_awaited_once()
assert isinstance(exc_info.value.error, TaskNotCancelableError)


@pytest.mark.asyncio
async def test_on_cancel_task_invalid_result_type():
"""Test on_cancel_task when result_aggregator returns a Message instead of a Task."""
Expand Down
2 changes: 2 additions & 0 deletions tests/server/request_handlers/test_jsonrpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ async def test_on_cancel_task_success(self) -> None:
call_context = ServerCallContext(state={'foo': 'bar'})

async def streaming_coro():
mock_task.status.state = TaskState.canceled
yield mock_task

with patch(
Expand All @@ -160,6 +161,7 @@ async def streaming_coro():
assert mock_agent_executor.cancel.call_count == 1
self.assertIsInstance(response.root, CancelTaskSuccessResponse)
assert response.root.result == mock_task # type: ignore
assert response.root.result.status.state == TaskState.canceled
mock_agent_executor.cancel.assert_called_once()

async def test_on_cancel_task_not_supported(self) -> None:
Expand Down
Loading