summaryrefslogtreecommitdiff
path: root/poezio/decorators.py
diff options
context:
space:
mode:
Diffstat (limited to 'poezio/decorators.py')
-rw-r--r--poezio/decorators.py161
1 files changed, 113 insertions, 48 deletions
diff --git a/poezio/decorators.py b/poezio/decorators.py
index 4b5d0320..9342161f 100644
--- a/poezio/decorators.py
+++ b/poezio/decorators.py
@@ -1,54 +1,106 @@
"""
Module containing various decorators
"""
-from typing import Any, Callable, List, Optional
+
+from __future__ import annotations
+from asyncio import iscoroutinefunction
+
+from typing import (
+ cast,
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ TypeVar,
+ TYPE_CHECKING,
+)
from poezio import common
+if TYPE_CHECKING:
+ from poezio.core.core import Core
+
+
+T = TypeVar('T', bound=Callable[..., Any])
+
+
+BeforeFunc = Optional[Callable[[List[Any], Dict[str, Any]], Any]]
+AfterFunc = Optional[Callable[[Any, List[Any], Dict[str, Any]], Any]]
+
+
+def wrap_generic(func: Callable, before: BeforeFunc = None, after: AfterFunc = None):
+ """
+ Generic wrapper which can both wrap coroutines and normal functions.
+ """
+ def wrap(*args, **kwargs):
+ args = list(args)
+ if before is not None:
+ result = before(args, kwargs)
+ if result is not None:
+ return result
+ result = func(*args, **kwargs)
+ if after is not None:
+ result = after(result, args, kwargs)
+ return result
+
+ async def awrap(*args, **kwargs):
+ args = list(args)
+ if before is not None:
+ result = before(args, kwargs)
+ if result is not None:
+ return result
+ result = await func(*args, **kwargs)
+ if after is not None:
+ result = after(result, args, kwargs)
+ return result
+ if iscoroutinefunction(func):
+ return awrap
+ return wrap
class RefreshWrapper:
- def __init__(self):
+ core: Optional[Core]
+
+ def __init__(self) -> None:
self.core = None
- def conditional(self, func: Callable) -> Callable:
+ def conditional(self, func: T) -> T:
"""
Decorator to refresh the UI if the wrapped function
returns True
"""
+ def after(result: Any, args, kwargs) -> Any:
+ if self.core is not None and result:
+ self.core.refresh_window() # pylint: disable=no-member
+ return result
- def wrap(*args, **kwargs):
- ret = func(*args, **kwargs)
- if self.core and ret:
- self.core.refresh_window()
- return ret
+ wrap = wrap_generic(func, after=after)
- return wrap
+ return cast(T, wrap)
- def always(self, func: Callable) -> Callable:
+ def always(self, func: T) -> T:
"""
Decorator that refreshs the UI no matter what after the function
"""
+ def after(result: Any, args, kwargs) -> Any:
+ if self.core is not None:
+ self.core.refresh_window() # pylint: disable=no-member
+ return result
- def wrap(*args, **kwargs):
- ret = func(*args, **kwargs)
- if self.core:
- self.core.refresh_window()
- return ret
-
- return wrap
+ wrap = wrap_generic(func, after=after)
+ return cast(T, wrap)
- def update(self, func: Callable) -> Callable:
+ def update(self, func: T) -> T:
"""
Decorator that only updates the screen
"""
- def wrap(*args, **kwargs):
- ret = func(*args, **kwargs)
- if self.core:
- self.core.doupdate()
- return ret
-
- return wrap
+ def after(result: Any, args, kwargs) -> Any:
+ if self.core is not None:
+ self.core.doupdate() # pylint: disable=no-member
+ return result
+ wrap = wrap_generic(func, after=after)
+ return cast(T, wrap)
refresh_wrapper = RefreshWrapper()
@@ -61,32 +113,29 @@ class CommandArgParser:
"""
@staticmethod
- def raw(func: Callable) -> Callable:
+ def raw(func: T) -> T:
"""Just call the function with a single string, which is the original string
untouched
"""
-
- def wrap(self, args, *a, **kw):
- return func(self, args, *a, **kw)
-
- return wrap
+ return func
@staticmethod
- def ignored(func: Callable) -> Callable:
+ def ignored(func: T) -> T:
"""
- Call the function without any argument
+ Call the function without textual arguments
"""
+ def before(args: List[Any], kwargs: Dict[Any, Any]) -> None:
+ if len(args) >= 2:
+ del args[1]
- def wrap(self, args=None, *a, **kw):
- return func(self, *a, **kw)
-
- return wrap
+ wrap = wrap_generic(func, before=before)
+ return cast(T, wrap)
@staticmethod
def quoted(mandatory: int,
- optional=0,
+ optional: int = 0,
defaults: Optional[List[Any]] = None,
- ignore_trailing_arguments=False):
+ ignore_trailing_arguments: bool = False) -> Callable[[T], T]:
"""The function receives a list with a number of arguments that is between
the numbers `mandatory` and `optional`.
@@ -131,15 +180,17 @@ class CommandArgParser:
"""
default_args_outer = defaults or []
- def first(func: Callable):
- def second(self, args: str, *a, **kw):
+ def first(func: T) -> T:
+ def before(args: List, kwargs: Dict[str, Any]) -> Any:
default_args = default_args_outer
- if args and args.strip():
- split_args = common.shell_split(args)
+ cmdargs = args[1]
+ if cmdargs and cmdargs.strip():
+ split_args = common.shell_split(cmdargs)
else:
split_args = []
if len(split_args) < mandatory:
- return func(self, None, *a, **kw)
+ args[1] = None
+ return
res, split_args = split_args[:mandatory], split_args[
mandatory:]
if optional == -1:
@@ -154,11 +205,25 @@ class CommandArgParser:
res += default_args
if split_args and res and not ignore_trailing_arguments:
res[-1] += " " + " ".join(split_args)
- return func(self, res, *a, **kw)
+ args[1] = res
+ return
+ wrap = wrap_generic(func, before=before)
+ return cast(T, wrap)
+ return first
- return second
+command_args_parser = CommandArgParser()
- return first
+def deny_anonymous(func: T) -> T:
+ """Decorator to disable commands when using an anonymous account."""
-command_args_parser = CommandArgParser()
+ def before(args: Any, kwargs: Any) -> Any:
+ core = args[0].core
+ if core.xmpp.anon:
+ core.information(
+ 'This command is not available for anonymous accounts.',
+ 'Info'
+ )
+ return False
+ wrap = wrap_generic(func, before=before)
+ return cast(T, wrap)