diff --git a/src/app/endpoints/a2a.py b/src/app/endpoints/a2a.py index 0370b4fba..97e0739c0 100644 --- a/src/app/endpoints/a2a.py +++ b/src/app/endpoints/a2a.py @@ -46,7 +46,7 @@ from models.requests import QueryRequest from utils.mcp_headers import mcp_headers_dependency, McpHeaders from utils.responses import ( - extract_text_from_response_output_item, + extract_text_from_output_item, prepare_responses_params, ) from utils.suid import normalize_conversation_id @@ -107,7 +107,7 @@ def _convert_responses_content_to_a2a_parts(output: list[Any]) -> list[Part]: parts: list[Part] = [] for output_item in output: - text = extract_text_from_response_output_item(output_item) + text = extract_text_from_output_item(output_item) if text: parts.append(Part(root=TextPart(text=text))) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 4b099da63..e5ae37e51 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -3,7 +3,7 @@ """Handler for REST API call to provide answer to query using Response API.""" import datetime -from typing import Annotated, Any, Optional, cast +from typing import Annotated, Any, cast from fastapi import APIRouter, Depends, HTTPException, Request from llama_stack_api.openai_responses import OpenAIResponseObject @@ -52,12 +52,10 @@ ) from utils.quota import check_tokens_available, get_available_quotas from utils.responses import ( - build_tool_call_summary, - extract_text_from_response_output_item, - extract_token_usage, + build_turn_summary, + deduplicate_referenced_documents, extract_vector_store_ids_from_tools, get_topic_summary, - parse_referenced_documents, prepare_responses_params, ) from utils.shields import ( @@ -66,6 +64,7 @@ ) from utils.suid import normalize_conversation_id from utils.types import ( + RAGChunk, ResponsesApiParams, TurnSummary, ) @@ -130,7 +129,9 @@ async def query_endpoint_handler( check_tokens_available(configuration.quota_limiters, user_id) # Enforce RBAC: optionally disallow overriding model/provider in requests - validate_model_provider_override(query_request, request.state.authorized_actions) + validate_model_provider_override( + query_request.model, query_request.provider, request.state.authorized_actions + ) # Validate attachments if provided if query_request.attachments: @@ -153,7 +154,7 @@ async def query_endpoint_handler( client = AsyncLlamaStackClientHolder().get_client() doc_ids_from_chunks: list[ReferencedDocument] = [] - pre_rag_chunks: list[Any] = [] # use your RAGChunk type (or the upstream one) + pre_rag_chunks: list[RAGChunk] = [] _, _, doc_ids_from_chunks, pre_rag_chunks = await perform_vector_search( client, query_request, configuration @@ -198,7 +199,7 @@ async def query_endpoint_handler( turn_summary.rag_chunks = pre_rag_chunks + (turn_summary.rag_chunks or []) if doc_ids_from_chunks: - turn_summary.referenced_documents = parse_referenced_docs( + turn_summary.referenced_documents = deduplicate_referenced_documents( doc_ids_from_chunks + (turn_summary.referenced_documents or []) ) @@ -216,7 +217,6 @@ async def query_endpoint_handler( user_id=user_id, model_id=responses_params.model, token_usage=turn_summary.token_usage, - configuration=configuration, ) logger.info("Getting available quotas") @@ -238,7 +238,6 @@ async def query_endpoint_handler( completed_at=completed_at, summary=turn_summary, query_request=query_request, - configuration=configuration, skip_userid_check=_skip_userid_check, topic_summary=topic_summary, ) @@ -258,26 +257,11 @@ async def query_endpoint_handler( ) -def parse_referenced_docs( - docs: list[ReferencedDocument], -) -> list[ReferencedDocument]: - """Remove duplicate referenced documents based on URL and title.""" - seen: set[tuple[str | None, str | None]] = set() - out: list[ReferencedDocument] = [] - for d in docs: - key = (str(d.doc_url) if d.doc_url else None, d.doc_title) - if key in seen: - continue - seen.add(key) - out.append(d) - return out - - async def retrieve_response( # pylint: disable=too-many-locals client: AsyncLlamaStackClient, responses_params: ResponsesApiParams, - vector_store_ids: Optional[list[str]] = None, - rag_id_mapping: Optional[dict[str, str]] = None, + vector_store_ids: list[str] | None = None, + rag_id_mapping: dict[str, str] | None = None, ) -> TurnSummary: """ Retrieve response from LLMs and agents. @@ -294,8 +278,6 @@ async def retrieve_response( # pylint: disable=too-many-locals Returns: TurnSummary: Summary of the LLM response content """ - summary = TurnSummary() - try: moderation_result = await run_shield_moderation(client, responses_params.input) if moderation_result.blocked: @@ -307,8 +289,7 @@ async def retrieve_response( # pylint: disable=too-many-locals responses_params.input, violation_message, ) - summary.llm_response = violation_message - return summary + return TurnSummary(llm_response=violation_message) response = await client.responses.create(**responses_params.model_dump()) response = cast(OpenAIResponseObject, response) @@ -327,30 +308,6 @@ async def retrieve_response( # pylint: disable=too-many-locals error_response = handle_known_apistatus_errors(e, responses_params.model) raise HTTPException(**error_response.model_dump()) from e - # Process OpenAI response format - for output_item in response.output: - message_text = extract_text_from_response_output_item(output_item) - if message_text: - summary.llm_response += message_text - - tool_call, tool_result = build_tool_call_summary( - output_item, summary.rag_chunks, vector_store_ids, rag_id_mapping - ) - if tool_call: - summary.tool_calls.append(tool_call) - if tool_result: - summary.tool_results.append(tool_result) - - logger.info( - "Response processing complete - Tool calls: %d, Response length: %d chars", - len(summary.tool_calls), - len(summary.llm_response), - ) - - # Extract referenced documents and token usage from Responses API response - summary.referenced_documents = parse_referenced_documents( - response, vector_store_ids, rag_id_mapping + return build_turn_summary( + response, responses_params.model, vector_store_ids, rag_id_mapping ) - summary.token_usage = extract_token_usage(response, responses_params.model) - - return summary diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index 76f759ff8..6528bfce4 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -34,7 +34,10 @@ from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse from observability import InferenceEventData, build_inference_event, send_splunk_event from utils.query import handle_known_apistatus_errors -from utils.responses import extract_text_from_response_output_item, get_mcp_tools +from utils.responses import ( + extract_text_from_output_items, + get_mcp_tools, +) from utils.suid import get_suid from log import get_logger @@ -189,10 +192,7 @@ async def retrieve_simple_response( ) response = cast(OpenAIResponseObject, response) - return "".join( - extract_text_from_response_output_item(output_item) - for output_item in response.output - ) + return extract_text_from_output_items(response.output) def _get_cla_version(request: Request) -> str: @@ -307,7 +307,7 @@ async def infer_endpoint( input_source = infer_request.get_input_source() instructions = _build_instructions(infer_request.context.systeminfo) model_id = _get_default_model_id() - mcp_tools = await get_mcp_tools(configuration.mcp_servers) + mcp_tools = await get_mcp_tools() logger.debug( "Request %s: Combined input source length: %d", request_id, len(input_source) ) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index b45c4f625..b3d370a57 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -153,7 +153,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals check_tokens_available(configuration.quota_limiters, user_id) # Enforce RBAC: optionally disallow overriding model/provider in requests - validate_model_provider_override(query_request, request.state.authorized_actions) + validate_model_provider_override( + query_request.model, query_request.provider, request.state.authorized_actions + ) # Validate attachments if provided if query_request.attachments: @@ -379,7 +381,6 @@ async def generate_response( user_id=context.user_id, model_id=responses_params.model, token_usage=turn_summary.token_usage, - configuration=configuration, ) # Get available quotas logger.info("Getting available quotas") @@ -405,7 +406,6 @@ async def generate_response( started_at=context.started_at, summary=turn_summary, query_request=context.query_request, - configuration=configuration, skip_userid_check=context.skip_userid_check, topic_summary=topic_summary, ) @@ -591,8 +591,11 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat ) # Extract token usage and referenced documents from the final response object + if not latest_response_object: + return + turn_summary.token_usage = extract_token_usage( - latest_response_object, context.model_id + latest_response_object.usage, context.model_id ) tool_based_documents = parse_referenced_documents( latest_response_object, diff --git a/src/models/responses.py b/src/models/responses.py index 9b29be513..2174d9216 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -1753,16 +1753,20 @@ class NotFoundResponse(AbstractErrorResponse): } } - def __init__(self, *, resource: str, resource_id: str): + def __init__(self, *, resource: str, resource_id: str | None = None): """ Create a NotFoundResponse for a missing resource and set the HTTP status to 404. Parameters: resource (str): Resource type that was not found (e.g., "conversation", "model"). - resource_id (str): Identifier of the missing resource. + resource_id (str | None): Identifier of the missing resource. If None, indicates + the resource type is not configured (e.g., no model selected). """ response = f"{resource.title()} not found" - cause = f"{resource.title()} with ID {resource_id} does not exist" + if resource_id is None: + cause = f"No {resource.title()} is configured" + else: + cause = f"{resource.title()} with ID {resource_id} does not exist" super().__init__( response=response, cause=cause, status_code=status.HTTP_404_NOT_FOUND ) diff --git a/src/utils/prompts.py b/src/utils/prompts.py index 0b6410b75..aabb10399 100644 --- a/src/utils/prompts.py +++ b/src/utils/prompts.py @@ -3,41 +3,40 @@ from fastapi import HTTPException import constants -from configuration import AppConfig -from models.requests import QueryRequest +from configuration import configuration from models.responses import UnprocessableEntityResponse -def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str: +def get_system_prompt(system_prompt: str | None) -> str: """ Resolve which system prompt to use for a query. - Precedence (highest to lowest): - 1. Per-request `system_prompt` from `query_request.system_prompt`. - 2. The `custom_profile`'s "default" prompt (when present), accessed via - `config.customization.custom_profile.get_prompts().get("default")`. - 3. `config.customization.system_prompt` from application configuration. + get_system_prompt resolves the system prompt with the following precedence + (highest to lowest): + 1. Per-request system prompt from the `system_prompt` argument (when allowed). + 2. The custom profile's "default" prompt (when present), from application + configuration. + 3. The application configuration system prompt. 4. The module default `constants.DEFAULT_SYSTEM_PROMPT` (lowest precedence). - If configuration disables per-request system prompts - (config.customization.disable_query_system_prompt) and the incoming - `query_request` contains a `system_prompt`, an HTTP 422 Unprocessable - Entity is raised instructing the client to remove the field. - Parameters: - query_request (QueryRequest): The incoming query payload; may contain a - per-request `system_prompt`. - config (AppConfig): Application configuration which may include - customization flags, a custom profile, and a default `system_prompt`. + system_prompt: Optional per-request system prompt from the query; may be + None. Returns: - str: The resolved system prompt to apply to the request. + The resolved system prompt string to apply to the request. + + Raises: + HTTPException: 422 Unprocessable Entity when per-request system prompts + are disabled (disable_query_system_prompt) and a non-None + `system_prompt` is provided; the response instructs the client to + remove the system_prompt field from the request. """ system_prompt_disabled = ( - config.customization is not None - and config.customization.disable_query_system_prompt + configuration.customization is not None + and configuration.customization.disable_query_system_prompt ) - if system_prompt_disabled and query_request.system_prompt: + if system_prompt_disabled and system_prompt: response = UnprocessableEntityResponse( response="System prompt customization is disabled", cause=( @@ -48,49 +47,47 @@ def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str: ) raise HTTPException(**response.model_dump()) - if query_request.system_prompt: + if system_prompt: # Query taking precedence over configuration is the only behavior that # makes sense here - if the configuration wants precedence, it can # disable query system prompt altogether with disable_query_system_prompt. - return query_request.system_prompt + return system_prompt # profile takes precedence for setting prompt if ( - config.customization is not None - and config.customization.custom_profile is not None + configuration.customization is not None + and configuration.customization.custom_profile is not None ): - prompt = config.customization.custom_profile.get_prompts().get("default") + prompt = configuration.customization.custom_profile.get_prompts().get("default") if prompt: return prompt if ( - config.customization is not None - and config.customization.system_prompt is not None + configuration.customization is not None + and configuration.customization.system_prompt is not None ): - return config.customization.system_prompt + return configuration.customization.system_prompt # default system prompt has the lowest precedence return constants.DEFAULT_SYSTEM_PROMPT -def get_topic_summary_system_prompt(config: AppConfig) -> str: +def get_topic_summary_system_prompt() -> str: """ Get the topic summary system prompt. - Parameters: - config (AppConfig): Application configuration from which to read - customization/profile settings. - Returns: str: The topic summary system prompt from the active custom profile if set, otherwise the default prompt. """ # profile takes precedence for setting prompt if ( - config.customization is not None - and config.customization.custom_profile is not None + configuration.customization is not None + and configuration.customization.custom_profile is not None ): - prompt = config.customization.custom_profile.get_prompts().get("topic_summary") + prompt = configuration.customization.custom_profile.get_prompts().get( + "topic_summary" + ) if prompt: return prompt diff --git a/src/utils/query.py b/src/utils/query.py index 740575a0f..d1ab6cfe0 100644 --- a/src/utils/query.py +++ b/src/utils/query.py @@ -9,11 +9,11 @@ AsyncLlamaStackClient, ) from openai._exceptions import APIStatusError as OpenAIAPIStatusError -from llama_stack_client.types import ModelListResponse, Shield +from llama_stack_client.types import Shield from fastapi import HTTPException from sqlalchemy import func -from configuration import AppConfig, configuration +from configuration import configuration from models.cache_entry import CacheEntry from models.config import Action from models.database.conversations import UserConversation, UserTurn @@ -23,7 +23,6 @@ AbstractErrorResponse, ForbiddenResponse, InternalServerErrorResponse, - NotFoundResponse, PromptTooLongResponse, QuotaExceededResponse, ServiceUnavailableResponse, @@ -36,7 +35,11 @@ from sqlalchemy.exc import SQLAlchemyError from app.database import get_session from client import AsyncLlamaStackClientHolder -from utils.transcripts import store_transcript +from utils.transcripts import ( + create_transcript, + create_transcript_metadata, + store_transcript, +) from utils.quota import consume_tokens from utils.suid import normalize_conversation_id from utils.token_counter import TokenCounter @@ -47,11 +50,10 @@ def store_conversation_into_cache( - config: AppConfig, user_id: str, conversation_id: str, cache_entry: CacheEntry, - _skip_userid_check: bool, + skip_userid_check: bool, topic_summary: Optional[str], ) -> None: """ @@ -62,122 +64,48 @@ def store_conversation_into_cache( anything. Parameters: - config (AppConfig): Application configuration that may contain - conversation cache settings and instance. user_id (str): Owner identifier used as the cache key. conversation_id (str): Conversation identifier used as the cache key. cache_entry (CacheEntry): Entry to insert or append to the conversation history. - _skip_userid_check (bool): When true, bypasses enforcing that the cache + skip_userid_check (bool): When true, bypasses enforcing that the cache operation must match the user id. topic_summary (Optional[str]): Optional topic summary to store alongside the conversation; ignored if None or empty. """ - if config.conversation_cache_configuration.type is not None: - cache = config.conversation_cache - if cache is None: - logger.warning("Conversation cache configured but not initialized") - return - cache.insert_or_append( - user_id, conversation_id, cache_entry, _skip_userid_check - ) - if topic_summary: - cache.set_topic_summary( - user_id, conversation_id, topic_summary, _skip_userid_check - ) - - -def select_model_and_provider_id( - models: ModelListResponse, model_id: Optional[str], provider_id: Optional[str] -) -> tuple[str, str, str]: - """ - Select the model ID and provider ID based on the request or available models. - - Determine and return the appropriate model and provider IDs for - a query request. - - If the request specifies both model and provider IDs, those are used. - Otherwise, defaults from configuration are applied. If neither is - available, selects the first available LLM model from the provided model - list. Validates that the selected model exists among the available models. - - Returns: - A tuple containing the combined model ID (in the format - "provider/model"), and its separated parts: the model label and the provider ID. - - Raises: - HTTPException: If no suitable LLM model is found or the selected model is not available. - """ - # If model_id and provider_id are provided in the request, use them - - # If model_id is not provided in the request, check the configuration - if not model_id or not provider_id: - logger.debug( - "No model ID or provider ID specified in request, checking configuration" - ) - model_id = configuration.inference.default_model # type: ignore[reportAttributeAccessIssue] - provider_id = ( - configuration.inference.default_provider # type: ignore[reportAttributeAccessIssue] + if configuration.conversation_cache_configuration.type is None: + logger.warning("Conversation cache is not configured") + return + + cache = configuration.conversation_cache + if cache is None: + logger.warning("Conversation cache configured but not initialized") + return + + cache.insert_or_append(user_id, conversation_id, cache_entry, skip_userid_check) + if topic_summary: + cache.set_topic_summary( + user_id, conversation_id, topic_summary, skip_userid_check ) - # If no model is specified in the request or configuration, use the first available LLM - if not model_id or not provider_id: - logger.debug( - "No model ID or provider ID specified in request or configuration, " - "using the first available LLM" - ) - try: - model = next( - m - for m in models - if m.custom_metadata and m.custom_metadata.get("model_type") == "llm" - ) - model_id = model.id - # Extract provider_id from custom_metadata - provider_id = ( - str(model.custom_metadata.get("provider_id", "")) - if model.custom_metadata - else "" - ) - logger.info("Selected model: %s", model) - model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id - return model_id, model_label, provider_id - except (StopIteration, AttributeError) as e: - message = "No LLM model found in available models" - logger.error(message) - response = NotFoundResponse(resource="model", resource_id=model_id or "") - raise HTTPException(**response.model_dump()) from e - - llama_stack_model_id = f"{provider_id}/{model_id}" - # Validate that the model_id and provider_id are in the available models - logger.debug("Searching for model: %s, provider: %s", model_id, provider_id) - # TODO: Create separate validation of provider - if not any( - m.id in (llama_stack_model_id, model_id) - and ( - m.custom_metadata - and str(m.custom_metadata.get("provider_id", "")) == provider_id - ) - for m in models - ): - message = f"Model {model_id} from provider {provider_id} not found in available models" - logger.error(message) - response = NotFoundResponse(resource="model", resource_id=model_id) - raise HTTPException(**response.model_dump()) - return llama_stack_model_id, model_id, provider_id - def validate_model_provider_override( - query_request: QueryRequest, authorized_actions: set[Action] | frozenset[Action] + model: str | None, + provider: str | None, + authorized_actions: set[Action] | frozenset[Action], ) -> None: """Validate whether model/provider overrides are allowed by RBAC. + Args: + model: Model identifier. In Responses API format, may be "provider/model". + provider: Provider identifier (specified only when used in query endpoint). + authorized_actions: Set of authorized actions for the caller. + Raises: - HTTPException: HTTP 403 if the request includes model or provider and + HTTPException: HTTP 403 if the request includes model/provider override and the caller lacks Action.MODEL_OVERRIDE permission. """ - if (query_request.model is not None or query_request.provider is not None) and ( - Action.MODEL_OVERRIDE not in authorized_actions - ): + has_override = provider is not None or (model is not None and "/" in model) + if has_override and Action.MODEL_OVERRIDE not in authorized_actions: response = ForbiddenResponse.model_override() raise HTTPException(**response.model_dump()) @@ -224,48 +152,6 @@ def is_input_shield(shield: Shield) -> bool: return _is_inout_shield(shield) or not is_output_shield(shield) -def evaluate_model_hints( - user_conversation: Optional[UserConversation], - query_request: QueryRequest, -) -> tuple[Optional[str], Optional[str]]: - """Evaluate model hints from user conversation.""" - model_id: Optional[str] = query_request.model - provider_id: Optional[str] = query_request.provider - - if user_conversation is not None: - if query_request.model is not None: - if query_request.model != user_conversation.last_used_model: - logger.debug( - "Model specified in request: %s, preferring it over user conversation model %s", - query_request.model, - user_conversation.last_used_model, - ) - else: - logger.debug( - "No model specified in request, using latest model from user conversation: %s", - user_conversation.last_used_model, - ) - model_id = user_conversation.last_used_model - - if query_request.provider is not None: - if query_request.provider != user_conversation.last_used_provider: - logger.debug( - "Provider specified in request: %s, " - "preferring it over user conversation provider %s", - query_request.provider, - user_conversation.last_used_provider, - ) - else: - logger.debug( - "No provider specified in request, " - "using latest provider from user conversation: %s", - user_conversation.last_used_provider, - ) - provider_id = user_conversation.last_used_provider - - return model_id, provider_id - - async def update_azure_token( client: AsyncLlamaStackClient, ) -> AsyncLlamaStackClient: @@ -336,7 +222,6 @@ def store_query_results( # pylint: disable=too-many-arguments,too-many-locals completed_at: str, summary: TurnSummary, query_request: QueryRequest, - configuration: AppConfig, skip_userid_check: bool, topic_summary: Optional[str], ) -> None: @@ -356,7 +241,6 @@ def store_query_results( # pylint: disable=too-many-arguments,too-many-locals completed_at: ISO formatted timestamp when the request completed summary: Summary of the turn including LLM response and tool calls query_request: The original query request - configuration: Application configuration skip_userid_check: Whether to skip user ID validation topic_summary: Optional topic summary for the conversation @@ -366,27 +250,22 @@ def store_query_results( # pylint: disable=too-many-arguments,too-many-locals provider_id, model_id = extract_provider_and_model_from_model_id(model) # Store transcript if enabled if is_transcripts_enabled(): - try: - # Convert RAG chunks to dictionary format once for reuse - logger.info("Storing transcript") - rag_chunks_dict = [chunk.model_dump() for chunk in summary.rag_chunks] - store_transcript( - user_id=user_id, - conversation_id=conversation_id, - model_id=model_id, - provider_id=provider_id, - query_is_valid=True, # TODO(lucasagomes): implement as part of query validation - query=query_request.query, - query_request=query_request, - summary=summary, - rag_chunks=rag_chunks_dict, - truncated=False, # TODO(lucasagomes): implement truncation as part of quota work - attachments=query_request.attachments or [], - ) - except (IOError, OSError) as e: - logger.exception("Error storing transcript: %s", e) - response = InternalServerErrorResponse.generic() - raise HTTPException(**response.model_dump()) from e + logger.info("Storing transcript") + metadata = create_transcript_metadata( + user_id=user_id, + conversation_id=conversation_id, + model_id=model_id, + provider_id=provider_id, + query_provider=query_request.provider, + query_model=query_request.model, + ) + transcript = create_transcript( + metadata=metadata, + redacted_query=query_request.query, + summary=summary, + attachments=query_request.attachments or [], + ) + store_transcript(transcript) else: logger.debug("Transcript collection is disabled in the configuration") @@ -409,26 +288,24 @@ def store_query_results( # pylint: disable=too-many-arguments,too-many-locals raise HTTPException(**response.model_dump()) from e # Store conversation in cache + cache_entry = CacheEntry( + query=query_request.query, + response=summary.llm_response, + provider=provider_id, + model=model_id, + started_at=started_at, + completed_at=completed_at, + referenced_documents=summary.referenced_documents, + tool_calls=summary.tool_calls, + tool_results=summary.tool_results, + ) try: - cache_entry = CacheEntry( - query=query_request.query, - response=summary.llm_response, - provider=provider_id, - model=model_id, - started_at=started_at, - completed_at=completed_at, - referenced_documents=summary.referenced_documents, - tool_calls=summary.tool_calls, - tool_results=summary.tool_results, - ) - logger.info("Storing conversation in cache") store_conversation_into_cache( - config=configuration, user_id=user_id, conversation_id=conversation_id, cache_entry=cache_entry, - _skip_userid_check=skip_userid_check, + skip_userid_check=skip_userid_check, topic_summary=topic_summary, ) except (CacheError, ValueError, psycopg2.Error, sqlite3.Error) as e: @@ -441,7 +318,6 @@ def consume_query_tokens( user_id: str, model_id: str, token_usage: TokenCounter, - configuration: AppConfig, ) -> None: """Consume tokens from quota limiters for a query. @@ -453,7 +329,6 @@ def consume_query_tokens( user_id: The authenticated user ID model_id: The full model identifier in "provider/model" format token_usage: TokenCounter object with input and output token counts - configuration: Application configuration Raises: HTTPException: On database errors during token consumption diff --git a/src/utils/responses.py b/src/utils/responses.py index 5b28d47bc..987629b2b 100644 --- a/src/utils/responses.py +++ b/src/utils/responses.py @@ -1,12 +1,21 @@ """Utility functions for processing Responses API output.""" +# pylint: disable=too-many-lines + import json -from typing import Any, Optional, cast +from collections.abc import Sequence +from typing import Any, cast from fastapi import HTTPException from llama_stack_api.openai_responses import ( - OpenAIResponseObject, - OpenAIResponseOutput, + OpenAIResponseContentPartRefusal as ContentPartRefusal, + OpenAIResponseInputMessageContent as InputMessageContent, + OpenAIResponseInputMessageContentText as InputTextPart, + OpenAIResponseMessage as ResponseMessage, + OpenAIResponseObject as ResponseObject, + OpenAIResponseOutput as ResponseOutput, + OpenAIResponseOutputMessageContent as OutputMessageContent, + OpenAIResponseOutputMessageContentOutputText as OutputTextPart, OpenAIResponseOutputMessageFileSearchToolCall as FileSearchCall, OpenAIResponseOutputMessageFunctionToolCall as FunctionCall, OpenAIResponseOutputMessageMCPCall as MCPCall, @@ -14,28 +23,27 @@ OpenAIResponseOutputMessageWebSearchToolCall as WebSearchCall, OpenAIResponseMCPApprovalRequest as MCPApprovalRequest, OpenAIResponseMCPApprovalResponse as MCPApprovalResponse, + OpenAIResponseUsage as ResponseUsage, ) from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient import constants import metrics -from configuration import AppConfig, configuration +from configuration import configuration from constants import DEFAULT_RAG_TOOL -from models.config import ModelContextProtocolServer from models.database.conversations import UserConversation from models.requests import QueryRequest from models.responses import ( InternalServerErrorResponse, + NotFoundResponse, ServiceUnavailableResponse, ) from utils.mcp_oauth_probe import probe_mcp_oauth_and_raise_401 from utils.prompts import get_system_prompt, get_topic_summary_system_prompt from utils.query import ( - evaluate_model_hints, extract_provider_and_model_from_model_id, handle_known_apistatus_errors, prepare_input, - select_model_and_provider_id, ) from utils.mcp_headers import McpHeaders from utils.suid import to_llama_stack_conversation_id @@ -46,53 +54,14 @@ ResponsesApiParams, ToolCallSummary, ToolResultSummary, + TurnSummary, ) from log import get_logger logger = get_logger(__name__) -def extract_text_from_response_output_item(output_item: Any) -> str: - """Extract assistant message text from a Responses API output item. - - Args: - output_item: A Responses API output item from response.output array. - - Returns: - Extracted text content, or empty string if not an assistant message. - """ - if getattr(output_item, "type", None) != "message": - return "" - if getattr(output_item, "role", None) != "assistant": - return "" - - content = getattr(output_item, "content", None) - if isinstance(content, str): - return content - - text_fragments: list[str] = [] - if isinstance(content, list): - for part in content: - if isinstance(part, str): - text_fragments.append(part) - continue - text_value = getattr(part, "text", None) - if text_value: - text_fragments.append(text_value) - continue - refusal = getattr(part, "refusal", None) - if refusal: - text_fragments.append(refusal) - continue - if isinstance(part, dict): - dict_text = part.get("text") or part.get("refusal") - if dict_text: - text_fragments.append(str(dict_text)) - - return "".join(text_fragments) - - -async def get_topic_summary( # pylint: disable=too-many-nested-blocks +async def get_topic_summary( question: str, client: AsyncLlamaStackClient, model_id: str ) -> str: """Get a topic summary for a question using Responses API. @@ -105,16 +74,13 @@ async def get_topic_summary( # pylint: disable=too-many-nested-blocks Returns: The topic summary for the question """ - topic_summary_system_prompt = get_topic_summary_system_prompt(configuration) - - # Use Responses API to generate topic summary try: response = cast( - OpenAIResponseObject, + ResponseObject, await client.responses.create( input=question, model=model_id, - instructions=topic_summary_system_prompt, + instructions=get_topic_summary_system_prompt(), stream=False, store=False, # Don't store topic summary requests ), @@ -129,42 +95,35 @@ async def get_topic_summary( # pylint: disable=too-many-nested-blocks error_response = handle_known_apistatus_errors(e, model_id) raise HTTPException(**error_response.model_dump()) from e - # Extract text from response output - summary_text = "".join( - extract_text_from_response_output_item(output_item) - for output_item in response.output - ) - - return summary_text.strip() if summary_text else "" + return extract_text_from_output_items(response.output) async def prepare_tools( client: AsyncLlamaStackClient, - query_request: QueryRequest, + vector_store_ids: list[str] | None, + no_tools: bool | None, token: str, - config: AppConfig, - mcp_headers: Optional[McpHeaders] = None, -) -> Optional[list[dict[str, Any]]]: + mcp_headers: McpHeaders | None = None, +) -> list[dict[str, Any]] | None: """Prepare tools for Responses API including RAG and MCP tools. Args: client: The Llama Stack client instance - query_request: The user's query request + vector_store_ids: The list of vector store IDs to use for RAG tools + or None if all vector stores should be used + no_tools: Whether to skip tool preparation token: Authentication token for MCP tools - config: Configuration object containing MCP server settings mcp_headers: Per-request headers for MCP servers Returns: - List of tool configurations, or None if no_tools is True or no tools available + List of tool configurations, or None if no tools available """ - if query_request.no_tools: + if no_tools: return None toolgroups = [] - # Get vector stores for RAG tools - use specified ones or fetch all - if query_request.vector_store_ids: - vector_store_ids = query_request.vector_store_ids - else: + # Get all vector stores if vector stores are not restricted by request + if vector_store_ids is None: try: vector_stores = await client.vector_stores.list() vector_store_ids = [vector_store.id for vector_store in vector_stores.data] @@ -184,7 +143,7 @@ async def prepare_tools( toolgroups.extend(rag_tools) # Add MCP server tools - mcp_tools = await get_mcp_tools(config.mcp_servers, token, mcp_headers) + mcp_tools = await get_mcp_tools(token, mcp_headers) if mcp_tools: toolgroups.extend(mcp_tools) logger.debug( @@ -202,9 +161,9 @@ async def prepare_tools( async def prepare_responses_params( # pylint: disable=too-many-arguments,too-many-locals,too-many-positional-arguments client: AsyncLlamaStackClient, query_request: QueryRequest, - user_conversation: Optional[UserConversation], + user_conversation: UserConversation | None, token: str, - mcp_headers: Optional[McpHeaders] = None, + mcp_headers: McpHeaders | None = None, stream: bool = False, store: bool = True, ) -> ResponsesApiParams: @@ -222,40 +181,33 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma Returns: ResponsesApiParams containing all prepared parameters for the API request """ - # Select model and provider - try: - models = await client.models.list() - except APIConnectionError as e: - error_response = ServiceUnavailableResponse( - backend_name="Llama Stack", - cause=str(e), - ) - raise HTTPException(**error_response.model_dump()) from e - except APIStatusError as e: - error_response = InternalServerErrorResponse.generic() - raise HTTPException(**error_response.model_dump()) from e + if query_request.model and query_request.provider: + model = f"{query_request.provider}/{query_request.model}" + else: + model = await select_model_for_responses(client, user_conversation) - llama_stack_model_id, _model_id, _provider_id = select_model_and_provider_id( - models, - *evaluate_model_hints( - user_conversation=user_conversation, query_request=query_request - ), - ) + if not await check_model_configured(client, model): + _, model_id = extract_provider_and_model_from_model_id(model) + error_response = NotFoundResponse(resource="model", resource_id=model_id) + raise HTTPException(**error_response.model_dump()) # Use system prompt from request or default one - system_prompt = get_system_prompt(query_request, configuration) + system_prompt = get_system_prompt(query_request.system_prompt) logger.debug("Using system prompt: %s", system_prompt) # Prepare tools for responses API tools = await prepare_tools( - client, query_request, token, configuration, mcp_headers + client, + query_request.vector_store_ids, + query_request.no_tools, + token, + mcp_headers, ) # Prepare input for Responses API input_text = prepare_input(query_request) # Handle conversation ID for Responses API - # Create conversation upfront if not provided conversation_id = query_request.conversation_id if conversation_id: # Conversation ID was provided - convert to llama-stack format @@ -284,7 +236,7 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma return ResponsesApiParams( input=input_text, - model=llama_stack_model_id, + model=model, instructions=system_prompt, tools=tools, conversation=llama_stack_conv_id, @@ -294,7 +246,7 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma def extract_vector_store_ids_from_tools( - tools: Optional[list[dict[str, Any]]], + tools: list[dict[str, Any]] | None, ) -> list[str]: """Extract vector store IDs from prepared tool configurations. @@ -312,7 +264,7 @@ def extract_vector_store_ids_from_tools( return [] -def get_rag_tools(vector_store_ids: list[str]) -> Optional[list[dict[str, Any]]]: +def get_rag_tools(vector_store_ids: list[str]) -> list[dict[str, Any]] | None: """Convert vector store IDs to tools format for Responses API. Args: @@ -334,14 +286,12 @@ def get_rag_tools(vector_store_ids: list[str]) -> Optional[list[dict[str, Any]]] async def get_mcp_tools( # pylint: disable=too-many-return-statements,too-many-locals - mcp_servers: list[ModelContextProtocolServer], token: str | None = None, - mcp_headers: Optional[McpHeaders] = None, + mcp_headers: McpHeaders | None = None, ) -> list[dict[str, Any]]: """Convert MCP servers to tools format for Responses API. Args: - mcp_servers: List of MCP server configurations token: Optional authentication token for MCP server authorization mcp_headers: Optional per-request headers for MCP servers, keyed by server URL @@ -382,7 +332,7 @@ def _get_token_value(original: str, header: str) -> str | None: return original tools = [] - for mcp_server in mcp_servers: + for mcp_server in configuration.mcp_servers: # Base tool definition tool_def = { "type": "mcp", @@ -431,9 +381,9 @@ def _get_token_value(original: str, header: str) -> str | None: def parse_referenced_documents( # pylint: disable=too-many-locals - response: Optional[OpenAIResponseObject], - vector_store_ids: Optional[list[str]] = None, - rag_id_mapping: Optional[dict[str, str]] = None, + response: ResponseObject | None, + vector_store_ids: list[str] | None = None, + rag_id_mapping: dict[str, str] | None = None, ) -> list[ReferencedDocument]: """Parse referenced documents from Responses API response. @@ -447,7 +397,7 @@ def parse_referenced_documents( # pylint: disable=too-many-locals """ documents: list[ReferencedDocument] = [] # Use a set to track unique documents by (doc_url, doc_title) tuple - seen_docs: set[tuple[Optional[str], Optional[str]]] = set() + seen_docs: set[tuple[str | None, str | None]] = set() # Handle None response (e.g., when agent fails) if response is None or not response.output: @@ -481,7 +431,7 @@ def parse_referenced_documents( # pylint: disable=too-many-locals doc_title = attributes.get("title") if doc_title or doc_url: - # Treat empty string as None for URL to satisfy Optional[AnyUrl] + # Treat empty string as None for URL to satisfy AnyUrl | None final_url = doc_url if doc_url else None if (final_url, doc_title) not in seen_docs: documents.append( @@ -496,97 +446,58 @@ def parse_referenced_documents( # pylint: disable=too-many-locals return documents -def extract_token_usage( - response: Optional[OpenAIResponseObject], model_id: str -) -> TokenCounter: - """Extract token usage from Responses API response and update metrics. +def extract_token_usage(usage: ResponseUsage | None, model: str) -> TokenCounter: + """Extract token usage from Responses API usage object and update metrics. Args: - response: The OpenAI Response API response object - model_id: The model identifier for metrics labeling + usage: ResponseUsage from the Responses API response, or None if not available. + model: The model identifier in "provider/model" format Returns: TokenCounter with input_tokens and output_tokens """ - token_counter = TokenCounter() - token_counter.llm_calls = 1 - provider, model = extract_provider_and_model_from_model_id(model_id) - - # Extract usage from the response if available - # Note: usage attribute exists at runtime but may not be in type definitions - usage = getattr(response, "usage", None) if response else None - if usage: - try: - # Handle both dict and object cases due to llama_stack inconsistency: - # - When llama_stack converts to chat_completions internally, usage is a dict - # - When using proper Responses API, usage should be an object - # TODO: Remove dict handling once llama_stack standardizes on object type # pylint: disable=fixme - if isinstance(usage, dict): - input_tokens = usage.get("input_tokens", 0) - output_tokens = usage.get("output_tokens", 0) - else: - # Object with attributes (expected final behavior) - input_tokens = getattr(usage, "input_tokens", 0) - output_tokens = getattr(usage, "output_tokens", 0) - # Only set if we got valid values - if input_tokens or output_tokens: - token_counter.input_tokens = input_tokens or 0 - token_counter.output_tokens = output_tokens or 0 - - logger.debug( - "Extracted token usage from Responses API: input=%d, output=%d", - token_counter.input_tokens, - token_counter.output_tokens, - ) - - # Update Prometheus metrics only when we have actual usage data - try: - metrics.llm_token_sent_total.labels(provider, model).inc( - token_counter.input_tokens - ) - metrics.llm_token_received_total.labels(provider, model).inc( - token_counter.output_tokens - ) - except (AttributeError, TypeError, ValueError) as e: - logger.warning("Failed to update token metrics: %s", e) - _increment_llm_call_metric(provider, model) - else: - logger.debug( - "Usage object exists but tokens are 0 or None, treating as no usage info" - ) - # Still increment the call counter - _increment_llm_call_metric(provider, model) - except (AttributeError, KeyError, TypeError) as e: - logger.warning( - "Failed to extract token usage from response.usage: %s. Usage value: %s", - e, - usage, - ) - # Still increment the call counter - _increment_llm_call_metric(provider, model) - else: - # No usage information available - this is expected when llama stack - # internally converts to chat_completions + provider_id, model_id = extract_provider_and_model_from_model_id(model) + if usage is None: logger.debug( "No usage information in Responses API response, token counts will be 0" ) - # token_counter already initialized with 0 values - # Still increment the call counter - _increment_llm_call_metric(provider, model) + _increment_llm_call_metric(provider_id, model_id) + return TokenCounter(llm_calls=1) + + token_counter = TokenCounter( + input_tokens=usage.input_tokens, output_tokens=usage.output_tokens, llm_calls=1 + ) + logger.debug( + "Extracted token usage from Responses API: input=%d, output=%d", + token_counter.input_tokens, + token_counter.output_tokens, + ) + + # Update Prometheus metrics only when we have actual usage data + try: + metrics.llm_token_sent_total.labels(provider_id, model_id).inc( + token_counter.input_tokens + ) + metrics.llm_token_received_total.labels(provider_id, model_id).inc( + token_counter.output_tokens + ) + except (AttributeError, TypeError, ValueError) as e: + logger.warning("Failed to update token metrics: %s", e) + _increment_llm_call_metric(provider_id, model_id) return token_counter def build_tool_call_summary( # pylint: disable=too-many-return-statements,too-many-branches,too-many-locals - output_item: OpenAIResponseOutput, + output_item: ResponseOutput, rag_chunks: list[RAGChunk], - vector_store_ids: Optional[list[str]] = None, - rag_id_mapping: Optional[dict[str, str]] = None, -) -> tuple[Optional[ToolCallSummary], Optional[ToolResultSummary]]: + vector_store_ids: list[str] | None = None, + rag_id_mapping: dict[str, str] | None = None, +) -> tuple[ToolCallSummary | None, ToolResultSummary | None]: """Translate Responses API tool outputs into ToolCallSummary and ToolResultSummary. Args: - output_item: An OpenAIResponseOutput item from the response.output array + output_item: A ResponseOutput item from the response.output array rag_chunks: List to append extracted RAG chunks to (from file_search_call items) vector_store_ids: Vector store IDs used in the query for source resolution. rag_id_mapping: Mapping from vector_db_id to user-facing rag_id. @@ -613,7 +524,7 @@ def build_tool_call_summary( # pylint: disable=too-many-return-statements,too-m extract_rag_chunks_from_file_search_item( file_search_item, rag_chunks, vector_store_ids, rag_id_mapping ) - response_payload: Optional[dict[str, Any]] = None + response_payload: dict[str, Any] | None = None if file_search_item.results is not None: response_payload = { "results": [result.model_dump() for result in file_search_item.results] @@ -740,7 +651,7 @@ def build_mcp_tool_call_from_arguments_done( output_index: int, arguments: str, mcp_call_items: dict[int, tuple[str, str]], -) -> Optional[ToolCallSummary]: +) -> ToolCallSummary | None: """Build ToolCallSummary from MCP call arguments completion event. Args: @@ -796,7 +707,7 @@ def _resolve_source_for_result( result: Any, vector_store_ids: list[str], rag_id_mapping: dict[str, str], -) -> Optional[str]: +) -> str | None: """Resolve the human-friendly index name for a file search result. Uses the vector store mapping to convert internal llama-stack IDs @@ -816,14 +727,14 @@ def _resolve_source_for_result( if len(vector_store_ids) > 1: attributes = getattr(result, "attributes", {}) or {} - attr_store_id: Optional[str] = attributes.get("vector_store_id") + attr_store_id: str | None = attributes.get("vector_store_id") if attr_store_id: return rag_id_mapping.get(attr_store_id, attr_store_id) return None -def _build_chunk_attributes(result: Any) -> Optional[dict[str, Any]]: +def _build_chunk_attributes(result: Any) -> dict[str, Any] | None: """Extract document metadata attributes from a file search result. Parameters: @@ -843,8 +754,8 @@ def _build_chunk_attributes(result: Any) -> Optional[dict[str, Any]]: def extract_rag_chunks_from_file_search_item( item: FileSearchCall, rag_chunks: list[RAGChunk], - vector_store_ids: Optional[list[str]] = None, - rag_id_mapping: Optional[dict[str, str]] = None, + vector_store_ids: list[str] | None = None, + rag_id_mapping: dict[str, str] | None = None, ) -> None: """Extract RAG chunks from a file search tool call item. @@ -908,3 +819,237 @@ def parse_arguments_string(arguments_str: str) -> dict[str, Any]: # Fallback: return wrapped in arguments key return {"args": arguments_str} + + +async def check_model_configured( + client: AsyncLlamaStackClient, + model_id: str, +) -> bool: + """Validate that a model is configured and available. + + Args: + client: The AsyncLlamaStackClient instance + model_id: The model identifier in "provider/model" format + + Returns: + True if the model is available, False if not found (404) + + Raises: + HTTPException: If there's a connection error or other API error + """ + try: + models = await client.models.list() + for model in models: + if model.id == model_id: + return True + return False + except APIStatusError as e: + response = InternalServerErrorResponse.generic() + raise HTTPException(**response.model_dump()) from e + except APIConnectionError as e: + error_response = ServiceUnavailableResponse( + backend_name="Llama Stack", + cause=str(e), + ) + raise HTTPException(**error_response.model_dump()) from e + + +async def select_model_for_responses( + client: AsyncLlamaStackClient, + user_conversation: UserConversation | None, +) -> str: + """Select model for Responses API if not explicitly specified in the request. + + Model selection precedence: + 1. If conversation is provided and has last_used_model, use it + 2. If default model is configured, use it + 3. Otherwise, fetch available models and select the first LLM model (model_type="llm") + 4. Raise HTTPException if no LLM model is found + + Args: + client: The AsyncLlamaStackClient instance + user_conversation: The user conversation if conversation_id was provided, None otherwise + + Returns: + The llama_stack_model_id in "provider/model" format + + Raises: + HTTPException: If models cannot be fetched or an error occurs, or if no LLM model is found + """ + # 1. Conversation has existing last_used_model + if ( + user_conversation is not None + and user_conversation.last_used_model + and user_conversation.last_used_provider + ): + model_id = f"{user_conversation.last_used_provider}/{user_conversation.last_used_model}" + return model_id + + # 2. Select default model from configuration + if configuration.inference is not None: + default_model = configuration.inference.default_model + default_provider = configuration.inference.default_provider + if default_model and default_provider: + return f"{default_provider}/{default_model}" + + # 3. Fetch models list and select the first LLM model (model_type="llm") + try: + models = await client.models.list() + except APIConnectionError as e: + error_response = ServiceUnavailableResponse( + backend_name="Llama Stack", + cause=str(e), + ) + raise HTTPException(**error_response.model_dump()) from e + except APIStatusError as e: + error_response = InternalServerErrorResponse.generic() + raise HTTPException(**error_response.model_dump()) from e + + llm_models = [ + m + for m in models + if m.custom_metadata and m.custom_metadata.get("model_type") == "llm" + ] + if not llm_models: + logger.error("No LLM model found in available models") + response = NotFoundResponse(resource="model", resource_id=None) + raise HTTPException(**response.model_dump()) + + model = llm_models[0] + logger.info("Selected first LLM model: %s", model.id) + return model.id + + +def build_turn_summary( + response: ResponseObject | None, + model: str, + vector_store_ids: list[str] | None = None, + rag_id_mapping: dict[str, str] | None = None, +) -> TurnSummary: + """Build a TurnSummary from a ResponseObject. + + Args: + response: The ResponseObject to build the turn summary from, or None + model: The model identifier in "provider/model" format + vector_store_ids: Vector store IDs used in the query for source resolution. + rag_id_mapping: Mapping from vector_db_id to user-facing rag_id. + Returns: + TurnSummary with extracted response text, referenced_documents, rag_chunks, + tool_calls, and tool_results. All fields are empty/default if response is None + or has no output. + """ + summary = TurnSummary() + + if response is None or response.output is None: + return summary + + # Extract text from output items + summary.llm_response = extract_text_from_output_items(response.output) + + # Extract referenced documents and tool calls/results + summary.referenced_documents = parse_referenced_documents( + response, vector_store_ids, rag_id_mapping + ) + + for item in response.output: + tool_call, tool_result = build_tool_call_summary( + item, summary.rag_chunks, vector_store_ids, rag_id_mapping + ) + if tool_call: + summary.tool_calls.append(tool_call) + if tool_result: + summary.tool_results.append(tool_result) + + summary.token_usage = extract_token_usage(response.usage, model) + return summary + + +def extract_text_from_output_items( + output_items: Sequence[ResponseOutput] | None, +) -> str: + """Extract text from response output items recursively. + + Args: + output_items: Sequence of output items from response.output, or None. + + Returns: + Extracted text content concatenated from all items, or empty string if None. + """ + if output_items is None: + return "" + + text_fragments: list[str] = [] + for item in output_items: + text = extract_text_from_output_item(item) + if text: + text_fragments.append(text) + + return " ".join(text_fragments) + + +def extract_text_from_output_item(output_item: ResponseOutput) -> str: + """Extract text from a single output item. + + Args: + output_item: A single output item from response.output. + + Returns: + Extracted text content, or empty string if not a message or role is user. + """ + if output_item.type != "message": + return "" + + message_item = cast(ResponseMessage, output_item) + if message_item.role == "user": + return "" + + return _extract_text_from_content(message_item.content) + + +def _extract_text_from_content( + content: str | Sequence[InputMessageContent] | Sequence[OutputMessageContent], +) -> str: + """Extract text from message content. + + Args: + content: Content from ResponseMessage.content which can be + str or sequence of content parts (input or output). + + Returns: + Extracted text content. Only extracts text from input_text, output_text, + or refusal types. Other content types (images, files, etc.) are ignored. + """ + if isinstance(content, str): + return content + + text_fragments: list[str] = [] + for part in content: + if part.type == "input_text": + input_text_part = cast(InputTextPart, part) + if input_text_part.text: + text_fragments.append(input_text_part.text.strip()) + elif part.type == "output_text": + output_text_part = cast(OutputTextPart, part) + if output_text_part.text: + text_fragments.append(output_text_part.text.strip()) + elif part.type == "refusal": + refusal_part = cast(ContentPartRefusal, part) + if refusal_part.refusal: + text_fragments.append(refusal_part.refusal.strip()) + + return " ".join(text_fragments) + + +def deduplicate_referenced_documents( + docs: list[ReferencedDocument], +) -> list[ReferencedDocument]: + """Remove duplicate referenced documents based on URL and title.""" + seen: set[tuple[str | None, str | None]] = set() + out: list[ReferencedDocument] = [] + for d in docs: + key = (str(d.doc_url) if d.doc_url else None, d.doc_title) + if key in seen: + continue + seen.add(key) + out.append(d) + return out diff --git a/src/utils/transcripts.py b/src/utils/transcripts.py index aba969512..8e64b3962 100644 --- a/src/utils/transcripts.py +++ b/src/utils/transcripts.py @@ -9,12 +9,19 @@ import os from pathlib import Path import hashlib -from typing import Optional + +from fastapi import HTTPException from configuration import configuration -from models.requests import Attachment, QueryRequest + +from models.requests import Attachment +from models.responses import InternalServerErrorResponse from utils.suid import get_suid -from utils.types import TurnSummary +from utils.types import ( + Transcript, + TranscriptMetadata, + TurnSummary, +) from log import get_logger logger = get_logger(__name__) @@ -34,17 +41,16 @@ def _hash_user_id(user_id: str) -> str: return hashlib.sha256(user_id.encode("utf-8")).hexdigest() -def construct_transcripts_path(user_id: str, conversation_id: str) -> Path: +def construct_transcripts_path(hashed_user_id: str, conversation_id: str) -> Path: """ Construct the filesystem path where transcripts for a given user and conversation are stored. The returned path is built from the configured transcripts storage base - directory, a filesystem-safe directory derived from a hash of `user_id`, + directory, a filesystem-safe directory derived from a pre-hashed `user_id`, and a filesystem-safe form of `conversation_id`. Parameters: - user_id (str): The identifier for the user; a hashed form of this value - is used as a path component. + hashed_user_id (str): The hashed identifier for the user conversation_id (str): The conversation identifier; this value is normalized for use as a path component. @@ -54,7 +60,6 @@ def construct_transcripts_path(user_id: str, conversation_id: str) -> Path: """ # these two normalizations are required by Snyk as it detects # this Path sanitization pattern - hashed_user_id = _hash_user_id(user_id) uid = os.path.normpath("/" + hashed_user_id).lstrip("/") cid = os.path.normpath("/" + conversation_id).lstrip("/") file_path = ( @@ -63,69 +68,94 @@ def construct_transcripts_path(user_id: str, conversation_id: str) -> Path: return Path(file_path, uid, cid) -def store_transcript( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals - user_id: str, - conversation_id: str, - model_id: str, - provider_id: Optional[str], - query_is_valid: bool, - query: str, - query_request: QueryRequest, - summary: TurnSummary, - rag_chunks: list[dict], - truncated: bool, - attachments: list[Attachment], +def store_transcript( + transcript: Transcript, ) -> None: """Store transcript in the local filesystem. Parameters: - user_id: The user ID (UUID). - conversation_id: The conversation ID (UUID). - model_id: Identifier of the model used to generate the LLM response. - provider_id: Optional provider identifier for the model. - query_is_valid: The result of the query validation. - query: The query (without attachments). - query_request: The request containing a query. - summary: Summary of the query/response turn. - rag_chunks: The list of serialized `RAGChunk` dictionaries. - truncated: The flag indicating if the history was truncated. - attachments: The list of `Attachment` objects. + transcript: BaseModel instance to be stored (e.g., Transcript). Raises: - IOError, OSError: If writing the transcript file to disk fails. + HTTPException: If writing the transcript file to disk fails. """ - transcripts_path = construct_transcripts_path(user_id, conversation_id) + transcripts_path = construct_transcripts_path( + transcript.metadata.user_id, transcript.metadata.conversation_id + ) transcripts_path.mkdir(parents=True, exist_ok=True) - hashed_user_id = _hash_user_id(user_id) - - data_to_store = { - "metadata": { - "provider": provider_id, - "model": model_id, - "query_provider": query_request.provider, - "query_model": query_request.model, - "user_id": hashed_user_id, - "conversation_id": conversation_id, - "timestamp": datetime.now(UTC).isoformat(), - }, - "redacted_query": query, - "query_is_valid": query_is_valid, - "llm_response": summary.llm_response, - "rag_chunks": rag_chunks, - "truncated": truncated, - "attachments": [attachment.model_dump() for attachment in attachments], - "tool_calls": [tc.model_dump() for tc in summary.tool_calls], - "tool_results": [tr.model_dump() for tr in summary.tool_results], - } - - # stores feedback in a file under unique uuid + # stores transcript in a file under unique uuid transcript_file_path = transcripts_path / f"{get_suid()}.json" try: with open(transcript_file_path, "w", encoding="utf-8") as transcript_file: - json.dump(data_to_store, transcript_file) + json.dump(transcript.model_dump(), transcript_file) + logger.info("Transcript successfully stored at: %s", transcript_file_path) except (IOError, OSError) as e: logger.error("Failed to store transcript into %s: %s", transcript_file_path, e) - raise + response = InternalServerErrorResponse.generic() + raise HTTPException(**response.model_dump()) from e + - logger.info("Transcript successfully stored at: %s", transcript_file_path) +def create_transcript_metadata( # pylint: disable=too-many-arguments,too-many-positional-arguments + user_id: str, + conversation_id: str, + model_id: str, + provider_id: str | None, + query_provider: str | None, + query_model: str | None, +) -> TranscriptMetadata: + """Create a TranscriptMetadata BaseModel instance. + + Parameters: + user_id: The user ID (UUID). + conversation_id: The conversation ID (UUID). + model_id: Identifier of the model used to generate the LLM response. + provider_id: Optional provider identifier for the model. + query_provider: Optional provider identifier from the query request. + query_model: Optional model identifier from the query request. + + Returns: + TranscriptMetadata: A TranscriptMetadata BaseModel instance. + """ + hashed_user_id = _hash_user_id(user_id) + + return TranscriptMetadata( + provider=provider_id, + model=model_id, + query_provider=query_provider, + query_model=query_model, + user_id=hashed_user_id, + conversation_id=conversation_id, + timestamp=datetime.now(UTC).isoformat(), + ) + + +def create_transcript( + metadata: TranscriptMetadata, + redacted_query: str, + summary: TurnSummary, + attachments: list[Attachment], +) -> Transcript: + """Create a Transcript BaseModel instance from individual parameters. + + Parameters: + metadata: The transcript metadata. + redacted_query: The query text (redacted if necessary). + summary: Summary of the query/response turn containing LLM response, + RAG chunks, tool calls, and tool results. + attachments: List of attachments from the query request. + + Returns: + Transcript: A Transcript BaseModel instance ready to be stored. + """ + return Transcript( + metadata=metadata, + redacted_query=redacted_query, + query_is_valid=True, + llm_response=summary.llm_response, + rag_chunks=[chunk.model_dump() for chunk in summary.rag_chunks], + truncated=False, + attachments=[attachment.model_dump() for attachment in attachments], + tool_calls=[tc.model_dump() for tc in summary.tool_calls], + tool_results=[tr.model_dump() for tr in summary.tool_results], + ) diff --git a/src/utils/types.py b/src/utils/types.py index 3de88a7a6..5dd7910cc 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -197,3 +197,29 @@ class TurnSummary(BaseModel): referenced_documents: list[ReferencedDocument] = Field(default_factory=list) pre_rag_documents: list[ReferencedDocument] = Field(default_factory=list) token_usage: TokenCounter = Field(default_factory=TokenCounter) + + +class TranscriptMetadata(BaseModel): + """Metadata for a transcript entry.""" + + provider: str | None = None + model: str + query_provider: str | None = None + query_model: str | None = None + user_id: str + conversation_id: str + timestamp: str + + +class Transcript(BaseModel): + """Model representing a transcript entry to be stored.""" + + metadata: TranscriptMetadata + redacted_query: str + query_is_valid: bool + llm_response: str + rag_chunks: list[dict[str, Any]] = Field(default_factory=list) + truncated: bool + attachments: list[dict[str, Any]] = Field(default_factory=list) + tool_calls: list[dict[str, Any]] = Field(default_factory=list) + tool_results: list[dict[str, Any]] = Field(default_factory=list) diff --git a/tests/integration/endpoints/test_query_v2_integration.py b/tests/integration/endpoints/test_query_integration.py similarity index 96% rename from tests/integration/endpoints/test_query_v2_integration.py rename to tests/integration/endpoints/test_query_integration.py index c22691663..7728179ae 100644 --- a/tests/integration/endpoints/test_query_v2_integration.py +++ b/tests/integration/endpoints/test_query_integration.py @@ -71,6 +71,12 @@ def mock_llama_stack_client_fixture( # Mock tool calls (empty by default) mock_response.tool_calls = [] + # Mock usage (required for token extraction) + mock_usage = mocker.MagicMock() + mock_usage.input_tokens = 10 + mock_usage.output_tokens = 5 + mock_response.usage = mock_usage + mock_client.responses.create.return_value = mock_response # Mock models list (required for model selection) @@ -391,6 +397,10 @@ async def test_query_v2_endpoint_with_tool_calls( mock_response.output = [mock_tool_output, mock_message_output] mock_response.stop_reason = "end_turn" + mock_usage = mocker.MagicMock() + mock_usage.input_tokens = 10 + mock_usage.output_tokens = 5 + mock_response.usage = mock_usage mock_llama_stack_client.responses.create.return_value = mock_response @@ -458,7 +468,10 @@ async def test_query_v2_endpoint_with_mcp_list_tools( mock_response.output = [mock_mcp_list, mock_message] mock_response.tool_calls = [] - mock_response.usage = {"input_tokens": 15, "output_tokens": 20} + mock_usage = mocker.MagicMock() + mock_usage.input_tokens = 15 + mock_usage.output_tokens = 20 + mock_response.usage = mock_usage mock_llama_stack_client.responses.create.return_value = mock_response @@ -525,7 +538,10 @@ async def test_query_v2_endpoint_with_multiple_tool_types( mock_response.output = [mock_file_search, mock_function, mock_message] mock_response.tool_calls = [] - mock_response.usage = {"input_tokens": 40, "output_tokens": 60} + mock_usage = mocker.MagicMock() + mock_usage.input_tokens = 40 + mock_usage.output_tokens = 60 + mock_response.usage = mock_usage mock_llama_stack_client.responses.create.return_value = mock_response @@ -723,6 +739,7 @@ async def test_query_v2_endpoint_updates_existing_conversation( test_request: Request, test_auth: AuthTuple, patch_db_session: Session, + mocker: MockerFixture, ) -> None: """Test that existing conversation is updated (not recreated). @@ -757,7 +774,15 @@ async def test_query_v2_endpoint_updates_existing_conversation( original_topic = existing_conversation.topic_summary original_count = existing_conversation.message_count - mock_llama_stack_client.responses.create.return_value.id = EXISTING_CONV_ID + # Create a proper mock response with all required attributes + mock_response = mocker.MagicMock(spec=OpenAIResponseObject) + mock_response.id = EXISTING_CONV_ID + mock_response.output = [] + mock_usage = mocker.MagicMock() + mock_usage.input_tokens = 10 + mock_usage.output_tokens = 5 + mock_response.usage = mock_usage + mock_llama_stack_client.responses.create.return_value = mock_response query_request = QueryRequest(query="Tell me more", conversation_id=EXISTING_CONV_ID) @@ -989,7 +1014,10 @@ async def test_query_v2_endpoint_with_shield_violation( mock_response.output = [mock_output_item] mock_response.tool_calls = [] - mock_response.usage = {"input_tokens": 10, "output_tokens": 5} + mock_usage = mocker.MagicMock() + mock_usage.input_tokens = 10 + mock_usage.output_tokens = 5 + mock_response.usage = mock_usage mock_llama_stack_client.responses.create.return_value = mock_response @@ -1097,7 +1125,10 @@ async def test_query_v2_endpoint_handles_empty_llm_response( mock_response.output = [mock_output_item] mock_response.stop_reason = "end_turn" - mock_response.usage = {"input_tokens": 10, "output_tokens": 0} + mock_usage = mocker.MagicMock() + mock_usage.input_tokens = 10 + mock_usage.output_tokens = 0 + mock_response.usage = mock_usage mock_llama_stack_client.responses.create.return_value = mock_response @@ -1150,7 +1181,10 @@ async def test_query_v2_endpoint_quota_integration( mock_response = mocker.MagicMock() mock_response.id = "response-quota" mock_response.output = [] - mock_response.usage = {"input_tokens": 100, "output_tokens": 50} + mock_usage = mocker.MagicMock() + mock_usage.input_tokens = 100 + mock_usage.output_tokens = 50 + mock_response.usage = mock_usage mock_llama_stack_client.responses.create.return_value = mock_response @@ -1176,7 +1210,6 @@ async def test_query_v2_endpoint_quota_integration( assert consume_args.kwargs["token_usage"] is not None assert consume_args.kwargs["token_usage"].input_tokens == 100 assert consume_args.kwargs["token_usage"].output_tokens == 50 - assert consume_args.kwargs["configuration"] is not None assert response.available_quotas is not None assert isinstance(response.available_quotas, dict) @@ -1344,6 +1377,7 @@ async def test_query_v2_endpoint_uses_conversation_history_model( test_request: Request, test_auth: AuthTuple, patch_db_session: Session, + mocker: MockerFixture, ) -> None: """Test that model from conversation history is used. @@ -1376,8 +1410,14 @@ async def test_query_v2_endpoint_uses_conversation_history_model( patch_db_session.add(existing_conv) patch_db_session.commit() - # Configure mock to return the existing conversation_id (response.id becomes conversation_id) - mock_llama_stack_client.responses.create.return_value.id = EXISTING_CONV_ID + mock_response = mocker.MagicMock(spec=OpenAIResponseObject) + mock_response.id = EXISTING_CONV_ID + mock_response.output = [] + mock_usage = mocker.MagicMock() + mock_usage.input_tokens = 10 + mock_usage.output_tokens = 5 + mock_response.usage = mock_usage + mock_llama_stack_client.responses.create.return_value = mock_response query_request = QueryRequest(query="Tell me more", conversation_id=EXISTING_CONV_ID) diff --git a/tests/unit/app/endpoints/test_a2a.py b/tests/unit/app/endpoints/test_a2a.py index 55eb47cf7..7fa7ee9d3 100644 --- a/tests/unit/app/endpoints/test_a2a.py +++ b/tests/unit/app/endpoints/test_a2a.py @@ -149,7 +149,7 @@ class TestConvertResponsesContentToA2AParts: def test_convert_empty_output(self, mocker: MockerFixture) -> None: """Test converting empty output returns empty list.""" mocker.patch( - "app.endpoints.a2a.extract_text_from_response_output_item", + "app.endpoints.a2a.extract_text_from_output_item", return_value=None, ) result = _convert_responses_content_to_a2a_parts([]) @@ -158,7 +158,7 @@ def test_convert_empty_output(self, mocker: MockerFixture) -> None: def test_convert_single_output_item(self, mocker: MockerFixture) -> None: """Test converting single output item with text.""" mocker.patch( - "app.endpoints.a2a.extract_text_from_response_output_item", + "app.endpoints.a2a.extract_text_from_output_item", return_value="Hello, world!", ) mock_output_item = MagicMock() @@ -171,7 +171,7 @@ def test_convert_single_output_item(self, mocker: MockerFixture) -> None: def test_convert_multiple_output_items(self, mocker: MockerFixture) -> None: """Test converting multiple output items.""" extract_mock = mocker.patch( - "app.endpoints.a2a.extract_text_from_response_output_item", + "app.endpoints.a2a.extract_text_from_output_item", ) extract_mock.side_effect = ["First", "Second"] @@ -190,7 +190,7 @@ def test_convert_multiple_output_items(self, mocker: MockerFixture) -> None: def test_convert_output_items_with_none_text(self, mocker: MockerFixture) -> None: """Test that output items with no text are filtered out.""" extract_mock = mocker.patch( - "app.endpoints.a2a.extract_text_from_response_output_item", + "app.endpoints.a2a.extract_text_from_output_item", ) extract_mock.side_effect = ["Valid text", None, "Another valid"] diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index db3991e54..4dcc7f2b1 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -16,7 +16,12 @@ from models.requests import Attachment, QueryRequest from models.responses import QueryResponse from utils.token_counter import TokenCounter -from utils.types import ResponsesApiParams, TurnSummary +from utils.types import ( + ResponsesApiParams, + ToolCallSummary, + ToolResultSummary, + TurnSummary, +) # User ID must be proper UUID MOCK_AUTH = ( @@ -481,25 +486,25 @@ async def test_retrieve_response_success(self, mocker: MockerFixture) -> None: mock_output_item.type = "message" mock_output_item.content = "Response text" + mock_usage = mocker.Mock() + mock_usage.input_tokens = 10 + mock_usage.output_tokens = 5 mock_response = mocker.Mock(spec=OpenAIResponseObject) mock_response.output = [mock_output_item] + mock_response.usage = mock_usage mocker.patch( "app.endpoints.query.run_shield_moderation", return_value=mocker.Mock(blocked=False), ) mock_client.responses.create = mocker.AsyncMock(return_value=mock_response) + + mock_summary = TurnSummary() + mock_summary.llm_response = "Response text" + mock_summary.token_usage = TokenCounter(input_tokens=10, output_tokens=5) mocker.patch( - "app.endpoints.query.extract_text_from_response_output_item", - return_value="Response text", - ) - mocker.patch( - "app.endpoints.query.build_tool_call_summary", return_value=(None, None) - ) - mocker.patch("app.endpoints.query.parse_referenced_documents", return_value=[]) - mocker.patch( - "app.endpoints.query.extract_token_usage", - return_value=TokenCounter(input_tokens=10, output_tokens=5), + "app.endpoints.query.build_turn_summary", + return_value=mock_summary, ) result = await retrieve_response(mock_client, mock_responses_params) @@ -668,8 +673,12 @@ async def test_retrieve_response_with_tool_calls( "model": "provider1/model1", } + mock_usage = mocker.Mock() + mock_usage.input_tokens = 10 + mock_usage.output_tokens = 5 mock_response = mocker.Mock(spec=OpenAIResponseObject) mock_response.output = [mocker.Mock(type="message")] + mock_response.usage = mock_usage mocker.patch( "app.endpoints.query.run_shield_moderation", @@ -677,23 +686,27 @@ async def test_retrieve_response_with_tool_calls( ) mock_client.responses.create = mocker.AsyncMock(return_value=mock_response) - mock_tool_call = mocker.Mock() - mock_tool_result = mocker.Mock() - mocker.patch( - "app.endpoints.query.extract_text_from_response_output_item", - return_value="Response text", - ) - mocker.patch( - "app.endpoints.query.build_tool_call_summary", - return_value=(mock_tool_call, mock_tool_result), + mock_tool_call = ToolCallSummary(id="1", name="test", args={}) + mock_tool_result = ToolResultSummary( + id="1", status="success", content="result", round=1 ) - mocker.patch("app.endpoints.query.parse_referenced_documents", return_value=[]) + mock_summary = TurnSummary() + mock_summary.llm_response = "Response text" + mock_summary.tool_calls = [mock_tool_call] + mock_summary.tool_results = [mock_tool_result] + mock_summary.token_usage = TokenCounter(input_tokens=10, output_tokens=5) mocker.patch( - "app.endpoints.query.extract_token_usage", - return_value=TokenCounter(input_tokens=10, output_tokens=5), + "app.endpoints.query.build_turn_summary", + return_value=mock_summary, ) result = await retrieve_response(mock_client, mock_responses_params) + assert result.llm_response == "Response text" assert len(result.tool_calls) == 1 assert len(result.tool_results) == 1 + assert result.token_usage.input_tokens == 10 + assert result.token_usage.output_tokens == 5 + assert result.rag_chunks == [] + assert result.referenced_documents == [] + assert result.pre_rag_documents == [] diff --git a/tests/unit/app/endpoints/test_rlsapi_v1.py b/tests/unit/app/endpoints/test_rlsapi_v1.py index 6dcc3364b..52196b597 100644 --- a/tests/unit/app/endpoints/test_rlsapi_v1.py +++ b/tests/unit/app/endpoints/test_rlsapi_v1.py @@ -699,7 +699,7 @@ async def test_infer_endpoint_calls_get_mcp_tools( auth=MOCK_AUTH, ) - mock_get_mcp_tools.assert_called_once_with(mock_configuration.mcp_servers) + mock_get_mcp_tools.assert_called_once_with() @pytest.mark.asyncio diff --git a/tests/unit/utils/test_prompts.py b/tests/unit/utils/test_prompts.py index acbf6b219..337812d85 100644 --- a/tests/unit/utils/test_prompts.py +++ b/tests/unit/utils/test_prompts.py @@ -148,10 +148,12 @@ def setup_configuration_fixture() -> AppConfig: def test_get_default_system_prompt( config_without_system_prompt: AppConfig, query_request_without_system_prompt: QueryRequest, + mocker: MockerFixture, ) -> None: """Test that default system prompt is returned when other prompts are not provided.""" + mocker.patch("utils.prompts.configuration", config_without_system_prompt) system_prompt = prompts.get_system_prompt( - query_request_without_system_prompt, config_without_system_prompt + query_request_without_system_prompt.system_prompt ) assert system_prompt == constants.DEFAULT_SYSTEM_PROMPT @@ -159,10 +161,12 @@ def test_get_default_system_prompt( def test_get_customized_system_prompt( config_with_custom_system_prompt: AppConfig, query_request_without_system_prompt: QueryRequest, + mocker: MockerFixture, ) -> None: """Test that customized system prompt is used when system prompt is not provided in query.""" + mocker.patch("utils.prompts.configuration", config_with_custom_system_prompt) system_prompt = prompts.get_system_prompt( - query_request_without_system_prompt, config_with_custom_system_prompt + query_request_without_system_prompt.system_prompt ) assert system_prompt == CONFIGURED_SYSTEM_PROMPT @@ -170,10 +174,12 @@ def test_get_customized_system_prompt( def test_get_query_system_prompt( config_without_system_prompt: AppConfig, query_request_with_system_prompt: QueryRequest, + mocker: MockerFixture, ) -> None: """Test that system prompt from query is returned.""" + mocker.patch("utils.prompts.configuration", config_without_system_prompt) system_prompt = prompts.get_system_prompt( - query_request_with_system_prompt, config_without_system_prompt + query_request_with_system_prompt.system_prompt ) assert system_prompt == query_request_with_system_prompt.system_prompt @@ -181,10 +187,12 @@ def test_get_query_system_prompt( def test_get_query_system_prompt_not_customized_one( config_with_custom_system_prompt: AppConfig, query_request_with_system_prompt: QueryRequest, + mocker: MockerFixture, ) -> None: """Test that system prompt from query is returned even when customized one is specified.""" + mocker.patch("utils.prompts.configuration", config_with_custom_system_prompt) system_prompt = prompts.get_system_prompt( - query_request_with_system_prompt, config_with_custom_system_prompt + query_request_with_system_prompt.system_prompt ) assert system_prompt == query_request_with_system_prompt.system_prompt @@ -192,38 +200,48 @@ def test_get_query_system_prompt_not_customized_one( def test_get_system_prompt_with_disable_query_system_prompt( config_with_custom_system_prompt_and_disable_query_system_prompt: AppConfig, query_request_with_system_prompt: QueryRequest, + mocker: MockerFixture, ) -> None: """Test that query system prompt is disallowed when disable_query_system_prompt is True.""" + mocker.patch( + "utils.prompts.configuration", + config_with_custom_system_prompt_and_disable_query_system_prompt, + ) with pytest.raises(HTTPException) as exc_info: - prompts.get_system_prompt( - query_request_with_system_prompt, - config_with_custom_system_prompt_and_disable_query_system_prompt, - ) + prompts.get_system_prompt(query_request_with_system_prompt.system_prompt) assert exc_info.value.status_code == 422 def test_get_system_prompt_with_disable_query_system_prompt_and_non_system_prompt_query( config_with_custom_system_prompt_and_disable_query_system_prompt: AppConfig, query_request_without_system_prompt: QueryRequest, + mocker: MockerFixture, ) -> None: """Test that query without system prompt is allowed when disable_query_system_prompt is True.""" - system_prompt = prompts.get_system_prompt( - query_request_without_system_prompt, + mocker.patch( + "utils.prompts.configuration", config_with_custom_system_prompt_and_disable_query_system_prompt, ) + system_prompt = prompts.get_system_prompt( + query_request_without_system_prompt.system_prompt + ) assert system_prompt == CONFIGURED_SYSTEM_PROMPT def test_get_profile_prompt_with_disable_query_system_prompt( config_with_custom_profile_prompt_and_disable_query_system_prompt: AppConfig, query_request_without_system_prompt: QueryRequest, + mocker: MockerFixture, ) -> None: """Test that system prompt is set if profile enabled and query system prompt disabled.""" + mocker.patch( + "utils.prompts.configuration", + config_with_custom_profile_prompt_and_disable_query_system_prompt, + ) custom_profile = CustomProfile(path="tests/profiles/test/profile.py") profile_prompts = custom_profile.get_prompts() system_prompt = prompts.get_system_prompt( - query_request_without_system_prompt, - config_with_custom_profile_prompt_and_disable_query_system_prompt, + query_request_without_system_prompt.system_prompt ) assert system_prompt == profile_prompts.get("default") @@ -231,26 +249,34 @@ def test_get_profile_prompt_with_disable_query_system_prompt( def test_get_profile_prompt_with_enabled_query_system_prompt( config_with_custom_profile_prompt_and_enabled_query_system_prompt: AppConfig, query_request_with_system_prompt: QueryRequest, + mocker: MockerFixture, ) -> None: """Test that profile system prompt is overridden by query system prompt enabled.""" - system_prompt = prompts.get_system_prompt( - query_request_with_system_prompt, + mocker.patch( + "utils.prompts.configuration", config_with_custom_profile_prompt_and_enabled_query_system_prompt, ) + system_prompt = prompts.get_system_prompt( + query_request_with_system_prompt.system_prompt + ) assert system_prompt == query_request_with_system_prompt.system_prompt def test_get_topic_summary_system_prompt_default( setup_configuration: AppConfig, + mocker: MockerFixture, ) -> None: """Test that default topic summary system prompt is returned when no custom profile is configured. """ - topic_summary_prompt = prompts.get_topic_summary_system_prompt(setup_configuration) + mocker.patch("utils.prompts.configuration", setup_configuration) + topic_summary_prompt = prompts.get_topic_summary_system_prompt() assert topic_summary_prompt == constants.DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT -def test_get_topic_summary_system_prompt_with_custom_profile() -> None: +def test_get_topic_summary_system_prompt_with_custom_profile( + mocker: MockerFixture, +) -> None: """Test that custom profile topic summary prompt is returned when available.""" test_config = config_dict.copy() test_config["customization"] = { @@ -258,12 +284,13 @@ def test_get_topic_summary_system_prompt_with_custom_profile() -> None: } cfg = AppConfig() cfg.init_from_dict(test_config) + mocker.patch("utils.prompts.configuration", cfg) # Mock the custom profile to return a topic_summary prompt custom_profile = CustomProfile(path="tests/profiles/test/profile.py") profile_prompts = custom_profile.get_prompts() - topic_summary_prompt = prompts.get_topic_summary_system_prompt(cfg) + topic_summary_prompt = prompts.get_topic_summary_system_prompt() assert topic_summary_prompt == profile_prompts.get("topic_summary") @@ -288,17 +315,21 @@ def test_get_topic_summary_system_prompt_with_custom_profile_no_topic_summary( # Patch the custom_profile property to return our mock mocker.patch.object(cfg.customization, "custom_profile", mock_profile) + mocker.patch("utils.prompts.configuration", cfg) - topic_summary_prompt = prompts.get_topic_summary_system_prompt(cfg) + topic_summary_prompt = prompts.get_topic_summary_system_prompt() assert topic_summary_prompt == constants.DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT -def test_get_topic_summary_system_prompt_no_customization() -> None: +def test_get_topic_summary_system_prompt_no_customization( + mocker: MockerFixture, +) -> None: """Test that default topic summary prompt is returned when customization is None.""" test_config = config_dict.copy() test_config["customization"] = None cfg = AppConfig() cfg.init_from_dict(test_config) + mocker.patch("utils.prompts.configuration", cfg) - topic_summary_prompt = prompts.get_topic_summary_system_prompt(cfg) + topic_summary_prompt = prompts.get_topic_summary_system_prompt() assert topic_summary_prompt == constants.DEFAULT_TOPIC_SUMMARY_SYSTEM_PROMPT diff --git a/tests/unit/utils/test_query.py b/tests/unit/utils/test_query.py index 26bd0f0a2..a2cd705a5 100644 --- a/tests/unit/utils/test_query.py +++ b/tests/unit/utils/test_query.py @@ -28,7 +28,6 @@ from tests.unit import config_dict from utils.query import ( consume_query_tokens, - evaluate_model_hints, extract_provider_and_model_from_model_id, handle_known_apistatus_errors, is_input_shield, @@ -36,7 +35,6 @@ is_transcripts_enabled, persist_user_conversation_details, prepare_input, - select_model_and_provider_id, store_conversation_into_cache, store_query_results, update_azure_token, @@ -87,6 +85,7 @@ def test_store_with_cache_configured(self, mocker: MockerFixture) -> None: mock_config.conversation_cache = mock_cache mock_config.conversation_cache_configuration = mocker.Mock() mock_config.conversation_cache_configuration.type = "sqlite" + mocker.patch("utils.query.configuration", mock_config) cache_entry = CacheEntry( query="test query", @@ -98,11 +97,10 @@ def test_store_with_cache_configured(self, mocker: MockerFixture) -> None: ) store_conversation_into_cache( - config=mock_config, user_id="test_user", conversation_id="test_conv", cache_entry=cache_entry, - _skip_userid_check=False, + skip_userid_check=False, topic_summary="Test topic", ) @@ -130,12 +128,12 @@ def test_store_without_topic_summary(self, mocker: MockerFixture) -> None: completed_at="2024-01-01T00:00:05Z", ) + mocker.patch("utils.query.configuration", mock_config) store_conversation_into_cache( - config=mock_config, user_id="test_user", conversation_id="test_conv", cache_entry=cache_entry, - _skip_userid_check=False, + skip_userid_check=False, topic_summary=None, ) @@ -159,152 +157,44 @@ def test_store_with_cache_not_initialized(self, mocker: MockerFixture) -> None: ) # Should not raise an exception, just log a warning + mocker.patch("utils.query.configuration", mock_config) store_conversation_into_cache( - config=mock_config, user_id="test_user", conversation_id="test_conv", cache_entry=cache_entry, - _skip_userid_check=False, + skip_userid_check=False, topic_summary=None, ) -class TestSelectModelAndProviderId: - """Tests for select_model_and_provider_id function.""" - - def test_select_from_request(self, mock_models: ModelListResponse) -> None: - """Test selecting model and provider from request.""" - result = select_model_and_provider_id( - models=mock_models, - model_id="model1", - provider_id="provider1", - ) - assert result == ("provider1/model1", "model1", "provider1") - - def test_select_first_available_llm(self, mock_models: ModelListResponse) -> None: - """Test selecting first available LLM when no model specified.""" - result = select_model_and_provider_id( - models=mock_models, - model_id=None, - provider_id=None, - ) - assert result[0] in ("provider1/model1", "provider2/model2") - assert result[1] in ("model1", "model2") - assert result[2] in ("provider1", "provider2") - - def test_select_model_not_found(self, mock_models: ModelListResponse) -> None: - """Test selecting non-existent model raises HTTPException.""" - with pytest.raises(HTTPException) as exc_info: - select_model_and_provider_id( - models=mock_models, - model_id="nonexistent", - provider_id="provider1", - ) - assert exc_info.value.status_code == 404 - - def test_select_model_no_llm_models_available(self, mocker: MockerFixture) -> None: - """Test selecting model when no LLM models are available raises HTTPException.""" - # Mock configuration to have no default model/provider - mocker.patch("utils.query.configuration.inference.default_model", None) - mocker.patch("utils.query.configuration.inference.default_provider", None) - - # Empty models list - empty_models: ModelListResponse = [] - with pytest.raises(HTTPException) as exc_info: - select_model_and_provider_id( - models=empty_models, - model_id=None, - provider_id=None, - ) - assert exc_info.value.status_code == 404 - # Verify it's a NotFoundResponse for model resource - detail = exc_info.value.detail - assert isinstance(detail, dict) - assert detail.get("response") == "Model not found" - assert "Model with ID" in detail.get("cause", "") - - def test_select_model_no_llm_models_with_non_llm_only( - self, mocker: MockerFixture - ) -> None: - """Test selecting model when only non-LLM models are available raises HTTPException.""" - # Mock configuration to have no default model/provider - mocker.patch("utils.query.configuration.inference.default_model", None) - mocker.patch("utils.query.configuration.inference.default_provider", None) - - # Models list with only non-LLM models - non_llm_models = [ - type( - "Model", - (), - { - "id": "provider1/model1", - "custom_metadata": { - "model_type": "embeddings", - "provider_id": "provider1", - }, - }, - )(), - ] - with pytest.raises(HTTPException) as exc_info: - select_model_and_provider_id( - models=non_llm_models, - model_id=None, - provider_id=None, - ) - assert exc_info.value.status_code == 404 - # Verify it's a NotFoundResponse for model resource - detail = exc_info.value.detail - assert isinstance(detail, dict) - assert detail.get("response") == "Model not found" - assert "Model with ID" in detail.get("cause", "") - - def test_select_model_attribute_error(self, mocker: MockerFixture) -> None: - """Test selecting model when model lacks custom_metadata raises HTTPException.""" - # Mock configuration to have no default model/provider - mocker.patch("utils.query.configuration.inference.default_model", None) - mocker.patch("utils.query.configuration.inference.default_provider", None) - - # Models list with model that has no custom_metadata attribute - models_without_metadata = [ - type("Model", (), {"id": "provider1/model1"})(), - ] - with pytest.raises(HTTPException) as exc_info: - select_model_and_provider_id( - models=models_without_metadata, - model_id=None, - provider_id=None, - ) - assert exc_info.value.status_code == 404 - # Verify it's a NotFoundResponse for model resource - detail = exc_info.value.detail - assert isinstance(detail, dict) - assert detail.get("response") == "Model not found" - assert "Model with ID" in detail.get("cause", "") - - class TestValidateModelProviderOverride: """Tests for validate_model_provider_override function.""" def test_allowed_with_action(self) -> None: """Test that override is allowed when user has MODEL_OVERRIDE action.""" - query_request = QueryRequest( - query="test", model="model1", provider="provider1" - ) # pyright: ignore[reportCallIssue] - validate_model_provider_override(query_request, {Action.MODEL_OVERRIDE}) + validate_model_provider_override("model1", "provider1", {Action.MODEL_OVERRIDE}) def test_rejected_without_action(self) -> None: """Test that override is rejected when user lacks MODEL_OVERRIDE action.""" - query_request = QueryRequest( - query="test", model="model1", provider="provider1" - ) # pyright: ignore[reportCallIssue] with pytest.raises(HTTPException) as exc_info: - validate_model_provider_override(query_request, set()) + validate_model_provider_override("model1", "provider1", set()) assert exc_info.value.status_code == 403 def test_no_override_allowed(self) -> None: """Test that request without override is allowed regardless of permissions.""" - query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue] - validate_model_provider_override(query_request, set()) + validate_model_provider_override(None, None, set()) + + def test_responses_api_format_with_action(self) -> None: + """Test that Responses API format (provider/model) is allowed with action.""" + validate_model_provider_override( + "provider1/model1", None, {Action.MODEL_OVERRIDE} + ) + + def test_responses_api_format_without_action(self) -> None: + """Test that Responses API format (provider/model) is rejected without action.""" + with pytest.raises(HTTPException) as exc_info: + validate_model_provider_override("provider1/model1", None, set()) + assert exc_info.value.status_code == 403 class TestShieldFunctions: @@ -341,48 +231,6 @@ def test_is_input_shield_output_prefix(self) -> None: assert is_input_shield(shield) is False -class TestEvaluateModelHints: - """Tests for evaluate_model_hints function.""" - - def test_with_user_conversation_no_request_hints(self) -> None: - """Test using hints from user conversation when request has none.""" - user_conv = UserConversation( - id="conv1", - user_id="user1", - last_used_model="model1", - last_used_provider="provider1", - ) - query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue] - - model_id, provider_id = evaluate_model_hints(user_conv, query_request) - assert model_id == "model1" - assert provider_id == "provider1" - - def test_with_user_conversation_and_request_hints(self) -> None: - """Test request hints take precedence over conversation hints.""" - user_conv = UserConversation( - id="conv1", - user_id="user1", - last_used_model="model1", - last_used_provider="provider1", - ) - query_request = QueryRequest( - query="test", model="model2", provider="provider2" - ) # pyright: ignore[reportCallIssue] - - model_id, provider_id = evaluate_model_hints(user_conv, query_request) - assert model_id == "model2" - assert provider_id == "provider2" - - def test_without_user_conversation(self) -> None: - """Test without user conversation returns request hints or None.""" - query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue] - - model_id, provider_id = evaluate_model_hints(None, query_request) - assert model_id is None - assert provider_id is None - - class TestPrepareInput: """Tests for prepare_input function.""" @@ -673,9 +521,7 @@ def query_side_effect(*args: Any) -> Any: class TestConsumeQueryTokens: """Tests for consume_query_tokens function.""" - def test_consume_tokens_success( - self, mock_config: AppConfig, mocker: MockerFixture - ) -> None: + def test_consume_tokens_success(self, mocker: MockerFixture) -> None: """Test successful token consumption.""" mock_consume = mocker.patch("utils.query.consume_tokens") @@ -684,15 +530,12 @@ def test_consume_tokens_success( user_id="user1", model_id="provider1/model1", token_usage=token_usage, - configuration=mock_config, ) # Verify consume_tokens was called mock_consume.assert_called_once() - def test_consume_tokens_database_error( - self, mock_config: AppConfig, mocker: MockerFixture - ) -> None: + def test_consume_tokens_database_error(self, mocker: MockerFixture) -> None: """Test token consumption raises HTTPException on database error.""" mocker.patch( "utils.query.consume_tokens", side_effect=sqlite3.Error("DB error") @@ -704,7 +547,6 @@ def test_consume_tokens_database_error( user_id="user1", model_id="provider1/model1", token_usage=token_usage, - configuration=mock_config, ) assert exc_info.value.status_code == 500 @@ -808,9 +650,7 @@ async def test_update_with_api_status_error(self, mocker: MockerFixture) -> None class TestStoreQueryResults: """Tests for store_query_results function.""" - def test_store_query_results_success( - self, mock_config: AppConfig, mocker: MockerFixture - ) -> None: + def test_store_query_results_success(self, mocker: MockerFixture) -> None: """Test successful storage of query results.""" mocker.patch("utils.query.is_transcripts_enabled", return_value=False) mock_persist = mocker.patch("utils.query.persist_user_conversation_details") @@ -830,7 +670,6 @@ def test_store_query_results_success( completed_at="2024-01-01T00:00:05Z", summary=summary, query_request=query_request, - configuration=mock_config, skip_userid_check=False, topic_summary="Topic", ) @@ -839,12 +678,14 @@ def test_store_query_results_success( mock_persist.assert_called_once() mock_store_cache.assert_called_once() - def test_store_query_results_transcript_error( - self, mock_config: AppConfig, mocker: MockerFixture - ) -> None: + def test_store_query_results_transcript_error(self, mocker: MockerFixture) -> None: """Test storage raises HTTPException on transcript error.""" mocker.patch("utils.query.is_transcripts_enabled", return_value=True) - mocker.patch("utils.query.store_transcript", side_effect=IOError("IO error")) + error_response = InternalServerErrorResponse.generic() + mocker.patch( + "utils.query.store_transcript", + side_effect=HTTPException(**error_response.model_dump()), + ) summary = TurnSummary() summary.llm_response = "response" @@ -861,15 +702,12 @@ def test_store_query_results_transcript_error( completed_at="2024-01-01T00:00:05Z", summary=summary, query_request=query_request, - configuration=mock_config, skip_userid_check=False, topic_summary=None, ) assert exc_info.value.status_code == 500 - def test_store_query_results_sqlalchemy_error( - self, mock_config: AppConfig, mocker: MockerFixture - ) -> None: + def test_store_query_results_sqlalchemy_error(self, mocker: MockerFixture) -> None: """Test storage raises HTTPException on SQLAlchemy error.""" mocker.patch("utils.query.is_transcripts_enabled", return_value=False) mocker.patch( @@ -892,15 +730,12 @@ def test_store_query_results_sqlalchemy_error( completed_at="2024-01-01T00:00:05Z", summary=summary, query_request=query_request, - configuration=mock_config, skip_userid_check=False, topic_summary=None, ) assert exc_info.value.status_code == 500 - def test_store_query_results_cache_error( - self, mock_config: AppConfig, mocker: MockerFixture - ) -> None: + def test_store_query_results_cache_error(self, mocker: MockerFixture) -> None: """Test storage raises HTTPException on cache error.""" mocker.patch("utils.query.is_transcripts_enabled", return_value=False) mocker.patch("utils.query.persist_user_conversation_details") @@ -924,15 +759,12 @@ def test_store_query_results_cache_error( completed_at="2024-01-01T00:00:05Z", summary=summary, query_request=query_request, - configuration=mock_config, skip_userid_check=False, topic_summary=None, ) assert exc_info.value.status_code == 500 - def test_store_query_results_value_error( - self, mock_config: AppConfig, mocker: MockerFixture - ) -> None: + def test_store_query_results_value_error(self, mocker: MockerFixture) -> None: """Test storage raises HTTPException on ValueError.""" mocker.patch("utils.query.is_transcripts_enabled", return_value=False) mocker.patch("utils.query.persist_user_conversation_details") @@ -956,15 +788,12 @@ def test_store_query_results_value_error( completed_at="2024-01-01T00:00:05Z", summary=summary, query_request=query_request, - configuration=mock_config, skip_userid_check=False, topic_summary=None, ) assert exc_info.value.status_code == 500 - def test_store_query_results_psycopg2_error( - self, mock_config: AppConfig, mocker: MockerFixture - ) -> None: + def test_store_query_results_psycopg2_error(self, mocker: MockerFixture) -> None: """Test storage raises HTTPException on psycopg2 error.""" mocker.patch("utils.query.is_transcripts_enabled", return_value=False) mocker.patch("utils.query.persist_user_conversation_details") @@ -988,15 +817,12 @@ def test_store_query_results_psycopg2_error( completed_at="2024-01-01T00:00:05Z", summary=summary, query_request=query_request, - configuration=mock_config, skip_userid_check=False, topic_summary=None, ) assert exc_info.value.status_code == 500 - def test_store_query_results_sqlite_error( - self, mock_config: AppConfig, mocker: MockerFixture - ) -> None: + def test_store_query_results_sqlite_error(self, mocker: MockerFixture) -> None: """Test storage raises HTTPException on sqlite3 error.""" mocker.patch("utils.query.is_transcripts_enabled", return_value=False) mocker.patch("utils.query.persist_user_conversation_details") @@ -1020,7 +846,6 @@ def test_store_query_results_sqlite_error( completed_at="2024-01-01T00:00:05Z", summary=summary, query_request=query_request, - configuration=mock_config, skip_userid_check=False, topic_summary=None, ) diff --git a/tests/unit/utils/test_responses.py b/tests/unit/utils/test_responses.py index efbdbf234..5331e7adf 100644 --- a/tests/unit/utils/test_responses.py +++ b/tests/unit/utils/test_responses.py @@ -21,7 +21,6 @@ from pydantic import AnyUrl from pytest_mock import MockerFixture -from configuration import AppConfig from models.config import ModelContextProtocolServer from models.requests import QueryRequest from utils.responses import ( @@ -29,7 +28,8 @@ build_tool_call_summary, build_tool_result_from_mcp_output_item_done, extract_rag_chunks_from_file_search_item, - extract_text_from_response_output_item, + extract_text_from_output_item, + extract_text_from_output_items, extract_token_usage, extract_vector_store_ids_from_tools, get_mcp_tools, @@ -65,10 +65,19 @@ class MockContentPart: # pylint: disable=too-few-public-methods """Mock content part for message content.""" def __init__( - self, text: Optional[str] = None, refusal: Optional[str] = None + self, + text: Optional[str] = None, + refusal: Optional[str] = None, + part_type: Optional[str] = None, ) -> None: self.text = text self.refusal = refusal + if part_type: + self.type = part_type + elif text: + self.type = "output_text" + elif refusal: + self.type = "refusal" def make_output_item( @@ -89,18 +98,21 @@ def make_output_item( def make_content_part( - text: Optional[str] = None, refusal: Optional[str] = None + text: Optional[str] = None, + refusal: Optional[str] = None, + part_type: Optional[str] = None, ) -> MockContentPart: """Create a mock content part for message content. Args: text: Text content of the part refusal: Refusal message content + part_type: Type of the content part ("output_text" or "refusal") Returns: MockContentPart: Mock object with text and/or refusal attributes """ - return MockContentPart(text=text, refusal=refusal) + return MockContentPart(text=text, refusal=refusal, part_type=part_type) @pytest.mark.parametrize( @@ -110,26 +122,19 @@ def make_content_part( ("function_call", "assistant", "some text", ""), ("file_search_call", "assistant", "some text", ""), (None, "assistant", "some text", ""), - # Non-assistant roles should return empty string + # User role messages are filtered out - return empty string ("message", "user", "some text", ""), - ("message", "system", "some text", ""), - ("message", None, "some text", ""), # Valid assistant message with string content ("message", "assistant", "Hello, world!", "Hello, world!"), ("message", "assistant", "", ""), - # No content attribute - ("message", "assistant", None, ""), ], ids=[ "function_call_type_returns_empty", "file_search_call_type_returns_empty", "none_type_returns_empty", "user_role_returns_empty", - "system_role_returns_empty", - "none_role_returns_empty", "valid_string_content", "empty_string_content", - "none_content", ], ) def test_extract_text_basic_cases( @@ -144,24 +149,22 @@ def test_extract_text_basic_cases( expected: Expected extracted text """ output_item = make_output_item(item_type=item_type, role=role, content=content) - result = extract_text_from_response_output_item(output_item) + result = extract_text_from_output_item(output_item) # type: ignore[arg-type] assert result == expected @pytest.mark.parametrize( "content_parts,expected", [ - # List with string items - (["Hello", " ", "world"], "Hello world"), - (["Single string"], "Single string"), + # Empty list ([], ""), - # List with make_content_part objects containing text + # List with make_content_part objects containing text (with type="output_text") ( [make_content_part(text="Part 1"), make_content_part(text=" Part 2")], "Part 1 Part 2", ), ([make_content_part(text="Only text")], "Only text"), - # List with make_content_part objects containing refusal + # List with make_content_part objects containing refusal (with type="refusal") ( [make_content_part(refusal="I cannot help with that")], "I cannot help with that", @@ -173,49 +176,13 @@ def test_extract_text_basic_cases( ], "Some text but I refuse", ), - # List with dict items - ([{"text": "Dict text 1"}, {"text": "Dict text 2"}], "Dict text 1Dict text 2"), - ([{"refusal": "Dict refusal"}], "Dict refusal"), - ([{"text": "Text"}, {"refusal": "Refusal"}], "TextRefusal"), - # Mixed content types - ( - [ - "String part", - make_content_part(text=" Object part"), - {"text": " Dict part"}, - ], - "String part Object part Dict part", - ), - ( - [ - make_content_part(text="Text"), - make_content_part(refusal=" Refusal"), - {"text": " DictText"}, - " String", - ], - "Text Refusal DictText String", - ), - # Content parts with None or missing attributes - ([make_content_part(text=None), make_content_part(refusal=None)], ""), - ([{"other_key": "value"}], ""), - ([make_content_part(text="Valid"), {"invalid": "key"}], "Valid"), ], ids=[ - "list_of_strings", - "list_single_string", "empty_list", "list_of_objects_with_text", "single_object_with_text", "object_with_refusal", "mixed_text_and_refusal_objects", - "list_of_dicts_with_text", - "dict_with_refusal", - "dict_mixed_text_refusal", - "mixed_string_object_dict", - "complex_mixed_content", - "none_attributes", - "dict_without_text_or_refusal", - "valid_mixed_with_invalid", ], ) def test_extract_text_list_content(content_parts: list[Any], expected: str) -> None: @@ -228,7 +195,7 @@ def test_extract_text_list_content(content_parts: list[Any], expected: str) -> N output_item = make_output_item( item_type="message", role="assistant", content=content_parts ) - result = extract_text_from_response_output_item(output_item) + result = extract_text_from_output_item(output_item) # type: ignore[arg-type] assert result == expected @@ -242,86 +209,122 @@ def test_extract_text_with_real_world_structure() -> None: content = [ make_content_part(text="I can help you with that. "), make_content_part(text="Here's the information you requested: "), - "The answer is 42.", + make_content_part(text="The answer is 42."), ] output_item = make_output_item( item_type="message", role="assistant", content=content ) - result = extract_text_from_response_output_item(output_item) + result = extract_text_from_output_item(output_item) # type: ignore[arg-type] expected = "I can help you with that. Here's the information you requested: The answer is 42." assert result == expected -def test_extract_text_with_numeric_dict_values() -> None: - """Test that numeric values in dicts are properly converted to strings. +def test_extract_text_preserves_order() -> None: + """Test that content parts are concatenated in the correct order. - Ensures that when dict values are numeric, they're converted to strings - during extraction. + Verifies that the order of content parts is preserved during extraction. """ - content = [{"text": 123}, {"refusal": 456}] + content = [ + make_content_part(text="First"), + make_content_part(text=" Second"), + make_content_part(text=" Third"), + make_content_part(text=" Fourth"), + ] output_item = make_output_item( item_type="message", role="assistant", content=content ) - result = extract_text_from_response_output_item(output_item) + result = extract_text_from_output_item(output_item) # type: ignore[arg-type] - # Numbers should be converted to strings - assert result == "123456" + assert result == "First Second Third Fourth" -def test_extract_text_preserves_order() -> None: - """Test that content parts are concatenated in the correct order. +def test_extract_text_with_refusal_content() -> None: + """Test extraction with refusal content parts. - Verifies that the order of content parts is preserved during extraction. + Verifies that refusal type content parts are properly extracted. """ content = [ - "First", - make_content_part(text=" Second"), - {"text": " Third"}, - " Fourth", + make_content_part(refusal="I understand your request, "), + make_content_part(refusal="but I cannot help with that."), ] output_item = make_output_item( item_type="message", role="assistant", content=content ) - result = extract_text_from_response_output_item(output_item) + result = extract_text_from_output_item(output_item) # type: ignore[arg-type] - assert result == "First Second Third Fourth" + assert result == "I understand your request, but I cannot help with that." -@pytest.mark.parametrize( - "missing_attr", - ["type", "role", "content"], - ids=["missing_type", "missing_role", "missing_content"], -) -def test_extract_text_handles_missing_attributes(missing_attr: str) -> None: - """Test graceful handling when expected attributes are missing. +class TestExtractTextFromOutputItems: + """Test cases for extract_text_from_output_items function.""" - Args: - missing_attr: The attribute to omit from the mock object - """ + def test_extract_text_from_output_items_none(self) -> None: + """Test extract_text_from_output_items returns empty string for None.""" + result = extract_text_from_output_items(None) + assert result == "" - # Create a basic dict-like object without using make_output_item - # pylint: disable=too-few-public-methods,missing-class-docstring,attribute-defined-outside-init - class PartialMock: - pass + def test_extract_text_from_output_items_empty_list(self) -> None: + """Test extract_text_from_output_items returns empty string for empty list.""" + result = extract_text_from_output_items([]) + assert result == "" - output_item = PartialMock() + def test_extract_text_from_output_items_single_item(self) -> None: + """Test extract_text_from_output_items with a single message item.""" + output_item = make_output_item( + item_type="message", role="assistant", content="Hello world" + ) + result = extract_text_from_output_items([output_item]) # type: ignore[arg-type] + assert result == "Hello world" - # Add only the attributes we want - if missing_attr != "type": - output_item.type = "message" # type: ignore - if missing_attr != "role": - output_item.role = "assistant" # type: ignore - if missing_attr != "content": - output_item.content = "Some text" # type: ignore + def test_extract_text_from_output_items_multiple_items(self) -> None: + """Test extract_text_from_output_items with multiple message items.""" + item1 = make_output_item( + item_type="message", role="assistant", content="First message" + ) + item2 = make_output_item( + item_type="message", role="assistant", content="Second message" + ) + result = extract_text_from_output_items([item1, item2]) # type: ignore[arg-type] + assert result == "First message Second message" - result = extract_text_from_response_output_item(output_item) + def test_extract_text_from_output_items_filters_non_messages(self) -> None: + """Test extract_text_from_output_items filters out non-message items.""" + item1 = make_output_item( + item_type="message", role="assistant", content="Valid message" + ) + item2 = make_output_item( + item_type="function_call", role="assistant", content="Should be ignored" + ) + result = extract_text_from_output_items([item1, item2]) # type: ignore[arg-type] + assert result == "Valid message" - # Should return empty string when critical attributes are missing - assert result == "" + def test_extract_text_from_output_items_filters_user_messages(self) -> None: + """Test extract_text_from_output_items filters out user role messages.""" + item1 = make_output_item( + item_type="message", role="assistant", content="Assistant message" + ) + item2 = make_output_item( + item_type="message", role="user", content="User message" + ) + result = extract_text_from_output_items([item1, item2]) # type: ignore[arg-type] + # User messages are filtered out - only assistant message is included + assert result == "Assistant message" + + def test_extract_text_from_output_items_with_list_content(self) -> None: + """Test extract_text_from_output_items with list-based content.""" + content = [ + make_content_part(text="Part 1"), + make_content_part(text="Part 2"), + ] + output_item = make_output_item( + item_type="message", role="assistant", content=content + ) + result = extract_text_from_output_items([output_item]) # type: ignore[arg-type] + assert result == "Part 1 Part 2" class TestGetRAGTools: @@ -345,14 +348,21 @@ class TestGetMCPTools: """Test cases for get_mcp_tools utility function.""" @pytest.mark.asyncio - async def test_get_mcp_tools_without_auth(self) -> None: + async def test_get_mcp_tools_without_auth(self, mocker: MockerFixture) -> None: """Test get_mcp_tools with servers without authorization headers.""" servers_no_auth = [ - ModelContextProtocolServer(name="fs", url="http://localhost:3000"), - ModelContextProtocolServer(name="git", url="https://git.example.com/mcp"), + ModelContextProtocolServer( + name="fs", url="http://localhost:3000", provider_id="mcp" + ), + ModelContextProtocolServer( + name="git", url="https://git.example.com/mcp", provider_id="mcp" + ), ] + mock_config = mocker.Mock() + mock_config.mcp_servers = servers_no_auth + mocker.patch("utils.responses.configuration", mock_config) - tools_no_auth = await get_mcp_tools(servers_no_auth, token=None) + tools_no_auth = await get_mcp_tools(token=None) assert len(tools_no_auth) == 2 assert tools_no_auth[0]["type"] == "mcp" assert tools_no_auth[0]["server_label"] == "fs" @@ -360,7 +370,9 @@ async def test_get_mcp_tools_without_auth(self) -> None: assert "headers" not in tools_no_auth[0] @pytest.mark.asyncio - async def test_get_mcp_tools_with_kubernetes_auth(self) -> None: + async def test_get_mcp_tools_with_kubernetes_auth( + self, mocker: MockerFixture + ) -> None: """Test get_mcp_tools with kubernetes auth.""" servers_k8s = [ ModelContextProtocolServer( @@ -369,12 +381,15 @@ async def test_get_mcp_tools_with_kubernetes_auth(self) -> None: authorization_headers={"Authorization": "kubernetes"}, ), ] - tools_k8s = await get_mcp_tools(servers_k8s, token="user-k8s-token") + mock_config = mocker.Mock() + mock_config.mcp_servers = servers_k8s + mocker.patch("utils.responses.configuration", mock_config) + tools_k8s = await get_mcp_tools(token="user-k8s-token") assert len(tools_k8s) == 1 assert tools_k8s[0]["headers"] == {"Authorization": "Bearer user-k8s-token"} @pytest.mark.asyncio - async def test_get_mcp_tools_with_mcp_headers(self) -> None: + async def test_get_mcp_tools_with_mcp_headers(self, mocker: MockerFixture) -> None: """Test get_mcp_tools with client-provided headers.""" servers = [ ModelContextProtocolServer( @@ -383,6 +398,9 @@ async def test_get_mcp_tools_with_mcp_headers(self) -> None: authorization_headers={"Authorization": "client", "X-Custom": "client"}, ), ] + mock_config = mocker.Mock() + mock_config.mcp_servers = servers + mocker.patch("utils.responses.configuration", mock_config) mcp_headers = { "fs": { @@ -390,7 +408,7 @@ async def test_get_mcp_tools_with_mcp_headers(self) -> None: "X-Custom": "custom-value", } } - tools = await get_mcp_tools(servers, token=None, mcp_headers=mcp_headers) + tools = await get_mcp_tools(token=None, mcp_headers=mcp_headers) assert len(tools) == 1 assert tools[0]["headers"] == { "Authorization": "client-provided-token", @@ -398,11 +416,13 @@ async def test_get_mcp_tools_with_mcp_headers(self) -> None: } # Test with mcp_headers=None (server should be skipped) - tools_no_headers = await get_mcp_tools(servers, token=None, mcp_headers=None) + tools_no_headers = await get_mcp_tools(token=None, mcp_headers=None) assert len(tools_no_headers) == 0 @pytest.mark.asyncio - async def test_get_mcp_tools_client_auth_no_mcp_headers(self) -> None: + async def test_get_mcp_tools_client_auth_no_mcp_headers( + self, mocker: MockerFixture + ) -> None: """Test get_mcp_tools skips server when mcp_headers is None and server requires client auth.""" # noqa: E501 servers = [ ModelContextProtocolServer( @@ -411,17 +431,20 @@ async def test_get_mcp_tools_client_auth_no_mcp_headers(self) -> None: authorization_headers={"X-Custom": "client"}, ), ] + mock_config = mocker.Mock() + mock_config.mcp_servers = servers + mocker.patch("utils.responses.configuration", mock_config) # When mcp_headers is None and server requires client auth, # should return None for that header # This tests the specific path at line 391 - tools = await get_mcp_tools(servers, token=None, mcp_headers=None) + tools = await get_mcp_tools(token=None, mcp_headers=None) # Server should be skipped because it requires client auth but mcp_headers is None assert len(tools) == 0 @pytest.mark.asyncio async def test_get_mcp_tools_client_auth_missing_server_in_headers( - self, + self, mocker: MockerFixture ) -> None: """Test get_mcp_tools skips server when mcp_headers doesn't contain server name.""" servers = [ @@ -431,16 +454,21 @@ async def test_get_mcp_tools_client_auth_missing_server_in_headers( authorization_headers={"X-Custom": "client"}, ), ] + mock_config = mocker.Mock() + mock_config.mcp_servers = servers + mocker.patch("utils.responses.configuration", mock_config) # mcp_headers exists but doesn't contain this server name # This tests the specific path at line 394 mcp_headers = {"other-server": {"X-Custom": "value"}} - tools = await get_mcp_tools(servers, token=None, mcp_headers=mcp_headers) + tools = await get_mcp_tools(token=None, mcp_headers=mcp_headers) # Server should be skipped because mcp_headers doesn't contain this server assert len(tools) == 0 @pytest.mark.asyncio - async def test_get_mcp_tools_with_static_headers(self, tmp_path: Path) -> None: + async def test_get_mcp_tools_with_static_headers( + self, tmp_path: Path, mocker: MockerFixture + ) -> None: """Test get_mcp_tools with static headers from config files.""" secret_file = tmp_path / "token.txt" secret_file.write_text("static-secret-token") @@ -452,13 +480,18 @@ async def test_get_mcp_tools_with_static_headers(self, tmp_path: Path) -> None: authorization_headers={"Authorization": str(secret_file)}, ), ] + mock_config = mocker.Mock() + mock_config.mcp_servers = servers + mocker.patch("utils.responses.configuration", mock_config) - tools = await get_mcp_tools(servers, token=None) + tools = await get_mcp_tools(token=None) assert len(tools) == 1 assert tools[0]["headers"] == {"Authorization": "static-secret-token"} @pytest.mark.asyncio - async def test_get_mcp_tools_with_mixed_headers(self, tmp_path: Path) -> None: + async def test_get_mcp_tools_with_mixed_headers( + self, tmp_path: Path, mocker: MockerFixture + ) -> None: """Test get_mcp_tools with mixed header types.""" secret_file = tmp_path / "api-key.txt" secret_file.write_text("secret-api-key") @@ -474,6 +507,9 @@ async def test_get_mcp_tools_with_mixed_headers(self, tmp_path: Path) -> None: }, ), ] + mock_config = mocker.Mock() + mock_config.mcp_servers = servers + mocker.patch("utils.responses.configuration", mock_config) mcp_headers = { "mixed-server": { @@ -481,7 +517,7 @@ async def test_get_mcp_tools_with_mixed_headers(self, tmp_path: Path) -> None: } } - tools = await get_mcp_tools(servers, token="k8s-token", mcp_headers=mcp_headers) + tools = await get_mcp_tools(token="k8s-token", mcp_headers=mcp_headers) assert len(tools) == 1 assert tools[0]["headers"] == { "Authorization": "Bearer k8s-token", @@ -490,7 +526,9 @@ async def test_get_mcp_tools_with_mixed_headers(self, tmp_path: Path) -> None: } @pytest.mark.asyncio - async def test_get_mcp_tools_skips_server_with_missing_auth(self) -> None: + async def test_get_mcp_tools_skips_server_with_missing_auth( + self, mocker: MockerFixture + ) -> None: """Test that servers with required but unavailable auth headers are skipped.""" servers = [ ModelContextProtocolServer( @@ -504,12 +542,17 @@ async def test_get_mcp_tools_skips_server_with_missing_auth(self) -> None: authorization_headers={"X-Token": "client"}, ), ] + mock_config = mocker.Mock() + mock_config.mcp_servers = servers + mocker.patch("utils.responses.configuration", mock_config) - tools = await get_mcp_tools(servers, token=None, mcp_headers=None) + tools = await get_mcp_tools(token=None, mcp_headers=None) assert len(tools) == 0 @pytest.mark.asyncio - async def test_get_mcp_tools_includes_server_without_auth(self) -> None: + async def test_get_mcp_tools_includes_server_without_auth( + self, mocker: MockerFixture + ) -> None: """Test that servers without auth config are always included.""" servers = [ ModelContextProtocolServer( @@ -518,8 +561,11 @@ async def test_get_mcp_tools_includes_server_without_auth(self) -> None: authorization_headers={}, ), ] + mock_config = mocker.Mock() + mock_config.mcp_servers = servers + mocker.patch("utils.responses.configuration", mock_config) - tools = await get_mcp_tools(servers, token=None, mcp_headers=None) + tools = await get_mcp_tools(token=None, mcp_headers=None) assert len(tools) == 1 assert tools[0]["server_label"] == "public-server" assert "headers" not in tools[0] @@ -536,6 +582,9 @@ async def test_get_mcp_tools_oauth_no_headers_raises_401_with_www_authenticate( authorization_headers={"Authorization": "oauth"}, ), ] + mock_config = mocker.Mock() + mock_config.mcp_servers = servers + mocker.patch("utils.responses.configuration", mock_config) mock_resp = mocker.Mock() mock_resp.headers = {"WWW-Authenticate": 'Bearer error="invalid_token"'} @@ -553,7 +602,7 @@ async def test_get_mcp_tools_oauth_no_headers_raises_401_with_www_authenticate( ) with pytest.raises(HTTPException) as exc_info: - await get_mcp_tools(servers, token=None, mcp_headers=None) + await get_mcp_tools(token=None, mcp_headers=None) assert exc_info.value.status_code == 401 assert exc_info.value.headers is not None @@ -656,32 +705,15 @@ async def test_get_topic_summary_api_error(self, mocker: MockerFixture) -> None: class TestPrepareTools: """Tests for prepare_tools function.""" - @pytest.mark.asyncio - async def test_prepare_tools_no_tools(self, mocker: MockerFixture) -> None: - """Test prepare_tools returns None when no_tools is True.""" - mock_client = mocker.AsyncMock() - query_request = QueryRequest( - query="test", no_tools=True - ) # pyright: ignore[reportCallIssue] - mock_config = mocker.Mock(spec=AppConfig) - mock_config.mcp_servers = [] - - result = await prepare_tools(mock_client, query_request, "token", mock_config) - assert result is None - @pytest.mark.asyncio async def test_prepare_tools_with_vector_store_ids( self, mocker: MockerFixture ) -> None: """Test prepare_tools with specified vector store IDs.""" mock_client = mocker.AsyncMock() - query_request = QueryRequest( - query="test", vector_store_ids=["vs1", "vs2"] - ) # pyright: ignore[reportCallIssue] - mock_config = mocker.Mock(spec=AppConfig) - mock_config.mcp_servers = [] + mocker.patch("utils.responses.get_mcp_tools", return_value=None) - result = await prepare_tools(mock_client, query_request, "token", mock_config) + result = await prepare_tools(mock_client, ["vs1", "vs2"], False, "token") assert result is not None assert len(result) == 1 assert result[0]["type"] == "file_search" @@ -702,12 +734,9 @@ async def test_prepare_tools_fetch_vector_stores( mock_client.vector_stores.list = mocker.AsyncMock( return_value=mock_vector_stores ) + mocker.patch("utils.responses.get_mcp_tools", return_value=None) - query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue] - mock_config = mocker.Mock(spec=AppConfig) - mock_config.mcp_servers = [] - - result = await prepare_tools(mock_client, query_request, "token", mock_config) + result = await prepare_tools(mock_client, None, False, "token") assert result is not None assert len(result) == 1 assert result[0]["vector_store_ids"] == ["vs1", "vs2"] @@ -722,27 +751,18 @@ async def test_prepare_tools_connection_error(self, mocker: MockerFixture) -> No ) ) - query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue] - mock_config = mocker.Mock(spec=AppConfig) - mock_config.mcp_servers = [] - with pytest.raises(HTTPException) as exc_info: - await prepare_tools(mock_client, query_request, "token", mock_config) + await prepare_tools(mock_client, None, False, "token") assert exc_info.value.status_code == 503 @pytest.mark.asyncio async def test_prepare_tools_with_mcp_servers(self, mocker: MockerFixture) -> None: """Test prepare_tools includes MCP tools.""" mock_client = mocker.AsyncMock() - query_request = QueryRequest( - query="test", vector_store_ids=["vs1"] - ) # pyright: ignore[reportCallIssue] - mock_config = mocker.Mock(spec=AppConfig) - mock_config.mcp_servers = [ - ModelContextProtocolServer(name="test-server", url="http://localhost:3000") - ] + mock_mcp_tool = {"type": "mcp", "server_label": "test-server"} + mocker.patch("utils.responses.get_mcp_tools", return_value=[mock_mcp_tool]) - result = await prepare_tools(mock_client, query_request, "token", mock_config) + result = await prepare_tools(mock_client, ["vs1"], False, "token") assert result is not None assert len(result) == 2 # RAG tool + MCP tool assert any(tool.get("type") == "mcp" for tool in result) @@ -757,12 +777,8 @@ async def test_prepare_tools_api_status_error(self, mocker: MockerFixture) -> No ) ) - query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue] - mock_config = mocker.Mock(spec=AppConfig) - mock_config.mcp_servers = [] - with pytest.raises(HTTPException) as exc_info: - await prepare_tools(mock_client, query_request, "token", mock_config) + await prepare_tools(mock_client, None, False, "token") assert exc_info.value.status_code == 500 @pytest.mark.asyncio @@ -774,13 +790,22 @@ async def test_prepare_tools_empty_toolgroups(self, mocker: MockerFixture) -> No mock_client.vector_stores.list = mocker.AsyncMock( return_value=mock_vector_stores ) + mocker.patch("utils.responses.get_mcp_tools", return_value=None) - query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue] - mock_config = mocker.Mock(spec=AppConfig) - mock_config.mcp_servers = [] # No MCP servers + result = await prepare_tools(mock_client, None, False, "token") + assert result is None + + @pytest.mark.asyncio + async def test_prepare_tools_no_tools_true(self, mocker: MockerFixture) -> None: + """Test prepare_tools returns None when no_tools=True.""" + mock_client = mocker.AsyncMock() + + # Should return None immediately without fetching vector stores or MCP tools + result = await prepare_tools(mock_client, ["vs1", "vs2"], True, "token") - result = await prepare_tools(mock_client, query_request, "token", mock_config) assert result is None + # Verify that vector_stores.list was not called + mock_client.vector_stores.list.assert_not_called() class TestPrepareResponsesParams: @@ -801,12 +826,9 @@ async def test_prepare_responses_params_with_conversation_id( query="test", conversation_id="123e4567-e89b-12d3-a456-426614174000" ) # pyright: ignore[reportCallIssue] - mocker.patch("utils.responses.configuration", mocker.Mock()) - mocker.patch( - "utils.responses.select_model_and_provider_id", - return_value=("provider1/model1", "model1", "provider1"), - ) - mocker.patch("utils.responses.evaluate_model_hints", return_value=(None, None)) + mock_config = mocker.Mock() + mock_config.inference = None + mocker.patch("utils.responses.configuration", mock_config) mocker.patch("utils.responses.get_system_prompt", return_value="System prompt") mocker.patch("utils.responses.prepare_tools", return_value=None) mocker.patch("utils.responses.prepare_input", return_value="test") @@ -840,12 +862,9 @@ async def test_prepare_responses_params_create_conversation( query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue] - mocker.patch("utils.responses.configuration", mocker.Mock()) - mocker.patch( - "utils.responses.select_model_and_provider_id", - return_value=("provider1/model1", "model1", "provider1"), - ) - mocker.patch("utils.responses.evaluate_model_hints", return_value=(None, None)) + mock_config = mocker.Mock() + mock_config.inference = None + mocker.patch("utils.responses.configuration", mock_config) mocker.patch("utils.responses.get_system_prompt", return_value="System prompt") mocker.patch("utils.responses.prepare_tools", return_value=None) mocker.patch("utils.responses.prepare_input", return_value="test") @@ -869,7 +888,9 @@ async def test_prepare_responses_params_connection_error_on_models( ) query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue] - mocker.patch("utils.responses.configuration", mocker.Mock()) + mock_config = mocker.Mock() + mock_config.inference = None + mocker.patch("utils.responses.configuration", mock_config) with pytest.raises(HTTPException) as exc_info: await prepare_responses_params(mock_client, query_request, None, "token") @@ -893,12 +914,9 @@ async def test_prepare_responses_params_connection_error_on_conversation( query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue] - mocker.patch("utils.responses.configuration", mocker.Mock()) - mocker.patch( - "utils.responses.select_model_and_provider_id", - return_value=("provider1/model1", "model1", "provider1"), - ) - mocker.patch("utils.responses.evaluate_model_hints", return_value=(None, None)) + mock_config = mocker.Mock() + mock_config.inference = None + mocker.patch("utils.responses.configuration", mock_config) mocker.patch("utils.responses.get_system_prompt", return_value="System prompt") mocker.patch("utils.responses.prepare_tools", return_value=None) mocker.patch("utils.responses.prepare_input", return_value="test") @@ -920,7 +938,9 @@ async def test_prepare_responses_params_api_status_error_on_models( ) query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue] - mocker.patch("utils.responses.configuration", mocker.Mock()) + mock_config = mocker.Mock() + mock_config.inference = None + mocker.patch("utils.responses.configuration", mock_config) with pytest.raises(HTTPException) as exc_info: await prepare_responses_params(mock_client, query_request, None, "token") @@ -944,12 +964,9 @@ async def test_prepare_responses_params_api_status_error_on_conversation( query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue] - mocker.patch("utils.responses.configuration", mocker.Mock()) - mocker.patch( - "utils.responses.select_model_and_provider_id", - return_value=("provider1/model1", "model1", "provider1"), - ) - mocker.patch("utils.responses.evaluate_model_hints", return_value=(None, None)) + mock_config = mocker.Mock() + mock_config.inference = None + mocker.patch("utils.responses.configuration", mock_config) mocker.patch("utils.responses.get_system_prompt", return_value="System prompt") mocker.patch("utils.responses.prepare_tools", return_value=None) mocker.patch("utils.responses.prepare_input", return_value="test") @@ -1080,32 +1097,21 @@ def test_parse_referenced_documents_deduplication( class TestExtractTokenUsage: """Tests for extract_token_usage function.""" - def test_extract_token_usage_with_dict_usage(self, mocker: MockerFixture) -> None: - """Test extracting token usage from dict format.""" - mock_response = mocker.Mock() - mock_response.usage = {"input_tokens": 100, "output_tokens": 50} - - mocker.patch( - "utils.responses.extract_provider_and_model_from_model_id", - return_value=("provider1", "model1"), - ) - mocker.patch("utils.responses.metrics.llm_token_sent_total") - mocker.patch("utils.responses.metrics.llm_token_received_total") - mocker.patch("utils.responses._increment_llm_call_metric") - - result = extract_token_usage(mock_response, "provider1/model1") - assert result.input_tokens == 100 - assert result.output_tokens == 50 - assert result.llm_calls == 1 - - def test_extract_token_usage_with_object_usage(self, mocker: MockerFixture) -> None: - """Test extracting token usage from object format.""" + @pytest.mark.parametrize( + "input_tokens,output_tokens", + [(100, 50), (200, 100)], + ids=["usage_100_50", "usage_200_100"], + ) + def test_extract_token_usage_with_usage_object( + self, + mocker: MockerFixture, + input_tokens: int, + output_tokens: int, + ) -> None: + """Test extracting token usage from usage object (mock with input/output tokens).""" mock_usage = mocker.Mock() - mock_usage.input_tokens = 200 - mock_usage.output_tokens = 100 - - mock_response = mocker.Mock() - mock_response.usage = mock_usage + mock_usage.input_tokens = input_tokens + mock_usage.output_tokens = output_tokens mocker.patch( "utils.responses.extract_provider_and_model_from_model_id", @@ -1115,22 +1121,20 @@ def test_extract_token_usage_with_object_usage(self, mocker: MockerFixture) -> N mocker.patch("utils.responses.metrics.llm_token_received_total") mocker.patch("utils.responses._increment_llm_call_metric") - result = extract_token_usage(mock_response, "provider1/model1") - assert result.input_tokens == 200 - assert result.output_tokens == 100 + result = extract_token_usage(mock_usage, "provider1/model1") + assert result.input_tokens == input_tokens + assert result.output_tokens == output_tokens + assert result.llm_calls == 1 def test_extract_token_usage_no_usage(self, mocker: MockerFixture) -> None: """Test extracting token usage when usage is None.""" - mock_response = mocker.Mock() - mock_response.usage = None - mocker.patch( "utils.responses.extract_provider_and_model_from_model_id", return_value=("provider1", "model1"), ) mocker.patch("utils.responses._increment_llm_call_metric") - result = extract_token_usage(mock_response, "provider1/model1") + result = extract_token_usage(None, "provider1/model1") assert result.input_tokens == 0 assert result.output_tokens == 0 assert result.llm_calls == 1 @@ -1141,16 +1145,13 @@ def test_extract_token_usage_zero_tokens(self, mocker: MockerFixture) -> None: mock_usage.input_tokens = 0 mock_usage.output_tokens = 0 - mock_response = mocker.Mock() - mock_response.usage = mock_usage - mocker.patch( "utils.responses.extract_provider_and_model_from_model_id", return_value=("provider1", "model1"), ) mocker.patch("utils.responses._increment_llm_call_metric") - result = extract_token_usage(mock_response, "provider1/model1") + result = extract_token_usage(mock_usage, "provider1/model1") assert result.input_tokens == 0 assert result.output_tokens == 0 @@ -1172,9 +1173,6 @@ def test_extract_token_usage_metrics_error(self, mocker: MockerFixture) -> None: mock_usage.input_tokens = 100 mock_usage.output_tokens = 50 - mock_response = mocker.Mock() - mock_response.usage = mock_usage - mocker.patch( "utils.responses.extract_provider_and_model_from_model_id", return_value=("provider1", "model1"), @@ -1190,42 +1188,10 @@ def test_extract_token_usage_metrics_error(self, mocker: MockerFixture) -> None: mocker.patch("utils.responses._increment_llm_call_metric") # Should not raise, just log warning - result = extract_token_usage(mock_response, "provider1/model1") + result = extract_token_usage(mock_usage, "provider1/model1") assert result.input_tokens == 100 assert result.output_tokens == 50 - def test_extract_token_usage_extraction_error(self, mocker: MockerFixture) -> None: - """Test extracting token usage handles errors when extracting usage.""" - - # Create a usage object that raises TypeError when attributes are accessed - # getattr catches AttributeError but not TypeError, so TypeError will propagate - class ErrorUsage: # pylint: disable=too-few-public-methods - """Mock usage object that raises TypeError.""" - - def __getattribute__(self, name: str) -> Any: - if name in ("input_tokens", "output_tokens"): - # Raise TypeError which getattr won't catch (only catches AttributeError) - raise TypeError(f"Cannot access {name}") - return super().__getattribute__(name) - - mock_usage = ErrorUsage() - mock_response = mocker.Mock() - mock_response.usage = mock_usage - - mocker.patch( - "utils.responses.extract_provider_and_model_from_model_id", - return_value=("provider1", "model1"), - ) - mocker.patch("utils.responses.logger") - mocker.patch("utils.responses._increment_llm_call_metric") - - # getattr with default catches AttributeError but not TypeError - # TypeError will propagate to exception handler at line 611 - # Should not raise, just log warning and return 0 tokens - result = extract_token_usage(mock_response, "provider1/model1") - assert result.input_tokens == 0 - assert result.output_tokens == 0 - class TestBuildToolCallSummary: """Tests for build_tool_call_summary function.""" diff --git a/tests/unit/utils/test_transcripts.py b/tests/unit/utils/test_transcripts.py index cbe2e5827..1c11d6ff4 100644 --- a/tests/unit/utils/test_transcripts.py +++ b/tests/unit/utils/test_transcripts.py @@ -4,10 +4,12 @@ from pytest_mock import MockerFixture from configuration import AppConfig -from models.requests import Attachment, QueryRequest +from models.requests import QueryRequest from utils.transcripts import ( construct_transcripts_path, + create_transcript, + create_transcript_metadata, store_transcript, ) from utils.types import ToolCallSummary, ToolResultSummary, TurnSummary @@ -44,7 +46,7 @@ def test_construct_transcripts_path(mocker: MockerFixture) -> None: conversation_id = "123e4567-e89b-12d3-a456-426614174000" hashed_user_id = hashlib.sha256(user_id.encode("utf-8")).hexdigest() - path = construct_transcripts_path(user_id, conversation_id) + path = construct_transcripts_path(hashed_user_id, conversation_id) assert ( str(path) @@ -52,7 +54,9 @@ def test_construct_transcripts_path(mocker: MockerFixture) -> None: ), "Path should be constructed correctly" -def test_store_transcript(mocker: MockerFixture) -> None: +def test_store_transcript( # pylint: disable=too-many-locals + mocker: MockerFixture, +) -> None: """Test the store_transcript function.""" mocker.patch("builtins.open", mocker.mock_open()) @@ -95,60 +99,46 @@ def test_store_transcript(mocker: MockerFixture) -> None: rag_chunks=[], ) query_is_valid = True - rag_chunks: list[dict] = [] truncated = False - attachments: list[Attachment] = [] - - store_transcript( - user_id, - conversation_id, - model, - provider, - query_is_valid, - query, - query_request, - summary, - rag_chunks, - truncated, - attachments, + + metadata = create_transcript_metadata( + user_id=user_id, + conversation_id=conversation_id, + model_id=model, + provider_id=provider, + query_provider=query_request.provider, + query_model=query_request.model, + ) + transcript = create_transcript( + metadata=metadata, + redacted_query=query_request.query, + summary=summary, + attachments=query_request.attachments or [], ) + store_transcript(transcript) + # Assert that the transcript was stored correctly hashed_user_id = hashlib.sha256(user_id.encode("utf-8")).hexdigest() - mock_json.dump.assert_called_once_with( - { - "metadata": { - "provider": "fake-provider", - "model": "fake-model", - "query_provider": query_request.provider, - "query_model": query_request.model, - "user_id": hashed_user_id, - "conversation_id": conversation_id, - "timestamp": mocker.ANY, - }, - "redacted_query": query, - "query_is_valid": query_is_valid, - "llm_response": summary.llm_response, - "rag_chunks": rag_chunks, - "truncated": truncated, - "attachments": attachments, - "tool_calls": [ - { - "id": "123", - "name": "test-tool", - "args": {"testing": "testing"}, - "type": "tool_call", - } - ], - "tool_results": [ - { - "id": "123", - "status": "success", - "content": "tool response", - "type": "tool_result", - "round": 1, - } - ], - }, - mocker.ANY, - ) + mock_json.dump.assert_called_once() + call_args = mock_json.dump.call_args[0] + stored_data = call_args[0] + + assert stored_data["metadata"]["provider"] == "fake-provider" + assert stored_data["metadata"]["model"] == "fake-model" + assert stored_data["metadata"]["query_provider"] == query_request.provider + assert stored_data["metadata"]["query_model"] == query_request.model + assert stored_data["metadata"]["user_id"] == hashed_user_id + assert stored_data["metadata"]["conversation_id"] == conversation_id + assert "timestamp" in stored_data["metadata"] + assert stored_data["redacted_query"] == query + assert stored_data["query_is_valid"] == query_is_valid + assert stored_data["llm_response"] == summary.llm_response + assert stored_data["rag_chunks"] == [] + assert stored_data["truncated"] == truncated + assert stored_data["attachments"] == [] + assert len(stored_data["tool_calls"]) == 1 + assert stored_data["tool_calls"][0]["id"] == "123" + assert stored_data["tool_calls"][0]["name"] == "test-tool" + assert len(stored_data["tool_results"]) == 1 + assert stored_data["tool_results"][0]["id"] == "123"