Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/app/endpoints/a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from models.requests import QueryRequest
from utils.mcp_headers import mcp_headers_dependency, McpHeaders
from utils.responses import (
extract_text_from_response_output_item,
extract_text_from_output_item,
prepare_responses_params,
)
from utils.suid import normalize_conversation_id
Expand Down Expand Up @@ -107,7 +107,7 @@ def _convert_responses_content_to_a2a_parts(output: list[Any]) -> list[Part]:
parts: list[Part] = []

for output_item in output:
text = extract_text_from_response_output_item(output_item)
text = extract_text_from_output_item(output_item)
if text:
parts.append(Part(root=TextPart(text=text)))

Expand Down
71 changes: 14 additions & 57 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""Handler for REST API call to provide answer to query using Response API."""

import datetime
from typing import Annotated, Any, Optional, cast
from typing import Annotated, Any, cast

from fastapi import APIRouter, Depends, HTTPException, Request
from llama_stack_api.openai_responses import OpenAIResponseObject
Expand Down Expand Up @@ -52,12 +52,10 @@
)
from utils.quota import check_tokens_available, get_available_quotas
from utils.responses import (
build_tool_call_summary,
extract_text_from_response_output_item,
extract_token_usage,
build_turn_summary,
deduplicate_referenced_documents,
extract_vector_store_ids_from_tools,
get_topic_summary,
parse_referenced_documents,
prepare_responses_params,
)
from utils.shields import (
Expand All @@ -66,6 +64,7 @@
)
from utils.suid import normalize_conversation_id
from utils.types import (
RAGChunk,
ResponsesApiParams,
TurnSummary,
)
Expand Down Expand Up @@ -130,7 +129,9 @@ async def query_endpoint_handler(
check_tokens_available(configuration.quota_limiters, user_id)

# Enforce RBAC: optionally disallow overriding model/provider in requests
validate_model_provider_override(query_request, request.state.authorized_actions)
validate_model_provider_override(
query_request.model, query_request.provider, request.state.authorized_actions
)

# Validate attachments if provided
if query_request.attachments:
Expand All @@ -153,7 +154,7 @@ async def query_endpoint_handler(
client = AsyncLlamaStackClientHolder().get_client()

doc_ids_from_chunks: list[ReferencedDocument] = []
pre_rag_chunks: list[Any] = [] # use your RAGChunk type (or the upstream one)
pre_rag_chunks: list[RAGChunk] = []

_, _, doc_ids_from_chunks, pre_rag_chunks = await perform_vector_search(
client, query_request, configuration
Expand Down Expand Up @@ -198,7 +199,7 @@ async def query_endpoint_handler(
turn_summary.rag_chunks = pre_rag_chunks + (turn_summary.rag_chunks or [])

if doc_ids_from_chunks:
turn_summary.referenced_documents = parse_referenced_docs(
turn_summary.referenced_documents = deduplicate_referenced_documents(
doc_ids_from_chunks + (turn_summary.referenced_documents or [])
)

Expand All @@ -216,7 +217,6 @@ async def query_endpoint_handler(
user_id=user_id,
model_id=responses_params.model,
token_usage=turn_summary.token_usage,
configuration=configuration,
)

logger.info("Getting available quotas")
Expand All @@ -238,7 +238,6 @@ async def query_endpoint_handler(
completed_at=completed_at,
summary=turn_summary,
query_request=query_request,
configuration=configuration,
skip_userid_check=_skip_userid_check,
topic_summary=topic_summary,
)
Expand All @@ -258,26 +257,11 @@ async def query_endpoint_handler(
)


def parse_referenced_docs(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed and moved to utils

docs: list[ReferencedDocument],
) -> list[ReferencedDocument]:
"""Remove duplicate referenced documents based on URL and title."""
seen: set[tuple[str | None, str | None]] = set()
out: list[ReferencedDocument] = []
for d in docs:
key = (str(d.doc_url) if d.doc_url else None, d.doc_title)
if key in seen:
continue
seen.add(key)
out.append(d)
return out


async def retrieve_response( # pylint: disable=too-many-locals
client: AsyncLlamaStackClient,
responses_params: ResponsesApiParams,
vector_store_ids: Optional[list[str]] = None,
rag_id_mapping: Optional[dict[str, str]] = None,
vector_store_ids: list[str] | None = None,
rag_id_mapping: dict[str, str] | None = None,
) -> TurnSummary:
"""
Retrieve response from LLMs and agents.
Expand All @@ -294,8 +278,6 @@ async def retrieve_response( # pylint: disable=too-many-locals
Returns:
TurnSummary: Summary of the LLM response content
"""
summary = TurnSummary()

try:
moderation_result = await run_shield_moderation(client, responses_params.input)
if moderation_result.blocked:
Expand All @@ -307,8 +289,7 @@ async def retrieve_response( # pylint: disable=too-many-locals
responses_params.input,
violation_message,
)
summary.llm_response = violation_message
return summary
return TurnSummary(llm_response=violation_message)
response = await client.responses.create(**responses_params.model_dump())
response = cast(OpenAIResponseObject, response)

Expand All @@ -327,30 +308,6 @@ async def retrieve_response( # pylint: disable=too-many-locals
error_response = handle_known_apistatus_errors(e, responses_params.model)
raise HTTPException(**error_response.model_dump()) from e

# Process OpenAI response format
for output_item in response.output:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extracted to function - reusable

message_text = extract_text_from_response_output_item(output_item)
if message_text:
summary.llm_response += message_text

tool_call, tool_result = build_tool_call_summary(
output_item, summary.rag_chunks, vector_store_ids, rag_id_mapping
)
if tool_call:
summary.tool_calls.append(tool_call)
if tool_result:
summary.tool_results.append(tool_result)

logger.info(
"Response processing complete - Tool calls: %d, Response length: %d chars",
len(summary.tool_calls),
len(summary.llm_response),
)

# Extract referenced documents and token usage from Responses API response
summary.referenced_documents = parse_referenced_documents(
response, vector_store_ids, rag_id_mapping
return build_turn_summary(
response, responses_params.model, vector_store_ids, rag_id_mapping
)
summary.token_usage = extract_token_usage(response, responses_params.model)

return summary
12 changes: 6 additions & 6 deletions src/app/endpoints/rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
from observability import InferenceEventData, build_inference_event, send_splunk_event
from utils.query import handle_known_apistatus_errors
from utils.responses import extract_text_from_response_output_item, get_mcp_tools
from utils.responses import (
extract_text_from_output_items,
get_mcp_tools,
)
from utils.suid import get_suid
from log import get_logger

Expand Down Expand Up @@ -189,10 +192,7 @@ async def retrieve_simple_response(
)
response = cast(OpenAIResponseObject, response)

return "".join(
extract_text_from_response_output_item(output_item)
for output_item in response.output
)
return extract_text_from_output_items(response.output)


def _get_cla_version(request: Request) -> str:
Expand Down Expand Up @@ -307,7 +307,7 @@ async def infer_endpoint(
input_source = infer_request.get_input_source()
instructions = _build_instructions(infer_request.context.systeminfo)
model_id = _get_default_model_id()
mcp_tools = await get_mcp_tools(configuration.mcp_servers)
mcp_tools = await get_mcp_tools()
logger.debug(
"Request %s: Combined input source length: %d", request_id, len(input_source)
)
Expand Down
11 changes: 7 additions & 4 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
check_tokens_available(configuration.quota_limiters, user_id)

# Enforce RBAC: optionally disallow overriding model/provider in requests
validate_model_provider_override(query_request, request.state.authorized_actions)
validate_model_provider_override(
query_request.model, query_request.provider, request.state.authorized_actions
)

# Validate attachments if provided
if query_request.attachments:
Expand Down Expand Up @@ -379,7 +381,6 @@ async def generate_response(
user_id=context.user_id,
model_id=responses_params.model,
token_usage=turn_summary.token_usage,
configuration=configuration,
)
# Get available quotas
logger.info("Getting available quotas")
Expand All @@ -405,7 +406,6 @@ async def generate_response(
started_at=context.started_at,
summary=turn_summary,
query_request=context.query_request,
configuration=configuration,
skip_userid_check=context.skip_userid_check,
topic_summary=topic_summary,
)
Expand Down Expand Up @@ -591,8 +591,11 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
)

# Extract token usage and referenced documents from the final response object
if not latest_response_object:
return

turn_summary.token_usage = extract_token_usage(
latest_response_object, context.model_id
latest_response_object.usage, context.model_id
)
tool_based_documents = parse_referenced_documents(
latest_response_object,
Expand Down
10 changes: 7 additions & 3 deletions src/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1753,16 +1753,20 @@ class NotFoundResponse(AbstractErrorResponse):
}
}

def __init__(self, *, resource: str, resource_id: str):
def __init__(self, *, resource: str, resource_id: str | None = None):
"""
Create a NotFoundResponse for a missing resource and set the HTTP status to 404.

Parameters:
resource (str): Resource type that was not found (e.g., "conversation", "model").
resource_id (str): Identifier of the missing resource.
resource_id (str | None): Identifier of the missing resource. If None, indicates
the resource type is not configured (e.g., no model selected).
"""
response = f"{resource.title()} not found"
cause = f"{resource.title()} with ID {resource_id} does not exist"
if resource_id is None:
cause = f"No {resource.title()} is configured"
else:
cause = f"{resource.title()} with ID {resource_id} does not exist"
super().__init__(
response=response, cause=cause, status_code=status.HTTP_404_NOT_FOUND
)
Expand Down
71 changes: 34 additions & 37 deletions src/utils/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,40 @@
from fastapi import HTTPException

import constants
from configuration import AppConfig
from models.requests import QueryRequest
from configuration import configuration
from models.responses import UnprocessableEntityResponse


def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str:
def get_system_prompt(system_prompt: str | None) -> str:
"""
Resolve which system prompt to use for a query.

Precedence (highest to lowest):
1. Per-request `system_prompt` from `query_request.system_prompt`.
2. The `custom_profile`'s "default" prompt (when present), accessed via
`config.customization.custom_profile.get_prompts().get("default")`.
3. `config.customization.system_prompt` from application configuration.
get_system_prompt resolves the system prompt with the following precedence
(highest to lowest):
1. Per-request system prompt from the `system_prompt` argument (when allowed).
2. The custom profile's "default" prompt (when present), from application
configuration.
3. The application configuration system prompt.
4. The module default `constants.DEFAULT_SYSTEM_PROMPT` (lowest precedence).

If configuration disables per-request system prompts
(config.customization.disable_query_system_prompt) and the incoming
`query_request` contains a `system_prompt`, an HTTP 422 Unprocessable
Entity is raised instructing the client to remove the field.

Parameters:
query_request (QueryRequest): The incoming query payload; may contain a
per-request `system_prompt`.
config (AppConfig): Application configuration which may include
customization flags, a custom profile, and a default `system_prompt`.
system_prompt: Optional per-request system prompt from the query; may be
None.

Returns:
str: The resolved system prompt to apply to the request.
The resolved system prompt string to apply to the request.

Raises:
HTTPException: 422 Unprocessable Entity when per-request system prompts
are disabled (disable_query_system_prompt) and a non-None
`system_prompt` is provided; the response instructs the client to
remove the system_prompt field from the request.
"""
system_prompt_disabled = (
config.customization is not None
and config.customization.disable_query_system_prompt
configuration.customization is not None
and configuration.customization.disable_query_system_prompt
)
if system_prompt_disabled and query_request.system_prompt:
if system_prompt_disabled and system_prompt:
response = UnprocessableEntityResponse(
response="System prompt customization is disabled",
cause=(
Expand All @@ -48,49 +47,47 @@ def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str:
)
raise HTTPException(**response.model_dump())

if query_request.system_prompt:
if system_prompt:
# Query taking precedence over configuration is the only behavior that
# makes sense here - if the configuration wants precedence, it can
# disable query system prompt altogether with disable_query_system_prompt.
return query_request.system_prompt
return system_prompt

# profile takes precedence for setting prompt
if (
config.customization is not None
and config.customization.custom_profile is not None
configuration.customization is not None
and configuration.customization.custom_profile is not None
):
prompt = config.customization.custom_profile.get_prompts().get("default")
prompt = configuration.customization.custom_profile.get_prompts().get("default")
if prompt:
return prompt

if (
config.customization is not None
and config.customization.system_prompt is not None
configuration.customization is not None
and configuration.customization.system_prompt is not None
):
return config.customization.system_prompt
return configuration.customization.system_prompt

# default system prompt has the lowest precedence
return constants.DEFAULT_SYSTEM_PROMPT


def get_topic_summary_system_prompt(config: AppConfig) -> str:
def get_topic_summary_system_prompt() -> str:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not pass as an argument

"""
Get the topic summary system prompt.

Parameters:
config (AppConfig): Application configuration from which to read
customization/profile settings.

Returns:
str: The topic summary system prompt from the active custom profile if
set, otherwise the default prompt.
"""
# profile takes precedence for setting prompt
if (
config.customization is not None
and config.customization.custom_profile is not None
configuration.customization is not None
and configuration.customization.custom_profile is not None
):
prompt = config.customization.custom_profile.get_prompts().get("topic_summary")
prompt = configuration.customization.custom_profile.get_prompts().get(
"topic_summary"
)
if prompt:
return prompt

Expand Down
Loading
Loading