Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions agentex/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6629,20 +6629,12 @@ components:
title: UpdateSpanRequest
UpdateStateRequest:
properties:
task_id:
type: string
title: The unique id of the task to update the state of
agent_id:
type: string
title: The unique id of the agent to update the state of
state:
additionalProperties: true
type: object
title: The state to update the state with.
type: object
required:
- task_id
- agent_id
- state
title: UpdateStateRequest
UpdateTaskMessageRequest:
Expand Down
1 change: 0 additions & 1 deletion agentex/src/api/routes/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ async def update_task_state(
) -> State:
state_entity = await states_use_case.update(
id=state_id,
task_id=request.task_id,
state=request.state,
)
return State.model_validate(state_entity)
Expand Down
8 changes: 0 additions & 8 deletions agentex/src/api/schemas/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,6 @@ class GetStatesRequest(BaseModel):


class UpdateStateRequest(BaseModel):
task_id: str = Field(
...,
title="The unique id of the task to update the state of",
)
agent_id: str = Field(
...,
title="The unique id of the agent to update the state of",
)
state: dict[str, Any] = Field(
...,
title="The state to update the state with.",
Expand Down
14 changes: 8 additions & 6 deletions agentex/src/domain/use_cases/states_use_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from fastapi import Depends

from src.adapters.crud_store.exceptions import ItemDoesNotExist
from src.domain.entities.states import StateEntity
from src.domain.repositories.task_state_repository import DTaskStateRepository
from src.utils.logging import make_logger
Expand Down Expand Up @@ -58,13 +59,14 @@ async def list(
order_direction=order_direction,
)

async def update(self, id: str, task_id: str, state: dict[str, Any]) -> StateEntity:
async def update(self, id: str, state: dict[str, Any]) -> StateEntity:
task_state = await self.task_state_repository.get(id=id)
Comment thread
deepthi-rao-scale marked this conversation as resolved.
if task_state and task_state.task_id == task_id:
# Update the state field but preserve other fields
task_state.state = state
return await self.task_state_repository.update(task_state)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@smoreinis any reason not having the task_id would introduce a regression?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be fine here, thanks for double checking

return task_state
if task_state is None:
raise ItemDoesNotExist(f"State {id} not found")

# Update the state field but preserve other fields.
task_state.state = state
return await self.task_state_repository.update(task_state)

async def delete(self, id: str) -> None:
return await self.task_state_repository.delete(id=id)
Expand Down
21 changes: 2 additions & 19 deletions agentex/tests/integration/api/states/test_states_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,30 +190,13 @@ async def test_create_and_retrieve_state_consistency(
assert created_state["agent_id"] == state_data["agent_id"]
assert created_state["state"] == state_value

# API-first validation: UPDATE with wrong task_id does nothing
update_response = await isolated_client.put(
f"/states/{state_id}",
json={
"state": {},
"task_id": "some-other-task-id",
"agent_id": test_agent.id,
},
)
assert update_response.status_code == 200
updated_state = update_response.json()

assert updated_state["id"] == state_id
assert updated_state["task_id"] == state_data["task_id"]
assert updated_state["agent_id"] == state_data["agent_id"]
assert updated_state["state"] == state_value

# API-first validation: UPDATE the created state
# API-first validation: UPDATE accepts legacy parent identifiers in the body
state_value_updated = {"test": "updated"}
update_response = await isolated_client.put(
f"/states/{state_id}",
json={
"state": state_value_updated,
"task_id": test_task.id,
"task_id": "some-other-task-id",
"agent_id": test_agent.id,
},
)
Expand Down
14 changes: 14 additions & 0 deletions agentex/tests/unit/api/test_states_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest
from src.api.schemas.states import UpdateStateRequest

@pytest.mark.unit
def test_update_state_request_ignores_legacy_parent_identifiers():
request = UpdateStateRequest.model_validate(
{
"state": {"status": "new"},
"task_id": "task-1",
"agent_id": "agent-1",
}
)

assert request.model_dump() == {"state": {"status": "new"}}
65 changes: 65 additions & 0 deletions agentex/tests/unit/use_cases/test_states_use_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from unittest.mock import AsyncMock

import pytest
from src.adapters.crud_store.exceptions import ItemDoesNotExist
from src.domain.entities.states import StateEntity
from src.domain.repositories.task_state_repository import TaskStateRepository
from src.domain.use_cases.states_use_case import StatesUseCase


@pytest.fixture
def task_state_repository():
repository = AsyncMock(spec=TaskStateRepository)
repository.get = AsyncMock()
repository.update = AsyncMock()
return repository


@pytest.fixture
def states_use_case(task_state_repository):
return StatesUseCase(task_state_repository=task_state_repository)


@pytest.fixture
def existing_state():
return StateEntity(
id="state-1",
task_id="task-1",
agent_id="agent-1",
state={"status": "old"},
)


@pytest.mark.unit
@pytest.mark.asyncio
class TestStatesUseCase:
async def test_update_mutates_state_by_id(
self, states_use_case, task_state_repository, existing_state
):
task_state_repository.get.return_value = existing_state
task_state_repository.update.return_value = existing_state

result = await states_use_case.update(
id="state-1",
state={"status": "new"},
)

assert result is existing_state
assert result.state == {"status": "new"}
task_state_repository.get.assert_awaited_once_with(id="state-1")
task_state_repository.update.assert_awaited_once_with(existing_state)

async def test_update_raises_not_found_when_state_does_not_exist(
self, states_use_case, task_state_repository
):
task_state_repository.get.return_value = None

with pytest.raises(ItemDoesNotExist) as exc_info:
await states_use_case.update(
id="state-1",
state={"status": "new"},
)

assert "State state-1 not found" in str(exc_info.value)
task_state_repository.get.assert_awaited_once_with(id="state-1")
task_state_repository.update.assert_not_awaited()
Loading