Skip to content

Commit 2e91d95

Browse files
author
Radovan Fuchs
committed
add option to disable topic summary
1 parent db45d8c commit 2e91d95

6 files changed

Lines changed: 243 additions & 8 deletions

File tree

src/app/endpoints/query.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,19 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
318318
session.query(UserConversation).filter_by(id=conversation_id).first()
319319
)
320320
if not existing_conversation:
321-
topic_summary = await get_topic_summary_func(
322-
query_request.query, client, llama_stack_model_id
323-
)
321+
# Check if topic summary should be generated (default: True)
322+
should_generate = query_request.generate_topic_summary
323+
324+
if should_generate:
325+
logger.debug("Generating topic summary for new conversation")
326+
topic_summary = await get_topic_summary_func(
327+
query_request.query, client, llama_stack_model_id
328+
)
329+
else:
330+
logger.debug(
331+
"Topic summary generation disabled by request parameter"
332+
)
333+
topic_summary = None
324334
# Convert RAG chunks to dictionary format once for reuse
325335
logger.info("Processing RAG chunks...")
326336
rag_chunks_dict = [chunk.model_dump() for chunk in summary.rag_chunks]

src/models/requests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class QueryRequest(BaseModel):
8181
system_prompt: The optional system prompt.
8282
attachments: The optional attachments.
8383
no_tools: Whether to bypass all tools and MCP servers (default: False).
84+
generate_topic_summary: Whether to generate topic summary for new conversations.
8485
media_type: The optional media type for response format (application/json or text/plain).
8586
8687
Example:
@@ -146,6 +147,12 @@ class QueryRequest(BaseModel):
146147
examples=[True, False],
147148
)
148149

150+
generate_topic_summary: Optional[bool] = Field(
151+
True,
152+
description="Whether to generate topic summary for new conversations",
153+
examples=[True, False],
154+
)
155+
149156
media_type: Optional[str] = Field(
150157
None,
151158
description="Media type for the response format",
@@ -164,6 +171,7 @@ class QueryRequest(BaseModel):
164171
"model": "model-name",
165172
"system_prompt": "You are a helpful assistant",
166173
"no_tools": False,
174+
"generate_topic_summary": True,
167175
"attachments": [
168176
{
169177
"attachment_type": "log",

src/utils/endpoints.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -671,11 +671,17 @@ async def cleanup_after_streaming(
671671
session.query(UserConversation).filter_by(id=conversation_id).first()
672672
)
673673
if not existing_conversation:
674-
topic_summary = await get_topic_summary_func(
675-
query_request.query,
676-
client,
677-
llama_stack_model_id,
678-
)
674+
# Check if topic summary should be generated (default: True)
675+
should_generate = query_request.generate_topic_summary
676+
677+
if should_generate:
678+
logger.debug("Generating topic summary for new conversation")
679+
topic_summary = await get_topic_summary_func(
680+
query_request.query, client, llama_stack_model_id
681+
)
682+
else:
683+
logger.debug("Topic summary generation disabled by request parameter")
684+
topic_summary = None
679685

680686
completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
681687

tests/unit/app/endpoints/test_query.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2265,6 +2265,7 @@ async def test_get_topic_summary_create_turn_parameters(mocker: MockerFixture) -
22652265

22662266

22672267
@pytest.mark.asyncio
2268+
<<<<<<< HEAD
22682269
async def test_query_endpoint_quota_exceeded(
22692270
mocker: MockerFixture, dummy_request: Request
22702271
) -> None:
@@ -2305,3 +2306,99 @@ async def test_query_endpoint_quota_exceeded(
23052306
assert isinstance(detail, dict)
23062307
assert detail["response"] == "Model quota exceeded" # type: ignore
23072308
assert "gpt-4-turbo" in detail["cause"] # type: ignore
2309+
=======
2310+
async def test_query_endpoint_generate_topic_summary_default_true(
2311+
mocker: MockerFixture, dummy_request: Request
2312+
) -> None:
2313+
"""Test that topic summary is generated by default for new conversations."""
2314+
mock_client = mocker.AsyncMock()
2315+
mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client")
2316+
mock_lsc.return_value = mock_client
2317+
mock_client.models.list.return_value = [
2318+
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"),
2319+
]
2320+
2321+
mock_config = mocker.Mock()
2322+
mock_config.quota_limiters = []
2323+
mocker.patch("app.endpoints.query.configuration", mock_config)
2324+
2325+
summary = TurnSummary(llm_response="Test response", tool_calls=[])
2326+
mocker.patch(
2327+
"app.endpoints.query.retrieve_response",
2328+
return_value=(
2329+
summary,
2330+
"00000000-0000-0000-0000-000000000000",
2331+
[],
2332+
TokenCounter(),
2333+
),
2334+
)
2335+
2336+
mocker.patch(
2337+
"app.endpoints.query.select_model_and_provider_id",
2338+
return_value=("test_model", "test_model", "test_provider"),
2339+
)
2340+
mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False)
2341+
2342+
mock_get_topic_summary = mocker.patch(
2343+
"app.endpoints.query.get_topic_summary", return_value="Generated topic"
2344+
)
2345+
mock_database_operations(mocker)
2346+
2347+
await query_endpoint_handler(
2348+
request=dummy_request,
2349+
query_request=QueryRequest(query="test query"),
2350+
auth=("user123", "username", False, "auth_token_123"),
2351+
mcp_headers={},
2352+
)
2353+
2354+
mock_get_topic_summary.assert_called_once()
2355+
2356+
2357+
@pytest.mark.asyncio
2358+
async def test_query_endpoint_generate_topic_summary_explicit_false(
2359+
mocker: MockerFixture, dummy_request: Request
2360+
) -> None:
2361+
"""Test that topic summary is NOT generated when explicitly set to False."""
2362+
mock_client = mocker.AsyncMock()
2363+
mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client")
2364+
mock_lsc.return_value = mock_client
2365+
mock_client.models.list.return_value = [
2366+
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"),
2367+
]
2368+
2369+
mock_config = mocker.Mock()
2370+
mock_config.quota_limiters = []
2371+
mocker.patch("app.endpoints.query.configuration", mock_config)
2372+
2373+
summary = TurnSummary(llm_response="Test response", tool_calls=[])
2374+
mocker.patch(
2375+
"app.endpoints.query.retrieve_response",
2376+
return_value=(
2377+
summary,
2378+
"00000000-0000-0000-0000-000000000000",
2379+
[],
2380+
TokenCounter(),
2381+
),
2382+
)
2383+
2384+
mocker.patch(
2385+
"app.endpoints.query.select_model_and_provider_id",
2386+
return_value=("test_model", "test_model", "test_provider"),
2387+
)
2388+
mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False)
2389+
2390+
mock_get_topic_summary = mocker.patch(
2391+
"app.endpoints.query.get_topic_summary", return_value="Generated topic"
2392+
)
2393+
2394+
mock_database_operations(mocker)
2395+
2396+
await query_endpoint_handler(
2397+
request=dummy_request,
2398+
query_request=QueryRequest(query="test query", generate_topic_summary=False),
2399+
auth=("user123", "username", False, "auth_token_123"),
2400+
mcp_headers={},
2401+
)
2402+
2403+
mock_get_topic_summary.assert_not_called()
2404+
>>>>>>> 81b4b90 (added unit tests for the extra logic)

tests/unit/models/requests/test_query_request.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,15 @@ def test_validate_media_type(self, mocker: MockerFixture) -> None:
154154

155155
# Media type is now fully supported, no warning expected
156156
mock_logger.warning.assert_not_called()
157+
158+
def test_generate_topic_summary_explicit_false(self) -> None:
159+
"""Test that generate_topic_summary can be explicitly set to False."""
160+
qr = QueryRequest(
161+
query="Tell me about Kubernetes", generate_topic_summary=False
162+
)
163+
assert qr.generate_topic_summary is False
164+
165+
def test_generate_topic_summary_explicit_true(self) -> None:
166+
"""Test that generate_topic_summary can be explicitly set to True."""
167+
qr = QueryRequest(query="Tell me about Kubernetes", generate_topic_summary=True)
168+
assert qr.generate_topic_summary is True

tests/unit/utils/test_endpoints.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,3 +1036,105 @@ def test_create_referenced_documents_invalid_urls(self) -> None:
10361036
assert result[0].doc_title == "not-a-valid-url"
10371037
assert result[1].doc_url == AnyUrl("https://example.com/doc1")
10381038
assert result[1].doc_title == "doc1"
1039+
1040+
1041+
@pytest.mark.asyncio
1042+
async def test_cleanup_after_streaming_generate_topic_summary_default_true(
1043+
mocker: MockerFixture,
1044+
) -> None:
1045+
"""Test that topic summary is generated by default for new conversations."""
1046+
mock_is_transcripts_enabled = mocker.Mock(return_value=False)
1047+
mock_get_topic_summary = mocker.AsyncMock(return_value="Generated topic")
1048+
mock_store_transcript = mocker.Mock()
1049+
mock_persist_conversation = mocker.Mock()
1050+
mock_client = mocker.AsyncMock()
1051+
mock_config = mocker.Mock()
1052+
1053+
mock_session = mocker.Mock()
1054+
mock_session.query.return_value.filter_by.return_value.first.return_value = None
1055+
mock_session.__enter__ = mocker.Mock(return_value=mock_session)
1056+
mock_session.__exit__ = mocker.Mock(return_value=None)
1057+
mocker.patch("utils.endpoints.get_session", return_value=mock_session)
1058+
1059+
mocker.patch(
1060+
"utils.endpoints.create_referenced_documents_with_metadata", return_value=[]
1061+
)
1062+
mocker.patch("utils.endpoints.store_conversation_into_cache")
1063+
1064+
query_request = QueryRequest(query="test query")
1065+
1066+
await endpoints.cleanup_after_streaming(
1067+
user_id="test_user",
1068+
conversation_id="test_conv_id",
1069+
model_id="test_model",
1070+
provider_id="test_provider",
1071+
llama_stack_model_id="test_llama_model",
1072+
query_request=query_request,
1073+
summary=mocker.Mock(llm_response="test response", tool_calls=[]),
1074+
metadata_map={},
1075+
started_at="2024-01-01T00:00:00Z",
1076+
client=mock_client,
1077+
config=mock_config,
1078+
skip_userid_check=False,
1079+
get_topic_summary_func=mock_get_topic_summary,
1080+
is_transcripts_enabled_func=mock_is_transcripts_enabled,
1081+
store_transcript_func=mock_store_transcript,
1082+
persist_user_conversation_details_func=mock_persist_conversation,
1083+
)
1084+
1085+
mock_get_topic_summary.assert_called_once_with(
1086+
"test query", mock_client, "test_llama_model"
1087+
)
1088+
1089+
mock_persist_conversation.assert_called_once()
1090+
assert mock_persist_conversation.call_args[1]["topic_summary"] == "Generated topic"
1091+
1092+
1093+
@pytest.mark.asyncio
1094+
async def test_cleanup_after_streaming_generate_topic_summary_explicit_false(
1095+
mocker: MockerFixture,
1096+
) -> None:
1097+
"""Test that topic summary is NOT generated when explicitly set to False."""
1098+
mock_is_transcripts_enabled = mocker.Mock(return_value=False)
1099+
mock_get_topic_summary = mocker.AsyncMock(return_value="Generated topic")
1100+
mock_store_transcript = mocker.Mock()
1101+
mock_persist_conversation = mocker.Mock()
1102+
mock_client = mocker.AsyncMock()
1103+
mock_config = mocker.Mock()
1104+
1105+
mock_session = mocker.Mock()
1106+
mock_session.query.return_value.filter_by.return_value.first.return_value = None
1107+
mock_session.__enter__ = mocker.Mock(return_value=mock_session)
1108+
mock_session.__exit__ = mocker.Mock(return_value=None)
1109+
mocker.patch("utils.endpoints.get_session", return_value=mock_session)
1110+
1111+
mocker.patch(
1112+
"utils.endpoints.create_referenced_documents_with_metadata", return_value=[]
1113+
)
1114+
mocker.patch("utils.endpoints.store_conversation_into_cache")
1115+
1116+
query_request = QueryRequest(query="test query", generate_topic_summary=False)
1117+
1118+
await endpoints.cleanup_after_streaming(
1119+
user_id="test_user",
1120+
conversation_id="test_conv_id",
1121+
model_id="test_model",
1122+
provider_id="test_provider",
1123+
llama_stack_model_id="test_llama_model",
1124+
query_request=query_request,
1125+
summary=mocker.Mock(llm_response="test response", tool_calls=[]),
1126+
metadata_map={},
1127+
started_at="2024-01-01T00:00:00Z",
1128+
client=mock_client,
1129+
config=mock_config,
1130+
skip_userid_check=False,
1131+
get_topic_summary_func=mock_get_topic_summary,
1132+
is_transcripts_enabled_func=mock_is_transcripts_enabled,
1133+
store_transcript_func=mock_store_transcript,
1134+
persist_user_conversation_details_func=mock_persist_conversation,
1135+
)
1136+
1137+
mock_get_topic_summary.assert_not_called()
1138+
1139+
mock_persist_conversation.assert_called_once()
1140+
assert mock_persist_conversation.call_args[1]["topic_summary"] is None

0 commit comments

Comments
 (0)