44# pylint: disable=protected-access
55
66from typing import Any
7- from unittest .mock import AsyncMock , MagicMock
87
98import httpx
109import pytest
@@ -161,7 +160,7 @@ def test_convert_single_output_item(self, mocker: MockerFixture) -> None:
161160 "app.endpoints.a2a.extract_text_from_response_item" ,
162161 return_value = "Hello, world!" ,
163162 )
164- mock_output_item = MagicMock ()
163+ mock_output_item = mocker . MagicMock ()
165164 result = _convert_responses_content_to_a2a_parts ([mock_output_item ])
166165 assert len (result ) == 1
167166 assert result [0 ].root is not None
@@ -175,8 +174,8 @@ def test_convert_multiple_output_items(self, mocker: MockerFixture) -> None:
175174 )
176175 extract_mock .side_effect = ["First" , "Second" ]
177176
178- mock_item1 = MagicMock ()
179- mock_item2 = MagicMock ()
177+ mock_item1 = mocker . MagicMock ()
178+ mock_item2 = mocker . MagicMock ()
180179
181180 result = _convert_responses_content_to_a2a_parts ([mock_item1 , mock_item2 ])
182181 assert len (result ) == 2
@@ -194,7 +193,7 @@ def test_convert_output_items_with_none_text(self, mocker: MockerFixture) -> Non
194193 )
195194 extract_mock .side_effect = ["Valid text" , None , "Another valid" ]
196195
197- mock_items = [MagicMock (), MagicMock (), MagicMock ()]
196+ mock_items = [mocker . MagicMock (), mocker . MagicMock (), mocker . MagicMock ()]
198197
199198 result = _convert_responses_content_to_a2a_parts (mock_items )
200199 assert len (result ) == 2
@@ -470,14 +469,16 @@ def test_executor_initialization_default_mcp_headers(self) -> None:
470469 assert executor .mcp_headers == {}
471470
472471 @pytest .mark .asyncio
473- async def test_execute_without_message_raises_error (self ) -> None :
472+ async def test_execute_without_message_raises_error (
473+ self , mocker : MockerFixture
474+ ) -> None :
474475 """Test that execute raises error when message is missing."""
475476 executor = A2AAgentExecutor (auth_token = "test-token" )
476477
477- context = MagicMock (spec = RequestContext )
478+ context = mocker . MagicMock (spec = RequestContext )
478479 context .message = None
479480
480- event_queue = AsyncMock (spec = EventQueue )
481+ event_queue = mocker . AsyncMock (spec = EventQueue )
481482
482483 with pytest .raises (ValueError , match = "A2A request must have a message" ):
483484 await executor .execute (context , event_queue )
@@ -492,23 +493,23 @@ async def test_execute_creates_new_task(
492493 executor = A2AAgentExecutor (auth_token = "test-token" )
493494
494495 # Mock the context with a mock message
495- mock_message = MagicMock ()
496+ mock_message = mocker . MagicMock ()
496497 mock_message .role = "user"
497498 mock_message .parts = [Part (root = TextPart (text = "Hello" ))]
498499 mock_message .metadata = {}
499500
500- context = MagicMock (spec = RequestContext )
501+ context = mocker . MagicMock (spec = RequestContext )
501502 context .message = mock_message
502503 context .current_task = None
503504 context .task_id = None
504505 context .context_id = None
505506 context .get_user_input .return_value = "Hello"
506507
507508 # Mock event queue
508- event_queue = AsyncMock (spec = EventQueue )
509+ event_queue = mocker . AsyncMock (spec = EventQueue )
509510
510511 # Mock new_task to return a mock Task
511- mock_task = MagicMock ()
512+ mock_task = mocker . MagicMock ()
512513 mock_task .id = "test-task-id"
513514 mock_task .context_id = "test-context-id"
514515 mocker .patch ("app.endpoints.a2a.new_task" , return_value = mock_task )
@@ -517,7 +518,7 @@ async def test_execute_creates_new_task(
517518 mocker .patch .object (
518519 executor ,
519520 "_process_task_streaming" ,
520- new_callable = AsyncMock ,
521+ new_callable = mocker . AsyncMock ,
521522 )
522523
523524 await executor .execute (context , event_queue )
@@ -540,23 +541,23 @@ async def test_execute_passes_task_ids_to_streaming(
540541 executor = A2AAgentExecutor (auth_token = "test-token" )
541542
542543 # Mock the context with empty task_id and context_id (first-turn scenario)
543- mock_message = MagicMock ()
544+ mock_message = mocker . MagicMock ()
544545 mock_message .role = "user"
545546 mock_message .parts = [Part (root = TextPart (text = "Hello" ))]
546547 mock_message .metadata = {}
547548
548- context = MagicMock (spec = RequestContext )
549+ context = mocker . MagicMock (spec = RequestContext )
549550 context .message = mock_message
550551 context .current_task = None
551552 context .task_id = None # Empty in context object
552553 context .context_id = None # Empty in context object
553554 context .get_user_input .return_value = "Hello"
554555
555556 # Mock event queue
556- event_queue = AsyncMock (spec = EventQueue )
557+ event_queue = mocker . AsyncMock (spec = EventQueue )
557558
558559 # Mock new_task to return a task with specific IDs
559- mock_task = MagicMock ()
560+ mock_task = mocker . MagicMock ()
560561 mock_task .id = "computed-task-id-123"
561562 mock_task .context_id = "computed-context-id-456"
562563 mocker .patch ("app.endpoints.a2a.new_task" , return_value = mock_task )
@@ -565,7 +566,7 @@ async def test_execute_passes_task_ids_to_streaming(
565566 mock_process_streaming = mocker .patch .object (
566567 executor ,
567568 "_process_task_streaming" ,
568- new_callable = AsyncMock ,
569+ new_callable = mocker . AsyncMock ,
569570 )
570571
571572 await executor .execute (context , event_queue )
@@ -591,20 +592,20 @@ async def test_execute_handles_errors_gracefully(
591592 executor = A2AAgentExecutor (auth_token = "test-token" )
592593
593594 # Mock the context with a mock message
594- mock_message = MagicMock ()
595+ mock_message = mocker . MagicMock ()
595596 mock_message .role = "user"
596597 mock_message .parts = [Part (root = TextPart (text = "Hello" ))]
597598 mock_message .metadata = {}
598599
599- context = MagicMock (spec = RequestContext )
600+ context = mocker . MagicMock (spec = RequestContext )
600601 context .message = mock_message
601- context .current_task = MagicMock ()
602+ context .current_task = mocker . MagicMock ()
602603 context .task_id = "task-123"
603604 context .context_id = "ctx-456"
604605 context .get_user_input .return_value = "Hello"
605606
606607 # Mock event queue
607- event_queue = AsyncMock (spec = EventQueue )
608+ event_queue = mocker . AsyncMock (spec = EventQueue )
608609
609610 # Mock the streaming process to raise an error
610611 mocker .patch .object (
@@ -637,23 +638,23 @@ async def test_process_task_streaming_no_input(
637638 executor = A2AAgentExecutor (auth_token = "test-token" )
638639
639640 # Mock the context with no input
640- mock_message = MagicMock ()
641+ mock_message = mocker . MagicMock ()
641642 mock_message .role = "user"
642643 mock_message .parts = []
643644 mock_message .metadata = {}
644645
645- context = MagicMock (spec = RequestContext )
646+ context = mocker . MagicMock (spec = RequestContext )
646647 context .task_id = "task-123"
647648 context .context_id = "ctx-456"
648649 context .message = mock_message
649650 context .get_user_input .return_value = ""
650651
651652 # Mock event queue
652- event_queue = AsyncMock (spec = EventQueue )
653+ event_queue = mocker . AsyncMock (spec = EventQueue )
653654
654655 # Create task updater mock
655- task_updater = MagicMock ()
656- task_updater .update_status = AsyncMock ()
656+ task_updater = mocker . MagicMock ()
657+ task_updater .update_status = mocker . AsyncMock ()
657658 task_updater .event_queue = event_queue
658659
659660 await executor ._process_task_streaming (
@@ -675,34 +676,34 @@ async def test_process_task_streaming_handles_api_connection_error_on_models_lis
675676 executor = A2AAgentExecutor (auth_token = "test-token" )
676677
677678 # Mock the context with valid input
678- mock_message = MagicMock ()
679+ mock_message = mocker . MagicMock ()
679680 mock_message .role = "user"
680681 mock_message .parts = [Part (root = TextPart (text = "Hello" ))]
681682 mock_message .metadata = {}
682683
683- context = MagicMock (spec = RequestContext )
684+ context = mocker . MagicMock (spec = RequestContext )
684685 context .task_id = "task-123"
685686 context .context_id = "ctx-456"
686687 context .message = mock_message
687688 context .get_user_input .return_value = "Hello"
688689
689690 # Mock event queue
690- event_queue = AsyncMock (spec = EventQueue )
691+ event_queue = mocker . AsyncMock (spec = EventQueue )
691692
692693 # Create task updater mock
693- task_updater = MagicMock ()
694- task_updater .update_status = AsyncMock ()
694+ task_updater = mocker . MagicMock ()
695+ task_updater .update_status = mocker . AsyncMock ()
695696 task_updater .event_queue = event_queue
696697
697698 # Mock the context store
698- mock_context_store = AsyncMock ()
699+ mock_context_store = mocker . AsyncMock ()
699700 mock_context_store .get .return_value = None
700701 mocker .patch (
701702 "app.endpoints.a2a._get_context_store" , return_value = mock_context_store
702703 )
703704
704705 # Mock the client to raise APIConnectionError on models.list()
705- mock_client = AsyncMock ()
706+ mock_client = mocker . AsyncMock ()
706707 # Create a mock httpx.Request for APIConnectionError
707708 mock_request = httpx .Request ("GET" , "http://test-llama-stack/models" )
708709 mock_client .models .list .side_effect = APIConnectionError (
@@ -733,35 +734,35 @@ async def test_process_task_streaming_handles_api_connection_error_on_retrieve_r
733734 executor = A2AAgentExecutor (auth_token = "test-token" )
734735
735736 # Mock the context with valid input
736- mock_message = MagicMock ()
737+ mock_message = mocker . MagicMock ()
737738 mock_message .role = "user"
738739 mock_message .parts = [Part (root = TextPart (text = "Hello" ))]
739740 mock_message .metadata = {}
740741
741- context = MagicMock (spec = RequestContext )
742+ context = mocker . MagicMock (spec = RequestContext )
742743 context .task_id = "task-123"
743744 context .context_id = "ctx-456"
744745 context .message = mock_message
745746 context .get_user_input .return_value = "Hello"
746747
747748 # Mock event queue
748- event_queue = AsyncMock (spec = EventQueue )
749+ event_queue = mocker . AsyncMock (spec = EventQueue )
749750
750751 # Create task updater mock
751- task_updater = MagicMock ()
752- task_updater .update_status = AsyncMock ()
752+ task_updater = mocker . MagicMock ()
753+ task_updater .update_status = mocker . AsyncMock ()
753754 task_updater .event_queue = event_queue
754755
755756 # Mock the context store
756- mock_context_store = AsyncMock ()
757+ mock_context_store = mocker . AsyncMock ()
757758 mock_context_store .get .return_value = None
758759 mocker .patch (
759760 "app.endpoints.a2a._get_context_store" , return_value = mock_context_store
760761 )
761762
762763 # Mock the client to succeed on models.list()
763- mock_client = AsyncMock ()
764- mock_models = [MagicMock ()] # Return a list of models
764+ mock_client = mocker . AsyncMock ()
765+ mock_models = [mocker . MagicMock ()] # Return a list of models
765766 mock_client .models .list = mocker .AsyncMock (return_value = mock_models )
766767
767768 # Mock responses.create to raise APIConnectionError
@@ -801,12 +802,12 @@ async def test_process_task_streaming_handles_api_connection_error_on_retrieve_r
801802 assert "Unable to connect to Llama Stack backend service" in str (error_message )
802803
803804 @pytest .mark .asyncio
804- async def test_cancel_raises_not_implemented (self ) -> None :
805+ async def test_cancel_raises_not_implemented (self , mocker : MockerFixture ) -> None :
805806 """Test that cancel raises NotImplementedError."""
806807 executor = A2AAgentExecutor (auth_token = "test-token" )
807808
808- context = MagicMock (spec = RequestContext )
809- event_queue = AsyncMock (spec = EventQueue )
809+ context = mocker . MagicMock (spec = RequestContext )
810+ event_queue = mocker . AsyncMock (spec = EventQueue )
810811
811812 with pytest .raises (NotImplementedError ):
812813 await executor .cancel (context , event_queue )
0 commit comments