Skip to content

Commit ae6da91

Browse files
committed
Add support output_shields in agents
1 parent eb13e53 commit ae6da91

5 files changed

Lines changed: 246 additions & 36 deletions

File tree

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,16 @@ customization:
133133
disable_query_system_prompt: true
134134
```
135135

136+
## Safety Shields
137+
138+
A single Llama Stack configuration file can include multiple safety shields, which are utilized in agent
139+
configurations to monitor input and/or output streams. LCS uses the following naming convention to specify how each safety shield is
140+
utilized:
141+
142+
1. If the `shield_id` starts with `input_`, it will be used for input only.
143+
1. If the `shield_id` starts with `output_`, it will be used for output only.
144+
1. If the `shield_id` starts with `inout_`, it will be used both for input and output.
145+
1. Otherwise, it will be used for input only.
136146

137147
# Usage
138148

src/app/endpoints/query.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from llama_stack_client.lib.agents.agent import Agent
1313
from llama_stack_client import APIConnectionError
1414
from llama_stack_client import LlamaStackClient # type: ignore
15-
from llama_stack_client.types import UserMessage # type: ignore
15+
from llama_stack_client.types import UserMessage, Shield # type: ignore
1616
from llama_stack_client.types.agents.turn_create_params import (
1717
ToolgroupAgentToolGroupWithArgs,
1818
Toolgroup,
@@ -72,11 +72,12 @@ def is_transcripts_enabled() -> bool:
7272
return not configuration.user_data_collection_configuration.transcripts_disabled
7373

7474

75-
def get_agent(
75+
def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments
7676
client: LlamaStackClient,
7777
model_id: str,
7878
system_prompt: str,
79-
available_shields: list[str],
79+
available_input_shields: list[str],
80+
available_output_shields: list[str],
8081
conversation_id: str | None,
8182
) -> tuple[Agent, str]:
8283
"""Get existing agent or create a new one with session persistence."""
@@ -92,7 +93,8 @@ def get_agent(
9293
client,
9394
model=model_id,
9495
instructions=system_prompt,
95-
input_shields=available_shields if available_shields else [],
96+
input_shields=available_input_shields if available_input_shields else [],
97+
output_shields=available_output_shields if available_output_shields else [],
9698
tool_parser=GraniteToolParser.get_parser(model_id),
9799
enable_session_persistence=True,
98100
)
@@ -202,6 +204,20 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s
202204
return model_id
203205

204206

207+
def _is_inout_shield(shield: Shield) -> bool:
208+
return shield.identifier.startswith("inout_")
209+
210+
211+
def is_output_shield(shield: Shield) -> bool:
212+
"""Determine if the shield is for monitoring output."""
213+
return _is_inout_shield(shield) or shield.identifier.startswith("output_")
214+
215+
216+
def is_input_shield(shield: Shield) -> bool:
217+
"""Determine if the shield is for monitoring input."""
218+
return _is_inout_shield(shield) or not is_output_shield(shield)
219+
220+
205221
def retrieve_response(
206222
client: LlamaStackClient,
207223
model_id: str,
@@ -210,11 +226,21 @@ def retrieve_response(
210226
mcp_headers: dict[str, dict[str, str]] | None = None,
211227
) -> tuple[str, str]:
212228
"""Retrieve response from LLMs and agents."""
213-
available_shields = [shield.identifier for shield in client.shields.list()]
214-
if not available_shields:
215-
logger.info("No available shields. Disabling safety")
229+
available_input_shields = [
230+
shield.identifier for shield in filter(is_input_shield, client.shields.list())
231+
]
232+
if not available_input_shields:
233+
logger.info("No available input shields.")
234+
else:
235+
logger.info("Available input shields found: %s", available_input_shields)
236+
237+
available_output_shields = [
238+
shield.identifier for shield in filter(is_output_shield, client.shields.list())
239+
]
240+
if not available_output_shields:
241+
logger.info("No available output shields.")
216242
else:
217-
logger.info("Available shields found: %s", available_shields)
243+
logger.info("Available output shields found: %s", available_output_shields)
218244

219245
# use system prompt from request or default one
220246
system_prompt = get_system_prompt(query_request, configuration)
@@ -229,7 +255,8 @@ def retrieve_response(
229255
client,
230256
model_id,
231257
system_prompt,
232-
available_shields,
258+
available_input_shields,
259+
available_output_shields,
233260
query_request.conversation_id,
234261
)
235262

src/app/endpoints/streaming_query.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import logging
55
import re
6+
from json import JSONDecodeError
67
from typing import Any, AsyncIterator
78

89
from cachetools import TTLCache # type: ignore
@@ -29,6 +30,8 @@
2930
from app.endpoints.conversations import conversation_id_to_agent_id
3031
from app.endpoints.query import (
3132
get_rag_toolgroups,
33+
is_input_shield,
34+
is_output_shield,
3235
is_transcripts_enabled,
3336
store_transcript,
3437
select_model_id,
@@ -43,11 +46,12 @@
4346
_agent_cache: TTLCache[str, AsyncAgent] = TTLCache(maxsize=1000, ttl=3600)
4447

4548

46-
async def get_agent(
49+
async def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments
4750
client: AsyncLlamaStackClient,
4851
model_id: str,
4952
system_prompt: str,
50-
available_shields: list[str],
53+
available_input_shields: list[str],
54+
available_output_shields: list[str],
5155
conversation_id: str | None,
5256
) -> tuple[AsyncAgent, str]:
5357
"""Get existing agent or create a new one with session persistence."""
@@ -62,7 +66,8 @@ async def get_agent(
6266
client, # type: ignore[arg-type]
6367
model=model_id,
6468
instructions=system_prompt,
65-
input_shields=available_shields if available_shields else [],
69+
input_shields=available_input_shields if available_input_shields else [],
70+
output_shields=available_output_shields if available_output_shields else [],
6671
tool_parser=GraniteToolParser.get_parser(model_id),
6772
enable_session_persistence=True,
6873
)
@@ -166,8 +171,13 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | N
166171
for match in METADATA_PATTERN.findall(
167172
text_content_item.text
168173
):
169-
meta = json.loads(match.replace("'", '"'))
170-
metadata_map[meta["document_id"]] = meta
174+
try:
175+
meta = json.loads(match.replace("'", '"'))
176+
metadata_map[meta["document_id"]] = meta
177+
except JSONDecodeError:
178+
pass
179+
finally:
180+
pass
171181
if chunk.event.payload.step_details.tool_calls:
172182
tool_name = str(
173183
chunk.event.payload.step_details.tool_calls[0].tool_name
@@ -268,11 +278,18 @@ async def retrieve_response(
268278
mcp_headers: dict[str, dict[str, str]] | None = None,
269279
) -> tuple[Any, str]:
270280
"""Retrieve response from LLMs and agents."""
271-
available_shields = [shield.identifier for shield in await client.shields.list()]
272-
if not available_shields:
281+
available_input_shields = [
282+
shield.identifier
283+
for shield in filter(is_input_shield, await client.shields.list())
284+
]
285+
available_output_shields = [
286+
shield.identifier
287+
for shield in filter(is_output_shield, await client.shields.list())
288+
]
289+
if not available_input_shields:
273290
logger.info("No available shields. Disabling safety")
274291
else:
275-
logger.info("Available shields found: %s", available_shields)
292+
logger.info("Available shields found: %s", available_input_shields)
276293

277294
# use system prompt from request or default one
278295
system_prompt = get_system_prompt(query_request, configuration)
@@ -287,7 +304,8 @@ async def retrieve_response(
287304
client,
288305
model_id,
289306
system_prompt,
290-
available_shields,
307+
available_input_shields,
308+
available_output_shields,
291309
query_request.conversation_id,
292310
)
293311

tests/unit/app/endpoints/test_query.py

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,69 @@ def __repr__(self):
469469
)
470470

471471

472+
def test_retrieve_response_four_available_shields(prepare_agent_mocks, mocker):
473+
"""Test the retrieve_response function."""
474+
475+
class MockShield:
476+
"""Mock for Llama Stack shield to be used."""
477+
478+
def __init__(self, identifier):
479+
self.identifier = identifier
480+
481+
def __str__(self):
482+
return "MockShield"
483+
484+
def __repr__(self):
485+
return "MockShield"
486+
487+
mock_client, mock_agent = prepare_agent_mocks
488+
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
489+
mock_client.shields.list.return_value = [
490+
MockShield("shield1"),
491+
MockShield("input_shield2"),
492+
MockShield("output_shield3"),
493+
MockShield("inout_shield4"),
494+
]
495+
mock_client.vector_dbs.list.return_value = []
496+
497+
# Mock configuration with empty MCP servers
498+
mock_config = mocker.Mock()
499+
mock_config.mcp_servers = []
500+
mocker.patch("app.endpoints.query.configuration", mock_config)
501+
mock_get_agent = mocker.patch(
502+
"app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id")
503+
)
504+
505+
query_request = QueryRequest(query="What is OpenStack?")
506+
model_id = "fake_model_id"
507+
access_token = "test_token"
508+
509+
response, conversation_id = retrieve_response(
510+
mock_client, model_id, query_request, access_token
511+
)
512+
513+
assert response == "LLM answer"
514+
assert conversation_id == "fake_session_id"
515+
516+
# Verify get_agent was called with the correct parameters
517+
mock_get_agent.assert_called_once_with(
518+
mock_client,
519+
model_id,
520+
mocker.ANY, # system_prompt
521+
["shield1", "input_shield2", "inout_shield4"], # available_input_shields
522+
["output_shield3", "inout_shield4"], # available_output_shields
523+
None, # conversation_id
524+
)
525+
526+
mock_agent.create_turn.assert_called_once_with(
527+
messages=[UserMessage(content="What is OpenStack?", role="user")],
528+
session_id="fake_session_id",
529+
documents=[],
530+
stream=False,
531+
toolgroups=None,
532+
)
533+
534+
472535
def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker):
473536
"""Test the retrieve_response function."""
474537
mock_client, mock_agent = prepare_agent_mocks
@@ -613,7 +676,8 @@ def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker):
613676
mock_client,
614677
model_id,
615678
mocker.ANY, # system_prompt
616-
[], # available_shields
679+
[], # available_input_shields
680+
[], # available_output_shields
617681
None, # conversation_id
618682
)
619683

@@ -676,7 +740,8 @@ def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, moc
676740
mock_client,
677741
model_id,
678742
mocker.ANY, # system_prompt
679-
[], # available_shields
743+
[], # available_input_shields
744+
[], # available_output_shields
680745
None, # conversation_id
681746
)
682747

@@ -746,7 +811,8 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker):
746811
mock_client,
747812
model_id,
748813
mocker.ANY, # system_prompt
749-
[], # available_shields
814+
[], # available_input_shields
815+
[], # available_output_shields
750816
None, # conversation_id
751817
)
752818

@@ -900,7 +966,8 @@ def test_get_agent_cache_hit(prepare_agent_mocks):
900966
client=mock_client,
901967
model_id="test_model",
902968
system_prompt="test_prompt",
903-
available_shields=["shield1"],
969+
available_input_shields=["shield1"],
970+
available_output_shields=["output_shield2"],
904971
conversation_id=conversation_id,
905972
)
906973

@@ -940,7 +1007,8 @@ def test_get_agent_cache_miss_with_conversation_id(
9401007
client=mock_client,
9411008
model_id="test_model",
9421009
system_prompt="test_prompt",
943-
available_shields=["shield1"],
1010+
available_input_shields=["shield1"],
1011+
available_output_shields=["output_shield2"],
9441012
conversation_id="non_existent_conversation_id",
9451013
)
9461014

@@ -954,6 +1022,7 @@ def test_get_agent_cache_miss_with_conversation_id(
9541022
model="test_model",
9551023
instructions="test_prompt",
9561024
input_shields=["shield1"],
1025+
output_shields=["output_shield2"],
9571026
tool_parser=None,
9581027
enable_session_persistence=True,
9591028
)
@@ -991,7 +1060,8 @@ def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks,
9911060
client=mock_client,
9921061
model_id="test_model",
9931062
system_prompt="test_prompt",
994-
available_shields=["shield1"],
1063+
available_input_shields=["shield1"],
1064+
available_output_shields=["output_shield2"],
9951065
conversation_id=None,
9961066
)
9971067

@@ -1005,6 +1075,7 @@ def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks,
10051075
model="test_model",
10061076
instructions="test_prompt",
10071077
input_shields=["shield1"],
1078+
output_shields=["output_shield2"],
10081079
tool_parser=None,
10091080
enable_session_persistence=True,
10101081
)
@@ -1042,7 +1113,8 @@ def test_get_agent_empty_shields(setup_configuration, prepare_agent_mocks, mocke
10421113
client=mock_client,
10431114
model_id="test_model",
10441115
system_prompt="test_prompt",
1045-
available_shields=[],
1116+
available_input_shields=[],
1117+
available_output_shields=[],
10461118
conversation_id=None,
10471119
)
10481120

@@ -1056,6 +1128,7 @@ def test_get_agent_empty_shields(setup_configuration, prepare_agent_mocks, mocke
10561128
model="test_model",
10571129
instructions="test_prompt",
10581130
input_shields=[],
1131+
output_shields=[],
10591132
tool_parser=None,
10601133
enable_session_persistence=True,
10611134
)
@@ -1094,7 +1167,8 @@ def test_get_agent_multiple_mcp_servers(
10941167
client=mock_client,
10951168
model_id="test_model",
10961169
system_prompt="test_prompt",
1097-
available_shields=["shield1", "shield2"],
1170+
available_input_shields=["shield1", "shield2"],
1171+
available_output_shields=["output_shield3", "output_shield4"],
10981172
conversation_id=None,
10991173
)
11001174

@@ -1108,6 +1182,7 @@ def test_get_agent_multiple_mcp_servers(
11081182
model="test_model",
11091183
instructions="test_prompt",
11101184
input_shields=["shield1", "shield2"],
1185+
output_shields=["output_shield3", "output_shield4"],
11111186
tool_parser=None,
11121187
enable_session_persistence=True,
11131188
)
@@ -1144,7 +1219,8 @@ def test_get_agent_session_persistence_enabled(
11441219
client=mock_client,
11451220
model_id="test_model",
11461221
system_prompt="test_prompt",
1147-
available_shields=["shield1"],
1222+
available_input_shields=["shield1"],
1223+
available_output_shields=["output_shield2"],
11481224
conversation_id=None,
11491225
)
11501226

@@ -1154,6 +1230,7 @@ def test_get_agent_session_persistence_enabled(
11541230
model="test_model",
11551231
instructions="test_prompt",
11561232
input_shields=["shield1"],
1233+
output_shields=["output_shield2"],
11571234
tool_parser=None,
11581235
enable_session_persistence=True,
11591236
)

0 commit comments

Comments
 (0)