@@ -469,6 +469,69 @@ def __repr__(self):
469469 )
470470
471471
472+ def test_retrieve_response_four_available_shields (prepare_agent_mocks , mocker ):
473+ """Test the retrieve_response function."""
474+
475+ class MockShield :
476+ """Mock for Llama Stack shield to be used."""
477+
478+ def __init__ (self , identifier ):
479+ self .identifier = identifier
480+
481+ def __str__ (self ):
482+ return "MockShield"
483+
484+ def __repr__ (self ):
485+ return "MockShield"
486+
487+ mock_client , mock_agent = prepare_agent_mocks
488+ mock_agent .create_turn .return_value .output_message .content = "LLM answer"
489+ mock_client .shields .list .return_value = [
490+ MockShield ("shield1" ),
491+ MockShield ("input_shield2" ),
492+ MockShield ("output_shield3" ),
493+ MockShield ("inout_shield4" ),
494+ ]
495+ mock_client .vector_dbs .list .return_value = []
496+
497+ # Mock configuration with empty MCP servers
498+ mock_config = mocker .Mock ()
499+ mock_config .mcp_servers = []
500+ mocker .patch ("app.endpoints.query.configuration" , mock_config )
501+ mock_get_agent = mocker .patch (
502+ "app.endpoints.query.get_agent" , return_value = (mock_agent , "fake_session_id" )
503+ )
504+
505+ query_request = QueryRequest (query = "What is OpenStack?" )
506+ model_id = "fake_model_id"
507+ access_token = "test_token"
508+
509+ response , conversation_id = retrieve_response (
510+ mock_client , model_id , query_request , access_token
511+ )
512+
513+ assert response == "LLM answer"
514+ assert conversation_id == "fake_session_id"
515+
516+ # Verify get_agent was called with the correct parameters
517+ mock_get_agent .assert_called_once_with (
518+ mock_client ,
519+ model_id ,
520+ mocker .ANY , # system_prompt
521+ ["shield1" , "input_shield2" , "inout_shield4" ], # available_input_shields
522+ ["output_shield3" , "inout_shield4" ], # available_output_shields
523+ None , # conversation_id
524+ )
525+
526+ mock_agent .create_turn .assert_called_once_with (
527+ messages = [UserMessage (content = "What is OpenStack?" , role = "user" )],
528+ session_id = "fake_session_id" ,
529+ documents = [],
530+ stream = False ,
531+ toolgroups = None ,
532+ )
533+
534+
472535def test_retrieve_response_with_one_attachment (prepare_agent_mocks , mocker ):
473536 """Test the retrieve_response function."""
474537 mock_client , mock_agent = prepare_agent_mocks
@@ -613,7 +676,8 @@ def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker):
613676 mock_client ,
614677 model_id ,
615678 mocker .ANY , # system_prompt
616- [], # available_shields
679+ [], # available_input_shields
680+ [], # available_output_shields
617681 None , # conversation_id
618682 )
619683
@@ -676,7 +740,8 @@ def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, moc
676740 mock_client ,
677741 model_id ,
678742 mocker .ANY , # system_prompt
679- [], # available_shields
743+ [], # available_input_shields
744+ [], # available_output_shields
680745 None , # conversation_id
681746 )
682747
@@ -746,7 +811,8 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker):
746811 mock_client ,
747812 model_id ,
748813 mocker .ANY , # system_prompt
749- [], # available_shields
814+ [], # available_input_shields
815+ [], # available_output_shields
750816 None , # conversation_id
751817 )
752818
@@ -900,7 +966,8 @@ def test_get_agent_cache_hit(prepare_agent_mocks):
900966 client = mock_client ,
901967 model_id = "test_model" ,
902968 system_prompt = "test_prompt" ,
903- available_shields = ["shield1" ],
969+ available_input_shields = ["shield1" ],
970+ available_output_shields = ["output_shield2" ],
904971 conversation_id = conversation_id ,
905972 )
906973
@@ -940,7 +1007,8 @@ def test_get_agent_cache_miss_with_conversation_id(
9401007 client = mock_client ,
9411008 model_id = "test_model" ,
9421009 system_prompt = "test_prompt" ,
943- available_shields = ["shield1" ],
1010+ available_input_shields = ["shield1" ],
1011+ available_output_shields = ["output_shield2" ],
9441012 conversation_id = "non_existent_conversation_id" ,
9451013 )
9461014
@@ -954,6 +1022,7 @@ def test_get_agent_cache_miss_with_conversation_id(
9541022 model = "test_model" ,
9551023 instructions = "test_prompt" ,
9561024 input_shields = ["shield1" ],
1025+ output_shields = ["output_shield2" ],
9571026 tool_parser = None ,
9581027 enable_session_persistence = True ,
9591028 )
@@ -991,7 +1060,8 @@ def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks,
9911060 client = mock_client ,
9921061 model_id = "test_model" ,
9931062 system_prompt = "test_prompt" ,
994- available_shields = ["shield1" ],
1063+ available_input_shields = ["shield1" ],
1064+ available_output_shields = ["output_shield2" ],
9951065 conversation_id = None ,
9961066 )
9971067
@@ -1005,6 +1075,7 @@ def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks,
10051075 model = "test_model" ,
10061076 instructions = "test_prompt" ,
10071077 input_shields = ["shield1" ],
1078+ output_shields = ["output_shield2" ],
10081079 tool_parser = None ,
10091080 enable_session_persistence = True ,
10101081 )
@@ -1042,7 +1113,8 @@ def test_get_agent_empty_shields(setup_configuration, prepare_agent_mocks, mocke
10421113 client = mock_client ,
10431114 model_id = "test_model" ,
10441115 system_prompt = "test_prompt" ,
1045- available_shields = [],
1116+ available_input_shields = [],
1117+ available_output_shields = [],
10461118 conversation_id = None ,
10471119 )
10481120
@@ -1056,6 +1128,7 @@ def test_get_agent_empty_shields(setup_configuration, prepare_agent_mocks, mocke
10561128 model = "test_model" ,
10571129 instructions = "test_prompt" ,
10581130 input_shields = [],
1131+ output_shields = [],
10591132 tool_parser = None ,
10601133 enable_session_persistence = True ,
10611134 )
@@ -1094,7 +1167,8 @@ def test_get_agent_multiple_mcp_servers(
10941167 client = mock_client ,
10951168 model_id = "test_model" ,
10961169 system_prompt = "test_prompt" ,
1097- available_shields = ["shield1" , "shield2" ],
1170+ available_input_shields = ["shield1" , "shield2" ],
1171+ available_output_shields = ["output_shield3" , "output_shield4" ],
10981172 conversation_id = None ,
10991173 )
11001174
@@ -1108,6 +1182,7 @@ def test_get_agent_multiple_mcp_servers(
11081182 model = "test_model" ,
11091183 instructions = "test_prompt" ,
11101184 input_shields = ["shield1" , "shield2" ],
1185+ output_shields = ["output_shield3" , "output_shield4" ],
11111186 tool_parser = None ,
11121187 enable_session_persistence = True ,
11131188 )
@@ -1144,7 +1219,8 @@ def test_get_agent_session_persistence_enabled(
11441219 client = mock_client ,
11451220 model_id = "test_model" ,
11461221 system_prompt = "test_prompt" ,
1147- available_shields = ["shield1" ],
1222+ available_input_shields = ["shield1" ],
1223+ available_output_shields = ["output_shield2" ],
11481224 conversation_id = None ,
11491225 )
11501226
@@ -1154,6 +1230,7 @@ def test_get_agent_session_persistence_enabled(
11541230 model = "test_model" ,
11551231 instructions = "test_prompt" ,
11561232 input_shields = ["shield1" ],
1233+ output_shields = ["output_shield2" ],
11571234 tool_parser = None ,
11581235 enable_session_persistence = True ,
11591236 )
0 commit comments