Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions cecli/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ def get_parser(default_config_files, git_root):
help="Specify Agent Mode configuration as a JSON string",
default=None,
)
group.add_argument(
"--agent-model",
metavar="AGENT_MODEL",
default=None,
help="Specify the model to use for Agent mode (default depends on --model)",
)
group.add_argument(
"--auto-save",
action=argparse.BooleanOptionalAction,
Expand Down Expand Up @@ -1111,15 +1117,21 @@ def main():
shell = sys.argv[2]
if shell not in shtab.SUPPORTED_SHELLS:
print(f"Error: Unsupported shell '{shell}'.", file=sys.stderr)
print(f"Supported shells are: {', '.join(shtab.SUPPORTED_SHELLS)}", file=sys.stderr)
print(
f"Supported shells are: {', '.join(shtab.SUPPORTED_SHELLS)}",
file=sys.stderr,
)
sys.exit(1)
parser = get_parser([], None)
parser.prog = "cecli" # Set the program name on the parser
print(shtab.complete(parser, shell=shell))
else:
print("Error: Please specify a shell for completion.", file=sys.stderr)
print(f"Usage: python {sys.argv[0]} completion <shell_name>", file=sys.stderr)
print(f"Supported shells are: {', '.join(shtab.SUPPORTED_SHELLS)}", file=sys.stderr)
print(
f"Supported shells are: {', '.join(shtab.SUPPORTED_SHELLS)}",
file=sys.stderr,
)
sys.exit(1)
else:
# Default to YAML for any other unrecognized argument, or if 'yaml' was explicitly passed
Expand Down
5 changes: 5 additions & 0 deletions cecli/coders/base_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ def get_announcements(self):
# Model
main_model = self.main_model
weak_model = main_model.weak_model
agent_model = main_model.agent_model

if weak_model is not main_model:
prefix = "Main model"
Expand Down Expand Up @@ -698,6 +699,10 @@ def get_announcements(self):
output = f"Weak model: {weak_model.name}"
lines.append(output)

if agent_model is not main_model:
output = f"Agent model: {agent_model.name}"
lines.append(output)

# Repo
if self.repo:
rel_repo_dir = self.repo.get_rel_repo_dir()
Expand Down
3 changes: 3 additions & 0 deletions cecli/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .add import AddCommand
from .agent import AgentCommand
from .agent_model import AgentModelCommand
from .architect import ArchitectCommand
from .ask import AskCommand
from .clear import ClearCommand
Expand Down Expand Up @@ -77,6 +78,7 @@
# Register commands
CommandRegistry.register(AddCommand)
CommandRegistry.register(AgentCommand)
CommandRegistry.register(AgentModelCommand)
CommandRegistry.register(ArchitectCommand)
CommandRegistry.register(AskCommand)
CommandRegistry.register(ClearCommand)
Expand Down Expand Up @@ -136,6 +138,7 @@
__all__ = [
"AddCommand",
"AgentCommand",
"AgentModelCommand",
"ArchitectCommand",
"AskCommand",
"BaseCommand",
Expand Down
143 changes: 143 additions & 0 deletions cecli/commands/agent_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from typing import List

import cecli.models as models
from cecli.commands.utils.base_command import BaseCommand
from cecli.commands.utils.helpers import format_command_result
from cecli.helpers.conversation import ConversationManager, MessageTag


class AgentModelCommand(BaseCommand):
NORM_NAME = "agent-model"
DESCRIPTION = "Switch the Agent Model to a new LLM"

@classmethod
async def execute(cls, io, coder, args, **kwargs):
"""Execute the agent-model command with given parameters."""
arg_split = args.split(" ", 1)
model_name = arg_split[0].strip()
if not model_name:
# If no model name provided, show current agent model
current_agent_model = coder.main_model.agent_model.name
io.tool_output(f"Current agent model: {current_agent_model}")
return format_command_result(
io, "agent-model", f"Displayed current agent model: {current_agent_model}"
)

# Create a new model with the same main model and editor model, but updated agent model
model = models.Model(
coder.main_model.name,
editor_model=coder.main_model.editor_model.name,
weak_model=coder.main_model.weak_model.name,
agent_model=model_name,
io=io,
retries=coder.main_model.retries,
debug=coder.main_model.debug,
)
await models.sanity_check_models(io, model)

if len(arg_split) > 1:
# implement architect coder-like generation call for agent model
message = arg_split[1].strip()

# Store the original model configuration
original_main_model = coder.main_model
original_edit_format = coder.edit_format

# Create a temporary coder with the new model
from cecli.coders import Coder

kwargs = dict()
kwargs["main_model"] = model
kwargs["edit_format"] = coder.edit_format # Keep the same edit format
kwargs["suggest_shell_commands"] = False
kwargs["total_cost"] = coder.total_cost
kwargs["num_cache_warming_pings"] = 0
kwargs["summarize_from_coder"] = False
kwargs["done_messages"] = []
kwargs["cur_messages"] = []

new_kwargs = dict(io=io, from_coder=coder)
new_kwargs.update(kwargs)

# Save current conversation state
original_coder = coder

temp_coder = await Coder.create(**new_kwargs)

# Re-initialize ConversationManager with temp coder
ConversationManager.initialize(
temp_coder,
reset=True,
reformat=True,
preserve_tags=[MessageTag.DONE, MessageTag.CUR],
)

verbose = kwargs.get("verbose", False)
if verbose:
temp_coder.show_announcements()

try:
await temp_coder.generate(user_message=message, preproc=False)
coder.total_cost = temp_coder.total_cost
coder.coder_commit_hashes = temp_coder.coder_commit_hashes

# Clear manager and restore original state
ConversationManager.initialize(
original_coder,
reset=True,
reformat=True,
preserve_tags=[MessageTag.DONE, MessageTag.CUR],
)

# Restore the original model configuration
from cecli.commands import SwitchCoderSignal

raise SwitchCoderSignal(
main_model=original_main_model, edit_format=original_edit_format
)
except Exception as e:
# If there's an error, still restore the original model
if not isinstance(e, SwitchCoderSignal):
io.tool_error(str(e))
raise SwitchCoderSignal(
main_model=original_main_model, edit_format=original_edit_format
)
else:
# Re-raise SwitchCoderSignal if that's what was thrown
raise
else:
from cecli.commands import SwitchCoderSignal

raise SwitchCoderSignal(main_model=model, edit_format=coder.edit_format)

@classmethod
def get_completions(cls, io, coder, args) -> List[str]:
"""Get completion options for agent-model command."""
return models.get_chat_model_names()

@classmethod
def get_help(cls) -> str:
"""Get help text for the agent-model command."""
help_text = super().get_help()
help_text += "\nUsage:\n"
help_text += " /agent-model <model-name> # Switch to a new agent model\n"
help_text += (
" /agent-model <model-name> <prompt> # Use a specific agent model for a single"
" prompt\n"
)
help_text += "\nExamples:\n"
help_text += (
" /agent-model gpt-4o-mini # Switch to GPT-4o Mini as agent model\n"
)
help_text += (
" /agent-model claude-3-haiku # Switch to Claude 3 Haiku as agent model\n"
)
help_text += ' /agent-model o1-mini "review this code" # Use o1-mini to review code\n'
help_text += (
"\nWhen switching agent models, the main model and editor model remain unchanged.\n"
)
help_text += (
"\nIf you provide a prompt after the model name, that agent model will be used\n"
)
help_text += "just for that prompt, then you'll return to your original agent model.\n"
return help_text
1 change: 1 addition & 0 deletions cecli/commands/editor_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ async def execute(cls, io, coder, args, **kwargs):
coder.main_model.name,
editor_model=model_name,
weak_model=coder.main_model.weak_model.name,
agent_model=coder.main_model.agent_model.name,
io=io,
retries=coder.main_model.retries,
debug=coder.main_model.debug,
Expand Down
1 change: 1 addition & 0 deletions cecli/commands/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ async def execute(cls, io, coder, args, **kwargs):
("Main model", coder.main_model),
("Editor model", getattr(coder.main_model, "editor_model", None)),
("Weak model", getattr(coder.main_model, "weak_model", None)),
("Agent model", getattr(coder.main_model, "agent_model", None)),
]
for label, model in active_models:
if not model:
Expand Down
14 changes: 14 additions & 0 deletions cecli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,7 @@ def apply_model_overrides(model_name):
main_model_name, main_model_overrides = apply_model_overrides(args.model)
weak_model_name, weak_model_overrides = apply_model_overrides(args.weak_model)
editor_model_name, editor_model_overrides = apply_model_overrides(args.editor_model)
agent_model_name, agent_model_overrides = apply_model_overrides(args.agent_model)
weak_model_obj = None
if weak_model_name:
weak_model_obj = models.Model(
Expand All @@ -848,6 +849,18 @@ def apply_model_overrides(model_name):
retries=args.retries,
debug=args.debug,
)
agent_model_obj = None
if agent_model_name:
agent_model_obj = models.Model(
agent_model_name,
agent_model=False,
verbose=args.verbose,
io=io,
override_kwargs=agent_model_overrides,
retries=args.retries,
debug=args.debug,
)

if main_model_name.startswith("openrouter/") and not os.environ.get("OPENROUTER_API_KEY"):
io.tool_warning(
f"The specified model '{main_model_name}' requires an OpenRouter API key, which was not"
Expand All @@ -873,6 +886,7 @@ def apply_model_overrides(model_name):
main_model_name,
weak_model=weak_model_obj,
editor_model=editor_model_obj,
agent_model=agent_model_obj,
editor_edit_format=args.editor_edit_format,
verbose=args.verbose,
io=io,
Expand Down
28 changes: 28 additions & 0 deletions cecli/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class ModelSettings:
name: str
edit_format: str = "diff"
weak_model_name: Optional[str] = None
agent_model_name: Optional[str] = None
use_repo_map: bool = False
send_undo_reply: bool = False
lazy: bool = False
Expand Down Expand Up @@ -314,6 +315,7 @@ def __init__(
model,
weak_model=None,
editor_model=None,
agent_model=None,
editor_edit_format=None,
verbose=False,
io=None,
Expand Down Expand Up @@ -341,6 +343,7 @@ def __init__(
self.max_chat_history_tokens = 1024
self.weak_model = None
self.editor_model = None
self.agent_model = None
self.extra_model_settings = next(
(ms for ms in MODEL_SETTINGS if ms.name == "cecli/extra_params"), None
)
Expand All @@ -354,6 +357,7 @@ def __init__(
self.configure_model_settings(model)
self._apply_provider_defaults()
self.get_weak_model(weak_model)
self.get_agent_model(agent_model)
self.retries = retries
self.debug = debug

Expand Down Expand Up @@ -580,6 +584,30 @@ def get_weak_model(self, provided_weak_model):
self.weak_model = Model(self.weak_model_name, weak_model=False, io=self.io)
return self.weak_model

def get_agent_model(self, provided_weak_model):
if provided_weak_model is False:
self.agent_model = self
self.agent_model_name = None
return
if self.copy_paste_transport == "clipboard":
self.agent_model = self
self.agent_model_name = None
return
if isinstance(provided_weak_model, Model):
self.agent_model = provided_weak_model
self.agent_model_name = provided_weak_model.name
return
if provided_weak_model:
self.agent_model_name = provided_weak_model
if not self.agent_model_name:
self.agent_model = self
return
if self.agent_model_name == self.name:
self.agent_model = self
return
self.agent_model = Model(self.agent_model_name, agent_model=False, io=self.io)
return self.agent_model

def commit_message_models(self):
return [self.weak_model, self]

Expand Down
1 change: 1 addition & 0 deletions cecli/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def _build_session_data(self, session_name) -> Dict:
"model": self.coder.main_model.name,
"weak_model": self.coder.main_model.weak_model.name,
"editor_model": self.coder.main_model.editor_model.name,
"agent_model": self.coder.main_model.agent_model.name,
"editor_edit_format": self.coder.main_model.editor_edit_format,
"edit_format": self.coder.edit_format,
"chat_history": {
Expand Down