diff --git a/cecli/args.py b/cecli/args.py index 60f04324259..c0984a79edb 100644 --- a/cecli/args.py +++ b/cecli/args.py @@ -297,6 +297,12 @@ def get_parser(default_config_files, git_root): help="Specify Agent Mode configuration as a JSON string", default=None, ) + group.add_argument( + "--agent-model", + metavar="AGENT_MODEL", + default=None, + help="Specify the model to use for Agent mode (default depends on --model)", + ) group.add_argument( "--auto-save", action=argparse.BooleanOptionalAction, @@ -1111,7 +1117,10 @@ def main(): shell = sys.argv[2] if shell not in shtab.SUPPORTED_SHELLS: print(f"Error: Unsupported shell '{shell}'.", file=sys.stderr) - print(f"Supported shells are: {', '.join(shtab.SUPPORTED_SHELLS)}", file=sys.stderr) + print( + f"Supported shells are: {', '.join(shtab.SUPPORTED_SHELLS)}", + file=sys.stderr, + ) sys.exit(1) parser = get_parser([], None) parser.prog = "cecli" # Set the program name on the parser @@ -1119,7 +1128,10 @@ def main(): else: print("Error: Please specify a shell for completion.", file=sys.stderr) print(f"Usage: python {sys.argv[0]} completion ", file=sys.stderr) - print(f"Supported shells are: {', '.join(shtab.SUPPORTED_SHELLS)}", file=sys.stderr) + print( + f"Supported shells are: {', '.join(shtab.SUPPORTED_SHELLS)}", + file=sys.stderr, + ) sys.exit(1) else: # Default to YAML for any other unrecognized argument, or if 'yaml' was explicitly passed diff --git a/cecli/coders/base_coder.py b/cecli/coders/base_coder.py index 734fb2f2c65..66e49e96770 100755 --- a/cecli/coders/base_coder.py +++ b/cecli/coders/base_coder.py @@ -660,6 +660,7 @@ def get_announcements(self): # Model main_model = self.main_model weak_model = main_model.weak_model + agent_model = main_model.agent_model if weak_model is not main_model: prefix = "Main model" @@ -698,6 +699,10 @@ def get_announcements(self): output = f"Weak model: {weak_model.name}" lines.append(output) + if agent_model is not main_model: + output = f"Agent model: {agent_model.name}" + lines.append(output) + # Repo if self.repo: rel_repo_dir = self.repo.get_rel_repo_dir() diff --git a/cecli/commands/__init__.py b/cecli/commands/__init__.py index ef7dff9dad5..528e18cecb8 100644 --- a/cecli/commands/__init__.py +++ b/cecli/commands/__init__.py @@ -7,6 +7,7 @@ from .add import AddCommand from .agent import AgentCommand +from .agent_model import AgentModelCommand from .architect import ArchitectCommand from .ask import AskCommand from .clear import ClearCommand @@ -77,6 +78,7 @@ # Register commands CommandRegistry.register(AddCommand) CommandRegistry.register(AgentCommand) +CommandRegistry.register(AgentModelCommand) CommandRegistry.register(ArchitectCommand) CommandRegistry.register(AskCommand) CommandRegistry.register(ClearCommand) @@ -136,6 +138,7 @@ __all__ = [ "AddCommand", "AgentCommand", + "AgentModelCommand", "ArchitectCommand", "AskCommand", "BaseCommand", diff --git a/cecli/commands/agent_model.py b/cecli/commands/agent_model.py new file mode 100644 index 00000000000..64d1e4a807c --- /dev/null +++ b/cecli/commands/agent_model.py @@ -0,0 +1,143 @@ +from typing import List + +import cecli.models as models +from cecli.commands.utils.base_command import BaseCommand +from cecli.commands.utils.helpers import format_command_result +from cecli.helpers.conversation import ConversationManager, MessageTag + + +class AgentModelCommand(BaseCommand): + NORM_NAME = "agent-model" + DESCRIPTION = "Switch the Agent Model to a new LLM" + + @classmethod + async def execute(cls, io, coder, args, **kwargs): + """Execute the agent-model command with given parameters.""" + arg_split = args.split(" ", 1) + model_name = arg_split[0].strip() + if not model_name: + # If no model name provided, show current agent model + current_agent_model = coder.main_model.agent_model.name + io.tool_output(f"Current agent model: {current_agent_model}") + return format_command_result( + io, "agent-model", f"Displayed current agent model: {current_agent_model}" + ) + + # Create a new model with the same main model and editor model, but updated agent model + model = models.Model( + coder.main_model.name, + editor_model=coder.main_model.editor_model.name, + weak_model=coder.main_model.weak_model.name, + agent_model=model_name, + io=io, + retries=coder.main_model.retries, + debug=coder.main_model.debug, + ) + await models.sanity_check_models(io, model) + + if len(arg_split) > 1: + # implement architect coder-like generation call for agent model + message = arg_split[1].strip() + + # Store the original model configuration + original_main_model = coder.main_model + original_edit_format = coder.edit_format + + # Create a temporary coder with the new model + from cecli.coders import Coder + + kwargs = dict() + kwargs["main_model"] = model + kwargs["edit_format"] = coder.edit_format # Keep the same edit format + kwargs["suggest_shell_commands"] = False + kwargs["total_cost"] = coder.total_cost + kwargs["num_cache_warming_pings"] = 0 + kwargs["summarize_from_coder"] = False + kwargs["done_messages"] = [] + kwargs["cur_messages"] = [] + + new_kwargs = dict(io=io, from_coder=coder) + new_kwargs.update(kwargs) + + # Save current conversation state + original_coder = coder + + temp_coder = await Coder.create(**new_kwargs) + + # Re-initialize ConversationManager with temp coder + ConversationManager.initialize( + temp_coder, + reset=True, + reformat=True, + preserve_tags=[MessageTag.DONE, MessageTag.CUR], + ) + + verbose = kwargs.get("verbose", False) + if verbose: + temp_coder.show_announcements() + + try: + await temp_coder.generate(user_message=message, preproc=False) + coder.total_cost = temp_coder.total_cost + coder.coder_commit_hashes = temp_coder.coder_commit_hashes + + # Clear manager and restore original state + ConversationManager.initialize( + original_coder, + reset=True, + reformat=True, + preserve_tags=[MessageTag.DONE, MessageTag.CUR], + ) + + # Restore the original model configuration + from cecli.commands import SwitchCoderSignal + + raise SwitchCoderSignal( + main_model=original_main_model, edit_format=original_edit_format + ) + except Exception as e: + # If there's an error, still restore the original model + if not isinstance(e, SwitchCoderSignal): + io.tool_error(str(e)) + raise SwitchCoderSignal( + main_model=original_main_model, edit_format=original_edit_format + ) + else: + # Re-raise SwitchCoderSignal if that's what was thrown + raise + else: + from cecli.commands import SwitchCoderSignal + + raise SwitchCoderSignal(main_model=model, edit_format=coder.edit_format) + + @classmethod + def get_completions(cls, io, coder, args) -> List[str]: + """Get completion options for agent-model command.""" + return models.get_chat_model_names() + + @classmethod + def get_help(cls) -> str: + """Get help text for the agent-model command.""" + help_text = super().get_help() + help_text += "\nUsage:\n" + help_text += " /agent-model # Switch to a new agent model\n" + help_text += ( + " /agent-model # Use a specific agent model for a single" + " prompt\n" + ) + help_text += "\nExamples:\n" + help_text += ( + " /agent-model gpt-4o-mini # Switch to GPT-4o Mini as agent model\n" + ) + help_text += ( + " /agent-model claude-3-haiku # Switch to Claude 3 Haiku as agent model\n" + ) + help_text += ' /agent-model o1-mini "review this code" # Use o1-mini to review code\n' + help_text += ( + "\nWhen switching agent models, the main model and editor model remain unchanged.\n" + ) + help_text += ( + "\nIf you provide a prompt after the model name, that agent model will be used\n" + ) + help_text += "just for that prompt, then you'll return to your original agent model.\n" + return help_text diff --git a/cecli/commands/editor_model.py b/cecli/commands/editor_model.py index de3b581cb2d..646604f9463 100644 --- a/cecli/commands/editor_model.py +++ b/cecli/commands/editor_model.py @@ -28,6 +28,7 @@ async def execute(cls, io, coder, args, **kwargs): coder.main_model.name, editor_model=model_name, weak_model=coder.main_model.weak_model.name, + agent_model=coder.main_model.agent_model.name, io=io, retries=coder.main_model.retries, debug=coder.main_model.debug, diff --git a/cecli/commands/settings.py b/cecli/commands/settings.py index 2dd7f6010ab..864db4686bf 100644 --- a/cecli/commands/settings.py +++ b/cecli/commands/settings.py @@ -30,6 +30,7 @@ async def execute(cls, io, coder, args, **kwargs): ("Main model", coder.main_model), ("Editor model", getattr(coder.main_model, "editor_model", None)), ("Weak model", getattr(coder.main_model, "weak_model", None)), + ("Agent model", getattr(coder.main_model, "agent_model", None)), ] for label, model in active_models: if not model: diff --git a/cecli/main.py b/cecli/main.py index 33310f01d09..c3519a931a6 100644 --- a/cecli/main.py +++ b/cecli/main.py @@ -826,6 +826,7 @@ def apply_model_overrides(model_name): main_model_name, main_model_overrides = apply_model_overrides(args.model) weak_model_name, weak_model_overrides = apply_model_overrides(args.weak_model) editor_model_name, editor_model_overrides = apply_model_overrides(args.editor_model) + agent_model_name, agent_model_overrides = apply_model_overrides(args.agent_model) weak_model_obj = None if weak_model_name: weak_model_obj = models.Model( @@ -848,6 +849,18 @@ def apply_model_overrides(model_name): retries=args.retries, debug=args.debug, ) + agent_model_obj = None + if agent_model_name: + agent_model_obj = models.Model( + agent_model_name, + agent_model=False, + verbose=args.verbose, + io=io, + override_kwargs=agent_model_overrides, + retries=args.retries, + debug=args.debug, + ) + if main_model_name.startswith("openrouter/") and not os.environ.get("OPENROUTER_API_KEY"): io.tool_warning( f"The specified model '{main_model_name}' requires an OpenRouter API key, which was not" @@ -873,6 +886,7 @@ def apply_model_overrides(model_name): main_model_name, weak_model=weak_model_obj, editor_model=editor_model_obj, + agent_model=agent_model_obj, editor_edit_format=args.editor_edit_format, verbose=args.verbose, io=io, diff --git a/cecli/models.py b/cecli/models.py index b96ce661628..2ab173576dd 100644 --- a/cecli/models.py +++ b/cecli/models.py @@ -105,6 +105,7 @@ class ModelSettings: name: str edit_format: str = "diff" weak_model_name: Optional[str] = None + agent_model_name: Optional[str] = None use_repo_map: bool = False send_undo_reply: bool = False lazy: bool = False @@ -314,6 +315,7 @@ def __init__( model, weak_model=None, editor_model=None, + agent_model=None, editor_edit_format=None, verbose=False, io=None, @@ -341,6 +343,7 @@ def __init__( self.max_chat_history_tokens = 1024 self.weak_model = None self.editor_model = None + self.agent_model = None self.extra_model_settings = next( (ms for ms in MODEL_SETTINGS if ms.name == "cecli/extra_params"), None ) @@ -354,6 +357,7 @@ def __init__( self.configure_model_settings(model) self._apply_provider_defaults() self.get_weak_model(weak_model) + self.get_agent_model(agent_model) self.retries = retries self.debug = debug @@ -580,6 +584,30 @@ def get_weak_model(self, provided_weak_model): self.weak_model = Model(self.weak_model_name, weak_model=False, io=self.io) return self.weak_model + def get_agent_model(self, provided_weak_model): + if provided_weak_model is False: + self.agent_model = self + self.agent_model_name = None + return + if self.copy_paste_transport == "clipboard": + self.agent_model = self + self.agent_model_name = None + return + if isinstance(provided_weak_model, Model): + self.agent_model = provided_weak_model + self.agent_model_name = provided_weak_model.name + return + if provided_weak_model: + self.agent_model_name = provided_weak_model + if not self.agent_model_name: + self.agent_model = self + return + if self.agent_model_name == self.name: + self.agent_model = self + return + self.agent_model = Model(self.agent_model_name, agent_model=False, io=self.io) + return self.agent_model + def commit_message_models(self): return [self.weak_model, self] diff --git a/cecli/sessions.py b/cecli/sessions.py index 18c13d4b9a5..2c5b633db6e 100644 --- a/cecli/sessions.py +++ b/cecli/sessions.py @@ -152,6 +152,7 @@ def _build_session_data(self, session_name) -> Dict: "model": self.coder.main_model.name, "weak_model": self.coder.main_model.weak_model.name, "editor_model": self.coder.main_model.editor_model.name, + "agent_model": self.coder.main_model.agent_model.name, "editor_edit_format": self.coder.main_model.editor_edit_format, "edit_format": self.coder.edit_format, "chat_history": {