Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ async def build(
):
tasks = await asyncio.gather(
*[
self._task_store.get(task_id)
self._task_store.get(
task_id, context or ServerCallContext()
)
for task_id in params.message.reference_task_ids
]
)
Expand Down
7 changes: 2 additions & 5 deletions src/a2a/server/owner_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@


# Definition
OwnerResolver = Callable[[ServerCallContext | None], str]
OwnerResolver = Callable[[ServerCallContext], str]


# Example Default Implementation
def resolve_user_scope(context: ServerCallContext | None) -> str:
def resolve_user_scope(context: ServerCallContext) -> str:
"""Resolves the owner scope based on the user in the context."""
if not context:
return 'unknown'
# Example: Basic user name. Adapt as needed for your user model.
return context.user.user_name
12 changes: 4 additions & 8 deletions src/a2a/server/tasks/copying_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,14 @@ class CopyingTaskStoreAdapter(TaskStore):
def __init__(self, underlying_store: TaskStore):
self._store = underlying_store

async def save(
self, task: Task, context: ServerCallContext | None = None
) -> None:
async def save(self, task: Task, context: ServerCallContext) -> None:
"""Saves a copy of the task to the underlying store."""
task_copy = Task()
task_copy.CopyFrom(task)
await self._store.save(task_copy, context)

async def get(
self, task_id: str, context: ServerCallContext | None = None
self, task_id: str, context: ServerCallContext
) -> Task | None:
"""Retrieves a task from the underlying store and returns a copy."""
task = await self._store.get(task_id, context)
Expand All @@ -46,16 +44,14 @@ async def get(
async def list(
self,
params: ListTasksRequest,
context: ServerCallContext | None = None,
context: ServerCallContext,
) -> ListTasksResponse:
"""Retrieves a list of tasks from the underlying store and returns a copy."""
response = await self._store.list(params, context)
response_copy = ListTasksResponse()
response_copy.CopyFrom(response)
return response_copy

async def delete(
self, task_id: str, context: ServerCallContext | None = None
) -> None:
async def delete(self, task_id: str, context: ServerCallContext) -> None:
"""Deletes a task from the underlying store."""
await self._store.delete(task_id, context)
12 changes: 4 additions & 8 deletions src/a2a/server/tasks/database_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def _from_orm(self, task_model: TaskModel) -> Task:
# Legacy conversion
return compat_task_model_to_core(task_model)

async def save(
self, task: Task, context: ServerCallContext | None = None
) -> None:
async def save(self, task: Task, context: ServerCallContext) -> None:
"""Saves or updates a task in the database for the resolved owner."""
await self._ensure_initialized()
owner = self.owner_resolver(context)
Expand All @@ -185,7 +183,7 @@ async def save(
)

async def get(
self, task_id: str, context: ServerCallContext | None = None
self, task_id: str, context: ServerCallContext
) -> Task | None:
"""Retrieves a task from the database by ID, for the given owner."""
await self._ensure_initialized()
Expand Down Expand Up @@ -216,7 +214,7 @@ async def get(
async def list(
self,
params: a2a_pb2.ListTasksRequest,
context: ServerCallContext | None = None,
context: ServerCallContext,
) -> a2a_pb2.ListTasksResponse:
"""Retrieves tasks from the database based on provided parameters, for the given owner."""
await self._ensure_initialized()
Expand Down Expand Up @@ -315,9 +313,7 @@ async def list(
page_size=page_size,
)

async def delete(
self, task_id: str, context: ServerCallContext | None = None
) -> None:
async def delete(self, task_id: str, context: ServerCallContext) -> None:
"""Deletes a task from the database by ID, for the given owner."""
await self._ensure_initialized()
owner = self.owner_resolver(context)
Expand Down
24 changes: 8 additions & 16 deletions src/a2a/server/tasks/inmemory_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def __init__(
def _get_owner_tasks(self, owner: str) -> dict[str, Task]:
return self.tasks.get(owner, {})

async def save(
self, task: Task, context: ServerCallContext | None = None
) -> None:
async def save(self, task: Task, context: ServerCallContext) -> None:
"""Saves or updates a task in the in-memory store for the resolved owner."""
owner = self.owner_resolver(context)
if owner not in self.tasks:
Expand All @@ -50,7 +48,7 @@ async def save(
)

async def get(
self, task_id: str, context: ServerCallContext | None = None
self, task_id: str, context: ServerCallContext
) -> Task | None:
"""Retrieves a task from the in-memory store by ID, for the given owner."""
owner = self.owner_resolver(context)
Expand All @@ -77,7 +75,7 @@ async def get(
async def list(
self,
params: a2a_pb2.ListTasksRequest,
context: ServerCallContext | None = None,
context: ServerCallContext,
) -> a2a_pb2.ListTasksResponse:
"""Retrieves a list of tasks from the store, for the given owner."""
owner = self.owner_resolver(context)
Expand Down Expand Up @@ -156,9 +154,7 @@ async def list(
page_size=page_size,
)

async def delete(
self, task_id: str, context: ServerCallContext | None = None
) -> None:
async def delete(self, task_id: str, context: ServerCallContext) -> None:
"""Deletes a task from the in-memory store by ID, for the given owner."""
owner = self.owner_resolver(context)
async with self.lock:
Expand Down Expand Up @@ -211,28 +207,24 @@ def __init__(
CopyingTaskStoreAdapter(self._impl) if use_copying else self._impl
)

async def save(
self, task: Task, context: ServerCallContext | None = None
) -> None:
async def save(self, task: Task, context: ServerCallContext) -> None:
"""Saves or updates a task in the store."""
await self._store.save(task, context)

async def get(
self, task_id: str, context: ServerCallContext | None = None
self, task_id: str, context: ServerCallContext
) -> Task | None:
"""Retrieves a task from the store by ID."""
return await self._store.get(task_id, context)

async def list(
self,
params: a2a_pb2.ListTasksRequest,
context: ServerCallContext | None = None,
context: ServerCallContext,
) -> a2a_pb2.ListTasksResponse:
"""Retrieves a list of tasks from the store."""
return await self._store.list(params, context)

async def delete(
self, task_id: str, context: ServerCallContext | None = None
) -> None:
async def delete(self, task_id: str, context: ServerCallContext) -> None:
"""Deletes a task from the store by ID."""
await self._store.delete(task_id, context)
4 changes: 2 additions & 2 deletions src/a2a/server/tasks/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
context_id: str | None,
task_store: TaskStore,
initial_message: Message | None,
context: ServerCallContext | None = None,
context: ServerCallContext,
):
"""Initializes the TaskManager.

Expand All @@ -51,7 +51,7 @@ def __init__(
self.task_store = task_store
self._initial_message = initial_message
self._current_task: Task | None = None
self._call_context: ServerCallContext | None = context
self._call_context: ServerCallContext = context
logger.debug(
'TaskManager initialized with task_id: %s, context_id: %s',
task_id,
Expand Down
12 changes: 4 additions & 8 deletions src/a2a/server/tasks/task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,23 @@ class TaskStore(ABC):
"""

@abstractmethod
async def save(
self, task: Task, context: ServerCallContext | None = None
) -> None:
async def save(self, task: Task, context: ServerCallContext) -> None:
"""Saves or updates a task in the store."""

@abstractmethod
async def get(
self, task_id: str, context: ServerCallContext | None = None
self, task_id: str, context: ServerCallContext
) -> Task | None:
"""Retrieves a task from the store by ID."""

@abstractmethod
async def list(
self,
params: ListTasksRequest,
context: ServerCallContext | None = None,
context: ServerCallContext,
) -> ListTasksResponse:
"""Retrieves a list of tasks from the store."""

@abstractmethod
async def delete(
self, task_id: str, context: ServerCallContext | None = None
) -> None:
async def delete(self, task_id: str, context: ServerCallContext) -> None:
"""Deletes a task from the store by ID."""
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,12 @@ async def test_build_populate_true_with_reference_task_ids(self) -> None:
mock_ref_task1 = create_sample_task(task_id=ref_task_id1)
mock_ref_task3 = create_sample_task(task_id=ref_task_id3)

server_call_context = ServerCallContext(user=UnauthenticatedUser())

# Configure task_store.get mock
# Note: AsyncMock side_effect needs to handle multiple calls if they have different args.
# A simple way is a list of return values, or a function.
async def get_side_effect(task_id):
async def get_side_effect(task_id, server_call_context):
if task_id == ref_task_id1:
return mock_ref_task1
if task_id == ref_task_id3:
Expand All @@ -144,7 +146,6 @@ async def get_side_effect(task_id):
reference_task_ids=[ref_task_id1, ref_task_id2, ref_task_id3]
)
)
server_call_context = ServerCallContext(user=UnauthenticatedUser())

request_context = await builder.build(
params=params,
Expand All @@ -155,9 +156,15 @@ async def get_side_effect(task_id):
)

self.assertEqual(self.mock_task_store.get.call_count, 3)
self.mock_task_store.get.assert_any_call(ref_task_id1)
self.mock_task_store.get.assert_any_call(ref_task_id2)
self.mock_task_store.get.assert_any_call(ref_task_id3)
self.mock_task_store.get.assert_any_call(
ref_task_id1, server_call_context
)
self.mock_task_store.get.assert_any_call(
ref_task_id2, server_call_context
)
self.mock_task_store.get.assert_any_call(
ref_task_id3, server_call_context
)

self.assertIsNotNone(request_context.related_tasks)
self.assertEqual(
Expand Down
Loading
Loading