diff --git a/cecli/coders/agent_coder.py b/cecli/coders/agent_coder.py index f523adbe4e1..5c0a119b9cf 100644 --- a/cecli/coders/agent_coder.py +++ b/cecli/coders/agent_coder.py @@ -95,6 +95,15 @@ def __init__(self, *args, **kwargs): ToolRegistry.build_registry(agent_config=self.agent_config) super().__init__(*args, **kwargs) + async def send(self, messages, model=None, functions=None, tools=None): + if not model: + if self.main_model.agent_model and self.main_model.agent_model is not self.main_model: + model = self.main_model.agent_model + else: + model = self.main_model + async for chunk in super().send(messages, model, functions, tools): + yield chunk + def _setup_agent(self): os.makedirs(".cecli/workspace", exist_ok=True) diff --git a/cecli/coders/base_coder.py b/cecli/coders/base_coder.py index c77c0e6c202..b0290470b9d 100755 --- a/cecli/coders/base_coder.py +++ b/cecli/coders/base_coder.py @@ -2146,6 +2146,11 @@ async def check_tokens(self, messages): return False return True + def get_active_model_name(self): + if self.edit_format == "agent" and self.main_model.agent_model: + return self.main_model.agent_model.name + return self.main_model.name + async def send_message(self, inp): # Notify IO that LLM processing is starting self.io.llm_started() @@ -2175,7 +2180,8 @@ async def send_message(self, inp): self.multi_response_content = "" if self.show_pretty(): spinner_text = ( - f"Waiting for {self.main_model.name} • ${self.format_cost(self.total_cost)} session" + f"Waiting for {self.get_active_model_name()} •" + f" ${self.format_cost(self.total_cost)} session" ) self.io.start_spinner(spinner_text) diff --git a/cecli/tui/widgets/footer.py b/cecli/tui/widgets/footer.py index 1a4934c4b42..75200c30c04 100644 --- a/cecli/tui/widgets/footer.py +++ b/cecli/tui/widgets/footer.py @@ -63,7 +63,7 @@ def _get_display_model(self) -> str: if not self.model_name: return "" # Strip common prefixes like "openrouter/x-ai/" - name = self.app.worker.coder.main_model.name + name = self.app.worker.coder.get_active_model_name() if len(name) > 40: if "/" in name: name = name.split("/")[-1]