Skip to content

Commit e07897c

Browse files
authored
Merge pull request #1251 from asamal4/avoid-unittest
[LCORE-1394] chore: enforce avoiding unittest and fix existing test cases
2 parents f95faa7 + 05ab874 commit e07897c

7 files changed

Lines changed: 188 additions & 157 deletions

File tree

pyproject.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,12 @@ disable = ["R0801"]
221221
[tool.ruff]
222222
extend-exclude = ["tests/profiles/syntax_error.py"]
223223

224-
[tool.ruff.lint.flake8-tidy-imports]
225-
banned-api = { "unittest" = { msg = "use pytest instead of unittest" }, "unittest.mock" = { msg = "use pytest-mock instead of unittest.mock" } }
224+
[tool.ruff.lint]
225+
extend-select = ["TID251"]
226+
227+
[tool.ruff.lint.flake8-tidy-imports.banned-api]
228+
unittest = { msg = "use pytest instead of unittest" }
229+
"unittest.mock" = { msg = "use pytest-mock instead of unittest.mock" }
226230

227231
[tool.mypy]
228232
explicit_package_bases = true

tests/unit/a2a_storage/test_storage_factory.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
# pylint: disable=protected-access
44

55
from pathlib import Path
6-
from typing import Generator
7-
from unittest.mock import PropertyMock
6+
from typing import Any, Generator
87

98
import pytest
109
from pytest_mock import MockerFixture
@@ -17,6 +16,16 @@
1716
from models.config import A2AStateConfiguration, SQLiteDatabaseConfiguration
1817

1918

19+
class _FakeProperty: # pylint: disable=too-few-public-methods
20+
"""Descriptor that returns a fixed value (like PropertyMock)."""
21+
22+
def __init__(self, value: Any) -> None:
23+
self._value = value
24+
25+
def __get__(self, obj: Any, owner: Any = None) -> Any:
26+
return self._value
27+
28+
2029
class TestA2AStorageFactory:
2130
"""Tests for A2AStorageFactory."""
2231

@@ -127,12 +136,11 @@ async def test_invalid_storage_type_raises_error(
127136
"""Test that an invalid storage type raises ValueError."""
128137
config = A2AStateConfiguration()
129138

130-
# Mock the storage_type property to return an invalid value
139+
# Replace property on class so config.storage_type returns "invalid"
131140
mocker.patch.object(
132141
A2AStateConfiguration,
133142
"storage_type",
134-
new_callable=PropertyMock,
135-
return_value="invalid",
143+
_FakeProperty("invalid"),
136144
)
137145
with pytest.raises(ValueError, match="Unknown A2A state type"):
138146
await A2AStorageFactory.create_task_store(config)
@@ -144,12 +152,11 @@ async def test_sqlite_storage_type_without_config_raises_error(
144152
"""Test that SQLite storage type without config raises ValueError."""
145153
config = A2AStateConfiguration()
146154

147-
# Mock to simulate misconfiguration
155+
# Replace property on class so config.storage_type returns "sqlite"
148156
mocker.patch.object(
149157
A2AStateConfiguration,
150158
"storage_type",
151-
new_callable=PropertyMock,
152-
return_value="sqlite",
159+
_FakeProperty("sqlite"),
153160
)
154161
with pytest.raises(ValueError, match="SQLite configuration required"):
155162
await A2AStorageFactory.create_task_store(config)
@@ -161,12 +168,11 @@ async def test_postgres_storage_type_without_config_raises_error(
161168
"""Test that PostgreSQL storage type without config raises ValueError."""
162169
config = A2AStateConfiguration()
163170

164-
# Mock to simulate misconfiguration
171+
# Replace property on class so config.storage_type returns "postgres"
165172
mocker.patch.object(
166173
A2AStateConfiguration,
167174
"storage_type",
168-
new_callable=PropertyMock,
169-
return_value="postgres",
175+
_FakeProperty("postgres"),
170176
)
171177
with pytest.raises(ValueError, match="PostgreSQL configuration required"):
172178
await A2AStorageFactory.create_task_store(config)

tests/unit/app/endpoints/test_a2a.py

Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# pylint: disable=protected-access
55

66
from typing import Any
7-
from unittest.mock import AsyncMock, MagicMock
87

98
import httpx
109
import pytest
@@ -161,7 +160,7 @@ def test_convert_single_output_item(self, mocker: MockerFixture) -> None:
161160
"app.endpoints.a2a.extract_text_from_response_item",
162161
return_value="Hello, world!",
163162
)
164-
mock_output_item = MagicMock()
163+
mock_output_item = mocker.MagicMock()
165164
result = _convert_responses_content_to_a2a_parts([mock_output_item])
166165
assert len(result) == 1
167166
assert result[0].root is not None
@@ -175,8 +174,8 @@ def test_convert_multiple_output_items(self, mocker: MockerFixture) -> None:
175174
)
176175
extract_mock.side_effect = ["First", "Second"]
177176

178-
mock_item1 = MagicMock()
179-
mock_item2 = MagicMock()
177+
mock_item1 = mocker.MagicMock()
178+
mock_item2 = mocker.MagicMock()
180179

181180
result = _convert_responses_content_to_a2a_parts([mock_item1, mock_item2])
182181
assert len(result) == 2
@@ -194,7 +193,7 @@ def test_convert_output_items_with_none_text(self, mocker: MockerFixture) -> Non
194193
)
195194
extract_mock.side_effect = ["Valid text", None, "Another valid"]
196195

197-
mock_items = [MagicMock(), MagicMock(), MagicMock()]
196+
mock_items = [mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()]
198197

199198
result = _convert_responses_content_to_a2a_parts(mock_items)
200199
assert len(result) == 2
@@ -470,14 +469,16 @@ def test_executor_initialization_default_mcp_headers(self) -> None:
470469
assert executor.mcp_headers == {}
471470

472471
@pytest.mark.asyncio
473-
async def test_execute_without_message_raises_error(self) -> None:
472+
async def test_execute_without_message_raises_error(
473+
self, mocker: MockerFixture
474+
) -> None:
474475
"""Test that execute raises error when message is missing."""
475476
executor = A2AAgentExecutor(auth_token="test-token")
476477

477-
context = MagicMock(spec=RequestContext)
478+
context = mocker.MagicMock(spec=RequestContext)
478479
context.message = None
479480

480-
event_queue = AsyncMock(spec=EventQueue)
481+
event_queue = mocker.AsyncMock(spec=EventQueue)
481482

482483
with pytest.raises(ValueError, match="A2A request must have a message"):
483484
await executor.execute(context, event_queue)
@@ -492,23 +493,23 @@ async def test_execute_creates_new_task(
492493
executor = A2AAgentExecutor(auth_token="test-token")
493494

494495
# Mock the context with a mock message
495-
mock_message = MagicMock()
496+
mock_message = mocker.MagicMock()
496497
mock_message.role = "user"
497498
mock_message.parts = [Part(root=TextPart(text="Hello"))]
498499
mock_message.metadata = {}
499500

500-
context = MagicMock(spec=RequestContext)
501+
context = mocker.MagicMock(spec=RequestContext)
501502
context.message = mock_message
502503
context.current_task = None
503504
context.task_id = None
504505
context.context_id = None
505506
context.get_user_input.return_value = "Hello"
506507

507508
# Mock event queue
508-
event_queue = AsyncMock(spec=EventQueue)
509+
event_queue = mocker.AsyncMock(spec=EventQueue)
509510

510511
# Mock new_task to return a mock Task
511-
mock_task = MagicMock()
512+
mock_task = mocker.MagicMock()
512513
mock_task.id = "test-task-id"
513514
mock_task.context_id = "test-context-id"
514515
mocker.patch("app.endpoints.a2a.new_task", return_value=mock_task)
@@ -517,7 +518,7 @@ async def test_execute_creates_new_task(
517518
mocker.patch.object(
518519
executor,
519520
"_process_task_streaming",
520-
new_callable=AsyncMock,
521+
new_callable=mocker.AsyncMock,
521522
)
522523

523524
await executor.execute(context, event_queue)
@@ -540,23 +541,23 @@ async def test_execute_passes_task_ids_to_streaming(
540541
executor = A2AAgentExecutor(auth_token="test-token")
541542

542543
# Mock the context with empty task_id and context_id (first-turn scenario)
543-
mock_message = MagicMock()
544+
mock_message = mocker.MagicMock()
544545
mock_message.role = "user"
545546
mock_message.parts = [Part(root=TextPart(text="Hello"))]
546547
mock_message.metadata = {}
547548

548-
context = MagicMock(spec=RequestContext)
549+
context = mocker.MagicMock(spec=RequestContext)
549550
context.message = mock_message
550551
context.current_task = None
551552
context.task_id = None # Empty in context object
552553
context.context_id = None # Empty in context object
553554
context.get_user_input.return_value = "Hello"
554555

555556
# Mock event queue
556-
event_queue = AsyncMock(spec=EventQueue)
557+
event_queue = mocker.AsyncMock(spec=EventQueue)
557558

558559
# Mock new_task to return a task with specific IDs
559-
mock_task = MagicMock()
560+
mock_task = mocker.MagicMock()
560561
mock_task.id = "computed-task-id-123"
561562
mock_task.context_id = "computed-context-id-456"
562563
mocker.patch("app.endpoints.a2a.new_task", return_value=mock_task)
@@ -565,7 +566,7 @@ async def test_execute_passes_task_ids_to_streaming(
565566
mock_process_streaming = mocker.patch.object(
566567
executor,
567568
"_process_task_streaming",
568-
new_callable=AsyncMock,
569+
new_callable=mocker.AsyncMock,
569570
)
570571

571572
await executor.execute(context, event_queue)
@@ -591,20 +592,20 @@ async def test_execute_handles_errors_gracefully(
591592
executor = A2AAgentExecutor(auth_token="test-token")
592593

593594
# Mock the context with a mock message
594-
mock_message = MagicMock()
595+
mock_message = mocker.MagicMock()
595596
mock_message.role = "user"
596597
mock_message.parts = [Part(root=TextPart(text="Hello"))]
597598
mock_message.metadata = {}
598599

599-
context = MagicMock(spec=RequestContext)
600+
context = mocker.MagicMock(spec=RequestContext)
600601
context.message = mock_message
601-
context.current_task = MagicMock()
602+
context.current_task = mocker.MagicMock()
602603
context.task_id = "task-123"
603604
context.context_id = "ctx-456"
604605
context.get_user_input.return_value = "Hello"
605606

606607
# Mock event queue
607-
event_queue = AsyncMock(spec=EventQueue)
608+
event_queue = mocker.AsyncMock(spec=EventQueue)
608609

609610
# Mock the streaming process to raise an error
610611
mocker.patch.object(
@@ -637,23 +638,23 @@ async def test_process_task_streaming_no_input(
637638
executor = A2AAgentExecutor(auth_token="test-token")
638639

639640
# Mock the context with no input
640-
mock_message = MagicMock()
641+
mock_message = mocker.MagicMock()
641642
mock_message.role = "user"
642643
mock_message.parts = []
643644
mock_message.metadata = {}
644645

645-
context = MagicMock(spec=RequestContext)
646+
context = mocker.MagicMock(spec=RequestContext)
646647
context.task_id = "task-123"
647648
context.context_id = "ctx-456"
648649
context.message = mock_message
649650
context.get_user_input.return_value = ""
650651

651652
# Mock event queue
652-
event_queue = AsyncMock(spec=EventQueue)
653+
event_queue = mocker.AsyncMock(spec=EventQueue)
653654

654655
# Create task updater mock
655-
task_updater = MagicMock()
656-
task_updater.update_status = AsyncMock()
656+
task_updater = mocker.MagicMock()
657+
task_updater.update_status = mocker.AsyncMock()
657658
task_updater.event_queue = event_queue
658659

659660
await executor._process_task_streaming(
@@ -675,34 +676,34 @@ async def test_process_task_streaming_handles_api_connection_error_on_models_lis
675676
executor = A2AAgentExecutor(auth_token="test-token")
676677

677678
# Mock the context with valid input
678-
mock_message = MagicMock()
679+
mock_message = mocker.MagicMock()
679680
mock_message.role = "user"
680681
mock_message.parts = [Part(root=TextPart(text="Hello"))]
681682
mock_message.metadata = {}
682683

683-
context = MagicMock(spec=RequestContext)
684+
context = mocker.MagicMock(spec=RequestContext)
684685
context.task_id = "task-123"
685686
context.context_id = "ctx-456"
686687
context.message = mock_message
687688
context.get_user_input.return_value = "Hello"
688689

689690
# Mock event queue
690-
event_queue = AsyncMock(spec=EventQueue)
691+
event_queue = mocker.AsyncMock(spec=EventQueue)
691692

692693
# Create task updater mock
693-
task_updater = MagicMock()
694-
task_updater.update_status = AsyncMock()
694+
task_updater = mocker.MagicMock()
695+
task_updater.update_status = mocker.AsyncMock()
695696
task_updater.event_queue = event_queue
696697

697698
# Mock the context store
698-
mock_context_store = AsyncMock()
699+
mock_context_store = mocker.AsyncMock()
699700
mock_context_store.get.return_value = None
700701
mocker.patch(
701702
"app.endpoints.a2a._get_context_store", return_value=mock_context_store
702703
)
703704

704705
# Mock the client to raise APIConnectionError on models.list()
705-
mock_client = AsyncMock()
706+
mock_client = mocker.AsyncMock()
706707
# Create a mock httpx.Request for APIConnectionError
707708
mock_request = httpx.Request("GET", "http://test-llama-stack/models")
708709
mock_client.models.list.side_effect = APIConnectionError(
@@ -733,35 +734,35 @@ async def test_process_task_streaming_handles_api_connection_error_on_retrieve_r
733734
executor = A2AAgentExecutor(auth_token="test-token")
734735

735736
# Mock the context with valid input
736-
mock_message = MagicMock()
737+
mock_message = mocker.MagicMock()
737738
mock_message.role = "user"
738739
mock_message.parts = [Part(root=TextPart(text="Hello"))]
739740
mock_message.metadata = {}
740741

741-
context = MagicMock(spec=RequestContext)
742+
context = mocker.MagicMock(spec=RequestContext)
742743
context.task_id = "task-123"
743744
context.context_id = "ctx-456"
744745
context.message = mock_message
745746
context.get_user_input.return_value = "Hello"
746747

747748
# Mock event queue
748-
event_queue = AsyncMock(spec=EventQueue)
749+
event_queue = mocker.AsyncMock(spec=EventQueue)
749750

750751
# Create task updater mock
751-
task_updater = MagicMock()
752-
task_updater.update_status = AsyncMock()
752+
task_updater = mocker.MagicMock()
753+
task_updater.update_status = mocker.AsyncMock()
753754
task_updater.event_queue = event_queue
754755

755756
# Mock the context store
756-
mock_context_store = AsyncMock()
757+
mock_context_store = mocker.AsyncMock()
757758
mock_context_store.get.return_value = None
758759
mocker.patch(
759760
"app.endpoints.a2a._get_context_store", return_value=mock_context_store
760761
)
761762

762763
# Mock the client to succeed on models.list()
763-
mock_client = AsyncMock()
764-
mock_models = [MagicMock()] # Return a list of models
764+
mock_client = mocker.AsyncMock()
765+
mock_models = [mocker.MagicMock()] # Return a list of models
765766
mock_client.models.list = mocker.AsyncMock(return_value=mock_models)
766767

767768
# Mock responses.create to raise APIConnectionError
@@ -801,12 +802,12 @@ async def test_process_task_streaming_handles_api_connection_error_on_retrieve_r
801802
assert "Unable to connect to Llama Stack backend service" in str(error_message)
802803

803804
@pytest.mark.asyncio
804-
async def test_cancel_raises_not_implemented(self) -> None:
805+
async def test_cancel_raises_not_implemented(self, mocker: MockerFixture) -> None:
805806
"""Test that cancel raises NotImplementedError."""
806807
executor = A2AAgentExecutor(auth_token="test-token")
807808

808-
context = MagicMock(spec=RequestContext)
809-
event_queue = AsyncMock(spec=EventQueue)
809+
context = mocker.MagicMock(spec=RequestContext)
810+
event_queue = mocker.AsyncMock(spec=EventQueue)
810811

811812
with pytest.raises(NotImplementedError):
812813
await executor.cancel(context, event_queue)

0 commit comments

Comments
 (0)