Skip to content

Commit 0d82ceb

Browse files
committed
Overload error decorator to allow command error handling / error handlers with context
1 parent 9d2b531 commit 0d82ceb

File tree

2 files changed

+138
-11
lines changed

2 files changed

+138
-11
lines changed

matrix/registry.py

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,20 @@
22
import logging
33

44
from collections import defaultdict
5-
from typing import Any, Callable, Coroutine, Optional, Type, Union, Dict, List
5+
from typing import (
6+
TypeVar,
7+
Any,
8+
Callable,
9+
Coroutine,
10+
Literal,
11+
Optional,
12+
Type,
13+
Union,
14+
Dict,
15+
List,
16+
cast,
17+
overload,
18+
)
619

720
from nio import (
821
Event,
@@ -25,6 +38,8 @@
2538
ErrorCallback = Callable[[Exception], Coroutine]
2639
CommandErrorCallback = Callable[[Context, Exception], Coroutine[Any, Any, Any]]
2740

41+
F = TypeVar("F", ErrorCallback, CommandErrorCallback)
42+
2843

2944
class Registry:
3045
"""
@@ -356,42 +371,92 @@ def wrapper(f: Callback) -> Callback:
356371

357372
return wrapper
358373

374+
@overload
375+
def error(
376+
self,
377+
exception: Optional[type[Exception]] = None,
378+
*,
379+
context: Literal[True],
380+
) -> Callable[[CommandErrorCallback], CommandErrorCallback]: ...
381+
382+
@overload
383+
def error(
384+
self,
385+
exception: Optional[type[Exception]] = None,
386+
*,
387+
context: Literal[False] = ...,
388+
) -> Callable[[ErrorCallback], ErrorCallback]: ...
389+
359390
def error(
360-
self, exception: Optional[type[Exception]] = None
361-
) -> Callable[[ErrorCallback], ErrorCallback]:
391+
self,
392+
exception: Optional[type[Exception]] = None,
393+
*,
394+
context: bool = False,
395+
) -> Union[
396+
Callable[[ErrorCallback], ErrorCallback],
397+
Callable[[CommandErrorCallback], CommandErrorCallback],
398+
]:
362399
"""Decorator to register an error handler.
363400
364401
If an exception type is provided, the handler is only invoked for
365402
that specific exception. If omitted, the handler acts as a generic
366403
fallback for any unhandled error.
367404
405+
Set ``context=True`` to receive the command context alongside the error,
406+
useful for command-specific errors where you want to reply to the user.
407+
368408
## Example
369409
370410
```python
411+
@bot.error(CommandNotFoundError, context=True)
412+
async def on_command_not_found(ctx, error):
413+
await ctx.reply("Command not found!")
414+
371415
@bot.error(ValueError)
372416
async def on_value_error(error):
373-
await room.send(f"Bad value: {error}")
417+
pass
374418
375419
@bot.error()
376420
async def on_any_error(error):
377-
await room.send(f"Something went wrong: {error}")
421+
pass
378422
```
379423
"""
380424

381-
def wrapper(func: ErrorCallback) -> ErrorCallback:
425+
def wrapper(
426+
func: F,
427+
) -> F:
382428
if not inspect.iscoroutinefunction(func):
383429
raise TypeError("Error handlers must be coroutines")
384430

385-
if exception:
386-
self._error_handlers[exception] = func
431+
if context:
432+
self._register_command_error(
433+
cast(CommandErrorCallback, func), exception
434+
)
387435
else:
388-
self._fallback_error_handler = func
436+
self._register_error(cast(ErrorCallback, func), exception)
437+
389438
logger.debug(
390439
"registered error handler '%s' on %s",
391440
func.__name__,
392441
type(self).__name__,
393442
)
394-
395443
return func
396444

397445
return wrapper
446+
447+
def _register_error(
448+
self, func: ErrorCallback, exception: Optional[type[Exception]] = None
449+
) -> None:
450+
if not exception:
451+
self._error_handlers[exception] = func
452+
else:
453+
self._fallback_error_handler = func
454+
455+
def _register_command_error(
456+
self,
457+
func: CommandErrorCallback,
458+
exception: Optional[type[Exception]] = None,
459+
) -> None:
460+
if not exception:
461+
exception = Exception
462+
self._command_error_handlers[exception] = func

tests/test_registry.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from matrix.registry import Registry
66
from matrix.command import Command
77
from matrix.group import Group
8-
from matrix.errors import AlreadyRegisteredError
8+
from matrix.errors import AlreadyRegisteredError, CommandNotFoundError, CheckError
99

1010

1111
@pytest.fixture
@@ -344,6 +344,68 @@ async def second_handler(error):
344344
assert registry._error_handlers[ValueError] is second_handler
345345

346346

347+
def test_register_command_error_handler_with_exception_type__expect_handler_in_dict(
348+
registry: Registry,
349+
):
350+
@registry.error(CommandNotFoundError, context=True)
351+
async def on_command_not_found(ctx, error):
352+
pass
353+
354+
assert (
355+
registry._command_error_handlers[CommandNotFoundError] is on_command_not_found
356+
)
357+
358+
359+
def test_register_command_error_handler_with_non_coroutine__expect_type_error(
360+
registry: Registry,
361+
):
362+
with pytest.raises(TypeError):
363+
364+
@registry.error(CommandNotFoundError, context=True)
365+
def sync_handler(ctx, error):
366+
pass
367+
368+
369+
def test_register_multiple_command_error_handlers__expect_all_in_dict(
370+
registry: Registry,
371+
):
372+
@registry.error(CommandNotFoundError, context=True)
373+
async def on_command_not_found(ctx, error):
374+
pass
375+
376+
@registry.error(CheckError, context=True)
377+
async def on_check_error(ctx, error):
378+
pass
379+
380+
assert CommandNotFoundError in registry._command_error_handlers
381+
assert CheckError in registry._command_error_handlers
382+
383+
384+
def test_register_command_error_handler_overwrites_previous__expect_latest_handler(
385+
registry: Registry,
386+
):
387+
@registry.error(CommandNotFoundError, context=True)
388+
async def first_handler(ctx, error):
389+
pass
390+
391+
@registry.error(CommandNotFoundError, context=True)
392+
async def second_handler(ctx, error):
393+
pass
394+
395+
assert registry._command_error_handlers[CommandNotFoundError] is second_handler
396+
397+
398+
def test_register_error_with_context_false__expect_handler_in_error_handlers(
399+
registry: Registry,
400+
):
401+
@registry.error(ValueError, context=False)
402+
async def on_value_error(error):
403+
pass
404+
405+
assert registry._error_handlers[ValueError] is on_value_error
406+
assert ValueError not in registry._command_error_handlers
407+
408+
347409
def test_commands_property_with_empty_registry__expect_empty_dict(registry: Registry):
348410
assert registry.commands == {}
349411

0 commit comments

Comments
 (0)