|
2 | 2 | import logging |
3 | 3 |
|
4 | 4 | from collections import defaultdict |
5 | | -from typing import Any, Callable, Coroutine, Optional, Type, Union, Dict, List |
| 5 | +from typing import ( |
| 6 | + TypeVar, |
| 7 | + Any, |
| 8 | + Callable, |
| 9 | + Coroutine, |
| 10 | + Literal, |
| 11 | + Optional, |
| 12 | + Type, |
| 13 | + Union, |
| 14 | + Dict, |
| 15 | + List, |
| 16 | + cast, |
| 17 | + overload, |
| 18 | +) |
6 | 19 |
|
7 | 20 | from nio import ( |
8 | 21 | Event, |
|
25 | 38 | ErrorCallback = Callable[[Exception], Coroutine] |
26 | 39 | CommandErrorCallback = Callable[[Context, Exception], Coroutine[Any, Any, Any]] |
27 | 40 |
|
| 41 | +F = TypeVar("F", ErrorCallback, CommandErrorCallback) |
| 42 | + |
28 | 43 |
|
29 | 44 | class Registry: |
30 | 45 | """ |
@@ -356,42 +371,92 @@ def wrapper(f: Callback) -> Callback: |
356 | 371 |
|
357 | 372 | return wrapper |
358 | 373 |
|
| 374 | + @overload |
| 375 | + def error( |
| 376 | + self, |
| 377 | + exception: Optional[type[Exception]] = None, |
| 378 | + *, |
| 379 | + context: Literal[True], |
| 380 | + ) -> Callable[[CommandErrorCallback], CommandErrorCallback]: ... |
| 381 | + |
| 382 | + @overload |
| 383 | + def error( |
| 384 | + self, |
| 385 | + exception: Optional[type[Exception]] = None, |
| 386 | + *, |
| 387 | + context: Literal[False] = ..., |
| 388 | + ) -> Callable[[ErrorCallback], ErrorCallback]: ... |
| 389 | + |
359 | 390 | def error( |
360 | | - self, exception: Optional[type[Exception]] = None |
361 | | - ) -> Callable[[ErrorCallback], ErrorCallback]: |
| 391 | + self, |
| 392 | + exception: Optional[type[Exception]] = None, |
| 393 | + *, |
| 394 | + context: bool = False, |
| 395 | + ) -> Union[ |
| 396 | + Callable[[ErrorCallback], ErrorCallback], |
| 397 | + Callable[[CommandErrorCallback], CommandErrorCallback], |
| 398 | + ]: |
362 | 399 | """Decorator to register an error handler. |
363 | 400 |
|
364 | 401 | If an exception type is provided, the handler is only invoked for |
365 | 402 | that specific exception. If omitted, the handler acts as a generic |
366 | 403 | fallback for any unhandled error. |
367 | 404 |
|
| 405 | + Set ``context=True`` to receive the command context alongside the error, |
| 406 | + useful for command-specific errors where you want to reply to the user. |
| 407 | +
|
368 | 408 | ## Example |
369 | 409 |
|
370 | 410 | ```python |
| 411 | + @bot.error(CommandNotFoundError, context=True) |
| 412 | + async def on_command_not_found(ctx, error): |
| 413 | + await ctx.reply("Command not found!") |
| 414 | +
|
371 | 415 | @bot.error(ValueError) |
372 | 416 | async def on_value_error(error): |
373 | | - await room.send(f"Bad value: {error}") |
| 417 | + pass |
374 | 418 |
|
375 | 419 | @bot.error() |
376 | 420 | async def on_any_error(error): |
377 | | - await room.send(f"Something went wrong: {error}") |
| 421 | + pass |
378 | 422 | ``` |
379 | 423 | """ |
380 | 424 |
|
381 | | - def wrapper(func: ErrorCallback) -> ErrorCallback: |
| 425 | + def wrapper( |
| 426 | + func: F, |
| 427 | + ) -> F: |
382 | 428 | if not inspect.iscoroutinefunction(func): |
383 | 429 | raise TypeError("Error handlers must be coroutines") |
384 | 430 |
|
385 | | - if exception: |
386 | | - self._error_handlers[exception] = func |
| 431 | + if context: |
| 432 | + self._register_command_error( |
| 433 | + cast(CommandErrorCallback, func), exception |
| 434 | + ) |
387 | 435 | else: |
388 | | - self._fallback_error_handler = func |
| 436 | + self._register_error(cast(ErrorCallback, func), exception) |
| 437 | + |
389 | 438 | logger.debug( |
390 | 439 | "registered error handler '%s' on %s", |
391 | 440 | func.__name__, |
392 | 441 | type(self).__name__, |
393 | 442 | ) |
394 | | - |
395 | 443 | return func |
396 | 444 |
|
397 | 445 | return wrapper |
| 446 | + |
| 447 | + def _register_error( |
| 448 | + self, func: ErrorCallback, exception: Optional[type[Exception]] = None |
| 449 | + ) -> None: |
| 450 | + if not exception: |
| 451 | + self._error_handlers[exception] = func |
| 452 | + else: |
| 453 | + self._fallback_error_handler = func |
| 454 | + |
| 455 | + def _register_command_error( |
| 456 | + self, |
| 457 | + func: CommandErrorCallback, |
| 458 | + exception: Optional[type[Exception]] = None, |
| 459 | + ) -> None: |
| 460 | + if not exception: |
| 461 | + exception = Exception |
| 462 | + self._command_error_handlers[exception] = func |
0 commit comments