diff --git a/matrix/registry.py b/matrix/registry.py index f8a6a66..472e835 100644 --- a/matrix/registry.py +++ b/matrix/registry.py @@ -2,7 +2,20 @@ import logging from collections import defaultdict -from typing import Any, Callable, Coroutine, Optional, Type, Union, Dict, List +from typing import ( + TypeVar, + Any, + Callable, + Coroutine, + Literal, + Optional, + Type, + Union, + Dict, + List, + cast, + overload, +) from nio import ( Event, @@ -25,6 +38,8 @@ ErrorCallback = Callable[[Exception], Coroutine] CommandErrorCallback = Callable[[Context, Exception], Coroutine[Any, Any, Any]] +F = TypeVar("F", ErrorCallback, CommandErrorCallback) + class Registry: """ @@ -356,42 +371,92 @@ def wrapper(f: Callback) -> Callback: return wrapper + @overload + def error( + self, + exception: Optional[type[Exception]] = None, + *, + context: Literal[True], + ) -> Callable[[CommandErrorCallback], CommandErrorCallback]: ... + + @overload + def error( + self, + exception: Optional[type[Exception]] = None, + *, + context: Literal[False] = ..., + ) -> Callable[[ErrorCallback], ErrorCallback]: ... + def error( - self, exception: Optional[type[Exception]] = None - ) -> Callable[[ErrorCallback], ErrorCallback]: + self, + exception: Optional[type[Exception]] = None, + *, + context: bool = False, + ) -> Union[ + Callable[[ErrorCallback], ErrorCallback], + Callable[[CommandErrorCallback], CommandErrorCallback], + ]: """Decorator to register an error handler. If an exception type is provided, the handler is only invoked for that specific exception. If omitted, the handler acts as a generic fallback for any unhandled error. + Set ``context=True`` to receive the command context alongside the error, + useful for command-specific errors where you want to reply to the user. + ## Example ```python + @bot.error(CommandNotFoundError, context=True) + async def on_command_not_found(ctx, error): + await ctx.reply("Command not found!") + @bot.error(ValueError) async def on_value_error(error): - await room.send(f"Bad value: {error}") + pass @bot.error() async def on_any_error(error): - await room.send(f"Something went wrong: {error}") + pass ``` """ - def wrapper(func: ErrorCallback) -> ErrorCallback: + def wrapper( + func: F, + ) -> F: if not inspect.iscoroutinefunction(func): raise TypeError("Error handlers must be coroutines") - if exception: - self._error_handlers[exception] = func + if context: + self._register_command_error( + cast(CommandErrorCallback, func), exception + ) else: - self._fallback_error_handler = func + self._register_error(cast(ErrorCallback, func), exception) + logger.debug( "registered error handler '%s' on %s", func.__name__, type(self).__name__, ) - return func return wrapper + + def _register_error( + self, func: ErrorCallback, exception: Optional[type[Exception]] = None + ) -> None: + if exception: + self._error_handlers[exception] = func + else: + self._fallback_error_handler = func + + def _register_command_error( + self, + func: CommandErrorCallback, + exception: Optional[type[Exception]] = None, + ) -> None: + if not exception: + exception = Exception + self._command_error_handlers[exception] = func diff --git a/tests/test_registry.py b/tests/test_registry.py index ebf3a60..6a69dca 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -5,7 +5,7 @@ from matrix.registry import Registry from matrix.command import Command from matrix.group import Group -from matrix.errors import AlreadyRegisteredError +from matrix.errors import AlreadyRegisteredError, CommandNotFoundError, CheckError @pytest.fixture @@ -344,6 +344,68 @@ async def second_handler(error): assert registry._error_handlers[ValueError] is second_handler +def test_register_command_error_handler_with_exception_type__expect_handler_in_dict( + registry: Registry, +): + @registry.error(CommandNotFoundError, context=True) + async def on_command_not_found(ctx, error): + pass + + assert ( + registry._command_error_handlers[CommandNotFoundError] is on_command_not_found + ) + + +def test_register_command_error_handler_with_non_coroutine__expect_type_error( + registry: Registry, +): + with pytest.raises(TypeError): + + @registry.error(CommandNotFoundError, context=True) + def sync_handler(ctx, error): + pass + + +def test_register_multiple_command_error_handlers__expect_all_in_dict( + registry: Registry, +): + @registry.error(CommandNotFoundError, context=True) + async def on_command_not_found(ctx, error): + pass + + @registry.error(CheckError, context=True) + async def on_check_error(ctx, error): + pass + + assert CommandNotFoundError in registry._command_error_handlers + assert CheckError in registry._command_error_handlers + + +def test_register_command_error_handler_overwrites_previous__expect_latest_handler( + registry: Registry, +): + @registry.error(CommandNotFoundError, context=True) + async def first_handler(ctx, error): + pass + + @registry.error(CommandNotFoundError, context=True) + async def second_handler(ctx, error): + pass + + assert registry._command_error_handlers[CommandNotFoundError] is second_handler + + +def test_register_error_with_context_false__expect_handler_in_error_handlers( + registry: Registry, +): + @registry.error(ValueError, context=False) + async def on_value_error(error): + pass + + assert registry._error_handlers[ValueError] is on_value_error + assert ValueError not in registry._command_error_handlers + + def test_commands_property_with_empty_registry__expect_empty_dict(registry: Registry): assert registry.commands == {}