summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormathieui <mathieui@mathieui.net>2021-01-29 14:54:25 +0100
committerLink Mauve <linkmauve@linkmauve.fr>2021-02-03 15:22:09 +0100
commit695b2ee09a89634159584a8ab519adb363d82f4e (patch)
treec47d1c486f10677bc36eb3cf1c57abfaa2f31996
parentf5ad5199aeaa020e0d6a723341cd578a53c10850 (diff)
downloadpoezio-695b2ee09a89634159584a8ab519adb363d82f4e.tar.gz
poezio-695b2ee09a89634159584a8ab519adb363d82f4e.tar.bz2
poezio-695b2ee09a89634159584a8ab519adb363d82f4e.tar.xz
poezio-695b2ee09a89634159584a8ab519adb363d82f4e.zip
decorators: make decorators work with coroutines
Tried the least ugly solution I could thing of.
-rw-r--r--poezio/decorators.py101
1 files changed, 67 insertions, 34 deletions
diff --git a/poezio/decorators.py b/poezio/decorators.py
index 6a853446..4b5ef1dc 100644
--- a/poezio/decorators.py
+++ b/poezio/decorators.py
@@ -3,11 +3,13 @@ Module containing various decorators
"""
from __future__ import annotations
+from asyncio import iscoroutinefunction
from typing import (
cast,
Any,
Callable,
+ Dict,
List,
Optional,
TypeVar,
@@ -21,6 +23,37 @@ if TYPE_CHECKING:
T = TypeVar('T', bound=Callable[..., Any])
+BeforeFunc = Callable[[List[Any], Dict[str, Any]], Any]
+AfterFunc = Callable[[List[Any], Dict[str, Any]], Any]
+
+def wrap_generic(func: Callable, before: BeforeFunc=None, after: AfterFunc=None):
+ """
+ Generic wrapper which can both wrap coroutines and normal functions.
+ """
+ def wrap(*args, **kwargs):
+ args = list(args)
+ if before is not None:
+ result = before(args, kwargs)
+ if result is not None:
+ return result
+ result = func(*args, **kwargs)
+ if after is not None:
+ result = after(result, args, kwargs)
+ return result
+
+ async def awrap(*args, **kwargs):
+ args = list(args)
+ if before is not None:
+ result = before(args, kwargs)
+ if result is not None:
+ return result
+ result = await func(*args, **kwargs)
+ if after is not None:
+ result = after(result, args, kwargs)
+ return result
+ if iscoroutinefunction(func):
+ return awrap
+ return wrap
class RefreshWrapper:
@@ -32,12 +65,12 @@ class RefreshWrapper:
Decorator to refresh the UI if the wrapped function
returns True
"""
-
- def wrap(*args: Any, **kwargs: Any) -> Any:
- ret = func(*args, **kwargs)
- if self.core and ret:
+ def after(result: Any, args, kwargs) -> Any:
+ if self.core and result:
self.core.refresh_window()
- return ret
+ return result
+
+ wrap = wrap_generic(func, after=after)
return cast(T, wrap)
@@ -45,13 +78,12 @@ class RefreshWrapper:
"""
Decorator that refreshs the UI no matter what after the function
"""
-
- def wrap(*args: Any, **kwargs: Any) -> Any:
- ret = func(*args, **kwargs)
+ def after(result: Any, args, kwargs) -> Any:
if self.core:
self.core.refresh_window()
- return ret
+ return result
+ wrap = wrap_generic(func, after=after)
return cast(T, wrap)
def update(self, func: T) -> T:
@@ -59,12 +91,11 @@ class RefreshWrapper:
Decorator that only updates the screen
"""
- def wrap(*args: Any, **kwargs: Any) -> Any:
- ret = func(*args, **kwargs)
+ def after(result: Any, args, kwargs) -> Any:
if self.core:
self.core.doupdate()
- return ret
-
+ return result
+ wrap = wrap_generic(func, after=after)
return cast(T, wrap)
@@ -82,21 +113,18 @@ class CommandArgParser:
"""Just call the function with a single string, which is the original string
untouched
"""
-
- def wrap(self: Any, args: Any, *a: Any, **kw: Any) -> Any:
- return func(self, args, *a, **kw)
-
- return cast(T, wrap)
+ return func
@staticmethod
def ignored(func: T) -> T:
"""
- Call the function without any argument
+ Call the function without textual arguments
"""
+ def before(args: List[Any], kwargs: Dict[Any, Any]) -> None:
+ if len(args) >= 2:
+ del args[1]
- def wrap(self: Any, args: Any = None, *a: Any, **kw: Any) -> Any:
- return func(self, *a, **kw)
-
+ wrap = wrap_generic(func, before=before)
return cast(T, wrap)
@staticmethod
@@ -149,14 +177,16 @@ class CommandArgParser:
default_args_outer = defaults or []
def first(func: T) -> T:
- def second(self: Any, args: str, *a: Any, **kw: Any) -> Any:
+ def before(args: List, kwargs: Dict[str, Any]) -> Any:
default_args = default_args_outer
- if args and args.strip():
- split_args = common.shell_split(args)
+ cmdargs = args[1]
+ if cmdargs and cmdargs.strip():
+ split_args = common.shell_split(cmdargs)
else:
split_args = []
if len(split_args) < mandatory:
- return func(self, None, *a, **kw)
+ args[1] = None
+ return
res, split_args = split_args[:mandatory], split_args[
mandatory:]
if optional == -1:
@@ -171,22 +201,25 @@ class CommandArgParser:
res += default_args
if split_args and res and not ignore_trailing_arguments:
res[-1] += " " + " ".join(split_args)
- return func(self, res, *a, **kw)
-
- return cast(T, second)
+ args[1] = res
+ return
+ wrap = wrap_generic(func, before=before)
+ return cast(T, wrap)
return first
-
command_args_parser = CommandArgParser()
def deny_anonymous(func: Callable) -> Callable:
"""Decorator to disable commands when using an anonymous account."""
- def wrap(self: RosterInfoTab, *args: Any, **kwargs: Any) -> Any:
- if self.core.xmpp.anon:
- return self.core.information(
+
+ def before(args: Any, kwargs: Any) -> Any:
+ core = args[0].core
+ if core.xmpp.anon:
+ core.information(
'This command is not available for anonymous accounts.',
'Info'
)
- return func(self, *args, **kwargs)
+ return False
+ wrap = wrap_generic(func, before=before)
return cast(T, wrap)