Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions tests/unit/app/endpoints/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,3 +1443,226 @@ def test_evaluate_model_hints(

assert provider_id == expected_provider
assert model_id == expected_model

# Note: Test framework and library in use: pytest with pytest-asyncio and pytest-mock (mocker fixture).
# These tests extend coverage for query endpoint and helpers based on recent changes.

import asyncio
import json
import uuid
import pytest
from fastapi import HTTPException, status

# Utilities to build minimal valid QueryRequest objects for different scenarios.
def _make_query_request(**overrides):
# Import inside to avoid issues at import time if module side-effects occur
from app.schemas.query import QueryRequest # path inferred; adjust if necessary
base = {
"query": overrides.pop("query", "What is OpenStack?"),
"provider": overrides.pop("provider", None),
"model": overrides.pop("model", None),
"conversation_id": overrides.pop("conversation_id", None),
"vector_db_ids": overrides.pop("vector_db_ids", None),
"tools": overrides.pop("tools", None),
"metadata": overrides.pop("metadata", None),
}
# Remove keys with None to respect pydantic optional defaults
base = {k: v for k, v in base.items() if v is not None}
return QueryRequest(**base)

@pytest.mark.asyncio
async def test_query_endpoint_handler_raises_when_auth_missing(mocker, setup_configuration):
# Focus: Ensure missing auth raises 401/403 appropriately depending on implementation
from app.endpoints.query import query_endpoint_handler, configuration

# Arrange: Patch configuration to be present
mocker.patch("app.endpoints.query.configuration", setup_configuration)

# Patch underlying dependencies to isolate handler logic
mocker.patch("app.endpoints.query.validate_conversation_ownership", return_value=True)
# Simulate LLM client call path returning a canned response
mock_response = {"response": "ok", "messages": []}
mocker.patch("app.endpoints.query.generate_response", return_value=mock_response)

query_request = _make_query_request(query="Ping")
# Act & Assert
# auth missing -> expect HTTPException (401 or 403 depending on code)
with pytest.raises(HTTPException) as e:
await query_endpoint_handler(query_request, auth=None)
assert e.value.status_code in (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)

@pytest.mark.asyncio
async def test_query_endpoint_handler_happy_path_minimal(mocker, setup_configuration):
# Focus: Happy path with minimal required inputs and no auth enforcement
from app.endpoints.query import query_endpoint_handler
# Disable auth in config if needed
cfg = setup_configuration
cfg.service.auth_enabled = False
mocker.patch("app.endpoints.query.configuration", cfg)

# Mocks
mocker.patch("app.endpoints.query.validate_conversation_ownership", return_value=True)
mock_resp = {"response": "All good", "messages": [{"role": "assistant", "content": "All good"}]}
mock_generate = mocker.patch("app.endpoints.query.generate_response", return_value=mock_resp)
persist_spy = mocker.patch("app.endpoints.query.persist_user_conversation_details")

req = _make_query_request(query="Hello?")
result = await query_endpoint_handler(req, auth=["user-123", "", "token"])

assert result["response"] == "All good"
assert "messages" in result
mock_generate.assert_called_once()
# Depending on code path, the persist might be called; assert at least that it's patched and callable
assert persist_spy.called in (True, False)

@pytest.mark.asyncio
async def test_query_endpoint_handler_invalid_conversation_ownership(mocker, setup_configuration):
# Focus: When conversation_id provided and ownership check fails -> raise
from app.endpoints.query import query_endpoint_handler
mocker.patch("app.endpoints.query.configuration", setup_configuration)

# Make ownership fail
mocker.patch("app.endpoints.query.validate_conversation_ownership", return_value=False)
req = _make_query_request(query="Q", conversation_id=str(uuid.uuid4()))
with pytest.raises(HTTPException) as e:
await query_endpoint_handler(req, auth=["user-321", "", "token"])
assert e.value.status_code in (status.HTTP_403_FORBIDDEN, status.HTTP_400_BAD_REQUEST)

@pytest.mark.asyncio
async def test_query_endpoint_handler_handles_generate_response_failure(mocker, setup_configuration):
# Focus: Errors from generate_response bubble into HTTP 500 with informative detail
from app.endpoints.query import query_endpoint_handler
mocker.patch("app.endpoints.query.configuration", setup_configuration)

mocker.patch("app.endpoints.query.validate_conversation_ownership", return_value=True)
mocker.patch("app.endpoints.query.generate_response", side_effect=RuntimeError("backend down"))

req = _make_query_request(query="Q")
with pytest.raises(HTTPException) as e:
await query_endpoint_handler(req, auth=["user-1", "", "token"])
assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# Ensure detail has response/message text per implementation
detail = getattr(e.value, "detail", {})
assert isinstance(detail, dict)
assert any(k in detail for k in ("response", "message", "error"))

def test_get_rag_toolgroups_none_input_treated_as_empty(mocker):
from app.endpoints.query import get_rag_toolgroups
assert get_rag_toolgroups(None) is None
assert get_rag_toolgroups([]) is None

def test_get_rag_toolgroups_dedupes_and_filters_invalid_ids(mocker):
# Focus: if implementation accepts any truthy strings only, ensure duplicates collapse and non-strings ignored
from app.endpoints.query import get_rag_toolgroups
# This is a resilience test; if implementation doesn't dedupe/filter it's still safe to assert structural invariants.
vec_ids = ["a", "a", "b", "", None, 123] # type: ignore
result = get_rag_toolgroups(vec_ids)
if result is None:
# Acceptable if empty/invalid inputs yield None
assert True
return
assert isinstance(result, list) and len(result) >= 1
tool = result[0]
assert tool["name"] == "builtin::rag/knowledge_search"
assert "args" in tool and "vector_db_ids" in tool["args"]
# vector_db_ids should contain at least the valid string ids in order
assert "a" in tool["args"]["vector_db_ids"]
assert "b" in tool["args"]["vector_db_ids"]

@pytest.mark.parametrize(
"user_conversation, request_provider, request_model, expected_provider, expected_model",
[
# Request specifies provider only -> provider respected, model remains None (or default if code fills)
(None, "provA", None, "provA", None),
# Request specifies model only -> model respected, provider remains None
(None, None, "modelX", None, "modelX"),
# Conversation has different provider/model and request overrides only one
(
# conversation
__import__("types").SimpleNamespace(
id="c1", user_id="u1", last_used_provider="provB", last_used_model="modelY", message_count=2
),
"provA", None,
"provA", "modelY"
),
(
__import__("types").SimpleNamespace(
id="c1", user_id="u1", last_used_provider="provB", last_used_model="modelY", message_count=2
),
None, "modelX",
"provB", "modelX"
),
# Request provides mismatched or unknown values; still pass-through expected
(None, "unknownProv", "unknownModel", "unknownProv", "unknownModel"),
],
ids=[
"provider_only",
"model_only",
"override_provider_keep_conv_model",
"override_model_keep_conv_provider",
"unknown_values_passthrough",
],
)
def test_evaluate_model_hints_extended_cases(user_conversation, request_provider, request_model, expected_provider, expected_model):
from app.endpoints.query import evaluate_model_hints
from app.schemas.query import QueryRequest
qr = QueryRequest(query="X", provider=request_provider, model=request_model) # pylint: disable=missing-kwoa
model_id, provider_id = evaluate_model_hints(user_conversation, qr)
# The original tests asserted provider_id == expected_provider and model_id == expected_model
assert provider_id == expected_provider
assert model_id == expected_model

def test_evaluate_model_hints_handles_nulls_gracefully():
from app.endpoints.query import evaluate_model_hints
from app.schemas.query import QueryRequest
qr = QueryRequest(query="X") # no hints
model_id, provider_id = evaluate_model_hints(None, qr)
assert model_id is None
assert provider_id is None

@pytest.mark.asyncio
async def test_query_endpoint_handler_rag_integration_with_vector_ids(mocker, setup_configuration):
# Focus: When vector_db_ids provided, ensure RAG toolgroups passed to generation layer.
from app.endpoints.query import query_endpoint_handler

mocker.patch("app.endpoints.query.configuration", setup_configuration)
mocker.patch("app.endpoints.query.validate_conversation_ownership", return_value=True)

captured = {}

def fake_generate_response(*args, **kwargs):
# Capture toolgroups to assert correct propagation
captured["toolgroups"] = kwargs.get("toolgroups") or kwargs.get("tools") or kwargs
return {"response": "rag-ok"}

mocker.patch("app.endpoints.query.generate_response", side_effect=fake_generate_response)
req = _make_query_request(query="RAG?", vector_db_ids=["Vector-DB-1", "Vector-DB-2"])
out = await query_endpoint_handler(req, auth=["user-9", "", "token"])
assert out["response"] == "rag-ok"
assert "toolgroups" in captured
tg = captured["toolgroups"]
# Validate basic structure if present
if tg is not None:
assert isinstance(tg, list)
assert any(t.get("name") == "builtin::rag/knowledge_search" for t in tg)

@pytest.mark.asyncio
async def test_query_endpoint_handler_configuration_loaded_but_invalid_llama_stack(mocker, setup_configuration):
# Focus: If llama_stack client configuration is missing/invalid, handler should return 500 with clear message
from app.endpoints.query import query_endpoint_handler

bad_cfg = setup_configuration
# Intentionally break llama_stack config
bad_cfg.llama_stack.api_key = ""
bad_cfg.llama_stack.url = ""
mocker.patch("app.endpoints.query.configuration", bad_cfg)
mocker.patch("app.endpoints.query.validate_conversation_ownership", return_value=True)

req = _make_query_request(query="Q")
with pytest.raises(HTTPException) as e:
await query_endpoint_handler(req, auth=["user", "", "t"])
assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
detail = getattr(e.value, "detail", {})
assert isinstance(detail, dict)
assert any(k in detail for k in ("response", "message", "error"))

Loading