diff options
Diffstat (limited to 'poezio/decorators.py')
-rw-r--r-- | poezio/decorators.py | 161 |
1 files changed, 113 insertions, 48 deletions
diff --git a/poezio/decorators.py b/poezio/decorators.py index 4b5d0320..9342161f 100644 --- a/poezio/decorators.py +++ b/poezio/decorators.py @@ -1,54 +1,106 @@ """ Module containing various decorators """ -from typing import Any, Callable, List, Optional + +from __future__ import annotations +from asyncio import iscoroutinefunction + +from typing import ( + cast, + Any, + Callable, + Dict, + List, + Optional, + TypeVar, + TYPE_CHECKING, +) from poezio import common +if TYPE_CHECKING: + from poezio.core.core import Core + + +T = TypeVar('T', bound=Callable[..., Any]) + + +BeforeFunc = Optional[Callable[[List[Any], Dict[str, Any]], Any]] +AfterFunc = Optional[Callable[[Any, List[Any], Dict[str, Any]], Any]] + + +def wrap_generic(func: Callable, before: BeforeFunc = None, after: AfterFunc = None): + """ + Generic wrapper which can both wrap coroutines and normal functions. + """ + def wrap(*args, **kwargs): + args = list(args) + if before is not None: + result = before(args, kwargs) + if result is not None: + return result + result = func(*args, **kwargs) + if after is not None: + result = after(result, args, kwargs) + return result + + async def awrap(*args, **kwargs): + args = list(args) + if before is not None: + result = before(args, kwargs) + if result is not None: + return result + result = await func(*args, **kwargs) + if after is not None: + result = after(result, args, kwargs) + return result + if iscoroutinefunction(func): + return awrap + return wrap class RefreshWrapper: - def __init__(self): + core: Optional[Core] + + def __init__(self) -> None: self.core = None - def conditional(self, func: Callable) -> Callable: + def conditional(self, func: T) -> T: """ Decorator to refresh the UI if the wrapped function returns True """ + def after(result: Any, args, kwargs) -> Any: + if self.core is not None and result: + self.core.refresh_window() # pylint: disable=no-member + return result - def wrap(*args, **kwargs): - ret = func(*args, **kwargs) - if self.core and ret: - self.core.refresh_window() - return ret + wrap = wrap_generic(func, after=after) - return wrap + return cast(T, wrap) - def always(self, func: Callable) -> Callable: + def always(self, func: T) -> T: """ Decorator that refreshs the UI no matter what after the function """ + def after(result: Any, args, kwargs) -> Any: + if self.core is not None: + self.core.refresh_window() # pylint: disable=no-member + return result - def wrap(*args, **kwargs): - ret = func(*args, **kwargs) - if self.core: - self.core.refresh_window() - return ret - - return wrap + wrap = wrap_generic(func, after=after) + return cast(T, wrap) - def update(self, func: Callable) -> Callable: + def update(self, func: T) -> T: """ Decorator that only updates the screen """ - def wrap(*args, **kwargs): - ret = func(*args, **kwargs) - if self.core: - self.core.doupdate() - return ret - - return wrap + def after(result: Any, args, kwargs) -> Any: + if self.core is not None: + self.core.doupdate() # pylint: disable=no-member + return result + wrap = wrap_generic(func, after=after) + return cast(T, wrap) refresh_wrapper = RefreshWrapper() @@ -61,32 +113,29 @@ class CommandArgParser: """ @staticmethod - def raw(func: Callable) -> Callable: + def raw(func: T) -> T: """Just call the function with a single string, which is the original string untouched """ - - def wrap(self, args, *a, **kw): - return func(self, args, *a, **kw) - - return wrap + return func @staticmethod - def ignored(func: Callable) -> Callable: + def ignored(func: T) -> T: """ - Call the function without any argument + Call the function without textual arguments """ + def before(args: List[Any], kwargs: Dict[Any, Any]) -> None: + if len(args) >= 2: + del args[1] - def wrap(self, args=None, *a, **kw): - return func(self, *a, **kw) - - return wrap + wrap = wrap_generic(func, before=before) + return cast(T, wrap) @staticmethod def quoted(mandatory: int, - optional=0, + optional: int = 0, defaults: Optional[List[Any]] = None, - ignore_trailing_arguments=False): + ignore_trailing_arguments: bool = False) -> Callable[[T], T]: """The function receives a list with a number of arguments that is between the numbers `mandatory` and `optional`. @@ -131,15 +180,17 @@ class CommandArgParser: """ default_args_outer = defaults or [] - def first(func: Callable): - def second(self, args: str, *a, **kw): + def first(func: T) -> T: + def before(args: List, kwargs: Dict[str, Any]) -> Any: default_args = default_args_outer - if args and args.strip(): - split_args = common.shell_split(args) + cmdargs = args[1] + if cmdargs and cmdargs.strip(): + split_args = common.shell_split(cmdargs) else: split_args = [] if len(split_args) < mandatory: - return func(self, None, *a, **kw) + args[1] = None + return res, split_args = split_args[:mandatory], split_args[ mandatory:] if optional == -1: @@ -154,11 +205,25 @@ class CommandArgParser: res += default_args if split_args and res and not ignore_trailing_arguments: res[-1] += " " + " ".join(split_args) - return func(self, res, *a, **kw) + args[1] = res + return + wrap = wrap_generic(func, before=before) + return cast(T, wrap) + return first - return second +command_args_parser = CommandArgParser() - return first +def deny_anonymous(func: T) -> T: + """Decorator to disable commands when using an anonymous account.""" -command_args_parser = CommandArgParser() + def before(args: Any, kwargs: Any) -> Any: + core = args[0].core + if core.xmpp.anon: + core.information( + 'This command is not available for anonymous accounts.', + 'Info' + ) + return False + wrap = wrap_generic(func, before=before) + return cast(T, wrap) |