Skip to content

Commit 3dfedf0

Browse files
authored
Merge pull request #984 from yangcao77/store-more-info-db
LCORE-1166: store tool calls and tool results in cache entry
2 parents d649176 + 3899c70 commit 3dfedf0

8 files changed

Lines changed: 401 additions & 16 deletions

File tree

src/app/endpoints/query.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,6 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
374374
)
375375

376376
completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
377-
378377
cache_entry = CacheEntry(
379378
query=query_request.query,
380379
response=summary.llm_response,
@@ -383,6 +382,8 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
383382
started_at=started_at,
384383
completed_at=completed_at,
385384
referenced_documents=referenced_documents if referenced_documents else None,
385+
tool_calls=summary.tool_calls if summary.tool_calls else None,
386+
tool_results=summary.tool_results if summary.tool_results else None,
386387
)
387388

388389
consume_tokens(

src/cache/postgres_cache.py

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from models.cache_entry import CacheEntry
1010
from models.config import PostgreSQLDatabaseConfiguration
1111
from models.responses import ConversationData, ReferencedDocument
12-
from log import get_logger
1312
from utils.connection_decorator import connection
13+
from utils.types import ToolCallSummary, ToolResultSummary
14+
from log import get_logger
1415

1516
logger = get_logger("cache.postgres_cache")
1617

@@ -32,7 +33,9 @@ class PostgresCache(Cache):
3233
response | text | |
3334
provider | text | |
3435
model | text | |
35-
referenced_documents | jsonb | |
36+
referenced_documents | jsonb | |
37+
tool_calls | jsonb | |
38+
tool_results | jsonb | |
3639
Indexes:
3740
"cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at)
3841
"timestamps" btree (created_at)
@@ -55,6 +58,8 @@ class PostgresCache(Cache):
5558
provider text,
5659
model text,
5760
referenced_documents jsonb,
61+
tool_calls jsonb,
62+
tool_results jsonb,
5863
PRIMARY KEY(user_id, conversation_id, created_at)
5964
);
6065
"""
@@ -75,16 +80,18 @@ class PostgresCache(Cache):
7580
"""
7681

7782
SELECT_CONVERSATION_HISTORY_STATEMENT = """
78-
SELECT query, response, provider, model, started_at, completed_at, referenced_documents
83+
SELECT query, response, provider, model, started_at, completed_at,
84+
referenced_documents, tool_calls, tool_results
7985
FROM cache
8086
WHERE user_id=%s AND conversation_id=%s
8187
ORDER BY created_at
8288
"""
8389

8490
INSERT_CONVERSATION_HISTORY_STATEMENT = """
8591
INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at,
86-
query, response, provider, model, referenced_documents)
87-
VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s, %s)
92+
query, response, provider, model, referenced_documents,
93+
tool_calls, tool_results)
94+
VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s, %s, %s, %s)
8895
"""
8996

9097
QUERY_CACHE_SIZE = """
@@ -220,7 +227,7 @@ def initialize_cache(self, namespace: str) -> None:
220227
self.connection.commit()
221228

222229
@connection
223-
def get(
230+
def get( # pylint: disable=R0914
224231
self, user_id: str, conversation_id: str, skip_user_id_check: bool = False
225232
) -> list[CacheEntry]:
226233
"""Get the value associated with the given key.
@@ -260,6 +267,40 @@ def get(
260267
conversation_id,
261268
e,
262269
)
270+
271+
# Parse tool_calls back into ToolCallSummary objects
272+
tool_calls_data = conversation_entry[7]
273+
tool_calls_obj = None
274+
if tool_calls_data:
275+
try:
276+
tool_calls_obj = [
277+
ToolCallSummary.model_validate(tc) for tc in tool_calls_data
278+
]
279+
except (ValueError, TypeError) as e:
280+
logger.warning(
281+
"Failed to deserialize tool_calls for "
282+
"conversation %s: %s",
283+
conversation_id,
284+
e,
285+
)
286+
287+
# Parse tool_results back into ToolResultSummary objects
288+
tool_results_data = conversation_entry[8]
289+
tool_results_obj = None
290+
if tool_results_data:
291+
try:
292+
tool_results_obj = [
293+
ToolResultSummary.model_validate(tr)
294+
for tr in tool_results_data
295+
]
296+
except (ValueError, TypeError) as e:
297+
logger.warning(
298+
"Failed to deserialize tool_results for "
299+
"conversation %s: %s",
300+
conversation_id,
301+
e,
302+
)
303+
263304
cache_entry = CacheEntry(
264305
query=conversation_entry[0],
265306
response=conversation_entry[1],
@@ -268,6 +309,8 @@ def get(
268309
started_at=conversation_entry[4],
269310
completed_at=conversation_entry[5],
270311
referenced_documents=docs_obj,
312+
tool_calls=tool_calls_obj,
313+
tool_results=tool_results_obj,
271314
)
272315
result.append(cache_entry)
273316

@@ -311,6 +354,34 @@ def insert_or_append(
311354
e,
312355
)
313356

357+
tool_calls_json = None
358+
if cache_entry.tool_calls:
359+
try:
360+
tool_calls_as_dicts = [
361+
tc.model_dump(mode="json") for tc in cache_entry.tool_calls
362+
]
363+
tool_calls_json = json.dumps(tool_calls_as_dicts)
364+
except (TypeError, ValueError) as e:
365+
logger.warning(
366+
"Failed to serialize tool_calls for conversation %s: %s",
367+
conversation_id,
368+
e,
369+
)
370+
371+
tool_results_json = None
372+
if cache_entry.tool_results:
373+
try:
374+
tool_results_as_dicts = [
375+
tr.model_dump(mode="json") for tr in cache_entry.tool_results
376+
]
377+
tool_results_json = json.dumps(tool_results_as_dicts)
378+
except (TypeError, ValueError) as e:
379+
logger.warning(
380+
"Failed to serialize tool_results for conversation %s: %s",
381+
conversation_id,
382+
e,
383+
)
384+
314385
# the whole operation is run in one transaction
315386
with self.connection.cursor() as cursor:
316387
cursor.execute(
@@ -325,6 +396,8 @@ def insert_or_append(
325396
cache_entry.provider,
326397
cache_entry.model,
327398
referenced_documents_json,
399+
tool_calls_json,
400+
tool_results_json,
328401
),
329402
)
330403

src/cache/sqlite_cache.py

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from models.cache_entry import CacheEntry
1111
from models.config import SQLiteDatabaseConfiguration
1212
from models.responses import ConversationData, ReferencedDocument
13-
from log import get_logger
1413
from utils.connection_decorator import connection
14+
from utils.types import ToolCallSummary, ToolResultSummary
15+
from log import get_logger
1516

1617
logger = get_logger("cache.sqlite_cache")
1718

@@ -34,6 +35,8 @@ class SQLiteCache(Cache):
3435
provider | text | |
3536
model | text | |
3637
referenced_documents | text | |
38+
tool_calls | text | |
39+
tool_results | text | |
3740
Indexes:
3841
"cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at)
3942
"cache_key_key" UNIQUE CONSTRAINT, btree (key)
@@ -54,6 +57,8 @@ class SQLiteCache(Cache):
5457
provider text,
5558
model text,
5659
referenced_documents text,
60+
tool_calls text,
61+
tool_results text,
5762
PRIMARY KEY(user_id, conversation_id, created_at)
5863
);
5964
"""
@@ -74,16 +79,18 @@ class SQLiteCache(Cache):
7479
"""
7580

7681
SELECT_CONVERSATION_HISTORY_STATEMENT = """
77-
SELECT query, response, provider, model, started_at, completed_at, referenced_documents
82+
SELECT query, response, provider, model, started_at, completed_at,
83+
referenced_documents, tool_calls, tool_results
7884
FROM cache
7985
WHERE user_id=? AND conversation_id=?
8086
ORDER BY created_at
8187
"""
8288

8389
INSERT_CONVERSATION_HISTORY_STATEMENT = """
8490
INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at,
85-
query, response, provider, model, referenced_documents)
86-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
91+
query, response, provider, model, referenced_documents,
92+
tool_calls, tool_results)
93+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
8794
"""
8895

8996
QUERY_CACHE_SIZE = """
@@ -187,7 +194,7 @@ def initialize_cache(self) -> None:
187194
self.connection.commit()
188195

189196
@connection
190-
def get(
197+
def get( # pylint: disable=R0914
191198
self, user_id: str, conversation_id: str, skip_user_id_check: bool = False
192199
) -> list[CacheEntry]:
193200
"""Get the value associated with the given key.
@@ -228,6 +235,39 @@ def get(
228235
conversation_id,
229236
e,
230237
)
238+
239+
# Parse tool_calls back into ToolCallSummary objects
240+
tool_calls_json_str = conversation_entry[7]
241+
tool_calls_obj = None
242+
if tool_calls_json_str:
243+
try:
244+
tool_calls_data = json.loads(tool_calls_json_str)
245+
tool_calls_obj = [
246+
ToolCallSummary.model_validate(tc) for tc in tool_calls_data
247+
]
248+
except (json.JSONDecodeError, ValueError) as e:
249+
logger.warning(
250+
"Failed to deserialize tool_calls for conversation %s: %s",
251+
conversation_id,
252+
e,
253+
)
254+
255+
# Parse tool_results back into ToolResultSummary objects
256+
tool_results_json_str = conversation_entry[8]
257+
tool_results_obj = None
258+
if tool_results_json_str:
259+
try:
260+
tool_results_data = json.loads(tool_results_json_str)
261+
tool_results_obj = [
262+
ToolResultSummary.model_validate(tr) for tr in tool_results_data
263+
]
264+
except (json.JSONDecodeError, ValueError) as e:
265+
logger.warning(
266+
"Failed to deserialize tool_results for conversation %s: %s",
267+
conversation_id,
268+
e,
269+
)
270+
231271
cache_entry = CacheEntry(
232272
query=conversation_entry[0],
233273
response=conversation_entry[1],
@@ -236,6 +276,8 @@ def get(
236276
started_at=conversation_entry[4],
237277
completed_at=conversation_entry[5],
238278
referenced_documents=docs_obj,
279+
tool_calls=tool_calls_obj,
280+
tool_results=tool_results_obj,
239281
)
240282
result.append(cache_entry)
241283

@@ -281,6 +323,34 @@ def insert_or_append(
281323
e,
282324
)
283325

326+
tool_calls_json = None
327+
if cache_entry.tool_calls:
328+
try:
329+
tool_calls_as_dicts = [
330+
tc.model_dump(mode="json") for tc in cache_entry.tool_calls
331+
]
332+
tool_calls_json = json.dumps(tool_calls_as_dicts)
333+
except (TypeError, ValueError) as e:
334+
logger.warning(
335+
"Failed to serialize tool_calls for conversation %s: %s",
336+
conversation_id,
337+
e,
338+
)
339+
340+
tool_results_json = None
341+
if cache_entry.tool_results:
342+
try:
343+
tool_results_as_dicts = [
344+
tr.model_dump(mode="json") for tr in cache_entry.tool_results
345+
]
346+
tool_results_json = json.dumps(tool_results_as_dicts)
347+
except (TypeError, ValueError) as e:
348+
logger.warning(
349+
"Failed to serialize tool_results for conversation %s: %s",
350+
conversation_id,
351+
e,
352+
)
353+
284354
cursor.execute(
285355
self.INSERT_CONVERSATION_HISTORY_STATEMENT,
286356
(
@@ -294,6 +364,8 @@ def insert_or_append(
294364
cache_entry.provider,
295365
cache_entry.model,
296366
referenced_documents_json,
367+
tool_calls_json,
368+
tool_results_json,
297369
),
298370
)
299371

src/models/cache_entry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Optional
44
from pydantic import BaseModel
55
from models.responses import ReferencedDocument
6+
from utils.types import ToolCallSummary, ToolResultSummary
67

78

89
class CacheEntry(BaseModel):
@@ -14,6 +15,8 @@ class CacheEntry(BaseModel):
1415
provider: Provider identification
1516
model: Model identification
1617
referenced_documents: List of documents referenced by the response
18+
tool_calls: List of tool calls made during response generation
19+
tool_results: List of tool results from tool calls
1720
"""
1821

1922
query: str
@@ -23,3 +26,5 @@ class CacheEntry(BaseModel):
2326
started_at: str
2427
completed_at: str
2528
referenced_documents: Optional[list[ReferencedDocument]] = None
29+
tool_calls: Optional[list[ToolCallSummary]] = None
30+
tool_results: Optional[list[ToolResultSummary]] = None

src/utils/endpoints.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,8 @@ async def cleanup_after_streaming(
806806
started_at=started_at,
807807
completed_at=completed_at,
808808
referenced_documents=referenced_documents if referenced_documents else None,
809+
tool_calls=summary.tool_calls if summary.tool_calls else None,
810+
tool_results=summary.tool_results if summary.tool_results else None,
809811
)
810812

811813
store_conversation_into_cache(

0 commit comments

Comments
 (0)