Skip to content

Commit 919a71c

Browse files
committed
LCORE-533: updated docstrings for streaming query
1 parent f475e65 commit 919a71c

1 file changed

Lines changed: 178 additions & 8 deletions

File tree

src/app/endpoints/streaming_query.py

Lines changed: 178 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,32 @@
4949

5050

5151
def format_stream_data(d: dict) -> str:
52-
"""Format outbound data in the Event Stream Format."""
52+
"""
53+
Format a dictionary as a Server-Sent Events (SSE) data string.
54+
55+
Parameters:
56+
d (dict): The data to be formatted as an SSE event.
57+
58+
Returns:
59+
str: The formatted SSE data string.
60+
"""
5361
data = json.dumps(d)
5462
return f"data: {data}\n\n"
5563

5664

5765
def stream_start_event(conversation_id: str) -> str:
58-
"""Yield the start of the data stream.
66+
"""
67+
Yield the start of the data stream.
5968
60-
Args:
61-
conversation_id: The conversation ID (UUID).
69+
Format a Server-Sent Events (SSE) start event containing the
70+
conversation ID.
71+
72+
Parameters:
73+
conversation_id (str): Unique identifier for the
74+
conversation.
75+
76+
Returns:
77+
str: SSE-formatted string representing the start event.
6278
"""
6379
return format_stream_data(
6480
{
@@ -71,7 +87,21 @@ def stream_start_event(conversation_id: str) -> str:
7187

7288

7389
def stream_end_event(metadata_map: dict) -> str:
74-
"""Yield the end of the data stream."""
90+
"""
91+
Yield the end of the data stream.
92+
93+
Format and return the end event for a streaming response,
94+
including referenced document metadata and placeholder token
95+
counts.
96+
97+
Parameters:
98+
metadata_map (dict): A mapping containing metadata about
99+
referenced documents.
100+
101+
Returns:
102+
str: A Server-Sent Events (SSE) formatted string
103+
representing the end of the data stream.
104+
"""
75105
return format_stream_data(
76106
{
77107
"event": "end",
@@ -137,6 +167,16 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> Iterato
137167
# Error handling
138168
# -----------------------------------
139169
def _handle_error_event(chunk: Any, chunk_id: int) -> Iterator[str]:
170+
"""
171+
Yield error event.
172+
173+
Yield a formatted Server-Sent Events (SSE) error event
174+
containing the error message from a streaming chunk.
175+
176+
Parameters:
177+
chunk_id (int): The unique identifier for the current
178+
streaming chunk.
179+
"""
140180
yield format_stream_data(
141181
{
142182
"event": "error",
@@ -152,6 +192,20 @@ def _handle_error_event(chunk: Any, chunk_id: int) -> Iterator[str]:
152192
# Turn handling
153193
# -----------------------------------
154194
def _handle_turn_start_event(chunk_id: int) -> Iterator[str]:
195+
"""
196+
Yield turn start event.
197+
198+
Yield a Server-Sent Event (SSE) token event indicating the
199+
start of a new conversation turn.
200+
201+
Parameters:
202+
chunk_id (int): The unique identifier for the current
203+
chunk.
204+
205+
Yields:
206+
str: SSE-formatted token event with an empty token to
207+
signal turn start.
208+
"""
155209
yield format_stream_data(
156210
{
157211
"event": "token",
@@ -164,6 +218,20 @@ def _handle_turn_start_event(chunk_id: int) -> Iterator[str]:
164218

165219

166220
def _handle_turn_complete_event(chunk: Any, chunk_id: int) -> Iterator[str]:
221+
"""
222+
Yield turn complete event.
223+
224+
Yields a Server-Sent Event (SSE) indicating the completion of a
225+
conversation turn, including the full output message content.
226+
227+
Parameters:
228+
chunk_id (int): The unique identifier for the current
229+
chunk.
230+
231+
Yields:
232+
str: SSE-formatted string containing the turn completion
233+
event and output message content.
234+
"""
167235
yield format_stream_data(
168236
{
169237
"event": "turn_complete",
@@ -181,6 +249,16 @@ def _handle_turn_complete_event(chunk: Any, chunk_id: int) -> Iterator[str]:
181249
# Shield handling
182250
# -----------------------------------
183251
def _handle_shield_event(chunk: Any, chunk_id: int) -> Iterator[str]:
252+
"""
253+
Yield shield event.
254+
255+
Processes a shield event chunk and yields a formatted SSE token
256+
event indicating shield validation results.
257+
258+
Yields a "No Violation" token if no violation is detected, or a
259+
violation message if a shield violation occurs. Increments
260+
validation error metrics when violations are present.
261+
"""
184262
if chunk.event.payload.event_type == "step_complete":
185263
violation = chunk.event.payload.step_details.violation
186264
if not violation:
@@ -216,6 +294,16 @@ def _handle_shield_event(chunk: Any, chunk_id: int) -> Iterator[str]:
216294
# Inference handling
217295
# -----------------------------------
218296
def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]:
297+
"""
298+
Yield inference step event.
299+
300+
Yield formatted Server-Sent Events (SSE) strings for inference
301+
step events during streaming.
302+
303+
Processes inference-related streaming chunks, yielding SSE
304+
events for step start, text token deltas, and tool call deltas.
305+
Supports both string and ToolCall object tool calls.
306+
"""
219307
if chunk.event.payload.event_type == "step_start":
220308
yield format_stream_data(
221309
{
@@ -273,6 +361,26 @@ def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]:
273361
def _handle_tool_execution_event(
274362
chunk: Any, chunk_id: int, metadata_map: dict
275363
) -> Iterator[str]:
364+
"""
365+
Yield tool call event.
366+
367+
Processes tool execution events from a streaming chunk and
368+
yields formatted Server-Sent Events (SSE) strings.
369+
370+
Handles both tool call initiation and completion, including
371+
tool call arguments, responses, and summaries. Extracts and
372+
updates document metadata from knowledge search tool responses
373+
when present.
374+
375+
Parameters:
376+
chunk_id (int): Unique identifier for the current streaming
377+
chunk. metadata_map (dict): Dictionary to be updated with
378+
document metadata extracted from tool responses.
379+
380+
Yields:
381+
str: SSE-formatted event strings representing tool call
382+
events and responses.
383+
"""
276384
if chunk.event.payload.event_type == "step_start":
277385
yield format_stream_data(
278386
{
@@ -372,6 +480,19 @@ def _handle_tool_execution_event(
372480
# Catch-all for everything else
373481
# -----------------------------------
374482
def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]:
483+
"""
484+
Yield a heartbeat event.
485+
486+
Yield a heartbeat event as a Server-Sent Event (SSE) for the
487+
given chunk ID.
488+
489+
Parameters:
490+
chunk_id (int): The identifier for the current streaming
491+
chunk.
492+
493+
Yields:
494+
str: SSE-formatted heartbeat event string.
495+
"""
375496
yield format_stream_data(
376497
{
377498
"event": "heartbeat",
@@ -390,7 +511,24 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
390511
auth: Annotated[AuthTuple, Depends(auth_dependency)],
391512
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
392513
) -> StreamingResponse:
393-
"""Handle request to the /streaming_query endpoint."""
514+
"""
515+
Handle request to the /streaming_query endpoint.
516+
517+
This endpoint receives a query request, authenticates the user,
518+
selects the appropriate model and provider, and streams
519+
incremental response events from the Llama Stack backend to the
520+
client. Events include start, token updates, tool calls, turn
521+
completions, errors, and end-of-stream metadata. Optionally
522+
stores the conversation transcript if enabled in configuration.
523+
524+
Returns:
525+
StreamingResponse: An HTTP streaming response yielding
526+
SSE-formatted events for the query lifecycle.
527+
528+
Raises:
529+
HTTPException: Returns HTTP 500 if unable to connect to the
530+
Llama Stack server.
531+
"""
394532
check_configuration_loaded(configuration)
395533

396534
llama_stack_config = configuration.llama_stack_configuration
@@ -437,7 +575,17 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
437575
metadata_map: dict[str, dict[str, Any]] = {}
438576

439577
async def response_generator(turn_response: Any) -> AsyncIterator[str]:
440-
"""Generate SSE formatted streaming response."""
578+
"""
579+
Generate SSE formatted streaming response.
580+
581+
Asynchronously generates a stream of Server-Sent Events
582+
(SSE) representing incremental responses from a
583+
language model turn.
584+
585+
Yields start, token, tool call, turn completion, and
586+
end events as SSE-formatted strings. Collects the
587+
complete response for transcript storage if enabled.
588+
"""
441589
chunk_id = 0
442590
complete_response = "No response from the model"
443591

@@ -508,7 +656,29 @@ async def retrieve_response(
508656
token: str,
509657
mcp_headers: dict[str, dict[str, str]] | None = None,
510658
) -> tuple[Any, str]:
511-
"""Retrieve response from LLMs and agents."""
659+
"""
660+
Retrieve response from LLMs and agents.
661+
662+
Asynchronously retrieves a streaming response and conversation
663+
ID from the Llama Stack agent for a given user query.
664+
665+
This function configures input/output shields, system prompt,
666+
and tool usage based on the request and environment. It
667+
prepares the agent with appropriate headers and toolgroups,
668+
validates attachments if present, and initiates a streaming
669+
turn with the user's query and any provided documents.
670+
671+
Parameters:
672+
model_id (str): Identifier of the model to use for the query.
673+
query_request (QueryRequest): The user's query and associated metadata.
674+
token (str): Authentication token for downstream services.
675+
mcp_headers (dict[str, dict[str, str]], optional):
676+
Multi-cluster proxy headers for tool integrations.
677+
678+
Returns:
679+
tuple: A tuple containing the streaming response object
680+
and the conversation ID.
681+
"""
512682
available_input_shields = [
513683
shield.identifier
514684
for shield in filter(is_input_shield, await client.shields.list())

0 commit comments

Comments
 (0)