diff --git a/src/a2a/server/agent_execution/simple_request_context_builder.py b/src/a2a/server/agent_execution/simple_request_context_builder.py index 9a1223afa..7a2f85148 100644 --- a/src/a2a/server/agent_execution/simple_request_context_builder.py +++ b/src/a2a/server/agent_execution/simple_request_context_builder.py @@ -68,7 +68,9 @@ async def build( ): tasks = await asyncio.gather( *[ - self._task_store.get(task_id) + self._task_store.get( + task_id, context or ServerCallContext() + ) for task_id in params.message.reference_task_ids ] ) diff --git a/src/a2a/server/owner_resolver.py b/src/a2a/server/owner_resolver.py index 798eb8c9b..4fca42d24 100644 --- a/src/a2a/server/owner_resolver.py +++ b/src/a2a/server/owner_resolver.py @@ -4,13 +4,10 @@ # Definition -OwnerResolver = Callable[[ServerCallContext | None], str] +OwnerResolver = Callable[[ServerCallContext], str] # Example Default Implementation -def resolve_user_scope(context: ServerCallContext | None) -> str: +def resolve_user_scope(context: ServerCallContext) -> str: """Resolves the owner scope based on the user in the context.""" - if not context: - return 'unknown' - # Example: Basic user name. Adapt as needed for your user model. return context.user.user_name diff --git a/src/a2a/server/tasks/copying_task_store.py b/src/a2a/server/tasks/copying_task_store.py index 6bfda5e74..f7f41bf1f 100644 --- a/src/a2a/server/tasks/copying_task_store.py +++ b/src/a2a/server/tasks/copying_task_store.py @@ -24,16 +24,14 @@ class CopyingTaskStoreAdapter(TaskStore): def __init__(self, underlying_store: TaskStore): self._store = underlying_store - async def save( - self, task: Task, context: ServerCallContext | None = None - ) -> None: + async def save(self, task: Task, context: ServerCallContext) -> None: """Saves a copy of the task to the underlying store.""" task_copy = Task() task_copy.CopyFrom(task) await self._store.save(task_copy, context) async def get( - self, task_id: str, context: ServerCallContext | None = None + self, task_id: str, context: ServerCallContext ) -> Task | None: """Retrieves a task from the underlying store and returns a copy.""" task = await self._store.get(task_id, context) @@ -46,7 +44,7 @@ async def get( async def list( self, params: ListTasksRequest, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> ListTasksResponse: """Retrieves a list of tasks from the underlying store and returns a copy.""" response = await self._store.list(params, context) @@ -54,8 +52,6 @@ async def list( response_copy.CopyFrom(response) return response_copy - async def delete( - self, task_id: str, context: ServerCallContext | None = None - ) -> None: + async def delete(self, task_id: str, context: ServerCallContext) -> None: """Deletes a task from the underlying store.""" await self._store.delete(task_id, context) diff --git a/src/a2a/server/tasks/database_task_store.py b/src/a2a/server/tasks/database_task_store.py index ac1cf947b..cfa007e56 100644 --- a/src/a2a/server/tasks/database_task_store.py +++ b/src/a2a/server/tasks/database_task_store.py @@ -169,9 +169,7 @@ def _from_orm(self, task_model: TaskModel) -> Task: # Legacy conversion return compat_task_model_to_core(task_model) - async def save( - self, task: Task, context: ServerCallContext | None = None - ) -> None: + async def save(self, task: Task, context: ServerCallContext) -> None: """Saves or updates a task in the database for the resolved owner.""" await self._ensure_initialized() owner = self.owner_resolver(context) @@ -185,7 +183,7 @@ async def save( ) async def get( - self, task_id: str, context: ServerCallContext | None = None + self, task_id: str, context: ServerCallContext ) -> Task | None: """Retrieves a task from the database by ID, for the given owner.""" await self._ensure_initialized() @@ -216,7 +214,7 @@ async def get( async def list( self, params: a2a_pb2.ListTasksRequest, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> a2a_pb2.ListTasksResponse: """Retrieves tasks from the database based on provided parameters, for the given owner.""" await self._ensure_initialized() @@ -315,9 +313,7 @@ async def list( page_size=page_size, ) - async def delete( - self, task_id: str, context: ServerCallContext | None = None - ) -> None: + async def delete(self, task_id: str, context: ServerCallContext) -> None: """Deletes a task from the database by ID, for the given owner.""" await self._ensure_initialized() owner = self.owner_resolver(context) diff --git a/src/a2a/server/tasks/inmemory_task_store.py b/src/a2a/server/tasks/inmemory_task_store.py index f887b77ba..75d2269bc 100644 --- a/src/a2a/server/tasks/inmemory_task_store.py +++ b/src/a2a/server/tasks/inmemory_task_store.py @@ -35,9 +35,7 @@ def __init__( def _get_owner_tasks(self, owner: str) -> dict[str, Task]: return self.tasks.get(owner, {}) - async def save( - self, task: Task, context: ServerCallContext | None = None - ) -> None: + async def save(self, task: Task, context: ServerCallContext) -> None: """Saves or updates a task in the in-memory store for the resolved owner.""" owner = self.owner_resolver(context) if owner not in self.tasks: @@ -50,7 +48,7 @@ async def save( ) async def get( - self, task_id: str, context: ServerCallContext | None = None + self, task_id: str, context: ServerCallContext ) -> Task | None: """Retrieves a task from the in-memory store by ID, for the given owner.""" owner = self.owner_resolver(context) @@ -77,7 +75,7 @@ async def get( async def list( self, params: a2a_pb2.ListTasksRequest, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> a2a_pb2.ListTasksResponse: """Retrieves a list of tasks from the store, for the given owner.""" owner = self.owner_resolver(context) @@ -156,9 +154,7 @@ async def list( page_size=page_size, ) - async def delete( - self, task_id: str, context: ServerCallContext | None = None - ) -> None: + async def delete(self, task_id: str, context: ServerCallContext) -> None: """Deletes a task from the in-memory store by ID, for the given owner.""" owner = self.owner_resolver(context) async with self.lock: @@ -211,14 +207,12 @@ def __init__( CopyingTaskStoreAdapter(self._impl) if use_copying else self._impl ) - async def save( - self, task: Task, context: ServerCallContext | None = None - ) -> None: + async def save(self, task: Task, context: ServerCallContext) -> None: """Saves or updates a task in the store.""" await self._store.save(task, context) async def get( - self, task_id: str, context: ServerCallContext | None = None + self, task_id: str, context: ServerCallContext ) -> Task | None: """Retrieves a task from the store by ID.""" return await self._store.get(task_id, context) @@ -226,13 +220,11 @@ async def get( async def list( self, params: a2a_pb2.ListTasksRequest, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> a2a_pb2.ListTasksResponse: """Retrieves a list of tasks from the store.""" return await self._store.list(params, context) - async def delete( - self, task_id: str, context: ServerCallContext | None = None - ) -> None: + async def delete(self, task_id: str, context: ServerCallContext) -> None: """Deletes a task from the store by ID.""" await self._store.delete(task_id, context) diff --git a/src/a2a/server/tasks/task_manager.py b/src/a2a/server/tasks/task_manager.py index 440100b1f..4daabb42c 100644 --- a/src/a2a/server/tasks/task_manager.py +++ b/src/a2a/server/tasks/task_manager.py @@ -31,7 +31,7 @@ def __init__( context_id: str | None, task_store: TaskStore, initial_message: Message | None, - context: ServerCallContext | None = None, + context: ServerCallContext, ): """Initializes the TaskManager. @@ -51,7 +51,7 @@ def __init__( self.task_store = task_store self._initial_message = initial_message self._current_task: Task | None = None - self._call_context: ServerCallContext | None = context + self._call_context: ServerCallContext = context logger.debug( 'TaskManager initialized with task_id: %s, context_id: %s', task_id, diff --git a/src/a2a/server/tasks/task_store.py b/src/a2a/server/tasks/task_store.py index a4d3308c0..25e4838d1 100644 --- a/src/a2a/server/tasks/task_store.py +++ b/src/a2a/server/tasks/task_store.py @@ -11,14 +11,12 @@ class TaskStore(ABC): """ @abstractmethod - async def save( - self, task: Task, context: ServerCallContext | None = None - ) -> None: + async def save(self, task: Task, context: ServerCallContext) -> None: """Saves or updates a task in the store.""" @abstractmethod async def get( - self, task_id: str, context: ServerCallContext | None = None + self, task_id: str, context: ServerCallContext ) -> Task | None: """Retrieves a task from the store by ID.""" @@ -26,12 +24,10 @@ async def get( async def list( self, params: ListTasksRequest, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> ListTasksResponse: """Retrieves a list of tasks from the store.""" @abstractmethod - async def delete( - self, task_id: str, context: ServerCallContext | None = None - ) -> None: + async def delete(self, task_id: str, context: ServerCallContext) -> None: """Deletes a task from the store by ID.""" diff --git a/tests/server/agent_execution/test_simple_request_context_builder.py b/tests/server/agent_execution/test_simple_request_context_builder.py index caab48342..ef374e364 100644 --- a/tests/server/agent_execution/test_simple_request_context_builder.py +++ b/tests/server/agent_execution/test_simple_request_context_builder.py @@ -127,10 +127,12 @@ async def test_build_populate_true_with_reference_task_ids(self) -> None: mock_ref_task1 = create_sample_task(task_id=ref_task_id1) mock_ref_task3 = create_sample_task(task_id=ref_task_id3) + server_call_context = ServerCallContext(user=UnauthenticatedUser()) + # Configure task_store.get mock # Note: AsyncMock side_effect needs to handle multiple calls if they have different args. # A simple way is a list of return values, or a function. - async def get_side_effect(task_id): + async def get_side_effect(task_id, server_call_context): if task_id == ref_task_id1: return mock_ref_task1 if task_id == ref_task_id3: @@ -144,7 +146,6 @@ async def get_side_effect(task_id): reference_task_ids=[ref_task_id1, ref_task_id2, ref_task_id3] ) ) - server_call_context = ServerCallContext(user=UnauthenticatedUser()) request_context = await builder.build( params=params, @@ -155,9 +156,15 @@ async def get_side_effect(task_id): ) self.assertEqual(self.mock_task_store.get.call_count, 3) - self.mock_task_store.get.assert_any_call(ref_task_id1) - self.mock_task_store.get.assert_any_call(ref_task_id2) - self.mock_task_store.get.assert_any_call(ref_task_id3) + self.mock_task_store.get.assert_any_call( + ref_task_id1, server_call_context + ) + self.mock_task_store.get.assert_any_call( + ref_task_id2, server_call_context + ) + self.mock_task_store.get.assert_any_call( + ref_task_id3, server_call_context + ) self.assertIsNotNone(request_context.related_tasks) self.assertEqual( diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index 445a45a37..8c9b7d07d 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -56,6 +56,9 @@ def user_name(self) -> str: return self._user_name +TEST_CONTEXT = ServerCallContext(user=SampleUser('test_user')) + + # DSNs for different databases SQLITE_TEST_DSN = ( 'sqlite+aiosqlite:///file:testdb?mode=memory&cache=shared&uri=true' @@ -170,13 +173,17 @@ async def test_save_task(db_store_parameterized: DatabaseTaskStore) -> None: task_to_save.id = ( f'save-task-{db_store_parameterized.engine.url.drivername}' ) - await db_store_parameterized.save(task_to_save) + await db_store_parameterized.save(task_to_save, TEST_CONTEXT) - retrieved_task = await db_store_parameterized.get(task_to_save.id) + retrieved_task = await db_store_parameterized.get( + task_to_save.id, TEST_CONTEXT + ) assert retrieved_task is not None assert retrieved_task.id == task_to_save.id assert MessageToDict(retrieved_task) == MessageToDict(task_to_save) - await db_store_parameterized.delete(task_to_save.id) # Cleanup + await db_store_parameterized.delete( + task_to_save.id, TEST_CONTEXT + ) # Cleanup @pytest.mark.asyncio @@ -186,14 +193,18 @@ async def test_get_task(db_store_parameterized: DatabaseTaskStore) -> None: task_to_save = Task() task_to_save.CopyFrom(MINIMAL_TASK_OBJ) task_to_save.id = task_id - await db_store_parameterized.save(task_to_save) + await db_store_parameterized.save(task_to_save, TEST_CONTEXT) - retrieved_task = await db_store_parameterized.get(task_to_save.id) + retrieved_task = await db_store_parameterized.get( + task_to_save.id, TEST_CONTEXT + ) assert retrieved_task is not None assert retrieved_task.id == task_to_save.id assert retrieved_task.context_id == task_to_save.context_id assert retrieved_task.status.state == TaskState.TASK_STATE_SUBMITTED - await db_store_parameterized.delete(task_to_save.id) # Cleanup + await db_store_parameterized.delete( + task_to_save.id, TEST_CONTEXT + ) # Cleanup @pytest.mark.asyncio @@ -321,9 +332,9 @@ async def test_list_tasks( ), ] for task in tasks_to_create: - await db_store_parameterized.save(task) + await db_store_parameterized.save(task, TEST_CONTEXT) - page = await db_store_parameterized.list(params) + page = await db_store_parameterized.list(params, TEST_CONTEXT) retrieved_ids = [task.id for task in page.tasks] assert retrieved_ids == expected_ids @@ -333,7 +344,7 @@ async def test_list_tasks( # Cleanup for task in tasks_to_create: - await db_store_parameterized.delete(task.id) + await db_store_parameterized.delete(task.id, TEST_CONTEXT) @pytest.mark.asyncio @@ -381,16 +392,16 @@ async def test_list_tasks_fails( ), ] for task in tasks_to_create: - await db_store_parameterized.save(task) + await db_store_parameterized.save(task, TEST_CONTEXT) with pytest.raises(InvalidParamsError) as excinfo: - await db_store_parameterized.list(params) + await db_store_parameterized.list(params, TEST_CONTEXT) assert expected_error_message in str(excinfo.value) # Cleanup for task in tasks_to_create: - await db_store_parameterized.delete(task.id) + await db_store_parameterized.delete(task.id, TEST_CONTEXT) @pytest.mark.asyncio @@ -398,7 +409,9 @@ async def test_get_nonexistent_task( db_store_parameterized: DatabaseTaskStore, ) -> None: """Test retrieving a nonexistent task.""" - retrieved_task = await db_store_parameterized.get('nonexistent-task-id') + retrieved_task = await db_store_parameterized.get( + 'nonexistent-task-id', TEST_CONTEXT + ) assert retrieved_task is None @@ -409,13 +422,23 @@ async def test_delete_task(db_store_parameterized: DatabaseTaskStore) -> None: task_to_save_and_delete = Task() task_to_save_and_delete.CopyFrom(MINIMAL_TASK_OBJ) task_to_save_and_delete.id = task_id - await db_store_parameterized.save(task_to_save_and_delete) + await db_store_parameterized.save(task_to_save_and_delete, TEST_CONTEXT) assert ( - await db_store_parameterized.get(task_to_save_and_delete.id) is not None + await db_store_parameterized.get( + task_to_save_and_delete.id, TEST_CONTEXT + ) + is not None + ) + await db_store_parameterized.delete( + task_to_save_and_delete.id, TEST_CONTEXT + ) + assert ( + await db_store_parameterized.get( + task_to_save_and_delete.id, TEST_CONTEXT + ) + is None ) - await db_store_parameterized.delete(task_to_save_and_delete.id) - assert await db_store_parameterized.get(task_to_save_and_delete.id) is None @pytest.mark.asyncio @@ -423,7 +446,9 @@ async def test_delete_nonexistent_task( db_store_parameterized: DatabaseTaskStore, ) -> None: """Test deleting a nonexistent task. Should not error.""" - await db_store_parameterized.delete('nonexistent-delete-task-id') + await db_store_parameterized.delete( + 'nonexistent-delete-task-id', TEST_CONTEXT + ) @pytest.mark.asyncio @@ -455,8 +480,10 @@ async def test_save_and_get_detailed_task( ], ) - await db_store_parameterized.save(test_task) - retrieved_task = await db_store_parameterized.get(test_task.id) + await db_store_parameterized.save(test_task, TEST_CONTEXT) + retrieved_task = await db_store_parameterized.get( + test_task.id, TEST_CONTEXT + ) assert retrieved_task is not None assert retrieved_task.id == test_task.id @@ -479,8 +506,8 @@ async def test_save_and_get_detailed_task( == MessageToDict(test_task)['history'] ) - await db_store_parameterized.delete(test_task.id) - assert await db_store_parameterized.get(test_task.id) is None + await db_store_parameterized.delete(test_task.id, TEST_CONTEXT) + assert await db_store_parameterized.get(test_task.id, TEST_CONTEXT) is None @pytest.mark.asyncio @@ -498,9 +525,11 @@ async def test_update_task(db_store_parameterized: DatabaseTaskStore) -> None: artifacts=[], history=[], ) - await db_store_parameterized.save(original_task) + await db_store_parameterized.save(original_task, TEST_CONTEXT) - retrieved_before_update = await db_store_parameterized.get(task_id) + retrieved_before_update = await db_store_parameterized.get( + task_id, TEST_CONTEXT + ) assert retrieved_before_update is not None assert ( retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED @@ -516,16 +545,18 @@ async def test_update_task(db_store_parameterized: DatabaseTaskStore) -> None: updated_task.status.timestamp.FromDatetime(updated_timestamp) updated_task.metadata['update_key'] = 'update_value' - await db_store_parameterized.save(updated_task) + await db_store_parameterized.save(updated_task, TEST_CONTEXT) - retrieved_after_update = await db_store_parameterized.get(task_id) + retrieved_after_update = await db_store_parameterized.get( + task_id, TEST_CONTEXT + ) assert retrieved_after_update is not None assert retrieved_after_update.status.state == TaskState.TASK_STATE_COMPLETED assert dict(retrieved_after_update.metadata) == { 'update_key': 'update_value' } - await db_store_parameterized.delete(task_id) + await db_store_parameterized.delete(task_id, TEST_CONTEXT) @pytest.mark.asyncio @@ -547,9 +578,9 @@ async def test_metadata_field_mapping( context_id='session-meta-1', status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), ) - await db_store_parameterized.save(task_no_metadata) + await db_store_parameterized.save(task_no_metadata, TEST_CONTEXT) retrieved_no_metadata = await db_store_parameterized.get( - 'task-metadata-test-1' + 'task-metadata-test-1', TEST_CONTEXT ) assert retrieved_no_metadata is not None # Proto Struct is empty, not None @@ -563,8 +594,10 @@ async def test_metadata_field_mapping( status=TaskStatus(state=TaskState.TASK_STATE_WORKING), metadata=simple_metadata, ) - await db_store_parameterized.save(task_simple_metadata) - retrieved_simple = await db_store_parameterized.get('task-metadata-test-2') + await db_store_parameterized.save(task_simple_metadata, TEST_CONTEXT) + retrieved_simple = await db_store_parameterized.get( + 'task-metadata-test-2', TEST_CONTEXT + ) assert retrieved_simple is not None assert dict(retrieved_simple.metadata) == simple_metadata @@ -586,8 +619,10 @@ async def test_metadata_field_mapping( status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), metadata=complex_metadata, ) - await db_store_parameterized.save(task_complex_metadata) - retrieved_complex = await db_store_parameterized.get('task-metadata-test-3') + await db_store_parameterized.save(task_complex_metadata, TEST_CONTEXT) + retrieved_complex = await db_store_parameterized.get( + 'task-metadata-test-3', TEST_CONTEXT + ) assert retrieved_complex is not None # Convert proto Struct to dict for comparison retrieved_meta = MessageToDict(retrieved_complex.metadata) @@ -599,14 +634,16 @@ async def test_metadata_field_mapping( context_id='session-meta-4', status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), ) - await db_store_parameterized.save(task_update_metadata) + await db_store_parameterized.save(task_update_metadata, TEST_CONTEXT) # Update metadata task_update_metadata.metadata['updated'] = True task_update_metadata.metadata['timestamp'] = '2024-01-01' - await db_store_parameterized.save(task_update_metadata) + await db_store_parameterized.save(task_update_metadata, TEST_CONTEXT) - retrieved_updated = await db_store_parameterized.get('task-metadata-test-4') + retrieved_updated = await db_store_parameterized.get( + 'task-metadata-test-4', TEST_CONTEXT + ) assert retrieved_updated is not None assert dict(retrieved_updated.metadata) == { 'updated': True, @@ -615,17 +652,19 @@ async def test_metadata_field_mapping( # Test 5: Clear metadata (set to empty) task_update_metadata.metadata.Clear() - await db_store_parameterized.save(task_update_metadata) + await db_store_parameterized.save(task_update_metadata, TEST_CONTEXT) - retrieved_none = await db_store_parameterized.get('task-metadata-test-4') + retrieved_none = await db_store_parameterized.get( + 'task-metadata-test-4', TEST_CONTEXT + ) assert retrieved_none is not None assert len(retrieved_none.metadata) == 0 # Cleanup - await db_store_parameterized.delete('task-metadata-test-1') - await db_store_parameterized.delete('task-metadata-test-2') - await db_store_parameterized.delete('task-metadata-test-3') - await db_store_parameterized.delete('task-metadata-test-4') + await db_store_parameterized.delete('task-metadata-test-1', TEST_CONTEXT) + await db_store_parameterized.delete('task-metadata-test-2', TEST_CONTEXT) + await db_store_parameterized.delete('task-metadata-test-3', TEST_CONTEXT) + await db_store_parameterized.delete('task-metadata-test-4', TEST_CONTEXT) @pytest.mark.asyncio @@ -874,7 +913,7 @@ async def test_core_to_0_3_model_conversion( ) # 1. Save the task (will use core_to_compat_task_model) - await store.save(original_task) + await store.save(original_task, TEST_CONTEXT) # 2. Verify it's stored in v0.3 format directly in DB async with store.async_session_maker() as session: @@ -882,17 +921,18 @@ async def test_core_to_0_3_model_conversion( assert db_task is not None assert db_task.protocol_version == '0.3' # v0.3 status JSON uses string for state + assert isinstance(db_task.status, dict) assert db_task.status['state'] == 'working' # 3. Retrieve the task (will use compat_task_model_to_core) - retrieved_task = await store.get(task_id) + retrieved_task = await store.get(task_id, context=TEST_CONTEXT) assert retrieved_task is not None assert retrieved_task.id == original_task.id assert retrieved_task.status.state == TaskState.TASK_STATE_WORKING assert dict(retrieved_task.metadata) == {'key': 'value'} # Reset conversion attributes store.core_to_model_conversion = None - await store.delete('v03-persistence-task') + await store.delete('v03-persistence-task', TEST_CONTEXT) # Ensure aiosqlite, asyncpg, and aiomysql are installed in the test environment (added to pyproject.toml). diff --git a/tests/server/tasks/test_inmemory_task_store.py b/tests/server/tasks/test_inmemory_task_store.py index af3531e33..f04a69170 100644 --- a/tests/server/tasks/test_inmemory_task_store.py +++ b/tests/server/tasks/test_inmemory_task_store.py @@ -25,6 +25,9 @@ def user_name(self) -> str: return self._user_name +TEST_CONTEXT = ServerCallContext(user=SampleUser('test_user')) + + def create_minimal_task( task_id: str = 'task-abc', context_id: str = 'session-xyz' ) -> Task: @@ -41,8 +44,8 @@ async def test_in_memory_task_store_save_and_get() -> None: """Test saving and retrieving a task from the in-memory store.""" store = InMemoryTaskStore() task = create_minimal_task() - await store.save(task) - retrieved_task = await store.get('task-abc') + await store.save(task, TEST_CONTEXT) + retrieved_task = await store.get('task-abc', TEST_CONTEXT) assert retrieved_task == task @@ -50,7 +53,7 @@ async def test_in_memory_task_store_save_and_get() -> None: async def test_in_memory_task_store_get_nonexistent() -> None: """Test retrieving a nonexistent task.""" store = InMemoryTaskStore() - retrieved_task = await store.get('nonexistent') + retrieved_task = await store.get('nonexistent', TEST_CONTEXT) assert retrieved_task is None @@ -179,9 +182,9 @@ async def test_list_tasks( ), ] for task in tasks_to_create: - await store.save(task) + await store.save(task, TEST_CONTEXT) - page = await store.list(params) + page = await store.list(params, TEST_CONTEXT) retrieved_ids = [task.id for task in page.tasks] assert retrieved_ids == expected_ids @@ -191,7 +194,7 @@ async def test_list_tasks( # Cleanup for task in tasks_to_create: - await store.delete(task.id) + await store.delete(task.id, TEST_CONTEXT) @pytest.mark.asyncio @@ -238,16 +241,16 @@ async def test_list_tasks_fails( ), ] for task in tasks_to_create: - await store.save(task) + await store.save(task, TEST_CONTEXT) with pytest.raises(InvalidParamsError) as excinfo: - await store.list(params) + await store.list(params, TEST_CONTEXT) assert expected_error_message in str(excinfo.value) # Cleanup for task in tasks_to_create: - await store.delete(task.id) + await store.delete(task.id, TEST_CONTEXT) @pytest.mark.asyncio @@ -255,9 +258,9 @@ async def test_in_memory_task_store_delete() -> None: """Test deleting a task from the store.""" store = InMemoryTaskStore() task = create_minimal_task() - await store.save(task) - await store.delete('task-abc') - retrieved_task = await store.get('task-abc') + await store.save(task, TEST_CONTEXT) + await store.delete('task-abc', TEST_CONTEXT) + retrieved_task = await store.get('task-abc', TEST_CONTEXT) assert retrieved_task is None @@ -265,7 +268,7 @@ async def test_in_memory_task_store_delete() -> None: async def test_in_memory_task_store_delete_nonexistent() -> None: """Test deleting a nonexistent task.""" store = InMemoryTaskStore() - await store.delete('nonexistent') + await store.delete('nonexistent', TEST_CONTEXT) @pytest.mark.asyncio @@ -341,10 +344,10 @@ async def test_inmemory_task_store_copying_behavior(use_copying: bool): original_task = Task( id='test_task', status=TaskStatus(state=TaskState.TASK_STATE_WORKING) ) - await store.save(original_task) + await store.save(original_task, TEST_CONTEXT) # Retrieve it - retrieved_task = await store.get('test_task') + retrieved_task = await store.get('test_task', TEST_CONTEXT) assert retrieved_task is not None if use_copying: @@ -356,7 +359,7 @@ async def test_inmemory_task_store_copying_behavior(use_copying: bool): retrieved_task.status.state = TaskState.TASK_STATE_COMPLETED # Retrieve it again, it should NOT be modified in the store if use_copying=True - retrieved_task_2 = await store.get('test_task') + retrieved_task_2 = await store.get('test_task', TEST_CONTEXT) assert retrieved_task_2 is not None if use_copying: diff --git a/tests/server/tasks/test_task_manager.py b/tests/server/tasks/test_task_manager.py index 381f71593..bdfbf525c 100644 --- a/tests/server/tasks/test_task_manager.py +++ b/tests/server/tasks/test_task_manager.py @@ -3,8 +3,9 @@ import pytest +from a2a.auth.user import User +from a2a.server.context import ServerCallContext from a2a.server.tasks import TaskManager -from a2a.utils.errors import InvalidParamsError from a2a.types.a2a_pb2 import ( Artifact, Message, @@ -19,6 +20,24 @@ from a2a.utils.errors import InvalidParamsError +class SampleUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +TEST_CONTEXT = ServerCallContext(user=SampleUser('test_user')) + + # Create proto task instead of dict def create_minimal_task( task_id: str = 'task-abc', @@ -49,6 +68,7 @@ def task_manager(mock_task_store: AsyncMock) -> TaskManager: context_id=MINIMAL_CONTEXT_ID, task_store=mock_task_store, initial_message=None, + context=TEST_CONTEXT, ) @@ -63,6 +83,7 @@ def test_task_manager_invalid_task_id( context_id='test_context', task_store=mock_task_store, initial_message=None, + context=TEST_CONTEXT, ) @@ -75,7 +96,7 @@ async def test_get_task_existing( mock_task_store.get.return_value = expected_task retrieved_task = await task_manager.get_task() assert retrieved_task == expected_task - mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, None) + mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, TEST_CONTEXT) @pytest.mark.asyncio @@ -86,7 +107,7 @@ async def test_get_task_nonexistent( mock_task_store.get.return_value = None retrieved_task = await task_manager.get_task() assert retrieved_task is None - mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, None) + mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, TEST_CONTEXT) @pytest.mark.asyncio @@ -96,7 +117,7 @@ async def test_save_task_event_new_task( """Test saving a new task.""" task = create_minimal_task() await task_manager.save_task_event(task) - mock_task_store.save.assert_called_once_with(task, None) + mock_task_store.save.assert_called_once_with(task, TEST_CONTEXT) @pytest.mark.asyncio @@ -188,7 +209,7 @@ async def test_ensure_task_existing( ) retrieved_task = await task_manager.ensure_task(event) assert retrieved_task == expected_task - mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, None) + mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, TEST_CONTEXT) @pytest.mark.asyncio @@ -202,6 +223,7 @@ async def test_ensure_task_nonexistent( context_id=None, task_store=mock_task_store, initial_message=None, + context=TEST_CONTEXT, ) event = TaskStatusUpdateEvent( task_id='new-task', @@ -212,7 +234,7 @@ async def test_ensure_task_nonexistent( assert new_task.id == 'new-task' assert new_task.context_id == 'some-context' assert new_task.status.state == TaskState.TASK_STATE_SUBMITTED - mock_task_store.save.assert_called_once_with(new_task, None) + mock_task_store.save.assert_called_once_with(new_task, TEST_CONTEXT) assert task_manager_without_id.task_id == 'new-task' assert task_manager_without_id.context_id == 'some-context' @@ -233,7 +255,7 @@ async def test_save_task( """Test saving a task.""" task = create_minimal_task() await task_manager._save_task(task) # type: ignore - mock_task_store.save.assert_called_once_with(task, None) + mock_task_store.save.assert_called_once_with(task, TEST_CONTEXT) @pytest.mark.asyncio @@ -263,6 +285,7 @@ async def test_save_task_event_new_task_no_task_id( context_id=None, task_store=mock_task_store, initial_message=None, + context=TEST_CONTEXT, ) task = Task( id='new-task-id', @@ -270,7 +293,7 @@ async def test_save_task_event_new_task_no_task_id( status=TaskStatus(state=TaskState.TASK_STATE_WORKING), ) await task_manager_without_id.save_task_event(task) - mock_task_store.save.assert_called_once_with(task, None) + mock_task_store.save.assert_called_once_with(task, TEST_CONTEXT) assert task_manager_without_id.task_id == 'new-task-id' assert task_manager_without_id.context_id == 'some-context' # initial submit should be updated to working @@ -287,6 +310,7 @@ async def test_get_task_no_task_id( context_id='some-context', task_store=mock_task_store, initial_message=None, + context=TEST_CONTEXT, ) retrieved_task = await task_manager_without_id.get_task() assert retrieved_task is None @@ -303,6 +327,7 @@ async def test_save_task_event_no_task_existing( context_id=None, task_store=mock_task_store, initial_message=None, + context=TEST_CONTEXT, ) mock_task_store.get.return_value = None event = TaskStatusUpdateEvent( diff --git a/tests/server/test_owner_resolver.py b/tests/server/test_owner_resolver.py index 5bac5c605..bb7b91012 100644 --- a/tests/server/test_owner_resolver.py +++ b/tests/server/test_owner_resolver.py @@ -19,13 +19,13 @@ def user_name(self) -> str: return self._user_name -def test_resolve_user_scope_valid_user(): - """Test resolve_user_scope with a valid user in the context.""" +def test_resolve_user(): + """Test resolve_user_scope.""" user = SampleUser(user_name='SampleUser') context = ServerCallContext(user=user) assert resolve_user_scope(context) == 'SampleUser' -def test_resolve_user_scope_no_context(): - """Test resolve_user_scope when the context is None.""" - assert resolve_user_scope(None) == 'unknown' +def test_resolve_user_default_context(): + """Test resolve_user_scope with default context.""" + assert resolve_user_scope(ServerCallContext()) == ''