Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 247 additions & 40 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
import logging
import re
from json import JSONDecodeError
from typing import Any, AsyncIterator
from typing import Any, AsyncIterator, Iterator

from cachetools import TTLCache # type: ignore

from llama_stack_client import APIConnectionError
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
from llama_stack_client import AsyncLlamaStackClient # type: ignore
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
from llama_stack_client.types import UserMessage # type: ignore

from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
from llama_stack_client.types.shared import ToolCall
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem

from fastapi import APIRouter, HTTPException, Request, Depends, status
from fastapi.responses import StreamingResponse

Expand Down Expand Up @@ -46,7 +49,8 @@
_agent_cache: TTLCache[str, AsyncAgent] = TTLCache(maxsize=1000, ttl=3600)


async def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments
# # pylint: disable=R0913,R0917
async def get_agent(
client: AsyncLlamaStackClient,
model_id: str,
system_prompt: str,
Expand Down Expand Up @@ -127,7 +131,7 @@ def stream_end_event(metadata_map: dict) -> str:
)


def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | None:
def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> Iterator[str]:
"""Build a streaming event from a chunk response.

This function processes chunks from the LLama Stack streaming response and formats
Expand All @@ -142,58 +146,261 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | N
chunk_id: The current chunk ID counter (gets incremented for each token)

Returns:
str | None: A formatted SSE data string with event information, or None if
the chunk doesn't contain processable event data
Iterator[str]: An iterable list of formatted SSE data strings with event information
"""
# pylint: disable=R1702
if hasattr(chunk.event, "payload"):
if chunk.event.payload.event_type == "step_progress":
if hasattr(chunk.event.payload.delta, "text"):
text = chunk.event.payload.delta.text
return format_stream_data(
if hasattr(chunk, "error"):
yield from _handle_error_event(chunk, chunk_id)
return

event_type = chunk.event.payload.event_type
step_type = getattr(chunk.event.payload, "step_type", None)

if event_type in {"turn_start", "turn_awaiting_input"}:
yield from _handle_turn_start_event(chunk_id)
elif event_type == "turn_complete":
yield from _handle_turn_complete_event(chunk, chunk_id)
elif step_type == "shield_call":
yield from _handle_shield_event(chunk, chunk_id)
elif step_type == "inference":
yield from _handle_inference_event(chunk, chunk_id)
elif step_type == "tool_execution":
yield from _handle_tool_execution_event(chunk, chunk_id, metadata_map)
else:
yield from _handle_heartbeat_event(chunk_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this! So much easier to follow the flow



# -----------------------------------
# Error handling
# -----------------------------------
def _handle_error_event(chunk: Any, chunk_id: int) -> Iterator[str]:
yield format_stream_data(
{
"event": "error",
"data": {
"id": chunk_id,
"token": chunk.error["message"],
},
}
)


# -----------------------------------
# Turn handling
# -----------------------------------
def _handle_turn_start_event(chunk_id: int) -> Iterator[str]:
yield format_stream_data(
{
"event": "token",
"data": {
"id": chunk_id,
"token": "",
},
}
)


def _handle_turn_complete_event(chunk: Any, chunk_id: int) -> Iterator[str]:
yield format_stream_data(
{
"event": "turn_complete",
"data": {
"id": chunk_id,
"token": chunk.event.payload.turn.output_message.content,
},
}
)


# -----------------------------------
# Shield handling
# -----------------------------------
def _handle_shield_event(chunk: Any, chunk_id: int) -> Iterator[str]:
if chunk.event.payload.event_type == "step_complete":
violation = chunk.event.payload.step_details.violation
if not violation:
yield format_stream_data(
{
"event": "token",
"data": {
"id": chunk_id,
"role": chunk.event.payload.step_type,
"token": "No Violation",
},
}
)
else:
violation = (
f"Violation: {violation.user_message} (Metadata: {violation.metadata})"
)
yield format_stream_data(
{
"event": "token",
"data": {
"id": chunk_id,
"role": chunk.event.payload.step_type,
"token": violation,
},
}
)


# -----------------------------------
# Inference handling
# -----------------------------------
def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]:
if chunk.event.payload.event_type == "step_start":
yield format_stream_data(
{
"event": "token",
"data": {
"id": chunk_id,
"role": chunk.event.payload.step_type,
"token": "",
},
}
)

elif chunk.event.payload.event_type == "step_progress":
if chunk.event.payload.delta.type == "tool_call":
if isinstance(chunk.event.payload.delta.tool_call, str):
yield format_stream_data(
{
"event": "token",
"event": "tool_call",
"data": {
"id": chunk_id,
"role": chunk.event.payload.step_type,
"token": text,
"token": chunk.event.payload.delta.tool_call,
},
}
)
if (
chunk.event.payload.event_type == "step_complete"
and chunk.event.payload.step_details.step_type == "tool_execution"
):
for r in chunk.event.payload.step_details.tool_responses:
if r.tool_name == "knowledge_search" and r.content:
for text_content_item in r.content:
if isinstance(text_content_item, TextContentItem):
for match in METADATA_PATTERN.findall(
text_content_item.text
):
try:
meta = json.loads(match.replace("'", '"'))
elif isinstance(chunk.event.payload.delta.tool_call, ToolCall):
yield format_stream_data(
{
"event": "tool_call",
"data": {
"id": chunk_id,
"role": chunk.event.payload.step_type,
"token": chunk.event.payload.delta.tool_call.tool_name,
},
}
)

elif chunk.event.payload.delta.type == "text":
yield format_stream_data(
{
"event": "token",
"data": {
"id": chunk_id,
"role": chunk.event.payload.step_type,
"token": chunk.event.payload.delta.text,
},
}
)


# -----------------------------------
# Tool Execution handling
# -----------------------------------
# pylint: disable=R1702,R0912
def _handle_tool_execution_event(
chunk: Any, chunk_id: int, metadata_map: dict
) -> Iterator[str]:
if chunk.event.payload.event_type == "step_start":
yield format_stream_data(
{
"event": "tool_call",
"data": {
"id": chunk_id,
"role": chunk.event.payload.step_type,
"token": "",
},
}
)

elif chunk.event.payload.event_type == "step_complete":
for t in chunk.event.payload.step_details.tool_calls:
yield format_stream_data(
{
"event": "tool_call",
"data": {
"id": chunk_id,
"role": chunk.event.payload.step_type,
"token": f"Tool:{t.tool_name} arguments:{t.arguments}",
},
}
)

for r in chunk.event.payload.step_details.tool_responses:
if r.tool_name == "query_from_memory":
inserted_context = interleaved_content_as_str(r.content)
yield format_stream_data(
{
"event": "tool_call",
"data": {
"id": chunk_id,
"role": chunk.event.payload.step_type,
"token": f"Fetched {len(inserted_context)} bytes from memory",
},
}
)

elif r.tool_name == "knowledge_search" and r.content:
summary = ""
for i, text_content_item in enumerate(r.content):
if isinstance(text_content_item, TextContentItem):
if i == 0:
summary = text_content_item.text
newline_pos = summary.find("\n")
if newline_pos > 0:
summary = summary[:newline_pos]
for match in METADATA_PATTERN.findall(text_content_item.text):
try:
meta = json.loads(match.replace("'", '"'))
if "document_id" in meta:
metadata_map[meta["document_id"]] = meta
except JSONDecodeError:
logger.debug(
"JSONDecodeError was thrown in processing %s",
match,
)
if chunk.event.payload.step_details.tool_calls:
tool_name = str(
chunk.event.payload.step_details.tool_calls[0].tool_name
except JSONDecodeError:
logger.debug(
"JSONDecodeError was thrown in processing %s",
match,
)

yield format_stream_data(
{
"event": "tool_call",
"data": {
"id": chunk_id,
"role": chunk.event.payload.step_type,
"token": f"Tool:{r.tool_name} summary:{summary}",
},
}
)
return format_stream_data(

else:
yield format_stream_data(
{
"event": "token",
"event": "tool_call",
"data": {
"id": chunk_id,
"role": chunk.event.payload.step_type,
"token": tool_name,
"token": f"Tool:{r.tool_name} response:{r.content}",
},
}
)
return None


# -----------------------------------
# Catch-all for everything else
# -----------------------------------
def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]:
yield format_stream_data(
{
"event": "heartbeat",
"data": {
"id": chunk_id,
"token": "heartbeat",
},
}
)


@router.post("/streaming_query")
Expand Down Expand Up @@ -233,7 +440,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
yield stream_start_event(conversation_id)

async for chunk in turn_response:
if event := stream_build_event(chunk, chunk_id, metadata_map):
for event in stream_build_event(chunk, chunk_id, metadata_map):
complete_response += json.loads(event.replace("data: ", ""))[
"data"
]["token"]
Expand Down
Loading