Skip to content

Commit 9e02999

Browse files
committed
Avoid code duplication at retrieve_response for responses API (v2)
1 parent e95bc2b commit 9e02999

2 files changed

Lines changed: 60 additions & 53 deletions

File tree

src/app/endpoints/query_v2.py

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from authentication import get_auth_dependency
1818
from authentication.interface import AuthTuple
1919
from authorization.middleware import authorize
20-
from configuration import configuration
20+
from configuration import AppConfig, configuration
2121
import metrics
2222
from models.config import Action
2323
from models.requests import QueryRequest
@@ -190,31 +190,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
190190
validate_attachments_metadata(query_request.attachments)
191191

192192
# Prepare tools for responses API
193-
toolgroups: list[dict[str, Any]] | None = None
194-
if not query_request.no_tools:
195-
toolgroups = []
196-
# Get vector stores for RAG tools
197-
vector_store_ids = [
198-
vector_store.id for vector_store in (await client.vector_stores.list()).data
199-
]
200-
201-
# Add RAG tools if vector stores are available
202-
rag_tools = get_rag_tools(vector_store_ids)
203-
if rag_tools:
204-
toolgroups.extend(rag_tools)
205-
206-
# Add MCP server tools
207-
mcp_tools = get_mcp_tools(configuration.mcp_servers, token)
208-
if mcp_tools:
209-
toolgroups.extend(mcp_tools)
210-
logger.debug(
211-
"Configured %d MCP tools: %s",
212-
len(mcp_tools),
213-
[tool.get("server_label", "unknown") for tool in mcp_tools],
214-
)
215-
# Convert empty list to None for consistency with existing behavior
216-
if not toolgroups:
217-
toolgroups = None
193+
toolgroups = await prepare_tools_for_responses_api(
194+
client, query_request, token, configuration
195+
)
218196

219197
# Prepare input for Responses API
220198
# Convert attachments to text and concatenate with query
@@ -470,3 +448,55 @@ def get_mcp_tools(mcp_servers: list, token: str | None = None) -> list[dict[str,
470448

471449
tools.append(tool_def)
472450
return tools
451+
452+
453+
async def prepare_tools_for_responses_api(
454+
client: AsyncLlamaStackClient,
455+
query_request: QueryRequest,
456+
token: str,
457+
config: AppConfig,
458+
) -> list[dict[str, Any]] | None:
459+
"""
460+
Prepare tools for Responses API including RAG and MCP tools.
461+
462+
This function retrieves vector stores and combines them with MCP
463+
server tools to create a unified toolgroups list for the Responses API.
464+
465+
Args:
466+
client: The Llama Stack client instance
467+
query_request: The user's query request
468+
token: Authentication token for MCP tools
469+
config: Configuration object containing MCP server settings
470+
471+
Returns:
472+
list[dict[str, Any]] | None: List of tool configurations for the
473+
Responses API, or None if no_tools is True or no tools are available
474+
"""
475+
if query_request.no_tools:
476+
return None
477+
478+
toolgroups = []
479+
# Get vector stores for RAG tools
480+
vector_store_ids = [
481+
vector_store.id for vector_store in (await client.vector_stores.list()).data
482+
]
483+
484+
# Add RAG tools if vector stores are available
485+
rag_tools = get_rag_tools(vector_store_ids)
486+
if rag_tools:
487+
toolgroups.extend(rag_tools)
488+
489+
# Add MCP server tools
490+
mcp_tools = get_mcp_tools(config.mcp_servers, token)
491+
if mcp_tools:
492+
toolgroups.extend(mcp_tools)
493+
logger.debug(
494+
"Configured %d MCP tools: %s",
495+
len(mcp_tools),
496+
[tool.get("server_label", "unknown") for tool in mcp_tools],
497+
)
498+
# Convert empty list to None for consistency with existing behavior
499+
if not toolgroups:
500+
return None
501+
502+
return toolgroups

src/app/endpoints/streaming_query_v2.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
)
2222
from app.endpoints.query_v2 import (
2323
extract_token_usage_from_responses_api,
24-
get_mcp_tools,
25-
get_rag_tools,
2624
get_topic_summary,
25+
prepare_tools_for_responses_api,
2726
)
2827
from app.endpoints.streaming_query import (
2928
format_stream_data,
@@ -448,31 +447,9 @@ async def retrieve_response(
448447
validate_attachments_metadata(query_request.attachments)
449448

450449
# Prepare tools for responses API
451-
toolgroups: list[dict[str, Any]] | None = None
452-
if not query_request.no_tools:
453-
toolgroups = []
454-
# Get vector stores for RAG tools
455-
vector_store_ids = [
456-
vector_store.id for vector_store in (await client.vector_stores.list()).data
457-
]
458-
459-
# Add RAG tools if vector stores are available
460-
rag_tools = get_rag_tools(vector_store_ids)
461-
if rag_tools:
462-
toolgroups.extend(rag_tools)
463-
464-
# Add MCP server tools
465-
mcp_tools = get_mcp_tools(configuration.mcp_servers, token)
466-
if mcp_tools:
467-
toolgroups.extend(mcp_tools)
468-
logger.debug(
469-
"Configured %d MCP tools: %s",
470-
len(mcp_tools),
471-
[tool.get("server_label", "unknown") for tool in mcp_tools],
472-
)
473-
# Convert empty list to None for consistency with existing behavior
474-
if not toolgroups:
475-
toolgroups = None
450+
toolgroups = await prepare_tools_for_responses_api(
451+
client, query_request, token, configuration
452+
)
476453

477454
response = await client.responses.create(
478455
input=query_request.query,

0 commit comments

Comments
 (0)