diff --git a/README.md b/README.md index 064919f5db3..05e521ee7da 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ LLMs are a part of our lives from here on out so join us in learning about and c * [TUI Configuration](https://github.com/dwash96/cecli/blob/main/cecli/website/docs/config/tui.md) * [Skills](https://github.com/dwash96/cecli/blob/main/cecli/website/docs/config/skills.md) * [Session Management](https://github.com/dwash96/cecli/blob/main/cecli/website/docs/sessions.md) +* [Hooks](https://github.com/dwash96/cecli/blob/main/cecli/website/docs/config/hooks.md) * [Custom Commands](https://github.com/dwash96/cecli/blob/main/cecli/website/docs/config/custom-commands.md) * [Custom System Prompts](https://github.com/dwash96/cecli/blob/main/cecli/website/docs/config/custom-system-prompts.md) * [Custom Tools](https://github.com/dwash96/cecli/blob/main/cecli/website/docs/config/agent-mode.md#creating-custom-tools) @@ -171,15 +172,15 @@ The current priorities are to improve core capabilities and user experience of t * [ ] Add visibility into active sub agent calls in TUI 8. **Hooks** - * [ ] Add hooks base class for user defined python hooks with an execute method with type and priority settings - * [ ] Add hook manager that can accept user defined files and command line commands - * [ ] Integrate hook manager with coder classes with hooks for `start`, `on_message`, `end_message`, `pre_tool`, and `post_tool` + * [x] Add hooks base class for user defined python hooks with an execute method with type and priority settings + * [x] Add hook manager that can accept user defined files and command line commands + * [x] Integrate hook manager with coder classes with hooks for `start`, `end`, `on_message`, `end_message`, `pre_tool`, and `post_tool` 9. **Efficient File Editing** - * [ ] Explore use of hashline file representation for more targeted file editing - * [ ] Assuming viability, update SEARCH part of SEARCH/REPLACE with hashline identification - * [ ] Update agent mode edit tools to work with hashline identification - * [ ] Update internal file diff representation to support hashline propagation + * [x] Explore use of hashline file representation for more targeted file editing + * [x] Assuming viability, update SEARCH part of SEARCH/REPLACE with hashline identification (Done with new edit format) + * [x] Update agent mode edit tools to work with hashline identification + * [x] Update internal file diff representation to support hashline propagation 10. **Dynamic Context Management** * [ ] Update compaction to use observational memory sub agent calls to generate decision records that are used as the compaction basis diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 5cb857ce121..f052cbffbdb 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -9,6 +9,7 @@ import shutil import subprocess import sys +import tarfile import time import traceback from collections import defaultdict @@ -174,6 +175,12 @@ def main( stats: bool = typer.Option( False, "--stats", help="Generate statistics YAML file from benchmark results" ), + aggregate: Optional[str] = typer.Option( + None, "--aggregate", help="Aggregate results from directories matching this pattern" + ), + tar: Optional[str] = typer.Option( + None, "--tar", help="Create a tar.gz of directories matching this pattern" + ), ): # setup logging and verbosity if quiet: @@ -185,19 +192,19 @@ def main( logging.basicConfig(level=log_level, format="%(message)s") + # Convert SimpleNamespace to dict for YAML serialization + def simple_namespace_to_dict(obj): + if isinstance(obj, SimpleNamespace): + return {k: simple_namespace_to_dict(v) for k, v in vars(obj).items()} + elif isinstance(obj, dict): + return {k: simple_namespace_to_dict(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [simple_namespace_to_dict(item) for item in obj] + else: + return obj + # Handle --stats flag: generate statistics YAML file and exit if stats: - # Convert SimpleNamespace to dict for YAML serialization - def simple_namespace_to_dict(obj): - if isinstance(obj, SimpleNamespace): - return {k: simple_namespace_to_dict(v) for k, v in vars(obj).items()} - elif isinstance(obj, dict): - return {k: simple_namespace_to_dict(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [simple_namespace_to_dict(item) for item in obj] - else: - return obj - # Get statistics stats_result = summarize_results(results_dir, verbose, stats_languages=languages) @@ -212,6 +219,50 @@ def simple_namespace_to_dict(obj): print(f"Statistics written to: {results_yaml_path}") return 0 + if aggregate: + # Find matching directories + matching_dirs = [d for d in BENCHMARK_DNAME.iterdir() if d.is_dir() and aggregate in d.name] + + if not matching_dirs: + print(f"No directories matching '{aggregate}' found in {BENCHMARK_DNAME}") + return 1 + + all_results = {} + for d in matching_dirs: + stats_result = summarize_results(d, verbose, stats_languages=languages) + if stats_result: + # Remove timestamp from directory name for the key + key = re.sub(r"^\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}--", "", d.name) + all_results[key] = simple_namespace_to_dict(stats_result) + + # Sort the results by the keys (directory names without timestamps) + all_results = dict(sorted(all_results.items())) + + output_path = BENCHMARK_DNAME / f"all_results_{aggregate}.yml" + with open(output_path, "w") as f: + yaml.dump(all_results, f, default_flow_style=False) + + print(f"Aggregated results written to: {output_path}") + return 0 + + if tar: + # Find matching directories + matching_dirs = [d for d in BENCHMARK_DNAME.iterdir() if d.is_dir() and tar in d.name] + + if not matching_dirs: + print(f"No directories matching '{tar}' found in {BENCHMARK_DNAME}") + return 1 + + output_path = BENCHMARK_DNAME / f"benchmarks_{tar}.tar.gz" + print(f"Creating tarball: {output_path}") + with tarfile.open(output_path, "w:gz") as tar_handle: + for d in matching_dirs: + print(f" Adding {d.name}...") + tar_handle.add(d, arcname=d.name) + + print(f"Tarball created at: {output_path}") + return 0 + from cecli import models if dry: @@ -424,7 +475,7 @@ def get_exercise_dirs(base_dir, languages=None, sets=None, hash_re=None, legacy= all_results = run_test_threaded.gather(tqdm=True) else: all_results = [] - for test_path in test_dnames: + for test_path in sorted(test_dnames): results = run_test(original_dname, results_dir / test_path, **test_args) all_results.append(results) summarize_results(results_dir, verbose) @@ -454,6 +505,21 @@ def load_results(results_dir, stats_languages=None): # BUG20251223 logger.debug(f"Processing result file: {fname}") + # Check if test case counts are missing and compute them if needed + test_dir = fname.parent + + logger.debug(f"Computing test case counts for {test_dir}") + total_cases, passed_cases = get_test_case_counts_from_directory(test_dir) + + if total_cases is not None: + results["test_cases_total"] = total_cases + if passed_cases is not None: + results["test_cases_passed"] = passed_cases + + # Update the JSON file immediately to keep data fresh + fname.write_text(json.dumps(results, indent=4)) + logger.debug(f"Updated {fname} with test case counts") + # Try to get language from cat.yaml if it exists in the same dir lang = "unknown" cat_yaml = fname.parent / "cat.yaml" @@ -515,6 +581,8 @@ def summarize_results(results_dir, verbose, stats_languages=None): res.lazy_comments = 0 res.prompt_tokens = 0 res.completion_tokens = 0 + res.test_cases_total = 0 + res.test_cases_passed = 0 res.reasoning_effort = None res.thinking_tokens = None @@ -551,6 +619,8 @@ def add(attr_name, increment, global_stats, lang_stats): lang_stats.lazy_comments = 0 lang_stats.prompt_tokens = 0 lang_stats.completion_tokens = 0 + lang_stats.test_cases_total = 0 + lang_stats.test_cases_passed = 0 lang_to_stats[lang] = lang_stats lang_to_passed_tests[lang] = [0] * tries @@ -604,6 +674,14 @@ def add(attr_name, increment, global_stats, lang_stats): lang_stats, ) + # Collect test case statistics from pre-computed results + total_cases = results.get("test_cases_total") + passed_cases = results.get("test_cases_passed") + if total_cases is not None: + add("test_cases_total", total_cases, res, lang_stats) + if passed_cases is not None: + add("test_cases_passed", passed_cases, res, lang_stats) + res.reasoning_effort = results.get("reasoning_effort") res.thinking_tokens = results.get("thinking_tokens") res.map_tokens = results.get("map_tokens") @@ -679,6 +757,13 @@ def show(stat, red="red"): show("test_timeouts") print(f" total_tests: {res.total_tests}") + # Add test case statistics and percentage + if res.test_cases_total > 0: + print(f" test_cases_total: {res.test_cases_total}") + print(f" test_cases_passed: {res.test_cases_passed}") + res.test_cases_percentage = 100 * res.test_cases_passed / res.test_cases_total + print(f" test_cases_percentage: {res.test_cases_percentage:.1f}") + if variants["model"]: a_model = set(variants["model"]).pop() command = f"cecli --model {a_model}" @@ -1317,5 +1402,267 @@ def cleanup_test_output(output, testdir): return res +def parse_test_results_from_history(history_file_path, total_cases=None, test_dir=None): + """ + Parse .cecli.dev.history.md file to extract test results. + Returns a tuple of (test_cases_total, test_cases_passed) or (None, None) if not found. + """ + try: + # First, check if we can get test results from .cecli.results.json + if test_dir: + results_json_path = Path(test_dir) / ".cecli.results.json" + if results_json_path.exists(): + try: + import json + + with open(results_json_path, "r") as f: + results_data = json.load(f) + + # Check if tests_outcomes contains true (indicating all tests passed) + tests_outcomes = results_data.get("tests_outcomes", []) + if True in tests_outcomes: + # All tests passed at some point + # Try to get test_cases_total from results or use provided total_cases + total_from_results = results_data.get("test_cases_total") + if total_from_results is not None: + logger.debug( + "All tests passed according to results.json, returning" + f" total_cases={total_from_results}" + ) + return total_from_results, total_from_results + elif total_cases is not None: + logger.debug( + "All tests passed according to results.json, using provided" + f" total_cases={total_cases}" + ) + return total_cases, total_cases + else: + logger.debug("All tests passed but no total cases available") + except Exception as e: + logger.debug(f"Failed to parse .cecli.results.json: {e}") + + # Determine language from cat.yaml if test_dir is provided + language = None + if test_dir: + cat_yaml_path = Path(test_dir) / "cat.yaml" + if cat_yaml_path.exists(): + try: + import yaml + + with open(cat_yaml_path, "r") as f: + metadata = yaml.safe_load(f) + language = metadata.get("language") + logger.debug(f"Detected language from cat.yaml: {language}") + except Exception as e: + logger.debug(f"Failed to read cat.yaml: {e}") + with open(history_file_path, "r") as f: + content = f.read() + # Find test output after the last "Tokens:" line + # Token lines look like: "Tokens: 4.6k sent, 5.8k received. Cost: $0.07 message, $0.07 session." + # Everything after the last token line is test output + import re + + lines = content.split("\n") + token_line_index = -1 + + # Find the last "Tokens:" line + for i, line in enumerate(lines): + if "Tokens:" in line: + token_line_index = i + + if token_line_index != -1 and token_line_index < len(lines) - 1: + # Capture everything after the last token line + test_output = "\n".join(lines[token_line_index + 1 :]) + else: + # No token line found, try to find test output in the last code block as fallback + test_output = "" + in_test_output = False + + for line in reversed(lines): + if line.strip().startswith("```") and not in_test_output: + # Found the end of a code block, start capturing + in_test_output = True + continue + elif line.strip().startswith("```") and in_test_output: + # Found the start of the code block, stop capturing + break + elif in_test_output: + test_output = line + "\n" + test_output + + if not test_output: + return None, None + # Parse test output based on detected language or fallback to pattern matching + passed_cases = 0 + + # Use language-specific parsing if language is known + if language: + language_lower = language.lower() + + # Python (pytest) format + if language_lower == "python": + # Example: "25 passed" or "24 passed, 1 failed" + python_match = re.search(r"(\d+)\s+passed", test_output) + if python_match: + passed_cases = int(python_match.group(1)) + # Try to get total from "collected X items" or "X passed; Y failed" + collected_match = re.search(r"collected\s+(\d+)\s+items", test_output) + if collected_match: + total_cases = int(collected_match.group(1)) + else: + # Try to get total from "X passed; Y failed" + failed_match = re.search(r"(\d+)\s+failed", test_output) + if failed_match: + total_cases = passed_cases + int(failed_match.group(1)) + else: + total_cases = passed_cases # Assume all passed if no failures mentioned + + # Go format + elif language_lower == "go": + # Example: "ok\t{test name}\t0.003s" for success + # Example: "--- FAIL:" for failures + if total_cases is not None: + # Check if test suite passed (has "ok" line) + if re.search(r"^[\t ]*ok", test_output, re.MULTILINE): + # Test suite passed - all tests passed + passed_cases = total_cases + else: + # Test suite failed - count failures and subtract from total + failed_count = test_output.count("--- FAIL:") + passed_cases = total_cases - failed_count + if passed_cases < 0: + passed_cases = 0 # Ensure non-negative + # Use the provided total_cases + total_cases = total_cases + else: + # Fallback to counting lines (legacy behavior) + # Count lines with "ok" (starting with any whitespace) as passed tests + passed_cases = len(re.findall(r"^[\t ]*ok", test_output, re.MULTILINE)) + # Count lines with "--- FAIL:" as failed tests + failed_count = test_output.count("--- FAIL:") + total_cases = passed_cases + failed_count + + # Rust format + elif language_lower == "rust": + # Example: "test result: ok. 10 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out" + # Need to match the LAST instance of the pattern in the output + # Use findall to get all matches, then take the last one + rust_matches = re.findall( + r"test result:.*?(\d+)\s+passed.*?(\d+)\s+failed", test_output, re.DOTALL + ) + if rust_matches: + # Get the last match (most recent test result) + last_match = rust_matches[-1] + # last_match is a tuple like ("15", "2") from the example + passed_cases = int(last_match[0]) + failed_cases = int(last_match[1]) + total_cases = passed_cases + failed_cases + + # JavaScript format + elif language_lower in ["javascript", "typescript"]: + # Example: "Tests: 11 failed, 4 passed, 15 total" + js_match = re.search( + r"Tests:\s*(\d+)\s+failed,\s*(\d+)\s+passed,\s*(\d+)\s+total", test_output + ) + if js_match: + failed_cases = int(js_match.group(1)) + passed_cases = int(js_match.group(2)) + total_cases = int(js_match.group(3)) + + # Java format + elif language_lower == "java": + # Look for "BUILD SUCCESSFUL" and count "PASSED" lines + if "BUILD SUCCESSFUL" in test_output: + passed_cases = test_output.count("PASSED") + failed_cases = test_output.count("FAILED") + total_cases = passed_cases + failed_cases + elif "FAILURE:" in test_output: + # Also look for "X tests completed, Y failed" pattern + java_match = re.search( + r"(\d+)\s+tests\s+completed,\s*(\d+)\s+failed", test_output + ) + + if java_match: + total_cases = int(java_match.group(1)) + failed_cases = int(java_match.group(2)) + passed_cases = total_cases - failed_cases + + # Check for build failures + if "[build failed]" in test_output: + # Build failure means 0 passed tests + return total_cases if total_cases else 0, 0 + + # C/C++ format + elif language_lower in ["c", "cpp", "c++"]: + # Look for test summary patterns + # Example: "[ PASSED ] 25 tests." + cpp_match = re.search(r"\[\s*PASSED\s*\]\s*(\d+)\s+tests", test_output) + if cpp_match: + passed_cases = int(cpp_match.group(1)) + # Try to get total from "[ FAILED ] X tests." + cpp_failed_match = re.search(r"\[\s*FAILED\s*\]\s*(\d+)\s+tests", test_output) + if cpp_failed_match: + failed_cases = int(cpp_failed_match.group(1)) + total_cases = passed_cases + failed_cases + else: + total_cases = passed_cases + + # If we found any test results, return them + if total_cases is not None and total_cases > 0: + return total_cases, passed_cases + elif passed_cases > 0: + return passed_cases, passed_cases # Assume all passed if total not found + + return None, None + + except Exception as e: + logger.debug(f"Failed to parse test results from {history_file_path}: {e}") + return None, None + + +def get_test_case_counts_from_directory(test_dir): + """ + Extract test case counts from a benchmark test directory. + Returns a tuple of (test_cases_total, test_cases_passed) or (None, None) if not found. + """ + test_dir = Path(test_dir) + + # Check for tests.toml file + tests_toml_path = test_dir / ".meta" / "tests.toml" + if not tests_toml_path.exists(): + return None, None + + # Parse tests.toml to get total test cases + total_cases = 0 + try: + with open(tests_toml_path, "r") as f: + content = f.read() + + # Count test entries by looking for UUID patterns in brackets + import re + + uuid_pattern = r"\[[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}\]" + test_entries = re.findall(uuid_pattern, content) + + # Count include = false entries to exclude them + include_false = content.count("include = false") + + total_cases = len(test_entries) - include_false + except Exception as e: + logger.debug(f"Failed to parse tests.toml from {tests_toml_path}: {e}") + return None, None + + # Check for .cecli.dev.history.md file to get passed test counts + history_file_path = test_dir / ".cecli.dev.history.md" + if not history_file_path.exists(): + return total_cases, None + + # Parse test results from history file + total_cases, passed_cases = parse_test_results_from_history( + history_file_path, total_cases, test_dir + ) + + return total_cases, passed_cases + + if __name__ == "__main__": app() diff --git a/benchmark/primary_variations.sh b/benchmark/primary_variations.sh index 31b8a653b6d..d3484333a4a 100755 --- a/benchmark/primary_variations.sh +++ b/benchmark/primary_variations.sh @@ -5,13 +5,13 @@ set -e # Exit on error # Default values -BASE_NAME="cecli-base-hashline-9" -EDIT_FORMAT="hashline" +BASE_NAME="cecli-base-d-big-3" +EDIT_FORMAT="diff" MAP_TOKENS="512" THREADS="1" LANGUAGES="javascript,python,rust,go,java" -HASH_RE="^[15]" -NUM_TESTS="32" +HASH_RE="^.[15ef]" +NUM_TESTS="72" EXERCISES_DIR="polyglot-benchmark" OUTPUT_DIR="tmp.benchmarks" SLEEP_BETWEEN=30 # Seconds to sleep between runs @@ -21,15 +21,14 @@ SLEEP_BETWEEN=30 # Seconds to sleep between runs # "openrouter/minimax/minimax-m2.1" # "openrouter/qwen/qwen3-vl-235b-a22b-thinking" MODELS=( +# "openrouter/deepseek/deepseek-v3.2-exp" + "openrouter/moonshotai/kimi-k2.5" + "openrouter/openai/gpt-oss-120b" + "openrouter/openai/gpt-5.2" "openrouter/google/gemini-3-flash-preview" - "openrouter/deepseek/deepseek-v3.2-exp" -# "openrouter/moonshotai/kimi-k2.5" -# "openrouter/openai/gpt-oss-120b" -# "openrouter/openai/gpt-5.2" -# "openrouter/google/gemini-3-flash-preview" -# "openrouter/google/gemini-3-pro-preview" -# "openrouter/anthropic/claude-haiku-4.5" -# "openrouter/anthropic/claude-sonnet-4.5" + "openrouter/google/gemini-3-pro-preview" + "openrouter/anthropic/claude-haiku-4.5" + "openrouter/anthropic/claude-sonnet-4.5" ) # Parse command line arguments @@ -111,6 +110,7 @@ run_benchmark() { # Create the benchmark command ./benchmark/benchmark.py "$run_name" \ + --new \ --model "$model" \ --edit-format "$EDIT_FORMAT" \ --map-tokens "$MAP_TOKENS" \ diff --git a/cecli/args.py b/cecli/args.py index 03efe64c8b5..7c1de5b0315 100644 --- a/cecli/args.py +++ b/cecli/args.py @@ -310,6 +310,12 @@ def get_parser(default_config_files, git_root): help="Specify Agent Mode configuration as a JSON string", default=None, ) + group.add_argument( + "--hooks", + metavar="HOOKS_CONFIG_JSON", + help="Specify hooks configuration as a JSON string", + default=None, + ) group.add_argument( "--agent-model", metavar="AGENT_MODEL", diff --git a/cecli/coders/agent_coder.py b/cecli/coders/agent_coder.py index f6f2583eaa8..dd4ebdea870 100644 --- a/cecli/coders/agent_coder.py +++ b/cecli/coders/agent_coder.py @@ -4,6 +4,7 @@ import locale import os import platform +import random import time import traceback from collections import Counter, defaultdict @@ -25,6 +26,7 @@ normalize_vector, ) from cecli.helpers.skills import SkillsManager +from cecli.hooks import HookIntegration from cecli.llm import litellm from cecli.mcp import LocalServer, McpServerManager from cecli.tools.utils.registry import ToolRegistry @@ -50,6 +52,8 @@ def __init__(self, *args, **kwargs): self.tool_similarity_threshold = 0.99 self.max_tool_vector_history = 10 self.read_tools = { + "command", + "commandinteractive", "viewfilesatglob", "viewfilesmatching", "ls", @@ -60,8 +64,6 @@ def __init__(self, *args, **kwargs): "thinking", } self.write_tools = { - "command", - "commandinteractive", "deletetext", "indenttext", "inserttext", @@ -75,6 +77,7 @@ def __init__(self, *args, **kwargs): self.args = kwargs.get("args") self.files_added_in_exploration = set() self.tool_call_count = 0 + self.turn_count = 0 self.max_reflections = 15 self.use_enhanced_context = True self._last_edited_file = None @@ -239,6 +242,17 @@ async def _execute_local_tool_calls(self, tool_calls_list): try: args_string = tool_call.function.arguments.strip() parsed_args_list = [] + + if not await HookIntegration.call_pre_tool_hooks(self, tool_name, args_string): + tool_responses.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": "Tool Request Aborted.", + } + ) + continue + if args_string: json_chunks = utils.split_concatenated_json(args_string) for chunk in json_chunks: @@ -291,6 +305,19 @@ async def _execute_local_tool_calls(self, tool_calls_list): if tasks: task_results = await asyncio.gather(*tasks) all_results_content.extend(str(res) for res in task_results) + + if not await HookIntegration.call_post_tool_hooks( + self, tool_name, args_string, "\n\n".join(all_results_content) + ): + tool_responses.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": "Tool Response Redacted.", + } + ) + continue + result_message = "\n\n".join(all_results_content) except Exception as e: result_message = f"Error executing {tool_name}: {e}" @@ -527,6 +554,7 @@ def format_chat_chunks(self): ConversationChunks.add_readonly_files_messages(self) ConversationChunks.add_chat_files_messages(self) + ConversationChunks.add_file_context_messages(self) # Add post-message context blocks (priority 250 - between CUR and REMINDER) ConversationChunks.add_post_message_context_blocks(self) @@ -687,6 +715,8 @@ async def process_tool_calls(self, tool_call_response): await self.auto_save_session() self.last_round_tools = [] + self.turn_count += 1 + if self.partial_response_tool_calls: for tool_call in self.partial_response_tool_calls: tool_name = getattr(tool_call.function, "name", None) @@ -931,12 +961,18 @@ def _generate_tool_context(self, repetitive_tools): """ if not self.tool_usage_history: return "" + + if not hasattr(self, "_last_repetitive_warning_turn"): + self._last_repetitive_warning_turn = 0 + self._last_repetitive_warning_severity = 0 + context_parts = [''] context_parts.append("## Turn and Tool Call Statistics") - context_parts.append(f"- Current turn: {self.num_reflections + 1}") + context_parts.append(f"- Current turn: {self.turn_count + 1}") context_parts.append(f"- Total tool calls this turn: {self.num_tool_calls}") context_parts.append("\n\n") context_parts.append("## Recent Tool Usage History") + if len(self.tool_usage_history) > 10: recent_history = self.tool_usage_history[-10:] context_parts.append("(Showing last 10 tools)") @@ -944,39 +980,75 @@ def _generate_tool_context(self, repetitive_tools): recent_history = self.tool_usage_history for i, tool in enumerate(recent_history, 1): context_parts.append(f"{i}. {tool}") + context_parts.append("\n\n") - if repetitive_tools and len(self.tool_usage_history) >= 8: - context_parts.append("""**Instruction:** -You have used the following tool(s) repeatedly:""") - context_parts.append("### DO NOT USE THE FOLLOWING TOOLS/FUNCTIONS") - for tool in repetitive_tools: - context_parts.append(f"- `{tool}`") - context_parts.append( - "Your exploration appears to be stuck in a loop. Please try a different approach." - " Use the `Thinking` tool to clarify your intentions and new approach to what you" - " are currently attempting to accomplish." - ) - context_parts.append("\n") - context_parts.append("**Suggestions for alternative approaches:**") - context_parts.append( - "- If you've been searching for files, try working with the files already in" - " context" - ) - context_parts.append( - "- If you've been viewing files, try making actual edits to move forward" - ) - context_parts.append("- Consider using different tools that you haven't used recently") - context_parts.append( - "- Focus on making concrete progress rather than gathering more information" - ) - context_parts.append( - "- Use the files you've already discovered to implement the requested changes" - ) - context_parts.append("\n") - context_parts.append( - "You most likely have enough context for a subset of the necessary changes." - ) - context_parts.append("Please prioritize file editing over further exploration.") + if repetitive_tools: + if self.turn_count - self._last_repetitive_warning_turn > 2: + self._last_repetitive_warning_turn = self.turn_count + self._last_repetitive_warning_severity += 1 + + repetition_warning = f""" +## Repetition Detected: Strategy Adjustment Required +I have detected repetitive usage of the following tools: {', '.join([f'`{t}`' for t in repetitive_tools])}. +**Constraint:** Do not repeat the exact same parameters for these tools in your next turn. + """ + + if self._last_repetitive_warning_severity > 2: + self._last_repetitive_warning_severity = 0 + + fruit = random.choice( + [ + "an apple", + "a banana", + "a cantaloupe", + "a cherry", + "a honeydew", + "an orange", + "a mango", + "a pomegranate", + "a watermelon", + ] + ) + animal = random.choice( + [ + "a bird", + "a bear", + "a cat", + "a deer", + "a dog", + "an elephant", + "a fish", + "a fox", + "a monkey", + "a rabbit", + ] + ) + verb = random.choice( + [ + "absorbing", + "becoming", + "creating", + "dreaming of", + "eating", + "fighting with", + "playing with", + "painting", + "smashing", + "writing a song about", + ] + ) + + repetition_warning += f""" +### CRITICAL: Execution Loop Detected +You are currently "spinning." To break the logic trap, you must: +1. **Analyze**: Use the `Thinking` tool to summarize exactly what you have found so far and why you were stuck. +2. **Pivot**: Abandon or modify your current exploration strategy. Try focusing on different files or running tests. +3. **Reframe**: To ensure your logic reset, include a 2-sentence story about {animal} {verb} {fruit} in your thoughts. + +Prioritize editing or verification over further exploration. + """ + + context_parts.append(repetition_warning) context_parts.append("") return "\n".join(context_parts) @@ -1061,6 +1133,7 @@ async def preproc_user_input(self, inp): inp = f'\n{inp}\n' self.agent_finished = False + self.turn_count = 0 return inp def get_directory_structure(self): @@ -1139,11 +1212,11 @@ def print_tree(node, prefix="- ", indent=" ", current_path=""): def get_todo_list(self): """ - Generate a todo list context block from the .cecli/todo.txt file. + Generate a todo list context block from the todo.txt file. Returns formatted string with the current todo list or None if empty/not present. """ try: - todo_file_path = ".cecli/todo.txt" + todo_file_path = self.local_agent_folder("todo.txt") abs_path = self.abs_root_path(todo_file_path) import os diff --git a/cecli/coders/base_coder.py b/cecli/coders/base_coder.py index 7ca2288c45b..77ee3e7ff74 100755 --- a/cecli/coders/base_coder.py +++ b/cecli/coders/base_coder.py @@ -16,7 +16,7 @@ import traceback import weakref from collections import defaultdict -from datetime import datetime +from datetime import date, datetime # Optional dependency: used to convert locale codes (eg ``en_US``) # into human-readable language names (eg ``English``). @@ -28,6 +28,7 @@ from pathlib import Path from typing import List from urllib.parse import urlparse +from uuid import uuid4 as generate_unique_id import httpx from litellm import experimental_mcp_client @@ -47,6 +48,7 @@ ) from cecli.helpers.profiler import TokenProfiler from cecli.history import ChatSummary +from cecli.hooks import HookIntegration from cecli.io import ConfirmGroup, InputOutput from cecli.linter import Linter from cecli.llm import litellm @@ -69,6 +71,8 @@ from ..dump import dump # noqa: F401 from ..prompts.utils.registry import PromptObject, PromptRegistry +GLOBAL_DATE = date.today().isoformat() + class UnknownEditFormat(ValueError): def __init__(self, edit_format, valid_formats): @@ -150,6 +154,9 @@ class Coder: compact_context_completed = True suppress_announcements_for_next_prompt = False tool_reflection = False + last_user_message = "" + uuid = "" + # Task coordination state variables input_running = False output_running = False @@ -239,6 +246,7 @@ async def create( total_tokens_received=from_coder.total_tokens_received, file_watcher=from_coder.file_watcher, mcp_manager=from_coder.mcp_manager, + uuid=from_coder.uuid, ) use_kwargs.update(update) # override to complete the switch use_kwargs.update(kwargs) # override passed kwargs @@ -334,8 +342,13 @@ def __init__( repomap_in_memory=False, linear_output=False, security_config=None, + uuid="", ): # initialize from args.map_cache_dir + self.uuid = generate_unique_id() + if uuid: + self.uuid = uuid + self.map_cache_dir = map_cache_dir self.chat_language = chat_language @@ -479,7 +492,7 @@ def __init__( self.io.tool_warning(f"Skipping {fname} that matches gitignore spec.") continue - if self.repo and self.repo.ignored_file(fname): + if self.repo and self.repo.ignored_file(fname) and not self.add_gitignore_files: self.io.tool_warning(f"Skipping {fname} that matches cecli.ignore spec.") continue @@ -569,7 +582,7 @@ def __init__( self.test_cmd = test_cmd # Clean up todo list file on startup; sessions will restore it when needed - todo_file_path = ".cecli/todo.txt" + todo_file_path = self.local_agent_folder("todo.txt") abs_path = self.abs_root_path(todo_file_path) if os.path.isfile(abs_path): try: @@ -1517,7 +1530,7 @@ async def generate(self, user_message, preproc): await asyncio.sleep(0.1) try: - if not self.enable_context_compaction: + if self.enable_context_compaction: self.compact_context_completed = False await self.compact_context_if_needed() self.compact_context_completed = True @@ -1578,6 +1591,10 @@ async def preproc_user_input(self, inp): async def run_one(self, user_message, preproc): self.init_before_message() + if not await HookIntegration.call_start_hooks(self): + self.io.tool_warning("Execution stopped by start hook") + return + if preproc: message = await self.preproc_user_input(user_message) else: @@ -1588,6 +1605,9 @@ async def run_one(self, user_message, preproc): ): return + if not self.commands.is_command(user_message): + self.last_user_message = user_message + while True: self.reflected_message = None self.tool_reflection = False @@ -1601,7 +1621,7 @@ async def run_one(self, user_message, preproc): if self.num_reflections >= self.max_reflections: self.io.tool_warning(f"Only {self.max_reflections} reflections allowed, stopping.") - return + break self.num_reflections += 1 @@ -1622,6 +1642,10 @@ async def run_one(self, user_message, preproc): await self.auto_save_session(force=True) + if not await HookIntegration.call_end_hooks(self): + self.io.tool_warning("Execution stopped by end hook") + return + def _is_url_allowed(self, url): allowed_domains = self.security_config.get("allowed-domains") if not allowed_domains: @@ -1781,9 +1805,12 @@ async def compact_context_if_needed(self, force=False, message=""): ConversationManager.clear_tag(MessageTag.CUR) # Keep the first message (user's initial input) if it exists - if cur_messages: + if self.last_user_message: ConversationManager.add_message( - message_dict=cur_messages[0], + message_dict={ + "role": "user", + "content": self.last_user_message, + }, tag=MessageTag.CUR, ) @@ -1819,8 +1846,10 @@ async def compact_context_if_needed(self, force=False, message=""): self.io.tool_output("...chat history compacted.") self.io.update_spinner(self.io.last_spinner_text) - # Clear all diff messages + # Clear all diff and file context messages ConversationManager.clear_tag(MessageTag.DIFFS) + ConversationManager.clear_tag(MessageTag.FILE_CONTEXTS) + # Reset ConversationFiles cache entirely from cecli.helpers.conversation.files import ConversationFiles @@ -2254,7 +2283,7 @@ async def send_message(self, inp): self.io.tool_output() self.show_usage_report() - self.add_assistant_reply_to_cur_messages() + await self.add_assistant_reply_to_cur_messages() if exhausted: cur_messages = ConversationManager.get_messages_dict(MessageTag.CUR) @@ -2576,6 +2605,13 @@ async def _exec_server_tools(server, tool_calls_list): new_tool_call = copy_tool_call(tool_call) new_tool_call.function.arguments = json.dumps(args) + if not await HookIntegration.call_pre_tool_hooks( + self, new_tool_call.function.name, args + ): + self.io.tool_warning("Tool call skipped by pre-tool call hook") + all_results_content.append("Tool Request Aborted.") + continue + call_result = await experimental_mcp_client.call_openai_tool( session=session, openai_tool=new_tool_call, @@ -2608,6 +2644,16 @@ async def _exec_server_tools(server, tool_calls_list): content_parts.append(item.text) result_text = "".join(content_parts) + + if not await HookIntegration.call_post_tool_hooks( + self, new_tool_call.function.name, args, result_text + ): + self.io.tool_warning( + "Tool call output skipped by post-tool call hook" + ) + all_results_content.append("Tool Response Redacted.") + continue + all_results_content.append(result_text) tool_responses.append( @@ -2783,7 +2829,7 @@ def __del__(self): """Cleanup when the Coder object is destroyed.""" self.ok_to_warm_cache = False - def add_assistant_reply_to_cur_messages(self): + async def add_assistant_reply_to_cur_messages(self): """ Add the assistant's reply to `cur_messages`. Handles model-specific quirks, like Deepseek which requires `content` @@ -2825,6 +2871,10 @@ def add_assistant_reply_to_cur_messages(self): or msg.get("tool_calls", None) or msg.get("function_call", None) ): + if not await HookIntegration.call_end_message_hooks(self, str(msg)): + self.io.tool_warning("Execution stopped by end message hook") + return + ConversationManager.add_message( message_dict=msg, tag=MessageTag.CUR, @@ -2952,7 +3002,7 @@ async def send(self, messages, model=None, functions=None, tools=None): async for chunk in self.show_send_output_stream(completion): yield chunk else: - self.show_send_output(completion) + await self.show_send_output(completion) response, func_err, content_err = self.consolidate_chunks() @@ -2982,7 +3032,7 @@ async def send(self, messages, model=None, functions=None, tools=None): if args: self.io.ai_output(json.dumps(args, indent=4)) - def show_send_output(self, completion): + async def show_send_output(self, completion): if self.verbose: print(completion) @@ -2998,6 +3048,10 @@ def show_send_output(self, completion): response, func_err, content_err = self.consolidate_chunks() + if not await HookIntegration.call_on_message_hooks(self, self.partial_response_content): + self.io.tool_warning("Execution stopped by on message hook") + return + resp_hash = dict( function_call=str(self.partial_response_function_call), content=self.partial_response_content, @@ -3163,6 +3217,10 @@ async def show_send_output_stream(self, completion): # The Part Doing the Heavy Lifting Now self.consolidate_chunks() + if not await HookIntegration.call_on_message_hooks(self, self.partial_response_content): + self.io.tool_warning("Execution stopped by on message hook") + return + if not received_content and len(self.partial_response_tool_calls) == 0: self.io.tool_warning("Empty response received from LLM. Check your provider account?") @@ -3753,6 +3811,7 @@ async def apply_updates(self): return edited except ANY_GIT_ERROR as err: + self.io.tool_error(traceback.format_exc()) self.io.tool_error(str(err)) return edited except Exception as err: @@ -3891,6 +3950,12 @@ def apply_edits(self, edits): def apply_edits_dry_run(self, edits): return edits + def local_agent_folder(self, path): + os.makedirs(f".cecli/agents/{GLOBAL_DATE}/{self.uuid}", exist_ok=True) + + stripped = path.lstrip("/") + return f".cecli/agents/{GLOBAL_DATE}/{self.uuid}/{stripped}" + async def auto_save_session(self, force=False): """Automatically save the current session to {auto-save-session-name}.json.""" if not getattr(self.args, "auto_save", False): @@ -3953,7 +4018,11 @@ async def run_shell_commands(self): self.commands.cmd_running_event.set() # Command finished async def handle_shell_commands(self, commands_str, group): - commands = command_parser.split_shell_commands(commands_str) + commands = [ + cmd + for cmd in command_parser.split_shell_commands(commands_str) + if cmd and not (isinstance(cmd, str) and cmd.startswith("#")) + ] # Early return if none of the command strings have length after stripping whitespace if not any(cmd.strip() for cmd in commands): diff --git a/cecli/coders/copypaste_coder.py b/cecli/coders/copypaste_coder.py index 1ca00540f31..65dd32d5a1f 100644 --- a/cecli/coders/copypaste_coder.py +++ b/cecli/coders/copypaste_coder.py @@ -146,7 +146,7 @@ async def send(self, messages, model=None, functions=None, tools=None): try: hash_object, completion = self.copy_paste_completion(messages, model) self.chat_completion_call_hashes.append(hash_object.hexdigest()) - self.show_send_output(completion) + await self.show_send_output(completion) self.calculate_and_show_tokens_and_cost(messages, completion) finally: self.preprocess_response() diff --git a/cecli/coders/hashline_coder.py b/cecli/coders/hashline_coder.py index 17d459e1baa..dd048a79f37 100644 --- a/cecli/coders/hashline_coder.py +++ b/cecli/coders/hashline_coder.py @@ -67,8 +67,8 @@ def apply_edits(self, edits, dry_run=False): # Validate operation if operation in ["replace", "insert", "delete"]: # Validate hashline format - if (isinstance(start_hash, str) and "|" in start_hash) and ( - operation == "insert" or (isinstance(end_hash, str) and "|" in end_hash) + if isinstance(start_hash, str) and ( + operation == "insert" or isinstance(end_hash, str) ): if path not in hashline_edits_by_file: hashline_edits_by_file[path] = [] @@ -116,36 +116,35 @@ def apply_edits(self, edits, dry_run=False): full_path = self.abs_root_path(path) new_content = None - if Path(full_path).exists(): - try: + try: + # Read existing content or use empty string for new files + if Path(full_path).exists(): content = self.io.read_text(full_path) - # Apply all hashline operations for this file in batch - new_content, _, _ = apply_hashline_operations( - original_content=strip_hashline(content), - operations=operations, - ) - - if dry_run: - # For dry runs, preserve the original edit format - updated_edits.extend(original_hashline_edits_by_file[path]) - else: - updated_edits.append((path, operations, "Batch hashline operations")) + else: + content = "" - if new_content: - if not dry_run: - self.io.write_text(full_path, new_content) - passed.append((path, operations, "Batch hashline operations")) - else: - # No changes or failed - failed.append((path, operations, "Batch hashline operations")) + # Apply all hashline operations for this file in batch + new_content, _, _ = apply_hashline_operations( + original_content=strip_hashline(content), + operations=operations, + ) - except (ValueError, HashlineError) as e: - # Record failure - failed.append((path, operations, f"Hashline batch operation failed: {e}")) - continue - else: - # File doesn't exist - failed.append((path, operations, "File not found")) + if dry_run: + # For dry runs, preserve the original edit format + updated_edits.extend(original_hashline_edits_by_file[path]) + else: + updated_edits.append((path, operations, "Batch hashline operations")) + + # Write the file if operation was successful (no exception) + # new_content could be empty string for empty file creation + if not dry_run: + self.io.write_text(full_path, new_content) + passed.append((path, operations, "Batch hashline operations")) + + except (ValueError, HashlineError) as e: + # Record failure + failed.append((path, operations, f"Hashline batch operation failed: {e}")) + continue # Process regular edits one by one (existing logic) for edit in regular_edits: @@ -156,9 +155,9 @@ def apply_edits(self, edits, dry_run=False): if not isinstance(original, str) or not isinstance(updated, str): continue - if Path(full_path).exists(): - content = self.io.read_text(full_path) - new_content = do_replace(full_path, content, original, updated, self.fence) + # Always try do_replace, it handles new file creation + content = self.io.read_text(full_path) if Path(full_path).exists() else "" + new_content = do_replace(full_path, content, original, updated, self.fence) # If the edit failed, and # this is not a "create a new file" with an empty original... @@ -177,7 +176,7 @@ def apply_edits(self, edits, dry_run=False): updated_edits.append((path, original, updated)) - if new_content: + if new_content is not None: if not dry_run: self.io.write_text(full_path, new_content) passed.append(edit) @@ -216,20 +215,6 @@ def apply_edits(self, edits, dry_run=False): {did_you_mean} {self.fence[1]} -""" - if updated in content and updated: - res += f"""Are you sure you need this LOCATE/CONTENTS block? -The CONTENTS lines are already in {path}! - -""" - did_you_mean = find_similar_lines(original, content) - if did_you_mean: - res += f"""Did you mean to match some of these actual lines from {path}? - -{self.fence[0]} -{did_you_mean} -{self.fence[1]} - """ if updated in content and updated: @@ -238,9 +223,9 @@ def apply_edits(self, edits, dry_run=False): """ res += ( - "The search section must be a valid JSON array in the format:\n" + "The LOCATE section must be a valid JSON array in the format:\n" '["{start hashline}", "{end hashline}", "{operation}"]\n' - "Hashline prefixes must have the structure `{line_num}|{hash_fragment}` (e.g., `20|Bv`)" + "Hashline prefixes must have the structure `{line_num}{hash_fragment}` (e.g., `20Bv`)" " and match one found directly in the file" ) if passed: @@ -670,13 +655,12 @@ def find_original_update_blocks(content, fence=DEFAULT_FENCE, valid_fnames=None) if isinstance(parsed, list) and len(parsed) == 3: # Validate the format: all strings if all(isinstance(item, str) for item in parsed): - # Check if first two items look like hashline format (e.g., "1|ab") - if all("|" in item for item in parsed[:2]): - # Check if operation is valid - if parsed[2] in ["replace", "insert", "delete"]: - # This is a hashline JSON block - yield filename, parsed, updated_text_str - continue + # Check if first two items look like hashline format (e.g., "1ab") + + if parsed[2] in ["replace", "insert", "delete"]: + # This is a hashline JSON block + yield filename, parsed, updated_text_str + continue except (json.JSONDecodeError, ValueError): # Not a valid JSON, treat as regular edit block pass @@ -759,6 +743,10 @@ def find_similar_lines(search_lines, content_lines, threshold=0.6): search_lines = search_lines.splitlines() content_lines = content_lines.splitlines() + # Handle empty search lines + if not search_lines: + return "" + best_ratio = 0 best_match = None diff --git a/cecli/coders/single_wholefile_func_coder.py b/cecli/coders/single_wholefile_func_coder.py index 802d714536b..6690d311b62 100644 --- a/cecli/coders/single_wholefile_func_coder.py +++ b/cecli/coders/single_wholefile_func_coder.py @@ -38,7 +38,7 @@ class SingleWholeFileFunctionCoder(Coder): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def add_assistant_reply_to_cur_messages(self, edited): + async def add_assistant_reply_to_cur_messages(self, edited): if edited: # Always add to conversation manager ConversationManager.add_message( diff --git a/cecli/coders/wholefile_func_coder.py b/cecli/coders/wholefile_func_coder.py index f1871480975..2c9668a52dd 100644 --- a/cecli/coders/wholefile_func_coder.py +++ b/cecli/coders/wholefile_func_coder.py @@ -49,7 +49,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def add_assistant_reply_to_cur_messages(self, edited): + async def add_assistant_reply_to_cur_messages(self, edited): if edited: # Always add to conversation manager ConversationManager.add_message( diff --git a/cecli/commands/__init__.py b/cecli/commands/__init__.py index 3f438898fde..9f214034011 100644 --- a/cecli/commands/__init__.py +++ b/cecli/commands/__init__.py @@ -30,9 +30,11 @@ from .hashline import HashlineCommand from .help import HelpCommand from .history_search import HistorySearchCommand +from .hooks import HooksCommand from .lint import LintCommand from .list_sessions import ListSessionsCommand from .load import LoadCommand +from .load_hook import LoadHookCommand from .load_mcp import LoadMcpCommand from .load_session import LoadSessionCommand from .load_skill import LoadSkillCommand @@ -47,6 +49,7 @@ from .read_only import ReadOnlyCommand from .read_only_stub import ReadOnlyStubCommand from .reasoning_effort import ReasoningEffortCommand +from .remove_hook import RemoveHookCommand from .remove_mcp import RemoveMcpCommand from .remove_skill import RemoveSkillCommand from .report import ReportCommand @@ -102,9 +105,11 @@ CommandRegistry.register(HashlineCommand) CommandRegistry.register(HelpCommand) CommandRegistry.register(HistorySearchCommand) +CommandRegistry.register(HooksCommand) CommandRegistry.register(LintCommand) CommandRegistry.register(ListSessionsCommand) CommandRegistry.register(LoadCommand) +CommandRegistry.register(LoadHookCommand) CommandRegistry.register(LoadMcpCommand) CommandRegistry.register(LoadSessionCommand) CommandRegistry.register(LoadSkillCommand) @@ -119,6 +124,7 @@ CommandRegistry.register(ReadOnlyCommand) CommandRegistry.register(ReadOnlyStubCommand) CommandRegistry.register(ReasoningEffortCommand) +CommandRegistry.register(RemoveHookCommand) CommandRegistry.register(RemoveMcpCommand) CommandRegistry.register(RemoveSkillCommand) CommandRegistry.register(ReportCommand) @@ -171,9 +177,11 @@ "HashlineCommand", "HelpCommand", "HistorySearchCommand", + "HookCommand", "LintCommand", "ListSessionsCommand", "LoadCommand", + "LoadHookCommand", "LoadMcpCommand", "LoadSessionCommand", "LoadSkillCommand", @@ -190,6 +198,7 @@ "ReadOnlyCommand", "ReadOnlyStubCommand", "ReasoningEffortCommand", + "RemoveHookCommand", "RemoveMcpCommand", "RemoveSkillCommand", "ReportCommand", diff --git a/cecli/commands/add.py b/cecli/commands/add.py index d0166e0db31..f7080a2512f 100644 --- a/cecli/commands/add.py +++ b/cecli/commands/add.py @@ -41,7 +41,7 @@ async def execute(cls, io, coder, args, **kwargs): else: fname = Path(coder.root) / word - if coder.repo and coder.repo.ignored_file(fname): + if coder.repo and coder.repo.ignored_file(fname) and not coder.add_gitignore_files: io.tool_warning(f"Skipping {fname} due to cecli.ignore or --subtree-only.") continue diff --git a/cecli/commands/clear.py b/cecli/commands/clear.py index 64a1c21c45c..a7743460129 100644 --- a/cecli/commands/clear.py +++ b/cecli/commands/clear.py @@ -20,6 +20,7 @@ async def execute(cls, io, coder, args, **kwargs): ConversationManager.clear_tag(MessageTag.CUR) ConversationManager.clear_tag(MessageTag.DONE) ConversationManager.clear_tag(MessageTag.DIFFS) + ConversationManager.clear_tag(MessageTag.FILE_CONTEXTS) ConversationFiles.reset() diff --git a/cecli/commands/hooks.py b/cecli/commands/hooks.py new file mode 100644 index 00000000000..1ab6d299391 --- /dev/null +++ b/cecli/commands/hooks.py @@ -0,0 +1,117 @@ +import argparse +from typing import List + +from cecli.commands.utils.base_command import BaseCommand +from cecli.hooks.manager import HookManager +from cecli.hooks.types import HookType + + +class HooksCommand(BaseCommand): + NORM_NAME = "hooks" + DESCRIPTION = "List all registered hooks by type with their current state" + + @classmethod + async def execute(cls, io, coder, args, **kwargs): + """Execute the hooks command with given parameters.""" + # Parse the args string + parsed_args = cls._parse_args(args) + + # Get all hooks grouped by type + hook_manager = HookManager() + all_hooks = hook_manager.get_all_hooks() + + # Apply type filter if specified + if parsed_args.type: + filtered_hooks = {} + if parsed_args.type in all_hooks: + filtered_hooks[parsed_args.type] = all_hooks[parsed_args.type] + all_hooks = filtered_hooks + + # Display hooks + if not all_hooks: + io.tool_output("No hooks registered") + return 0 + + total_hooks = 0 + total_enabled = 0 + + for hook_type, hooks in sorted(all_hooks.items(), reverse=True): + # Apply state filters + filtered_hooks = [] + for hook in hooks: + if parsed_args.enabled_only and not hook.enabled: + continue + if parsed_args.disabled_only and hook.enabled: + continue + filtered_hooks.append(hook) + + if not filtered_hooks: + continue + + io.tool_output(f"\n{hook_type.upper()} hooks:") + io.tool_output("-" * 40) + + for hook in filtered_hooks: + status = "✓ ENABLED" if hook.enabled else "✗ DISABLED" + io.tool_output(f" {hook.name:30} {status}") + + # Show additional info if available + if hasattr(hook, "description") and hook.description: + io.tool_output(f" Description: {hook.description}") + if hasattr(hook, "priority"): + io.tool_output(f" Priority: {hook.priority}") + + total_hooks += 1 + if hook.enabled: + total_enabled += 1 + + io.tool_output( + f"\nTotal hooks: {total_hooks} ({total_enabled} enabled," + f" {total_hooks - total_enabled} disabled)" + ) + return 0 + + @classmethod + def _parse_args(cls, args_string: str) -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(prog="/hooks", add_help=False) + parser.add_argument( + "--type", + choices=[t.value for t in HookType], + help="Filter hooks by type (start, on_message, end_message, pre_tool, post_tool, end)", + ) + parser.add_argument("--enabled-only", action="store_true", help="Show only enabled hooks") + parser.add_argument("--disabled-only", action="store_true", help="Show only disabled hooks") + + try: + # Split args string and parse + args_list = args_string.split() + return parser.parse_args(args_list) + except SystemExit: + # argparse will call sys.exit() on error, we need to catch it + return argparse.Namespace(type=None, enabled_only=False, disabled_only=False) + + @classmethod + def get_completions(cls, io, coder, args) -> List[str]: + """Get completion options for hooks command.""" + return [] + + @classmethod + def get_help(cls) -> str: + """Get help text for the hooks command.""" + help_text = super().get_help() + help_text += "\nUsage:\n" + help_text += " /hooks # List all hooks\n" + help_text += " /hooks --type pre_tool # List only pre_tool hooks\n" + help_text += " /hooks --enabled-only # List only enabled hooks\n" + help_text += " /hooks --disabled-only # List only disabled hooks\n" + help_text += "\nExamples:\n" + help_text += " /hooks\n" + help_text += " /hooks --type start\n" + help_text += " /hooks --enabled-only\n" + help_text += "\nThis command displays all registered hooks grouped by type.\n" + help_text += ( + "Each hook shows its name, current state (enabled/disabled), and additional info.\n" + ) + help_text += "Use /load-hook and /remove-hook to enable or disable specific hooks.\n" + return help_text diff --git a/cecli/commands/load_hook.py b/cecli/commands/load_hook.py new file mode 100644 index 00000000000..e3370e4fe1b --- /dev/null +++ b/cecli/commands/load_hook.py @@ -0,0 +1,67 @@ +from typing import List + +from cecli.commands.utils.base_command import BaseCommand +from cecli.hooks.manager import HookManager + + +class LoadHookCommand(BaseCommand): + NORM_NAME = "load-hook" + DESCRIPTION = "Enable a specific hook by name" + + @classmethod + async def execute(cls, io, coder, args, **kwargs): + """Execute the load-hook command with given parameters.""" + # Get hook name from args string + if not args.strip(): + io.tool_error("Usage: /load-hook ") + return 1 + + hook_name = args.strip() + + # Check if hook exists + hook_manager = HookManager() + if not hook_manager.hook_exists(hook_name): + io.tool_error(f"Error: Hook '{hook_name}' not found") + return 1 + + # Enable the hook + success = hook_manager.enable_hook(hook_name) + + if success: + io.tool_output(f"Hook '{hook_name}' enabled successfully") + return 0 + else: + io.tool_error(f"Error: Failed to enable hook '{hook_name}'") + return 1 + + @classmethod + def get_completions(cls, io, coder, args) -> List[str]: + """Get completion options for load-hook command.""" + hook_manager = HookManager() + all_hooks = hook_manager.get_all_hooks() + + # Get all hook names + hook_names = [] + for hooks in all_hooks.values(): + for hook in hooks: + hook_names.append(hook.name) + + # Filter based on current args + current_arg = args.strip() + if current_arg: + return [name for name in hook_names if name.startswith(current_arg)] + else: + return hook_names + + @classmethod + def get_help(cls) -> str: + """Get help text for the load-hook command.""" + help_text = super().get_help() + help_text += "\nUsage:\n" + help_text += " /load-hook # Enable a specific hook\n" + help_text += "\nExamples:\n" + help_text += " /load-hook my_start_hook\n" + help_text += " /load-hook check_commands\n" + help_text += "\nThis command enables a hook that was previously disabled.\n" + help_text += "Use /hooks to see all available hooks and their current state.\n" + return help_text diff --git a/cecli/commands/remove_hook.py b/cecli/commands/remove_hook.py new file mode 100644 index 00000000000..378c5528326 --- /dev/null +++ b/cecli/commands/remove_hook.py @@ -0,0 +1,68 @@ +from typing import List + +from cecli.commands.utils.base_command import BaseCommand +from cecli.hooks.manager import HookManager + + +class RemoveHookCommand(BaseCommand): + NORM_NAME = "remove-hook" + DESCRIPTION = "Disable a specific hook by name" + + @classmethod + async def execute(cls, io, coder, args, **kwargs): + """Execute the remove-hook command with given parameters.""" + # Get hook name from args string + if not args.strip(): + io.tool_error("Usage: /remove-hook ") + return 1 + + hook_name = args.strip() + + # Check if hook exists + hook_manager = HookManager() + if not hook_manager.hook_exists(hook_name): + io.tool_error(f"Error: Hook '{hook_name}' not found") + return 1 + + # Disable the hook + success = hook_manager.disable_hook(hook_name) + + if success: + io.tool_output(f"Hook '{hook_name}' disabled successfully") + return 0 + else: + io.tool_error(f"Error: Failed to disable hook '{hook_name}'") + return 1 + + @classmethod + def get_completions(cls, io, coder, args) -> List[str]: + """Get completion options for remove-hook command.""" + hook_manager = HookManager() + all_hooks = hook_manager.get_all_hooks() + + # Get all hook names + hook_names = [] + for hooks in all_hooks.values(): + for hook in hooks: + hook_names.append(hook.name) + + # Filter based on current args + current_arg = args.strip() + if current_arg: + return [name for name in hook_names if name.startswith(current_arg)] + else: + return hook_names + + @classmethod + def get_help(cls) -> str: + """Get help text for the remove-hook command.""" + help_text = super().get_help() + help_text += "\nUsage:\n" + help_text += " /remove-hook # Disable a specific hook\n" + help_text += "\nExamples:\n" + help_text += " /remove-hook my_start_hook\n" + help_text += " /remove-hook check_commands\n" + help_text += "\nThis command disables a hook without removing it from the registry.\n" + help_text += "Use /load-hook to re-enable it later.\n" + help_text += "Use /hooks to see all available hooks and their current state.\n" + return help_text diff --git a/cecli/commands/tokens.py b/cecli/commands/tokens.py index 49651332bc2..ca5016ffa2e 100644 --- a/cecli/commands/tokens.py +++ b/cecli/commands/tokens.py @@ -59,6 +59,8 @@ async def execute(cls, io, coder, args, **kwargs): msgs_done = ConversationManager.get_messages_dict(tag=MessageTag.DONE) msgs_cur = ConversationManager.get_messages_dict(tag=MessageTag.CUR) msgs_diffs = ConversationManager.get_messages_dict(tag=MessageTag.DIFFS) + msgs_file_contexts = ConversationManager.get_messages_dict(tag=MessageTag.FILE_CONTEXTS) + tokens_done = 0 tokens_cur = 0 tokens_diffs = 0 @@ -72,6 +74,9 @@ async def execute(cls, io, coder, args, **kwargs): if msgs_diffs: tokens_diffs = coder.main_model.token_count(msgs_diffs) + if msgs_file_contexts: + tokens_file_contexts = coder.main_model.token_count(msgs_file_contexts) + if tokens_cur + tokens_done: res.append((tokens_cur + tokens_done, "chat history", "use /clear to clear")) # Add separate line for diffs if they exist @@ -79,6 +84,9 @@ async def execute(cls, io, coder, args, **kwargs): if tokens_diffs: res.append((tokens_diffs, "file diffs", "part of chat history")) + if tokens_file_contexts: + res.append((tokens_file_contexts, "numbered context messages", "part of chat history")) + # repo map if coder.repo_map: tokens = coder.main_model.token_count( diff --git a/cecli/commands/undo.py b/cecli/commands/undo.py index 3c6f3781edd..e4e09ea8211 100644 --- a/cecli/commands/undo.py +++ b/cecli/commands/undo.py @@ -13,6 +13,17 @@ class UndoCommand(BaseCommand): @classmethod async def execute(cls, io, coder, args, **kwargs): try: + # Clear chat history using ConversationManager + from cecli.helpers.conversation import ( + ConversationFiles, + ConversationManager, + MessageTag, + ) + + ConversationManager.clear_tag(MessageTag.DIFFS) + ConversationManager.clear_tag(MessageTag.FILE_CONTEXTS) + ConversationFiles.reset() + return await cls._raw_cmd_undo(io, coder, args) except ANY_GIT_ERROR as err: io.tool_error(f"Unable to complete undo: {err}") diff --git a/cecli/helpers/conversation/files.py b/cecli/helpers/conversation/files.py index b4fe510a6c1..cbe838de933 100644 --- a/cecli/helpers/conversation/files.py +++ b/cecli/helpers/conversation/files.py @@ -1,6 +1,6 @@ import os import weakref -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Tuple from cecli.helpers.hashline import get_hashline_content_diff, hashline from cecli.repomap import RepoMap @@ -25,6 +25,8 @@ class ConversationFiles: _file_to_message_id: Dict[str, str] = {} # Track image files separately since they don't have text content _image_files: Dict[str, bool] = {} + # Track numbered context ranges for files + _numbered_contexts: Dict[str, List[Tuple[int, int]]] = {} _coder_ref = None _initialized = False @@ -296,6 +298,7 @@ def clear_file_cache(cls, fname: Optional[str] = None) -> None: cls._file_timestamps.clear() cls._file_diffs.clear() cls._file_to_message_id.clear() + cls._numbered_contexts.clear() # New line else: abs_fname = os.path.abspath(fname) cls._file_contents_original.pop(abs_fname, None) @@ -304,6 +307,7 @@ def clear_file_cache(cls, fname: Optional[str] = None) -> None: cls._file_diffs.pop(abs_fname, None) cls._file_to_message_id.pop(abs_fname, None) cls._image_files.pop(abs_fname, None) + cls._numbered_contexts.pop(abs_fname, None) # New line @classmethod def add_image_file(cls, fname: str) -> None: @@ -346,10 +350,141 @@ def get_coder(cls): return cls._coder_ref() return None + @classmethod + def update_file_context(cls, file_path: str, start_line: int, end_line: int) -> None: + """ + Update numbered contexts for a file with a new range. + + Args: + file_path: Absolute file path + start_line: Start line number (1-based) + end_line: End line number (1-based) + """ + abs_fname = os.path.abspath(file_path) + + # Validate range + if start_line > end_line: + start_line, end_line = end_line, start_line + + # Get existing ranges + existing_ranges = cls._numbered_contexts.get(abs_fname, []) + + # Add new range + new_range = (start_line, end_line) + all_ranges = existing_ranges + [new_range] + + # Sort by start line + all_ranges.sort(key=lambda x: x[0]) + + # Merge overlapping or close ranges + merged_ranges = [] + for current_start, current_end in all_ranges: + if not merged_ranges: + merged_ranges.append([current_start, current_end]) + else: + last_start, last_end = merged_ranges[-1] + + # Check if ranges overlap or are close (within 20 lines) + if current_start <= last_end + 20: # Overlap or close + # Extend the range + merged_ranges[-1][1] = max(last_end, current_end) + else: + # Add as new range + merged_ranges.append([current_start, current_end]) + + # Convert back to tuples + cls._numbered_contexts[abs_fname] = [(start, end) for start, end in merged_ranges] + + # Remove using hash key (file_context, abs_fname) + ConversationManager.remove_message_by_hash_key(("file_context_user", abs_fname)) + ConversationManager.remove_message_by_hash_key(("file_context_assistant", abs_fname)) + + @classmethod + def get_file_context(cls, file_path: str) -> str: + """ + Generate hashline representation of cached context ranges. + + Args: + file_path: Absolute file path + + Returns: + Hashline representation of cached ranges, or empty string if no ranges + """ + abs_fname = os.path.abspath(file_path) + + # Get cached ranges + ranges = cls._numbered_contexts.get(abs_fname, []) + if not ranges: + return "" + + # Get coder instance + coder = cls.get_coder() + if not coder: + return "" + + # Read file content + try: + content = coder.io.read_text(abs_fname) + if not content: + return "" + except Exception: + return "" + + # Generate hashline representations for each range + context_parts = [] + for i, (start_line, end_line) in enumerate(ranges): + # Note: hashline uses 1-based line numbers, so no conversion needed + start_line_adj = max(1, start_line) + end_line_adj = min(len(content.splitlines()), end_line) + + if start_line_adj > end_line_adj: + continue + + # Extract lines for this range (0-based indexing for list) + lines = content.splitlines()[start_line_adj - 1 : end_line_adj] + + # Generate hashline representation using the hashline() function + # Join lines back with newlines for hashline() + range_content = "\n".join(lines) + hashline_content = hashline(range_content, start_line=start_line_adj) + + context_parts.append(hashline_content.strip()) + + # Join with ellipsis separator + return "\n...\n\n".join(context_parts) + + @classmethod + def remove_file_context(cls, file_path: str) -> None: + """ + Remove all cached context for a file. + + Args: + file_path: Absolute file path + """ + abs_fname = os.path.abspath(file_path) + + # Remove from numbered contexts + cls._numbered_contexts.pop(abs_fname, None) + + # Remove using hash key (file_context, abs_fname) + ConversationManager.remove_message_by_hash_key(("file_context_user", abs_fname)) + ConversationManager.remove_message_by_hash_key(("file_context_assistant", abs_fname)) + + @classmethod + def clear_all_numbered_contexts(cls) -> None: + """Clear all numbered contexts for all files.""" + cls._numbered_contexts.clear() + + @classmethod + def _get_numbered_contexts(cls) -> Dict[str, List[Tuple[int, int]]]: + """Get the numbered contexts dictionary.""" + return cls._numbered_contexts + @classmethod def reset(cls) -> None: """Clear all file caches and reset to initial state.""" cls.clear_file_cache() + cls.clear_all_numbered_contexts() cls._coder_ref = None cls._initialized = False diff --git a/cecli/helpers/conversation/integration.py b/cecli/helpers/conversation/integration.py index ba26651084e..30bd2b00856 100644 --- a/cecli/helpers/conversation/integration.py +++ b/cecli/helpers/conversation/integration.py @@ -85,7 +85,6 @@ def cleanup_files(cls, coder) -> None: coder: The coder instance """ - """ # Check diff message ratio and clear if too many diffs diff_messages = ConversationManager.get_messages_dict(MessageTag.DIFFS) read_only_messages = ConversationManager.get_messages_dict(MessageTag.READONLY_FILES) @@ -124,9 +123,9 @@ def cleanup_files(cls, coder) -> None: if should_clear: # Clear all diff messages ConversationManager.clear_tag(MessageTag.DIFFS) + ConversationManager.clear_tag(MessageTag.FILE_CONTEXTS) # Clear ConversationFiles caches to force regeneration ConversationFiles.clear_file_cache() - """ # Get all tracked files (both regular and image files) tracked_files = ConversationFiles.get_all_tracked_files() @@ -528,6 +527,7 @@ def add_chat_files_messages(cls, coder) -> Dict[str, Any]: # Get file content (with proper caching and stub generation) content = ConversationFiles.get_file_stub(fname) if not content: + ConversationFiles.clear_file_cache(fname) continue rel_fname = coder.get_rel_fname(fname) @@ -608,6 +608,52 @@ def add_chat_files_messages(cls, coder) -> Dict[str, Any]: return result + @classmethod + def add_file_context_messages(cls, coder) -> None: + """ + Create and insert FILE_CONTEXTS messages based on cached contexts. + + Args: + coder: The coder instance + """ + # Get numbered contexts + numbered_contexts = ConversationFiles._get_numbered_contexts() + + for file_path, ranges in numbered_contexts.items(): + if not ranges: + continue + + # Generate context content + context_content = ConversationFiles.get_file_context(file_path) + if not context_content: + continue + + # Get relative file name + rel_fname = coder.get_rel_fname(file_path) + + user_msg = { + "role": "user", + "content": f"Numbered Context For:\n{rel_fname}\n\n{context_content}", + } + + assistant_msg = { + "role": "assistant", + "content": "I understand, thank you for sharing the file contents.", + } + + # Add to conversation manager + ConversationManager.add_message( + message_dict=user_msg, + tag=MessageTag.FILE_CONTEXTS, + hash_key=("file_context_user", file_path), + ) + + ConversationManager.add_message( + message_dict=assistant_msg, + tag=MessageTag.FILE_CONTEXTS, + hash_key=("file_context_assistant", file_path), + ) + @classmethod def add_assistant_reply(cls, coder, partial_response_chunks) -> None: """ diff --git a/cecli/helpers/conversation/tags.py b/cecli/helpers/conversation/tags.py index 9755b198ea7..4972a1b90de 100644 --- a/cecli/helpers/conversation/tags.py +++ b/cecli/helpers/conversation/tags.py @@ -18,6 +18,7 @@ class MessageTag(str, Enum): CHAT_FILES = "chat_files" EDIT_FILES = "edit_files" DIFFS = "diffs" + FILE_CONTEXTS = "file_contexts" CUR = "cur" DONE = "done" REMINDER = "reminder" @@ -34,6 +35,7 @@ class MessageTag(str, Enum): MessageTag.CHAT_FILES: 200, MessageTag.EDIT_FILES: 200, MessageTag.DIFFS: 200, + MessageTag.FILE_CONTEXTS: 200, MessageTag.DONE: 200, MessageTag.CUR: 200, MessageTag.REMINDER: 300, @@ -51,6 +53,7 @@ class MessageTag(str, Enum): MessageTag.CHAT_FILES: 0, MessageTag.EDIT_FILES: 0, MessageTag.DIFFS: 0, + MessageTag.FILE_CONTEXTS: 0, MessageTag.DONE: 0, MessageTag.CUR: 0, MessageTag.REMINDER: 0, diff --git a/cecli/helpers/hashline.py b/cecli/helpers/hashline.py index 47c8008618f..364e539bd14 100644 --- a/cecli/helpers/hashline.py +++ b/cecli/helpers/hashline.py @@ -4,10 +4,10 @@ import xxhash # Regex patterns for hashline parsing -# Format: {line_number}|{hash_fragment}| -HASHLINE_PREFIX_RE = re.compile(r"^(-?\d+)\|([a-zA-Z]{2})\|") -# Format: {line_number}|{hash_fragment} -PARSE_NEW_FORMAT_RE = re.compile(r"^(-?\d+)\|([a-zA-Z]{2})$") +# Format: |{line_number}{hash_fragment}| +HASHLINE_PREFIX_RE = re.compile(r"^\|?(-?\d+)([a-zA-Z]{2})\|") +# Format: |{line_number}{hash_fragment}| +PARSE_NEW_FORMAT_RE = re.compile(r"^\|?(-?\d+)([a-zA-Z]{2})\|?$") # Format: {hash_fragment}|{line_number} PARSE_OLD_FORMAT_RE = re.compile(r"^([a-zA-Z]{2})\|(-?\d+)$") @@ -23,7 +23,7 @@ def hashline(text: str, start_line: int = 1) -> str: Add a hash scheme to each line of text. For each line in the input text, returns a string where each line is prefixed with: - "{line number}|{2-digit base52 of xxhash mod 52^2}|{line contents}" + "|{line number}{2-digit base52 of xxhash mod 52^2}|{line contents}" Args: text: Input text (most likely representing a file's text) @@ -37,7 +37,7 @@ def hashline(text: str, start_line: int = 1) -> str: for i, line in enumerate(lines, start=start_line): # Calculate xxhash for the line content - hash_value = xxhash.xxh3_64_intdigest(line.encode("utf-8")) + hash_value = xxhash.xxh3_64_intdigest(line.strip().encode("utf-8")) # Use mod 52^2 (2704) for faster computation mod_value = hash_value % 2704 # 52^2 = 2704 @@ -46,7 +46,7 @@ def hashline(text: str, start_line: int = 1) -> str: last_two_str = int_to_2digit_52(mod_value) # Format the line - formatted_line = f"{i}|{last_two_str}|{line}" + formatted_line = f"|{i}{last_two_str}|{line}" result_lines.append(formatted_line) return "".join(result_lines) @@ -93,7 +93,7 @@ def strip_hashline(text: str) -> str: """ Remove hashline-like sequences from the start of every line. - Removes prefixes that match the pattern: "{line number}|{2-digit base52}|" + Removes prefixes that match the pattern: "|{line number}{2-digit base52}|" where line number can be any integer (positive, negative, or zero) and the 2-digit base52 is exactly 2 characters from the set [a-zA-Z]. @@ -118,7 +118,7 @@ def parse_hashline(hashline_str: str): Parse a hashline string into hash fragment and line number. Args: - hashline_str: Hashline format string: "{line_num}|{hash_fragment}" + hashline_str: Hashline format string: "{line_num}{hash_fragment}" Returns: tuple: (hash_fragment, line_num_str, line_num) @@ -130,9 +130,10 @@ def parse_hashline(hashline_str: str): raise HashlineError("Hashline string cannot be None") try: - hashline_str = hashline_str.rstrip("|") + # No longer rstrip("|") here as the regex handles optional trailing pipe + # and we want to preserve the leading pipe for the new format. - # Try new format first: {line_num}|{hash_fragment} + # Try new format first: |{line_num}{hash_fragment}| match = PARSE_NEW_FORMAT_RE.match(hashline_str) if match: line_num_str, hash_fragment = match.groups() @@ -151,16 +152,18 @@ def parse_hashline(hashline_str: str): def normalize_hashline(hashline_str: str) -> str: """ - Normalize a hashline string to the proper "{line_num}|{hash_fragment}" format. + Normalize a hashline string to the proper "{line_num}{hash_fragment}" format. - Accepts hashline strings in either "{hash_fragment}|{line_num}" format or - "{line_num}|{hash_fragment}" format and returns it in the proper format. + Accepts hashline strings in either "{line_num}{hash_fragment}" format or + "{hash_fragment}|{line_num}" format and returns it in the proper format. + Also extracts hashline from strings that contain content after the hashline, + e.g., "|1100df| # Range-shifting logic..." Args: - hashline_str: Hashline string in either format + hashline_str: Hashline string in either format, optionally with content after Returns: - str: Hashline string in "{line_num}|{hash_fragment}" format + str: Hashline string in "{line_num}{hash_fragment}" format Raises: HashlineError: If format is invalid @@ -168,28 +171,50 @@ def normalize_hashline(hashline_str: str) -> str: if hashline_str is None: raise HashlineError("Hashline string cannot be None") - # Try to parse as "{line_num}|{hash_fragment}" first (preferred) + # Try to parse as exact "|{line_num}{hash_fragment}|" first (preferred) match1 = PARSE_NEW_FORMAT_RE.match(hashline_str) if match1: return hashline_str - # Try to parse as "{hash_fragment}|{line_num}" + # Try to parse as exact "{hash_fragment}|{line_num}" match2 = PARSE_OLD_FORMAT_RE.match(hashline_str) if match2: hash_fragment, line_num_str = match2.groups() - return f"{line_num_str}|{hash_fragment}" + return f"|{line_num_str}{hash_fragment}|" + + # If exact matches fail, try to extract hashline from the beginning of the string + # First try new format with content: |{line_num}{hash_fragment}|... + match3 = HASHLINE_PREFIX_RE.match(hashline_str) + if match3: + line_num_str, hash_fragment = match3.groups() + return f"|{line_num_str}{hash_fragment}|" + + # Try to extract old format with content: {hash_fragment}|{line_num}|... + # We need a regex that matches the old format with optional content after + # Pattern: {hash_fragment}|{line_num}|... where hash_fragment is 2 letters, line_num is integer + old_format_with_content_re = re.compile(r"^([a-zA-Z]{2})\|(-?\d+)\|?") + match4 = old_format_with_content_re.match(hashline_str) + if match4: + hash_fragment, line_num_str = match4.groups() + return f"|{line_num_str}{hash_fragment}|" + + old_format_with_content_re = re.compile(r"^(-?\d+)\|([a-zA-Z]{2})\|?") + match5 = old_format_with_content_re.match(hashline_str) + if match5: + line_num_str, hash_fragment = match5.groups() + return f"|{line_num_str}{hash_fragment}|" # If neither pattern matches, raise error raise HashlineError( f"Invalid hashline format '{hashline_str}'. " - "Expected either '{line_num}|{hash_fragment}' or '{hash_fragment}|{line_num}' " - "where hash_fragment is exactly 2 letters and line_num is an integer." + "Expected '{line_num}{hash_fragment}' " + "where line_num is an integer and hash_fragment is exactly 2 letters. " ) def find_hashline_by_exact_match(hashed_lines, hash_fragment, line_num_str): """ - Find a hashline by exact line_num|hash_fragment match. + Find a hashline by |{exact line_num}{hash_fragment match}|. Args: hashed_lines: List of hashed lines @@ -200,7 +225,7 @@ def find_hashline_by_exact_match(hashed_lines, hash_fragment, line_num_str): int: Index of matching line, or None if not found """ for i, line in enumerate(hashed_lines): - if line.startswith(f"{line_num_str}|{hash_fragment}|"): + if line.startswith(f"|{line_num_str}{hash_fragment}|"): return i return None @@ -221,16 +246,15 @@ def find_hashline_by_fragment(hashed_lines, hash_fragment, target_line_num=None) """ matches = [] for i, line in enumerate(hashed_lines): - parts = line.split("|", 2) - if len(parts) < 3: + match = HASHLINE_PREFIX_RE.match(line) + if not match: continue - line_hash_fragment = parts[1] + line_num_part, line_hash_fragment = match.groups() if line_hash_fragment == hash_fragment: if target_line_num is None: return i # Return first match for backward compatibility # Extract line number from hashline - line_num_part = parts[0] try: line_num = int(line_num_part) distance = abs(line_num - target_line_num) @@ -274,10 +298,19 @@ def find_hashline_range( """ # Parse start_line_hash start_hash_fragment, start_line_num_str, start_line_num = parse_hashline(start_line_hash) + found_start_line = None + # Special handling for genesis anchor "0aa" + if start_hash_fragment == "aa" and start_line_num == 0: + found_start_line = 0 + if not hashed_lines: + # Genesis anchor for empty content - return 0 for both start and end + found_end_line = 0 + return found_start_line, found_end_line + # For non-empty files, 0aa as start anchor means the first line (index 0) + # We continue to find found_end_line normally. # Try to find start line - found_start_line = None - if allow_exact_match: + if found_start_line is None and allow_exact_match: found_start_line = find_hashline_by_exact_match( hashed_lines, start_hash_fragment, start_line_num_str ) @@ -319,7 +352,8 @@ def find_hashline_range( # Check if end hash fragment matches at the expected position # If not, use find_hashline_by_fragment() to find the closest match actual_end_hashed_line = hashed_lines[expected_found_end_line] - actual_end_hash_fragment = actual_end_hashed_line.split(":", 1)[0] + match = HASHLINE_PREFIX_RE.match(actual_end_hashed_line) + actual_end_hash_fragment = match.group(2) if match else None if actual_end_hash_fragment != end_hash_fragment: # Instead of raising an error, try to find the closest matching hash fragment @@ -529,6 +563,9 @@ def get_hashline_diff( HashlineError: If hashline verification fails or operation is invalid """ + start_line_hash = normalize_hashline(start_line_hash) + end_line_hash = normalize_hashline(end_line_hash) + if operation == "insert": end_line_hash = start_line_hash @@ -610,17 +647,16 @@ def _parse_content_for_diff(content: str): content_only_lines = [] for line_num, line in enumerate(content.splitlines(keepends=True), 1): - if "|" in line: - parts = line.split("|", 1) - if len(parts) == 2: - line_content = parts[1].rstrip("\r\n") - hashline_prefixed = line.rstrip("\r\n") - hashlines.append(hashline_prefixed) - content_only_lines.append(line_content) - if line_content not in content_to_lines: - content_to_lines[line_content] = [] - content_to_lines[line_content].append(line_num) - continue + match = HASHLINE_PREFIX_RE.match(line) + if match: + line_content = line[match.end() :].rstrip("\r\n") + hashline_prefixed = line.rstrip("\r\n") + hashlines.append(hashline_prefixed) + content_only_lines.append(line_content) + if line_content not in content_to_lines: + content_to_lines[line_content] = [] + content_to_lines[line_content].append(line_num) + continue # Line without hashline prefix or malformed stripped = line.rstrip("\r\n") @@ -1004,6 +1040,255 @@ def _apply_end_stitching( return end_idx, replacement_lines +def _apply_range_shifting(hashed_lines, resolved_ops): + """ + Apply content-aware range expansion/shifting for replace operations. + + Adjusts ranges if replacement text includes boundary context to prevent + duplication and ensure proper stitching of code. + + Args: + hashed_lines: List of hashed lines from the file + resolved_ops: List of resolved operation dictionaries + + Returns: + Updated resolved_ops with adjusted ranges + """ + for i, resolved in enumerate(resolved_ops): + op = resolved["op"] + if op["operation"] == "replace" and op.get("text"): + replacement_lines = op["text"].splitlines(keepends=True) + if not replacement_lines: + continue + + # Check for downward expansion/shift (end_idx) + if resolved["end_idx"] < len(hashed_lines) - 1: + last_rep_line = strip_hashline(replacement_lines[-1]) + if not last_rep_line.endswith("\n"): + last_rep_line += "\n" + + file_line_after = strip_hashline(hashed_lines[resolved["end_idx"] + 1]) + if not file_line_after.endswith("\n"): + file_line_after += "\n" + + if last_rep_line == file_line_after: + # Only expand/shift if the last line of replacement is NOT the same as the last line of the range + file_line_end = strip_hashline(hashed_lines[resolved["end_idx"]]) + if not file_line_end.endswith("\n"): + file_line_end += "\n" + + if last_rep_line != file_line_end: + # Determine if we should expand or shift + range_len = resolved["end_idx"] - resolved["start_idx"] + 1 + should_expand = len(replacement_lines) > range_len + + new_start = ( + resolved["start_idx"] if should_expand else resolved["start_idx"] + 1 + ) + new_end = resolved["end_idx"] + 1 + + # Safety check: ensure new range doesn't overlap with another operation + overlap = False + for j, other in enumerate(resolved_ops): + if i != j and ( + (other["start_idx"] <= new_start <= other["end_idx"]) + or (other["start_idx"] <= new_end <= other["end_idx"]) + ): + overlap = True + break + if not overlap: + resolved["start_idx"] = new_start + resolved["end_idx"] = new_end + + # Check for upward expansion/shift (start_idx) + if resolved["start_idx"] > 0: + first_rep_line = strip_hashline(replacement_lines[0]) + if not first_rep_line.endswith("\n"): + first_rep_line += "\n" + + file_line_before = strip_hashline(hashed_lines[resolved["start_idx"] - 1]) + if not file_line_before.endswith("\n"): + file_line_before += "\n" + + if first_rep_line == file_line_before: + # Only expand/shift if the first line of replacement is NOT the same as the first line of the range + file_line_start = strip_hashline(hashed_lines[resolved["start_idx"]]) + if not file_line_start.endswith("\n"): + file_line_start += "\n" + + if first_rep_line != file_line_start: + # Determine if we should expand or shift + # If replacement is longer than range, expand. Otherwise shift. + range_len = resolved["end_idx"] - resolved["start_idx"] + 1 + should_expand = len(replacement_lines) > range_len + + new_start = resolved["start_idx"] - 1 + new_end = resolved["end_idx"] if should_expand else resolved["end_idx"] - 1 + + # Safety check: ensure new range doesn't overlap with another operation + overlap = False + for j, other in enumerate(resolved_ops): + if i != j and ( + (other["start_idx"] <= new_start <= other["end_idx"]) + or (other["start_idx"] <= new_end <= other["end_idx"]) + ): + overlap = True + break + if not overlap: + resolved["start_idx"] = new_start + resolved["end_idx"] = new_end + + return resolved_ops + + +def _apply_closure_safeguard(hashed_lines, resolved_ops): + """ + Apply closure safeguard for replace operations ending with braces/brackets. + + Detects when replacement text ends in a brace or bracket and if the end line + of the replacement range is also a brace or bracket but with fewer spaces/tabs + preceding it (less indented), moves the resolved end range up one line. + + This prevents including outer closing braces/brackets in the replacement range + when the model includes them in the replacement text. + + Args: + hashed_lines: List of hashed lines from the file + resolved_ops: List of resolved operation dictionaries + + Returns: + Updated resolved_ops with adjusted ranges for closure safeguard + """ + for i, resolved in enumerate(resolved_ops): + op = resolved["op"] + if op["operation"] == "replace" and op.get("text"): + replacement_lines = op["text"].splitlines(keepends=True) + if not replacement_lines: + continue + + # Check if replacement text ends with a brace or bracket + last_replacement_line = strip_hashline(replacement_lines[-1]) + last_replacement_line_stripped = last_replacement_line.strip() + + # Check if it ends with a closing brace/bracket (ignoring trailing punctuation) + check_text = last_replacement_line_stripped.rstrip(";,") + if check_text and check_text[-1] in "})]": + # Get the indentation of the last replacement line + # Count leading spaces/tabs + replacement_indent = 0 + for char in last_replacement_line: + if char in " \t": + replacement_indent += 1 + else: + break + + # Check if the end line of the range is also a brace/bracket + if resolved["end_idx"] < len(hashed_lines): + end_line = strip_hashline(hashed_lines[resolved["end_idx"]]) + end_line_stripped = end_line.strip() + + # Check if it ends with a closing brace/bracket (ignoring trailing punctuation) + check_end = end_line_stripped.rstrip(";,") + if check_end and check_end[-1] in "})]": + # Get indentation of the end line + end_line_indent = 0 + for char in end_line: + if char in " \t": + end_line_indent += 1 + else: + break + + # If end line has less indentation (fewer spaces/tabs) + # and we're not at the start of the range + if ( + end_line_indent < replacement_indent + and resolved["end_idx"] > resolved["start_idx"] + ): + # Check if moving up one line would still be valid + new_end_idx = resolved["end_idx"] - 1 + + # Safety check: ensure new range doesn't overlap with another operation + overlap = False + for j, other in enumerate(resolved_ops): + if i != j and ( + ( + other["start_idx"] + <= resolved["start_idx"] + <= other["end_idx"] + ) + or (other["start_idx"] <= new_end_idx <= other["end_idx"]) + ): + overlap = True + break + + if not overlap: + resolved["end_idx"] = new_end_idx + + return resolved_ops + + +def _merge_replace_operations(resolved_ops): + """ + Merge contiguous or overlapping replace operations. + """ + if len(resolved_ops) < 2: + return resolved_ops + + # Sort by start_idx to find contiguous operations + resolved_ops.sort(key=lambda x: (x["start_idx"], x["end_idx"])) + + merged = [] + for current in resolved_ops: + if not merged: + merged.append(current) + continue + + prev = merged[-1] + + # Only merge replace operations that have text + if ( + prev["op"]["operation"] == "replace" + and current["op"]["operation"] == "replace" + and prev["op"].get("text") is not None + and current["op"].get("text") is not None + ): + # Check if they are contiguous or overlapping + # Contiguous: prev.end_idx == current.start_idx OR prev.end_idx + 1 == current.start_idx + if prev["end_idx"] >= current["start_idx"] - 1: + prev_text = prev["op"]["text"] + curr_text = current["op"]["text"] + + prev_lines = prev_text.splitlines(keepends=True) + curr_lines = curr_text.splitlines(keepends=True) + + # Find longest overlap between suffix of prev and prefix of current + max_check = min(len(prev_lines), len(curr_lines)) + overlap_len = 0 + for i in range(1, max_check + 1): + if prev_lines[-i:] == curr_lines[:i]: + overlap_len = i + + if overlap_len > 0: + new_text = "".join(prev_lines) + "".join(curr_lines[overlap_len:]) + else: + # No overlap, just concatenate + new_text = prev_text + curr_text + + # Update prev + prev["end_idx"] = max(prev["end_idx"], current["end_idx"]) + prev["op"]["text"] = new_text + + # Track merged indices + if "merged_indices" not in prev: + prev["merged_indices"] = [prev["index"]] + prev["merged_indices"].append(current["index"]) + continue + + merged.append(current) + + return merged + + def apply_hashline_operations( original_content: str, operations: list, @@ -1054,20 +1339,26 @@ def apply_hashline_operations( op["start_line_hash"] ) - # Try exact match first for insert operations - found_start = find_hashline_by_exact_match( - hashed_lines, start_hash_fragment, start_line_num_str - ) - - if found_start is None: - found_start = find_hashline_by_fragment( - hashed_lines, start_hash_fragment, start_line_num + # Special handling for genesis anchor "0aa" + if start_hash_fragment == "aa" and start_line_num == 0: + # Genesis anchor - if empty, insert at 0. If not empty, insert at -1 + # so that hashed_lines.insert(found_start + 1, text) inserts at 0. + found_start = 0 if not hashed_lines else -1 + else: + # Try exact match first for insert operations + found_start = find_hashline_by_exact_match( + hashed_lines, start_hash_fragment, start_line_num_str ) - if found_start is None: - raise HashlineError( - f"Start line hash fragment '{start_hash_fragment}' not found in file" - ) + if found_start is None: + found_start = find_hashline_by_fragment( + hashed_lines, start_hash_fragment, start_line_num + ) + + if found_start is None: + raise HashlineError( + f"Start line hash fragment '{start_hash_fragment}' not found in file" + ) resolved_ops.append( {"index": i, "start_idx": found_start, "end_idx": found_start, "op": op} @@ -1105,8 +1396,6 @@ def apply_hashline_operations( # Replace resolved_ops with deduplicated version resolved_ops = deduplicated_ops - # Optimize: discard inner ranges that are completely contained within outer ranges - # Optimize: discard inner ranges that are completely contained within outer ranges # This prevents redundant operations and potential errors optimized_ops = [] @@ -1141,6 +1430,13 @@ def apply_hashline_operations( # Replace resolved_ops with optimized version resolved_ops = optimized_ops + # Merge contiguous replace operations + resolved_ops = _merge_replace_operations(resolved_ops) + # Apply content-aware range expansion/shifting for replace operations + resolved_ops = _apply_range_shifting(hashed_lines, resolved_ops) + # Apply closure safeguard for braces/brackets + resolved_ops = _apply_closure_safeguard(hashed_lines, resolved_ops) + # Sort by start_idx descending to apply from bottom to top # When operations have same start_idx, apply in order: insert, replace, delete # This ensures correct behavior when multiple operations target the same line @@ -1171,9 +1467,14 @@ def sort_key(op): text = op["text"] if text and not text.endswith("\n"): text += "\n" - if not hashed_lines[start_idx].endswith("\n"): - hashed_lines[start_idx] += "\n" - hashed_lines.insert(start_idx + 1, text) + # Special handling for empty hashed_lines (genesis anchor case) + if hashed_lines: + if not hashed_lines[start_idx].endswith("\n"): + hashed_lines[start_idx] += "\n" + hashed_lines.insert(start_idx + 1, text) + else: + # Empty content with genesis anchor - just add the text + hashed_lines.append(text) elif op["operation"] == "delete": del hashed_lines[start_idx : end_idx + 1] elif op["operation"] == "replace": @@ -1212,7 +1513,10 @@ def sort_key(op): # Empty text - replace with nothing (delete) hashed_lines[start_idx : end_idx + 1] = [] - successful_ops.append(resolved["index"]) + if "merged_indices" in resolved: + successful_ops.extend(resolved["merged_indices"]) + else: + successful_ops.append(resolved["index"]) except Exception as e: failed_ops.append( {"index": resolved["index"], "error": str(e), "operation": resolved["op"]} diff --git a/cecli/hooks/__init__.py b/cecli/hooks/__init__.py new file mode 100644 index 00000000000..339a5fbe6db --- /dev/null +++ b/cecli/hooks/__init__.py @@ -0,0 +1,17 @@ +"""Hooks module for extending cecli functionality.""" + +from .base import BaseHook, CommandHook +from .integration import HookIntegration +from .manager import HookManager +from .registry import HookRegistry +from .types import METADATA_TEMPLATES, HookType + +__all__ = [ + "BaseHook", + "CommandHook", + "HookIntegration", + "HookManager", + "HookRegistry", + "HookType", + "METADATA_TEMPLATES", +] diff --git a/cecli/hooks/base.py b/cecli/hooks/base.py new file mode 100644 index 00000000000..1a2943377c9 --- /dev/null +++ b/cecli/hooks/base.py @@ -0,0 +1,111 @@ +import asyncio +import shlex +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +from cecli.run_cmd import run_cmd + +from .types import HookType + + +class BaseHook(ABC): + """Base class for all hooks.""" + + type: HookType + name: str + priority: int = 10 + enabled: bool = True + description: Optional[str] = None + + def __init__(self, name: Optional[str] = None, priority: int = 10, enabled: bool = True): + """Initialize a hook. + + Args: + name: Optional name for the hook. If not provided, uses class name. + priority: Hook priority (lower = higher priority). Default is 10. + enabled: Whether the hook is enabled. Default is True. + """ + self.name = name or self.__class__.__name__ + self.priority = priority + self.enabled = enabled + + # Validate that subclass has defined type + if not hasattr(self, "type") or self.type is None: + raise ValueError(f"Hook {self.__class__.__name__} must define a 'type' attribute") + + @abstractmethod + async def execute(self, coder: Any, metadata: Dict[str, Any]) -> Any: + """Execute the hook logic. + + Args: + coder: The coder instance providing context. + metadata: Dictionary with metadata about the current operation. + + Returns: + Any value. For Python hooks, return False or falsy value to abort operation. + For command hooks, non-zero exit code aborts operation. + """ + pass + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(name='{self.name}', type={self.type}," + f" priority={self.priority}, enabled={self.enabled})" + ) + + +class CommandHook(BaseHook): + """Hook that executes a command-line script.""" + + command: str + + def __init__(self, command: str, hook_type: str, **kwargs): + """Initialize a command hook. + + Args: + command: The command to execute. + **kwargs: Additional arguments passed to BaseHook. + """ + self.type = hook_type + super().__init__(**kwargs) + self.command = command + + async def execute(self, coder: Any, metadata: Dict[str, Any]) -> Any: + """Execute the command hook. + + Args: + coder: The coder instance providing context. + metadata: Dictionary with metadata about the current operation. + + Returns: + Exit code of the command. Non-zero exit code aborts operation. + """ + import subprocess + + # Escape metadata values for shell safety + safe_metadata = {k: shlex.quote(str(v)) for k, v in metadata.items()} + + # Format command with metadata + formatted_command = self.command.format(**safe_metadata) + + try: + exit_status, result = await asyncio.to_thread( + run_cmd, formatted_command, error_print=coder.io.tool_error, cwd=coder.root + ) + + printed_result = "" + + if result: + printed_result = f" result: {result}" + + if coder.verbose or exit_status != 0: + print(f"[Hook {self.name}]{printed_result}") + + return exit_status + + except subprocess.TimeoutExpired: + print(f"[Hook {self.name}] Timeout") + return 1 # Non-zero to abort + except Exception as e: + print(f"[Hook {self.name}] Error: {e}") + return 1 # Non-zero to abort diff --git a/cecli/hooks/integration.py b/cecli/hooks/integration.py new file mode 100644 index 00000000000..1ad676a535d --- /dev/null +++ b/cecli/hooks/integration.py @@ -0,0 +1,124 @@ +"""Integration of hooks into the cecli agent loop.""" + +import time +from typing import Any, Optional + +from .manager import HookManager +from .types import HookType + + +class HookIntegrationBase: + """Class to integrate hooks into the cecli agent loop.""" + + def __init__(self, hook_manager: Optional[HookManager] = None): + """Initialize hook integration.""" + self.hook_manager = hook_manager or HookManager() + + async def call_start_hooks(self, coder: Any) -> bool: + """Call start hooks when agent session begins. + + Args: + coder: The coder instance. + + Returns: + True if all hooks succeeded, False otherwise. + """ + metadata = { + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "coder_type": coder.__class__.__name__, + } + + return await self.hook_manager.call_hooks(HookType.START.value, coder, metadata) + + async def call_end_hooks(self, coder: Any) -> bool: + """Call end hooks when agent session ends. + + Args: + coder: The coder instance. + + Returns: + True if all hooks succeeded, False otherwise. + """ + metadata = { + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "coder_type": coder.__class__.__name__, + } + return await self.hook_manager.call_hooks(HookType.END.value, coder, metadata) + + async def call_on_message_hooks(self, coder: Any, message: str) -> bool: + """Call on_message hooks when a new message is received. + + Args: + coder: The coder instance. + message: The user message content. + + Returns: + True if all hooks succeeded, False otherwise. + """ + metadata = { + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "message": message, + "message_length": len(message), + } + return await self.hook_manager.call_hooks(HookType.ON_MESSAGE.value, coder, metadata) + + async def call_end_message_hooks(self, coder: Any, message: str) -> bool: + """Call end_message hooks when message processing completes. + + Args: + coder: The coder instance. + message: The user message content. + + Returns: + True if all hooks succeeded, False otherwise. + """ + metadata = { + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "message": message, + "message_length": len(message), + } + return await self.hook_manager.call_hooks(HookType.END_MESSAGE.value, coder, metadata) + + async def call_pre_tool_hooks(self, coder: Any, tool_name: str, arg_string: str) -> bool: + """Call pre_tool hooks before tool execution. + + Args: + coder: The coder instance. + tool_name: The name of the tool to be executed. + arg_string: The argument string for the tool. + + Returns: + True if all hooks succeeded (tool execution should proceed), False otherwise. + """ + metadata = { + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "tool_name": tool_name, + "arg_string": arg_string, + } + return await self.hook_manager.call_hooks(HookType.PRE_TOOL.value, coder, metadata) + + async def call_post_tool_hooks( + self, coder: Any, tool_name: str, arg_string: str, output: str + ) -> bool: + """Call post_tool hooks after tool execution completes. + + Args: + coder: The coder instance. + tool_name: The name of the tool that was executed. + arg_string: The argument string for the tool. + output: The output from the tool execution. + + Returns: + True if all hooks succeeded, False otherwise. + """ + metadata = { + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "tool_name": tool_name, + "arg_string": arg_string, + "output": output, + } + return await self.hook_manager.call_hooks(HookType.POST_TOOL.value, coder, metadata) + + +# Global instance for easy access +HookIntegration = HookIntegrationBase() diff --git a/cecli/hooks/manager.py b/cecli/hooks/manager.py new file mode 100644 index 00000000000..9f9119bd814 --- /dev/null +++ b/cecli/hooks/manager.py @@ -0,0 +1,243 @@ +import json +import threading +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List + +from .base import BaseHook +from .types import HookType + + +class HookManager: + """Central registry and dispatcher for hooks.""" + + _instance = None + _lock = threading.Lock() + + def __new__(cls): + """Singleton pattern.""" + if cls._instance is None: + cls._instance = super(HookManager, cls).__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + """Initialize the hook manager.""" + if self._initialized: + return + + self._hooks_by_type: Dict[str, List[BaseHook]] = defaultdict(list) + self._hooks_by_name: Dict[str, BaseHook] = {} + self._state_file = Path.home() / ".cecli" / "hooks_state.json" + self._state_lock = threading.Lock() + + # Ensure state directory exists + self._state_file.parent.mkdir(parents=True, exist_ok=True) + + self._initialized = True + + def register_hook(self, hook: BaseHook) -> None: + """Register a hook instance. + + Args: + hook: The hook instance to register. + + Raises: + ValueError: If hook with same name already exists. + """ + with self._lock: + if hook.name in self._hooks_by_name: + raise ValueError(f"Hook with name '{hook.name}' already exists") + + # Load saved state if available + # self._load_hook_state(hook) + + # Add to registries + self._hooks_by_type[hook.type.value].append(hook) + self._hooks_by_name[hook.name] = hook + + # Sort hooks by priority (lower = higher priority) + self._hooks_by_type[hook.type.value].sort(key=lambda h: h.priority) + + def get_hooks(self, hook_type: str) -> List[BaseHook]: + """Return hooks of specific type, sorted by priority. + + Args: + hook_type: The hook type to retrieve. + + Returns: + List of hooks of the specified type, sorted by priority. + """ + with self._lock: + hooks = self._hooks_by_type.get(hook_type, []) + # Return only enabled hooks + return [h for h in hooks if h.enabled] + + def get_all_hooks(self) -> Dict[str, List[BaseHook]]: + """Get all hooks grouped by type for display. + + Returns: + Dictionary mapping hook types to lists of hooks. + """ + with self._lock: + return {hook_type: hooks.copy() for hook_type, hooks in self._hooks_by_type.items()} + + def hook_exists(self, name: str) -> bool: + """Check if a hook exists by name. + + Args: + name: The hook name to check. + + Returns: + True if hook exists, False otherwise. + """ + with self._lock: + return name in self._hooks_by_name + + def enable_hook(self, name: str) -> bool: + """Enable a hook by name and persist state. + + Args: + name: The hook name to enable. + + Returns: + True if hook was enabled, False if hook not found. + """ + with self._lock: + if name not in self._hooks_by_name: + return False + + hook = self._hooks_by_name[name] + hook.enabled = True + # self._save_state() + return True + + def disable_hook(self, name: str) -> bool: + """Disable a hook by name and persist state. + + Args: + name: The hook name to disable. + + Returns: + True if hook was disabled, False if hook not found. + """ + with self._lock: + if name not in self._hooks_by_name: + return False + + hook = self._hooks_by_name[name] + hook.enabled = False + # self._save_state() + return True + + async def call_hooks(self, hook_type: str, coder: Any, metadata: Dict[str, Any]) -> bool: + """Execute all hooks of a type. + + Args: + hook_type: The hook type to execute. + coder: The coder instance providing context. + metadata: Dictionary with metadata about the current operation. + + Returns: + True if all hooks succeeded (or no hooks to run), False if any hook failed. + """ + hooks = self.get_hooks(hook_type) + if not hooks: + return True + + all_succeeded = True + + for hook in hooks: + if not hook.enabled: + continue + + try: + result = await hook.execute(coder, metadata) + + # Check if hook indicates failure + if hook_type in [HookType.PRE_TOOL.value, HookType.POST_TOOL.value]: + # For tool hooks, falsy value or non-zero exit code indicates failure + if isinstance(result, bool): + # Boolean result: False indicates failure + if not result: + print(f"[Hook {hook.name}] Returned False") + all_succeeded = False + elif isinstance(result, int): + # Integer result: non-zero indicates failure + if result != 0: + print(f"[Hook {hook.name}] Failed with exit code {result}") + all_succeeded = False + elif not result: + # Other falsy value indicates failure + print(f"[Hook {hook.name}] Returned falsy value: {result}") + all_succeeded = False + + except Exception as e: + print(f"[Hook {hook.name}] Error during execution: {e}") + # Continue with other hooks even if one fails + + return all_succeeded + + def _load_hook_state(self, hook: BaseHook) -> None: + """Load saved state for a hook.""" + if not self._state_file.exists(): + return + + try: + with self._state_lock: + with open(self._state_file, "r") as f: + state = json.load(f) + + if hook.name in state: + hook.enabled = state[hook.name] + except (json.JSONDecodeError, IOError) as e: + print(f"Warning: Could not load hook state from {self._state_file}: {e}") + + def _save_state(self) -> None: + """Save hook states to configuration file.""" + try: + with self._state_lock: + # Create backup of existing state + if self._state_file.exists(): + backup_file = self._state_file.with_suffix(".json.bak") + import shutil + + shutil.copy2(self._state_file, backup_file) + + # Save current state + state = {name: hook.enabled for name, hook in self._hooks_by_name.items()} + + # Write to temporary file first, then rename (atomic write) + temp_file = self._state_file.with_suffix(".json.tmp") + with open(temp_file, "w") as f: + json.dump(state, f, indent=2) + + # Atomic rename + temp_file.rename(self._state_file) + + except Exception as e: + print(f"Warning: Could not save hook state to {self._state_file}: {e}") + + def _load_state(self) -> None: + """Load hook states from configuration file.""" + if not self._state_file.exists(): + return + + try: + with self._state_lock: + with open(self._state_file, "r") as f: + state = json.load(f) + + # Apply loaded state to registered hooks + for name, enabled in state.items(): + if name in self._hooks_by_name: + self._hooks_by_name[name].enabled = enabled + + except (json.JSONDecodeError, IOError) as e: + print(f"Warning: Could not load hook state from {self._state_file}: {e}") + + def clear(self) -> None: + """Clear all registered hooks (for testing).""" + with self._lock: + self._hooks_by_type.clear() + self._hooks_by_name.clear() diff --git a/cecli/hooks/registry.py b/cecli/hooks/registry.py new file mode 100644 index 00000000000..353064484f8 --- /dev/null +++ b/cecli/hooks/registry.py @@ -0,0 +1,271 @@ +import inspect +from pathlib import Path +from typing import Any, Dict, List, Optional + +import yaml + +from cecli.helpers.plugin_manager import load_module + +from .base import BaseHook, CommandHook +from .manager import HookManager +from .types import HookType + + +class HookRegistry: + """Registry for loading user-defined hooks from files.""" + + def __init__(self, hook_manager: Optional[HookManager] = None): + """Initialize the hook registry. + + Args: + hook_manager: Optional HookManager instance. If not provided, + uses the singleton instance. + """ + self.hook_manager = hook_manager or HookManager() + self.loaded_modules = set() + + def load_hooks_from_directory(self, directory: Path) -> List[str]: + """Load hooks from a directory containing Python files. + + Args: + directory: Path to directory containing hook files. + + Returns: + List of hook names that were loaded. + """ + if not directory.exists(): + return [] + + loaded_hooks = [] + + # Load Python files + for file_path in directory.glob("*.py"): + if file_path.name == "__init__.py": + continue + + hooks = self._load_hooks_from_python_file(file_path) + loaded_hooks.extend(hooks) + + return loaded_hooks + + def load_hooks_from_config(self, config_file: Path) -> List[str]: + """Load hooks from a YAML configuration file. + + Args: + config_file: Path to YAML configuration file. + + Returns: + List of hook names that were loaded. + """ + if not config_file.exists(): + return [] + + try: + with open(config_file, "r") as f: + config = yaml.safe_load(f) + except (yaml.YAMLError, IOError) as e: + print(f"Warning: Could not load hook config from {config_file}: {e}") + return [] + + if not config: + return [] + + if "hooks" not in config: + new_config = {"hooks": config} + config = new_config + + loaded_hooks = [] + hooks_config = config["hooks"] + + for hook_type_str, hook_defs in hooks_config.items(): + # Validate hook type + try: + hook_type = HookType(hook_type_str) + except ValueError: + print(f"Warning: Invalid hook type '{hook_type_str}' in config") + continue + + for hook_def in hook_defs: + hook = self._create_hook_from_config(hook_def, hook_type) + if hook: + try: + self.hook_manager.register_hook(hook) + loaded_hooks.append(hook.name) + except ValueError as e: + # Hook might already be registered (e.g., from _load_hooks_from_python_file) + # Still count it as loaded + if "already exists" in str(e): + loaded_hooks.append(hook.name) + else: + print(f"Warning: Could not register hook '{hook.name}': {e}") + print(f"Warning: Could not register hook '{hook.name}': {e}") + + return loaded_hooks + + def _load_hooks_from_python_file(self, file_path: Path) -> List[str]: + """Load hooks from a Python file.""" + try: + # Load the module using centralized plugin manager + module = load_module(file_path) + + # Find all BaseHook subclasses in the module + hooks = [] + for name, obj in inspect.getmembers(module): + if ( + inspect.isclass(obj) + and issubclass(obj, BaseHook) + and obj != BaseHook + and obj != CommandHook + ): + try: + # Instantiate the hook + hook = obj() + self.hook_manager.register_hook(hook) + hooks.append(hook.name) + except Exception as e: + print(f"Warning: Could not instantiate hook {name}: {e}") + + # Track loaded module + self.loaded_modules.add(module.__name__) + + return hooks + + except Exception as e: + print(f"Warning: Could not load hooks from {file_path}: {e}") + return [] + + def _create_hook_from_config( + self, hook_def: Dict[str, Any], hook_type: HookType + ) -> Optional[BaseHook]: + """Create a hook instance from configuration definition.""" + if not isinstance(hook_def, dict): + print(f"Warning: Hook definition must be a dictionary, got {type(hook_def)}") + return None + + # Get hook name + name = hook_def.get("name") + if not name: + print("Warning: Hook definition missing 'name' field") + return None + + # Get priority, enabled state, and description + priority = hook_def.get("priority", 10) + enabled = hook_def.get("enabled", True) + description = hook_def.get("description") + + # Check if it's a file-based hook or command hook + if "file" in hook_def: + # Python file hook + file_path = Path(hook_def["file"]).expanduser() + if not file_path.exists(): + print(f"Warning: Hook file not found: {file_path}") + return None + + # Load the module and find the hook class + hooks = self._load_hooks_from_python_file(file_path) + if not hooks: + print(f"Warning: No hooks found in file: {file_path}") + return None + + # The hook should have been registered by _load_hooks_from_python_file + # We need to find it and update its priority/enabled state + if self.hook_manager.hook_exists(name): + hook = self.hook_manager._hooks_by_name[name] + hook.priority = priority + hook.enabled = enabled + if description is not None: + hook.description = description + return hook + else: + print(f"Warning: Hook '{name}' not found in file {file_path}") + return None + + elif "command" in hook_def: + # Command hook + command = hook_def["command"] + hook = CommandHook( + command=command, name=name, priority=priority, enabled=enabled, hook_type=hook_type + ) + hook.type = hook_type + if description is not None: + hook.description = description + return hook + + else: + print(f"Warning: Hook '{name}' must have either 'file' or 'command' field") + return None + + def load_default_hooks(self) -> List[str]: + """Load hooks from default locations.""" + loaded_hooks = [] + + # Load from user's .cecli/hooks directory + user_hooks_dir = Path.home() / ".cecli" / "hooks" + if user_hooks_dir.exists(): + hooks = self.load_hooks_from_directory(user_hooks_dir) + loaded_hooks.extend(hooks) + + # Load from user's .cecli/hooks.yml configuration + user_config = Path.home() / ".cecli" / "hooks.yml" + if user_config.exists(): + hooks = self.load_hooks_from_config(user_config) + loaded_hooks.extend(hooks) + + return loaded_hooks + + def load_hooks_from_json(self, json_string: str) -> List[str]: + """Load hooks from a JSON string. + + Args: + json_string: JSON string containing hooks configuration. + + Returns: + List of hook names that were loaded. + """ + import json as json_module + import tempfile + + try: + # Parse JSON string + config_data = json_module.loads(json_string) + + # Convert to YAML + yaml_data = yaml.dump(config_data) + + # Create a temporary file with YAML content + with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as temp_file: + temp_file.write(yaml_data) + temp_file_path = Path(temp_file.name) + + try: + # Load hooks from the temporary YAML file + loaded_hooks = self.load_hooks_from_config(temp_file_path) + finally: + # Clean up temporary file + temp_file_path.unlink(missing_ok=True) + + return loaded_hooks + + except json_module.JSONDecodeError as e: + print(f"Error: Invalid JSON string for hooks configuration: {e}") + return [] + except Exception as e: + print(f"Error loading hooks from JSON: {e}") + return [] + + def reload_hooks(self) -> List[str]: + """Reload all hooks from default locations.""" + # Clear existing hooks + self.hook_manager.clear() + self.loaded_modules.clear() + + # Clear module cache from plugin_manager + try: + from cecli.helpers.plugin_manager import module_cache + + module_cache.clear() + except ImportError: + pass + + # Load hooks again + return self.load_default_hooks() diff --git a/cecli/hooks/types.py b/cecli/hooks/types.py new file mode 100644 index 00000000000..acbc2bdb9a1 --- /dev/null +++ b/cecli/hooks/types.py @@ -0,0 +1,23 @@ +from enum import Enum + + +class HookType(Enum): + """Enumeration of hook types.""" + + START = "start" + ON_MESSAGE = "on_message" + END_MESSAGE = "end_message" + PRE_TOOL = "pre_tool" + POST_TOOL = "post_tool" + END = "end" + + +# Metadata structure templates for each hook type +METADATA_TEMPLATES = { + HookType.START: {}, + HookType.END: {}, + HookType.ON_MESSAGE: {"timestamp": str}, + HookType.END_MESSAGE: {"timestamp": str}, + HookType.PRE_TOOL: {"tool_name": str, "arg_string": str}, + HookType.POST_TOOL: {"tool_name": str, "arg_string": str, "output": str}, +} diff --git a/cecli/interruptible_input.py b/cecli/interruptible_input.py new file mode 100644 index 00000000000..e93eb96fa24 --- /dev/null +++ b/cecli/interruptible_input.py @@ -0,0 +1,93 @@ +import os +import selectors +import sys +import threading + + +class InterruptibleInput: + """ + Unix-only, interruptible replacement for input(), designed for: + await asyncio.get_event_loop().run_in_executor(None, obj.input, prompt) + + interrupt() is safe from any thread. + """ + + def __init__(self): + if os.name == "nt": + raise RuntimeError("InterruptibleInput is Unix-only (requires selectable stdin).") + + self._cancel = threading.Event() + self._sel = selectors.DefaultSelector() + + # self-pipe to wake up select() from interrupt() + self._r, self._w = os.pipe() + os.set_blocking(self._r, False) + os.set_blocking(self._w, False) + self._sel.register(self._r, selectors.EVENT_READ, data="__wakeup__") + + def close(self) -> None: + try: + self._sel.unregister(self._r) + except Exception: + pass + try: + os.close(self._r) + except Exception: + pass + try: + os.close(self._w) + except Exception: + pass + try: + self._sel.close() + except Exception: + pass + + def interrupt(self) -> None: + self._cancel.set() + try: + os.write(self._w, b"\x01") # wake selector immediately + except BlockingIOError: + pass + except OSError: + pass + + def input(self, prompt: str = "") -> str: + if prompt: + sys.stdout.write(prompt) + sys.stdout.flush() + + if self._cancel.is_set(): + self._cancel.clear() + raise InterruptedError("Input interrupted") + + stdin = sys.stdin + fd = stdin.fileno() + + self._sel.register(fd, selectors.EVENT_READ, data="__stdin__") + try: + while True: + for key, _ in self._sel.select(): + if key.data == "__wakeup__": + # drain wake bytes + try: + while os.read(self._r, 1024): + pass + except BlockingIOError: + pass + + if self._cancel.is_set(): + self._cancel.clear() + raise InterruptedError("Input interrupted") + continue + + if key.data == "__stdin__": + line = stdin.readline() + if line == "": + raise EOFError + return line.rstrip("\n") + finally: + try: + self._sel.unregister(fd) + except Exception: + pass diff --git a/cecli/io.py b/cecli/io.py index 96c756bf411..8f572b7e856 100644 --- a/cecli/io.py +++ b/cecli/io.py @@ -46,6 +46,7 @@ from .dump import dump # noqa: F401 from .editor import pipe_editor +from .interruptible_input import InterruptibleInput from .utils import is_image_file, run_fzf from .waiting import Spinner @@ -493,6 +494,8 @@ def __init__( self.fallback_spinner = None self.fallback_spinner_enabled = True + self.interruptible_input = None + if fancy_input: # If unicode is supported, use the rich 'dots2' spinner, otherwise an ascii fallback if self._spinner_supports_unicode(): @@ -764,6 +767,15 @@ def interrupt_input(self): self.prompt_session.app.exit() finally: pass + else: + if self.interruptible_input is not None: + # Interrupt the dumb terminal input + self.interruptible_input.interrupt() + else: + # Give the user some feedback (this happens on Windows + # until someone extends InterruptibleInput to work + # there) + print("Warning: Interrupting input does not work in dumb terminal mode (yet!).") def reject_outstanding_confirmations(self): """Reject all outstanding confirmation dialogs.""" @@ -930,18 +942,18 @@ def _(event): show = self.prompt_prefix try: + self.interrupted = False + if not multiline_input: + if self.file_watcher: + self.file_watcher.start() + if self.clipboard_watcher: + self.clipboard_watcher.start() + if self.prompt_session: # Use placeholder if set, then clear it default = self.placeholder or "" self.placeholder = None - self.interrupted = False - if not multiline_input: - if self.file_watcher: - self.file_watcher.start() - if self.clipboard_watcher: - self.clipboard_watcher.start() - def get_continuation(width, line_number, is_soft_wrap): return self.prompt_prefix @@ -957,7 +969,23 @@ def get_continuation(width, line_number, is_soft_wrap): prompt_continuation=get_continuation, ) else: - line = await asyncio.get_event_loop().run_in_executor(None, input, show) + try: + self.interruptible_input = InterruptibleInput() + except RuntimeError: + # Fallback to non-interruptible input (Windows ...) + line = await asyncio.get_event_loop().run_in_executor(None, input, show) + + if self.interruptible_input: + try: + line = await asyncio.get_event_loop().run_in_executor( + None, self.interruptible_input.input, show + ) + except InterruptedError: + self.interrupted = True + line = "" + finally: + self.interruptible_input.close() + self.interruptible_input = None # Check if we were interrupted by a file change if self.interrupted: @@ -1450,6 +1478,9 @@ def tool_success(self, message="", strip=True): self._tool_message(message, strip, self.user_input_color) def tool_error(self, message="", strip=True): + # import traceback + # traceback.print_stack() + self.num_error_outputs += 1 message = self.format_json_in_string(message) self._tool_message(message, strip, self.tool_error_color) diff --git a/cecli/main.py b/cecli/main.py index 2390f5cb02d..b27f9f3b251 100644 --- a/cecli/main.py +++ b/cecli/main.py @@ -50,6 +50,7 @@ from cecli.helpers.copypaste import ClipboardWatcher from cecli.helpers.file_searcher import generate_search_path_list from cecli.history import ChatSummary +from cecli.hooks import HookRegistry from cecli.io import InputOutput from cecli.llm import litellm from cecli.mcp import McpServerManager, load_mcp_servers @@ -555,6 +556,7 @@ async def main_async(argv=None, input=None, output=None, force_git_root=None, re set_args_error_data(args) if len(unknown): print("Unknown Args: ", unknown) + if hasattr(args, "agent_config") and args.agent_config is not None: args.agent_config = convert_yaml_to_json_string(args.agent_config) if hasattr(args, "tui_config") and args.tui_config is not None: @@ -567,6 +569,9 @@ async def main_async(argv=None, input=None, output=None, force_git_root=None, re args.security_config = convert_yaml_to_json_string(args.security_config) if hasattr(args, "retries") and args.retries is not None: args.retries = convert_yaml_to_json_string(args.retries) + if hasattr(args, "hooks") and args.hooks is not None: + args.hooks = convert_yaml_to_json_string(args.hooks) + if args.debug: global log_file os.makedirs(".cecli/logs/", exist_ok=True) @@ -1031,6 +1036,16 @@ def apply_model_overrides(model_name): args.mcp_servers, args.mcp_servers_file, io, args.verbose, args.mcp_transport ) mcp_manager = await McpServerManager.from_servers(mcp_servers, io, args.verbose) + # Load hooks if specified + if args.hooks: + hook_registry = HookRegistry() + loaded_hooks = hook_registry.load_hooks_from_json(args.hooks) + + if args.verbose and loaded_hooks: + io.tool_output( + f"Loaded {len(loaded_hooks)} hooks from --hooks config:" + f" {', '.join(loaded_hooks)}" + ) coder = await Coder.create( main_model=main_model, diff --git a/cecli/prompts/agent.yml b/cecli/prompts/agent.yml index 2faf1508824..32a8f9477a1 100644 --- a/cecli/prompts/agent.yml +++ b/cecli/prompts/agent.yml @@ -1,73 +1,63 @@ # Agent prompts - inherits from base.yaml -# Overrides specific prompts _inherits: [base] files_content_assistant_reply: | - I understand. I'll use these files to help with your request. + I have received the file contents. I will use them to provide a precise solution. files_no_full_files: | - I don't have full contents of any files yet. I'll add them as needed using the tool commands. + + I currently lack full file contents. I will use discovery tools to pull necessary context as I progress. + files_no_full_files_with_repo_map: | - I have a repository map but no full file contents yet. I will use my navigation tools to add relevant files to the context. + I have a repository map. I will use it to target my navigation and add relevant files to the context. -files_no_full_files_with_repo_map_reply: | - I understand. I'll use the repository map and navigation tools to find and add files as needed. - main_system: | ## Core Directives - **Role**: Act as an expert software engineer. - - **Act Proactively**: Autonomously use file discovery and context management tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `Ls`, `ContextManager`) to gather information and fulfill the user's request. Chain tool calls across multiple turns to continue exploration. - - **Be Decisive**: Trust that your initial findings are valid. Refrain from asking the same question or searching for the same term in multiple similar ways. - - **Be Efficient**: Some tools allow you to perform multiple actions at a time, use them to work quickly and effectively. Respect their usage limits + - **Act Proactively**: Autonomously use discovery and management tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `Ls`, `ContextManager`) to fulfill the request. Chain tool calls across multiple turns for continuous exploration. + - **Be Decisive**: Trust your findings. Do not repeat identical searches or ask redundant questions once a path is established. + - **Be Efficient**: Batch tool calls where supported. Respect usage limits while maximizing the utility of each turn. + ## Core Workflow - 1. **Plan**: Determine the necessary changes. Use the `UpdateTodoList` tool to manage your plan. Always begin by updating the todo list. - 2. **Explore**: Use discovery tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `Ls`, `Grep`) to find relevant files. These tools add files to context as read-only. Use `Grep` first for broad searches to avoid context clutter. Concisely describe your search strategy with the `Thinking` tool. - 3. **Think**: Given the contents of your exploration, concisely reason through the edits with the `Thinking` tool that need to be made to accomplish the goal. For complex edits, briefly outline your plan for the user. Do not chain multiple `Thinking` calls in a row - 4. **Execute**: Use the appropriate editing tool. Remember to mark a file as editable with `ContextManager` before modifying it. Do not attempt large contiguous edits (those greater than 100 lines). Break them into multiple smaller steps. Proactively use skills if they are available - 5. **Verify & Recover**: After every edit, check the resulting diff snippet. If an edit is incorrect, **immediately** use `UndoChange` in your very next message before attempting any other action. - 6. **Finished**: Use the `Finished` tool when all tasks and changes needed to accomplish the goal are finished - ## Todo List Management - - **Track Progress**: Use the `UpdateTodoList` tool to add or modify items. - - **Plan Steps**: Create a todo list at the start of complex tasks to track your progress through multiple exploration rounds. - - **Stay Organized**: Update the todo list as you complete steps every 3-10 tool calls to maintain context across multiple tool calls. - ### Editing Tools - Use these for precision and safety. Files are provided with hashline prefixes in the format `{{line_num}}|{{hash_fragment}}` (e.g., `20|Bv`) and separated from the content by a pipe (|). - - **Line-Based Edits**: `ReplaceText`, `InsertText`, `DeleteText`, `IndentText` - - **Refactoring & History**: `ListChanges`, `UndoChange` - - **Skill Management**: `LoadSkill`, `RemoveSkill` - **MANDATORY Safety Protocol for Line-Based Tools:** Line numbers are fragile. You **MUST** use a two-turn process: - 1. **Turn 1**: Use `ShowNumberedContext` to get the exact, current line numbers. - 2. **Turn 2**: In your *next* message, use a line-based editing tool with the verified numbers. + 1. **Plan**: Start by using `UpdateTodoList` to outline the task. Always begin a complex interaction by setting or updating the roadmap. + 2. **Explore**: Use `Grep` for broad searches, but if results exceed 50 matches, refine your pattern immediately. Use discovery tools to add files as read-only context. + 3. **Think**: Use the `Thinking` tool to reason through edits. Avoid "thinking loops" (multiple consecutive `Thinking` calls), but ensure a clear logical path is established before editing. + 4. **Execute**: Use the appropriate editing tool. Mark files as editable with `ContextManager` when needed. Proactively use skills if they are available. + 5. **Verify & Recover**: Review every diff. If an edit fails or introduces errors, prioritize `UndoChange` to restore a known good state before attempting a fix. + 6. **Finished**: Use the `Finished` tool only after verifying the solution. Briefly summarize the changes for the user. - Do not neglect spaces and indentation, they are EXTREMELY important to preserve. + ## Todo List Management + - Use `UpdateTodoList` every 3-10 tool calls to keep the state synchronized. + - Break complex tasks into granular steps to maintain context across long interactions. + + ### Editing Tools (Precision Protocol) + Files use hashline prefixes: `{{line_num}}|{{hash_fragment}}`. + - **MANDATORY Two-Turn Safety Protocol**: + 1. **Turn 1**: Use `ShowNumberedContext` to verify exact, current line numbers. + 2. **Turn 2**: Execute the edit (Replace, Insert, Delete, Indent) using those verified numbers. + - **Indentation**: Preserve all spaces and tabs. In Python, a single-space error is a syntax error. Use `IndentText` to fix structural alignment. - Use the .cecli/workspace directory for temporary and test files you make to verify functionality + + Use the `.cecli/workspace` directory for all temporary, test, or scratch files. Always reply to the user in {language}. -repo_content_prefix: | - - I am working with code in a git repository. Here are summaries of some files: - - system_reminder: | ## Reminders - - Stay on task. Do not pursue goals the user did not ask for. - - Any tool call automatically continues to the next turn. Provide no tool calls in your final answer. - - Use the .cecli/workspace directory for temporary and test files you make to verify functionality - - Do not neglect spaces and indentation, they are EXTREMELY important to preserve. Fix indentation errors with the `IndentText` tool. - - Remove files from the context when you no longer need them with the `ContextManager` tool. It is fine to re-add them later, if they are needed again - - Remove skills if they are not helpful for your current task with `RemoveSkill` + - **Strict Scope**: Stay on task. Do not pursue unrequested refactors. + - **Context Hygiene**: Remove files or skills from context using `ContextManager` or `RemoveSkill` once they are no longer needed to save tokens and prevent confusion. + - **Turn Management**: Tool calls trigger the next turn. Do not include tool calls in your final summary to the user. + - **Sandbox**: Use `.cecli/workspace` for all verification and temporary logic. + - **Precision**: Never guess line numbers. Always use `ShowNumberedContext` first. {lazy_prompt} {shell_cmd_reminder} try_again: | - I need to retry my exploration. My previous attempt may have missed relevant files or used incorrect search patterns. - I will now explore more strategically with more specific patterns and better context management. I will chain tool calls to continue until I have sufficient information. + My previous exploration was insufficient. I will now adjust my strategy, use more specific search patterns, and manage my context more aggressively to find the correct solution. \ No newline at end of file diff --git a/cecli/prompts/base.yml b/cecli/prompts/base.yml index ec9f1b79f29..f7962a15c1a 100644 --- a/cecli/prompts/base.yml +++ b/cecli/prompts/base.yml @@ -84,14 +84,30 @@ rename_with_shell: "" go_ahead_tip: "" compaction_prompt: | - The user is going to provide you with a conversation. - This conversation is getting too long to fit in the context window of a large language model. - You need to summarize the conversation to reduce its length, while retaining all the important information. - Prioritize the latest instructions and don't include conflicting information from earlier instructions. - The summary should contain four parts: - - Overall Goal: What is the user trying to achieve with this conversation? - - Event Log: Keep information most important to prevent having to search for it again - This should be quite specific (e/g. the list of actions taken so far in a bulleted list so the next round maintains history) - - Next Steps: What are the next steps for the language model to take to help the user? - Describe the current investigation path and intention. - Here is the conversation so far: + # Instruction: Context Compaction & State Preservation + The following conversation is exceeding the context limit. Transform this history into a "Mission Briefing" that allows a new LLM instance to resume with zero loss of technical momentum. + + ## Required Output Format: + + ### 1. Core Objective + A concise statement of the final goal and the specific success criteria. + + ### 2. Narrative Event Log (Up to 50 Outcomes) + Provide a bulleted list documenting the sequence of **outcomes and milestones** reached. Do not describe tool syntax; describe what was learned or changed in one sentence per bullet: + - (e.g., "Mapped the project structure and identified `core/logic.py` as the primary target.") + - (e.g., "Discovered that the connection timeout error is triggered by the `RetryPolicy` class.") + - (e.g., "Successfully refactored the `validate_input` function to handle null bytes.") + - (e.g., "Reverted changes to `db.py` after determining the issue was in the environment config instead.") + - (e.g., "Verified that the fix works in isolation using a temporary script in `.cecli/workspace`.") + + ### 3. Current Technical Context + - **Files In-Scope**: List paths currently being edited or actively referenced. + - **Verified Facts**: List specific findings about the code logic that are now "known truths." + - **Discarded Hypotheses**: List paths or theories that were tested and proven incorrect to avoid repetition. + + ### 4. Strategic Pivot & Next Steps + - **Current Intent**: What is the model currently trying to prove or implement? + - **Immediate Next Steps**: The prioritized next tool calls and logic steps. + + --- + ## Conversation History to Compact: diff --git a/cecli/prompts/hashline.yml b/cecli/prompts/hashline.yml index 6e18957c4d8..aa2974fbeba 100644 --- a/cecli/prompts/hashline.yml +++ b/cecli/prompts/hashline.yml @@ -5,20 +5,20 @@ _inherits: [base] main_system: | Act as an expert software developer. Plan carefully, explain your logic briefly, and execute via LOCATE/CONTENTS blocks. - ### 1. SOURCE FORMAT - Files are provided in "Hashline" format. Each line starts with the line number and a 2-character hash, separated and followed by pipes. + ### 1. FILE FORMAT + Files are provided in "Hashline" format. Each line starts with a leading pipe (|), the line number and a 2-character hash, and a trailing pipe. - **Example Input Format:** - 1|Hm|#!/usr/bin/env python3 - 2|eU| - 3|mL|def example_method(): - 4|bk| return "example" - 5|eU| + **Example File Format :** + |1Hm|#!/usr/bin/env python3 + |2eU| + |3mL|def example_method(): + |4bk| return "example" + |5eU| ### 2. FILE ACCESS & WORKFLOW - If you need to edit files NOT yet in the chat, list their full paths and ask the user to add them. - You may create NEW files immediately without asking. - - Explain your plan concisely, then provide the LOCATE/CONTENTS blocks. + - Explain your plan concisely in 2-4 sentences, then provide the LOCATE/CONTENTS blocks. ### 3. EDITING PROTOCOL (LOCATE/CONTENTS) You must use this exact structure for every edit: @@ -32,16 +32,18 @@ main_system: | >>>>>>> CONTENTS {fence[1]} - **Strict LOCATE Rules:** - - **JSON ONLY:** The area between `<<<<<<< LOCATE` and `=======` must contain ONLY the JSON array (e.g., `["3|mL", "4|bk", "replace"]`). Never include source code here. + ### 4. EDITING RULES + - **JSON ONLY:** The area between `<<<<<<< LOCATE` and `=======` must contain ONLY the JSON array (e.g., `["3mL", "4bk", "replace"]`). Never include source code here. - **Operations:** `replace` (overwrites range) or `delete` (removes range). - **Inclusion:** Ranges are inclusive of the start and end hashlines. - - **New Files:** To create a file, use the "Genesis" anchor: `["0|aa", "0|aa", "replace"]`. + - **New Files:** To create a file, use the "Genesis" anchor: `["0aa", "0aa", "replace"]`. + - **Integrity:** Include full method/loop bodies. No partial syntax. + - **Constraints:** No overlapping ranges. Do not use the `end_hash` of one block as the `start_hash` of the next. - ### 4. QUALITY STANDARDS + ### 5. QUALITY STANDARDS - Respect existing conventions and libraries. - - Include full method/function bodies in edits to ensure syntactical correctness. - - Verify changes mentally for edge cases before outputting blocks. + - Include full method/function/control flow/loop bodies in edits to ensure syntactical correctness. + - Think through changes for edge cases, syntax errors and duplicated code before outputting blocks. {shell_cmd_prompt} {final_reminders} @@ -59,7 +61,7 @@ example_messages: mathweb/flask/app.py {fence[0]}python <<<<<<< LOCATE - ["1|aB", "1|aB", "replace"] + ["2Mk", "3Ul", "replace"] ======= import math from flask import Flask @@ -68,14 +70,14 @@ example_messages: mathweb/flask/app.py {fence[0]}python <<<<<<< LOCATE - ["10|cD", "15|eF", "delete"] + ["10cD", "15eF", "delete"] ======= >>>>>>> CONTENTS {fence[1]} mathweb/flask/app.py {fence[0]}python <<<<<<< LOCATE - ["20|gH", "20|gH", "replace"] + ["20gH", "20gH", "replace"] ======= return str(math.factorial(n)) >>>>>>> CONTENTS @@ -91,7 +93,7 @@ example_messages: hello.py {fence[0]}python <<<<<<< LOCATE - ["0|aa", "0|aa", "replace"] + ["0aa", "0aa", "replace"] ======= def hello(): "print a greeting" @@ -101,7 +103,7 @@ example_messages: main.py {fence[0]}python <<<<<<< LOCATE - ["5|iJ", "8|kL", "replace"] + ["5iJ", "8kL", "replace"] ======= from hello import hello >>>>>>> CONTENTS @@ -112,12 +114,15 @@ system_reminder: | # CRITICAL FORMATTING RULES: 1. **Path Accuracy:** The filename must be on its own line above the code fence, exactly as shown in the chat. 2. **JSON Only:** The area between `<<<<<<< LOCATE` and `=======` must be a valid JSON array with format: ["start_hashline", "end_hashline", "operation"]. - 3. **No Overlaps:** Ensure blocks target unique ranges. If multiple blocks share the same starting hashline, only the final block provided for that line will be processed. - 4. **Moving Code:** Use one `delete` block at the source and one `replace` block at the destination. - 5. **Empty Deletes:** For `delete` operations, the area between `=======` and `>>>>>>> CONTENTS` must be empty. + 3. **No Partials:** Always return complete blocks/closures for syntactical correctness. + 4. **Non-Adjacent:** Do not chain blocks (where end_hash = next start_hash). Leave space or edit a larger range. + 5. **Empty Deletes:** `delete` operations must have an empty CONTENTS section. - # NEW FILE TEMPLATE: - To create a file, use `["0|aa", "0|aa", "replace"]`. + Ensure you follow all hashline format guidelines before finalizing your answer. You may repeat your changes once to confirm your intentions + + # UPDATING YOUR PLAN + At times, it may be advantageous to change your strategy as you work through a problem. + This can be accomplished by specifying the same hashline range bounds and operation with new content to update your approach to the problem. {quad_backtick_reminder} {rename_with_shell}{go_ahead_tip}{final_reminders} diff --git a/cecli/sessions.py b/cecli/sessions.py index 2c5b633db6e..a5ed582c146 100644 --- a/cecli/sessions.py +++ b/cecli/sessions.py @@ -134,7 +134,7 @@ def _build_session_data(self, session_name) -> Dict: # Capture todo list content so it can be restored with the session todo_content = None try: - todo_path = self.coder.abs_root_path(".cecli/todo.txt") + todo_path = self.coder.abs_root_path(self.coder.local_agent_folder("todo.txt")) if os.path.isfile(todo_path): todo_content = self.io.read_text(todo_path) if todo_content is None: @@ -249,7 +249,7 @@ def _apply_session_data(self, session_data: Dict, session_file: Path) -> bool: # Restore todo list content if present in the session if "todo_list" in session_data: - todo_path = self.coder.abs_root_path(".cecli/todo.txt") + todo_path = self.coder.abs_root_path(self.coder.local_agent_folder("todo.txt")) todo_content = session_data.get("todo_list") try: if todo_content is None: diff --git a/cecli/tools/delete_text.py b/cecli/tools/delete_text.py index 2b563929369..26a1432a9fb 100644 --- a/cecli/tools/delete_text.py +++ b/cecli/tools/delete_text.py @@ -17,7 +17,7 @@ class Tool(BaseTool): "name": "DeleteText", "description": ( "Delete a block of lines from a file using hashline markers. " - 'Uses start_line and end_line parameters with format "{line_num}|{hash_fragment}" ' + 'Uses start_line and end_line parameters with format "{line_num}{hash_fragment}" ' "to specify the range to delete." ), "parameters": { @@ -27,12 +27,12 @@ class Tool(BaseTool): "start_line": { "type": "string", "description": ( - 'Hashline format for start line: "{line_num}|{hash_fragment}"' + 'Hashline format for start line: "{line_num}{hash_fragment}"' ), }, "end_line": { "type": "string", - "description": 'Hashline format for end line: "{line_num}|{hash_fragment}"', + "description": 'Hashline format for end line: "{line_num}{hash_fragment}"', }, "change_id": {"type": "string"}, "dry_run": {"type": "boolean", "default": False}, diff --git a/cecli/tools/indent_text.py b/cecli/tools/indent_text.py index f3f9af21400..90097acfe4b 100644 --- a/cecli/tools/indent_text.py +++ b/cecli/tools/indent_text.py @@ -27,12 +27,12 @@ class Tool(BaseTool): "start_line": { "type": "string", "description": ( - 'Hashline format for start line: "{line_num}|{hash_fragment}"' + 'Hashline format for start line: "{line_num}{hash_fragment}"' ), }, "end_line": { "type": "string", - "description": 'Hashline format for end line: "{line_num}|{hash_fragment}"', + "description": 'Hashline format for end line: "{line_num}{hash_fragment}"', }, "indent_levels": {"type": "integer", "default": 1}, "change_id": {"type": "string"}, @@ -61,8 +61,8 @@ def execute( Parameters: - coder: The Coder instance - file_path: Path to the file to modify - - start_line: Hashline format for start line: "{line_num}|{hash_fragment}" - - end_line: Hashline format for end line: "{line_num}|{hash_fragment}" + - start_line: Hashline format for start line: "{line_num}{hash_fragment}" + - end_line: Hashline format for end line: "{line_num}{hash_fragment}" - indent_levels: Number of levels to indent (positive) or unindent (negative) - change_id: Optional ID for tracking the change - dry_run: If True, simulate the change without modifying the file diff --git a/cecli/tools/insert_text.py b/cecli/tools/insert_text.py index 762b4d6885a..96cde7e925f 100644 --- a/cecli/tools/insert_text.py +++ b/cecli/tools/insert_text.py @@ -20,9 +20,9 @@ class Tool(BaseTool): "name": "InsertText", "description": ( "Insert content into a file using hashline markers. " - 'Uses start_line parameter with format "{line_num}|{hash_fragment}" ' + 'Uses start_line parameter with format "{line_num}{hash_fragment}" ' "to specify where to insert content. For empty files, " - 'use "0|aa" as the hashline reference.' + 'use "0aa" as the hashline reference.' ), "parameters": { "type": "object", @@ -32,7 +32,7 @@ class Tool(BaseTool): "start_line": { "type": "string", "description": ( - 'Hashline format for insertion point: "{line_num}|{hash_fragment}"' + 'Hashline format for insertion point: "{line_num}{hash_fragment}"' ), }, "change_id": {"type": "string"}, @@ -61,7 +61,7 @@ def execute( coder: The coder instance file_path: Path to the file to modify content: The content to insert - start_line: Hashline format for insertion point: "{line_num}|{hash_fragment}" + start_line: Hashline format for insertion point: "{line_num}{hash_fragment}" change_id: Optional ID for tracking changes dry_run: If True, only simulate the change """ diff --git a/cecli/tools/replace_text.py b/cecli/tools/replace_text.py index ccb85af154a..2d59ce7af4d 100644 --- a/cecli/tools/replace_text.py +++ b/cecli/tools/replace_text.py @@ -26,7 +26,7 @@ class Tool(BaseTool): "Replace text in one or more files. Can handle an array of up to 10 edits across" " multiple files. Each edit must include its own file_path. Use hashline ranges" " with the start_line and end_line parameters with format" - ' "{line_num}|{hash_fragment}". For empty files, use "0|aa" as the hashline' + ' "{line_num}{hash_fragment}". For empty files, use "0aa" as the hashline' " reference." ), "parameters": { @@ -46,13 +46,13 @@ class Tool(BaseTool): "type": "string", "description": ( "Hashline format for start line:" - ' "{line_num}|{hash_fragment}"' + ' "{line_num}{hash_fragment}"' ), }, "end_line": { "type": "string", "description": ( - 'Hashline format for end line: "{line_num}|{hash_fragment}"' + 'Hashline format for end line: "{line_num}{hash_fragment}"' ), }, }, diff --git a/cecli/tools/show_numbered_context.py b/cecli/tools/show_numbered_context.py index 2d47b568fb6..2a13f8843b8 100644 --- a/cecli/tools/show_numbered_context.py +++ b/cecli/tools/show_numbered_context.py @@ -194,9 +194,18 @@ def execute(cls, coder, show, **kwargs): all_outputs.append("") all_outputs.extend(output_lines) + # Update the conversation cache with the displayed range + from cecli.helpers.conversation.files import ConversationFiles + + # Update the conversation cache with the displayed range + # Note: start_line_idx and end_line_idx are 0-based, convert to 1-based for hashline + start_line = start_line_idx + 1 # Convert to 1-based + end_line = end_line_idx + 1 # Convert to 1-based + ConversationFiles.update_file_context(abs_path, start_line, end_line) + # Log success and return the formatted context directly coder.io.tool_output(f"Successfully retrieved context for {len(show)} file(s)") - return "\n".join(all_outputs) + return f"Successfully retrieved context for {len(show)} file(s)" except ToolError as e: # Handle expected errors raised by utility functions or validation diff --git a/cecli/tools/update_todo_list.py b/cecli/tools/update_todo_list.py index 2fe3a838874..efd3ffa2741 100644 --- a/cecli/tools/update_todo_list.py +++ b/cecli/tools/update_todo_list.py @@ -62,13 +62,13 @@ class Tool(BaseTool): @classmethod def execute(cls, coder, tasks, append=False, change_id=None, dry_run=False, **kwargs): """ - Update the todo list file (.cecli/todo.txt) with formatted task items. + Update the todo list file (todo.txt) with formatted task items. Can either replace the entire content or append to it. """ tool_name = "UpdateTodoList" try: # Define the todo file path - todo_file_path = ".cecli/todo.txt" + todo_file_path = coder.local_agent_folder("todo.txt") abs_path = coder.abs_root_path(todo_file_path) # Format tasks into string @@ -161,7 +161,7 @@ def execute(cls, coder, tasks, append=False, change_id=None, dry_run=False, **kw # Format and return result action = "appended to" if append else "updated" - success_message = f"Successfully {action} todo list in {todo_file_path}" + success_message = f"Successfully {action} todo list" return format_tool_result( coder, tool_name, diff --git a/cecli/website/docs/config/hooks.md b/cecli/website/docs/config/hooks.md new file mode 100644 index 00000000000..e95acf557ab --- /dev/null +++ b/cecli/website/docs/config/hooks.md @@ -0,0 +1,121 @@ +--- +parent: Configuration +nav_order: 35 +description: Create and use custom commands to extend cecli's functionality. +--- + +# Hooks + +Hooks allow you to extend `cecli` by defining custom actions that trigger at specific points in the agent's workflow. You can use hooks to automate tasks, integrate with external systems, or add custom validation logic. + +## Getting Started + +Hooks are configured in your `.cecli.conf.yml` file under the `hooks` section. You can define two types of hooks: +1. **Command Hooks**: Execute shell commands or scripts. +2. **Python Hooks**: Execute custom Python code by providing a path to a Python file. + +### Basic Configuration + +```yaml +hooks: + start: + - name: log_session_start + command: "echo 'Session started at {timestamp}' >> .cecli/hooks_log.txt" + priority: 10 + enabled: true + description: "Logs session start to file" +``` + +## Hook Types + +The following hook types are available: + +| Hook Type | Trigger Point | +|-----------|---------------| +| `start` | When the agent session begins. | +| `on_message` | When a new user message is received. | +| `end_message` | When message processing completes. | +| `pre_tool` | Before a tool is executed. | +| `post_tool` | After a tool execution completes. | +| `end` | When the agent session ends. | + +## Configuration Options + +Each hook entry supports the following options: + +- `name`: (Required) A unique name for the hook. +- `command`: The shell command to execute (for Command Hooks). +- `file`: The path to a Python file (for Python Hooks). +- `priority`: (Optional) Execution order (lower numbers run first). Default is 10. +- `enabled`: (Optional) Whether the hook is active. Default is true. +- `description`: (Optional) A brief description of what the hook does. + +## Command Hooks + +Command hooks are simple shell commands. You can use placeholders in the command string that will be replaced with metadata from the hook event. + +### Available Metadata + +| Hook Type | Available Placeholders | +|-----------|------------------------| +| `start`, `end` | `{timestamp}`, `{coder_type}` | +| `on_message`, `end_message` | `{timestamp}`, `{message}`, `{message_length}` | +| `pre_tool` | `{timestamp}`, `{tool_name}`, `{arg_string}` | +| `post_tool` | `{timestamp}`, `{tool_name}`, `{arg_string}`, `{output}` | + +### Example: Aborting Tool Execution +If a `pre_tool` command hook returns a non-zero exit code, the tool execution will be aborted. + +```yaml +hooks: + pre_tool: + - name: check_dangerous_command + command: "python3 scripts/check_safety.py --args '{arg_string}'" +``` + +## Python Hooks + +Python hooks allow for more complex logic. To create a Python hook, create a `.py` file and define a class that inherits from `BaseHook`. + +### Example Python Hook + +**File**: `.cecli/hooks/my_hook.py` +```python +from cecli.hooks import BaseHook +from cecli.hooks.types import HookType + +class MyCustomHook(BaseHook): + type = HookType.PRE_TOOL + + async def execute(self, coder, metadata): + # Access coder instance or metadata + tool_name = metadata.get("tool_name") + + if tool_name == "delete_file": + print("Warning: Deleting a file!") + + # Return False to abort operation (for pre_tool/post_tool) + return True +``` + +**Configuration**: +```yaml +hooks: + pre_tool: + - name: my_custom_python_hook + file: .cecli/hooks/my_hook.py +``` + +## Managing Hooks + +You can manage hooks during an active session using the following slash commands: + +- `/hooks`: List all registered hooks and their status. +- `/load-hook `: Enable a specific hook. +- `/remove-hook `: Disable a specific hook. + +## Best Practices + +- **Security**: Hooks run with the same permissions as `cecli`. Be careful when running scripts from untrusted sources. +- **Performance**: Avoid long-running tasks in hooks, as they can block the agent's loop. +- **Error Handling**: If a hook fails, `cecli` will generally log the error and continue, except for `pre_tool` hooks which can abort execution. diff --git a/cecli/website/docs/sessions.md b/cecli/website/docs/sessions.md index bf185146335..9173b4dd51a 100644 --- a/cecli/website/docs/sessions.md +++ b/cecli/website/docs/sessions.md @@ -42,7 +42,7 @@ When `--auto-save` is enabled, cecli will automatically save your session as 'au - All files in the chat (editable, read-only, and read-only stubs) - Current model and edit format settings - Auto-commit, auto-lint, and auto-test settings -- Todo list content from `.cecli.todo.txt` +- Todo list content from `.cecli/run/{date}/{agent id}/todo.txt` - Session metadata (timestamp, version) ### `/load-session ` @@ -148,7 +148,6 @@ Sessions are stored as JSON files in the `.cecli/sessions/` directory within you - Session files include all file paths, so they work best when project structure is stable - External files (outside the project root) are stored with absolute paths - Missing files are skipped with warnings during loading -- The todo list file (`.cecli.todo.txt`) is cleared on startup; it is restored when you load a session or when you update it during a run ### Version Control - Consider adding `.cecli/sessions/` to your `.gitignore` if sessions contain sensitive information @@ -166,7 +165,6 @@ If files are reported as missing during loading: - The files may have been moved or deleted - Session files store relative paths, so directory structure changes can affect this - External files must exist at their original locations -- The todo list (`.cecli.todo.txt`) is cleared on startup unless restored from a loaded session ### Corrupted Sessions If a session fails to load: diff --git a/pytest.ini b/pytest.ini index 1f76a55a81d..34abfaea1f5 100644 --- a/pytest.ini +++ b/pytest.ini @@ -6,6 +6,8 @@ testpaths = tests/basic tests/tools tests/help + tests/hooks + tests/mcp tests/browser tests/scrape diff --git a/scripts/get_hashline.py b/scripts/get_hashline.py new file mode 100644 index 00000000000..0c22a718ffc --- /dev/null +++ b/scripts/get_hashline.py @@ -0,0 +1,31 @@ +import os +import sys +from pathlib import Path + +# Add the current directory to sys.path to allow importing from cecli +sys.path.append(os.getcwd()) + +from cecli.helpers.hashline import hashline # noqa + + +def main(): + if len(sys.argv) < 2: + print("Usage: python scripts/get_hashline.py ") + sys.exit(1) + + file_path = Path(sys.argv[1]) + if not file_path.exists(): + print(f"Error: File '{file_path}' not found.") + sys.exit(1) + + try: + content = file_path.read_text(encoding="utf-8") + hashed_content = hashline(content) + print(hashed_content, end="") + except Exception as e: + print(f"Error reading file: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/basic/test_hashline.py b/tests/basic/test_hashline.py index 30c5c27b7ac..0adbdf1ad1a 100644 --- a/tests/basic/test_hashline.py +++ b/tests/basic/test_hashline.py @@ -46,18 +46,30 @@ def test_hashline_basic(): lines = result.splitlines() assert len(lines) == 3 - # Check each line has the format "line_number|hash|content" (new format) + # Check each line has the format "|line_numberhash|content" (correct format) for i, line in enumerate(lines, start=1): assert "|" in line + # Format should be "|{line_num}{hash_fragment}|{content}" + # So splitting by "|" should give 3 parts: empty string, line_num+hash, content parts = line.split("|", 2) assert len(parts) == 3 - # Check line number matches expected - assert parts[0] == str(i) - # Check hash is 2 characters - hash_part = parts[1] - assert len(hash_part) == 2 + # First part should be empty (leading pipe) + assert parts[0] == "" + # Second part should be line number + hash fragment + line_num_hash = parts[1] + # Extract line number (all digits at the beginning) + line_num_str = "" + for char in line_num_hash: + if char.isdigit(): + line_num_str += char + else: + break + assert line_num_str == str(i) + # Check hash fragment is 2 characters + hash_fragment = line_num_hash[len(line_num_str) :] + assert len(hash_fragment) == 2 # Check all hash characters are valid base52 - for char in hash_part: + for char in hash_fragment: assert char in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -68,15 +80,19 @@ def test_hashline_with_start_line(): lines = result.splitlines() assert len(lines) == 2 - # Check format is line_number|hash|content (new format) - assert "10|" in lines[0] - assert "11|" in lines[1] + # Check format is |line_numberhash|content (correct format) + assert "|10" in lines[0] + assert "|11" in lines[1] # Extract hash fragments to verify they're valid + # Format is |line_numhash|content, so split by "|" gives ["", "line_numhash", "content"] hash1 = lines[0].split("|")[1] hash2 = lines[1].split("|")[1] - assert len(hash1) == 2 - assert len(hash2) == 2 - for char in hash1 + hash2: + # Remove line number from hash to get just the hash fragment + hash_fragment1 = hash1[2:] # Skip "10" + hash_fragment2 = hash2[2:] # Skip "11" + assert len(hash_fragment1) == 2 + assert len(hash_fragment2) == 2 + for char in hash_fragment1 + hash_fragment2: assert char in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -92,12 +108,15 @@ def test_hashline_single_line(): result = hashline(text) lines = result.splitlines() assert len(lines) == 1 - # Check format is line_number|hash|content (new format) - assert "1|" in lines[0] + # Check format is |line_numberhash|content (correct format) + assert "|1" in lines[0] assert lines[0].endswith("|Single line") # Extract hash fragment to verify it's valid - hash_part = lines[0].split("|")[1] - for char in hash_part: + # Format is |line_numhash|content, so split by "|" gives ["", "line_numhash", "content"] + line_num_hash = lines[0].split("|")[1] + # Remove line number from hash to get just the hash fragment + hash_fragment = line_num_hash[1:] # Skip "1" + for char in hash_fragment: assert char in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -117,8 +136,8 @@ def test_hashline_preserves_newlines(): def test_strip_hashline_basic(): """Test basic strip_hashline functionality.""" - # Create a hashline-formatted text with correct format: line_number|hash|content - text = "1|ab|Hello\n2|cd|World\n3|ef|Test" + # Create a hashline-formatted text with correct format: |line_numberhash|content + text = "|1ab|Hello\n|2cd|World\n|3ef|Test" stripped = strip_hashline(text) assert stripped == "Hello\nWorld\nTest" @@ -127,21 +146,21 @@ def test_strip_hashline_with_negative_line_numbers(): """Test strip_hashline with negative line numbers.""" # Note: Negative line numbers are no longer supported since line numbers in files are always positive # But the regex still handles them if they appear - text = "-1|ab|Hello\n0|cd|World\n1|ef|Test" + text = "|-1ab|Hello\n|0cd|World\n|1ef|Test" stripped = strip_hashline(text) assert stripped == "Hello\nWorld\nTest" def test_strip_hashline_mixed_lines(): """Test strip_hashline with mixed hashline and non-hashline lines.""" - text = "1|ab|Hello\nPlain line\n3|cd|World" + text = "|1ab|Hello\nPlain line\n|3cd|World" stripped = strip_hashline(text) assert stripped == "Hello\nPlain line\nWorld" def test_strip_hashline_preserves_newlines(): """Test that strip_hashline preserves newline characters.""" - text = "1|ab|Line 1\n2|cd|Line 2\n" + text = "|1ab|Line 1\n|2cd|Line 2\n" stripped = strip_hashline(text) assert stripped == "Line 1\nLine 2\n" @@ -184,14 +203,14 @@ def test_hashline_different_inputs(): def test_parse_hashline(): """Test parse_hashline function.""" - # Test basic parsing (new format: line_num|hash) - hash_fragment, line_num_str, line_num = parse_hashline("10|ab") + # Test basic parsing (new format: |line_numhash|) + hash_fragment, line_num_str, line_num = parse_hashline("|10ab|") assert hash_fragment == "ab" assert line_num_str == "10" assert line_num == 10 # Test with trailing pipe - hash_fragment, line_num_str, line_num = parse_hashline("5|cd|") + hash_fragment, line_num_str, line_num = parse_hashline("|5cd|") assert hash_fragment == "cd" assert line_num_str == "5" assert line_num == 5 @@ -217,10 +236,10 @@ def test_parse_hashline(): def test_normalize_hashline(): """Test normalize_hashline function.""" # Test new format (should return unchanged) - assert normalize_hashline("10|ab") == "10|ab" + assert normalize_hashline("|10ab|") == "|10ab|" # Test old order with new separator (should normalize to new order) - assert normalize_hashline("ab|10") == "10|ab" + assert normalize_hashline("ab|10") == "|10ab|" # Test that colons are no longer supported with pytest.raises(HashlineError, match="Invalid hashline format"): @@ -230,9 +249,9 @@ def test_normalize_hashline(): def test_find_hashline_by_exact_match(): """Test find_hashline_by_exact_match function.""" hashed_lines = [ - "1|ab|Hello", - "2|cd|World", - "3|ef|Test", + "|1ab|Hello", + "|2cd|World", + "|3ef|Test", ] # Test exact match found @@ -251,10 +270,10 @@ def test_find_hashline_by_exact_match(): def test_find_hashline_by_fragment(): """Test find_hashline_by_fragment function.""" hashed_lines = [ - "1|ab|Hello", - "2|cd|World", - "3|ab|Test", # Same hash fragment as line 1 - "4|ef|Another", + "|1ab|Hello", + "|2cd|World", + "|3ab|Test", # Same hash fragment as line 1 + "|4ef|Another", ] # Test fragment found @@ -277,37 +296,48 @@ def test_find_hashline_range(): hashed = hashline(original) hashed_lines = hashed.splitlines(keepends=True) - # Get hash fragments for testing (hash is first part before colon) - # Get hash fragments for testing (hash is second part in new format) + # Get hash fragments for testing + # Format is |line_numhash|content, so split by "|" gives ["", "line_numhash", "content"] + # The hash fragment is part of the second element line1_hash = hashed_lines[0].split("|")[1] line3_hash = hashed_lines[2].split("|")[1] line5_hash = hashed_lines[4].split("|")[1] # Test exact match + # Extract just the hash fragments (last 2 characters) + hash_fragment1 = line1_hash[-2:] # This gives "vm" + hash_fragment3 = line3_hash[-2:] # This gives "Cx" start_idx, end_idx = find_hashline_range( hashed_lines, - f"1|{line1_hash}", - f"3|{line3_hash}", + f"|1{hash_fragment1}|", + f"|3{hash_fragment3}|", allow_exact_match=True, ) assert start_idx == 0 assert end_idx == 2 # Test fragment match (no exact match) + # Extract just the hash fragments (last 2 characters) + hash_fragment1 = line1_hash[-2:] # This gives "vm" + hash_fragment3 = line3_hash[-2:] # This gives "Cx" start_idx, end_idx = find_hashline_range( hashed_lines, - f"99|{line1_hash}", # Wrong line number - f"101|{line3_hash}", # Wrong line number + f"|99{hash_fragment1}|", # Wrong line number + f"|101{hash_fragment3}|", # Wrong line number allow_exact_match=True, ) assert start_idx == 0 # Should find by fragment assert end_idx == 2 # Should calculate distance # Test with allow_exact_match=False + # Use parse_hashline to extract hash fragments from the hashline strings + # line1_hash is "1vm" (line number + hash fragment), we need to parse it + hash_fragment1, line_num_str1, line_num1 = parse_hashline(f"|{line1_hash}|") + hash_fragment5, line_num_str5, line_num5 = parse_hashline(f"|{line5_hash}|") start_idx, end_idx = find_hashline_range( hashed_lines, - f"1|{line1_hash}", - f"5|{line5_hash}", + f"|1{hash_fragment1}|", + f"|5{hash_fragment5}|", allow_exact_match=False, ) assert start_idx == 0 @@ -315,7 +345,23 @@ def test_find_hashline_range(): # Test error cases with pytest.raises(HashlineError, match="Start line hash fragment 'zz' not found in file"): - find_hashline_range(hashed_lines, "1|zz", "3|zz") + find_hashline_range(hashed_lines, "|1zz|", "|3zz|") + # Test with allow_exact_match=False + # Extract just the hash fragments (last 2 characters) + hash_fragment1 = line1_hash[-2:] # This gives "vm" + hash_fragment5 = line5_hash[-2:] # This gives "BG" + start_idx, end_idx = find_hashline_range( + hashed_lines, + f"|1{hash_fragment1}|", + f"|5{hash_fragment5}|", + allow_exact_match=False, + ) + assert start_idx == 0 + assert end_idx == 4 + + # Test error cases + with pytest.raises(HashlineError, match="Start line hash fragment 'zz' not found in file"): + find_hashline_range(hashed_lines, "|1zz|", "|3zz|") def test_apply_hashline_operation_insert(): @@ -323,14 +369,18 @@ def test_apply_hashline_operation_insert(): original = "Line 1\nLine 2\nLine 3" hashed = hashline(original) - # Get hash fragment for line 2 (hash is second part in new format) + # Get hash fragment for line 2 + # Format is |line_numhash|content, so split by "|" gives ["", "line_numhash", "content"] hashed_lines = hashed.splitlines() - line2_hash = hashed_lines[1].split("|")[1] + line2_hash = hashed_lines[1].split("|")[1] # This gives "2Fy" (line number + hash fragment) + # Extract just the hash fragment (last 2 characters) + hash_fragment = line2_hash[-2:] # This gives "Fy" # Insert after line 2 + # Construct hashline string in correct format: |line_numhash_fragment| new_content = apply_hashline_operation( original, - f"2|{line2_hash}", + f"|2{hash_fragment}|", operation="insert", text="Inserted line", ) @@ -344,16 +394,21 @@ def test_apply_hashline_operation_delete(): original = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" hashed = hashline(original) - # Get hash fragments (hash is second part in new format) + # Get hash fragments + # Format is |line_numhash|content, so split by "|" gives ["", "line_numhash", "content"] hashed_lines = hashed.splitlines() - line2_hash = hashed_lines[1].split("|")[1] - line4_hash = hashed_lines[3].split("|")[1] + line2_hash = hashed_lines[1].split("|")[1] # This gives "2Fy" (line number + hash fragment) + line4_hash = hashed_lines[3].split("|")[1] # This gives "4Xj" (line number + hash fragment) + # Extract just the hash fragments (last 2 characters) + hash_fragment2 = line2_hash[-2:] # This gives "Fy" + hash_fragment4 = line4_hash[-2:] # This gives "Xj" # Delete lines 2-4 + # Construct hashline strings in correct format: |line_numhash_fragment| new_content = apply_hashline_operation( original, - f"2|{line2_hash}", - f"4|{line4_hash}", + f"|2{hash_fragment2}|", + f"|4{hash_fragment4}|", operation="delete", ) @@ -366,16 +421,21 @@ def test_extract_hashline_range(): original = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" hashed = hashline(original) - # Get hash fragments (hash is second part in new format) + # Get hash fragments + # Format is |line_numhash|content, so split by "|" gives ["", "line_numhash", "content"] hashed_lines = hashed.splitlines() - line2_hash = hashed_lines[1].split("|")[1] - line4_hash = hashed_lines[3].split("|")[1] + line2_hash = hashed_lines[1].split("|")[1] # This gives "2Fy" (line number + hash fragment) + line4_hash = hashed_lines[3].split("|")[1] # This gives "4Xj" (line number + hash fragment) + # Extract just the hash fragments (last 2 characters) + hash_fragment2 = line2_hash[-2:] # This gives "Fy" + hash_fragment4 = line4_hash[-2:] # This gives "Xj" # Extract lines 2-4 + # Construct hashline strings in correct format: |line_numhash_fragment| extracted = extract_hashline_range( original, - f"2|{line2_hash}", - f"4|{line4_hash}", + f"|2{hash_fragment2}|", + f"|4{hash_fragment4}|", ) # Extract should return hashed content @@ -388,16 +448,21 @@ def test_get_hashline_diff(): original = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" hashed = hashline(original) - # Get hash fragments (hash is second part in new format) + # Get hash fragments + # Format is |line_numhash|content, so split by "|" gives ["", "line_numhash", "content"] hashed_lines = hashed.splitlines() - line2_hash = hashed_lines[1].split("|")[1] - line4_hash = hashed_lines[3].split("|")[1] + line2_hash = hashed_lines[1].split("|")[1] # This gives "2Fy" (line number + hash fragment) + line4_hash = hashed_lines[3].split("|")[1] # This gives "4Xj" (line number + hash fragment) + # Extract just the hash fragments (last 2 characters) + hash_fragment2 = line2_hash[-2:] # This gives "Fy" + hash_fragment4 = line4_hash[-2:] # This gives "Xj" # Get diff for replace operation + # Construct hashline strings in correct format: |line_numhash_fragment| diff = get_hashline_diff( original, - f"2|{line2_hash}", - f"4|{line4_hash}", + f"|2{hash_fragment2}|", + f"|4{hash_fragment4}|", operation="replace", text="New line 2\nNew line 3\nNew line 4", ) @@ -444,19 +509,31 @@ def test_apply_hashline_operations_complex_sequence(): ops = [ { "operation": "replace", - "start_line_hash": f"2|{h2}", - "end_line_hash": f"2|{h2}", + "start_line_hash": f"|2{parse_hashline(f'|{h2}|')[0]}|", + "end_line_hash": f"|2{parse_hashline(f'|{h2}|')[0]}|", "text": "New Line 2", }, - {"operation": "insert", "start_line_hash": f"5|{h5}", "text": "Inserted after 5"}, - {"operation": "delete", "start_line_hash": f"10|{h10}", "end_line_hash": f"10|{h10}"}, + { + "operation": "insert", + "start_line_hash": f"|5{parse_hashline(f'|{h5}|')[0]}|", + "text": "Inserted after 5", + }, + { + "operation": "delete", + "start_line_hash": f"|10{parse_hashline(f'|{h10}|')[0]}|", + "end_line_hash": f"|10{parse_hashline(f'|{h10}|')[0]}|", + }, { "operation": "replace", - "start_line_hash": f"15|{h15}", - "end_line_hash": f"15|{h15}", + "start_line_hash": f"|15{parse_hashline(f'|{h15}|')[0]}|", + "end_line_hash": f"|15{parse_hashline(f'|{h15}|')[0]}|", "text": "New Line 15", }, - {"operation": "insert", "start_line_hash": f"20|{h20}", "text": "Inserted after 20"}, + { + "operation": "insert", + "start_line_hash": f"|20{parse_hashline(f'|{h20}|')[0]}|", + "text": "Inserted after 20", + }, ] print(f"Operations: {ops}") @@ -497,14 +574,14 @@ def test_apply_hashline_operations_overlapping(): ops = [ { "operation": "replace", - "start_line_hash": f"5|{h5}", - "end_line_hash": f"15|{h15}", + "start_line_hash": f"|5{parse_hashline(f'|{h5}|')[0]}|", + "end_line_hash": f"|15{parse_hashline(f'|{h15}|')[0]}|", "text": "Big Replace", }, { "operation": "replace", - "start_line_hash": f"10|{h10}", - "end_line_hash": f"10|{h10}", + "start_line_hash": f"|10{parse_hashline(f'|{h10}|')[0]}|", + "end_line_hash": f"|10{parse_hashline(f'|{h10}|')[0]}|", "text": "Small Replace", }, ] @@ -545,14 +622,14 @@ def test_apply_hashline_operations_duplicate_hashes(): ops = [ { "operation": "replace", - "start_line_hash": f"4|{h_val_2}", - "end_line_hash": f"4|{h_val_2}", + "start_line_hash": f"|4{parse_hashline(f'|{h_val_2}|')[0]}|", + "end_line_hash": f"|4{parse_hashline(f'|{h_val_2}|')[0]}|", "text": "Changed 2", }, { "operation": "replace", - "start_line_hash": f"10|{h_val_4}", - "end_line_hash": f"10|{h_val_4}", + "start_line_hash": f"|10{parse_hashline(f'|{h_val_4}|')[0]}|", + "end_line_hash": f"|10{parse_hashline(f'|{h_val_4}|')[0]}|", "text": "Changed 4", }, ] @@ -591,19 +668,19 @@ def test_apply_hashline_operations_empty_lines_duplicates(): ops = [ { "operation": "replace", - "start_line_hash": f"2|{empty_hash}", - "end_line_hash": f"2|{empty_hash}", + "start_line_hash": f"|2{parse_hashline(f'|{empty_hash}|')[0]}|", + "end_line_hash": f"|2{parse_hashline(f'|{empty_hash}|')[0]}|", "text": "# Comment 1", }, { "operation": "replace", - "start_line_hash": f"6|{empty_hash}", - "end_line_hash": f"6|{empty_hash}", + "start_line_hash": f"|6{parse_hashline(f'|{empty_hash}|')[0]}|", + "end_line_hash": f"|6{parse_hashline(f'|{empty_hash}|')[0]}|", "text": "# Comment 2", }, { "operation": "insert", - "start_line_hash": f"8|{empty_hash}", + "start_line_hash": f"|8{parse_hashline(f'|{empty_hash}|')[0]}|", "text": "# Inserted after empty line 8", }, ] @@ -663,20 +740,20 @@ def get_h(ln): ops = [ { "operation": "replace", - "start_line_hash": f"5|{get_h(5)}", - "end_line_hash": f"8|{get_h(8)}", + "start_line_hash": f"|5{parse_hashline(f'|{get_h(5)}|')[0]}|", + "end_line_hash": f"|8{parse_hashline(f'|{get_h(8)}|')[0]}|", "text": "Replacement Alpha", }, { "operation": "replace", - "start_line_hash": f"16|{get_h(16)}", - "end_line_hash": f"22|{get_h(22)}", + "start_line_hash": f"|16{parse_hashline(f'|{get_h(16)}|')[0]}|", + "end_line_hash": f"|22{parse_hashline(f'|{get_h(22)}|')[0]}|", "text": "Replacement Beta\nMore Beta", }, { "operation": "replace", - "start_line_hash": f"35|{get_h(35)}", - "end_line_hash": f"42|{get_h(42)}", + "start_line_hash": f"|35{parse_hashline(f'|{get_h(35)}|')[0]}|", + "end_line_hash": f"|42{parse_hashline(f'|{get_h(42)}|')[0]}|", "text": "Replacement Gamma", }, ] @@ -724,8 +801,16 @@ def get_h(ln): h_last = h_lines[2].split("|")[1] ops = [ - {"operation": "insert", "start_line_hash": f"1|{h_first}", "text": "Before First"}, - {"operation": "insert", "start_line_hash": f"3|{h_last}", "text": "After Last"}, + { + "operation": "insert", + "start_line_hash": f"|1{parse_hashline(f'|{h_first}|')[0]}|", + "text": "Before First", + }, + { + "operation": "insert", + "start_line_hash": f"|3{parse_hashline(f'|{h_last}|')[0]}|", + "text": "After Last", + }, ] modified, success, failed = apply_hashline_operations(original, ops) @@ -750,14 +835,14 @@ def test_apply_hashline_operations_mixed_success(): ops = [ { "operation": "replace", - "start_line_hash": f"1|{h1}", - "end_line_hash": f"1|{h1}", + "start_line_hash": f"|1{parse_hashline(f'|{h1}|')[0]}|", + "end_line_hash": f"|1{parse_hashline(f'|{h1}|')[0]}|", "text": "New 1", }, { "operation": "replace", - "start_line_hash": "99|zz", - "end_line_hash": "99|zz", + "start_line_hash": "|99zz|", + "end_line_hash": "|99zz|", "text": "Fail", }, ] @@ -875,8 +960,12 @@ def test_apply_hashline_operations_bidirectional_stitching(): operations = [ { - "start_line_hash": f"7|{line_7_hash}", # Line 7 (1-indexed) - D - "end_line_hash": f"10|{line_10_hash}", # Line 10 (1-indexed) - F + "start_line_hash": ( + f"|7{parse_hashline(f'|{line_7_hash}|')[0]}|" + ), # Line 7 (1-indexed) - D + "end_line_hash": ( + f"|10{parse_hashline(f'|{line_10_hash}|')[0]}|" + ), # Line 10 (1-indexed) - F "operation": "replace", "text": replacement_text, } diff --git a/tests/basic/test_io.py b/tests/basic/test_io.py index 0af08d0ef19..577adab5987 100644 --- a/tests/basic/test_io.py +++ b/tests/basic/test_io.py @@ -169,7 +169,8 @@ def test_autocompleter_with_unicode_file(self): assert autocompleter.words == set(rel_fnames) @patch("builtins.input", return_value="test input") - def test_get_input_is_a_directory_error(self, mock_input): + @patch("cecli.io.InterruptibleInput", side_effect=RuntimeError) + def test_get_input_is_a_directory_error(self, mock_interruptible_input, mock_input): io = InputOutput(pretty=False, fancy_input=False) # Windows tests throw UnicodeDecodeError root = "/" rel_fnames = ["existing_file.txt"] diff --git a/tests/hooks/__init__.py b/tests/hooks/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/hooks/test_base.py b/tests/hooks/test_base.py new file mode 100644 index 00000000000..83897d5164b --- /dev/null +++ b/tests/hooks/test_base.py @@ -0,0 +1,205 @@ +"""Tests for base hook classes.""" + + +import pytest + +from cecli.hooks import BaseHook, CommandHook, HookType + + +class TestHook(BaseHook): + """Test hook for unit testing.""" + + type = HookType.START + + async def execute(self, coder, metadata): + """Test execution.""" + return True + + +class TestHookWithReturn(BaseHook): + """Test hook that returns a specific value.""" + + type = HookType.PRE_TOOL + + def __init__(self, return_value=True, **kwargs): + super().__init__(**kwargs) + self.return_value = return_value + + async def execute(self, coder, metadata): + """Return the configured value.""" + return self.return_value + + +class TestHookWithException(BaseHook): + """Test hook that raises an exception.""" + + type = HookType.START + + async def execute(self, coder, metadata): + """Raise an exception.""" + raise ValueError("Test exception") + + +class TestBaseHook: + """Test BaseHook class.""" + + def test_hook_creation(self): + """Test basic hook creation.""" + hook = TestHook(name="test_hook", priority=5, enabled=True) + + assert hook.name == "test_hook" + assert hook.priority == 5 + assert hook.enabled is True + assert hook.type == HookType.START + + def test_hook_default_name(self): + """Test hook uses class name as default.""" + hook = TestHook() + assert hook.name == "TestHook" + + def test_hook_validation(self): + """Test hook type validation.""" + + class InvalidHook(BaseHook): + # Missing type attribute + pass + + async def execute(self, coder, metadata): + return True + + with pytest.raises(ValueError, match="must define a 'type' attribute"): + InvalidHook() + + def test_hook_repr(self): + """Test hook string representation.""" + hook = TestHook(name="test_hook", priority=10, enabled=False) + repr_str = repr(hook) + + assert "TestHook" in repr_str + assert "name='test_hook'" in repr_str + assert "type=HookType.START" in repr_str or "type=START" in repr_str + assert "priority=10" in repr_str + assert "enabled=False" in repr_str + + @pytest.mark.asyncio + async def test_hook_execution(self): + """Test hook execution.""" + hook = TestHook() + result = await hook.execute(None, {}) + + assert result is True + + @pytest.mark.asyncio + async def test_hook_with_return_value(self): + """Test hook with specific return value.""" + hook = TestHookWithReturn(return_value=False) + result = await hook.execute(None, {}) + + assert result is False + + @pytest.mark.asyncio + async def test_hook_with_exception(self): + """Test hook that raises exception.""" + hook = TestHookWithException() + + with pytest.raises(ValueError, match="Test exception"): + await hook.execute(None, {}) + + +class TestCommandHook: + """Test CommandHook class.""" + + def test_command_hook_creation(self): + """Test command hook creation.""" + hook = CommandHook( + command="echo test", + hook_type=HookType.START, + name="test_command", + priority=5, + enabled=True, + ) + + assert hook.name == "test_command" + assert hook.command == "echo test" + assert hook.type == HookType.START + assert hook.priority == 5 + assert hook.enabled is True + + def test_command_hook_defaults(self): + """Test command hook with defaults.""" + hook = CommandHook(command="echo test", hook_type=HookType.START) + hook.type = HookType.END + + assert hook.name == "CommandHook" # Default class name + assert hook.command == "echo test" + assert hook.type == HookType.END + assert hook.priority == 10 # Default priority + assert hook.enabled is True # Default enabled + + @pytest.mark.asyncio + async def test_command_hook_execution(self, mocker): + """Test command hook execution.""" + # Mock run_cmd to avoid actually running commands + mock_run_cmd = mocker.patch("cecli.hooks.base.run_cmd") + mock_run_cmd.return_value = (0, "test output") + + # Mock coder object + mock_coder = mocker.Mock() + mock_coder.io = mocker.Mock() + mock_coder.io.tool_error = mocker.Mock() + mock_coder.root = "/tmp" + mock_coder.verbose = False + + hook = CommandHook(command="echo {test_var}", hook_type=HookType.START) + + metadata = {"test_var": "hello"} + result = await hook.execute(mock_coder, metadata) + + # Check that command was formatted with metadata + mock_run_cmd.assert_called_once() + call_args = mock_run_cmd.call_args[0][0] + assert "echo hello" in call_args + + assert result == 0 # Return code + + @pytest.mark.asyncio + async def test_command_hook_timeout(self, mocker): + """Test command hook timeout.""" + import subprocess + + # Mock run_cmd to raise TimeoutExpired + mock_run_cmd = mocker.patch("cecli.hooks.base.run_cmd") + mock_run_cmd.side_effect = subprocess.TimeoutExpired("echo test", 30) + + # Mock coder object + mock_coder = mocker.Mock() + mock_coder.io = mocker.Mock() + mock_coder.io.tool_error = mocker.Mock() + mock_coder.root = "/tmp" + mock_coder.verbose = False + + hook = CommandHook(command="echo test", hook_type=HookType.START) + + result = await hook.execute(mock_coder, {}) + + assert result == 1 # Non-zero exit code on timeout + + @pytest.mark.asyncio + async def test_command_hook_exception(self, mocker): + """Test command hook with general exception.""" + # Mock run_cmd to raise general exception + mock_run_cmd = mocker.patch("cecli.hooks.base.run_cmd") + mock_run_cmd.side_effect = Exception("Test error") + + # Mock coder object + mock_coder = mocker.Mock() + mock_coder.io = mocker.Mock() + mock_coder.io.tool_error = mocker.Mock() + mock_coder.root = "/tmp" + mock_coder.verbose = False + + hook = CommandHook(command="echo test", hook_type=HookType.START) + + result = await hook.execute(mock_coder, {}) + + assert result == 1 # Non-zero exit code on exception diff --git a/tests/hooks/test_manager.py b/tests/hooks/test_manager.py new file mode 100644 index 00000000000..073ee50034c --- /dev/null +++ b/tests/hooks/test_manager.py @@ -0,0 +1,267 @@ +"""Tests for HookManager.""" + +import json + +import pytest + +from cecli.hooks import BaseHook, HookManager, HookType + + +class TestHook(BaseHook): + """Test hook for unit testing.""" + + type = HookType.START + + async def execute(self, coder, metadata): + """Test execution.""" + return True + + +class TestPreToolHook(BaseHook): + """Test hook for pre_tool type.""" + + type = HookType.PRE_TOOL + + async def execute(self, coder, metadata): + """Test execution.""" + return True + + +class TestHookManager: + """Test HookManager class.""" + + def setup_method(self): + """Set up test environment.""" + # Clear singleton instance + HookManager._instance = None + self.manager = HookManager() + + def test_singleton_pattern(self): + """Test that HookManager is a singleton.""" + manager1 = HookManager() + manager2 = HookManager() + + assert manager1 is manager2 + assert manager1._initialized is True + assert manager2._initialized is True + + def test_register_hook(self): + """Test hook registration.""" + hook = TestHook(name="test_hook") + self.manager.register_hook(hook) + + assert self.manager.hook_exists("test_hook") is True + assert "test_hook" in self.manager._hooks_by_name + assert hook in self.manager._hooks_by_type[HookType.START.value] + + def test_register_duplicate_hook(self): + """Test duplicate hook registration fails.""" + hook1 = TestHook(name="test_hook") + hook2 = TestHook(name="test_hook") # Same name + + self.manager.register_hook(hook1) + + with pytest.raises(ValueError, match="already exists"): + self.manager.register_hook(hook2) + + def test_get_hooks(self): + """Test getting hooks by type.""" + hook1 = TestHook(name="hook1", priority=10) + hook2 = TestHook(name="hook2", priority=5) # Higher priority + hook3 = TestPreToolHook(name="hook3", priority=10) + + self.manager.register_hook(hook1) + self.manager.register_hook(hook2) + self.manager.register_hook(hook3) + + # Get start hooks + start_hooks = self.manager.get_hooks(HookType.START.value) + assert len(start_hooks) == 2 + + # Should be sorted by priority (lower = higher priority) + assert start_hooks[0].name == "hook2" # priority 5 + assert start_hooks[1].name == "hook1" # priority 10 + + # Get pre_tool hooks + pre_tool_hooks = self.manager.get_hooks(HookType.PRE_TOOL.value) + assert len(pre_tool_hooks) == 1 + assert pre_tool_hooks[0].name == "hook3" + + # Get non-existent type + no_hooks = self.manager.get_hooks("non_existent_type") + assert len(no_hooks) == 0 + + def test_get_all_hooks(self): + """Test getting all hooks grouped by type.""" + hook1 = TestHook(name="hook1") + hook2 = TestHook(name="hook2") + hook3 = TestPreToolHook(name="hook3") + + self.manager.register_hook(hook1) + self.manager.register_hook(hook2) + self.manager.register_hook(hook3) + + all_hooks = self.manager.get_all_hooks() + + assert HookType.START.value in all_hooks + assert HookType.PRE_TOOL.value in all_hooks + assert len(all_hooks[HookType.START.value]) == 2 + assert len(all_hooks[HookType.PRE_TOOL.value]) == 1 + + def test_hook_exists(self): + """Test checking if hook exists.""" + hook = TestHook(name="test_hook") + + assert self.manager.hook_exists("test_hook") is False + + self.manager.register_hook(hook) + + assert self.manager.hook_exists("test_hook") is True + assert self.manager.hook_exists("non_existent") is False + + def test_enable_disable_hook(self): + """Test enabling and disabling hooks.""" + hook = TestHook(name="test_hook", enabled=False) + self.manager.register_hook(hook) + + # Initially disabled + start_hooks = self.manager.get_hooks(HookType.START.value) + assert len(start_hooks) == 0 # Disabled hooks not returned + + # Enable hook + result = self.manager.enable_hook("test_hook") + assert result is True + assert hook.enabled is True + + start_hooks = self.manager.get_hooks(HookType.START.value) + assert len(start_hooks) == 1 # Now enabled + + # Disable hook + result = self.manager.disable_hook("test_hook") + assert result is True + assert hook.enabled is False + + start_hooks = self.manager.get_hooks(HookType.START.value) + assert len(start_hooks) == 0 # Disabled again + + def test_enable_nonexistent_hook(self): + """Test enabling non-existent hook.""" + result = self.manager.enable_hook("non_existent") + assert result is False + + def test_disable_nonexistent_hook(self): + """Test disabling non-existent hook.""" + result = self.manager.disable_hook("non_existent") + assert result is False + + def test_state_persistence(self, tmp_path): + """Test hook state persistence.""" + # Create temporary state file + state_file = tmp_path / "hooks_state.json" + + # Monkey-patch state file location + self.manager._state_file = state_file + + # Create and register hooks + hook1 = TestHook(name="hook1", enabled=True) + hook2 = TestHook(name="hook2", enabled=False) + + self.manager.register_hook(hook1) + self.manager.register_hook(hook2) + + # Save state + self.manager._save_state() + + # Verify state file was created + assert state_file.exists() + + # Load and verify state + with open(state_file, "r") as f: + state = json.load(f) + + assert state["hook1"] is True + assert state["hook2"] is False + + # Create new manager instance to test loading + HookManager._instance = None + new_manager = HookManager() + new_manager._state_file = state_file + + # Register hooks with new manager + new_hook1 = TestHook(name="hook1", enabled=False) # Default disabled + new_hook2 = TestHook(name="hook2", enabled=True) # Default enabled + + new_manager.register_hook(new_hook1) + new_manager.register_hook(new_hook2) + + # Load state should override defaults + new_manager._load_state() + + assert new_hook1.enabled is True # Loaded from state + assert new_hook2.enabled is False # Loaded from state + + def test_clear(self): + """Test clearing all hooks.""" + hook1 = TestHook(name="hook1") + hook2 = TestHook(name="hook2") + + self.manager.register_hook(hook1) + self.manager.register_hook(hook2) + + assert len(self.manager._hooks_by_name) == 2 + assert len(self.manager._hooks_by_type[HookType.START.value]) == 2 + + self.manager.clear() + + assert len(self.manager._hooks_by_name) == 0 + assert len(self.manager._hooks_by_type) == 0 + + @pytest.mark.asyncio + async def test_call_hooks(self): + """Test calling hooks.""" + + # Create hooks with different return values + class TrueHook(BaseHook): + type = HookType.PRE_TOOL + + async def execute(self, coder, metadata): + return True + + class FalseHook(BaseHook): + type = HookType.PRE_TOOL + + async def execute(self, coder, metadata): + return False + + class ErrorHook(BaseHook): + type = HookType.PRE_TOOL + + async def execute(self, coder, metadata): + raise ValueError("Test error") + + true_hook = TrueHook(name="true_hook") + false_hook = FalseHook(name="false_hook") + error_hook = ErrorHook(name="error_hook") + + self.manager.register_hook(true_hook) + self.manager.register_hook(false_hook) + self.manager.register_hook(error_hook) + + # Test with all hooks enabled + result = await self.manager.call_hooks(HookType.PRE_TOOL.value, None, {}) + assert result is False # false_hook returns False + + # Disable false_hook + self.manager.disable_hook("false_hook") + + # Test with only true_hook and error_hook enabled + result = await self.manager.call_hooks(HookType.PRE_TOOL.value, None, {}) + assert result is True # Only true_hook returns True, error is caught + + # Disable all hooks + self.manager.disable_hook("true_hook") + self.manager.disable_hook("error_hook") + + # Test with no enabled hooks + result = await self.manager.call_hooks(HookType.PRE_TOOL.value, None, {}) + assert result is True # No hooks to run = success diff --git a/tests/hooks/test_registry.py b/tests/hooks/test_registry.py new file mode 100644 index 00000000000..3a10432a8c0 --- /dev/null +++ b/tests/hooks/test_registry.py @@ -0,0 +1,365 @@ +"""Tests for HookRegistry.""" + +from pathlib import Path + +import pytest # noqa: F401 +import yaml + +from cecli.hooks import BaseHook, HookManager, HookRegistry, HookType + + +class TestHook(BaseHook): + """Test hook for unit testing.""" + + type = HookType.START + + async def execute(self, coder, metadata): + """Test execution.""" + return True + + +class AnotherTestHook(BaseHook): + """Another test hook.""" + + type = HookType.END + + async def execute(self, coder, metadata): + """Test execution.""" + return True + + +class TestHookRegistry: + """Test HookRegistry class.""" + + def setup_method(self): + """Set up test environment.""" + # Clear singleton instance + HookManager._instance = None + self.manager = HookManager() + self.registry = HookRegistry(self.manager) + + def test_load_hooks_from_directory(self, tmp_path): + """Test loading hooks from a directory.""" + # Create a test hook file + hook_file = tmp_path / "test_hook.py" + hook_file.write_text(""" +from cecli.hooks import BaseHook, HookType + +class TestHook(BaseHook): + type = HookType.START + + async def execute(self, coder, metadata): + return True + +class AnotherHook(BaseHook): + type = HookType.END + + async def execute(self, coder, metadata): + return True +""") + + # Load hooks from directory + loaded = self.registry.load_hooks_from_directory(tmp_path) + + assert len(loaded) == 2 + assert "TestHook" in loaded + assert "AnotherHook" in loaded + + # Verify hooks were registered + assert self.manager.hook_exists("TestHook") + assert self.manager.hook_exists("AnotherHook") + + def test_load_hooks_from_empty_directory(self, tmp_path): + """Test loading hooks from empty directory.""" + loaded = self.registry.load_hooks_from_directory(tmp_path) + + assert loaded == [] + + def test_load_hooks_from_nonexistent_directory(self): + """Test loading hooks from non-existent directory.""" + non_existent = Path("/non/existent/directory") + loaded = self.registry.load_hooks_from_directory(non_existent) + + assert loaded == [] + + def test_load_hooks_from_config(self, tmp_path): + """Test loading hooks from YAML configuration.""" + # Create a test hook file first + hook_file = tmp_path / "test_hook.py" + hook_file.write_text(""" +from cecli.hooks import BaseHook, HookType + +class MyStartHook(BaseHook): + type = HookType.START + + async def execute(self, coder, metadata): + return True +""") + + # Create YAML config + config_file = tmp_path / "hooks.yml" + config = { + "hooks": { + "start": [ + { + "name": "MyStartHook", + "file": str(hook_file), + "priority": 5, + "enabled": True, + "description": "Test start hook", + } + ], + "end": [ + { + "name": "cleanup_hook", + "command": "echo 'Cleanup at {timestamp}'", + "priority": 10, + "enabled": False, + "description": "Test cleanup hook", + } + ], + } + } + + with open(config_file, "w") as f: + yaml.dump(config, f) + + # Load hooks from config + loaded = self.registry.load_hooks_from_config(config_file) + + assert len(loaded) == 2 + assert "MyStartHook" in loaded + assert "cleanup_hook" in loaded + + # Verify hooks were registered + assert self.manager.hook_exists("MyStartHook") + assert self.manager.hook_exists("cleanup_hook") + + # Verify properties + start_hook = self.manager._hooks_by_name["MyStartHook"] + assert start_hook.priority == 5 + assert start_hook.enabled is True + assert start_hook.description == "Test start hook" + + cleanup_hook = self.manager._hooks_by_name["cleanup_hook"] + assert cleanup_hook.priority == 10 + assert cleanup_hook.enabled is False + assert cleanup_hook.description == "Test cleanup hook" + + def test_load_hooks_from_invalid_config(self, tmp_path): + """Test loading hooks from invalid YAML.""" + config_file = tmp_path / "hooks.yml" + config_file.write_text("invalid: yaml: [") + + loaded = self.registry.load_hooks_from_config(config_file) + + assert loaded == [] + + def test_load_hooks_from_empty_config(self, tmp_path): + """Test loading hooks from empty config.""" + config_file = tmp_path / "hooks.yml" + config_file.write_text("") + + loaded = self.registry.load_hooks_from_config(config_file) + + assert loaded == [] + + def test_load_hooks_from_config_missing_hooks_key(self, tmp_path): + """Test loading hooks from config missing 'hooks' key.""" + config_file = tmp_path / "hooks.yml" + config = {"other_key": "value"} + + with open(config_file, "w") as f: + yaml.dump(config, f) + + loaded = self.registry.load_hooks_from_config(config_file) + + assert loaded == [] + + def test_load_hooks_from_config_invalid_hook_type(self, tmp_path): + """Test loading hooks with invalid hook type.""" + config_file = tmp_path / "hooks.yml" + config = { + "hooks": { + "invalid_type": [ # Not a valid HookType + {"name": "test_hook", "command": "echo test", "priority": 10, "enabled": True} + ] + } + } + + with open(config_file, "w") as f: + yaml.dump(config, f) + + loaded = self.registry.load_hooks_from_config(config_file) + + assert loaded == [] + + def test_load_hooks_from_config_missing_name(self, tmp_path): + """Test loading hooks with missing name.""" + config_file = tmp_path / "hooks.yml" + config = { + "hooks": { + "start": [ + { + # Missing name + "command": "echo test", + "priority": 10, + "enabled": True, + } + ] + } + } + + with open(config_file, "w") as f: + yaml.dump(config, f) + + loaded = self.registry.load_hooks_from_config(config_file) + + assert loaded == [] + + def test_load_hooks_from_config_missing_file_or_command(self, tmp_path): + """Test loading hooks with missing file or command.""" + config_file = tmp_path / "hooks.yml" + config = { + "hooks": { + "start": [ + { + "name": "test_hook", + "priority": 10, + "enabled": True, + # Missing file or command + } + ] + } + } + + with open(config_file, "w") as f: + yaml.dump(config, f) + + loaded = self.registry.load_hooks_from_config(config_file) + + assert loaded == [] + + def test_load_hooks_from_config_nonexistent_file(self, tmp_path): + """Test loading hooks with non-existent file.""" + config_file = tmp_path / "hooks.yml" + config = { + "hooks": { + "start": [ + { + "name": "test_hook", + "file": "/non/existent/file.py", + "priority": 10, + "enabled": True, + } + ] + } + } + + with open(config_file, "w") as f: + yaml.dump(config, f) + + loaded = self.registry.load_hooks_from_config(config_file) + + assert loaded == [] + + def test_load_default_hooks(self, tmp_path, monkeypatch): + """Test loading hooks from default locations.""" + # Mock home directory + mock_home = tmp_path / "home" + mock_home.mkdir() + + # Create .cecli directory structure + cecli_dir = mock_home / ".cecli" + cecli_dir.mkdir() + hooks_dir = cecli_dir / "hooks" + hooks_dir.mkdir() + + # Create a hook file + hook_file = hooks_dir / "my_hook.py" + hook_file.write_text(""" +from cecli.hooks import BaseHook, HookType + +class MyHook(BaseHook): + type = HookType.START + + async def execute(self, coder, metadata): + return True +""") + + # Create YAML config + config_file = cecli_dir / "hooks.yml" + config = { + "hooks": { + "end": [ + {"name": "config_hook", "command": "echo test", "priority": 10, "enabled": True} + ] + } + } + + with open(config_file, "w") as f: + yaml.dump(config, f) + + # Monkey-patch Path.home + monkeypatch.setattr(Path, "home", lambda: mock_home) + + # Load default hooks + loaded = self.registry.load_default_hooks() + + assert len(loaded) == 2 + assert "MyHook" in loaded + assert "config_hook" in loaded + + def test_reload_hooks(self, tmp_path, monkeypatch): + """Test reloading hooks.""" + # Mock home directory + mock_home = tmp_path / "home" + mock_home.mkdir() + + # Create .cecli directory structure + cecli_dir = mock_home / ".cecli" + cecli_dir.mkdir() + hooks_dir = cecli_dir / "hooks" + hooks_dir.mkdir() + + # Create a hook file + hook_file = hooks_dir / "my_hook.py" + hook_file.write_text(""" +from cecli.hooks import BaseHook, HookType + +class MyHook(BaseHook): + type = HookType.START + + async def execute(self, coder, metadata): + return True +""") + + # Monkey-patch Path.home + monkeypatch.setattr(Path, "home", lambda: mock_home) + + # Load default hooks + loaded1 = self.registry.load_default_hooks() + assert len(loaded1) == 1 + assert "MyHook" in loaded1 + + # Update the hook file + hook_file.write_text(""" +from cecli.hooks import BaseHook, HookType + +class NewHook(BaseHook): + type = HookType.END + + async def execute(self, coder, metadata): + return True +""") + + # Reload hooks + loaded2 = self.registry.reload_hooks() + + assert len(loaded2) == 1 + assert "NewHook" in loaded2 + assert "MyHook" not in loaded2 # Old hook should be gone + + # Verify manager was cleared + assert not self.manager.hook_exists("MyHook") + assert self.manager.hook_exists("NewHook") diff --git a/tests/tools/test_show_numbered_context.py b/tests/tools/test_show_numbered_context.py index 0ed3b338da4..419956e6aa4 100644 --- a/tests/tools/test_show_numbered_context.py +++ b/tests/tools/test_show_numbered_context.py @@ -4,7 +4,6 @@ import pytest -from cecli.helpers.hashline import hashline from cecli.tools import show_numbered_context @@ -60,24 +59,14 @@ def test_pattern_with_zero_line_number_is_allowed(coder_with_file): ], ) - assert "beta" in result - assert "line 2" in result or "2 | beta" in result + # show_numbered_context now returns a static success message + assert "Successfully retrieved context" in result coder.io.tool_error.assert_not_called() def test_empty_pattern_uses_line_number(coder_with_file): coder, file_path = coder_with_file - # Calculate expected hashline for line 2 - content = file_path.read_text() - hashed_content = hashline(content) - # Extract hashline for line 2 - lines = hashed_content.splitlines() - line2_hashline = lines[1] # Index 1 is line 2 (0-indexed) - # hashline format is "{hash_fragment}:{line_num}|{line_content}" - # We need the full hashline (e.g., "BP:2|beta") - expected_hashline = line2_hashline - result = show_numbered_context.Tool.execute( coder, show=[ @@ -90,7 +79,8 @@ def test_empty_pattern_uses_line_number(coder_with_file): ], ) - assert expected_hashline in result + # show_numbered_context now returns a static success message + assert "Successfully retrieved context" in result coder.io.tool_error.assert_not_called()