|
31 | 31 |
|
32 | 32 | import httpx |
33 | 33 | from litellm import experimental_mcp_client |
34 | | -from litellm.types.utils import ModelResponse |
| 34 | +from litellm.types.utils import ChatCompletionMessageToolCall, Function, ModelResponse |
35 | 35 | from prompt_toolkit.patch_stdout import patch_stdout |
36 | 36 | from rich.console import Console |
37 | 37 |
|
|
64 | 64 | from cecli.sessions import SessionManager |
65 | 65 | from cecli.tools.utils.output import print_tool_response |
66 | 66 | from cecli.tools.utils.registry import ToolRegistry |
67 | | -from cecli.utils import format_tokens, is_image_file |
| 67 | +from cecli.utils import copy_tool_call, format_tokens, is_image_file |
68 | 68 |
|
69 | 69 | from ..dump import dump # noqa: F401 |
70 | 70 | from ..prompts.utils.registry import PromptObject, PromptRegistry |
@@ -2357,23 +2357,19 @@ async def send_message(self, inp): |
2357 | 2357 | return |
2358 | 2358 |
|
2359 | 2359 | async def process_tool_calls(self, tool_call_response): |
2360 | | - if tool_call_response is None: |
2361 | | - return False |
2362 | | - |
2363 | | - # Handle different response structures |
2364 | | - try: |
2365 | | - # Try to get tool calls from the standard OpenAI response format |
2366 | | - if hasattr(tool_call_response, "choices") and tool_call_response.choices: |
2367 | | - message = tool_call_response.choices[0].message |
2368 | | - if hasattr(message, "tool_calls") and message.tool_calls: |
2369 | | - original_tool_calls = message.tool_calls |
2370 | | - else: |
2371 | | - return False |
2372 | | - else: |
2373 | | - # Handle other response formats |
2374 | | - return False |
2375 | | - except (AttributeError, IndexError): |
2376 | | - return False |
| 2360 | + # Use partial_response_tool_calls if available (populated by consolidate_chunks) |
| 2361 | + # otherwise try to extract from tool_call_response |
| 2362 | + original_tool_calls = [] |
| 2363 | + if self.partial_response_tool_calls: |
| 2364 | + original_tool_calls = self.partial_response_tool_calls |
| 2365 | + elif tool_call_response is not None: |
| 2366 | + try: |
| 2367 | + if hasattr(tool_call_response, "choices") and tool_call_response.choices: |
| 2368 | + message = tool_call_response.choices[0].message |
| 2369 | + if hasattr(message, "tool_calls") and message.tool_calls: |
| 2370 | + original_tool_calls = message.tool_calls |
| 2371 | + except (AttributeError, IndexError): |
| 2372 | + pass |
2377 | 2373 |
|
2378 | 2374 | if not original_tool_calls: |
2379 | 2375 | return False |
@@ -2404,10 +2400,13 @@ async def process_tool_calls(self, tool_call_response): |
2404 | 2400 | continue |
2405 | 2401 |
|
2406 | 2402 | # Create a new tool call for each JSON chunk, with a unique ID. |
2407 | | - new_function = tool_call.function.model_copy(update={"arguments": chunk}) |
2408 | | - new_tool_call = tool_call.model_copy( |
2409 | | - update={"id": f"{tool_call.id}-{i}", "function": new_function} |
2410 | | - ) |
| 2403 | + new_tool_call = copy_tool_call(tool_call) |
| 2404 | + if hasattr(new_tool_call, "model_copy"): |
| 2405 | + new_tool_call.function.arguments = chunk |
| 2406 | + new_tool_call.id = f"{tool_call.id}-{i}" |
| 2407 | + else: |
| 2408 | + new_tool_call.function.arguments = chunk |
| 2409 | + new_tool_call.id = f"{getattr(tool_call, 'id', 'call')}-{i}" |
2411 | 2410 | expanded_tool_calls.append(new_tool_call) |
2412 | 2411 |
|
2413 | 2412 | # Collect all tool calls grouped by server |
@@ -2551,7 +2550,7 @@ async def _exec_server_tools(server, tool_calls_list): |
2551 | 2550 |
|
2552 | 2551 | all_results_content = [] |
2553 | 2552 | for args in parsed_args_list: |
2554 | | - new_tool_call = tool_call.model_copy(deep=True) |
| 2553 | + new_tool_call = copy_tool_call(tool_call) |
2555 | 2554 | new_tool_call.function.arguments = json.dumps(args) |
2556 | 2555 |
|
2557 | 2556 | call_result = await experimental_mcp_client.call_openai_tool( |
@@ -2806,6 +2805,7 @@ def add_assistant_reply_to_cur_messages(self): |
2806 | 2805 | ConversationManager.add_message( |
2807 | 2806 | message_dict=msg, |
2808 | 2807 | tag=MessageTag.CUR, |
| 2808 | + hash_key=("assistant_message", str(msg), str(time.monotonic_ns())), |
2809 | 2809 | ) |
2810 | 2810 |
|
2811 | 2811 | def get_file_mentions(self, content, ignore_current=False): |
@@ -3202,22 +3202,9 @@ def consolidate_chunks(self): |
3202 | 3202 | # Add provider-specific fields directly to the tool call object |
3203 | 3203 | tool_call.provider_specific_fields = provider_specific_fields_by_index[i] |
3204 | 3204 |
|
3205 | | - # Create dictionary version with provider-specific fields |
3206 | | - tool_call_dict = tool_call.model_dump() |
3207 | | - |
3208 | | - # Add provider-specific fields to the dictionary too (in case model_dump() doesn't include them) |
3209 | | - if tool_id in provider_specific_fields_by_id: |
3210 | | - tool_call_dict["provider_specific_fields"] = provider_specific_fields_by_id[ |
3211 | | - tool_id |
3212 | | - ] |
3213 | | - elif i in provider_specific_fields_by_index: |
3214 | | - tool_call_dict["provider_specific_fields"] = ( |
3215 | | - provider_specific_fields_by_index[i] |
3216 | | - ) |
3217 | | - |
3218 | 3205 | # Only append to partial_response_tool_calls if it's empty |
3219 | 3206 | if len(self.partial_response_tool_calls) == 0: |
3220 | | - self.partial_response_tool_calls.append(tool_call_dict) |
| 3207 | + self.partial_response_tool_calls.append(tool_call) |
3221 | 3208 |
|
3222 | 3209 | self.partial_response_function_call = ( |
3223 | 3210 | response.choices[0].message.tool_calls[0].function |
@@ -3253,6 +3240,70 @@ def consolidate_chunks(self): |
3253 | 3240 | except AttributeError as e: |
3254 | 3241 | content_err = e |
3255 | 3242 |
|
| 3243 | + # If no native tool calls, check if the content contains JSON tool calls |
| 3244 | + # This handles models that write JSON in text instead of using native calling |
| 3245 | + if not self.partial_response_tool_calls and self.partial_response_content: |
| 3246 | + try: |
| 3247 | + # Simple extraction of JSON-like structures that look like tool calls |
| 3248 | + # Only look for tool calls if it looks like JSON |
| 3249 | + if "{" in self.partial_response_content or "[" in self.partial_response_content: |
| 3250 | + json_chunks = utils.split_concatenated_json(self.partial_response_content) |
| 3251 | + extracted_calls = [] |
| 3252 | + chunk_index = 0 |
| 3253 | + |
| 3254 | + for chunk in json_chunks: |
| 3255 | + chunk_index += 1 |
| 3256 | + try: |
| 3257 | + json_obj = json.loads(chunk) |
| 3258 | + if ( |
| 3259 | + isinstance(json_obj, dict) |
| 3260 | + and "name" in json_obj |
| 3261 | + and "arguments" in json_obj |
| 3262 | + ): |
| 3263 | + # Create a Pydantic model for the tool call |
| 3264 | + function_obj = Function( |
| 3265 | + name=json_obj["name"], |
| 3266 | + arguments=( |
| 3267 | + json.dumps(json_obj["arguments"]) |
| 3268 | + if isinstance(json_obj["arguments"], (dict, list)) |
| 3269 | + else str(json_obj["arguments"]) |
| 3270 | + ), |
| 3271 | + ) |
| 3272 | + tool_call_obj = ChatCompletionMessageToolCall( |
| 3273 | + type="function", |
| 3274 | + function=function_obj, |
| 3275 | + id=f"call_{len(extracted_calls)}_{int(time.time())}_{chunk_index}", |
| 3276 | + ) |
| 3277 | + extracted_calls.append(tool_call_obj) |
| 3278 | + elif isinstance(json_obj, list): |
| 3279 | + for item in json_obj: |
| 3280 | + if ( |
| 3281 | + isinstance(item, dict) |
| 3282 | + and "name" in item |
| 3283 | + and "arguments" in item |
| 3284 | + ): |
| 3285 | + function_obj = Function( |
| 3286 | + name=item["name"], |
| 3287 | + arguments=( |
| 3288 | + json.dumps(item["arguments"]) |
| 3289 | + if isinstance(item["arguments"], (dict, list)) |
| 3290 | + else str(item["arguments"]) |
| 3291 | + ), |
| 3292 | + ) |
| 3293 | + tool_call_obj = ChatCompletionMessageToolCall( |
| 3294 | + type="function", |
| 3295 | + function=function_obj, |
| 3296 | + id=f"call_{len(extracted_calls)}_{int(time.time())}_{chunk_index}", |
| 3297 | + ) |
| 3298 | + extracted_calls.append(tool_call_obj) |
| 3299 | + except json.JSONDecodeError: |
| 3300 | + continue |
| 3301 | + |
| 3302 | + if extracted_calls: |
| 3303 | + self.partial_response_tool_calls = extracted_calls |
| 3304 | + except Exception: |
| 3305 | + pass |
| 3306 | + |
3256 | 3307 | return response, func_err, content_err |
3257 | 3308 |
|
3258 | 3309 | def stream_wrapper(self, content, final): |
@@ -3298,13 +3349,19 @@ def preprocess_response(self): |
3298 | 3349 | tool_list = [] |
3299 | 3350 | tool_id_set = set() |
3300 | 3351 |
|
3301 | | - for tool_call_dict in self.partial_response_tool_calls: |
| 3352 | + for tool_call in self.partial_response_tool_calls: |
| 3353 | + # Handle both dictionary and object tool calls |
| 3354 | + if isinstance(tool_call, dict): |
| 3355 | + tool_id = tool_call.get("id") |
| 3356 | + else: |
| 3357 | + tool_id = getattr(tool_call, "id", None) |
| 3358 | + |
3302 | 3359 | # LLM APIs sometimes return duplicates and that's annoying part 2 |
3303 | | - if tool_call_dict.get("id") in tool_id_set: |
| 3360 | + if tool_id in tool_id_set: |
3304 | 3361 | continue |
3305 | 3362 |
|
3306 | | - tool_id_set.add(tool_call_dict.get("id")) |
3307 | | - tool_list.append(tool_call_dict) |
| 3363 | + tool_id_set.add(tool_id) |
| 3364 | + tool_list.append(tool_call) |
3308 | 3365 |
|
3309 | 3366 | self.partial_response_tool_calls = tool_list |
3310 | 3367 |
|
|
0 commit comments