Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions tests/unit/utils/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import constants
from configuration import AppConfig
from models.config import CustomProfile
from models.responses import ReferencedDocument
from models.requests import QueryRequest
from models.config import Action
from utils import endpoints
Expand Down Expand Up @@ -111,13 +112,17 @@ def config_with_custom_profile_prompt_and_disable_query_system_prompt_fixture():
@pytest.fixture(name="query_request_without_system_prompt")
def query_request_without_system_prompt_fixture():
"""Fixture for query request without system prompt."""
return QueryRequest(query="query", system_prompt=None)
return QueryRequest(
query="query", system_prompt=None
) # pyright: ignore[reportCallIssue]


@pytest.fixture(name="query_request_with_system_prompt")
def query_request_with_system_prompt_fixture():
"""Fixture for query request with system prompt."""
return QueryRequest(query="query", system_prompt="System prompt defined in query")
return QueryRequest(
query="query", system_prompt="System prompt defined in query"
) # pyright: ignore[reportCallIssue]


@pytest.fixture(name="setup_configuration")
Expand Down Expand Up @@ -762,14 +767,18 @@ async def test_get_temp_agent_no_persistence(prepare_agent_mocks, mocker):

def test_validate_model_provider_override_allowed_with_action():
"""Ensure no exception when caller has MODEL_OVERRIDE and request includes model/provider."""
query_request = QueryRequest(query="q", model="m", provider="p")
query_request = QueryRequest(
query="q", model="m", provider="p"
) # pyright: ignore[reportCallIssue]
authorized_actions = {Action.MODEL_OVERRIDE}
endpoints.validate_model_provider_override(query_request, authorized_actions)


def test_validate_model_provider_override_rejected_without_action():
"""Ensure HTTP 403 when request includes model/provider and caller lacks permission."""
query_request = QueryRequest(query="q", model="m", provider="p")
query_request = QueryRequest(
query="q", model="m", provider="p"
) # pyright: ignore[reportCallIssue]
authorized_actions: set[Action] = set()
with pytest.raises(HTTPException) as exc_info:
endpoints.validate_model_provider_override(query_request, authorized_actions)
Expand All @@ -778,7 +787,7 @@ def test_validate_model_provider_override_rejected_without_action():

def test_validate_model_provider_override_no_override_without_action():
"""No exception when request does not include model/provider regardless of permission."""
query_request = QueryRequest(query="q")
query_request = QueryRequest(query="q") # pyright:ignore[reportCallIssue]
endpoints.validate_model_provider_override(query_request, set())


Expand Down Expand Up @@ -861,7 +870,14 @@ def test_create_referenced_documents_http_urls_referenced_document_format(self):

result = endpoints.create_referenced_documents([mock_chunk1, mock_chunk2])

# two referenced documents are expected
assert len(result) == 2
# results must exist
assert result[0] is not None
assert result[1] is not None
# results must be of the right type
assert isinstance(result[0], ReferencedDocument)
assert isinstance(result[1], ReferencedDocument)
assert result[0].doc_url == AnyUrl("https://example.com/doc1")
assert result[0].doc_title == "doc1"
assert result[1].doc_url == AnyUrl("https://example.com/doc2")
Expand All @@ -882,7 +898,14 @@ def test_create_referenced_documents_document_ids_with_metadata(self):
[mock_chunk1, mock_chunk2], metadata_map
)

# two referenced documents are expected
assert len(result) == 2
# results must exist
assert result[0] is not None
assert result[1] is not None
# results must be of the right type
assert isinstance(result[0], ReferencedDocument)
assert isinstance(result[1], ReferencedDocument)
assert result[0].doc_url == AnyUrl("https://example.com/doc1")
assert result[0].doc_title == "Document 1"
assert result[1].doc_url == AnyUrl("https://example.com/doc2")
Expand All @@ -896,7 +919,12 @@ def test_create_referenced_documents_skips_tool_names(self):

result = endpoints.create_referenced_documents([mock_chunk1, mock_chunk2])

# one referenced document is expected
assert len(result) == 1
# result must exist
assert result[0] is not None
# result must be of the right type
assert isinstance(result[0], ReferencedDocument)
assert result[0].doc_url == AnyUrl("https://example.com/doc1")
assert result[0].doc_title == "doc1"

Expand All @@ -911,7 +939,12 @@ def test_create_referenced_documents_skips_empty_sources(self):
[mock_chunk1, mock_chunk2, mock_chunk3]
)

# one referenced document is expected
assert len(result) == 1
# result must exist
assert result[0] is not None
# result must be of the right type
assert isinstance(result[0], ReferencedDocument)
assert result[0].doc_url == AnyUrl("https://example.com/doc1")
assert result[0].doc_title == "doc1"

Expand All @@ -929,7 +962,14 @@ def test_create_referenced_documents_deduplication(self):
[mock_chunk1, mock_chunk2, mock_chunk3, mock_chunk4]
)

# two referenced documents are expected
assert len(result) == 2
# results must exist
assert result[0] is not None
assert result[1] is not None
# results must be of the right type
assert isinstance(result[0], ReferencedDocument)
assert isinstance(result[1], ReferencedDocument)
assert result[0].doc_url == AnyUrl("https://example.com/doc1")
assert result[1].doc_title == "doc_id_1"

Expand All @@ -941,7 +981,14 @@ def test_create_referenced_documents_invalid_urls(self):

result = endpoints.create_referenced_documents([mock_chunk1, mock_chunk2])

# two referenced documents are expected
assert len(result) == 2
# results must exist
assert result[0] is not None
assert result[1] is not None
# results must be of the right type
assert isinstance(result[0], ReferencedDocument)
assert isinstance(result[1], ReferencedDocument)
assert result[0].doc_url is None
assert result[0].doc_title == "not-a-valid-url"
assert result[1].doc_url == AnyUrl("https://example.com/doc1")
Expand Down
Loading