diff --git a/src/a2a/server/models.py b/src/a2a/server/models.py index fb03a5273..b3ae1a389 100644 --- a/src/a2a/server/models.py +++ b/src/a2a/server/models.py @@ -11,6 +11,9 @@ def override(func): # noqa: ANN001, ANN201 return func +from a2a.types.a2a_pb2 import Artifact, Message, TaskStatus + + try: from sqlalchemy import JSON, DateTime, Index, LargeBinary, String from sqlalchemy.orm import ( @@ -48,13 +51,11 @@ class TaskMixin: last_updated: Mapped[datetime | None] = mapped_column( DateTime, nullable=True ) - status: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) - artifacts: Mapped[list[dict[str, Any]] | None] = mapped_column( - JSON, nullable=True - ) - history: Mapped[list[dict[str, Any]] | None] = mapped_column( + status: Mapped[TaskStatus] = mapped_column(JSON, nullable=False) + artifacts: Mapped[list[Artifact] | None] = mapped_column( JSON, nullable=True ) + history: Mapped[list[Message] | None] = mapped_column(JSON, nullable=True) protocol_version: Mapped[str | None] = mapped_column( String(16), nullable=True )