1+ """Unit tests for the /query REST API endpoint."""
2+
3+ # pylint: disable=too-many-lines
4+
15import json
26from fastapi import HTTPException , status
37import pytest
48
9+ from llama_stack_client import APIConnectionError
10+ from llama_stack_client .types import UserMessage # type: ignore
11+
512from configuration import AppConfig
613from app .endpoints .query import (
714 query_endpoint_handler ,
1522 get_agent ,
1623 _agent_cache ,
1724)
18- from llama_stack_client import APIConnectionError
25+
1926from models .requests import QueryRequest , Attachment
2027from models .config import ModelContextProtocolServer
21- from llama_stack_client .types import UserMessage # type: ignore
2228
2329MOCK_AUTH = ("mock_user_id" , "mock_username" , "mock_token" )
2430
2531
26- @pytest .fixture (autouse = True )
27- def setup_configuration ():
32+ @pytest .fixture (name = "setup_configuration" )
33+ def setup_configuration_fixture ():
2834 """Set up configuration for tests."""
2935 config_dict = {
3036 "name" : "test" ,
@@ -52,12 +58,13 @@ def setup_configuration():
5258 return cfg
5359
5460
55- @pytest .fixture (autouse = True )
56- def prepare_agent_mocks (mocker ):
61+ @pytest .fixture (autouse = True , name = "prepare_agent_mocks" )
62+ def prepare_agent_mocks_fixture (mocker ):
63+ """Fixture that yields mock agent when called."""
5764 mock_client = mocker .Mock ()
5865 mock_agent = mocker .Mock ()
59- """Cleanup agent cache after tests."""
6066 yield mock_client , mock_agent
67+ # cleanup agent cache after tests
6168 _agent_cache .clear ()
6269
6370
@@ -98,7 +105,7 @@ def test_is_transcripts_disabled(setup_configuration, mocker):
98105 assert is_transcripts_enabled () is False , "Transcripts should be disabled"
99106
100107
101- def _test_query_endpoint_handler (mocker , store_transcript = False ):
108+ def _test_query_endpoint_handler (mocker , store_transcript_to_file = False ):
102109 """Test the query endpoint handler."""
103110 mock_client = mocker .Mock ()
104111 mock_lsc = mocker .patch ("client.LlamaStackClientHolder.get_client" )
@@ -110,7 +117,7 @@ def _test_query_endpoint_handler(mocker, store_transcript=False):
110117
111118 mock_config = mocker .Mock ()
112119 mock_config .user_data_collection_configuration .transcripts_disabled = (
113- not store_transcript
120+ not store_transcript_to_file
114121 )
115122 mocker .patch ("app.endpoints.query.configuration" , mock_config )
116123
@@ -124,7 +131,8 @@ def _test_query_endpoint_handler(mocker, store_transcript=False):
124131 )
125132 mocker .patch ("app.endpoints.query.select_model_id" , return_value = "fake_model_id" )
126133 mocker .patch (
127- "app.endpoints.query.is_transcripts_enabled" , return_value = store_transcript
134+ "app.endpoints.query.is_transcripts_enabled" ,
135+ return_value = store_transcript_to_file ,
128136 )
129137 mock_transcript = mocker .patch ("app.endpoints.query.store_transcript" )
130138
@@ -137,7 +145,7 @@ def _test_query_endpoint_handler(mocker, store_transcript=False):
137145 assert response .conversation_id == conversation_id
138146
139147 # Assert the store_transcript function is called if transcripts are enabled
140- if store_transcript :
148+ if store_transcript_to_file :
141149 mock_transcript .assert_called_once_with (
142150 user_id = "user_id_placeholder" ,
143151 conversation_id = conversation_id ,
@@ -155,12 +163,12 @@ def _test_query_endpoint_handler(mocker, store_transcript=False):
155163
156164def test_query_endpoint_handler_transcript_storage_disabled (mocker ):
157165 """Test the query endpoint handler with transcript storage disabled."""
158- _test_query_endpoint_handler (mocker , store_transcript = False )
166+ _test_query_endpoint_handler (mocker , store_transcript_to_file = False )
159167
160168
161169def test_query_endpoint_handler_store_transcript (mocker ):
162170 """Test the query endpoint handler with transcript storage enabled."""
163- _test_query_endpoint_handler (mocker , store_transcript = True )
171+ _test_query_endpoint_handler (mocker , store_transcript_to_file = True )
164172
165173
166174def test_select_model_id (mocker ):
@@ -368,9 +376,17 @@ def test_retrieve_response_one_available_shield(prepare_agent_mocks, mocker):
368376 """Test the retrieve_response function."""
369377
370378 class MockShield :
379+ """Mock for Llama Stack shield to be used."""
380+
371381 def __init__ (self , identifier ):
372382 self .identifier = identifier
373383
384+ def __str__ (self ):
385+ return "MockShield"
386+
387+ def __repr__ (self ):
388+ return "MockShield"
389+
374390 mock_client , mock_agent = prepare_agent_mocks
375391 mock_agent .create_turn .return_value .output_message .content = "LLM answer"
376392 mock_client .shields .list .return_value = [MockShield ("shield1" )]
@@ -407,9 +423,17 @@ def test_retrieve_response_two_available_shields(prepare_agent_mocks, mocker):
407423 """Test the retrieve_response function."""
408424
409425 class MockShield :
426+ """Mock for Llama Stack shield to be used."""
427+
410428 def __init__ (self , identifier ):
411429 self .identifier = identifier
412430
431+ def __str__ (self ):
432+ return "MockShield"
433+
434+ def __repr__ (self ):
435+ return "MockShield"
436+
413437 mock_client , mock_agent = prepare_agent_mocks
414438 mock_agent .create_turn .return_value .output_message .content = "LLM answer"
415439 mock_client .shields .list .return_value = [
@@ -832,7 +856,7 @@ def test_store_transcript(mocker):
832856 )
833857
834858
835- def test_get_rag_toolgroups (mocker ):
859+ def test_get_rag_toolgroups ():
836860 """Test get_rag_toolgroups function."""
837861 vector_db_ids = []
838862 result = get_rag_toolgroups (vector_db_ids )
@@ -864,7 +888,7 @@ def test_query_endpoint_handler_on_connection_error(mocker):
864888 query_endpoint_handler (query_request )
865889
866890
867- def test_get_agent_cache_hit (prepare_agent_mocks , mocker ):
891+ def test_get_agent_cache_hit (prepare_agent_mocks ):
868892 """Test get_agent function when agent exists in cache."""
869893 mock_client , mock_agent = prepare_agent_mocks
870894
0 commit comments