diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index fd378cf47..6a38933f6 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -180,14 +180,21 @@ async def on_cancel_task( consumer = EventConsumer(queue) result = await result_aggregator.consume_all(consumer) - if isinstance(result, Task): - return result + if not isinstance(result, Task): + raise ServerError( + error=InternalError( + message='Agent did not return valid response for cancel' + ) + ) - raise ServerError( - error=InternalError( - message='Agent did not return valid response for cancel' + if result.status.state != TaskState.canceled: + raise ServerError( + error=TaskNotCancelableError( + message=f'Task cannot be canceled - current state: {result.status.state}' + ) ) - ) + + return result async def _run_event_stream( self, request: RequestContext, queue: EventQueue diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index e8906554a..88fb7d3e5 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -263,6 +263,56 @@ async def test_on_cancel_task_cancels_running_agent(): mock_agent_executor.cancel.assert_awaited_once() +@pytest.mark.asyncio +async def test_on_cancel_task_completes_during_cancellation(): + """Test on_cancel_task fails to cancel a task due to concurrent task completion.""" + task_id = 'running_agent_task_to_cancel' + sample_task = create_sample_task(task_id=task_id) + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = sample_task + + mock_queue_manager = AsyncMock(spec=QueueManager) + mock_event_queue = AsyncMock(spec=EventQueue) + mock_queue_manager.tap.return_value = mock_event_queue + + mock_agent_executor = AsyncMock(spec=AgentExecutor) + + # Mock ResultAggregator + mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator) + mock_result_aggregator_instance.consume_all.return_value = ( + create_sample_task(task_id=task_id, status_state=TaskState.completed) + ) + + request_handler = DefaultRequestHandler( + agent_executor=mock_agent_executor, + task_store=mock_task_store, + queue_manager=mock_queue_manager, + ) + + # Simulate a running agent task + mock_producer_task = AsyncMock(spec=asyncio.Task) + request_handler._running_agents[task_id] = mock_producer_task + + from a2a.utils.errors import ( + ServerError, # Local import + TaskNotCancelableError, # Local import + ) + + with patch( + 'a2a.server.request_handlers.default_request_handler.ResultAggregator', + return_value=mock_result_aggregator_instance, + ): + params = TaskIdParams(id=task_id) + with pytest.raises(ServerError) as exc_info: + await request_handler.on_cancel_task( + params, create_server_call_context() + ) + + mock_producer_task.cancel.assert_called_once() + mock_agent_executor.cancel.assert_awaited_once() + assert isinstance(exc_info.value.error, TaskNotCancelableError) + + @pytest.mark.asyncio async def test_on_cancel_task_invalid_result_type(): """Test on_cancel_task when result_aggregator returns a Message instead of a Task.""" diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index 19cf8be06..1d1b3c5d1 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -149,6 +149,7 @@ async def test_on_cancel_task_success(self) -> None: call_context = ServerCallContext(state={'foo': 'bar'}) async def streaming_coro(): + mock_task.status.state = TaskState.canceled yield mock_task with patch( @@ -160,6 +161,7 @@ async def streaming_coro(): assert mock_agent_executor.cancel.call_count == 1 self.assertIsInstance(response.root, CancelTaskSuccessResponse) assert response.root.result == mock_task # type: ignore + assert response.root.result.status.state == TaskState.canceled mock_agent_executor.cancel.assert_called_once() async def test_on_cancel_task_not_supported(self) -> None: