summaryrefslogtreecommitdiff
path: root/slixmpp/xmlstream
diff options
context:
space:
mode:
Diffstat (limited to 'slixmpp/xmlstream')
-rw-r--r--slixmpp/xmlstream/asyncio.py22
-rw-r--r--slixmpp/xmlstream/cert.py22
-rw-r--r--slixmpp/xmlstream/handler/base.py32
-rw-r--r--slixmpp/xmlstream/handler/callback.py28
-rw-r--r--slixmpp/xmlstream/handler/collector.py40
-rw-r--r--slixmpp/xmlstream/handler/coroutine_callback.py39
-rw-r--r--slixmpp/xmlstream/handler/waiter.py45
-rw-r--r--slixmpp/xmlstream/handler/xmlcallback.py5
-rw-r--r--slixmpp/xmlstream/handler/xmlwaiter.py5
-rw-r--r--slixmpp/xmlstream/matcher/base.py9
-rw-r--r--slixmpp/xmlstream/matcher/id.py8
-rw-r--r--slixmpp/xmlstream/matcher/idsender.py23
-rw-r--r--slixmpp/xmlstream/matcher/many.py5
-rw-r--r--slixmpp/xmlstream/matcher/stanzapath.py25
-rw-r--r--slixmpp/xmlstream/matcher/xmlmask.py28
-rw-r--r--slixmpp/xmlstream/matcher/xpath.py17
-rw-r--r--slixmpp/xmlstream/resolver.py70
-rw-r--r--slixmpp/xmlstream/stanzabase.py370
-rw-r--r--slixmpp/xmlstream/tostring.py21
-rw-r--r--slixmpp/xmlstream/xmlstream.py470
20 files changed, 797 insertions, 487 deletions
diff --git a/slixmpp/xmlstream/asyncio.py b/slixmpp/xmlstream/asyncio.py
deleted file mode 100644
index b42b366a..00000000
--- a/slixmpp/xmlstream/asyncio.py
+++ /dev/null
@@ -1,22 +0,0 @@
-"""
-asyncio-related utilities
-"""
-
-import asyncio
-from functools import wraps
-
-def future_wrapper(func):
- """
- Make sure the result of a function call is an asyncio.Future()
- object.
- """
- @wraps(func)
- def wrapper(*args, **kwargs):
- result = func(*args, **kwargs)
- if isinstance(result, asyncio.Future):
- return result
- future = asyncio.Future()
- future.set_result(result)
- return future
-
- return wrapper
diff --git a/slixmpp/xmlstream/cert.py b/slixmpp/xmlstream/cert.py
index 28ef585a..41679a7e 100644
--- a/slixmpp/xmlstream/cert.py
+++ b/slixmpp/xmlstream/cert.py
@@ -1,5 +1,6 @@
import logging
from datetime import datetime, timedelta
+from typing import Dict, Set, Tuple, Optional
# Make a call to strptime before starting threads to
# prevent thread safety issues.
@@ -32,13 +33,13 @@ class CertificateError(Exception):
pass
-def decode_str(data):
+def decode_str(data: bytes) -> str:
encoding = 'utf-16-be' if isinstance(data, BMPString) else 'utf-8'
return bytes(data).decode(encoding)
-def extract_names(raw_cert):
- results = {'CN': set(),
+def extract_names(raw_cert: bytes) -> Dict[str, Set[str]]:
+ results: Dict[str, Set[str]] = {'CN': set(),
'DNS': set(),
'SRV': set(),
'URI': set(),
@@ -96,7 +97,7 @@ def extract_names(raw_cert):
return results
-def extract_dates(raw_cert):
+def extract_dates(raw_cert: bytes) -> Tuple[Optional[datetime], Optional[datetime]]:
if not HAVE_PYASN1:
log.warning("Could not find pyasn1 and pyasn1_modules. " + \
"SSL certificate expiration COULD NOT BE VERIFIED.")
@@ -125,24 +126,29 @@ def extract_dates(raw_cert):
return not_before, not_after
-def get_ttl(raw_cert):
+def get_ttl(raw_cert: bytes) -> Optional[timedelta]:
not_before, not_after = extract_dates(raw_cert)
- if not_after is None:
+ if not_after is None or not_before is None:
return None
return not_after - datetime.utcnow()
-def verify(expected, raw_cert):
+def verify(expected: str, raw_cert: bytes) -> Optional[bool]:
if not HAVE_PYASN1:
log.warning("Could not find pyasn1 and pyasn1_modules. " + \
"SSL certificate COULD NOT BE VERIFIED.")
- return
+ return None
not_before, not_after = extract_dates(raw_cert)
cert_names = extract_names(raw_cert)
now = datetime.utcnow()
+ if not not_before or not not_after:
+ raise CertificateError(
+ "Error while checking the dates of the certificate"
+ )
+
if not_before > now:
raise CertificateError(
'Certificate has not entered its valid date range.')
diff --git a/slixmpp/xmlstream/handler/base.py b/slixmpp/xmlstream/handler/base.py
index 1e657777..0bae5674 100644
--- a/slixmpp/xmlstream/handler/base.py
+++ b/slixmpp/xmlstream/handler/base.py
@@ -4,10 +4,19 @@
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
+from __future__ import annotations
+
import weakref
+from weakref import ReferenceType
+from typing import Optional, TYPE_CHECKING, Union
+from slixmpp.xmlstream.matcher.base import MatcherBase
+from xml.etree.ElementTree import Element
+
+if TYPE_CHECKING:
+ from slixmpp.xmlstream import XMLStream, StanzaBase
-class BaseHandler(object):
+class BaseHandler:
"""
Base class for stream handlers. Stream handlers are matched with
@@ -26,8 +35,13 @@ class BaseHandler(object):
:param stream: The :class:`~slixmpp.xmlstream.xmlstream.XMLStream`
instance that the handle will respond to.
"""
+ name: str
+ stream: Optional[ReferenceType[XMLStream]]
+ _destroy: bool
+ _matcher: MatcherBase
+ _payload: Optional[StanzaBase]
- def __init__(self, name, matcher, stream=None):
+ def __init__(self, name: str, matcher: MatcherBase, stream: Optional[XMLStream] = None):
#: The name of the handler
self.name = name
@@ -41,33 +55,33 @@ class BaseHandler(object):
self._payload = None
self._matcher = matcher
- def match(self, xml):
+ def match(self, xml: StanzaBase) -> bool:
"""Compare a stanza or XML object with the handler's matcher.
:param xml: An XML or
- :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object
+ :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object
"""
return self._matcher.match(xml)
- def prerun(self, payload):
+ def prerun(self, payload: StanzaBase) -> None:
"""Prepare the handler for execution while the XML
stream is being processed.
- :param payload: A :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
+ :param payload: A :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
object.
"""
self._payload = payload
- def run(self, payload):
+ def run(self, payload: StanzaBase) -> None:
"""Execute the handler after XML stream processing and during the
main event loop.
- :param payload: A :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
+ :param payload: A :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
object.
"""
self._payload = payload
- def check_delete(self):
+ def check_delete(self) -> bool:
"""Check if the handler should be removed from the list
of stream handlers.
"""
diff --git a/slixmpp/xmlstream/handler/callback.py b/slixmpp/xmlstream/handler/callback.py
index 93cec6b7..50dd4c66 100644
--- a/slixmpp/xmlstream/handler/callback.py
+++ b/slixmpp/xmlstream/handler/callback.py
@@ -1,10 +1,17 @@
-
# slixmpp.xmlstream.handler.callback
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
+from __future__ import annotations
+
+from typing import Optional, Callable, Any, TYPE_CHECKING
from slixmpp.xmlstream.handler.base import BaseHandler
+from slixmpp.xmlstream.matcher.base import MatcherBase
+
+if TYPE_CHECKING:
+ from slixmpp.xmlstream.stanzabase import StanzaBase
+ from slixmpp.xmlstream.xmlstream import XMLStream
class Callback(BaseHandler):
@@ -28,8 +35,6 @@ class Callback(BaseHandler):
:param matcher: A :class:`~slixmpp.xmlstream.matcher.base.MatcherBase`
derived object for matching stanza objects.
:param pointer: The function to execute during callback.
- :param bool thread: **DEPRECATED.** Remains only for
- backwards compatibility.
:param bool once: Indicates if the handler should be used only
once. Defaults to False.
:param bool instream: Indicates if the callback should be executed
@@ -38,31 +43,36 @@ class Callback(BaseHandler):
:param stream: The :class:`~slixmpp.xmlstream.xmlstream.XMLStream`
instance this handler should monitor.
"""
+ _once: bool
+ _instream: bool
- def __init__(self, name, matcher, pointer, thread=False,
- once=False, instream=False, stream=None):
+ def __init__(self, name: str, matcher: MatcherBase,
+ pointer: Callable[[StanzaBase], Any],
+ once: bool = False, instream: bool = False,
+ stream: Optional[XMLStream] = None):
BaseHandler.__init__(self, name, matcher, stream)
+ self._pointer: Callable[[StanzaBase], Any] = pointer
self._pointer = pointer
self._once = once
self._instream = instream
- def prerun(self, payload):
+ def prerun(self, payload: StanzaBase) -> None:
"""Execute the callback during stream processing, if
the callback was created with ``instream=True``.
:param payload: The matched
- :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object.
+ :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object.
"""
if self._once:
self._destroy = True
if self._instream:
self.run(payload, True)
- def run(self, payload, instream=False):
+ def run(self, payload: StanzaBase, instream: bool = False) -> None:
"""Execute the callback function with the matched stanza payload.
:param payload: The matched
- :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object.
+ :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object.
:param bool instream: Force the handler to execute during stream
processing. This should only be used by
:meth:`prerun()`. Defaults to ``False``.
diff --git a/slixmpp/xmlstream/handler/collector.py b/slixmpp/xmlstream/handler/collector.py
index 8d012873..a5ee109c 100644
--- a/slixmpp/xmlstream/handler/collector.py
+++ b/slixmpp/xmlstream/handler/collector.py
@@ -4,11 +4,17 @@
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2012 Nathanael C. Fritz, Lance J.T. Stout
# :license: MIT, see LICENSE for more details
+from __future__ import annotations
+
import logging
-from queue import Queue, Empty
+from typing import List, Optional, TYPE_CHECKING
+from slixmpp.xmlstream.stanzabase import StanzaBase
from slixmpp.xmlstream.handler.base import BaseHandler
+from slixmpp.xmlstream.matcher.base import MatcherBase
+if TYPE_CHECKING:
+ from slixmpp.xmlstream.xmlstream import XMLStream
log = logging.getLogger(__name__)
@@ -27,35 +33,35 @@ class Collector(BaseHandler):
:param stream: The :class:`~slixmpp.xmlstream.xmlstream.XMLStream`
instance this handler should monitor.
"""
+ _stanzas: List[StanzaBase]
- def __init__(self, name, matcher, stream=None):
+ def __init__(self, name: str, matcher: MatcherBase, stream: Optional[XMLStream] = None):
BaseHandler.__init__(self, name, matcher, stream=stream)
- self._payload = Queue()
+ self._stanzas = []
- def prerun(self, payload):
+ def prerun(self, payload: StanzaBase) -> None:
"""Store the matched stanza when received during processing.
:param payload: The matched
- :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object.
+ :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object.
"""
- self._payload.put(payload)
+ self._stanzas.append(payload)
- def run(self, payload):
+ def run(self, payload: StanzaBase) -> None:
"""Do not process this handler during the main event loop."""
pass
- def stop(self):
+ def stop(self) -> List[StanzaBase]:
"""
Stop collection of matching stanzas, and return the ones that
have been stored so far.
"""
+ stream_ref = self.stream
+ if stream_ref is None:
+ raise ValueError('stop() called without a stream!')
+ stream = stream_ref()
+ if stream is None:
+ raise ValueError('stop() called without a stream!')
self._destroy = True
- results = []
- try:
- while True:
- results.append(self._payload.get(False))
- except Empty:
- pass
-
- self.stream().remove_handler(self.name)
- return results
+ stream.remove_handler(self.name)
+ return self._stanzas
diff --git a/slixmpp/xmlstream/handler/coroutine_callback.py b/slixmpp/xmlstream/handler/coroutine_callback.py
index 6568ba9f..524cca54 100644
--- a/slixmpp/xmlstream/handler/coroutine_callback.py
+++ b/slixmpp/xmlstream/handler/coroutine_callback.py
@@ -4,8 +4,19 @@
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
+from __future__ import annotations
+
+from asyncio import iscoroutinefunction, ensure_future
+from typing import Optional, Callable, Awaitable, TYPE_CHECKING
+
+from slixmpp.xmlstream.stanzabase import StanzaBase
from slixmpp.xmlstream.handler.base import BaseHandler
-from slixmpp.xmlstream.asyncio import asyncio
+from slixmpp.xmlstream.matcher.base import MatcherBase
+
+CoroutineFunction = Callable[[StanzaBase], Awaitable[None]]
+
+if TYPE_CHECKING:
+ from slixmpp.xmlstream.xmlstream import XMLStream
class CoroutineCallback(BaseHandler):
@@ -34,45 +45,49 @@ class CoroutineCallback(BaseHandler):
instance this handler should monitor.
"""
- def __init__(self, name, matcher, pointer, once=False,
- instream=False, stream=None):
+ _once: bool
+ _instream: bool
+
+ def __init__(self, name: str, matcher: MatcherBase,
+ pointer: CoroutineFunction, once: bool = False,
+ instream: bool = False, stream: Optional[XMLStream] = None):
BaseHandler.__init__(self, name, matcher, stream)
- if not asyncio.iscoroutinefunction(pointer):
+ if not iscoroutinefunction(pointer):
raise ValueError("Given function is not a coroutine")
- async def pointer_wrapper(stanza, *args, **kwargs):
+ async def pointer_wrapper(stanza: StanzaBase) -> None:
try:
- await pointer(stanza, *args, **kwargs)
+ await pointer(stanza)
except Exception as e:
stanza.exception(e)
- self._pointer = pointer_wrapper
+ self._pointer: CoroutineFunction = pointer_wrapper
self._once = once
self._instream = instream
- def prerun(self, payload):
+ def prerun(self, payload: StanzaBase) -> None:
"""Execute the callback during stream processing, if
the callback was created with ``instream=True``.
:param payload: The matched
- :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object.
+ :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object.
"""
if self._once:
self._destroy = True
if self._instream:
self.run(payload, True)
- def run(self, payload, instream=False):
+ def run(self, payload: StanzaBase, instream: bool = False) -> None:
"""Execute the callback function with the matched stanza payload.
:param payload: The matched
- :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object.
+ :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object.
:param bool instream: Force the handler to execute during stream
processing. This should only be used by
:meth:`prerun()`. Defaults to ``False``.
"""
if not self._instream or instream:
- asyncio.ensure_future(self._pointer(payload))
+ ensure_future(self._pointer(payload))
if self._once:
self._destroy = True
del self._pointer
diff --git a/slixmpp/xmlstream/handler/waiter.py b/slixmpp/xmlstream/handler/waiter.py
index 758cd1f1..dde49754 100644
--- a/slixmpp/xmlstream/handler/waiter.py
+++ b/slixmpp/xmlstream/handler/waiter.py
@@ -4,13 +4,20 @@
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
+from __future__ import annotations
+
import logging
-import asyncio
-from asyncio import Queue, wait_for, TimeoutError
+from asyncio import Event, wait_for, TimeoutError
+from typing import Optional, TYPE_CHECKING, Union
+from xml.etree.ElementTree import Element
import slixmpp
+from slixmpp.xmlstream.stanzabase import StanzaBase
from slixmpp.xmlstream.handler.base import BaseHandler
+from slixmpp.xmlstream.matcher.base import MatcherBase
+if TYPE_CHECKING:
+ from slixmpp.xmlstream.xmlstream import XMLStream
log = logging.getLogger(__name__)
@@ -28,24 +35,27 @@ class Waiter(BaseHandler):
:param stream: The :class:`~slixmpp.xmlstream.xmlstream.XMLStream`
instance this handler should monitor.
"""
+ _event: Event
- def __init__(self, name, matcher, stream=None):
+ def __init__(self, name: str, matcher: MatcherBase, stream: Optional[XMLStream] = None):
BaseHandler.__init__(self, name, matcher, stream=stream)
- self._payload = Queue()
+ self._event = Event()
- def prerun(self, payload):
+ def prerun(self, payload: StanzaBase) -> None:
"""Store the matched stanza when received during processing.
:param payload: The matched
- :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object.
+ :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object.
"""
- self._payload.put_nowait(payload)
+ if not self._event.is_set():
+ self._event.set()
+ self._payload = payload
- def run(self, payload):
+ def run(self, payload: StanzaBase) -> None:
"""Do not process this handler during the main event loop."""
pass
- async def wait(self, timeout=None):
+ async def wait(self, timeout: Optional[int] = None) -> Optional[StanzaBase]:
"""Block an event handler while waiting for a stanza to arrive.
Be aware that this will impact performance if called from a
@@ -59,17 +69,24 @@ class Waiter(BaseHandler):
:class:`~slixmpp.xmlstream.xmlstream.XMLStream.response_timeout`
value.
"""
+ stream_ref = self.stream
+ if stream_ref is None:
+ raise ValueError('wait() called without a stream')
+ stream = stream_ref()
+ if stream is None:
+ raise ValueError('wait() called without a stream')
if timeout is None:
timeout = slixmpp.xmlstream.RESPONSE_TIMEOUT
- stanza = None
try:
- stanza = await self._payload.get()
+ await wait_for(
+ self._event.wait(), timeout, loop=stream.loop
+ )
except TimeoutError:
log.warning("Timed out waiting for %s", self.name)
- self.stream().remove_handler(self.name)
- return stanza
+ stream.remove_handler(self.name)
+ return self._payload
- def check_delete(self):
+ def check_delete(self) -> bool:
"""Always remove waiters after use."""
return True
diff --git a/slixmpp/xmlstream/handler/xmlcallback.py b/slixmpp/xmlstream/handler/xmlcallback.py
index b44b2da1..c1adc815 100644
--- a/slixmpp/xmlstream/handler/xmlcallback.py
+++ b/slixmpp/xmlstream/handler/xmlcallback.py
@@ -4,6 +4,7 @@
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
from slixmpp.xmlstream.handler import Callback
+from slixmpp.xmlstream.stanzabase import StanzaBase
class XMLCallback(Callback):
@@ -17,7 +18,7 @@ class XMLCallback(Callback):
run -- Overrides Callback.run
"""
- def run(self, payload, instream=False):
+ def run(self, payload: StanzaBase, instream: bool = False) -> None:
"""
Execute the callback function with the matched stanza's
XML contents, instead of the stanza itself.
@@ -30,4 +31,4 @@ class XMLCallback(Callback):
stream processing. Used only by prerun.
Defaults to False.
"""
- Callback.run(self, payload.xml, instream)
+ Callback.run(self, payload.xml, instream) # type: ignore
diff --git a/slixmpp/xmlstream/handler/xmlwaiter.py b/slixmpp/xmlstream/handler/xmlwaiter.py
index 6eb6577e..f730efec 100644
--- a/slixmpp/xmlstream/handler/xmlwaiter.py
+++ b/slixmpp/xmlstream/handler/xmlwaiter.py
@@ -3,6 +3,7 @@
# Copyright (C) 2010 Nathanael C. Fritz
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
+from slixmpp.xmlstream.stanzabase import StanzaBase
from slixmpp.xmlstream.handler import Waiter
@@ -17,7 +18,7 @@ class XMLWaiter(Waiter):
prerun -- Overrides Waiter.prerun
"""
- def prerun(self, payload):
+ def prerun(self, payload: StanzaBase) -> None:
"""
Store the XML contents of the stanza to return to the
waiting event handler.
@@ -27,4 +28,4 @@ class XMLWaiter(Waiter):
Arguments:
payload -- The matched stanza object.
"""
- Waiter.prerun(self, payload.xml)
+ Waiter.prerun(self, payload.xml) # type: ignore
diff --git a/slixmpp/xmlstream/matcher/base.py b/slixmpp/xmlstream/matcher/base.py
index e2560d2b..552269c5 100644
--- a/slixmpp/xmlstream/matcher/base.py
+++ b/slixmpp/xmlstream/matcher/base.py
@@ -1,10 +1,13 @@
-
# slixmpp.xmlstream.matcher.base
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
+from typing import Any
+from slixmpp.xmlstream.stanzabase import StanzaBase
+
+
class MatcherBase(object):
"""
@@ -15,10 +18,10 @@ class MatcherBase(object):
:param criteria: Object to compare some aspect of a stanza against.
"""
- def __init__(self, criteria):
+ def __init__(self, criteria: Any):
self._criteria = criteria
- def match(self, xml):
+ def match(self, xml: StanzaBase) -> bool:
"""Check if a stanza matches the stored criteria.
Meant to be overridden.
diff --git a/slixmpp/xmlstream/matcher/id.py b/slixmpp/xmlstream/matcher/id.py
index 44df2c18..e4e7ad4e 100644
--- a/slixmpp/xmlstream/matcher/id.py
+++ b/slixmpp/xmlstream/matcher/id.py
@@ -5,6 +5,7 @@
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
from slixmpp.xmlstream.matcher.base import MatcherBase
+from slixmpp.xmlstream.stanzabase import StanzaBase
class MatcherId(MatcherBase):
@@ -13,12 +14,13 @@ class MatcherId(MatcherBase):
The ID matcher selects stanzas that have the same stanza 'id'
interface value as the desired ID.
"""
+ _criteria: str
- def match(self, xml):
+ def match(self, xml: StanzaBase) -> bool:
"""Compare the given stanza's ``'id'`` attribute to the stored
``id`` value.
- :param xml: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
+ :param xml: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
stanza to compare against.
"""
- return xml['id'] == self._criteria
+ return bool(xml['id'] == self._criteria)
diff --git a/slixmpp/xmlstream/matcher/idsender.py b/slixmpp/xmlstream/matcher/idsender.py
index 8f5d0303..572f9d87 100644
--- a/slixmpp/xmlstream/matcher/idsender.py
+++ b/slixmpp/xmlstream/matcher/idsender.py
@@ -4,7 +4,19 @@
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
+
from slixmpp.xmlstream.matcher.base import MatcherBase
+from slixmpp.xmlstream.stanzabase import StanzaBase
+from slixmpp.jid import JID
+from slixmpp.types import TypedDict
+
+from typing import Dict
+
+
+class CriteriaType(TypedDict):
+ self: JID
+ peer: JID
+ id: str
class MatchIDSender(MatcherBase):
@@ -14,25 +26,26 @@ class MatchIDSender(MatcherBase):
interface value as the desired ID, and that the 'from' value is one
of a set of approved entities that can respond to a request.
"""
+ _criteria: CriteriaType
- def match(self, xml):
+ def match(self, xml: StanzaBase) -> bool:
"""Compare the given stanza's ``'id'`` attribute to the stored
``id`` value, and verify the sender's JID.
- :param xml: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
+ :param xml: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
stanza to compare against.
"""
selfjid = self._criteria['self']
peerjid = self._criteria['peer']
- allowed = {}
+ allowed: Dict[str, bool] = {}
allowed[''] = True
allowed[selfjid.bare] = True
- allowed[selfjid.host] = True
+ allowed[selfjid.domain] = True
allowed[peerjid.full] = True
allowed[peerjid.bare] = True
- allowed[peerjid.host] = True
+ allowed[peerjid.domain] = True
_from = xml['from']
diff --git a/slixmpp/xmlstream/matcher/many.py b/slixmpp/xmlstream/matcher/many.py
index e8ad54e7..dae45463 100644
--- a/slixmpp/xmlstream/matcher/many.py
+++ b/slixmpp/xmlstream/matcher/many.py
@@ -3,7 +3,9 @@
# Copyright (C) 2010 Nathanael C. Fritz
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
+from typing import Iterable
from slixmpp.xmlstream.matcher.base import MatcherBase
+from slixmpp.xmlstream.stanzabase import StanzaBase
class MatchMany(MatcherBase):
@@ -18,8 +20,9 @@ class MatchMany(MatcherBase):
Methods:
match -- Overrides MatcherBase.match.
"""
+ _criteria: Iterable[MatcherBase]
- def match(self, xml):
+ def match(self, xml: StanzaBase) -> bool:
"""
Match a stanza against multiple criteria. The match is successful
if one of the criteria matches.
diff --git a/slixmpp/xmlstream/matcher/stanzapath.py b/slixmpp/xmlstream/matcher/stanzapath.py
index 1bf3fa8e..b7de3ee2 100644
--- a/slixmpp/xmlstream/matcher/stanzapath.py
+++ b/slixmpp/xmlstream/matcher/stanzapath.py
@@ -4,8 +4,9 @@
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
+from typing import cast, List
from slixmpp.xmlstream.matcher.base import MatcherBase
-from slixmpp.xmlstream.stanzabase import fix_ns
+from slixmpp.xmlstream.stanzabase import fix_ns, StanzaBase
class StanzaPath(MatcherBase):
@@ -17,22 +18,28 @@ class StanzaPath(MatcherBase):
:param criteria: Object to compare some aspect of a stanza against.
"""
-
- def __init__(self, criteria):
- self._criteria = fix_ns(criteria, split=True,
- propagate_ns=False,
- default_ns='jabber:client')
+ _criteria: List[str]
+ _raw_criteria: str
+
+ def __init__(self, criteria: str):
+ self._criteria = cast(
+ List[str],
+ fix_ns(
+ criteria, split=True, propagate_ns=False,
+ default_ns='jabber:client'
+ )
+ )
self._raw_criteria = criteria
- def match(self, stanza):
+ def match(self, stanza: StanzaBase) -> bool:
"""
Compare a stanza against a "stanza path". A stanza path is similar to
an XPath expression, but uses the stanza's interfaces and plugins
instead of the underlying XML. See the documentation for the stanza
- :meth:`~slixmpp.xmlstream.stanzabase.ElementBase.match()` method
+ :meth:`~slixmpp.xmlstream.stanzabase.StanzaBase.match()` method
for more information.
- :param stanza: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
+ :param stanza: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
stanza to compare against.
"""
return stanza.match(self._criteria) or stanza.match(self._raw_criteria)
diff --git a/slixmpp/xmlstream/matcher/xmlmask.py b/slixmpp/xmlstream/matcher/xmlmask.py
index d50b706e..b63e0f05 100644
--- a/slixmpp/xmlstream/matcher/xmlmask.py
+++ b/slixmpp/xmlstream/matcher/xmlmask.py
@@ -1,4 +1,3 @@
-
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2010 Nathanael C. Fritz
# This file is part of Slixmpp.
@@ -6,8 +5,9 @@
import logging
from xml.parsers.expat import ExpatError
+from xml.etree.ElementTree import Element
-from slixmpp.xmlstream.stanzabase import ET
+from slixmpp.xmlstream.stanzabase import ET, StanzaBase
from slixmpp.xmlstream.matcher.base import MatcherBase
@@ -33,32 +33,33 @@ class MatchXMLMask(MatcherBase):
:param criteria: Either an :class:`~xml.etree.ElementTree.Element` XML
object or XML string to use as a mask.
"""
+ _criteria: Element
- def __init__(self, criteria, default_ns='jabber:client'):
+ def __init__(self, criteria: str, default_ns: str = 'jabber:client'):
MatcherBase.__init__(self, criteria)
if isinstance(criteria, str):
- self._criteria = ET.fromstring(self._criteria)
+ self._criteria = ET.fromstring(criteria)
self.default_ns = default_ns
- def setDefaultNS(self, ns):
+ def setDefaultNS(self, ns: str) -> None:
"""Set the default namespace to use during comparisons.
:param ns: The new namespace to use as the default.
"""
self.default_ns = ns
- def match(self, xml):
+ def match(self, xml: StanzaBase) -> bool:
"""Compare a stanza object or XML object against the stored XML mask.
Overrides MatcherBase.match.
:param xml: The stanza object or XML object to compare against.
"""
- if hasattr(xml, 'xml'):
- xml = xml.xml
- return self._mask_cmp(xml, self._criteria, True)
+ real_xml = xml.xml
+ return self._mask_cmp(real_xml, self._criteria, True)
- def _mask_cmp(self, source, mask, use_ns=False, default_ns='__no_ns__'):
+ def _mask_cmp(self, source: Element, mask: Element, use_ns: bool = False,
+ default_ns: str = '__no_ns__') -> bool:
"""Compare an XML object against an XML mask.
:param source: The :class:`~xml.etree.ElementTree.Element` XML object
@@ -75,13 +76,6 @@ class MatchXMLMask(MatcherBase):
# If the element was not found. May happen during recursive calls.
return False
- # Convert the mask to an XML object if it is a string.
- if not hasattr(mask, 'attrib'):
- try:
- mask = ET.fromstring(mask)
- except ExpatError:
- log.warning("Expat error: %s\nIn parsing: %s", '', mask)
-
mask_ns_tag = "{%s}%s" % (self.default_ns, mask.tag)
if source.tag not in [mask.tag, mask_ns_tag]:
return False
diff --git a/slixmpp/xmlstream/matcher/xpath.py b/slixmpp/xmlstream/matcher/xpath.py
index b7503b73..bd41b60a 100644
--- a/slixmpp/xmlstream/matcher/xpath.py
+++ b/slixmpp/xmlstream/matcher/xpath.py
@@ -4,7 +4,8 @@
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
-from slixmpp.xmlstream.stanzabase import ET, fix_ns
+from typing import cast
+from slixmpp.xmlstream.stanzabase import ET, fix_ns, StanzaBase
from slixmpp.xmlstream.matcher.base import MatcherBase
@@ -17,23 +18,23 @@ class MatchXPath(MatcherBase):
If the value of :data:`IGNORE_NS` is set to ``True``, then XPath
expressions will be matched without using namespaces.
"""
+ _criteria: str
- def __init__(self, criteria):
- self._criteria = fix_ns(criteria)
+ def __init__(self, criteria: str):
+ self._criteria = cast(str, fix_ns(criteria))
- def match(self, xml):
+ def match(self, xml: StanzaBase) -> bool:
"""
Compare a stanza's XML contents to an XPath expression.
If the value of :data:`IGNORE_NS` is set to ``True``, then XPath
expressions will be matched without using namespaces.
- :param xml: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
+ :param xml: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
stanza to compare against.
"""
- if hasattr(xml, 'xml'):
- xml = xml.xml
+ real_xml = xml.xml
x = ET.Element('x')
- x.append(xml)
+ x.append(real_xml)
return x.find(self._criteria) is not None
diff --git a/slixmpp/xmlstream/resolver.py b/slixmpp/xmlstream/resolver.py
index 97798353..e524da3b 100644
--- a/slixmpp/xmlstream/resolver.py
+++ b/slixmpp/xmlstream/resolver.py
@@ -1,18 +1,32 @@
-
# slixmpp.xmlstream.dns
# ~~~~~~~~~~~~~~~~~~~~~~~
# :copyright: (c) 2012 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
-from slixmpp.xmlstream.asyncio import asyncio
import socket
+import sys
import logging
import random
+from asyncio import Future, AbstractEventLoop
+from typing import Optional, Tuple, Dict, List, Iterable, cast
+from slixmpp.types import Protocol
log = logging.getLogger(__name__)
+class AnswerProtocol(Protocol):
+ host: str
+ priority: int
+ weight: int
+ port: int
+
+
+class ResolverProtocol(Protocol):
+ def query(self, query: str, querytype: str) -> Future:
+ ...
+
+
#: Global flag indicating the availability of the ``aiodns`` package.
#: Installing ``aiodns`` can be done via:
#:
@@ -23,12 +37,12 @@ AIODNS_AVAILABLE = False
try:
import aiodns
AIODNS_AVAILABLE = True
-except ImportError as e:
- log.debug("Could not find aiodns package. " + \
+except ImportError:
+ log.debug("Could not find aiodns package. "
"Not all features will be available")
-def default_resolver(loop):
+def default_resolver(loop: AbstractEventLoop) -> Optional[ResolverProtocol]:
"""Return a basic DNS resolver object.
:returns: A :class:`aiodns.DNSResolver` object if aiodns
@@ -41,8 +55,11 @@ def default_resolver(loop):
return None
-async def resolve(host, port=None, service=None, proto='tcp',
- resolver=None, use_ipv6=True, use_aiodns=True, loop=None):
+async def resolve(host: str, port: int, *, loop: AbstractEventLoop,
+ service: Optional[str] = None, proto: str = 'tcp',
+ resolver: Optional[ResolverProtocol] = None,
+ use_ipv6: bool = True,
+ use_aiodns: bool = True) -> List[Tuple[str, str, int]]:
"""Peform DNS resolution for a given hostname.
Resolution may perform SRV record lookups if a service and protocol
@@ -91,8 +108,8 @@ async def resolve(host, port=None, service=None, proto='tcp',
if not use_ipv6:
log.debug("DNS: Use of IPv6 has been disabled.")
- if resolver is None and AIODNS_AVAILABLE and use_aiodns:
- resolver = aiodns.DNSResolver(loop=loop)
+ if resolver is None and use_aiodns:
+ resolver = default_resolver(loop=loop)
# An IPv6 literal is allowed to be enclosed in square brackets, but
# the brackets must be stripped in order to process the literal;
@@ -101,7 +118,7 @@ async def resolve(host, port=None, service=None, proto='tcp',
try:
# If `host` is an IPv4 literal, we can return it immediately.
- ipv4 = socket.inet_aton(host)
+ socket.inet_aton(host)
return [(host, host, port)]
except socket.error:
pass
@@ -111,7 +128,7 @@ async def resolve(host, port=None, service=None, proto='tcp',
# Likewise, If `host` is an IPv6 literal, we can return
# it immediately.
if hasattr(socket, 'inet_pton'):
- ipv6 = socket.inet_pton(socket.AF_INET6, host)
+ socket.inet_pton(socket.AF_INET6, host)
return [(host, host, port)]
except (socket.error, ValueError):
pass
@@ -148,7 +165,10 @@ async def resolve(host, port=None, service=None, proto='tcp',
return results
-async def get_A(host, resolver=None, use_aiodns=True, loop=None):
+
+async def get_A(host: str, *, loop: AbstractEventLoop,
+ resolver: Optional[ResolverProtocol] = None,
+ use_aiodns: bool = True) -> List[str]:
"""Lookup DNS A records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will
@@ -172,10 +192,10 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None):
# getaddrinfo() method.
if resolver is None or not use_aiodns:
try:
- recs = await loop.getaddrinfo(host, None,
+ inet_recs = await loop.getaddrinfo(host, None,
family=socket.AF_INET,
type=socket.SOCK_STREAM)
- return [rec[4][0] for rec in recs]
+ return [rec[4][0] for rec in inet_recs]
except socket.gaierror:
log.debug("DNS: Error retrieving A address info for %s." % host)
return []
@@ -183,14 +203,16 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None):
# Using aiodns:
future = resolver.query(host, 'A')
try:
- recs = await future
+ recs = cast(Iterable[AnswerProtocol], await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s A records: %s', host, e)
recs = []
return [rec.host for rec in recs]
-async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
+async def get_AAAA(host: str, *, loop: AbstractEventLoop,
+ resolver: Optional[ResolverProtocol] = None,
+ use_aiodns: bool = True) -> List[str]:
"""Lookup DNS AAAA records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will
@@ -217,10 +239,10 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
return []
try:
- recs = await loop.getaddrinfo(host, None,
+ inet_recs = await loop.getaddrinfo(host, None,
family=socket.AF_INET6,
type=socket.SOCK_STREAM)
- return [rec[4][0] for rec in recs]
+ return [rec[4][0] for rec in inet_recs]
except (OSError, socket.gaierror):
log.debug("DNS: Error retrieving AAAA address " + \
"info for %s." % host)
@@ -229,13 +251,17 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
# Using aiodns:
future = resolver.query(host, 'AAAA')
try:
- recs = await future
+ recs = cast(Iterable[AnswerProtocol], await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
recs = []
return [rec.host for rec in recs]
-async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True):
+
+async def get_SRV(host: str, port: int, service: str,
+ proto: str = 'tcp',
+ resolver: Optional[ResolverProtocol] = None,
+ use_aiodns: bool = True) -> List[Tuple[str, int]]:
"""Perform SRV record resolution for a given host.
.. note::
@@ -269,12 +295,12 @@ async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=Tr
try:
future = resolver.query('_%s._%s.%s' % (service, proto, host),
'SRV')
- recs = await future
+ recs = cast(Iterable[AnswerProtocol], await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
return []
- answers = {}
+ answers: Dict[int, List[AnswerProtocol]] = {}
for rec in recs:
if rec.priority not in answers:
answers[rec.priority] = []
diff --git a/slixmpp/xmlstream/stanzabase.py b/slixmpp/xmlstream/stanzabase.py
index 7679f73a..2f2faa8d 100644
--- a/slixmpp/xmlstream/stanzabase.py
+++ b/slixmpp/xmlstream/stanzabase.py
@@ -1,4 +1,3 @@
-
# slixmpp.xmlstream.stanzabase
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# module implements a wrapper layer for XML objects
@@ -11,13 +10,34 @@ from __future__ import annotations
import copy
import logging
import weakref
-from typing import Optional
+from typing import (
+ cast,
+ Any,
+ Callable,
+ ClassVar,
+ Coroutine,
+ Dict,
+ List,
+ Iterable,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ TYPE_CHECKING,
+ Union,
+)
+from weakref import ReferenceType
from xml.etree import ElementTree as ET
+from slixmpp.types import JidStr
from slixmpp.xmlstream import JID
from slixmpp.xmlstream.tostring import tostring
+if TYPE_CHECKING:
+ from slixmpp.xmlstream import XMLStream
+
+
log = logging.getLogger(__name__)
@@ -28,7 +48,8 @@ XML_TYPE = type(ET.Element('xml'))
XML_NS = 'http://www.w3.org/XML/1998/namespace'
-def register_stanza_plugin(stanza, plugin, iterable=False, overrides=False):
+def register_stanza_plugin(stanza: Type[ElementBase], plugin: Type[ElementBase],
+ iterable: bool = False, overrides: bool = False) -> None:
"""
Associate a stanza object as a plugin for another stanza.
@@ -85,15 +106,15 @@ def register_stanza_plugin(stanza, plugin, iterable=False, overrides=False):
stanza.plugin_overrides[interface] = plugin.plugin_attrib
-def multifactory(stanza, plugin_attrib):
+def multifactory(stanza: Type[ElementBase], plugin_attrib: str) -> Type[ElementBase]:
"""
Returns a ElementBase class for handling reoccuring child stanzas
"""
- def plugin_filter(self):
+ def plugin_filter(self: Multi) -> Callable[..., bool]:
return lambda x: isinstance(x, self._multistanza)
- def plugin_lang_filter(self, lang):
+ def plugin_lang_filter(self: Multi, lang: Optional[str]) -> Callable[..., bool]:
return lambda x: isinstance(x, self._multistanza) and \
x['lang'] == lang
@@ -101,31 +122,41 @@ def multifactory(stanza, plugin_attrib):
"""
Template class for multifactory
"""
- def setup(self, xml=None):
+ _multistanza: Type[ElementBase]
+
+ def setup(self, xml: Optional[ET.Element] = None) -> bool:
self.xml = ET.Element('')
+ return False
- def get_multi(self, lang=None):
- parent = self.parent()
+ def get_multi(self: Multi, lang: Optional[str] = None) -> List[ElementBase]:
+ parent = fail_without_parent(self)
if not lang or lang == '*':
res = filter(plugin_filter(self), parent)
else:
- res = filter(plugin_filter(self, lang), parent)
+ res = filter(plugin_lang_filter(self, lang), parent)
return list(res)
- def set_multi(self, val, lang=None):
- parent = self.parent()
+ def set_multi(self: Multi, val: Iterable[ElementBase], lang: Optional[str] = None) -> None:
+ parent = fail_without_parent(self)
del_multi = getattr(self, 'del_%s' % plugin_attrib)
del_multi(lang)
for sub in val:
parent.append(sub)
- def del_multi(self, lang=None):
- parent = self.parent()
+ def fail_without_parent(self: Multi) -> ElementBase:
+ parent = None
+ if self.parent:
+ parent = self.parent()
+ if not parent:
+ raise ValueError('No stanza parent for multifactory')
+ return parent
+
+ def del_multi(self: Multi, lang: Optional[str] = None) -> None:
+ parent = fail_without_parent(self)
if not lang or lang == '*':
- res = filter(plugin_filter(self), parent)
+ res = list(filter(plugin_filter(self), parent))
else:
- res = filter(plugin_filter(self, lang), parent)
- res = list(res)
+ res = list(filter(plugin_lang_filter(self, lang), parent))
if not res:
del parent.plugins[(plugin_attrib, None)]
parent.loaded_plugins.remove(plugin_attrib)
@@ -149,7 +180,8 @@ def multifactory(stanza, plugin_attrib):
return Multi
-def fix_ns(xpath, split=False, propagate_ns=True, default_ns=''):
+def fix_ns(xpath: str, split: bool = False, propagate_ns: bool = True,
+ default_ns: str = '') -> Union[str, List[str]]:
"""Apply the stanza's namespace to elements in an XPath expression.
:param string xpath: The XPath expression to fix with namespaces.
@@ -275,12 +307,12 @@ class ElementBase(object):
#: The XML tag name of the element, not including any namespace
#: prefixes. For example, an :class:`ElementBase` object for
#: ``<message />`` would use ``name = 'message'``.
- name = 'stanza'
+ name: ClassVar[str] = 'stanza'
#: The XML namespace for the element. Given ``<foo xmlns="bar" />``,
#: then ``namespace = "bar"`` should be used. The default namespace
#: is ``jabber:client`` since this is being used in an XMPP library.
- namespace = 'jabber:client'
+ namespace: str = 'jabber:client'
#: For :class:`ElementBase` subclasses which are intended to be used
#: as plugins, the ``plugin_attrib`` value defines the plugin name.
@@ -290,7 +322,7 @@ class ElementBase(object):
#: register_stanza_plugin(Message, FooPlugin)
#: msg = Message()
#: msg['foo']['an_interface_from_the_foo_plugin']
- plugin_attrib = 'plugin'
+ plugin_attrib: ClassVar[str] = 'plugin'
#: For :class:`ElementBase` subclasses that are intended to be an
#: iterable group of items, the ``plugin_multi_attrib`` value defines
@@ -300,29 +332,29 @@ class ElementBase(object):
#: # Given stanza class Foo, with plugin_multi_attrib = 'foos'
#: parent['foos']
#: filter(isinstance(item, Foo), parent['substanzas'])
- plugin_multi_attrib = ''
+ plugin_multi_attrib: ClassVar[str] = ''
#: The set of keys that the stanza provides for accessing and
#: manipulating the underlying XML object. This set may be augmented
#: with the :attr:`plugin_attrib` value of any registered
#: stanza plugins.
- interfaces = {'type', 'to', 'from', 'id', 'payload'}
+ interfaces: ClassVar[Set[str]] = {'type', 'to', 'from', 'id', 'payload'}
#: A subset of :attr:`interfaces` which maps interfaces to direct
#: subelements of the underlying XML object. Using this set, the text
#: of these subelements may be set, retrieved, or removed without
#: needing to define custom methods.
- sub_interfaces = set()
+ sub_interfaces: ClassVar[Set[str]] = set()
#: A subset of :attr:`interfaces` which maps the presence of
#: subelements to boolean values. Using this set allows for quickly
#: checking for the existence of empty subelements like ``<required />``.
#:
#: .. versionadded:: 1.1
- bool_interfaces = set()
+ bool_interfaces: ClassVar[Set[str]] = set()
#: .. versionadded:: 1.1.2
- lang_interfaces = set()
+ lang_interfaces: ClassVar[Set[str]] = set()
#: In some cases you may wish to override the behaviour of one of the
#: parent stanza's interfaces. The ``overrides`` list specifies the
@@ -336,7 +368,7 @@ class ElementBase(object):
#: be affected.
#:
#: .. versionadded:: 1.0-Beta5
- overrides = []
+ overrides: ClassVar[List[str]] = []
#: If you need to add a new interface to an existing stanza, you
#: can create a plugin and set ``is_extension = True``. Be sure
@@ -346,7 +378,7 @@ class ElementBase(object):
#: parent stanza will be passed to the plugin directly.
#:
#: .. versionadded:: 1.0-Beta5
- is_extension = False
+ is_extension: ClassVar[bool] = False
#: A map of interface operations to the overriding functions.
#: For example, after overriding the ``set`` operation for
@@ -355,15 +387,15 @@ class ElementBase(object):
#: {'set_body': <some function>}
#:
#: .. versionadded: 1.0-Beta5
- plugin_overrides = {}
+ plugin_overrides: ClassVar[Dict[str, str]] = {}
#: A mapping of the :attr:`plugin_attrib` values of registered
#: plugins to their respective classes.
- plugin_attrib_map = {}
+ plugin_attrib_map: ClassVar[Dict[str, Type[ElementBase]]] = {}
#: A mapping of root element tag names (in ``'{namespace}elementname'``
#: format) to the plugin classes responsible for them.
- plugin_tag_map = {}
+ plugin_tag_map: ClassVar[Dict[str, Type[ElementBase]]] = {}
#: The set of stanza classes that can be iterated over using
#: the 'substanzas' interface. Classes are added to this set
@@ -372,17 +404,26 @@ class ElementBase(object):
#: register_stanza_plugin(DiscoInfo, DiscoItem, iterable=True)
#:
#: .. versionadded:: 1.0-Beta5
- plugin_iterables = set()
+ plugin_iterables: ClassVar[Set[Type[ElementBase]]] = set()
#: The default XML namespace: ``http://www.w3.org/XML/1998/namespace``.
- xml_ns = XML_NS
-
- def __init__(self, xml=None, parent=None):
+ xml_ns: ClassVar[str] = XML_NS
+
+ plugins: Dict[Tuple[str, Optional[str]], ElementBase]
+ #: The underlying XML object for the stanza. It is a standard
+ #: :class:`xml.etree.ElementTree` object.
+ xml: ET.Element
+ _index: int
+ loaded_plugins: Set[str]
+ iterables: List[ElementBase]
+ tag: str
+ parent: Optional[ReferenceType[ElementBase]]
+
+ def __init__(self, xml: Optional[ET.Element] = None, parent: Union[Optional[ElementBase], ReferenceType[ElementBase]] = None):
self._index = 0
- #: The underlying XML object for the stanza. It is a standard
- #: :class:`xml.etree.ElementTree` object.
- self.xml = xml
+ if xml is not None:
+ self.xml = xml
#: An ordered dictionary of plugin stanzas, mapped by their
#: :attr:`plugin_attrib` value.
@@ -419,7 +460,7 @@ class ElementBase(object):
existing_xml=child,
reuse=False)
- def setup(self, xml=None):
+ def setup(self, xml: Optional[ET.Element] = None) -> bool:
"""Initialize the stanza's XML contents.
Will return ``True`` if XML was generated according to the stanza's
@@ -429,29 +470,31 @@ class ElementBase(object):
:param xml: An existing XML object to use for the stanza's content
instead of generating new XML.
"""
- if self.xml is None:
+ if hasattr(self, 'xml'):
+ return False
+ if not hasattr(self, 'xml') and xml is not None:
self.xml = xml
+ return False
- last_xml = self.xml
- if self.xml is None:
- # Generate XML from the stanza definition
- for ename in self.name.split('/'):
- new = ET.Element("{%s}%s" % (self.namespace, ename))
- if self.xml is None:
- self.xml = new
- else:
- last_xml.append(new)
- last_xml = new
- if self.parent is not None:
- self.parent().xml.append(self.xml)
- # We had to generate XML
- return True
- else:
- # We did not generate XML
- return False
+ # Generate XML from the stanza definition
+ last_xml = ET.Element('')
+ for ename in self.name.split('/'):
+ new = ET.Element("{%s}%s" % (self.namespace, ename))
+ if not hasattr(self, 'xml'):
+ self.xml = new
+ else:
+ last_xml.append(new)
+ last_xml = new
+ if self.parent is not None:
+ parent = self.parent()
+ if parent:
+ parent.xml.append(self.xml)
+
+ # We had to generate XML
+ return True
- def enable(self, attrib, lang=None):
+ def enable(self, attrib: str, lang: Optional[str] = None) -> ElementBase:
"""Enable and initialize a stanza plugin.
Alias for :meth:`init_plugin`.
@@ -487,7 +530,10 @@ class ElementBase(object):
else:
return None if check else self.init_plugin(name, lang)
- def init_plugin(self, attrib, lang=None, existing_xml=None, element=None, reuse=True):
+ def init_plugin(self, attrib: str, lang: Optional[str] = None,
+ existing_xml: Optional[ET.Element] = None,
+ reuse: bool = True,
+ element: Optional[ElementBase] = None) -> ElementBase:
"""Enable and initialize a stanza plugin.
:param string attrib: The :attr:`plugin_attrib` value of the
@@ -525,7 +571,7 @@ class ElementBase(object):
return plugin
- def _get_stanza_values(self):
+ def _get_stanza_values(self) -> Dict[str, Any]:
"""Return A JSON/dictionary version of the XML content
exposed through the stanza's interfaces::
@@ -567,7 +613,7 @@ class ElementBase(object):
values['substanzas'] = iterables
return values
- def _set_stanza_values(self, values):
+ def _set_stanza_values(self, values: Dict[str, Any]) -> ElementBase:
"""Set multiple stanza interface values using a dictionary.
Stanza plugin values may be set using nested dictionaries.
@@ -623,7 +669,7 @@ class ElementBase(object):
plugin.values = value
return self
- def __getitem__(self, full_attrib):
+ def __getitem__(self, full_attrib: str) -> Any:
"""Return the value of a stanza interface using dict-like syntax.
Example::
@@ -688,7 +734,7 @@ class ElementBase(object):
else:
return ''
- def __setitem__(self, attrib, value):
+ def __setitem__(self, attrib: str, value: Any) -> Any:
"""Set the value of a stanza interface using dictionary-like syntax.
Example::
@@ -773,7 +819,7 @@ class ElementBase(object):
plugin[full_attrib] = value
return self
- def __delitem__(self, attrib):
+ def __delitem__(self, attrib: str) -> Any:
"""Delete the value of a stanza interface using dict-like syntax.
Example::
@@ -851,7 +897,7 @@ class ElementBase(object):
pass
return self
- def _set_attr(self, name, value):
+ def _set_attr(self, name: str, value: Optional[JidStr]) -> None:
"""Set the value of a top level attribute of the XML object.
If the new value is None or an empty string, then the attribute will
@@ -868,7 +914,7 @@ class ElementBase(object):
value = str(value)
self.xml.attrib[name] = value
- def _del_attr(self, name):
+ def _del_attr(self, name: str) -> None:
"""Remove a top level attribute of the XML object.
:param name: The name of the attribute.
@@ -876,7 +922,7 @@ class ElementBase(object):
if name in self.xml.attrib:
del self.xml.attrib[name]
- def _get_attr(self, name, default=''):
+ def _get_attr(self, name: str, default: str = '') -> str:
"""Return the value of a top level attribute of the XML object.
In case the attribute has not been set, a default value can be
@@ -889,7 +935,8 @@ class ElementBase(object):
"""
return self.xml.attrib.get(name, default)
- def _get_sub_text(self, name, default='', lang=None):
+ def _get_sub_text(self, name: str, default: str = '',
+ lang: Optional[str] = None) -> Union[str, Dict[str, str]]:
"""Return the text contents of a sub element.
In case the element does not exist, or it has no textual content,
@@ -900,7 +947,7 @@ class ElementBase(object):
:param default: Optional default to return if the element does
not exists. An empty string is returned otherwise.
"""
- name = self._fix_ns(name)
+ name = cast(str, self._fix_ns(name))
if lang == '*':
return self._get_all_sub_text(name, default, None)
@@ -924,8 +971,9 @@ class ElementBase(object):
return result
return default
- def _get_all_sub_text(self, name, default='', lang=None):
- name = self._fix_ns(name)
+ def _get_all_sub_text(self, name: str, default: str = '',
+ lang: Optional[str] = None) -> Dict[str, str]:
+ name = cast(str, self._fix_ns(name))
default_lang = self.get_lang()
results = {}
@@ -935,10 +983,16 @@ class ElementBase(object):
stanza_lang = stanza.attrib.get('{%s}lang' % XML_NS,
default_lang)
if not lang or lang == '*' or stanza_lang == lang:
- results[stanza_lang] = stanza.text
+ if stanza.text is None:
+ text = default
+ else:
+ text = stanza.text
+ results[stanza_lang] = text
return results
- def _set_sub_text(self, name, text=None, keep=False, lang=None):
+ def _set_sub_text(self, name: str, text: Optional[str] = None,
+ keep: bool = False,
+ lang: Optional[str] = None) -> Optional[ET.Element]:
"""Set the text contents of a sub element.
In case the element does not exist, a element will be created,
@@ -959,15 +1013,16 @@ class ElementBase(object):
lang = default_lang
if not text and not keep:
- return self._del_sub(name, lang=lang)
+ self._del_sub(name, lang=lang)
+ return None
- path = self._fix_ns(name, split=True)
+ path = cast(List[str], self._fix_ns(name, split=True))
name = path[-1]
- parent = self.xml
+ parent: Optional[ET.Element] = self.xml
# The first goal is to find the parent of the subelement, or, if
# we can't find that, the closest grandparent element.
- missing_path = []
+ missing_path: List[str] = []
search_order = path[:-1]
while search_order:
parent = self.xml.find('/'.join(search_order))
@@ -1008,15 +1063,17 @@ class ElementBase(object):
parent.append(element)
return element
- def _set_all_sub_text(self, name, values, keep=False, lang=None):
- self._del_sub(name, lang)
+ def _set_all_sub_text(self, name: str, values: Dict[str, str],
+ keep: bool = False,
+ lang: Optional[str] = None) -> None:
+ self._del_sub(name, lang=lang)
for value_lang, value in values.items():
if not lang or lang == '*' or value_lang == lang:
self._set_sub_text(name, text=value,
keep=keep,
lang=value_lang)
- def _del_sub(self, name, all=False, lang=None):
+ def _del_sub(self, name: str, all: bool = False, lang: Optional[str] = None) -> None:
"""Remove sub elements that match the given name or XPath.
If the element is in a path, then any parent elements that become
@@ -1034,11 +1091,11 @@ class ElementBase(object):
if not lang:
lang = default_lang
- parent = self.xml
+ parent: Optional[ET.Element] = self.xml
for level, _ in enumerate(path):
# Generate the paths to the target elements and their parent.
element_path = "/".join(path[:len(path) - level])
- parent_path = "/".join(path[:len(path) - level - 1])
+ parent_path: Optional[str] = "/".join(path[:len(path) - level - 1])
elements = self.xml.findall(element_path)
if parent_path == '':
@@ -1061,7 +1118,7 @@ class ElementBase(object):
# after deleting the first level of elements.
return
- def match(self, xpath):
+ def match(self, xpath: Union[str, List[str]]) -> bool:
"""Compare a stanza object with an XPath-like expression.
If the XPath matches the contents of the stanza object, the match
@@ -1127,7 +1184,7 @@ class ElementBase(object):
# Everything matched.
return True
- def get(self, key, default=None):
+ def get(self, key: str, default: Optional[Any] = None) -> Any:
"""Return the value of a stanza interface.
If the found value is None or an empty string, return the supplied
@@ -1144,7 +1201,7 @@ class ElementBase(object):
return default
return value
- def keys(self):
+ def keys(self) -> List[str]:
"""Return the names of all stanza interfaces provided by the
stanza object.
@@ -1158,7 +1215,7 @@ class ElementBase(object):
out.append('substanzas')
return out
- def append(self, item):
+ def append(self, item: Union[ET.Element, ElementBase]) -> ElementBase:
"""Append either an XML object or a substanza to this stanza object.
If a substanza object is appended, it will be added to the list
@@ -1189,7 +1246,7 @@ class ElementBase(object):
return self
- def appendxml(self, xml):
+ def appendxml(self, xml: ET.Element) -> ElementBase:
"""Append an XML object to the stanza's XML.
The added XML will not be included in the list of
@@ -1200,7 +1257,7 @@ class ElementBase(object):
self.xml.append(xml)
return self
- def pop(self, index=0):
+ def pop(self, index: int = 0) -> ElementBase:
"""Remove and return the last substanza in the list of
iterable substanzas.
@@ -1212,11 +1269,11 @@ class ElementBase(object):
self.xml.remove(substanza.xml)
return substanza
- def next(self):
+ def next(self) -> ElementBase:
"""Return the next iterable substanza."""
return self.__next__()
- def clear(self):
+ def clear(self) -> ElementBase:
"""Remove all XML element contents and plugins.
Any attribute values will be preserved.
@@ -1229,7 +1286,7 @@ class ElementBase(object):
return self
@classmethod
- def tag_name(cls):
+ def tag_name(cls) -> str:
"""Return the namespaced name of the stanza's root element.
The format for the tag name is::
@@ -1241,29 +1298,32 @@ class ElementBase(object):
"""
return "{%s}%s" % (cls.namespace, cls.name)
- def get_lang(self, lang=None):
+ def get_lang(self, lang: Optional[str] = None) -> str:
result = self.xml.attrib.get('{%s}lang' % XML_NS, '')
- if not result and self.parent and self.parent():
- return self.parent()['lang']
+ if not result and self.parent:
+ parent = self.parent()
+ if parent:
+ return cast(str, parent['lang'])
return result
- def set_lang(self, lang):
+ def set_lang(self, lang: Optional[str]) -> None:
self.del_lang()
attr = '{%s}lang' % XML_NS
if lang:
self.xml.attrib[attr] = lang
- def del_lang(self):
+ def del_lang(self) -> None:
attr = '{%s}lang' % XML_NS
if attr in self.xml.attrib:
del self.xml.attrib[attr]
- def _fix_ns(self, xpath, split=False, propagate_ns=True):
+ def _fix_ns(self, xpath: str, split: bool = False,
+ propagate_ns: bool = True) -> Union[str, List[str]]:
return fix_ns(xpath, split=split,
propagate_ns=propagate_ns,
default_ns=self.namespace)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
"""Compare the stanza object with another to test for equality.
Stanzas are equal if their interfaces return the same values,
@@ -1290,7 +1350,7 @@ class ElementBase(object):
# must be equal.
return True
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
"""Compare the stanza object with another to test for inequality.
Stanzas are not equal if their interfaces return different values,
@@ -1300,16 +1360,16 @@ class ElementBase(object):
"""
return not self.__eq__(other)
- def __bool__(self):
+ def __bool__(self) -> bool:
"""Stanza objects should be treated as True in boolean contexts.
"""
return True
- def __len__(self):
+ def __len__(self) -> int:
"""Return the number of iterable substanzas in this stanza."""
return len(self.iterables)
- def __iter__(self):
+ def __iter__(self) -> ElementBase:
"""Return an iterator object for the stanza's substanzas.
The iterator is the stanza object itself. Attempting to use two
@@ -1318,7 +1378,7 @@ class ElementBase(object):
self._index = 0
return self
- def __next__(self):
+ def __next__(self) -> ElementBase:
"""Return the next iterable substanza."""
self._index += 1
if self._index > len(self.iterables):
@@ -1326,13 +1386,16 @@ class ElementBase(object):
raise StopIteration
return self.iterables[self._index - 1]
- def __copy__(self):
+ def __copy__(self) -> ElementBase:
"""Return a copy of the stanza object that does not share the same
underlying XML object.
"""
- return self.__class__(xml=copy.deepcopy(self.xml), parent=self.parent)
+ return self.__class__(
+ xml=copy.deepcopy(self.xml),
+ parent=self.parent,
+ )
- def __str__(self, top_level_ns=True):
+ def __str__(self, top_level_ns: bool = True) -> str:
"""Return a string serialization of the underlying XML object.
.. seealso:: :ref:`tostring`
@@ -1343,12 +1406,33 @@ class ElementBase(object):
return tostring(self.xml, xmlns='',
top_level=True)
- def __repr__(self):
+ def __repr__(self) -> str:
"""Use the stanza's serialized XML as its representation."""
return self.__str__()
# Compatibility.
_get_plugin = get_plugin
+ get_stanza_values = _get_stanza_values
+ set_stanza_values = _set_stanza_values
+
+ #: A JSON/dictionary version of the XML content exposed through
+ #: the stanza interfaces::
+ #:
+ #: >>> msg = Message()
+ #: >>> msg.values
+ #: {'body': '', 'from': , 'mucnick': '', 'mucroom': '',
+ #: 'to': , 'type': 'normal', 'id': '', 'subject': ''}
+ #:
+ #: Likewise, assigning to the :attr:`values` will change the XML
+ #: content::
+ #:
+ #: >>> msg = Message()
+ #: >>> msg.values = {'body': 'Hi!', 'to': 'user@example.com'}
+ #: >>> msg
+ #: '<message to="user@example.com"><body>Hi!</body></message>'
+ #:
+ #: Child stanzas are exposed as nested dictionaries.
+ values = property(_get_stanza_values, _set_stanza_values) # type: ignore
class StanzaBase(ElementBase):
@@ -1386,9 +1470,14 @@ class StanzaBase(ElementBase):
#: The default XMPP client namespace
namespace = 'jabber:client'
-
- def __init__(self, stream=None, xml=None, stype=None,
- sto=None, sfrom=None, sid=None, parent=None, recv=False):
+ types: ClassVar[Set[str]] = set()
+
+ def __init__(self, stream: Optional[XMLStream] = None,
+ xml: Optional[ET.Element] = None,
+ stype: Optional[str] = None,
+ sto: Optional[JidStr] = None, sfrom: Optional[JidStr] = None,
+ sid: Optional[str] = None,
+ parent: Optional[ElementBase] = None, recv: bool = False):
self.stream = stream
if stream is not None:
self.namespace = stream.default_ns
@@ -1403,7 +1492,7 @@ class StanzaBase(ElementBase):
self['id'] = sid
self.tag = "{%s}%s" % (self.namespace, self.name)
- def set_type(self, value):
+ def set_type(self, value: str) -> StanzaBase:
"""Set the stanza's ``'type'`` attribute.
Only type values contained in :attr:`types` are accepted.
@@ -1414,11 +1503,11 @@ class StanzaBase(ElementBase):
self.xml.attrib['type'] = value
return self
- def get_to(self):
+ def get_to(self) -> JID:
"""Return the value of the stanza's ``'to'`` attribute."""
return JID(self._get_attr('to'))
- def set_to(self, value):
+ def set_to(self, value: JidStr) -> None:
"""Set the ``'to'`` attribute of the stanza.
:param value: A string or :class:`slixmpp.xmlstream.JID` object
@@ -1426,11 +1515,11 @@ class StanzaBase(ElementBase):
"""
return self._set_attr('to', str(value))
- def get_from(self):
+ def get_from(self) -> JID:
"""Return the value of the stanza's ``'from'`` attribute."""
return JID(self._get_attr('from'))
- def set_from(self, value):
+ def set_from(self, value: JidStr) -> None:
"""Set the 'from' attribute of the stanza.
:param from: A string or JID object representing the sender's JID.
@@ -1438,11 +1527,11 @@ class StanzaBase(ElementBase):
"""
return self._set_attr('from', str(value))
- def get_payload(self):
+ def get_payload(self) -> List[ET.Element]:
"""Return a list of XML objects contained in the stanza."""
return list(self.xml)
- def set_payload(self, value):
+ def set_payload(self, value: Union[List[ElementBase], ElementBase]) -> StanzaBase:
"""Add XML content to the stanza.
:param value: Either an XML or a stanza object, or a list
@@ -1454,12 +1543,12 @@ class StanzaBase(ElementBase):
self.append(val)
return self
- def del_payload(self):
+ def del_payload(self) -> StanzaBase:
"""Remove the XML contents of the stanza."""
self.clear()
return self
- def reply(self, clear=True):
+ def reply(self, clear: bool = True) -> StanzaBase:
"""Prepare the stanza for sending a reply.
Swaps the ``'from'`` and ``'to'`` attributes.
@@ -1475,7 +1564,7 @@ class StanzaBase(ElementBase):
new_stanza = copy.copy(self)
# if it's a component, use from
if self.stream and hasattr(self.stream, "is_component") and \
- self.stream.is_component:
+ getattr(self.stream, 'is_component'):
new_stanza['from'], new_stanza['to'] = self['to'], self['from']
else:
new_stanza['to'] = self['from']
@@ -1484,19 +1573,19 @@ class StanzaBase(ElementBase):
new_stanza.clear()
return new_stanza
- def error(self):
+ def error(self) -> StanzaBase:
"""Set the stanza's type to ``'error'``."""
self['type'] = 'error'
return self
- def unhandled(self):
+ def unhandled(self) -> None:
"""Called if no handlers have been registered to process this stanza.
Meant to be overridden.
"""
pass
- def exception(self, e):
+ def exception(self, e: Exception) -> None:
"""Handle exceptions raised during stanza processing.
Meant to be overridden.
@@ -1504,18 +1593,21 @@ class StanzaBase(ElementBase):
log.exception('Error handling {%s}%s stanza', self.namespace,
self.name)
- def send(self):
+ def send(self) -> None:
"""Queue the stanza to be sent on the XML stream."""
- self.stream.send(self)
+ if self.stream is not None:
+ self.stream.send(self)
+ else:
+ log.error("Tried to send stanza without a stream: %s", self)
- def __copy__(self):
+ def __copy__(self) -> StanzaBase:
"""Return a copy of the stanza object that does not share the
same underlying XML object, but does share the same XML stream.
"""
return self.__class__(xml=copy.deepcopy(self.xml),
stream=self.stream)
- def __str__(self, top_level_ns=False):
+ def __str__(self, top_level_ns: bool = False) -> str:
"""Serialize the stanza's XML to a string.
:param bool top_level_ns: Display the top-most namespace.
@@ -1525,27 +1617,3 @@ class StanzaBase(ElementBase):
return tostring(self.xml, xmlns=xmlns,
stream=self.stream,
top_level=(self.stream is None))
-
-
-#: A JSON/dictionary version of the XML content exposed through
-#: the stanza interfaces::
-#:
-#: >>> msg = Message()
-#: >>> msg.values
-#: {'body': '', 'from': , 'mucnick': '', 'mucroom': '',
-#: 'to': , 'type': 'normal', 'id': '', 'subject': ''}
-#:
-#: Likewise, assigning to the :attr:`values` will change the XML
-#: content::
-#:
-#: >>> msg = Message()
-#: >>> msg.values = {'body': 'Hi!', 'to': 'user@example.com'}
-#: >>> msg
-#: '<message to="user@example.com"><body>Hi!</body></message>'
-#:
-#: Child stanzas are exposed as nested dictionaries.
-ElementBase.values = property(ElementBase._get_stanza_values,
- ElementBase._set_stanza_values)
-
-ElementBase.get_stanza_values = ElementBase._get_stanza_values
-ElementBase.set_stanza_values = ElementBase._set_stanza_values
diff --git a/slixmpp/xmlstream/tostring.py b/slixmpp/xmlstream/tostring.py
index efac124e..447c9017 100644
--- a/slixmpp/xmlstream/tostring.py
+++ b/slixmpp/xmlstream/tostring.py
@@ -1,4 +1,3 @@
-
# slixmpp.xmlstream.tostring
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# This module converts XML objects into Unicode strings and
@@ -7,11 +6,20 @@
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
+from __future__ import annotations
+
+from typing import Optional, Set, TYPE_CHECKING
+from xml.etree.ElementTree import Element
+if TYPE_CHECKING:
+ from slixmpp.xmlstream import XMLStream
+
XML_NS = 'http://www.w3.org/XML/1998/namespace'
-def tostring(xml=None, xmlns='', stream=None, outbuffer='',
- top_level=False, open_only=False, namespaces=None):
+def tostring(xml: Optional[Element] = None, xmlns: str = '',
+ stream: Optional[XMLStream] = None, outbuffer: str = '',
+ top_level: bool = False, open_only: bool = False,
+ namespaces: Optional[Set[str]] = None) -> str:
"""Serialize an XML object to a Unicode string.
If an outer xmlns is provided using ``xmlns``, then the current element's
@@ -35,6 +43,8 @@ def tostring(xml=None, xmlns='', stream=None, outbuffer='',
:rtype: Unicode string
"""
+ if xml is None:
+ return ''
# Add previous results to the start of the output.
output = [outbuffer]
@@ -123,11 +133,12 @@ def tostring(xml=None, xmlns='', stream=None, outbuffer='',
# Remove namespaces introduced in this context. This is necessary
# because the namespaces object continues to be shared with other
# contexts.
- namespaces.remove(ns)
+ if namespaces is not None:
+ namespaces.remove(ns)
return ''.join(output)
-def escape(text, use_cdata=False):
+def escape(text: str, use_cdata: bool = False) -> str:
"""Convert special characters in XML to escape sequences.
:param string text: The XML text to convert.
diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py
index 8d90abf8..30f99071 100644
--- a/slixmpp/xmlstream/xmlstream.py
+++ b/slixmpp/xmlstream/xmlstream.py
@@ -9,17 +9,24 @@
# :license: MIT, see LICENSE for more details
from typing import (
Any,
+ Dict,
+ Awaitable,
+ Generator,
Coroutine,
Callable,
- Iterable,
Iterator,
List,
Optional,
Set,
Union,
Tuple,
+ TypeVar,
+ NoReturn,
+ Type,
+ cast,
)
+import asyncio
import functools
import logging
import socket as Socket
@@ -27,30 +34,66 @@ import ssl
import weakref
import uuid
-from asyncio import iscoroutinefunction, wait, Future
from contextlib import contextmanager
import xml.etree.ElementTree as ET
+from asyncio import (
+ AbstractEventLoop,
+ BaseTransport,
+ Future,
+ Task,
+ TimerHandle,
+ Transport,
+ iscoroutinefunction,
+ wait,
+)
-from slixmpp.xmlstream.asyncio import asyncio
-from slixmpp.xmlstream import tostring
+from slixmpp.types import FilterString
+from slixmpp.xmlstream.tostring import tostring
from slixmpp.xmlstream.stanzabase import StanzaBase, ElementBase
from slixmpp.xmlstream.resolver import resolve, default_resolver
+from slixmpp.xmlstream.handler.base import BaseHandler
+
+T = TypeVar('T')
#: The time in seconds to wait before timing out waiting for response stanzas.
RESPONSE_TIMEOUT = 30
log = logging.getLogger(__name__)
+
+
class ContinueQueue(Exception):
"""
Exception raised in the send queue to "continue" from within an inner loop
"""
+
class NotConnectedError(Exception):
"""
Raised when we try to send something over the wire but we are not
connected.
"""
+
+_T = TypeVar('_T', str, ElementBase, StanzaBase)
+
+
+SyncFilter = Callable[[StanzaBase], Optional[StanzaBase]]
+AsyncFilter = Callable[[StanzaBase], Awaitable[Optional[StanzaBase]]]
+
+
+Filter = Union[
+ SyncFilter,
+ AsyncFilter,
+]
+
+_FiltersDict = Dict[str, List[Filter]]
+
+Handler = Callable[[Any], Union[
+ Any,
+ Coroutine[Any, Any, Any]
+]]
+
+
class XMLStream(asyncio.BaseProtocol):
"""
An XML stream connection manager and event dispatcher.
@@ -78,16 +121,156 @@ class XMLStream(asyncio.BaseProtocol):
:param int port: The port to use for the connection. Defaults to 0.
"""
- def __init__(self, host='', port=0):
- # The asyncio.Transport object provided by the connection_made()
- # callback when we are connected
- self.transport = None
+ transport: Optional[Transport]
- # The socket that is used internally by the transport object
- self.socket = None
+ # The socket that is used internally by the transport object
+ socket: Optional[ssl.SSLSocket]
+
+ # The backoff of the connect routine (increases exponentially
+ # after each failure)
+ _connect_loop_wait: float
+
+ parser: Optional[ET.XMLPullParser]
+ xml_depth: int
+ xml_root: Optional[ET.Element]
+
+ force_starttls: Optional[bool]
+ disable_starttls: Optional[bool]
+
+ waiting_queue: asyncio.Queue
+
+ # A dict of {name: handle}
+ scheduled_events: Dict[str, TimerHandle]
+
+ ssl_context: ssl.SSLContext
+
+ # The event to trigger when the create_connection() succeeds. It can
+ # be "connected" or "tls_success" depending on the step we are at.
+ event_when_connected: str
+
+ #: The list of accepted ciphers, in OpenSSL Format.
+ #: It might be useful to override it for improved security
+ #: over the python defaults.
+ ciphers: Optional[str]
+
+ #: Path to a file containing certificates for verifying the
+ #: server SSL certificate. A non-``None`` value will trigger
+ #: certificate checking.
+ #:
+ #: .. note::
+ #:
+ #: On Mac OS X, certificates in the system keyring will
+ #: be consulted, even if they are not in the provided file.
+ ca_certs: Optional[str]
+
+ #: Path to a file containing a client certificate to use for
+ #: authenticating via SASL EXTERNAL. If set, there must also
+ #: be a corresponding `:attr:keyfile` value.
+ certfile: Optional[str]
+
+ #: Path to a file containing the private key for the selected
+ #: client certificate to use for authenticating via SASL EXTERNAL.
+ keyfile: Optional[str]
+
+ # The asyncio event loop
+ _loop: Optional[AbstractEventLoop]
+
+ #: The default port to return when querying DNS records.
+ default_port: int
+
+ #: The domain to try when querying DNS records.
+ default_domain: str
+
+ #: The expected name of the server, for validation.
+ _expected_server_name: str
+ _service_name: str
+
+ #: The desired, or actual, address of the connected server.
+ address: Tuple[str, int]
+
+ #: Enable connecting to the server directly over SSL, in
+ #: particular when the service provides two ports: one for
+ #: non-SSL traffic and another for SSL traffic.
+ use_ssl: bool
+
+ #: If set to ``True``, attempt to use IPv6.
+ use_ipv6: bool
+
+ #: If set to ``True``, allow using the ``dnspython`` DNS library
+ #: if available. If set to ``False``, the builtin DNS resolver
+ #: will be used, even if ``dnspython`` is installed.
+ use_aiodns: bool
- # The backoff of the connect routine (increases exponentially
- # after each failure)
+ #: Use CDATA for escaping instead of XML entities. Defaults
+ #: to ``False``.
+ use_cdata: bool
+
+ #: The default namespace of the stream content, not of the
+ #: stream wrapper it
+ default_ns: str
+
+ default_lang: Optional[str]
+ peer_default_lang: Optional[str]
+
+ #: The namespace of the enveloping stream element.
+ stream_ns: str
+
+ #: The default opening tag for the stream element.
+ stream_header: str
+
+ #: The default closing tag for the stream element.
+ stream_footer: str
+
+ #: If ``True``, periodically send a whitespace character over the
+ #: wire to keep the connection alive. Mainly useful for connections
+ #: traversing NAT.
+ whitespace_keepalive: bool
+
+ #: The default interval between keepalive signals when
+ #: :attr:`whitespace_keepalive` is enabled.
+ whitespace_keepalive_interval: int
+
+ #: Flag for controlling if the session can be considered ended
+ #: if the connection is terminated.
+ end_session_on_disconnect: bool
+
+ #: A mapping of XML namespaces to well-known prefixes.
+ namespace_map: dict
+
+ __root_stanza: List[Type[StanzaBase]]
+ __handlers: List[BaseHandler]
+ __event_handlers: Dict[str, List[Tuple[Handler, bool]]]
+ __filters: _FiltersDict
+
+ # Current connection attempt (Future)
+ _current_connection_attempt: Optional[Future]
+
+ #: A list of DNS results that have not yet been tried.
+ _dns_answers: Optional[Iterator[Tuple[str, str, int]]]
+
+ #: The service name to check with DNS SRV records. For
+ #: example, setting this to ``'xmpp-client'`` would query the
+ #: ``_xmpp-client._tcp`` service.
+ dns_service: Optional[str]
+
+ #: The reason why we are disconnecting from the server
+ disconnect_reason: Optional[str]
+
+ #: An asyncio Future being done when the stream is disconnected.
+ disconnected: Future
+
+ # If the session has been started or not
+ _session_started: bool
+ # If we want to bypass the send() check (e.g. unit tests)
+ _always_send_everything: bool
+
+ _run_out_filters: Optional[Future]
+ __slow_tasks: List[Task]
+ __queued_stanzas: List[Tuple[Union[StanzaBase, str], bool]]
+
+ def __init__(self, host: str = '', port: int = 0):
+ self.transport = None
+ self.socket = None
self._connect_loop_wait = 0
self.parser = None
@@ -106,126 +289,60 @@ class XMLStream(asyncio.BaseProtocol):
self.ssl_context.check_hostname = False
self.ssl_context.verify_mode = ssl.CERT_NONE
- # The event to trigger when the create_connection() succeeds. It can
- # be "connected" or "tls_success" depending on the step we are at.
self.event_when_connected = "connected"
- #: The list of accepted ciphers, in OpenSSL Format.
- #: It might be useful to override it for improved security
- #: over the python defaults.
self.ciphers = None
- #: Path to a file containing certificates for verifying the
- #: server SSL certificate. A non-``None`` value will trigger
- #: certificate checking.
- #:
- #: .. note::
- #:
- #: On Mac OS X, certificates in the system keyring will
- #: be consulted, even if they are not in the provided file.
self.ca_certs = None
- #: Path to a file containing a client certificate to use for
- #: authenticating via SASL EXTERNAL. If set, there must also
- #: be a corresponding `:attr:keyfile` value.
- self.certfile = None
-
- #: Path to a file containing the private key for the selected
- #: client certificate to use for authenticating via SASL EXTERNAL.
self.keyfile = None
- self._der_cert = None
-
- # The asyncio event loop
self._loop = None
- #: The default port to return when querying DNS records.
self.default_port = int(port)
-
- #: The domain to try when querying DNS records.
self.default_domain = ''
- #: The expected name of the server, for validation.
self._expected_server_name = ''
self._service_name = ''
- #: The desired, or actual, address of the connected server.
self.address = (host, int(port))
- #: Enable connecting to the server directly over SSL, in
- #: particular when the service provides two ports: one for
- #: non-SSL traffic and another for SSL traffic.
self.use_ssl = False
-
- #: If set to ``True``, attempt to use IPv6.
self.use_ipv6 = True
- #: If set to ``True``, allow using the ``dnspython`` DNS library
- #: if available. If set to ``False``, the builtin DNS resolver
- #: will be used, even if ``dnspython`` is installed.
self.use_aiodns = True
-
- #: Use CDATA for escaping instead of XML entities. Defaults
- #: to ``False``.
self.use_cdata = False
- #: The default namespace of the stream content, not of the
- #: stream wrapper itself.
self.default_ns = ''
self.default_lang = None
self.peer_default_lang = None
- #: The namespace of the enveloping stream element.
self.stream_ns = ''
-
- #: The default opening tag for the stream element.
self.stream_header = "<stream>"
-
- #: The default closing tag for the stream element.
self.stream_footer = "</stream>"
- #: If ``True``, periodically send a whitespace character over the
- #: wire to keep the connection alive. Mainly useful for connections
- #: traversing NAT.
self.whitespace_keepalive = True
-
- #: The default interval between keepalive signals when
- #: :attr:`whitespace_keepalive` is enabled.
self.whitespace_keepalive_interval = 300
- #: Flag for controlling if the session can be considered ended
- #: if the connection is terminated.
self.end_session_on_disconnect = True
-
- #: A mapping of XML namespaces to well-known prefixes.
self.namespace_map = {StanzaBase.xml_ns: 'xml'}
self.__root_stanza = []
self.__handlers = []
self.__event_handlers = {}
- self.__filters = {'in': [], 'out': [], 'out_sync': []}
+ self.__filters = {
+ 'in': [], 'out': [], 'out_sync': []
+ }
- # Current connection attempt (Future)
self._current_connection_attempt = None
- #: A list of DNS results that have not yet been tried.
- self._dns_answers: Optional[Iterator[Tuple[str, str, int]]] = None
-
- #: The service name to check with DNS SRV records. For
- #: example, setting this to ``'xmpp-client'`` would query the
- #: ``_xmpp-client._tcp`` service.
+ self._dns_answers = None
self.dns_service = None
- #: The reason why we are disconnecting from the server
self.disconnect_reason = None
-
- #: An asyncio Future being done when the stream is disconnected.
- self.disconnected: Future = Future()
-
- # If the session has been started or not
+ self.disconnected = Future()
self._session_started = False
- # If we want to bypass the send() check (e.g. unit tests)
self._always_send_everything = False
self.add_event_handler('disconnected', self._remove_schedules)
@@ -234,21 +351,21 @@ class XMLStream(asyncio.BaseProtocol):
self.add_event_handler('session_start', self._set_session_start)
self.add_event_handler('session_resumed', self._set_session_start)
- self._run_out_filters: Optional[Future] = None
- self.__slow_tasks: List[Future] = []
- self.__queued_stanzas: List[Tuple[StanzaBase, bool]] = []
+ self._run_out_filters = None
+ self.__slow_tasks = []
+ self.__queued_stanzas = []
@property
- def loop(self):
+ def loop(self) -> AbstractEventLoop:
if self._loop is None:
self._loop = asyncio.get_event_loop()
return self._loop
@loop.setter
- def loop(self, value):
+ def loop(self, value: AbstractEventLoop) -> None:
self._loop = value
- def new_id(self):
+ def new_id(self) -> str:
"""Generate and return a new stream ID in hexadecimal form.
Many stanzas, handlers, or matchers may require unique
@@ -257,7 +374,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
return uuid.uuid4().hex
- def _set_session_start(self, event):
+ def _set_session_start(self, event: Any) -> None:
"""
On session start, queue all pending stanzas to be sent.
"""
@@ -266,17 +383,17 @@ class XMLStream(asyncio.BaseProtocol):
self.waiting_queue.put_nowait(stanza)
self.__queued_stanzas = []
- def _set_disconnected(self, event):
+ def _set_disconnected(self, event: Any) -> None:
self._session_started = False
- def _set_disconnected_future(self):
+ def _set_disconnected_future(self) -> None:
"""Set the self.disconnected future on disconnect"""
if not self.disconnected.done():
self.disconnected.set_result(True)
self.disconnected = asyncio.Future()
- def connect(self, host='', port=0, use_ssl=False,
- force_starttls=True, disable_starttls=False):
+ def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = False,
+ force_starttls: Optional[bool] = True, disable_starttls: Optional[bool] = False) -> None:
"""Create a new socket and connect to the server.
:param host: The name of the desired server for the connection.
@@ -327,7 +444,7 @@ class XMLStream(asyncio.BaseProtocol):
loop=self.loop,
)
- async def _connect_routine(self):
+ async def _connect_routine(self) -> None:
self.event_when_connected = "connected"
if self._connect_loop_wait > 0:
@@ -345,6 +462,7 @@ class XMLStream(asyncio.BaseProtocol):
# and try (host, port) as a last resort
self._dns_answers = None
+ ssl_context: Optional[ssl.SSLContext]
if self.use_ssl:
ssl_context = self.get_ssl_context()
else:
@@ -373,7 +491,7 @@ class XMLStream(asyncio.BaseProtocol):
loop=self.loop,
)
- def process(self, *, forever=True, timeout=None):
+ def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None:
"""Process all the available XMPP events (receiving or sending data on the
socket(s), calling various registered callbacks, calling expired
timers, handling signal events, etc). If timeout is None, this
@@ -386,12 +504,12 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.loop.run_until_complete(self.disconnected)
else:
- tasks = [asyncio.sleep(timeout, loop=self.loop)]
+ tasks: List[Future] = [asyncio.sleep(timeout, loop=self.loop)]
if not forever:
tasks.append(self.disconnected)
self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop))
- def init_parser(self):
+ def init_parser(self) -> None:
"""init the XML parser. The parser must always be reset for each new
connexion
"""
@@ -399,11 +517,13 @@ class XMLStream(asyncio.BaseProtocol):
self.xml_root = None
self.parser = ET.XMLPullParser(("start", "end"))
- def connection_made(self, transport):
+ def connection_made(self, transport: BaseTransport) -> None:
"""Called when the TCP connection has been established with the server
"""
self.event(self.event_when_connected)
- self.transport = transport
+ self.transport = cast(Transport, transport)
+ if self.transport is None:
+ raise ValueError("Transport cannot be none")
self.socket = self.transport.get_extra_info(
"ssl_object",
default=self.transport.get_extra_info("socket")
@@ -413,7 +533,7 @@ class XMLStream(asyncio.BaseProtocol):
self.send_raw(self.stream_header)
self._dns_answers = None
- def data_received(self, data):
+ def data_received(self, data: bytes) -> None:
"""Called when incoming data is received on the socket.
We feed that data to the parser and the see if this produced any XML
@@ -467,18 +587,18 @@ class XMLStream(asyncio.BaseProtocol):
self.send(error)
self.disconnect()
- def is_connecting(self):
+ def is_connecting(self) -> bool:
return self._current_connection_attempt is not None
- def is_connected(self):
+ def is_connected(self) -> bool:
return self.transport is not None
- def eof_received(self):
+ def eof_received(self) -> None:
"""When the TCP connection is properly closed by the remote end
"""
self.event("eof_received")
- def connection_lost(self, exception):
+ def connection_lost(self, exception: Optional[BaseException]) -> None:
"""On any kind of disconnection, initiated by us or not. This signals the
closure of the TCP connection
"""
@@ -493,9 +613,9 @@ class XMLStream(asyncio.BaseProtocol):
self._reset_sendq()
self.event('session_end')
self._set_disconnected_future()
- self.event("disconnected", self.disconnect_reason or exception and exception.strerror)
+ self.event("disconnected", self.disconnect_reason or exception)
- def cancel_connection_attempt(self):
+ def cancel_connection_attempt(self) -> None:
"""
Immediately cancel the current create_connection() Future.
This is useful when a client using slixmpp tries to connect
@@ -526,7 +646,7 @@ class XMLStream(asyncio.BaseProtocol):
# `disconnect(wait=True)` for ages. This doesn't mean anything to the
# schedule call below. It would fortunately be converted to `1` later
# down the call chain. Praise the implicit casts lord.
- if wait == True:
+ if wait is True:
wait = 2.0
if self.transport:
@@ -545,11 +665,11 @@ class XMLStream(asyncio.BaseProtocol):
else:
self._set_disconnected_future()
self.event("disconnected", reason)
- future = Future()
+ future: Future = Future()
future.set_result(None)
return future
- async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float):
+ async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float) -> None:
"""Wait until the send queue is empty before disconnecting"""
try:
await asyncio.wait_for(
@@ -561,7 +681,7 @@ class XMLStream(asyncio.BaseProtocol):
self.disconnect_reason = reason
await self._end_stream_wait(wait)
- async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None):
+ async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None) -> None:
"""
Run abort() if we do not received the disconnected event
after a waiting time.
@@ -578,7 +698,7 @@ class XMLStream(asyncio.BaseProtocol):
# that means the disconnect has already been handled
pass
- def abort(self):
+ def abort(self) -> None:
"""
Forcibly close the connection
"""
@@ -588,26 +708,26 @@ class XMLStream(asyncio.BaseProtocol):
self.transport.abort()
self.event("killed")
- def reconnect(self, wait=2.0, reason="Reconnecting"):
+ def reconnect(self, wait: Union[int, float] = 2.0, reason: str = "Reconnecting") -> None:
"""Calls disconnect(), and once we are disconnected (after the timeout, or
when the server acknowledgement is received), call connect()
"""
log.debug("reconnecting...")
- async def handler(event):
+ async def handler(event: Any) -> None:
# We yield here to allow synchronous handlers to work first
await asyncio.sleep(0, loop=self.loop)
self.connect()
self.add_event_handler('disconnected', handler, disposable=True)
self.disconnect(wait, reason)
- def configure_socket(self):
+ def configure_socket(self) -> None:
"""Set timeout and other options for self.socket.
Meant to be overridden.
"""
pass
- def configure_dns(self, resolver, domain=None, port=None):
+ def configure_dns(self, resolver: Any, domain: Optional[str] = None, port: Optional[int] = None) -> None:
"""
Configure and set options for a :class:`~dns.resolver.Resolver`
instance, and other DNS related tasks. For example, you
@@ -624,7 +744,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
pass
- def get_ssl_context(self):
+ def get_ssl_context(self) -> ssl.SSLContext:
"""
Get SSL context.
"""
@@ -644,12 +764,14 @@ class XMLStream(asyncio.BaseProtocol):
return self.ssl_context
- async def start_tls(self):
+ async def start_tls(self) -> bool:
"""Perform handshakes for TLS.
If the handshake is successful, the XML stream will need
to be restarted.
"""
+ if self.transport is None:
+ raise ValueError("Transport should not be None")
self.event_when_connected = "tls_success"
ssl_context = self.get_ssl_context()
try:
@@ -685,7 +807,7 @@ class XMLStream(asyncio.BaseProtocol):
self.connection_made(transp)
return True
- def _start_keepalive(self, event):
+ def _start_keepalive(self, event: Any) -> None:
"""Begin sending whitespace periodically to keep the connection alive.
May be disabled by setting::
@@ -702,11 +824,11 @@ class XMLStream(asyncio.BaseProtocol):
args=(' ',),
repeat=True)
- def _remove_schedules(self, event):
+ def _remove_schedules(self, event: Any) -> None:
"""Remove some schedules that become pointless when disconnected"""
self.cancel_schedule('Whitespace Keepalive')
- def start_stream_handler(self, xml):
+ def start_stream_handler(self, xml: ET.Element) -> None:
"""Perform any initialization actions, such as handshakes,
once the stream header has been sent.
@@ -714,7 +836,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
pass
- def register_stanza(self, stanza_class):
+ def register_stanza(self, stanza_class: Type[StanzaBase]) -> None:
"""Add a stanza object class as a known root stanza.
A root stanza is one that appears as a direct child of the stream's
@@ -732,7 +854,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
self.__root_stanza.append(stanza_class)
- def remove_stanza(self, stanza_class):
+ def remove_stanza(self, stanza_class: Type[StanzaBase]) -> None:
"""Remove a stanza from being a known root stanza.
A root stanza is one that appears as a direct child of the stream's
@@ -744,7 +866,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
self.__root_stanza.remove(stanza_class)
- def add_filter(self, mode, handler, order=None):
+ def add_filter(self, mode: FilterString, handler: Callable[[StanzaBase], Optional[StanzaBase]], order: Optional[int] = None) -> None:
"""Add a filter for incoming or outgoing stanzas.
These filters are applied before incoming stanzas are
@@ -766,11 +888,11 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.__filters[mode].append(handler)
- def del_filter(self, mode, handler):
+ def del_filter(self, mode: str, handler: Callable[[StanzaBase], Optional[StanzaBase]]) -> None:
"""Remove an incoming or outgoing filter."""
self.__filters[mode].remove(handler)
- def register_handler(self, handler, before=None, after=None):
+ def register_handler(self, handler: BaseHandler, before: Optional[BaseHandler] = None, after: Optional[BaseHandler] = None) -> None:
"""Add a stream event handler that will be executed when a matching
stanza is received.
@@ -782,7 +904,7 @@ class XMLStream(asyncio.BaseProtocol):
self.__handlers.append(handler)
handler.stream = weakref.ref(self)
- def remove_handler(self, name):
+ def remove_handler(self, name: str) -> bool:
"""Remove any stream event handlers with the given name.
:param name: The name of the handler.
@@ -831,9 +953,9 @@ class XMLStream(asyncio.BaseProtocol):
try:
return next(self._dns_answers)
except StopIteration:
- return
+ return None
- def add_event_handler(self, name, pointer, disposable=False):
+ def add_event_handler(self, name: str, pointer: Callable[..., Any], disposable: bool = False) -> None:
"""Add a custom event handler that will be executed whenever
its event is manually triggered.
@@ -847,7 +969,7 @@ class XMLStream(asyncio.BaseProtocol):
self.__event_handlers[name] = []
self.__event_handlers[name].append((pointer, disposable))
- def del_event_handler(self, name, pointer):
+ def del_event_handler(self, name: str, pointer: Callable[..., Any]) -> None:
"""Remove a function as a handler for an event.
:param name: The name of the event.
@@ -858,21 +980,21 @@ class XMLStream(asyncio.BaseProtocol):
# Need to keep handlers that do not use
# the given function pointer
- def filter_pointers(handler):
+ def filter_pointers(handler: Tuple[Callable[..., Any], bool]) -> bool:
return handler[0] != pointer
self.__event_handlers[name] = list(filter(
filter_pointers,
self.__event_handlers[name]))
- def event_handled(self, name):
+ def event_handled(self, name: str) -> int:
"""Returns the number of registered handlers for an event.
:param name: The name of the event to check.
"""
return len(self.__event_handlers.get(name, []))
- async def event_async(self, name: str, data: Any = {}):
+ async def event_async(self, name: str, data: Any = {}) -> None:
"""Manually trigger a custom event, but await coroutines immediately.
This event generator should only be called in situations when
@@ -908,7 +1030,7 @@ class XMLStream(asyncio.BaseProtocol):
except Exception as e:
self.exception(e)
- def event(self, name: str, data: Any = {}):
+ def event(self, name: str, data: Any = {}) -> None:
"""Manually trigger a custom event.
Coroutine handlers are wrapped into a future and sent into the
event loop for their execution, and not awaited.
@@ -928,7 +1050,7 @@ class XMLStream(asyncio.BaseProtocol):
# If the callback is a coroutine, schedule it instead of
# running it directly
if iscoroutinefunction(handler_callback):
- async def handler_callback_routine(cb):
+ async def handler_callback_routine(cb: Callable[[ElementBase], Any]) -> None:
try:
await cb(data)
except Exception as e:
@@ -957,8 +1079,9 @@ class XMLStream(asyncio.BaseProtocol):
except ValueError:
pass
- def schedule(self, name, seconds, callback, args=tuple(),
- kwargs={}, repeat=False):
+ def schedule(self, name: str, seconds: int, callback: Callable[..., None],
+ args: Tuple[Any, ...] = tuple(),
+ kwargs: Dict[Any, Any] = {}, repeat: bool = False) -> None:
"""Schedule a callback function to execute after a given delay.
:param name: A unique name for the scheduled callback.
@@ -986,21 +1109,21 @@ class XMLStream(asyncio.BaseProtocol):
# canceling scheduled_events[name]
self.scheduled_events[name] = handle
- def cancel_schedule(self, name):
+ def cancel_schedule(self, name: str) -> None:
try:
handle = self.scheduled_events.pop(name)
handle.cancel()
except KeyError:
log.debug("Tried to cancel unscheduled event: %s" % (name,))
- def _safe_cb_run(self, name, cb):
+ def _safe_cb_run(self, name: str, cb: Callable[[], None]) -> None:
log.debug('Scheduled event: %s', name)
try:
cb()
except Exception as e:
self.exception(e)
- def _execute_and_reschedule(self, name, cb, seconds):
+ def _execute_and_reschedule(self, name: str, cb: Callable[[], None], seconds: int) -> None:
"""Simple method that calls the given callback, and then schedule itself to
be called after the given number of seconds.
"""
@@ -1009,7 +1132,7 @@ class XMLStream(asyncio.BaseProtocol):
name, cb, seconds)
self.scheduled_events[name] = handle
- def _execute_and_unschedule(self, name, cb):
+ def _execute_and_unschedule(self, name: str, cb: Callable[[], None]) -> None:
"""
Execute the callback and remove the handler for it.
"""
@@ -1018,7 +1141,7 @@ class XMLStream(asyncio.BaseProtocol):
if name in self.scheduled_events:
del self.scheduled_events[name]
- def incoming_filter(self, xml):
+ def incoming_filter(self, xml: ET.Element) -> ET.Element:
"""Filter incoming XML objects before they are processed.
Possible uses include remapping namespaces, or correcting elements
@@ -1028,7 +1151,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
return xml
- def _reset_sendq(self):
+ def _reset_sendq(self) -> None:
"""Clear sending tasks on session end"""
# Cancel all pending slow send tasks
log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks))
@@ -1043,7 +1166,7 @@ class XMLStream(asyncio.BaseProtocol):
async def _continue_slow_send(
self,
task: asyncio.Task,
- already_used: Set[Callable[[ElementBase], Optional[StanzaBase]]]
+ already_used: Set[Filter]
) -> None:
"""
Used when an item in the send queue has taken too long to process.
@@ -1060,14 +1183,16 @@ class XMLStream(asyncio.BaseProtocol):
if filter in already_used:
continue
if iscoroutinefunction(filter):
- data = await filter(data)
+ data = await filter(data) # type: ignore
else:
+ filter = cast(SyncFilter, filter)
data = filter(data)
if data is None:
return
- if isinstance(data, ElementBase):
+ if isinstance(data, StanzaBase):
for filter in self.__filters['out_sync']:
+ filter = cast(SyncFilter, filter)
data = filter(data)
if data is None:
return
@@ -1077,19 +1202,21 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.send_raw(data)
- async def run_filters(self):
+ async def run_filters(self) -> NoReturn:
"""
Background loop that processes stanzas to send.
"""
while True:
+ data: Optional[Union[StanzaBase, str]]
(data, use_filters) = await self.waiting_queue.get()
try:
- if isinstance(data, ElementBase):
+ if isinstance(data, StanzaBase):
if use_filters:
already_run_filters = set()
for filter in self.__filters['out']:
already_run_filters.add(filter)
if iscoroutinefunction(filter):
+ filter = cast(AsyncFilter, filter)
task = asyncio.create_task(filter(data))
completed, pending = await wait(
{task},
@@ -1108,21 +1235,26 @@ class XMLStream(asyncio.BaseProtocol):
"Slow coroutine, rescheduling filters"
)
data = task.result()
- else:
+ elif isinstance(data, StanzaBase):
+ filter = cast(SyncFilter, filter)
data = filter(data)
if data is None:
raise ContinueQueue('Empty stanza')
- if isinstance(data, ElementBase):
+ if isinstance(data, StanzaBase):
if use_filters:
for filter in self.__filters['out_sync']:
+ filter = cast(SyncFilter, filter)
data = filter(data)
if data is None:
raise ContinueQueue('Empty stanza')
- str_data = tostring(data.xml, xmlns=self.default_ns,
- stream=self, top_level=True)
+ if isinstance(data, StanzaBase):
+ str_data = tostring(data.xml, xmlns=self.default_ns,
+ stream=self, top_level=True)
+ else:
+ str_data = data
self.send_raw(str_data)
- else:
+ elif isinstance(data, (str, bytes)):
self.send_raw(data)
except ContinueQueue as exc:
log.debug('Stanza in send queue not sent: %s', exc)
@@ -1130,10 +1262,10 @@ class XMLStream(asyncio.BaseProtocol):
log.error('Exception raised in send queue:', exc_info=True)
self.waiting_queue.task_done()
- def send(self, data, use_filters=True):
+ def send(self, data: Union[StanzaBase, str], use_filters: bool = True) -> None:
"""A wrapper for :meth:`send_raw()` for sending stanza objects.
- :param data: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
+ :param data: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
stanza to send on the stream.
:param bool use_filters: Indicates if outgoing filters should be
applied to the given stanza data. Disabling
@@ -1156,15 +1288,15 @@ class XMLStream(asyncio.BaseProtocol):
return
self.waiting_queue.put_nowait((data, use_filters))
- def send_xml(self, data):
+ def send_xml(self, data: ET.Element) -> None:
"""Send an XML object on the stream
:param data: The :class:`~xml.etree.ElementTree.Element` XML object
to send on the stream.
"""
- return self.send(tostring(data))
+ self.send(tostring(data))
- def send_raw(self, data):
+ def send_raw(self, data: Union[str, bytes]) -> None:
"""Send raw data across the stream.
:param string data: Any bytes or utf-8 string value.
@@ -1176,7 +1308,8 @@ class XMLStream(asyncio.BaseProtocol):
data = data.encode('utf-8')
self.transport.write(data)
- def _build_stanza(self, xml, default_ns=None):
+ def _build_stanza(self, xml: ET.Element,
+ default_ns: Optional[str] = None) -> StanzaBase:
"""Create a stanza object from a given XML object.
If a specialized stanza type is not found for the XML, then
@@ -1201,7 +1334,7 @@ class XMLStream(asyncio.BaseProtocol):
stanza['lang'] = self.peer_default_lang
return stanza
- def _spawn_event(self, xml):
+ def _spawn_event(self, xml: ET.Element) -> None:
"""
Analyze incoming XML stanzas and convert them into stanza
objects if applicable and queue stream events to be processed
@@ -1215,9 +1348,10 @@ class XMLStream(asyncio.BaseProtocol):
# Convert the raw XML object into a stanza object. If no registered
# stanza type applies, a generic StanzaBase stanza will be used.
- stanza = self._build_stanza(xml)
+ stanza: Optional[StanzaBase] = self._build_stanza(xml)
for filter in self.__filters['in']:
if stanza is not None:
+ filter = cast(SyncFilter, filter)
stanza = filter(stanza)
if stanza is None:
return
@@ -1244,7 +1378,7 @@ class XMLStream(asyncio.BaseProtocol):
if not handled:
stanza.unhandled()
- def exception(self, exception):
+ def exception(self, exception: Exception) -> None:
"""Process an unknown exception.
Meant to be overridden.
@@ -1253,7 +1387,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
pass
- async def wait_until(self, event: str, timeout=30) -> Any:
+ async def wait_until(self, event: str, timeout: Union[int, float] = 30) -> Any:
"""Utility method to wake on the next firing of an event.
(Registers a disposable handler on it)
@@ -1261,9 +1395,9 @@ class XMLStream(asyncio.BaseProtocol):
:param int timeout: Timeout
:raises: :class:`asyncio.TimeoutError` when the timeout is reached
"""
- fut = asyncio.Future()
+ fut: Future = asyncio.Future()
- def result_handler(event_data):
+ def result_handler(event_data: Any) -> None:
if not fut.done():
fut.set_result(event_data)
else:
@@ -1280,19 +1414,19 @@ class XMLStream(asyncio.BaseProtocol):
return await asyncio.wait_for(fut, timeout)
@contextmanager
- def event_handler(self, event: str, handler: Callable):
+ def event_handler(self, event: str, handler: Callable[..., Any]) -> Generator[None, None, None]:
"""
Context manager that adds then removes an event handler.
"""
self.add_event_handler(event, handler)
try:
yield
- except Exception as exc:
+ except Exception:
raise
finally:
self.del_event_handler(event, handler)
- def wrap(self, coroutine: Coroutine[Any, Any, Any]) -> Future:
+ def wrap(self, coroutine: Coroutine[None, None, T]) -> Future:
"""Make a Future out of a coroutine with the current loop.
:param coroutine: The coroutine to wrap.