diff options
63 files changed, 1156 insertions, 779 deletions
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3aa76989..48c6be9a 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,7 +1,17 @@ stages: + - lint - test - trigger +mypy: + stage: lint + tags: + - docker + image: python:3 + script: + - pip3 install mypy + - mypy slixmpp + test: stage: test tags: diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..edf0d65f --- /dev/null +++ b/mypy.ini @@ -0,0 +1,15 @@ +[mypy] +check_untyped_defs = False +ignore_missing_imports = True + +[mypy-slixmpp.types] +ignore_errors = True + +[mypy-slixmpp.thirdparty.*] +ignore_errors = True + +[mypy-slixmpp.plugins.*] +ignore_errors = True + +[mypy-slixmpp.plugins.base] +ignore_errors = False @@ -83,6 +83,7 @@ setup( url='https://lab.louiz.org/poezio/slixmpp', license='MIT', platforms=['any'], + package_data={'slixmpp': ['py.typed']}, packages=packages, ext_modules=ext_modules, install_requires=['aiodns>=1.0', 'pyasn1', 'pyasn1_modules', 'typing_extensions; python_version < "3.8.0"'], diff --git a/slixmpp/__init__.py b/slixmpp/__init__.py index 769a9e31..403c9299 100644 --- a/slixmpp/__init__.py +++ b/slixmpp/__init__.py @@ -19,7 +19,6 @@ from slixmpp.xmlstream.stanzabase import ET, ElementBase, register_stanza_plugin from slixmpp.xmlstream.handler import * from slixmpp.xmlstream import XMLStream from slixmpp.xmlstream.matcher import * -from slixmpp.xmlstream.asyncio import asyncio, future_wrapper from slixmpp.basexmpp import BaseXMPP from slixmpp.clientxmpp import ClientXMPP from slixmpp.componentxmpp import ComponentXMPP diff --git a/slixmpp/api.py b/slixmpp/api.py index 39fed490..949954cd 100644 --- a/slixmpp/api.py +++ b/slixmpp/api.py @@ -21,7 +21,7 @@ class APIWrapper(object): if name not in self.api.settings: self.api.settings[name] = {} - def __getattr__(self, attr): + def __getattr__(self, attr: str): """Curry API management commands with the API name.""" if attr == 'name': return self.name @@ -33,13 +33,13 @@ class APIWrapper(object): return register(handler, self.name, op, jid, node, default) return partial elif attr == 'register_default': - def partial(handler, op, jid=None, node=None): + def partial1(handler, op, jid=None, node=None): return getattr(self.api, attr)(handler, self.name, op) - return partial + return partial1 elif attr in ('run', 'restore_default', 'unregister'): - def partial(*args, **kwargs): + def partial2(*args, **kwargs): return getattr(self.api, attr)(self.name, *args, **kwargs) - return partial + return partial2 return None def __getitem__(self, attr): @@ -82,7 +82,7 @@ class APIRegistry(object): """Return a wrapper object that targets a specific API.""" return APIWrapper(self, ctype) - def purge(self, ctype: str): + def purge(self, ctype: str) -> None: """Remove all information for a given API.""" del self.settings[ctype] del self._handler_defaults[ctype] @@ -131,22 +131,23 @@ class APIRegistry(object): jid = JID(jid) elif jid == JID(''): jid = self.xmpp.boundjid + assert jid is not None if node is None: node = '' if self.xmpp.is_component: if self.settings[ctype].get('component_bare', False): - jid = jid.bare + jid_str = jid.bare else: - jid = jid.full + jid_str = jid.full else: if self.settings[ctype].get('client_bare', False): - jid = jid.bare + jid_str = jid.bare else: - jid = jid.full + jid_str = jid.full - jid = JID(jid) + jid = JID(jid_str) handler = self._handlers[ctype][op]['node'].get((jid, node), None) if handler is None: @@ -167,8 +168,11 @@ class APIRegistry(object): # To preserve backward compatibility, drop the ifrom # parameter for existing handlers that don't understand it. return handler(jid, node, args) + future = Future() + future.set_result(None) + return future - def register(self, handler: APIHandler, ctype: str, op: str, + def register(self, handler: Optional[APIHandler], ctype: str, op: str, jid: Optional[JID] = None, node: Optional[str] = None, default: bool = False): """Register an API callback, with JID+node specificity. diff --git a/slixmpp/basexmpp.py b/slixmpp/basexmpp.py index 25aa0d75..cd228312 100644 --- a/slixmpp/basexmpp.py +++ b/slixmpp/basexmpp.py @@ -45,10 +45,11 @@ log = logging.getLogger(__name__) from slixmpp.types import ( - PresenceShows, PresenceTypes, MessageTypes, IqTypes, + JidStr, + OptJidStr, ) if TYPE_CHECKING: @@ -263,9 +264,9 @@ class BaseXMPP(XMLStream): if not pconfig: pconfig = self.plugin_config.get(plugin, {}) - if not self.plugin.registered(plugin): + if not self.plugin.registered(plugin): # type: ignore load_plugin(plugin, module) - self.plugin.enable(plugin, pconfig) + self.plugin.enable(plugin, pconfig) # type: ignore def register_plugins(self): """Register and initialize all built-in plugins. @@ -298,25 +299,25 @@ class BaseXMPP(XMLStream): """Return a plugin given its name, if it has been registered.""" return self.plugin.get(key, default) - def Message(self, *args, **kwargs) -> Message: + def Message(self, *args, **kwargs) -> stanza.Message: """Create a Message stanza associated with this stream.""" msg = Message(self, *args, **kwargs) msg['lang'] = self.default_lang return msg - def Iq(self, *args, **kwargs) -> Iq: + def Iq(self, *args, **kwargs) -> stanza.Iq: """Create an Iq stanza associated with this stream.""" return Iq(self, *args, **kwargs) - def Presence(self, *args, **kwargs) -> Presence: + def Presence(self, *args, **kwargs) -> stanza.Presence: """Create a Presence stanza associated with this stream.""" pres = Presence(self, *args, **kwargs) pres['lang'] = self.default_lang return pres - def make_iq(self, id: str = "0", ifrom: Optional[JID] = None, - ito: Optional[JID] = None, itype: Optional[IqTypes] = None, - iquery: Optional[str] = None) -> Iq: + def make_iq(self, id: str = "0", ifrom: OptJidStr = None, + ito: OptJidStr = None, itype: Optional[IqTypes] = None, + iquery: Optional[str] = None) -> stanza.Iq: """Create a new :class:`~.Iq` stanza with a given Id and from JID. :param id: An ideally unique ID value for this stanza thread. @@ -339,8 +340,8 @@ class BaseXMPP(XMLStream): return iq def make_iq_get(self, queryxmlns: Optional[str] =None, - ito: Optional[JID] = None, ifrom: Optional[JID] = None, - iq: Optional[Iq] = None) -> Iq: + ito: OptJidStr = None, ifrom: OptJidStr = None, + iq: Optional[stanza.Iq] = None) -> stanza.Iq: """Create an :class:`~.Iq` stanza of type ``'get'``. Optionally, a query element may be added. @@ -364,8 +365,8 @@ class BaseXMPP(XMLStream): return iq def make_iq_result(self, id: Optional[str] = None, - ito: Optional[JID] = None, ifrom: Optional[JID] = None, - iq: Optional[Iq] = None) -> Iq: + ito: OptJidStr = None, ifrom: OptJidStr = None, + iq: Optional[stanza.Iq] = None) -> stanza.Iq: """ Create an :class:`~.Iq` stanza of type ``'result'`` with the given ID value. @@ -391,8 +392,8 @@ class BaseXMPP(XMLStream): return iq def make_iq_set(self, sub: Optional[Union[ElementBase, ET.Element]] = None, - ito: Optional[JID] = None, ifrom: Optional[JID] = None, - iq: Optional[Iq] = None) -> Iq: + ito: OptJidStr = None, ifrom: OptJidStr = None, + iq: Optional[stanza.Iq] = None) -> stanza.Iq: """ Create an :class:`~.Iq` stanza of type ``'set'``. @@ -414,7 +415,7 @@ class BaseXMPP(XMLStream): if not iq: iq = self.Iq() iq['type'] = 'set' - if sub != None: + if sub is not None: iq.append(sub) if ito: iq['to'] = ito @@ -453,9 +454,9 @@ class BaseXMPP(XMLStream): iq['from'] = ifrom return iq - def make_iq_query(self, iq: Optional[Iq] = None, xmlns: str = '', - ito: Optional[JID] = None, - ifrom: Optional[JID] = None) -> Iq: + def make_iq_query(self, iq: Optional[stanza.Iq] = None, xmlns: str = '', + ito: OptJidStr = None, + ifrom: OptJidStr = None) -> stanza.Iq: """ Create or modify an :class:`~.Iq` stanza to use the given query namespace. @@ -477,7 +478,7 @@ class BaseXMPP(XMLStream): iq['from'] = ifrom return iq - def make_query_roster(self, iq: Optional[Iq] = None) -> ET.Element: + def make_query_roster(self, iq: Optional[stanza.Iq] = None) -> ET.Element: """Create a roster query element. :param iq: Optionally use an existing stanza instead @@ -487,11 +488,11 @@ class BaseXMPP(XMLStream): iq['query'] = 'jabber:iq:roster' return ET.Element("{jabber:iq:roster}query") - def make_message(self, mto: JID, mbody: Optional[str] = None, + def make_message(self, mto: JidStr, mbody: Optional[str] = None, msubject: Optional[str] = None, mtype: Optional[MessageTypes] = None, - mhtml: Optional[str] = None, mfrom: Optional[JID] = None, - mnick: Optional[str] = None) -> Message: + mhtml: Optional[str] = None, mfrom: OptJidStr = None, + mnick: Optional[str] = None) -> stanza.Message: """ Create and initialize a new :class:`~.Message` stanza. @@ -516,13 +517,13 @@ class BaseXMPP(XMLStream): message['html']['body'] = mhtml return message - def make_presence(self, pshow: Optional[PresenceShows] = None, + def make_presence(self, pshow: Optional[str] = None, pstatus: Optional[str] = None, ppriority: Optional[int] = None, - pto: Optional[JID] = None, + pto: OptJidStr = None, ptype: Optional[PresenceTypes] = None, - pfrom: Optional[JID] = None, - pnick: Optional[str] = None) -> Presence: + pfrom: OptJidStr = None, + pnick: Optional[str] = None) -> stanza.Presence: """ Create and initialize a new :class:`~.Presence` stanza. @@ -548,7 +549,7 @@ class BaseXMPP(XMLStream): def send_message(self, mto: JID, mbody: Optional[str] = None, msubject: Optional[str] = None, mtype: Optional[MessageTypes] = None, - mhtml: Optional[str] = None, mfrom: Optional[JID] = None, + mhtml: Optional[str] = None, mfrom: OptJidStr = None, mnick: Optional[str] = None): """ Create, initialize, and send a new @@ -568,12 +569,12 @@ class BaseXMPP(XMLStream): self.make_message(mto, mbody, msubject, mtype, mhtml, mfrom, mnick).send() - def send_presence(self, pshow: Optional[PresenceShows] = None, + def send_presence(self, pshow: Optional[str] = None, pstatus: Optional[str] = None, ppriority: Optional[int] = None, - pto: Optional[JID] = None, + pto: OptJidStr = None, ptype: Optional[PresenceTypes] = None, - pfrom: Optional[JID] = None, + pfrom: OptJidStr = None, pnick: Optional[str] = None): """ Create, initialize, and send a new @@ -590,8 +591,9 @@ class BaseXMPP(XMLStream): self.make_presence(pshow, pstatus, ppriority, pto, ptype, pfrom, pnick).send() - def send_presence_subscription(self, pto, pfrom=None, - ptype='subscribe', pnick=None): + def send_presence_subscription(self, pto: JidStr, pfrom: OptJidStr = None, + ptype: PresenceTypes='subscribe', pnick: + Optional[str] = None): """ Create, initialize, and send a new :class:`~.Presence` stanza of @@ -608,62 +610,62 @@ class BaseXMPP(XMLStream): pnick=pnick).send() @property - def jid(self): + def jid(self) -> str: """Attribute accessor for bare jid""" log.warning("jid property deprecated. Use boundjid.bare") return self.boundjid.bare @jid.setter - def jid(self, value): + def jid(self, value: str): log.warning("jid property deprecated. Use boundjid.bare") self.boundjid.bare = value @property - def fulljid(self): + def fulljid(self) -> str: """Attribute accessor for full jid""" log.warning("fulljid property deprecated. Use boundjid.full") return self.boundjid.full @fulljid.setter - def fulljid(self, value): + def fulljid(self, value: str): log.warning("fulljid property deprecated. Use boundjid.full") self.boundjid.full = value @property - def resource(self): + def resource(self) -> str: """Attribute accessor for jid resource""" log.warning("resource property deprecated. Use boundjid.resource") return self.boundjid.resource @resource.setter - def resource(self, value): + def resource(self, value: str): log.warning("fulljid property deprecated. Use boundjid.resource") self.boundjid.resource = value @property - def username(self): + def username(self) -> str: """Attribute accessor for jid usernode""" log.warning("username property deprecated. Use boundjid.user") return self.boundjid.user @username.setter - def username(self, value): + def username(self, value: str): log.warning("username property deprecated. Use boundjid.user") self.boundjid.user = value @property - def server(self): + def server(self) -> str: """Attribute accessor for jid host""" log.warning("server property deprecated. Use boundjid.host") return self.boundjid.server @server.setter - def server(self, value): + def server(self, value: str): log.warning("server property deprecated. Use boundjid.host") self.boundjid.server = value @property - def auto_authorize(self): + def auto_authorize(self) -> Optional[bool]: """Auto accept or deny subscription requests. If ``True``, auto accept subscription requests. @@ -673,11 +675,11 @@ class BaseXMPP(XMLStream): return self.roster.auto_authorize @auto_authorize.setter - def auto_authorize(self, value): + def auto_authorize(self, value: Optional[bool]): self.roster.auto_authorize = value @property - def auto_subscribe(self): + def auto_subscribe(self) -> bool: """Auto send requests for mutual subscriptions. If ``True``, auto send mutual subscription requests. @@ -685,21 +687,21 @@ class BaseXMPP(XMLStream): return self.roster.auto_subscribe @auto_subscribe.setter - def auto_subscribe(self, value): + def auto_subscribe(self, value: bool): self.roster.auto_subscribe = value - def set_jid(self, jid): + def set_jid(self, jid: JidStr): """Rip a JID apart and claim it as our own.""" log.debug("setting jid to %s", jid) self.boundjid = JID(jid) - def getjidresource(self, fulljid): + def getjidresource(self, fulljid: str): if '/' in fulljid: return fulljid.split('/', 1)[-1] else: return '' - def getjidbare(self, fulljid): + def getjidbare(self, fulljid: str): return fulljid.split('/', 1)[0] def _handle_session_start(self, event): diff --git a/slixmpp/clientxmpp.py b/slixmpp/clientxmpp.py index 37b4c590..754db100 100644 --- a/slixmpp/clientxmpp.py +++ b/slixmpp/clientxmpp.py @@ -8,23 +8,18 @@ # :license: MIT, see LICENSE for more details import asyncio import logging +from typing import Optional, Any, Callable, Tuple, Dict, Set, List from slixmpp.jid import JID -from slixmpp.stanza import StreamFeatures +from slixmpp.stanza import StreamFeatures, Iq from slixmpp.basexmpp import BaseXMPP from slixmpp.exceptions import XMPPError +from slixmpp.types import JidStr from slixmpp.xmlstream import XMLStream +from slixmpp.xmlstream.stanzabase import StanzaBase from slixmpp.xmlstream.matcher import StanzaPath, MatchXPath from slixmpp.xmlstream.handler import Callback, CoroutineCallback -# Flag indicating if DNS SRV records are available for use. -try: - import dns.resolver -except ImportError: - DNSPYTHON = False -else: - DNSPYTHON = True - log = logging.getLogger(__name__) @@ -53,7 +48,7 @@ class ClientXMPP(BaseXMPP): :param escape_quotes: **Deprecated.** """ - def __init__(self, jid, password, plugin_config=None, + def __init__(self, jid: JidStr, password: str, plugin_config=None, plugin_whitelist=None, escape_quotes=True, sasl_mech=None, lang='en', **kwargs): if not plugin_whitelist: @@ -69,7 +64,7 @@ class ClientXMPP(BaseXMPP): self.default_port = 5222 self.default_lang = lang - self.credentials = {} + self.credentials: Dict[str, str] = {} self.password = password @@ -81,9 +76,9 @@ class ClientXMPP(BaseXMPP): "version='1.0'") self.stream_footer = "</stream:stream>" - self.features = set() - self._stream_feature_handlers = {} - self._stream_feature_order = [] + self.features: Set[str] = set() + self._stream_feature_handlers: Dict[str, Tuple[Callable, bool]] = {} + self._stream_feature_order: List[Tuple[int, str]] = [] self.dns_service = 'xmpp-client' @@ -100,10 +95,14 @@ class ClientXMPP(BaseXMPP): self.register_stanza(StreamFeatures) self.register_handler( - CoroutineCallback('Stream Features', - MatchXPath('{%s}features' % self.stream_ns), - self._handle_stream_features)) - def roster_push_filter(iq): + CoroutineCallback( + 'Stream Features', + MatchXPath('{%s}features' % self.stream_ns), + self._handle_stream_features, # type: ignore + ) + ) + + def roster_push_filter(iq: StanzaBase) -> None: from_ = iq['from'] if from_ and from_ != JID('') and from_ != self.boundjid.bare: reply = iq.reply() @@ -131,15 +130,16 @@ class ClientXMPP(BaseXMPP): self['feature_mechanisms'].use_mech = sasl_mech @property - def password(self): + def password(self) -> str: return self.credentials.get('password', '') @password.setter - def password(self, value): + def password(self, value: str) -> None: self.credentials['password'] = value - def connect(self, address=tuple(), use_ssl=False, - force_starttls=True, disable_starttls=False): + def connect(self, address: Optional[Tuple[str, int]] = None, # type: ignore + use_ssl: bool = False, force_starttls: bool = True, + disable_starttls: bool = False) -> None: """Connect to the XMPP server. When no address is given, a SRV lookup for the server will @@ -161,14 +161,15 @@ class ClientXMPP(BaseXMPP): # XMPP client port and allow SRV lookup. if address: self.dns_service = None + host, port = address else: - address = (self.boundjid.host, 5222) + host, port = (self.boundjid.host, 5222) self.dns_service = 'xmpp-client' - return XMLStream.connect(self, address[0], address[1], use_ssl=use_ssl, + return XMLStream.connect(self, host, port, use_ssl=use_ssl, force_starttls=force_starttls, disable_starttls=disable_starttls) - def register_feature(self, name, handler, restart=False, order=5000): + def register_feature(self, name: str, handler: Callable, restart: bool = False, order: int = 5000) -> None: """Register a stream feature handler. :param name: The name of the stream feature. @@ -183,13 +184,13 @@ class ClientXMPP(BaseXMPP): self._stream_feature_order.append((order, name)) self._stream_feature_order.sort() - def unregister_feature(self, name, order): + def unregister_feature(self, name: str, order: int) -> None: if name in self._stream_feature_handlers: del self._stream_feature_handlers[name] self._stream_feature_order.remove((order, name)) self._stream_feature_order.sort() - def update_roster(self, jid, **kwargs): + def update_roster(self, jid: JID, **kwargs) -> None: """Add or change a roster item. :param jid: The JID of the entry to modify. @@ -251,7 +252,7 @@ class ClientXMPP(BaseXMPP): return iq.send(callback, timeout, timeout_callback) - def _reset_connection_state(self, event=None): + def _reset_connection_state(self, event: Optional[Any] = None) -> None: #TODO: Use stream state here self.authenticated = False self.sessionstarted = False @@ -259,7 +260,7 @@ class ClientXMPP(BaseXMPP): self.bindfail = False self.features = set() - async def _handle_stream_features(self, features): + async def _handle_stream_features(self, features: StreamFeatures) -> Optional[bool]: """Process the received stream features. :param features: The features stanza. @@ -277,8 +278,9 @@ class ClientXMPP(BaseXMPP): return True log.debug('Finished processing stream features.') self.event('stream_negotiated') + return None - def _handle_roster(self, iq): + def _handle_roster(self, iq: Iq) -> None: """Update the roster after receiving a roster stanza. :param iq: The roster stanza. @@ -310,7 +312,7 @@ class ClientXMPP(BaseXMPP): resp.enable('roster') resp.send() - def _handle_session_bind(self, jid): + def _handle_session_bind(self, jid: JID) -> None: """Set the client roster to the JID set by the server. :param :class:`slixmpp.xmlstream.jid.JID` jid: The bound JID as diff --git a/slixmpp/features/feature_bind/bind.py b/slixmpp/features/feature_bind/bind.py index 6c30d9f2..75bbcfa2 100644 --- a/slixmpp/features/feature_bind/bind.py +++ b/slixmpp/features/feature_bind/bind.py @@ -1,4 +1,3 @@ - # Slixmpp: The Slick XMPP Library # Copyright (C) 2011 Nathanael C. Fritz # This file is part of Slixmpp. @@ -11,6 +10,7 @@ from slixmpp.stanza import Iq, StreamFeatures from slixmpp.features.feature_bind import stanza from slixmpp.xmlstream import register_stanza_plugin from slixmpp.plugins import BasePlugin +from typing import ClassVar, Set log = logging.getLogger(__name__) @@ -20,7 +20,7 @@ class FeatureBind(BasePlugin): name = 'feature_bind' description = 'RFC 6120: Stream Feature: Resource Binding' - dependencies = set() + dependencies: ClassVar[Set[str]] = set() stanza = stanza def plugin_init(self): diff --git a/slixmpp/features/feature_mechanisms/mechanisms.py b/slixmpp/features/feature_mechanisms/mechanisms.py index dfdca010..db2d73d0 100644 --- a/slixmpp/features/feature_mechanisms/mechanisms.py +++ b/slixmpp/features/feature_mechanisms/mechanisms.py @@ -1,4 +1,3 @@ - # Slixmpp: The Slick XMPP Library # Copyright (C) 2011 Nathanael C. Fritz # This file is part of Slixmpp. @@ -15,6 +14,8 @@ from slixmpp.xmlstream.matcher import MatchXPath from slixmpp.xmlstream.handler import Callback from slixmpp.features.feature_mechanisms import stanza +from typing import ClassVar, Set + log = logging.getLogger(__name__) @@ -23,7 +24,7 @@ class FeatureMechanisms(BasePlugin): name = 'feature_mechanisms' description = 'RFC 6120: Stream Feature: SASL' - dependencies = set() + dependencies: ClassVar[Set[str]] = set() stanza = stanza default_config = { 'use_mech': None, diff --git a/slixmpp/features/feature_mechanisms/stanza/abort.py b/slixmpp/features/feature_mechanisms/stanza/abort.py index afa75cde..a430fa0f 100644 --- a/slixmpp/features/feature_mechanisms/stanza/abort.py +++ b/slixmpp/features/feature_mechanisms/stanza/abort.py @@ -1,9 +1,9 @@ - # Slixmpp: The Slick XMPP Library # Copyright (C) 2011 Nathanael C. Fritz # This file is part of Slixmpp. # See the file LICENSE for copying permission. from slixmpp.xmlstream import StanzaBase +from typing import ClassVar, Set class Abort(StanzaBase): @@ -13,7 +13,7 @@ class Abort(StanzaBase): name = 'abort' namespace = 'urn:ietf:params:xml:ns:xmpp-sasl' - interfaces = set() + interfaces: ClassVar[Set[str]] = set() plugin_attrib = name def setup(self, xml): diff --git a/slixmpp/features/feature_preapproval/preapproval.py b/slixmpp/features/feature_preapproval/preapproval.py index 5dceeb06..46fbaeec 100644 --- a/slixmpp/features/feature_preapproval/preapproval.py +++ b/slixmpp/features/feature_preapproval/preapproval.py @@ -1,4 +1,3 @@ - # Slixmpp: The Slick XMPP Library # Copyright (C) 2012 Nathanael C. Fritz # This file is part of Slixmpp. @@ -9,6 +8,7 @@ from slixmpp.stanza import StreamFeatures from slixmpp.features.feature_preapproval import stanza from slixmpp.xmlstream import register_stanza_plugin from slixmpp.plugins.base import BasePlugin +from typing import ClassVar, Set log = logging.getLogger(__name__) @@ -18,7 +18,7 @@ class FeaturePreApproval(BasePlugin): name = 'feature_preapproval' description = 'RFC 6121: Stream Feature: Subscription Pre-Approval' - dependences = set() + dependencies: ClassVar[Set[str]] = set() stanza = stanza def plugin_init(self): diff --git a/slixmpp/features/feature_preapproval/stanza.py b/slixmpp/features/feature_preapproval/stanza.py index b5d68f8e..630a0b7e 100644 --- a/slixmpp/features/feature_preapproval/stanza.py +++ b/slixmpp/features/feature_preapproval/stanza.py @@ -1,14 +1,14 @@ - # Slixmpp: The Slick XMPP Library # Copyright (C) 2012 Nathanael C. Fritz # This file is part of Slixmpp. # See the file LICENSE for copying permission. from slixmpp.xmlstream import ElementBase +from typing import ClassVar, Set class PreApproval(ElementBase): name = 'sub' namespace = 'urn:xmpp:features:pre-approval' - interfaces = set() + interfaces: ClassVar[Set[str]] = set() plugin_attrib = 'preapproval' diff --git a/slixmpp/features/feature_rosterver/rosterver.py b/slixmpp/features/feature_rosterver/rosterver.py index ae980388..1dea6878 100644 --- a/slixmpp/features/feature_rosterver/rosterver.py +++ b/slixmpp/features/feature_rosterver/rosterver.py @@ -1,4 +1,3 @@ - # Slixmpp: The Slick XMPP Library # Copyright (C) 2012 Nathanael C. Fritz # This file is part of Slixmpp. @@ -9,6 +8,7 @@ from slixmpp.stanza import StreamFeatures from slixmpp.features.feature_rosterver import stanza from slixmpp.xmlstream import register_stanza_plugin from slixmpp.plugins.base import BasePlugin +from typing import ClassVar, Set log = logging.getLogger(__name__) @@ -18,7 +18,7 @@ class FeatureRosterVer(BasePlugin): name = 'feature_rosterver' description = 'RFC 6121: Stream Feature: Roster Versioning' - dependences = set() + dependences: ClassVar[Set[str]] = set() stanza = stanza def plugin_init(self): diff --git a/slixmpp/features/feature_rosterver/stanza.py b/slixmpp/features/feature_rosterver/stanza.py index 3696d89a..428dbb26 100644 --- a/slixmpp/features/feature_rosterver/stanza.py +++ b/slixmpp/features/feature_rosterver/stanza.py @@ -1,14 +1,14 @@ - # Slixmpp: The Slick XMPP Library # Copyright (C) 2012 Nathanael C. Fritz # This file is part of Slixmpp. # See the file LICENSE for copying permission. from slixmpp.xmlstream import ElementBase +from typing import Set, ClassVar class RosterVer(ElementBase): name = 'ver' namespace = 'urn:xmpp:features:rosterver' - interfaces = set() + interfaces: ClassVar[Set[str]] = set() plugin_attrib = 'rosterver' diff --git a/slixmpp/features/feature_session/session.py b/slixmpp/features/feature_session/session.py index 0f9b9dd7..83e6c220 100644 --- a/slixmpp/features/feature_session/session.py +++ b/slixmpp/features/feature_session/session.py @@ -11,6 +11,7 @@ from slixmpp.xmlstream import register_stanza_plugin from slixmpp.plugins import BasePlugin from slixmpp.features.feature_session import stanza +from typing import ClassVar, Set log = logging.getLogger(__name__) @@ -20,7 +21,7 @@ class FeatureSession(BasePlugin): name = 'feature_session' description = 'RFC 3920: Stream Feature: Start Session' - dependencies = set() + dependencies: ClassVar[Set[str]] = set() stanza = stanza def plugin_init(self): diff --git a/slixmpp/features/feature_starttls/stanza.py b/slixmpp/features/feature_starttls/stanza.py index 5552cf61..70979402 100644 --- a/slixmpp/features/feature_starttls/stanza.py +++ b/slixmpp/features/feature_starttls/stanza.py @@ -4,39 +4,47 @@ # This file is part of Slixmpp. # See the file LICENSE for copying permission. from slixmpp.xmlstream import StanzaBase, ElementBase +from typing import Set, ClassVar -class STARTTLS(ElementBase): - - """ +class STARTTLS(StanzaBase): """ + .. code-block:: xml + + <starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/> + + """ name = 'starttls' namespace = 'urn:ietf:params:xml:ns:xmpp-tls' interfaces = {'required'} plugin_attrib = name def get_required(self): - """ - """ return True class Proceed(StanzaBase): - - """ """ + .. code-block:: xml + + <proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/> + + """ name = 'proceed' namespace = 'urn:ietf:params:xml:ns:xmpp-tls' - interfaces = set() + interfaces: ClassVar[Set[str]] = set() class Failure(StanzaBase): - - """ """ + .. code-block:: xml + + <failure xmlns='urn:ietf:params:xml:ns:xmpp-tls'/> + + """ name = 'failure' namespace = 'urn:ietf:params:xml:ns:xmpp-tls' - interfaces = set() + interfaces: ClassVar[Set[str]] = set() diff --git a/slixmpp/features/feature_starttls/starttls.py b/slixmpp/features/feature_starttls/starttls.py index fe793a2d..318d4a5e 100644 --- a/slixmpp/features/feature_starttls/starttls.py +++ b/slixmpp/features/feature_starttls/starttls.py @@ -12,6 +12,8 @@ from slixmpp.xmlstream.matcher import MatchXPath from slixmpp.xmlstream.handler import CoroutineCallback from slixmpp.features.feature_starttls import stanza +from typing import ClassVar, Set + log = logging.getLogger(__name__) @@ -20,7 +22,7 @@ class FeatureSTARTTLS(BasePlugin): name = 'feature_starttls' description = 'RFC 6120: Stream Feature: STARTTLS' - dependencies = set() + dependencies: ClassVar[Set[str]] = set() stanza = stanza def plugin_init(self): @@ -52,7 +54,7 @@ class FeatureSTARTTLS(BasePlugin): elif self.xmpp.disable_starttls: return False else: - self.xmpp.send(features['starttls']) + self.xmpp.send(stanza.STARTTLS()) return True async def _handle_starttls_proceed(self, proceed): diff --git a/slixmpp/jid.py b/slixmpp/jid.py index ee5ef987..adde95a4 100644 --- a/slixmpp/jid.py +++ b/slixmpp/jid.py @@ -351,48 +351,49 @@ class JID: else self._bare) @property - def node(self) -> str: - return self._node - - @property - def domain(self) -> str: - return self._domain - - @property - def resource(self) -> str: - return self._resource - - @property def bare(self) -> str: return self._bare + @bare.setter + def bare(self, value: str): + node, domain, resource = _parse_jid(value) + assert not resource + self._node = node + self._domain = domain + self._update_bare_full() + + @property - def full(self) -> str: - return self._full + def node(self) -> str: + return self._node @node.setter def node(self, value: str): self._node = _validate_node(value) self._update_bare_full() + @property + def domain(self) -> str: + return self._domain + @domain.setter def domain(self, value: str): self._domain = _validate_domain(value) self._update_bare_full() - @bare.setter - def bare(self, value: str): - node, domain, resource = _parse_jid(value) - assert not resource - self._node = node - self._domain = domain - self._update_bare_full() + @property + def resource(self) -> str: + return self._resource @resource.setter def resource(self, value: str): self._resource = _validate_resource(value) self._update_bare_full() + @property + def full(self) -> str: + return self._full + @full.setter def full(self, value: str): self._node, self._domain, self._resource = _parse_jid(value) diff --git a/slixmpp/plugins/base.py b/slixmpp/plugins/base.py index afdb5339..2aaf1b99 100644 --- a/slixmpp/plugins/base.py +++ b/slixmpp/plugins/base.py @@ -12,6 +12,8 @@ import copy import logging import threading +from typing import Any, Dict, Set, ClassVar + log = logging.getLogger(__name__) @@ -250,17 +252,17 @@ class BasePlugin(object): #: A short name for the plugin based on the implemented specification. #: For example, a plugin for XEP-0030 would use `'xep_0030'`. - name = '' + name: str = '' #: A longer name for the plugin, describing its purpose. For example, #: a plugin for XEP-0030 would use `'Service Discovery'` as its #: description value. - description = '' + description: str = '' #: Some plugins may depend on others in order to function properly. #: Any plugin names included in :attr:`~BasePlugin.dependencies` will #: be initialized as needed if this plugin is enabled. - dependencies = set() + dependencies: ClassVar[Set[str]] = set() #: The basic, standard configuration for the plugin, which may #: be overridden when initializing the plugin. The configuration @@ -268,7 +270,7 @@ class BasePlugin(object): #: the plugin. For example, including the configuration field 'foo' #: would mean accessing `plugin.foo` returns the current value of #: `plugin.config['foo']`. - default_config = {} + default_config: ClassVar[Dict[str, Any]] = {} def __init__(self, xmpp, config=None): self.xmpp = xmpp diff --git a/slixmpp/plugins/xep_0012/last_activity.py b/slixmpp/plugins/xep_0012/last_activity.py index 61531431..56905de0 100644 --- a/slixmpp/plugins/xep_0012/last_activity.py +++ b/slixmpp/plugins/xep_0012/last_activity.py @@ -11,11 +11,11 @@ from typing import ( Optional ) -from slixmpp.plugins import BasePlugin, register_plugin -from slixmpp import future_wrapper, JID +from slixmpp.plugins import BasePlugin +from slixmpp import JID from slixmpp.stanza import Iq from slixmpp.exceptions import XMPPError -from slixmpp.xmlstream import JID, register_stanza_plugin +from slixmpp.xmlstream import register_stanza_plugin from slixmpp.xmlstream.handler import CoroutineCallback from slixmpp.xmlstream.matcher import StanzaPath from slixmpp.plugins.xep_0012 import stanza, LastActivity diff --git a/slixmpp/plugins/xep_0054/vcard_temp.py b/slixmpp/plugins/xep_0054/vcard_temp.py index 460013b8..c909f6cd 100644 --- a/slixmpp/plugins/xep_0054/vcard_temp.py +++ b/slixmpp/plugins/xep_0054/vcard_temp.py @@ -4,7 +4,6 @@ # This file is part of Slixmpp. # See the file LICENSE for copying permission. import logging -from asyncio import Future from typing import Optional from slixmpp import JID @@ -15,7 +14,6 @@ from slixmpp.xmlstream.handler import CoroutineCallback from slixmpp.xmlstream.matcher import StanzaPath from slixmpp.plugins import BasePlugin from slixmpp.plugins.xep_0054 import VCardTemp, stanza -from slixmpp import future_wrapper log = logging.getLogger(__name__) diff --git a/slixmpp/plugins/xep_0070/confirm.py b/slixmpp/plugins/xep_0070/confirm.py index 334f78d4..1edde8d9 100644 --- a/slixmpp/plugins/xep_0070/confirm.py +++ b/slixmpp/plugins/xep_0070/confirm.py @@ -1,4 +1,3 @@ - # Slixmpp: The Slick XMPP Library # Copyright (C) 2015 Emmanuel Gil Peyrot # This file is part of Slixmpp. @@ -7,11 +6,10 @@ import asyncio import logging from uuid import uuid4 -from slixmpp.plugins import BasePlugin, register_plugin -from slixmpp import future_wrapper, Iq, Message -from slixmpp.exceptions import XMPPError, IqError, IqTimeout +from slixmpp.plugins import BasePlugin +from slixmpp import Iq, Message from slixmpp.jid import JID -from slixmpp.xmlstream import JID, register_stanza_plugin +from slixmpp.xmlstream import register_stanza_plugin from slixmpp.xmlstream.handler import Callback from slixmpp.xmlstream.matcher import StanzaPath from slixmpp.plugins.xep_0070 import stanza, Confirm @@ -52,7 +50,6 @@ class XEP_0070(BasePlugin): def session_bind(self, jid): self.xmpp['xep_0030'].add_feature('http://jabber.org/protocol/http-auth') - @future_wrapper def ask_confirm(self, jid, id, url, method, *, ifrom=None, message=None): jid = JID(jid) if jid.resource: @@ -70,7 +67,9 @@ class XEP_0070(BasePlugin): if message is not None: stanza['body'] = message.format(id=id, url=url, method=method) stanza.send() - return stanza + fut = asyncio.Future() + fut.set_result(stanza) + return fut else: return stanza.send() diff --git a/slixmpp/plugins/xep_0153/vcard_avatar.py b/slixmpp/plugins/xep_0153/vcard_avatar.py index e2d98b0a..23709c25 100644 --- a/slixmpp/plugins/xep_0153/vcard_avatar.py +++ b/slixmpp/plugins/xep_0153/vcard_avatar.py @@ -17,7 +17,6 @@ from slixmpp.exceptions import XMPPError, IqTimeout, IqError from slixmpp.xmlstream import register_stanza_plugin, ElementBase from slixmpp.plugins.base import BasePlugin from slixmpp.plugins.xep_0153 import stanza, VCardTempUpdate -from slixmpp import future_wrapper log = logging.getLogger(__name__) diff --git a/slixmpp/plugins/xep_0163.py b/slixmpp/plugins/xep_0163.py index d8ab8c8e..46ca4235 100644 --- a/slixmpp/plugins/xep_0163.py +++ b/slixmpp/plugins/xep_0163.py @@ -3,10 +3,11 @@ # Copyright (C) 2011 Nathanael C. Fritz, Lance J.T. Stout # This file is part of Slixmpp. # See the file LICENSE for copying permission. +import asyncio import logging from typing import Optional, Callable -from slixmpp import asyncio, JID +from slixmpp import JID from slixmpp.xmlstream import register_stanza_plugin, ElementBase from slixmpp.plugins.base import BasePlugin, register_plugin from slixmpp.plugins.xep_0004.stanza import Form diff --git a/slixmpp/plugins/xep_0199/ping.py b/slixmpp/plugins/xep_0199/ping.py index 89303ad9..03d272dd 100644 --- a/slixmpp/plugins/xep_0199/ping.py +++ b/slixmpp/plugins/xep_0199/ping.py @@ -3,6 +3,7 @@ # Copyright (C) 2010 Nathanael C. Fritz # This file is part of Slixmpp. # See the file LICENSE for copying permission. +import asyncio import time import logging @@ -11,7 +12,6 @@ from typing import Optional, Callable, List from slixmpp.jid import JID from slixmpp.stanza import Iq -from slixmpp import asyncio from slixmpp.exceptions import IqError, IqTimeout from slixmpp.xmlstream import register_stanza_plugin from slixmpp.xmlstream.matcher import StanzaPath diff --git a/slixmpp/plugins/xep_0231/bob.py b/slixmpp/plugins/xep_0231/bob.py index 30722208..5614b5b0 100644 --- a/slixmpp/plugins/xep_0231/bob.py +++ b/slixmpp/plugins/xep_0231/bob.py @@ -9,14 +9,13 @@ import hashlib from asyncio import Future from typing import Optional -from slixmpp import future_wrapper, JID +from slixmpp import JID from slixmpp.stanza import Iq, Message, Presence -from slixmpp.exceptions import XMPPError from slixmpp.xmlstream.handler import CoroutineCallback from slixmpp.xmlstream.matcher import StanzaPath from slixmpp.xmlstream import register_stanza_plugin from slixmpp.plugins.base import BasePlugin -from slixmpp.plugins.xep_0231 import stanza, BitsOfBinary +from slixmpp.plugins.xep_0231 import BitsOfBinary log = logging.getLogger(__name__) diff --git a/slixmpp/plugins/xep_0325/control.py b/slixmpp/plugins/xep_0325/control.py index 734b3204..467e10d7 100644 --- a/slixmpp/plugins/xep_0325/control.py +++ b/slixmpp/plugins/xep_0325/control.py @@ -5,10 +5,10 @@ # Copyright (C) 2013 Sustainable Innovation, Joachim.lindborg@sust.se, bjorn.westrom@consoden.se # This file is part of Slixmpp. # See the file LICENSE for copying permission. +import asyncio import logging import time -from slixmpp import asyncio from functools import partial from slixmpp.xmlstream import JID from slixmpp.xmlstream.handler import Callback diff --git a/slixmpp/pluginsdict.py b/slixmpp/pluginsdict.py index d9954f51..28ac4a0a 100644 --- a/slixmpp/pluginsdict.py +++ b/slixmpp/pluginsdict.py @@ -20,6 +20,7 @@ from slixmpp.plugins.xep_0030 import XEP_0030 from slixmpp.plugins.xep_0033 import XEP_0033 from slixmpp.plugins.xep_0045 import XEP_0045 from slixmpp.plugins.xep_0047 import XEP_0047 +from slixmpp.plugins.xep_0048 import XEP_0048 from slixmpp.plugins.xep_0049 import XEP_0049 from slixmpp.plugins.xep_0050 import XEP_0050 from slixmpp.plugins.xep_0054 import XEP_0054 @@ -112,6 +113,7 @@ class PluginsDict(TypedDict): xep_0033: XEP_0033 xep_0045: XEP_0045 xep_0047: XEP_0047 + xep_0048: XEP_0048 xep_0049: XEP_0049 xep_0050: XEP_0050 xep_0054: XEP_0054 diff --git a/slixmpp/py.typed b/slixmpp/py.typed new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/slixmpp/py.typed diff --git a/slixmpp/stanza/error.py b/slixmpp/stanza/error.py index 54ace5a8..76c6b9cc 100644 --- a/slixmpp/stanza/error.py +++ b/slixmpp/stanza/error.py @@ -1,8 +1,9 @@ - # Slixmpp: The Slick XMPP Library # Copyright (C) 2010 Nathanael C. Fritz # This file is part of Slixmpp. # See the file LICENSE for copying permission. +from __future__ import annotations +from typing import Optional, Dict, Type, ClassVar from slixmpp.xmlstream import ElementBase, ET @@ -49,10 +50,10 @@ class Error(ElementBase): name = 'error' plugin_attrib = 'error' interfaces = {'code', 'condition', 'text', 'type', - 'gone', 'redirect', 'by'} + 'gone', 'redirect', 'by'} sub_interfaces = {'text'} - plugin_attrib_map = {} - plugin_tag_map = {} + plugin_attrib_map: ClassVar[Dict[str, Type[ElementBase]]] = {} + plugin_tag_map: ClassVar[Dict[str, Type[ElementBase]]] = {} conditions = {'bad-request', 'conflict', 'feature-not-implemented', 'forbidden', 'gone', 'internal-server-error', 'item-not-found', 'jid-malformed', 'not-acceptable', @@ -62,10 +63,10 @@ class Error(ElementBase): 'remote-server-timeout', 'resource-constraint', 'service-unavailable', 'subscription-required', 'undefined-condition', 'unexpected-request'} - condition_ns = 'urn:ietf:params:xml:ns:xmpp-stanzas' + condition_ns: str = 'urn:ietf:params:xml:ns:xmpp-stanzas' types = {'cancel', 'continue', 'modify', 'auth', 'wait'} - def setup(self, xml=None): + def setup(self, xml: Optional[ET.Element] = None): """ Populate the stanza object using an optional XML object. @@ -82,9 +83,11 @@ class Error(ElementBase): self['type'] = 'cancel' self['condition'] = 'feature-not-implemented' if self.parent is not None: - self.parent()['type'] = 'error' + parent = self.parent() + if parent: + parent['type'] = 'error' - def get_condition(self): + def get_condition(self) -> str: """Return the condition element's name.""" for child in self.xml: if "{%s}" % self.condition_ns in child.tag: @@ -93,7 +96,7 @@ class Error(ElementBase): return cond return '' - def set_condition(self, value): + def set_condition(self, value: str) -> Error: """ Set the tag name of the condition element. @@ -105,7 +108,7 @@ class Error(ElementBase): self.xml.append(ET.Element("{%s}%s" % (self.condition_ns, value))) return self - def del_condition(self): + def del_condition(self) -> Error: """Remove the condition element.""" for child in self.xml: if "{%s}" % self.condition_ns in child.tag: @@ -139,14 +142,14 @@ class Error(ElementBase): def get_redirect(self): return self._get_sub_text('{%s}redirect' % self.condition_ns, '') - def set_gone(self, value): + def set_gone(self, value: str): if value: del self['condition'] return self._set_sub_text('{%s}gone' % self.condition_ns, value) elif self['condition'] == 'gone': del self['condition'] - def set_redirect(self, value): + def set_redirect(self, value: str): if value: del self['condition'] ns = self.condition_ns diff --git a/slixmpp/stanza/handshake.py b/slixmpp/stanza/handshake.py index c58f69aa..70f890be 100644 --- a/slixmpp/stanza/handshake.py +++ b/slixmpp/stanza/handshake.py @@ -4,6 +4,7 @@ # See the file LICENSE for copying permission. from slixmpp.xmlstream import StanzaBase +from typing import Optional class Handshake(StanzaBase): @@ -18,7 +19,7 @@ class Handshake(StanzaBase): def set_value(self, value: str): self.xml.text = value - def get_value(self) -> str: + def get_value(self) -> Optional[str]: return self.xml.text def del_value(self): diff --git a/slixmpp/stanza/iq.py b/slixmpp/stanza/iq.py index 044c9df8..34e56f60 100644 --- a/slixmpp/stanza/iq.py +++ b/slixmpp/stanza/iq.py @@ -3,10 +3,10 @@ # Copyright (C) 2010 Nathanael C. Fritz # This file is part of Slixmpp. # See the file LICENSE for copying permission. +import asyncio from slixmpp.stanza.rootstanza import RootStanza from slixmpp.xmlstream import StanzaBase, ET -from slixmpp.xmlstream.handler import Waiter, Callback, CoroutineCallback -from slixmpp.xmlstream.asyncio import asyncio +from slixmpp.xmlstream.handler import Callback, CoroutineCallback from slixmpp.xmlstream.matcher import MatchIDSender, MatcherId from slixmpp.exceptions import IqTimeout, IqError diff --git a/slixmpp/stanza/message.py b/slixmpp/stanza/message.py index debfb380..50d32ff0 100644 --- a/slixmpp/stanza/message.py +++ b/slixmpp/stanza/message.py @@ -61,8 +61,10 @@ class Message(RootStanza): """ StanzaBase.__init__(self, *args, **kwargs) if not recv and self['id'] == '': - if self.stream is not None and self.stream.use_message_ids: - self['id'] = self.stream.new_id() + if self.stream: + use_ids = getattr(self.stream, 'use_message_ids', None) + if use_ids: + self['id'] = self.stream.new_id() else: del self['origin_id'] @@ -93,8 +95,10 @@ class Message(RootStanza): self.xml.attrib['id'] = value - if self.stream and not self.stream.use_origin_id: - return None + if self.stream: + use_orig_ids = getattr(self.stream, 'use_origin_id', None) + if not use_orig_ids: + return None sub = self.xml.find(ORIGIN_NAME) if sub is not None: diff --git a/slixmpp/stanza/presence.py b/slixmpp/stanza/presence.py index 022e7133..d77ce1e4 100644 --- a/slixmpp/stanza/presence.py +++ b/slixmpp/stanza/presence.py @@ -1,4 +1,3 @@ - # Slixmpp: The Slick XMPP Library # Copyright (C) 2010 Nathanael C. Fritz # This file is part of Slixmpp. @@ -61,7 +60,7 @@ class Presence(RootStanza): 'subscribed', 'unsubscribe', 'unsubscribed'} showtypes = {'dnd', 'chat', 'xa', 'away'} - def __init__(self, *args, recv=False, **kwargs): + def __init__(self, *args, recv: bool = False, **kwargs): """ Initialize a new <presence /> stanza with an optional 'id' value. @@ -69,10 +68,12 @@ class Presence(RootStanza): """ StanzaBase.__init__(self, *args, **kwargs) if not recv and self['id'] == '': - if self.stream is not None and self.stream.use_presence_ids: - self['id'] = self.stream.new_id() + if self.stream: + use_ids = getattr(self.stream, 'use_presence_ids', None) + if use_ids: + self['id'] = self.stream.new_id() - def set_show(self, show): + def set_show(self, show: str): """ Set the value of the <show> element. @@ -84,7 +85,7 @@ class Presence(RootStanza): self._set_sub_text('show', text=show) return self - def get_type(self): + def get_type(self) -> str: """ Return the value of the <presence> stanza's type attribute, or the value of the <show> element if valid. @@ -96,7 +97,7 @@ class Presence(RootStanza): out = 'available' return out - def set_type(self, value): + def set_type(self, value: str): """ Set the type attribute's value, and the <show> element if applicable. @@ -119,7 +120,7 @@ class Presence(RootStanza): self._del_attr('type') self._del_sub('show') - def set_priority(self, value): + def set_priority(self, value: int): """ Set the entity's priority value. Some server use priority to determine message routing behavior. diff --git a/slixmpp/stanza/stream_error.py b/slixmpp/stanza/stream_error.py index 0e728c8e..d0eadd5b 100644 --- a/slixmpp/stanza/stream_error.py +++ b/slixmpp/stanza/stream_error.py @@ -4,7 +4,8 @@ # This file is part of Slixmpp. # See the file LICENSE for copying permission. from slixmpp.stanza.error import Error -from slixmpp.xmlstream import StanzaBase +from slixmpp.xmlstream import StanzaBase, ET +from typing import Optional, Dict, Union class StreamError(Error, StanzaBase): @@ -62,19 +63,20 @@ class StreamError(Error, StanzaBase): 'system-shutdown', 'undefined-condition', 'unsupported-encoding', 'unsupported-feature', 'unsupported-stanza-type', 'unsupported-version'} - condition_ns = 'urn:ietf:params:xml:ns:xmpp-streams' + condition_ns: str = 'urn:ietf:params:xml:ns:xmpp-streams' - def get_see_other_host(self): + def get_see_other_host(self) -> Union[str, Dict[str, str]]: ns = self.condition_ns return self._get_sub_text('{%s}see-other-host' % ns, '') - def set_see_other_host(self, value): + def set_see_other_host(self, value: str) -> Optional[ET.Element]: if value: del self['condition'] ns = self.condition_ns return self._set_sub_text('{%s}see-other-host' % ns, value) elif self['condition'] == 'see-other-host': del self['condition'] + return None - def del_see_other_host(self): + def del_see_other_host(self) -> None: self._del_sub('{%s}see-other-host' % self.condition_ns) diff --git a/slixmpp/stanza/stream_features.py b/slixmpp/stanza/stream_features.py index 7362f17b..a3e05a1a 100644 --- a/slixmpp/stanza/stream_features.py +++ b/slixmpp/stanza/stream_features.py @@ -3,7 +3,8 @@ # Copyright (C) 2010 Nathanael C. Fritz # This file is part of Slixmpp. # See the file LICENSE for copying permission. -from slixmpp.xmlstream import StanzaBase +from slixmpp.xmlstream import StanzaBase, ElementBase +from typing import ClassVar, Dict, Type class StreamFeatures(StanzaBase): @@ -15,8 +16,8 @@ class StreamFeatures(StanzaBase): namespace = 'http://etherx.jabber.org/streams' interfaces = {'features', 'required', 'optional'} sub_interfaces = interfaces - plugin_tag_map = {} - plugin_attrib_map = {} + plugin_attrib_map: ClassVar[Dict[str, Type[ElementBase]]] = {} + plugin_tag_map: ClassVar[Dict[str, Type[ElementBase]]] = {} def setup(self, xml): StanzaBase.setup(self, xml) diff --git a/slixmpp/test/integration.py b/slixmpp/test/integration.py index e8093107..fe33fa57 100644 --- a/slixmpp/test/integration.py +++ b/slixmpp/test/integration.py @@ -11,7 +11,7 @@ except ImportError: # Python < 3.8 # just to make sure the imports do not break, but # not usable. - from unittest import TestCase as IsolatedAsyncioTestCase + from unittest import TestCase as IsolatedAsyncioTestCase # type: ignore from typing import ( Dict, List, diff --git a/slixmpp/test/slixtest.py b/slixmpp/test/slixtest.py index 7c700fd2..0d05a4ac 100644 --- a/slixmpp/test/slixtest.py +++ b/slixmpp/test/slixtest.py @@ -17,9 +17,7 @@ from slixmpp.xmlstream.matcher import StanzaPath, MatcherId, MatchIDSender from slixmpp.xmlstream.matcher import MatchXMLMask, MatchXPath import asyncio -cls = asyncio.get_event_loop().__class__ -cls.idle_call = lambda self, callback: callback() class SlixTest(unittest.TestCase): diff --git a/slixmpp/types.py b/slixmpp/types.py index 453d25e3..336ab7d8 100644 --- a/slixmpp/types.py +++ b/slixmpp/types.py @@ -16,11 +16,13 @@ try: from typing import ( Literal, TypedDict, + Protocol, ) except ImportError: from typing_extensions import ( Literal, TypedDict, + Protocol, ) from slixmpp.jid import JID @@ -78,3 +80,11 @@ JidStr = Union[str, JID] OptJidStr = Optional[Union[str, JID]] MAMDefault = Literal['always', 'never', 'roster'] + +FilterString = Literal['in', 'out', 'out_sync'] + +__all__ = [ + 'Protocol', 'TypedDict', 'Literal', 'OptJid', 'JidStr', 'MAMDefault', + 'PresenceTypes', 'PresenceShows', 'MessageTypes', 'IqTypes', 'MucRole', + 'MucAffiliation', 'FilterString', +] diff --git a/slixmpp/util/cache.py b/slixmpp/util/cache.py index 23592404..b7042a56 100644 --- a/slixmpp/util/cache.py +++ b/slixmpp/util/cache.py @@ -1,4 +1,3 @@ - # Slixmpp: The Slick XMPP Library # Copyright (C) 2018 Emmanuel Gil Peyrot # This file is part of Slixmpp. @@ -6,8 +5,11 @@ import os import logging +from typing import Callable, Optional, Any + log = logging.getLogger(__name__) + class Cache: def retrieve(self, key): raise NotImplementedError @@ -16,7 +18,8 @@ class Cache: raise NotImplementedError def remove(self, key): - raise NotImplemented + raise NotImplementedError + class PerJidCache: def retrieve_by_jid(self, jid, key): @@ -28,6 +31,7 @@ class PerJidCache: def remove_by_jid(self, jid, key): raise NotImplementedError + class MemoryCache(Cache): def __init__(self): self.cache = {} @@ -44,6 +48,7 @@ class MemoryCache(Cache): del self.cache[key] return True + class MemoryPerJidCache(PerJidCache): def __init__(self): self.cache = {} @@ -65,14 +70,15 @@ class MemoryPerJidCache(PerJidCache): del cache[key] return True + class FileSystemStorage: - def __init__(self, encode, decode, binary): + def __init__(self, encode: Optional[Callable[[Any], str]], decode: Optional[Callable[[str], Any]], binary: bool): self.encode = encode if encode is not None else lambda x: x self.decode = decode if decode is not None else lambda x: x self.read = 'rb' if binary else 'r' self.write = 'wb' if binary else 'w' - def _retrieve(self, directory, key): + def _retrieve(self, directory: str, key: str): filename = os.path.join(directory, key.replace('/', '_')) try: with open(filename, self.read) as cache_file: @@ -86,7 +92,7 @@ class FileSystemStorage: log.debug('Removing %s entry', key) self._remove(directory, key) - def _store(self, directory, key, value): + def _store(self, directory: str, key: str, value): filename = os.path.join(directory, key.replace('/', '_')) try: os.makedirs(directory, exist_ok=True) @@ -99,7 +105,7 @@ class FileSystemStorage: except Exception: log.debug('Failed to encode %s to cache:', key, exc_info=True) - def _remove(self, directory, key): + def _remove(self, directory: str, key: str): filename = os.path.join(directory, key.replace('/', '_')) try: os.remove(filename) @@ -108,8 +114,9 @@ class FileSystemStorage: return False return True + class FileSystemCache(Cache, FileSystemStorage): - def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False): + def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False): FileSystemStorage.__init__(self, encode, decode, binary) self.base_dir = os.path.join(directory, cache_type) @@ -122,8 +129,9 @@ class FileSystemCache(Cache, FileSystemStorage): def remove(self, key): return self._remove(self.base_dir, key) + class FileSystemPerJidCache(PerJidCache, FileSystemStorage): - def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False): + def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False): FileSystemStorage.__init__(self, encode, decode, binary) self.base_dir = os.path.join(directory, cache_type) diff --git a/slixmpp/util/misc_ops.py b/slixmpp/util/misc_ops.py index 1dcd6e3f..ed16d347 100644 --- a/slixmpp/util/misc_ops.py +++ b/slixmpp/util/misc_ops.py @@ -2,15 +2,19 @@ import builtins import sys import hashlib +from typing import Optional, Union, Callable, List -def unicode(text): +bytes_ = builtins.bytes # alias the stdlib type but ew + + +def unicode(text: Union[bytes_, str]) -> str: if not isinstance(text, str): return text.decode('utf-8') else: return text -def bytes(text): +def bytes(text: Optional[Union[str, bytes_]]) -> bytes_: """ Convert Unicode text to UTF-8 encoded bytes. @@ -34,7 +38,7 @@ def bytes(text): return builtins.bytes(text, encoding='utf-8') -def quote(text): +def quote(text: Union[str, bytes_]) -> bytes_: """ Enclose in quotes and escape internal slashes and double quotes. @@ -44,7 +48,7 @@ def quote(text): return b'"' + text.replace(b'\\', b'\\\\').replace(b'"', b'\\"') + b'"' -def num_to_bytes(num): +def num_to_bytes(num: int) -> bytes_: """ Convert an integer into a four byte sequence. @@ -58,21 +62,21 @@ def num_to_bytes(num): return bval -def bytes_to_num(bval): +def bytes_to_num(bval: bytes_) -> int: """ Convert a four byte sequence to an integer. :param bytes bval: A four byte sequence to turn into an integer. """ num = 0 - num += ord(bval[0] << 24) - num += ord(bval[1] << 16) - num += ord(bval[2] << 8) - num += ord(bval[3]) + num += (bval[0] << 24) + num += (bval[1] << 16) + num += (bval[2] << 8) + num += (bval[3]) return num -def XOR(x, y): +def XOR(x: bytes_, y: bytes_) -> bytes_: """ Return the results of an XOR operation on two equal length byte strings. @@ -85,7 +89,7 @@ def XOR(x, y): return builtins.bytes([a ^ b for a, b in zip(x, y)]) -def hash(name): +def hash(name: str) -> Optional[Callable]: """ Return a hash function implementing the given algorithm. @@ -102,7 +106,7 @@ def hash(name): return None -def hashes(): +def hashes() -> List[str]: """ Return a list of available hashing algorithms. @@ -115,28 +119,3 @@ def hashes(): t += ['MD2'] hashes = ['SHA-' + h[3:] for h in dir(hashlib) if h.startswith('sha')] return t + hashes - - -def setdefaultencoding(encoding): - """ - Set the current default string encoding used by the Unicode implementation. - - Actually calls sys.setdefaultencoding under the hood - see the docs for that - for more details. This method exists only as a way to call find/call it - even after it has been 'deleted' when the site module is executed. - - :param string encoding: An encoding name, compatible with sys.setdefaultencoding - """ - func = getattr(sys, 'setdefaultencoding', None) - if func is None: - import gc - import types - for obj in gc.get_objects(): - if (isinstance(obj, types.BuiltinFunctionType) - and obj.__name__ == 'setdefaultencoding'): - func = obj - break - if func is None: - raise RuntimeError("Could not find setdefaultencoding") - sys.setdefaultencoding = func - return func(encoding) diff --git a/slixmpp/util/sasl/client.py b/slixmpp/util/sasl/client.py index 7c9d38e0..7565db6b 100644 --- a/slixmpp/util/sasl/client.py +++ b/slixmpp/util/sasl/client.py @@ -1,4 +1,3 @@ - # slixmpp.util.sasl.client # ~~~~~~~~~~~~~~~~~~~~~~~~~~ # This module was originally based on Dave Cridland's Suelta library. @@ -6,9 +5,11 @@ # :copryight: (c) 2004-2013 David Alan Cridland # :copyright: (c) 2013 Nathanael C. Fritz, Lance J.T. Stout # :license: MIT, see LICENSE for more details +from __future__ import annotations import logging import stringprep +from typing import Iterable, Set, Callable, Dict, Any, Optional, Type from slixmpp.util import hashes, bytes, stringprep_profiles @@ -16,11 +17,11 @@ log = logging.getLogger(__name__) #: Global registry mapping mechanism names to implementation classes. -MECHANISMS = {} +MECHANISMS: Dict[str, Type[Mech]] = {} #: Global registry mapping mechanism names to security scores. -MECH_SEC_SCORES = {} +MECH_SEC_SCORES: Dict[str, int] = {} #: The SASLprep profile of stringprep used to validate simple username @@ -45,9 +46,10 @@ saslprep = stringprep_profiles.create( unassigned=[stringprep.in_table_a1]) -def sasl_mech(score): +def sasl_mech(score: int): sec_score = score - def register(mech): + + def register(mech: Type[Mech]): n = 0 mech.score = sec_score if mech.use_hashes: @@ -99,9 +101,9 @@ class Mech(object): score = -1 use_hashes = False channel_binding = False - required_credentials = set() - optional_credentials = set() - security = set() + required_credentials: Set[str] = set() + optional_credentials: Set[str] = set() + security: Set[str] = set() def __init__(self, name, credentials, security_settings): self.credentials = credentials @@ -118,7 +120,14 @@ class Mech(object): return b'' -def choose(mech_list, credentials, security_settings, limit=None, min_mech=None): +CredentialsCallback = Callable[[Iterable[str], Iterable[str]], Dict[str, Any]] +SecurityCallback = Callable[[Iterable[str]], Dict[str, Any]] + + +def choose(mech_list: Iterable[Type[Mech]], credentials: CredentialsCallback, + security_settings: SecurityCallback, + limit: Optional[Iterable[Type[Mech]]] = None, + min_mech: Optional[str] = None) -> Mech: available_mechs = set(MECHANISMS.keys()) if limit is None: limit = set(mech_list) @@ -130,7 +139,10 @@ def choose(mech_list, credentials, security_settings, limit=None, min_mech=None) mech_list = mech_list.intersection(limit) available_mechs = available_mechs.intersection(mech_list) - best_score = MECH_SEC_SCORES.get(min_mech, -1) + if min_mech is None: + best_score = -1 + else: + best_score = MECH_SEC_SCORES.get(min_mech, -1) best_mech = None for name in available_mechs: if name in MECH_SEC_SCORES: diff --git a/slixmpp/util/sasl/mechanisms.py b/slixmpp/util/sasl/mechanisms.py index 53f39395..d53caec8 100644 --- a/slixmpp/util/sasl/mechanisms.py +++ b/slixmpp/util/sasl/mechanisms.py @@ -11,6 +11,9 @@ import hmac import random from base64 import b64encode, b64decode +from typing import List, Dict, Optional + +bytes_ = bytes from slixmpp.util import bytes, hash, XOR, quote, num_to_bytes from slixmpp.util.sasl.client import sasl_mech, Mech, \ @@ -63,7 +66,7 @@ class PLAIN(Mech): if not self.security_settings['encrypted_plain']: raise SASLCancelled('PLAIN with encryption') - def process(self, challenge=b''): + def process(self, challenge: bytes_ = b'') -> bytes_: authzid = self.credentials['authzid'] authcid = self.credentials['username'] password = self.credentials['password'] @@ -148,7 +151,7 @@ class CRAM(Mech): required_credentials = {'username', 'password'} security = {'encrypted', 'unencrypted_cram'} - def setup(self, name): + def setup(self, name: str): self.hash_name = name[5:] self.hash = hash(self.hash_name) if self.hash is None: @@ -157,14 +160,14 @@ class CRAM(Mech): if not self.security_settings['unencrypted_cram']: raise SASLCancelled('Unecrypted CRAM-%s' % self.hash_name) - def process(self, challenge=b''): + def process(self, challenge: bytes_ = b'') -> Optional[bytes_]: if not challenge: return None username = self.credentials['username'] password = self.credentials['password'] - mac = hmac.HMAC(key=password, digestmod=self.hash) + mac = hmac.HMAC(key=password, digestmod=self.hash) # type: ignore mac.update(challenge) return username + b' ' + bytes(mac.hexdigest()) @@ -201,43 +204,42 @@ class SCRAM(Mech): def HMAC(self, key, msg): return hmac.HMAC(key=key, msg=msg, digestmod=self.hash).digest() - def Hi(self, text, salt, iterations): - text = bytes(text) - ui1 = self.HMAC(text, salt + b'\0\0\0\01') + def Hi(self, text: str, salt: bytes_, iterations: int): + text_enc = bytes(text) + ui1 = self.HMAC(text_enc, salt + b'\0\0\0\01') ui = ui1 for i in range(iterations - 1): - ui1 = self.HMAC(text, ui1) + ui1 = self.HMAC(text_enc, ui1) ui = XOR(ui, ui1) return ui - def H(self, text): + def H(self, text: str) -> bytes_: return self.hash(text).digest() - def saslname(self, value): - value = value.decode("utf-8") - escaped = [] + def saslname(self, value_b: bytes_) -> bytes_: + value = value_b.decode("utf-8") + escaped: List[str] = [] for char in value: if char == ',': - escaped += b'=2C' + escaped.append('=2C') elif char == '=': - escaped += b'=3D' + escaped.append('=3D') else: - escaped += char + escaped.append(char) return "".join(escaped).encode("utf-8") - def parse(self, challenge): + def parse(self, challenge: bytes_) -> Dict[bytes_, bytes_]: items = {} for key, value in [item.split(b'=', 1) for item in challenge.split(b',')]: items[key] = value return items - def process(self, challenge=b''): + def process(self, challenge: bytes_ = b''): steps = [self.process_1, self.process_2, self.process_3] return steps[self.step](challenge) - def process_1(self, challenge): + def process_1(self, challenge: bytes_) -> bytes_: self.step = 1 - data = {} self.cnonce = bytes(('%s' % random.random())[2:]) @@ -263,7 +265,7 @@ class SCRAM(Mech): return self.client_first_message - def process_2(self, challenge): + def process_2(self, challenge: bytes_) -> bytes_: self.step = 2 data = self.parse(challenge) @@ -304,7 +306,7 @@ class SCRAM(Mech): return client_final_message - def process_3(self, challenge): + def process_3(self, challenge: bytes_) -> bytes_: data = self.parse(challenge) verifier = data.get(b'v', None) error = data.get(b'e', 'Unknown error') @@ -345,17 +347,16 @@ class DIGEST(Mech): self.cnonce = b'' self.nonce_count = 1 - def parse(self, challenge=b''): - data = {} + def parse(self, challenge: bytes_ = b''): + data: Dict[str, bytes_] = {} var_name = b'' var_value = b'' # States: var, new_var, end, quote, escaped_quote state = 'var' - - for char in challenge: - char = bytes([char]) + for char_int in challenge: + char = bytes_([char_int]) if state == 'var': if char.isspace(): @@ -401,14 +402,14 @@ class DIGEST(Mech): state = 'var' return data - def MAC(self, key, seq, msg): + def MAC(self, key: bytes_, seq: int, msg: bytes_) -> bytes_: mac = hmac.HMAC(key=key, digestmod=self.hash) seqnum = num_to_bytes(seq) mac.update(seqnum) mac.update(msg) return mac.digest()[:10] + b'\x00\x01' + seqnum - def A1(self): + def A1(self) -> bytes_: username = self.credentials['username'] password = self.credentials['password'] authzid = self.credentials['authzid'] @@ -423,13 +424,13 @@ class DIGEST(Mech): return bytes(a1) - def A2(self, prefix=b''): + def A2(self, prefix: bytes_ = b'') -> bytes_: a2 = prefix + b':' + self.digest_uri() if self.qop in (b'auth-int', b'auth-conf'): a2 += b':00000000000000000000000000000000' return bytes(a2) - def response(self, prefix=b''): + def response(self, prefix: bytes_ = b'') -> bytes_: nc = bytes('%08x' % self.nonce_count) a1 = bytes(self.hash(self.A1()).hexdigest().lower()) @@ -439,7 +440,7 @@ class DIGEST(Mech): return bytes(self.hash(a1 + b':' + s).hexdigest().lower()) - def digest_uri(self): + def digest_uri(self) -> bytes_: serv_type = self.credentials['service'] serv_name = self.credentials['service-name'] host = self.credentials['host'] @@ -449,7 +450,7 @@ class DIGEST(Mech): uri += b'/' + serv_name return uri - def respond(self): + def respond(self) -> bytes_: data = { 'username': quote(self.credentials['username']), 'authzid': quote(self.credentials['authzid']), @@ -469,7 +470,7 @@ class DIGEST(Mech): resp += b',' + bytes(key) + b'=' + bytes(value) return resp[1:] - def process(self, challenge=b''): + def process(self, challenge: bytes_ = b'') -> Optional[bytes_]: if not challenge: if self.cnonce and self.nonce and self.nonce_count and self.qop: self.nonce_count += 1 @@ -480,6 +481,7 @@ class DIGEST(Mech): if 'rspauth' in data: if data['rspauth'] != self.response(): raise SASLMutualAuthFailed() + return None else: self.nonce_count = 1 self.cnonce = bytes('%s' % random.random())[2:] diff --git a/slixmpp/xmlstream/asyncio.py b/slixmpp/xmlstream/asyncio.py deleted file mode 100644 index b42b366a..00000000 --- a/slixmpp/xmlstream/asyncio.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -asyncio-related utilities -""" - -import asyncio -from functools import wraps - -def future_wrapper(func): - """ - Make sure the result of a function call is an asyncio.Future() - object. - """ - @wraps(func) - def wrapper(*args, **kwargs): - result = func(*args, **kwargs) - if isinstance(result, asyncio.Future): - return result - future = asyncio.Future() - future.set_result(result) - return future - - return wrapper diff --git a/slixmpp/xmlstream/cert.py b/slixmpp/xmlstream/cert.py index 28ef585a..41679a7e 100644 --- a/slixmpp/xmlstream/cert.py +++ b/slixmpp/xmlstream/cert.py @@ -1,5 +1,6 @@ import logging from datetime import datetime, timedelta +from typing import Dict, Set, Tuple, Optional # Make a call to strptime before starting threads to # prevent thread safety issues. @@ -32,13 +33,13 @@ class CertificateError(Exception): pass -def decode_str(data): +def decode_str(data: bytes) -> str: encoding = 'utf-16-be' if isinstance(data, BMPString) else 'utf-8' return bytes(data).decode(encoding) -def extract_names(raw_cert): - results = {'CN': set(), +def extract_names(raw_cert: bytes) -> Dict[str, Set[str]]: + results: Dict[str, Set[str]] = {'CN': set(), 'DNS': set(), 'SRV': set(), 'URI': set(), @@ -96,7 +97,7 @@ def extract_names(raw_cert): return results -def extract_dates(raw_cert): +def extract_dates(raw_cert: bytes) -> Tuple[Optional[datetime], Optional[datetime]]: if not HAVE_PYASN1: log.warning("Could not find pyasn1 and pyasn1_modules. " + \ "SSL certificate expiration COULD NOT BE VERIFIED.") @@ -125,24 +126,29 @@ def extract_dates(raw_cert): return not_before, not_after -def get_ttl(raw_cert): +def get_ttl(raw_cert: bytes) -> Optional[timedelta]: not_before, not_after = extract_dates(raw_cert) - if not_after is None: + if not_after is None or not_before is None: return None return not_after - datetime.utcnow() -def verify(expected, raw_cert): +def verify(expected: str, raw_cert: bytes) -> Optional[bool]: if not HAVE_PYASN1: log.warning("Could not find pyasn1 and pyasn1_modules. " + \ "SSL certificate COULD NOT BE VERIFIED.") - return + return None not_before, not_after = extract_dates(raw_cert) cert_names = extract_names(raw_cert) now = datetime.utcnow() + if not not_before or not not_after: + raise CertificateError( + "Error while checking the dates of the certificate" + ) + if not_before > now: raise CertificateError( 'Certificate has not entered its valid date range.') diff --git a/slixmpp/xmlstream/handler/base.py b/slixmpp/xmlstream/handler/base.py index 1e657777..0bae5674 100644 --- a/slixmpp/xmlstream/handler/base.py +++ b/slixmpp/xmlstream/handler/base.py @@ -4,10 +4,19 @@ # Part of Slixmpp: The Slick XMPP Library # :copyright: (c) 2011 Nathanael C. Fritz # :license: MIT, see LICENSE for more details +from __future__ import annotations + import weakref +from weakref import ReferenceType +from typing import Optional, TYPE_CHECKING, Union +from slixmpp.xmlstream.matcher.base import MatcherBase +from xml.etree.ElementTree import Element + +if TYPE_CHECKING: + from slixmpp.xmlstream import XMLStream, StanzaBase -class BaseHandler(object): +class BaseHandler: """ Base class for stream handlers. Stream handlers are matched with @@ -26,8 +35,13 @@ class BaseHandler(object): :param stream: The :class:`~slixmpp.xmlstream.xmlstream.XMLStream` instance that the handle will respond to. """ + name: str + stream: Optional[ReferenceType[XMLStream]] + _destroy: bool + _matcher: MatcherBase + _payload: Optional[StanzaBase] - def __init__(self, name, matcher, stream=None): + def __init__(self, name: str, matcher: MatcherBase, stream: Optional[XMLStream] = None): #: The name of the handler self.name = name @@ -41,33 +55,33 @@ class BaseHandler(object): self._payload = None self._matcher = matcher - def match(self, xml): + def match(self, xml: StanzaBase) -> bool: """Compare a stanza or XML object with the handler's matcher. :param xml: An XML or - :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object + :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object """ return self._matcher.match(xml) - def prerun(self, payload): + def prerun(self, payload: StanzaBase) -> None: """Prepare the handler for execution while the XML stream is being processed. - :param payload: A :class:`~slixmpp.xmlstream.stanzabase.ElementBase` + :param payload: A :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object. """ self._payload = payload - def run(self, payload): + def run(self, payload: StanzaBase) -> None: """Execute the handler after XML stream processing and during the main event loop. - :param payload: A :class:`~slixmpp.xmlstream.stanzabase.ElementBase` + :param payload: A :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object. """ self._payload = payload - def check_delete(self): + def check_delete(self) -> bool: """Check if the handler should be removed from the list of stream handlers. """ diff --git a/slixmpp/xmlstream/handler/callback.py b/slixmpp/xmlstream/handler/callback.py index 93cec6b7..50dd4c66 100644 --- a/slixmpp/xmlstream/handler/callback.py +++ b/slixmpp/xmlstream/handler/callback.py @@ -1,10 +1,17 @@ - # slixmpp.xmlstream.handler.callback # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Part of Slixmpp: The Slick XMPP Library # :copyright: (c) 2011 Nathanael C. Fritz # :license: MIT, see LICENSE for more details +from __future__ import annotations + +from typing import Optional, Callable, Any, TYPE_CHECKING from slixmpp.xmlstream.handler.base import BaseHandler +from slixmpp.xmlstream.matcher.base import MatcherBase + +if TYPE_CHECKING: + from slixmpp.xmlstream.stanzabase import StanzaBase + from slixmpp.xmlstream.xmlstream import XMLStream class Callback(BaseHandler): @@ -28,8 +35,6 @@ class Callback(BaseHandler): :param matcher: A :class:`~slixmpp.xmlstream.matcher.base.MatcherBase` derived object for matching stanza objects. :param pointer: The function to execute during callback. - :param bool thread: **DEPRECATED.** Remains only for - backwards compatibility. :param bool once: Indicates if the handler should be used only once. Defaults to False. :param bool instream: Indicates if the callback should be executed @@ -38,31 +43,36 @@ class Callback(BaseHandler): :param stream: The :class:`~slixmpp.xmlstream.xmlstream.XMLStream` instance this handler should monitor. """ + _once: bool + _instream: bool - def __init__(self, name, matcher, pointer, thread=False, - once=False, instream=False, stream=None): + def __init__(self, name: str, matcher: MatcherBase, + pointer: Callable[[StanzaBase], Any], + once: bool = False, instream: bool = False, + stream: Optional[XMLStream] = None): BaseHandler.__init__(self, name, matcher, stream) + self._pointer: Callable[[StanzaBase], Any] = pointer self._pointer = pointer self._once = once self._instream = instream - def prerun(self, payload): + def prerun(self, payload: StanzaBase) -> None: """Execute the callback during stream processing, if the callback was created with ``instream=True``. :param payload: The matched - :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object. + :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object. """ if self._once: self._destroy = True if self._instream: self.run(payload, True) - def run(self, payload, instream=False): + def run(self, payload: StanzaBase, instream: bool = False) -> None: """Execute the callback function with the matched stanza payload. :param payload: The matched - :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object. + :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object. :param bool instream: Force the handler to execute during stream processing. This should only be used by :meth:`prerun()`. Defaults to ``False``. diff --git a/slixmpp/xmlstream/handler/collector.py b/slixmpp/xmlstream/handler/collector.py index 8d012873..a5ee109c 100644 --- a/slixmpp/xmlstream/handler/collector.py +++ b/slixmpp/xmlstream/handler/collector.py @@ -4,11 +4,17 @@ # Part of Slixmpp: The Slick XMPP Library # :copyright: (c) 2012 Nathanael C. Fritz, Lance J.T. Stout # :license: MIT, see LICENSE for more details +from __future__ import annotations + import logging -from queue import Queue, Empty +from typing import List, Optional, TYPE_CHECKING +from slixmpp.xmlstream.stanzabase import StanzaBase from slixmpp.xmlstream.handler.base import BaseHandler +from slixmpp.xmlstream.matcher.base import MatcherBase +if TYPE_CHECKING: + from slixmpp.xmlstream.xmlstream import XMLStream log = logging.getLogger(__name__) @@ -27,35 +33,35 @@ class Collector(BaseHandler): :param stream: The :class:`~slixmpp.xmlstream.xmlstream.XMLStream` instance this handler should monitor. """ + _stanzas: List[StanzaBase] - def __init__(self, name, matcher, stream=None): + def __init__(self, name: str, matcher: MatcherBase, stream: Optional[XMLStream] = None): BaseHandler.__init__(self, name, matcher, stream=stream) - self._payload = Queue() + self._stanzas = [] - def prerun(self, payload): + def prerun(self, payload: StanzaBase) -> None: """Store the matched stanza when received during processing. :param payload: The matched - :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object. + :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object. """ - self._payload.put(payload) + self._stanzas.append(payload) - def run(self, payload): + def run(self, payload: StanzaBase) -> None: """Do not process this handler during the main event loop.""" pass - def stop(self): + def stop(self) -> List[StanzaBase]: """ Stop collection of matching stanzas, and return the ones that have been stored so far. """ + stream_ref = self.stream + if stream_ref is None: + raise ValueError('stop() called without a stream!') + stream = stream_ref() + if stream is None: + raise ValueError('stop() called without a stream!') self._destroy = True - results = [] - try: - while True: - results.append(self._payload.get(False)) - except Empty: - pass - - self.stream().remove_handler(self.name) - return results + stream.remove_handler(self.name) + return self._stanzas diff --git a/slixmpp/xmlstream/handler/coroutine_callback.py b/slixmpp/xmlstream/handler/coroutine_callback.py index 6568ba9f..524cca54 100644 --- a/slixmpp/xmlstream/handler/coroutine_callback.py +++ b/slixmpp/xmlstream/handler/coroutine_callback.py @@ -4,8 +4,19 @@ # Part of Slixmpp: The Slick XMPP Library # :copyright: (c) 2011 Nathanael C. Fritz # :license: MIT, see LICENSE for more details +from __future__ import annotations + +from asyncio import iscoroutinefunction, ensure_future +from typing import Optional, Callable, Awaitable, TYPE_CHECKING + +from slixmpp.xmlstream.stanzabase import StanzaBase from slixmpp.xmlstream.handler.base import BaseHandler -from slixmpp.xmlstream.asyncio import asyncio +from slixmpp.xmlstream.matcher.base import MatcherBase + +CoroutineFunction = Callable[[StanzaBase], Awaitable[None]] + +if TYPE_CHECKING: + from slixmpp.xmlstream.xmlstream import XMLStream class CoroutineCallback(BaseHandler): @@ -34,45 +45,49 @@ class CoroutineCallback(BaseHandler): instance this handler should monitor. """ - def __init__(self, name, matcher, pointer, once=False, - instream=False, stream=None): + _once: bool + _instream: bool + + def __init__(self, name: str, matcher: MatcherBase, + pointer: CoroutineFunction, once: bool = False, + instream: bool = False, stream: Optional[XMLStream] = None): BaseHandler.__init__(self, name, matcher, stream) - if not asyncio.iscoroutinefunction(pointer): + if not iscoroutinefunction(pointer): raise ValueError("Given function is not a coroutine") - async def pointer_wrapper(stanza, *args, **kwargs): + async def pointer_wrapper(stanza: StanzaBase) -> None: try: - await pointer(stanza, *args, **kwargs) + await pointer(stanza) except Exception as e: stanza.exception(e) - self._pointer = pointer_wrapper + self._pointer: CoroutineFunction = pointer_wrapper self._once = once self._instream = instream - def prerun(self, payload): + def prerun(self, payload: StanzaBase) -> None: """Execute the callback during stream processing, if the callback was created with ``instream=True``. :param payload: The matched - :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object. + :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object. """ if self._once: self._destroy = True if self._instream: self.run(payload, True) - def run(self, payload, instream=False): + def run(self, payload: StanzaBase, instream: bool = False) -> None: """Execute the callback function with the matched stanza payload. :param payload: The matched - :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object. + :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object. :param bool instream: Force the handler to execute during stream processing. This should only be used by :meth:`prerun()`. Defaults to ``False``. """ if not self._instream or instream: - asyncio.ensure_future(self._pointer(payload)) + ensure_future(self._pointer(payload)) if self._once: self._destroy = True del self._pointer diff --git a/slixmpp/xmlstream/handler/waiter.py b/slixmpp/xmlstream/handler/waiter.py index 758cd1f1..dde49754 100644 --- a/slixmpp/xmlstream/handler/waiter.py +++ b/slixmpp/xmlstream/handler/waiter.py @@ -4,13 +4,20 @@ # Part of Slixmpp: The Slick XMPP Library # :copyright: (c) 2011 Nathanael C. Fritz # :license: MIT, see LICENSE for more details +from __future__ import annotations + import logging -import asyncio -from asyncio import Queue, wait_for, TimeoutError +from asyncio import Event, wait_for, TimeoutError +from typing import Optional, TYPE_CHECKING, Union +from xml.etree.ElementTree import Element import slixmpp +from slixmpp.xmlstream.stanzabase import StanzaBase from slixmpp.xmlstream.handler.base import BaseHandler +from slixmpp.xmlstream.matcher.base import MatcherBase +if TYPE_CHECKING: + from slixmpp.xmlstream.xmlstream import XMLStream log = logging.getLogger(__name__) @@ -28,24 +35,27 @@ class Waiter(BaseHandler): :param stream: The :class:`~slixmpp.xmlstream.xmlstream.XMLStream` instance this handler should monitor. """ + _event: Event - def __init__(self, name, matcher, stream=None): + def __init__(self, name: str, matcher: MatcherBase, stream: Optional[XMLStream] = None): BaseHandler.__init__(self, name, matcher, stream=stream) - self._payload = Queue() + self._event = Event() - def prerun(self, payload): + def prerun(self, payload: StanzaBase) -> None: """Store the matched stanza when received during processing. :param payload: The matched - :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object. + :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object. """ - self._payload.put_nowait(payload) + if not self._event.is_set(): + self._event.set() + self._payload = payload - def run(self, payload): + def run(self, payload: StanzaBase) -> None: """Do not process this handler during the main event loop.""" pass - async def wait(self, timeout=None): + async def wait(self, timeout: Optional[int] = None) -> Optional[StanzaBase]: """Block an event handler while waiting for a stanza to arrive. Be aware that this will impact performance if called from a @@ -59,17 +69,24 @@ class Waiter(BaseHandler): :class:`~slixmpp.xmlstream.xmlstream.XMLStream.response_timeout` value. """ + stream_ref = self.stream + if stream_ref is None: + raise ValueError('wait() called without a stream') + stream = stream_ref() + if stream is None: + raise ValueError('wait() called without a stream') if timeout is None: timeout = slixmpp.xmlstream.RESPONSE_TIMEOUT - stanza = None try: - stanza = await self._payload.get() + await wait_for( + self._event.wait(), timeout, loop=stream.loop + ) except TimeoutError: log.warning("Timed out waiting for %s", self.name) - self.stream().remove_handler(self.name) - return stanza + stream.remove_handler(self.name) + return self._payload - def check_delete(self): + def check_delete(self) -> bool: """Always remove waiters after use.""" return True diff --git a/slixmpp/xmlstream/handler/xmlcallback.py b/slixmpp/xmlstream/handler/xmlcallback.py index b44b2da1..c1adc815 100644 --- a/slixmpp/xmlstream/handler/xmlcallback.py +++ b/slixmpp/xmlstream/handler/xmlcallback.py @@ -4,6 +4,7 @@ # This file is part of Slixmpp. # See the file LICENSE for copying permission. from slixmpp.xmlstream.handler import Callback +from slixmpp.xmlstream.stanzabase import StanzaBase class XMLCallback(Callback): @@ -17,7 +18,7 @@ class XMLCallback(Callback): run -- Overrides Callback.run """ - def run(self, payload, instream=False): + def run(self, payload: StanzaBase, instream: bool = False) -> None: """ Execute the callback function with the matched stanza's XML contents, instead of the stanza itself. @@ -30,4 +31,4 @@ class XMLCallback(Callback): stream processing. Used only by prerun. Defaults to False. """ - Callback.run(self, payload.xml, instream) + Callback.run(self, payload.xml, instream) # type: ignore diff --git a/slixmpp/xmlstream/handler/xmlwaiter.py b/slixmpp/xmlstream/handler/xmlwaiter.py index 6eb6577e..f730efec 100644 --- a/slixmpp/xmlstream/handler/xmlwaiter.py +++ b/slixmpp/xmlstream/handler/xmlwaiter.py @@ -3,6 +3,7 @@ # Copyright (C) 2010 Nathanael C. Fritz # This file is part of Slixmpp. # See the file LICENSE for copying permission. +from slixmpp.xmlstream.stanzabase import StanzaBase from slixmpp.xmlstream.handler import Waiter @@ -17,7 +18,7 @@ class XMLWaiter(Waiter): prerun -- Overrides Waiter.prerun """ - def prerun(self, payload): + def prerun(self, payload: StanzaBase) -> None: """ Store the XML contents of the stanza to return to the waiting event handler. @@ -27,4 +28,4 @@ class XMLWaiter(Waiter): Arguments: payload -- The matched stanza object. """ - Waiter.prerun(self, payload.xml) + Waiter.prerun(self, payload.xml) # type: ignore diff --git a/slixmpp/xmlstream/matcher/base.py b/slixmpp/xmlstream/matcher/base.py index e2560d2b..552269c5 100644 --- a/slixmpp/xmlstream/matcher/base.py +++ b/slixmpp/xmlstream/matcher/base.py @@ -1,10 +1,13 @@ - # slixmpp.xmlstream.matcher.base # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Part of Slixmpp: The Slick XMPP Library # :copyright: (c) 2011 Nathanael C. Fritz # :license: MIT, see LICENSE for more details +from typing import Any +from slixmpp.xmlstream.stanzabase import StanzaBase + + class MatcherBase(object): """ @@ -15,10 +18,10 @@ class MatcherBase(object): :param criteria: Object to compare some aspect of a stanza against. """ - def __init__(self, criteria): + def __init__(self, criteria: Any): self._criteria = criteria - def match(self, xml): + def match(self, xml: StanzaBase) -> bool: """Check if a stanza matches the stored criteria. Meant to be overridden. diff --git a/slixmpp/xmlstream/matcher/id.py b/slixmpp/xmlstream/matcher/id.py index 44df2c18..e4e7ad4e 100644 --- a/slixmpp/xmlstream/matcher/id.py +++ b/slixmpp/xmlstream/matcher/id.py @@ -5,6 +5,7 @@ # :copyright: (c) 2011 Nathanael C. Fritz # :license: MIT, see LICENSE for more details from slixmpp.xmlstream.matcher.base import MatcherBase +from slixmpp.xmlstream.stanzabase import StanzaBase class MatcherId(MatcherBase): @@ -13,12 +14,13 @@ class MatcherId(MatcherBase): The ID matcher selects stanzas that have the same stanza 'id' interface value as the desired ID. """ + _criteria: str - def match(self, xml): + def match(self, xml: StanzaBase) -> bool: """Compare the given stanza's ``'id'`` attribute to the stored ``id`` value. - :param xml: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase` + :param xml: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` stanza to compare against. """ - return xml['id'] == self._criteria + return bool(xml['id'] == self._criteria) diff --git a/slixmpp/xmlstream/matcher/idsender.py b/slixmpp/xmlstream/matcher/idsender.py index 8f5d0303..572f9d87 100644 --- a/slixmpp/xmlstream/matcher/idsender.py +++ b/slixmpp/xmlstream/matcher/idsender.py @@ -4,7 +4,19 @@ # Part of Slixmpp: The Slick XMPP Library # :copyright: (c) 2011 Nathanael C. Fritz # :license: MIT, see LICENSE for more details + from slixmpp.xmlstream.matcher.base import MatcherBase +from slixmpp.xmlstream.stanzabase import StanzaBase +from slixmpp.jid import JID +from slixmpp.types import TypedDict + +from typing import Dict + + +class CriteriaType(TypedDict): + self: JID + peer: JID + id: str class MatchIDSender(MatcherBase): @@ -14,25 +26,26 @@ class MatchIDSender(MatcherBase): interface value as the desired ID, and that the 'from' value is one of a set of approved entities that can respond to a request. """ + _criteria: CriteriaType - def match(self, xml): + def match(self, xml: StanzaBase) -> bool: """Compare the given stanza's ``'id'`` attribute to the stored ``id`` value, and verify the sender's JID. - :param xml: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase` + :param xml: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` stanza to compare against. """ selfjid = self._criteria['self'] peerjid = self._criteria['peer'] - allowed = {} + allowed: Dict[str, bool] = {} allowed[''] = True allowed[selfjid.bare] = True - allowed[selfjid.host] = True + allowed[selfjid.domain] = True allowed[peerjid.full] = True allowed[peerjid.bare] = True - allowed[peerjid.host] = True + allowed[peerjid.domain] = True _from = xml['from'] diff --git a/slixmpp/xmlstream/matcher/many.py b/slixmpp/xmlstream/matcher/many.py index e8ad54e7..dae45463 100644 --- a/slixmpp/xmlstream/matcher/many.py +++ b/slixmpp/xmlstream/matcher/many.py @@ -3,7 +3,9 @@ # Copyright (C) 2010 Nathanael C. Fritz # This file is part of Slixmpp. # See the file LICENSE for copying permission. +from typing import Iterable from slixmpp.xmlstream.matcher.base import MatcherBase +from slixmpp.xmlstream.stanzabase import StanzaBase class MatchMany(MatcherBase): @@ -18,8 +20,9 @@ class MatchMany(MatcherBase): Methods: match -- Overrides MatcherBase.match. """ + _criteria: Iterable[MatcherBase] - def match(self, xml): + def match(self, xml: StanzaBase) -> bool: """ Match a stanza against multiple criteria. The match is successful if one of the criteria matches. diff --git a/slixmpp/xmlstream/matcher/stanzapath.py b/slixmpp/xmlstream/matcher/stanzapath.py index 1bf3fa8e..b7de3ee2 100644 --- a/slixmpp/xmlstream/matcher/stanzapath.py +++ b/slixmpp/xmlstream/matcher/stanzapath.py @@ -4,8 +4,9 @@ # Part of Slixmpp: The Slick XMPP Library # :copyright: (c) 2011 Nathanael C. Fritz # :license: MIT, see LICENSE for more details +from typing import cast, List from slixmpp.xmlstream.matcher.base import MatcherBase -from slixmpp.xmlstream.stanzabase import fix_ns +from slixmpp.xmlstream.stanzabase import fix_ns, StanzaBase class StanzaPath(MatcherBase): @@ -17,22 +18,28 @@ class StanzaPath(MatcherBase): :param criteria: Object to compare some aspect of a stanza against. """ - - def __init__(self, criteria): - self._criteria = fix_ns(criteria, split=True, - propagate_ns=False, - default_ns='jabber:client') + _criteria: List[str] + _raw_criteria: str + + def __init__(self, criteria: str): + self._criteria = cast( + List[str], + fix_ns( + criteria, split=True, propagate_ns=False, + default_ns='jabber:client' + ) + ) self._raw_criteria = criteria - def match(self, stanza): + def match(self, stanza: StanzaBase) -> bool: """ Compare a stanza against a "stanza path". A stanza path is similar to an XPath expression, but uses the stanza's interfaces and plugins instead of the underlying XML. See the documentation for the stanza - :meth:`~slixmpp.xmlstream.stanzabase.ElementBase.match()` method + :meth:`~slixmpp.xmlstream.stanzabase.StanzaBase.match()` method for more information. - :param stanza: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase` + :param stanza: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` stanza to compare against. """ return stanza.match(self._criteria) or stanza.match(self._raw_criteria) diff --git a/slixmpp/xmlstream/matcher/xmlmask.py b/slixmpp/xmlstream/matcher/xmlmask.py index d50b706e..b63e0f05 100644 --- a/slixmpp/xmlstream/matcher/xmlmask.py +++ b/slixmpp/xmlstream/matcher/xmlmask.py @@ -1,4 +1,3 @@ - # Slixmpp: The Slick XMPP Library # Copyright (C) 2010 Nathanael C. Fritz # This file is part of Slixmpp. @@ -6,8 +5,9 @@ import logging from xml.parsers.expat import ExpatError +from xml.etree.ElementTree import Element -from slixmpp.xmlstream.stanzabase import ET +from slixmpp.xmlstream.stanzabase import ET, StanzaBase from slixmpp.xmlstream.matcher.base import MatcherBase @@ -33,32 +33,33 @@ class MatchXMLMask(MatcherBase): :param criteria: Either an :class:`~xml.etree.ElementTree.Element` XML object or XML string to use as a mask. """ + _criteria: Element - def __init__(self, criteria, default_ns='jabber:client'): + def __init__(self, criteria: str, default_ns: str = 'jabber:client'): MatcherBase.__init__(self, criteria) if isinstance(criteria, str): - self._criteria = ET.fromstring(self._criteria) + self._criteria = ET.fromstring(criteria) self.default_ns = default_ns - def setDefaultNS(self, ns): + def setDefaultNS(self, ns: str) -> None: """Set the default namespace to use during comparisons. :param ns: The new namespace to use as the default. """ self.default_ns = ns - def match(self, xml): + def match(self, xml: StanzaBase) -> bool: """Compare a stanza object or XML object against the stored XML mask. Overrides MatcherBase.match. :param xml: The stanza object or XML object to compare against. """ - if hasattr(xml, 'xml'): - xml = xml.xml - return self._mask_cmp(xml, self._criteria, True) + real_xml = xml.xml + return self._mask_cmp(real_xml, self._criteria, True) - def _mask_cmp(self, source, mask, use_ns=False, default_ns='__no_ns__'): + def _mask_cmp(self, source: Element, mask: Element, use_ns: bool = False, + default_ns: str = '__no_ns__') -> bool: """Compare an XML object against an XML mask. :param source: The :class:`~xml.etree.ElementTree.Element` XML object @@ -75,13 +76,6 @@ class MatchXMLMask(MatcherBase): # If the element was not found. May happen during recursive calls. return False - # Convert the mask to an XML object if it is a string. - if not hasattr(mask, 'attrib'): - try: - mask = ET.fromstring(mask) - except ExpatError: - log.warning("Expat error: %s\nIn parsing: %s", '', mask) - mask_ns_tag = "{%s}%s" % (self.default_ns, mask.tag) if source.tag not in [mask.tag, mask_ns_tag]: return False diff --git a/slixmpp/xmlstream/matcher/xpath.py b/slixmpp/xmlstream/matcher/xpath.py index b7503b73..bd41b60a 100644 --- a/slixmpp/xmlstream/matcher/xpath.py +++ b/slixmpp/xmlstream/matcher/xpath.py @@ -4,7 +4,8 @@ # Part of Slixmpp: The Slick XMPP Library # :copyright: (c) 2011 Nathanael C. Fritz # :license: MIT, see LICENSE for more details -from slixmpp.xmlstream.stanzabase import ET, fix_ns +from typing import cast +from slixmpp.xmlstream.stanzabase import ET, fix_ns, StanzaBase from slixmpp.xmlstream.matcher.base import MatcherBase @@ -17,23 +18,23 @@ class MatchXPath(MatcherBase): If the value of :data:`IGNORE_NS` is set to ``True``, then XPath expressions will be matched without using namespaces. """ + _criteria: str - def __init__(self, criteria): - self._criteria = fix_ns(criteria) + def __init__(self, criteria: str): + self._criteria = cast(str, fix_ns(criteria)) - def match(self, xml): + def match(self, xml: StanzaBase) -> bool: """ Compare a stanza's XML contents to an XPath expression. If the value of :data:`IGNORE_NS` is set to ``True``, then XPath expressions will be matched without using namespaces. - :param xml: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase` + :param xml: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` stanza to compare against. """ - if hasattr(xml, 'xml'): - xml = xml.xml + real_xml = xml.xml x = ET.Element('x') - x.append(xml) + x.append(real_xml) return x.find(self._criteria) is not None diff --git a/slixmpp/xmlstream/resolver.py b/slixmpp/xmlstream/resolver.py index 97798353..e524da3b 100644 --- a/slixmpp/xmlstream/resolver.py +++ b/slixmpp/xmlstream/resolver.py @@ -1,18 +1,32 @@ - # slixmpp.xmlstream.dns # ~~~~~~~~~~~~~~~~~~~~~~~ # :copyright: (c) 2012 Nathanael C. Fritz # :license: MIT, see LICENSE for more details -from slixmpp.xmlstream.asyncio import asyncio import socket +import sys import logging import random +from asyncio import Future, AbstractEventLoop +from typing import Optional, Tuple, Dict, List, Iterable, cast +from slixmpp.types import Protocol log = logging.getLogger(__name__) +class AnswerProtocol(Protocol): + host: str + priority: int + weight: int + port: int + + +class ResolverProtocol(Protocol): + def query(self, query: str, querytype: str) -> Future: + ... + + #: Global flag indicating the availability of the ``aiodns`` package. #: Installing ``aiodns`` can be done via: #: @@ -23,12 +37,12 @@ AIODNS_AVAILABLE = False try: import aiodns AIODNS_AVAILABLE = True -except ImportError as e: - log.debug("Could not find aiodns package. " + \ +except ImportError: + log.debug("Could not find aiodns package. " "Not all features will be available") -def default_resolver(loop): +def default_resolver(loop: AbstractEventLoop) -> Optional[ResolverProtocol]: """Return a basic DNS resolver object. :returns: A :class:`aiodns.DNSResolver` object if aiodns @@ -41,8 +55,11 @@ def default_resolver(loop): return None -async def resolve(host, port=None, service=None, proto='tcp', - resolver=None, use_ipv6=True, use_aiodns=True, loop=None): +async def resolve(host: str, port: int, *, loop: AbstractEventLoop, + service: Optional[str] = None, proto: str = 'tcp', + resolver: Optional[ResolverProtocol] = None, + use_ipv6: bool = True, + use_aiodns: bool = True) -> List[Tuple[str, str, int]]: """Peform DNS resolution for a given hostname. Resolution may perform SRV record lookups if a service and protocol @@ -91,8 +108,8 @@ async def resolve(host, port=None, service=None, proto='tcp', if not use_ipv6: log.debug("DNS: Use of IPv6 has been disabled.") - if resolver is None and AIODNS_AVAILABLE and use_aiodns: - resolver = aiodns.DNSResolver(loop=loop) + if resolver is None and use_aiodns: + resolver = default_resolver(loop=loop) # An IPv6 literal is allowed to be enclosed in square brackets, but # the brackets must be stripped in order to process the literal; @@ -101,7 +118,7 @@ async def resolve(host, port=None, service=None, proto='tcp', try: # If `host` is an IPv4 literal, we can return it immediately. - ipv4 = socket.inet_aton(host) + socket.inet_aton(host) return [(host, host, port)] except socket.error: pass @@ -111,7 +128,7 @@ async def resolve(host, port=None, service=None, proto='tcp', # Likewise, If `host` is an IPv6 literal, we can return # it immediately. if hasattr(socket, 'inet_pton'): - ipv6 = socket.inet_pton(socket.AF_INET6, host) + socket.inet_pton(socket.AF_INET6, host) return [(host, host, port)] except (socket.error, ValueError): pass @@ -148,7 +165,10 @@ async def resolve(host, port=None, service=None, proto='tcp', return results -async def get_A(host, resolver=None, use_aiodns=True, loop=None): + +async def get_A(host: str, *, loop: AbstractEventLoop, + resolver: Optional[ResolverProtocol] = None, + use_aiodns: bool = True) -> List[str]: """Lookup DNS A records for a given host. If ``resolver`` is not provided, or is ``None``, then resolution will @@ -172,10 +192,10 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None): # getaddrinfo() method. if resolver is None or not use_aiodns: try: - recs = await loop.getaddrinfo(host, None, + inet_recs = await loop.getaddrinfo(host, None, family=socket.AF_INET, type=socket.SOCK_STREAM) - return [rec[4][0] for rec in recs] + return [rec[4][0] for rec in inet_recs] except socket.gaierror: log.debug("DNS: Error retrieving A address info for %s." % host) return [] @@ -183,14 +203,16 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None): # Using aiodns: future = resolver.query(host, 'A') try: - recs = await future + recs = cast(Iterable[AnswerProtocol], await future) except Exception as e: log.debug('DNS: Exception while querying for %s A records: %s', host, e) recs = [] return [rec.host for rec in recs] -async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None): +async def get_AAAA(host: str, *, loop: AbstractEventLoop, + resolver: Optional[ResolverProtocol] = None, + use_aiodns: bool = True) -> List[str]: """Lookup DNS AAAA records for a given host. If ``resolver`` is not provided, or is ``None``, then resolution will @@ -217,10 +239,10 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None): log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host) return [] try: - recs = await loop.getaddrinfo(host, None, + inet_recs = await loop.getaddrinfo(host, None, family=socket.AF_INET6, type=socket.SOCK_STREAM) - return [rec[4][0] for rec in recs] + return [rec[4][0] for rec in inet_recs] except (OSError, socket.gaierror): log.debug("DNS: Error retrieving AAAA address " + \ "info for %s." % host) @@ -229,13 +251,17 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None): # Using aiodns: future = resolver.query(host, 'AAAA') try: - recs = await future + recs = cast(Iterable[AnswerProtocol], await future) except Exception as e: log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e) recs = [] return [rec.host for rec in recs] -async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True): + +async def get_SRV(host: str, port: int, service: str, + proto: str = 'tcp', + resolver: Optional[ResolverProtocol] = None, + use_aiodns: bool = True) -> List[Tuple[str, int]]: """Perform SRV record resolution for a given host. .. note:: @@ -269,12 +295,12 @@ async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=Tr try: future = resolver.query('_%s._%s.%s' % (service, proto, host), 'SRV') - recs = await future + recs = cast(Iterable[AnswerProtocol], await future) except Exception as e: log.debug('DNS: Exception while querying for %s SRV records: %s', host, e) return [] - answers = {} + answers: Dict[int, List[AnswerProtocol]] = {} for rec in recs: if rec.priority not in answers: answers[rec.priority] = [] diff --git a/slixmpp/xmlstream/stanzabase.py b/slixmpp/xmlstream/stanzabase.py index 7679f73a..2f2faa8d 100644 --- a/slixmpp/xmlstream/stanzabase.py +++ b/slixmpp/xmlstream/stanzabase.py @@ -1,4 +1,3 @@ - # slixmpp.xmlstream.stanzabase # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # module implements a wrapper layer for XML objects @@ -11,13 +10,34 @@ from __future__ import annotations import copy import logging import weakref -from typing import Optional +from typing import ( + cast, + Any, + Callable, + ClassVar, + Coroutine, + Dict, + List, + Iterable, + Optional, + Set, + Tuple, + Type, + TYPE_CHECKING, + Union, +) +from weakref import ReferenceType from xml.etree import ElementTree as ET +from slixmpp.types import JidStr from slixmpp.xmlstream import JID from slixmpp.xmlstream.tostring import tostring +if TYPE_CHECKING: + from slixmpp.xmlstream import XMLStream + + log = logging.getLogger(__name__) @@ -28,7 +48,8 @@ XML_TYPE = type(ET.Element('xml')) XML_NS = 'http://www.w3.org/XML/1998/namespace' -def register_stanza_plugin(stanza, plugin, iterable=False, overrides=False): +def register_stanza_plugin(stanza: Type[ElementBase], plugin: Type[ElementBase], + iterable: bool = False, overrides: bool = False) -> None: """ Associate a stanza object as a plugin for another stanza. @@ -85,15 +106,15 @@ def register_stanza_plugin(stanza, plugin, iterable=False, overrides=False): stanza.plugin_overrides[interface] = plugin.plugin_attrib -def multifactory(stanza, plugin_attrib): +def multifactory(stanza: Type[ElementBase], plugin_attrib: str) -> Type[ElementBase]: """ Returns a ElementBase class for handling reoccuring child stanzas """ - def plugin_filter(self): + def plugin_filter(self: Multi) -> Callable[..., bool]: return lambda x: isinstance(x, self._multistanza) - def plugin_lang_filter(self, lang): + def plugin_lang_filter(self: Multi, lang: Optional[str]) -> Callable[..., bool]: return lambda x: isinstance(x, self._multistanza) and \ x['lang'] == lang @@ -101,31 +122,41 @@ def multifactory(stanza, plugin_attrib): """ Template class for multifactory """ - def setup(self, xml=None): + _multistanza: Type[ElementBase] + + def setup(self, xml: Optional[ET.Element] = None) -> bool: self.xml = ET.Element('') + return False - def get_multi(self, lang=None): - parent = self.parent() + def get_multi(self: Multi, lang: Optional[str] = None) -> List[ElementBase]: + parent = fail_without_parent(self) if not lang or lang == '*': res = filter(plugin_filter(self), parent) else: - res = filter(plugin_filter(self, lang), parent) + res = filter(plugin_lang_filter(self, lang), parent) return list(res) - def set_multi(self, val, lang=None): - parent = self.parent() + def set_multi(self: Multi, val: Iterable[ElementBase], lang: Optional[str] = None) -> None: + parent = fail_without_parent(self) del_multi = getattr(self, 'del_%s' % plugin_attrib) del_multi(lang) for sub in val: parent.append(sub) - def del_multi(self, lang=None): - parent = self.parent() + def fail_without_parent(self: Multi) -> ElementBase: + parent = None + if self.parent: + parent = self.parent() + if not parent: + raise ValueError('No stanza parent for multifactory') + return parent + + def del_multi(self: Multi, lang: Optional[str] = None) -> None: + parent = fail_without_parent(self) if not lang or lang == '*': - res = filter(plugin_filter(self), parent) + res = list(filter(plugin_filter(self), parent)) else: - res = filter(plugin_filter(self, lang), parent) - res = list(res) + res = list(filter(plugin_lang_filter(self, lang), parent)) if not res: del parent.plugins[(plugin_attrib, None)] parent.loaded_plugins.remove(plugin_attrib) @@ -149,7 +180,8 @@ def multifactory(stanza, plugin_attrib): return Multi -def fix_ns(xpath, split=False, propagate_ns=True, default_ns=''): +def fix_ns(xpath: str, split: bool = False, propagate_ns: bool = True, + default_ns: str = '') -> Union[str, List[str]]: """Apply the stanza's namespace to elements in an XPath expression. :param string xpath: The XPath expression to fix with namespaces. @@ -275,12 +307,12 @@ class ElementBase(object): #: The XML tag name of the element, not including any namespace #: prefixes. For example, an :class:`ElementBase` object for #: ``<message />`` would use ``name = 'message'``. - name = 'stanza' + name: ClassVar[str] = 'stanza' #: The XML namespace for the element. Given ``<foo xmlns="bar" />``, #: then ``namespace = "bar"`` should be used. The default namespace #: is ``jabber:client`` since this is being used in an XMPP library. - namespace = 'jabber:client' + namespace: str = 'jabber:client' #: For :class:`ElementBase` subclasses which are intended to be used #: as plugins, the ``plugin_attrib`` value defines the plugin name. @@ -290,7 +322,7 @@ class ElementBase(object): #: register_stanza_plugin(Message, FooPlugin) #: msg = Message() #: msg['foo']['an_interface_from_the_foo_plugin'] - plugin_attrib = 'plugin' + plugin_attrib: ClassVar[str] = 'plugin' #: For :class:`ElementBase` subclasses that are intended to be an #: iterable group of items, the ``plugin_multi_attrib`` value defines @@ -300,29 +332,29 @@ class ElementBase(object): #: # Given stanza class Foo, with plugin_multi_attrib = 'foos' #: parent['foos'] #: filter(isinstance(item, Foo), parent['substanzas']) - plugin_multi_attrib = '' + plugin_multi_attrib: ClassVar[str] = '' #: The set of keys that the stanza provides for accessing and #: manipulating the underlying XML object. This set may be augmented #: with the :attr:`plugin_attrib` value of any registered #: stanza plugins. - interfaces = {'type', 'to', 'from', 'id', 'payload'} + interfaces: ClassVar[Set[str]] = {'type', 'to', 'from', 'id', 'payload'} #: A subset of :attr:`interfaces` which maps interfaces to direct #: subelements of the underlying XML object. Using this set, the text #: of these subelements may be set, retrieved, or removed without #: needing to define custom methods. - sub_interfaces = set() + sub_interfaces: ClassVar[Set[str]] = set() #: A subset of :attr:`interfaces` which maps the presence of #: subelements to boolean values. Using this set allows for quickly #: checking for the existence of empty subelements like ``<required />``. #: #: .. versionadded:: 1.1 - bool_interfaces = set() + bool_interfaces: ClassVar[Set[str]] = set() #: .. versionadded:: 1.1.2 - lang_interfaces = set() + lang_interfaces: ClassVar[Set[str]] = set() #: In some cases you may wish to override the behaviour of one of the #: parent stanza's interfaces. The ``overrides`` list specifies the @@ -336,7 +368,7 @@ class ElementBase(object): #: be affected. #: #: .. versionadded:: 1.0-Beta5 - overrides = [] + overrides: ClassVar[List[str]] = [] #: If you need to add a new interface to an existing stanza, you #: can create a plugin and set ``is_extension = True``. Be sure @@ -346,7 +378,7 @@ class ElementBase(object): #: parent stanza will be passed to the plugin directly. #: #: .. versionadded:: 1.0-Beta5 - is_extension = False + is_extension: ClassVar[bool] = False #: A map of interface operations to the overriding functions. #: For example, after overriding the ``set`` operation for @@ -355,15 +387,15 @@ class ElementBase(object): #: {'set_body': <some function>} #: #: .. versionadded: 1.0-Beta5 - plugin_overrides = {} + plugin_overrides: ClassVar[Dict[str, str]] = {} #: A mapping of the :attr:`plugin_attrib` values of registered #: plugins to their respective classes. - plugin_attrib_map = {} + plugin_attrib_map: ClassVar[Dict[str, Type[ElementBase]]] = {} #: A mapping of root element tag names (in ``'{namespace}elementname'`` #: format) to the plugin classes responsible for them. - plugin_tag_map = {} + plugin_tag_map: ClassVar[Dict[str, Type[ElementBase]]] = {} #: The set of stanza classes that can be iterated over using #: the 'substanzas' interface. Classes are added to this set @@ -372,17 +404,26 @@ class ElementBase(object): #: register_stanza_plugin(DiscoInfo, DiscoItem, iterable=True) #: #: .. versionadded:: 1.0-Beta5 - plugin_iterables = set() + plugin_iterables: ClassVar[Set[Type[ElementBase]]] = set() #: The default XML namespace: ``http://www.w3.org/XML/1998/namespace``. - xml_ns = XML_NS - - def __init__(self, xml=None, parent=None): + xml_ns: ClassVar[str] = XML_NS + + plugins: Dict[Tuple[str, Optional[str]], ElementBase] + #: The underlying XML object for the stanza. It is a standard + #: :class:`xml.etree.ElementTree` object. + xml: ET.Element + _index: int + loaded_plugins: Set[str] + iterables: List[ElementBase] + tag: str + parent: Optional[ReferenceType[ElementBase]] + + def __init__(self, xml: Optional[ET.Element] = None, parent: Union[Optional[ElementBase], ReferenceType[ElementBase]] = None): self._index = 0 - #: The underlying XML object for the stanza. It is a standard - #: :class:`xml.etree.ElementTree` object. - self.xml = xml + if xml is not None: + self.xml = xml #: An ordered dictionary of plugin stanzas, mapped by their #: :attr:`plugin_attrib` value. @@ -419,7 +460,7 @@ class ElementBase(object): existing_xml=child, reuse=False) - def setup(self, xml=None): + def setup(self, xml: Optional[ET.Element] = None) -> bool: """Initialize the stanza's XML contents. Will return ``True`` if XML was generated according to the stanza's @@ -429,29 +470,31 @@ class ElementBase(object): :param xml: An existing XML object to use for the stanza's content instead of generating new XML. """ - if self.xml is None: + if hasattr(self, 'xml'): + return False + if not hasattr(self, 'xml') and xml is not None: self.xml = xml + return False - last_xml = self.xml - if self.xml is None: - # Generate XML from the stanza definition - for ename in self.name.split('/'): - new = ET.Element("{%s}%s" % (self.namespace, ename)) - if self.xml is None: - self.xml = new - else: - last_xml.append(new) - last_xml = new - if self.parent is not None: - self.parent().xml.append(self.xml) - # We had to generate XML - return True - else: - # We did not generate XML - return False + # Generate XML from the stanza definition + last_xml = ET.Element('') + for ename in self.name.split('/'): + new = ET.Element("{%s}%s" % (self.namespace, ename)) + if not hasattr(self, 'xml'): + self.xml = new + else: + last_xml.append(new) + last_xml = new + if self.parent is not None: + parent = self.parent() + if parent: + parent.xml.append(self.xml) + + # We had to generate XML + return True - def enable(self, attrib, lang=None): + def enable(self, attrib: str, lang: Optional[str] = None) -> ElementBase: """Enable and initialize a stanza plugin. Alias for :meth:`init_plugin`. @@ -487,7 +530,10 @@ class ElementBase(object): else: return None if check else self.init_plugin(name, lang) - def init_plugin(self, attrib, lang=None, existing_xml=None, element=None, reuse=True): + def init_plugin(self, attrib: str, lang: Optional[str] = None, + existing_xml: Optional[ET.Element] = None, + reuse: bool = True, + element: Optional[ElementBase] = None) -> ElementBase: """Enable and initialize a stanza plugin. :param string attrib: The :attr:`plugin_attrib` value of the @@ -525,7 +571,7 @@ class ElementBase(object): return plugin - def _get_stanza_values(self): + def _get_stanza_values(self) -> Dict[str, Any]: """Return A JSON/dictionary version of the XML content exposed through the stanza's interfaces:: @@ -567,7 +613,7 @@ class ElementBase(object): values['substanzas'] = iterables return values - def _set_stanza_values(self, values): + def _set_stanza_values(self, values: Dict[str, Any]) -> ElementBase: """Set multiple stanza interface values using a dictionary. Stanza plugin values may be set using nested dictionaries. @@ -623,7 +669,7 @@ class ElementBase(object): plugin.values = value return self - def __getitem__(self, full_attrib): + def __getitem__(self, full_attrib: str) -> Any: """Return the value of a stanza interface using dict-like syntax. Example:: @@ -688,7 +734,7 @@ class ElementBase(object): else: return '' - def __setitem__(self, attrib, value): + def __setitem__(self, attrib: str, value: Any) -> Any: """Set the value of a stanza interface using dictionary-like syntax. Example:: @@ -773,7 +819,7 @@ class ElementBase(object): plugin[full_attrib] = value return self - def __delitem__(self, attrib): + def __delitem__(self, attrib: str) -> Any: """Delete the value of a stanza interface using dict-like syntax. Example:: @@ -851,7 +897,7 @@ class ElementBase(object): pass return self - def _set_attr(self, name, value): + def _set_attr(self, name: str, value: Optional[JidStr]) -> None: """Set the value of a top level attribute of the XML object. If the new value is None or an empty string, then the attribute will @@ -868,7 +914,7 @@ class ElementBase(object): value = str(value) self.xml.attrib[name] = value - def _del_attr(self, name): + def _del_attr(self, name: str) -> None: """Remove a top level attribute of the XML object. :param name: The name of the attribute. @@ -876,7 +922,7 @@ class ElementBase(object): if name in self.xml.attrib: del self.xml.attrib[name] - def _get_attr(self, name, default=''): + def _get_attr(self, name: str, default: str = '') -> str: """Return the value of a top level attribute of the XML object. In case the attribute has not been set, a default value can be @@ -889,7 +935,8 @@ class ElementBase(object): """ return self.xml.attrib.get(name, default) - def _get_sub_text(self, name, default='', lang=None): + def _get_sub_text(self, name: str, default: str = '', + lang: Optional[str] = None) -> Union[str, Dict[str, str]]: """Return the text contents of a sub element. In case the element does not exist, or it has no textual content, @@ -900,7 +947,7 @@ class ElementBase(object): :param default: Optional default to return if the element does not exists. An empty string is returned otherwise. """ - name = self._fix_ns(name) + name = cast(str, self._fix_ns(name)) if lang == '*': return self._get_all_sub_text(name, default, None) @@ -924,8 +971,9 @@ class ElementBase(object): return result return default - def _get_all_sub_text(self, name, default='', lang=None): - name = self._fix_ns(name) + def _get_all_sub_text(self, name: str, default: str = '', + lang: Optional[str] = None) -> Dict[str, str]: + name = cast(str, self._fix_ns(name)) default_lang = self.get_lang() results = {} @@ -935,10 +983,16 @@ class ElementBase(object): stanza_lang = stanza.attrib.get('{%s}lang' % XML_NS, default_lang) if not lang or lang == '*' or stanza_lang == lang: - results[stanza_lang] = stanza.text + if stanza.text is None: + text = default + else: + text = stanza.text + results[stanza_lang] = text return results - def _set_sub_text(self, name, text=None, keep=False, lang=None): + def _set_sub_text(self, name: str, text: Optional[str] = None, + keep: bool = False, + lang: Optional[str] = None) -> Optional[ET.Element]: """Set the text contents of a sub element. In case the element does not exist, a element will be created, @@ -959,15 +1013,16 @@ class ElementBase(object): lang = default_lang if not text and not keep: - return self._del_sub(name, lang=lang) + self._del_sub(name, lang=lang) + return None - path = self._fix_ns(name, split=True) + path = cast(List[str], self._fix_ns(name, split=True)) name = path[-1] - parent = self.xml + parent: Optional[ET.Element] = self.xml # The first goal is to find the parent of the subelement, or, if # we can't find that, the closest grandparent element. - missing_path = [] + missing_path: List[str] = [] search_order = path[:-1] while search_order: parent = self.xml.find('/'.join(search_order)) @@ -1008,15 +1063,17 @@ class ElementBase(object): parent.append(element) return element - def _set_all_sub_text(self, name, values, keep=False, lang=None): - self._del_sub(name, lang) + def _set_all_sub_text(self, name: str, values: Dict[str, str], + keep: bool = False, + lang: Optional[str] = None) -> None: + self._del_sub(name, lang=lang) for value_lang, value in values.items(): if not lang or lang == '*' or value_lang == lang: self._set_sub_text(name, text=value, keep=keep, lang=value_lang) - def _del_sub(self, name, all=False, lang=None): + def _del_sub(self, name: str, all: bool = False, lang: Optional[str] = None) -> None: """Remove sub elements that match the given name or XPath. If the element is in a path, then any parent elements that become @@ -1034,11 +1091,11 @@ class ElementBase(object): if not lang: lang = default_lang - parent = self.xml + parent: Optional[ET.Element] = self.xml for level, _ in enumerate(path): # Generate the paths to the target elements and their parent. element_path = "/".join(path[:len(path) - level]) - parent_path = "/".join(path[:len(path) - level - 1]) + parent_path: Optional[str] = "/".join(path[:len(path) - level - 1]) elements = self.xml.findall(element_path) if parent_path == '': @@ -1061,7 +1118,7 @@ class ElementBase(object): # after deleting the first level of elements. return - def match(self, xpath): + def match(self, xpath: Union[str, List[str]]) -> bool: """Compare a stanza object with an XPath-like expression. If the XPath matches the contents of the stanza object, the match @@ -1127,7 +1184,7 @@ class ElementBase(object): # Everything matched. return True - def get(self, key, default=None): + def get(self, key: str, default: Optional[Any] = None) -> Any: """Return the value of a stanza interface. If the found value is None or an empty string, return the supplied @@ -1144,7 +1201,7 @@ class ElementBase(object): return default return value - def keys(self): + def keys(self) -> List[str]: """Return the names of all stanza interfaces provided by the stanza object. @@ -1158,7 +1215,7 @@ class ElementBase(object): out.append('substanzas') return out - def append(self, item): + def append(self, item: Union[ET.Element, ElementBase]) -> ElementBase: """Append either an XML object or a substanza to this stanza object. If a substanza object is appended, it will be added to the list @@ -1189,7 +1246,7 @@ class ElementBase(object): return self - def appendxml(self, xml): + def appendxml(self, xml: ET.Element) -> ElementBase: """Append an XML object to the stanza's XML. The added XML will not be included in the list of @@ -1200,7 +1257,7 @@ class ElementBase(object): self.xml.append(xml) return self - def pop(self, index=0): + def pop(self, index: int = 0) -> ElementBase: """Remove and return the last substanza in the list of iterable substanzas. @@ -1212,11 +1269,11 @@ class ElementBase(object): self.xml.remove(substanza.xml) return substanza - def next(self): + def next(self) -> ElementBase: """Return the next iterable substanza.""" return self.__next__() - def clear(self): + def clear(self) -> ElementBase: """Remove all XML element contents and plugins. Any attribute values will be preserved. @@ -1229,7 +1286,7 @@ class ElementBase(object): return self @classmethod - def tag_name(cls): + def tag_name(cls) -> str: """Return the namespaced name of the stanza's root element. The format for the tag name is:: @@ -1241,29 +1298,32 @@ class ElementBase(object): """ return "{%s}%s" % (cls.namespace, cls.name) - def get_lang(self, lang=None): + def get_lang(self, lang: Optional[str] = None) -> str: result = self.xml.attrib.get('{%s}lang' % XML_NS, '') - if not result and self.parent and self.parent(): - return self.parent()['lang'] + if not result and self.parent: + parent = self.parent() + if parent: + return cast(str, parent['lang']) return result - def set_lang(self, lang): + def set_lang(self, lang: Optional[str]) -> None: self.del_lang() attr = '{%s}lang' % XML_NS if lang: self.xml.attrib[attr] = lang - def del_lang(self): + def del_lang(self) -> None: attr = '{%s}lang' % XML_NS if attr in self.xml.attrib: del self.xml.attrib[attr] - def _fix_ns(self, xpath, split=False, propagate_ns=True): + def _fix_ns(self, xpath: str, split: bool = False, + propagate_ns: bool = True) -> Union[str, List[str]]: return fix_ns(xpath, split=split, propagate_ns=propagate_ns, default_ns=self.namespace) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: """Compare the stanza object with another to test for equality. Stanzas are equal if their interfaces return the same values, @@ -1290,7 +1350,7 @@ class ElementBase(object): # must be equal. return True - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: """Compare the stanza object with another to test for inequality. Stanzas are not equal if their interfaces return different values, @@ -1300,16 +1360,16 @@ class ElementBase(object): """ return not self.__eq__(other) - def __bool__(self): + def __bool__(self) -> bool: """Stanza objects should be treated as True in boolean contexts. """ return True - def __len__(self): + def __len__(self) -> int: """Return the number of iterable substanzas in this stanza.""" return len(self.iterables) - def __iter__(self): + def __iter__(self) -> ElementBase: """Return an iterator object for the stanza's substanzas. The iterator is the stanza object itself. Attempting to use two @@ -1318,7 +1378,7 @@ class ElementBase(object): self._index = 0 return self - def __next__(self): + def __next__(self) -> ElementBase: """Return the next iterable substanza.""" self._index += 1 if self._index > len(self.iterables): @@ -1326,13 +1386,16 @@ class ElementBase(object): raise StopIteration return self.iterables[self._index - 1] - def __copy__(self): + def __copy__(self) -> ElementBase: """Return a copy of the stanza object that does not share the same underlying XML object. """ - return self.__class__(xml=copy.deepcopy(self.xml), parent=self.parent) + return self.__class__( + xml=copy.deepcopy(self.xml), + parent=self.parent, + ) - def __str__(self, top_level_ns=True): + def __str__(self, top_level_ns: bool = True) -> str: """Return a string serialization of the underlying XML object. .. seealso:: :ref:`tostring` @@ -1343,12 +1406,33 @@ class ElementBase(object): return tostring(self.xml, xmlns='', top_level=True) - def __repr__(self): + def __repr__(self) -> str: """Use the stanza's serialized XML as its representation.""" return self.__str__() # Compatibility. _get_plugin = get_plugin + get_stanza_values = _get_stanza_values + set_stanza_values = _set_stanza_values + + #: A JSON/dictionary version of the XML content exposed through + #: the stanza interfaces:: + #: + #: >>> msg = Message() + #: >>> msg.values + #: {'body': '', 'from': , 'mucnick': '', 'mucroom': '', + #: 'to': , 'type': 'normal', 'id': '', 'subject': ''} + #: + #: Likewise, assigning to the :attr:`values` will change the XML + #: content:: + #: + #: >>> msg = Message() + #: >>> msg.values = {'body': 'Hi!', 'to': 'user@example.com'} + #: >>> msg + #: '<message to="user@example.com"><body>Hi!</body></message>' + #: + #: Child stanzas are exposed as nested dictionaries. + values = property(_get_stanza_values, _set_stanza_values) # type: ignore class StanzaBase(ElementBase): @@ -1386,9 +1470,14 @@ class StanzaBase(ElementBase): #: The default XMPP client namespace namespace = 'jabber:client' - - def __init__(self, stream=None, xml=None, stype=None, - sto=None, sfrom=None, sid=None, parent=None, recv=False): + types: ClassVar[Set[str]] = set() + + def __init__(self, stream: Optional[XMLStream] = None, + xml: Optional[ET.Element] = None, + stype: Optional[str] = None, + sto: Optional[JidStr] = None, sfrom: Optional[JidStr] = None, + sid: Optional[str] = None, + parent: Optional[ElementBase] = None, recv: bool = False): self.stream = stream if stream is not None: self.namespace = stream.default_ns @@ -1403,7 +1492,7 @@ class StanzaBase(ElementBase): self['id'] = sid self.tag = "{%s}%s" % (self.namespace, self.name) - def set_type(self, value): + def set_type(self, value: str) -> StanzaBase: """Set the stanza's ``'type'`` attribute. Only type values contained in :attr:`types` are accepted. @@ -1414,11 +1503,11 @@ class StanzaBase(ElementBase): self.xml.attrib['type'] = value return self - def get_to(self): + def get_to(self) -> JID: """Return the value of the stanza's ``'to'`` attribute.""" return JID(self._get_attr('to')) - def set_to(self, value): + def set_to(self, value: JidStr) -> None: """Set the ``'to'`` attribute of the stanza. :param value: A string or :class:`slixmpp.xmlstream.JID` object @@ -1426,11 +1515,11 @@ class StanzaBase(ElementBase): """ return self._set_attr('to', str(value)) - def get_from(self): + def get_from(self) -> JID: """Return the value of the stanza's ``'from'`` attribute.""" return JID(self._get_attr('from')) - def set_from(self, value): + def set_from(self, value: JidStr) -> None: """Set the 'from' attribute of the stanza. :param from: A string or JID object representing the sender's JID. @@ -1438,11 +1527,11 @@ class StanzaBase(ElementBase): """ return self._set_attr('from', str(value)) - def get_payload(self): + def get_payload(self) -> List[ET.Element]: """Return a list of XML objects contained in the stanza.""" return list(self.xml) - def set_payload(self, value): + def set_payload(self, value: Union[List[ElementBase], ElementBase]) -> StanzaBase: """Add XML content to the stanza. :param value: Either an XML or a stanza object, or a list @@ -1454,12 +1543,12 @@ class StanzaBase(ElementBase): self.append(val) return self - def del_payload(self): + def del_payload(self) -> StanzaBase: """Remove the XML contents of the stanza.""" self.clear() return self - def reply(self, clear=True): + def reply(self, clear: bool = True) -> StanzaBase: """Prepare the stanza for sending a reply. Swaps the ``'from'`` and ``'to'`` attributes. @@ -1475,7 +1564,7 @@ class StanzaBase(ElementBase): new_stanza = copy.copy(self) # if it's a component, use from if self.stream and hasattr(self.stream, "is_component") and \ - self.stream.is_component: + getattr(self.stream, 'is_component'): new_stanza['from'], new_stanza['to'] = self['to'], self['from'] else: new_stanza['to'] = self['from'] @@ -1484,19 +1573,19 @@ class StanzaBase(ElementBase): new_stanza.clear() return new_stanza - def error(self): + def error(self) -> StanzaBase: """Set the stanza's type to ``'error'``.""" self['type'] = 'error' return self - def unhandled(self): + def unhandled(self) -> None: """Called if no handlers have been registered to process this stanza. Meant to be overridden. """ pass - def exception(self, e): + def exception(self, e: Exception) -> None: """Handle exceptions raised during stanza processing. Meant to be overridden. @@ -1504,18 +1593,21 @@ class StanzaBase(ElementBase): log.exception('Error handling {%s}%s stanza', self.namespace, self.name) - def send(self): + def send(self) -> None: """Queue the stanza to be sent on the XML stream.""" - self.stream.send(self) + if self.stream is not None: + self.stream.send(self) + else: + log.error("Tried to send stanza without a stream: %s", self) - def __copy__(self): + def __copy__(self) -> StanzaBase: """Return a copy of the stanza object that does not share the same underlying XML object, but does share the same XML stream. """ return self.__class__(xml=copy.deepcopy(self.xml), stream=self.stream) - def __str__(self, top_level_ns=False): + def __str__(self, top_level_ns: bool = False) -> str: """Serialize the stanza's XML to a string. :param bool top_level_ns: Display the top-most namespace. @@ -1525,27 +1617,3 @@ class StanzaBase(ElementBase): return tostring(self.xml, xmlns=xmlns, stream=self.stream, top_level=(self.stream is None)) - - -#: A JSON/dictionary version of the XML content exposed through -#: the stanza interfaces:: -#: -#: >>> msg = Message() -#: >>> msg.values -#: {'body': '', 'from': , 'mucnick': '', 'mucroom': '', -#: 'to': , 'type': 'normal', 'id': '', 'subject': ''} -#: -#: Likewise, assigning to the :attr:`values` will change the XML -#: content:: -#: -#: >>> msg = Message() -#: >>> msg.values = {'body': 'Hi!', 'to': 'user@example.com'} -#: >>> msg -#: '<message to="user@example.com"><body>Hi!</body></message>' -#: -#: Child stanzas are exposed as nested dictionaries. -ElementBase.values = property(ElementBase._get_stanza_values, - ElementBase._set_stanza_values) - -ElementBase.get_stanza_values = ElementBase._get_stanza_values -ElementBase.set_stanza_values = ElementBase._set_stanza_values diff --git a/slixmpp/xmlstream/tostring.py b/slixmpp/xmlstream/tostring.py index efac124e..447c9017 100644 --- a/slixmpp/xmlstream/tostring.py +++ b/slixmpp/xmlstream/tostring.py @@ -1,4 +1,3 @@ - # slixmpp.xmlstream.tostring # ~~~~~~~~~~~~~~~~~~~~~~~~~~ # This module converts XML objects into Unicode strings and @@ -7,11 +6,20 @@ # Part of Slixmpp: The Slick XMPP Library # :copyright: (c) 2011 Nathanael C. Fritz # :license: MIT, see LICENSE for more details +from __future__ import annotations + +from typing import Optional, Set, TYPE_CHECKING +from xml.etree.ElementTree import Element +if TYPE_CHECKING: + from slixmpp.xmlstream import XMLStream + XML_NS = 'http://www.w3.org/XML/1998/namespace' -def tostring(xml=None, xmlns='', stream=None, outbuffer='', - top_level=False, open_only=False, namespaces=None): +def tostring(xml: Optional[Element] = None, xmlns: str = '', + stream: Optional[XMLStream] = None, outbuffer: str = '', + top_level: bool = False, open_only: bool = False, + namespaces: Optional[Set[str]] = None) -> str: """Serialize an XML object to a Unicode string. If an outer xmlns is provided using ``xmlns``, then the current element's @@ -35,6 +43,8 @@ def tostring(xml=None, xmlns='', stream=None, outbuffer='', :rtype: Unicode string """ + if xml is None: + return '' # Add previous results to the start of the output. output = [outbuffer] @@ -123,11 +133,12 @@ def tostring(xml=None, xmlns='', stream=None, outbuffer='', # Remove namespaces introduced in this context. This is necessary # because the namespaces object continues to be shared with other # contexts. - namespaces.remove(ns) + if namespaces is not None: + namespaces.remove(ns) return ''.join(output) -def escape(text, use_cdata=False): +def escape(text: str, use_cdata: bool = False) -> str: """Convert special characters in XML to escape sequences. :param string text: The XML text to convert. diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 8d90abf8..30f99071 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -9,17 +9,24 @@ # :license: MIT, see LICENSE for more details from typing import ( Any, + Dict, + Awaitable, + Generator, Coroutine, Callable, - Iterable, Iterator, List, Optional, Set, Union, Tuple, + TypeVar, + NoReturn, + Type, + cast, ) +import asyncio import functools import logging import socket as Socket @@ -27,30 +34,66 @@ import ssl import weakref import uuid -from asyncio import iscoroutinefunction, wait, Future from contextlib import contextmanager import xml.etree.ElementTree as ET +from asyncio import ( + AbstractEventLoop, + BaseTransport, + Future, + Task, + TimerHandle, + Transport, + iscoroutinefunction, + wait, +) -from slixmpp.xmlstream.asyncio import asyncio -from slixmpp.xmlstream import tostring +from slixmpp.types import FilterString +from slixmpp.xmlstream.tostring import tostring from slixmpp.xmlstream.stanzabase import StanzaBase, ElementBase from slixmpp.xmlstream.resolver import resolve, default_resolver +from slixmpp.xmlstream.handler.base import BaseHandler + +T = TypeVar('T') #: The time in seconds to wait before timing out waiting for response stanzas. RESPONSE_TIMEOUT = 30 log = logging.getLogger(__name__) + + class ContinueQueue(Exception): """ Exception raised in the send queue to "continue" from within an inner loop """ + class NotConnectedError(Exception): """ Raised when we try to send something over the wire but we are not connected. """ + +_T = TypeVar('_T', str, ElementBase, StanzaBase) + + +SyncFilter = Callable[[StanzaBase], Optional[StanzaBase]] +AsyncFilter = Callable[[StanzaBase], Awaitable[Optional[StanzaBase]]] + + +Filter = Union[ + SyncFilter, + AsyncFilter, +] + +_FiltersDict = Dict[str, List[Filter]] + +Handler = Callable[[Any], Union[ + Any, + Coroutine[Any, Any, Any] +]] + + class XMLStream(asyncio.BaseProtocol): """ An XML stream connection manager and event dispatcher. @@ -78,16 +121,156 @@ class XMLStream(asyncio.BaseProtocol): :param int port: The port to use for the connection. Defaults to 0. """ - def __init__(self, host='', port=0): - # The asyncio.Transport object provided by the connection_made() - # callback when we are connected - self.transport = None + transport: Optional[Transport] - # The socket that is used internally by the transport object - self.socket = None + # The socket that is used internally by the transport object + socket: Optional[ssl.SSLSocket] + + # The backoff of the connect routine (increases exponentially + # after each failure) + _connect_loop_wait: float + + parser: Optional[ET.XMLPullParser] + xml_depth: int + xml_root: Optional[ET.Element] + + force_starttls: Optional[bool] + disable_starttls: Optional[bool] + + waiting_queue: asyncio.Queue + + # A dict of {name: handle} + scheduled_events: Dict[str, TimerHandle] + + ssl_context: ssl.SSLContext + + # The event to trigger when the create_connection() succeeds. It can + # be "connected" or "tls_success" depending on the step we are at. + event_when_connected: str + + #: The list of accepted ciphers, in OpenSSL Format. + #: It might be useful to override it for improved security + #: over the python defaults. + ciphers: Optional[str] + + #: Path to a file containing certificates for verifying the + #: server SSL certificate. A non-``None`` value will trigger + #: certificate checking. + #: + #: .. note:: + #: + #: On Mac OS X, certificates in the system keyring will + #: be consulted, even if they are not in the provided file. + ca_certs: Optional[str] + + #: Path to a file containing a client certificate to use for + #: authenticating via SASL EXTERNAL. If set, there must also + #: be a corresponding `:attr:keyfile` value. + certfile: Optional[str] + + #: Path to a file containing the private key for the selected + #: client certificate to use for authenticating via SASL EXTERNAL. + keyfile: Optional[str] + + # The asyncio event loop + _loop: Optional[AbstractEventLoop] + + #: The default port to return when querying DNS records. + default_port: int + + #: The domain to try when querying DNS records. + default_domain: str + + #: The expected name of the server, for validation. + _expected_server_name: str + _service_name: str + + #: The desired, or actual, address of the connected server. + address: Tuple[str, int] + + #: Enable connecting to the server directly over SSL, in + #: particular when the service provides two ports: one for + #: non-SSL traffic and another for SSL traffic. + use_ssl: bool + + #: If set to ``True``, attempt to use IPv6. + use_ipv6: bool + + #: If set to ``True``, allow using the ``dnspython`` DNS library + #: if available. If set to ``False``, the builtin DNS resolver + #: will be used, even if ``dnspython`` is installed. + use_aiodns: bool - # The backoff of the connect routine (increases exponentially - # after each failure) + #: Use CDATA for escaping instead of XML entities. Defaults + #: to ``False``. + use_cdata: bool + + #: The default namespace of the stream content, not of the + #: stream wrapper it + default_ns: str + + default_lang: Optional[str] + peer_default_lang: Optional[str] + + #: The namespace of the enveloping stream element. + stream_ns: str + + #: The default opening tag for the stream element. + stream_header: str + + #: The default closing tag for the stream element. + stream_footer: str + + #: If ``True``, periodically send a whitespace character over the + #: wire to keep the connection alive. Mainly useful for connections + #: traversing NAT. + whitespace_keepalive: bool + + #: The default interval between keepalive signals when + #: :attr:`whitespace_keepalive` is enabled. + whitespace_keepalive_interval: int + + #: Flag for controlling if the session can be considered ended + #: if the connection is terminated. + end_session_on_disconnect: bool + + #: A mapping of XML namespaces to well-known prefixes. + namespace_map: dict + + __root_stanza: List[Type[StanzaBase]] + __handlers: List[BaseHandler] + __event_handlers: Dict[str, List[Tuple[Handler, bool]]] + __filters: _FiltersDict + + # Current connection attempt (Future) + _current_connection_attempt: Optional[Future] + + #: A list of DNS results that have not yet been tried. + _dns_answers: Optional[Iterator[Tuple[str, str, int]]] + + #: The service name to check with DNS SRV records. For + #: example, setting this to ``'xmpp-client'`` would query the + #: ``_xmpp-client._tcp`` service. + dns_service: Optional[str] + + #: The reason why we are disconnecting from the server + disconnect_reason: Optional[str] + + #: An asyncio Future being done when the stream is disconnected. + disconnected: Future + + # If the session has been started or not + _session_started: bool + # If we want to bypass the send() check (e.g. unit tests) + _always_send_everything: bool + + _run_out_filters: Optional[Future] + __slow_tasks: List[Task] + __queued_stanzas: List[Tuple[Union[StanzaBase, str], bool]] + + def __init__(self, host: str = '', port: int = 0): + self.transport = None + self.socket = None self._connect_loop_wait = 0 self.parser = None @@ -106,126 +289,60 @@ class XMLStream(asyncio.BaseProtocol): self.ssl_context.check_hostname = False self.ssl_context.verify_mode = ssl.CERT_NONE - # The event to trigger when the create_connection() succeeds. It can - # be "connected" or "tls_success" depending on the step we are at. self.event_when_connected = "connected" - #: The list of accepted ciphers, in OpenSSL Format. - #: It might be useful to override it for improved security - #: over the python defaults. self.ciphers = None - #: Path to a file containing certificates for verifying the - #: server SSL certificate. A non-``None`` value will trigger - #: certificate checking. - #: - #: .. note:: - #: - #: On Mac OS X, certificates in the system keyring will - #: be consulted, even if they are not in the provided file. self.ca_certs = None - #: Path to a file containing a client certificate to use for - #: authenticating via SASL EXTERNAL. If set, there must also - #: be a corresponding `:attr:keyfile` value. - self.certfile = None - - #: Path to a file containing the private key for the selected - #: client certificate to use for authenticating via SASL EXTERNAL. self.keyfile = None - self._der_cert = None - - # The asyncio event loop self._loop = None - #: The default port to return when querying DNS records. self.default_port = int(port) - - #: The domain to try when querying DNS records. self.default_domain = '' - #: The expected name of the server, for validation. self._expected_server_name = '' self._service_name = '' - #: The desired, or actual, address of the connected server. self.address = (host, int(port)) - #: Enable connecting to the server directly over SSL, in - #: particular when the service provides two ports: one for - #: non-SSL traffic and another for SSL traffic. self.use_ssl = False - - #: If set to ``True``, attempt to use IPv6. self.use_ipv6 = True - #: If set to ``True``, allow using the ``dnspython`` DNS library - #: if available. If set to ``False``, the builtin DNS resolver - #: will be used, even if ``dnspython`` is installed. self.use_aiodns = True - - #: Use CDATA for escaping instead of XML entities. Defaults - #: to ``False``. self.use_cdata = False - #: The default namespace of the stream content, not of the - #: stream wrapper itself. self.default_ns = '' self.default_lang = None self.peer_default_lang = None - #: The namespace of the enveloping stream element. self.stream_ns = '' - - #: The default opening tag for the stream element. self.stream_header = "<stream>" - - #: The default closing tag for the stream element. self.stream_footer = "</stream>" - #: If ``True``, periodically send a whitespace character over the - #: wire to keep the connection alive. Mainly useful for connections - #: traversing NAT. self.whitespace_keepalive = True - - #: The default interval between keepalive signals when - #: :attr:`whitespace_keepalive` is enabled. self.whitespace_keepalive_interval = 300 - #: Flag for controlling if the session can be considered ended - #: if the connection is terminated. self.end_session_on_disconnect = True - - #: A mapping of XML namespaces to well-known prefixes. self.namespace_map = {StanzaBase.xml_ns: 'xml'} self.__root_stanza = [] self.__handlers = [] self.__event_handlers = {} - self.__filters = {'in': [], 'out': [], 'out_sync': []} + self.__filters = { + 'in': [], 'out': [], 'out_sync': [] + } - # Current connection attempt (Future) self._current_connection_attempt = None - #: A list of DNS results that have not yet been tried. - self._dns_answers: Optional[Iterator[Tuple[str, str, int]]] = None - - #: The service name to check with DNS SRV records. For - #: example, setting this to ``'xmpp-client'`` would query the - #: ``_xmpp-client._tcp`` service. + self._dns_answers = None self.dns_service = None - #: The reason why we are disconnecting from the server self.disconnect_reason = None - - #: An asyncio Future being done when the stream is disconnected. - self.disconnected: Future = Future() - - # If the session has been started or not + self.disconnected = Future() self._session_started = False - # If we want to bypass the send() check (e.g. unit tests) self._always_send_everything = False self.add_event_handler('disconnected', self._remove_schedules) @@ -234,21 +351,21 @@ class XMLStream(asyncio.BaseProtocol): self.add_event_handler('session_start', self._set_session_start) self.add_event_handler('session_resumed', self._set_session_start) - self._run_out_filters: Optional[Future] = None - self.__slow_tasks: List[Future] = [] - self.__queued_stanzas: List[Tuple[StanzaBase, bool]] = [] + self._run_out_filters = None + self.__slow_tasks = [] + self.__queued_stanzas = [] @property - def loop(self): + def loop(self) -> AbstractEventLoop: if self._loop is None: self._loop = asyncio.get_event_loop() return self._loop @loop.setter - def loop(self, value): + def loop(self, value: AbstractEventLoop) -> None: self._loop = value - def new_id(self): + def new_id(self) -> str: """Generate and return a new stream ID in hexadecimal form. Many stanzas, handlers, or matchers may require unique @@ -257,7 +374,7 @@ class XMLStream(asyncio.BaseProtocol): """ return uuid.uuid4().hex - def _set_session_start(self, event): + def _set_session_start(self, event: Any) -> None: """ On session start, queue all pending stanzas to be sent. """ @@ -266,17 +383,17 @@ class XMLStream(asyncio.BaseProtocol): self.waiting_queue.put_nowait(stanza) self.__queued_stanzas = [] - def _set_disconnected(self, event): + def _set_disconnected(self, event: Any) -> None: self._session_started = False - def _set_disconnected_future(self): + def _set_disconnected_future(self) -> None: """Set the self.disconnected future on disconnect""" if not self.disconnected.done(): self.disconnected.set_result(True) self.disconnected = asyncio.Future() - def connect(self, host='', port=0, use_ssl=False, - force_starttls=True, disable_starttls=False): + def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = False, + force_starttls: Optional[bool] = True, disable_starttls: Optional[bool] = False) -> None: """Create a new socket and connect to the server. :param host: The name of the desired server for the connection. @@ -327,7 +444,7 @@ class XMLStream(asyncio.BaseProtocol): loop=self.loop, ) - async def _connect_routine(self): + async def _connect_routine(self) -> None: self.event_when_connected = "connected" if self._connect_loop_wait > 0: @@ -345,6 +462,7 @@ class XMLStream(asyncio.BaseProtocol): # and try (host, port) as a last resort self._dns_answers = None + ssl_context: Optional[ssl.SSLContext] if self.use_ssl: ssl_context = self.get_ssl_context() else: @@ -373,7 +491,7 @@ class XMLStream(asyncio.BaseProtocol): loop=self.loop, ) - def process(self, *, forever=True, timeout=None): + def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None: """Process all the available XMPP events (receiving or sending data on the socket(s), calling various registered callbacks, calling expired timers, handling signal events, etc). If timeout is None, this @@ -386,12 +504,12 @@ class XMLStream(asyncio.BaseProtocol): else: self.loop.run_until_complete(self.disconnected) else: - tasks = [asyncio.sleep(timeout, loop=self.loop)] + tasks: List[Future] = [asyncio.sleep(timeout, loop=self.loop)] if not forever: tasks.append(self.disconnected) self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop)) - def init_parser(self): + def init_parser(self) -> None: """init the XML parser. The parser must always be reset for each new connexion """ @@ -399,11 +517,13 @@ class XMLStream(asyncio.BaseProtocol): self.xml_root = None self.parser = ET.XMLPullParser(("start", "end")) - def connection_made(self, transport): + def connection_made(self, transport: BaseTransport) -> None: """Called when the TCP connection has been established with the server """ self.event(self.event_when_connected) - self.transport = transport + self.transport = cast(Transport, transport) + if self.transport is None: + raise ValueError("Transport cannot be none") self.socket = self.transport.get_extra_info( "ssl_object", default=self.transport.get_extra_info("socket") @@ -413,7 +533,7 @@ class XMLStream(asyncio.BaseProtocol): self.send_raw(self.stream_header) self._dns_answers = None - def data_received(self, data): + def data_received(self, data: bytes) -> None: """Called when incoming data is received on the socket. We feed that data to the parser and the see if this produced any XML @@ -467,18 +587,18 @@ class XMLStream(asyncio.BaseProtocol): self.send(error) self.disconnect() - def is_connecting(self): + def is_connecting(self) -> bool: return self._current_connection_attempt is not None - def is_connected(self): + def is_connected(self) -> bool: return self.transport is not None - def eof_received(self): + def eof_received(self) -> None: """When the TCP connection is properly closed by the remote end """ self.event("eof_received") - def connection_lost(self, exception): + def connection_lost(self, exception: Optional[BaseException]) -> None: """On any kind of disconnection, initiated by us or not. This signals the closure of the TCP connection """ @@ -493,9 +613,9 @@ class XMLStream(asyncio.BaseProtocol): self._reset_sendq() self.event('session_end') self._set_disconnected_future() - self.event("disconnected", self.disconnect_reason or exception and exception.strerror) + self.event("disconnected", self.disconnect_reason or exception) - def cancel_connection_attempt(self): + def cancel_connection_attempt(self) -> None: """ Immediately cancel the current create_connection() Future. This is useful when a client using slixmpp tries to connect @@ -526,7 +646,7 @@ class XMLStream(asyncio.BaseProtocol): # `disconnect(wait=True)` for ages. This doesn't mean anything to the # schedule call below. It would fortunately be converted to `1` later # down the call chain. Praise the implicit casts lord. - if wait == True: + if wait is True: wait = 2.0 if self.transport: @@ -545,11 +665,11 @@ class XMLStream(asyncio.BaseProtocol): else: self._set_disconnected_future() self.event("disconnected", reason) - future = Future() + future: Future = Future() future.set_result(None) return future - async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float): + async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float) -> None: """Wait until the send queue is empty before disconnecting""" try: await asyncio.wait_for( @@ -561,7 +681,7 @@ class XMLStream(asyncio.BaseProtocol): self.disconnect_reason = reason await self._end_stream_wait(wait) - async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None): + async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None) -> None: """ Run abort() if we do not received the disconnected event after a waiting time. @@ -578,7 +698,7 @@ class XMLStream(asyncio.BaseProtocol): # that means the disconnect has already been handled pass - def abort(self): + def abort(self) -> None: """ Forcibly close the connection """ @@ -588,26 +708,26 @@ class XMLStream(asyncio.BaseProtocol): self.transport.abort() self.event("killed") - def reconnect(self, wait=2.0, reason="Reconnecting"): + def reconnect(self, wait: Union[int, float] = 2.0, reason: str = "Reconnecting") -> None: """Calls disconnect(), and once we are disconnected (after the timeout, or when the server acknowledgement is received), call connect() """ log.debug("reconnecting...") - async def handler(event): + async def handler(event: Any) -> None: # We yield here to allow synchronous handlers to work first await asyncio.sleep(0, loop=self.loop) self.connect() self.add_event_handler('disconnected', handler, disposable=True) self.disconnect(wait, reason) - def configure_socket(self): + def configure_socket(self) -> None: """Set timeout and other options for self.socket. Meant to be overridden. """ pass - def configure_dns(self, resolver, domain=None, port=None): + def configure_dns(self, resolver: Any, domain: Optional[str] = None, port: Optional[int] = None) -> None: """ Configure and set options for a :class:`~dns.resolver.Resolver` instance, and other DNS related tasks. For example, you @@ -624,7 +744,7 @@ class XMLStream(asyncio.BaseProtocol): """ pass - def get_ssl_context(self): + def get_ssl_context(self) -> ssl.SSLContext: """ Get SSL context. """ @@ -644,12 +764,14 @@ class XMLStream(asyncio.BaseProtocol): return self.ssl_context - async def start_tls(self): + async def start_tls(self) -> bool: """Perform handshakes for TLS. If the handshake is successful, the XML stream will need to be restarted. """ + if self.transport is None: + raise ValueError("Transport should not be None") self.event_when_connected = "tls_success" ssl_context = self.get_ssl_context() try: @@ -685,7 +807,7 @@ class XMLStream(asyncio.BaseProtocol): self.connection_made(transp) return True - def _start_keepalive(self, event): + def _start_keepalive(self, event: Any) -> None: """Begin sending whitespace periodically to keep the connection alive. May be disabled by setting:: @@ -702,11 +824,11 @@ class XMLStream(asyncio.BaseProtocol): args=(' ',), repeat=True) - def _remove_schedules(self, event): + def _remove_schedules(self, event: Any) -> None: """Remove some schedules that become pointless when disconnected""" self.cancel_schedule('Whitespace Keepalive') - def start_stream_handler(self, xml): + def start_stream_handler(self, xml: ET.Element) -> None: """Perform any initialization actions, such as handshakes, once the stream header has been sent. @@ -714,7 +836,7 @@ class XMLStream(asyncio.BaseProtocol): """ pass - def register_stanza(self, stanza_class): + def register_stanza(self, stanza_class: Type[StanzaBase]) -> None: """Add a stanza object class as a known root stanza. A root stanza is one that appears as a direct child of the stream's @@ -732,7 +854,7 @@ class XMLStream(asyncio.BaseProtocol): """ self.__root_stanza.append(stanza_class) - def remove_stanza(self, stanza_class): + def remove_stanza(self, stanza_class: Type[StanzaBase]) -> None: """Remove a stanza from being a known root stanza. A root stanza is one that appears as a direct child of the stream's @@ -744,7 +866,7 @@ class XMLStream(asyncio.BaseProtocol): """ self.__root_stanza.remove(stanza_class) - def add_filter(self, mode, handler, order=None): + def add_filter(self, mode: FilterString, handler: Callable[[StanzaBase], Optional[StanzaBase]], order: Optional[int] = None) -> None: """Add a filter for incoming or outgoing stanzas. These filters are applied before incoming stanzas are @@ -766,11 +888,11 @@ class XMLStream(asyncio.BaseProtocol): else: self.__filters[mode].append(handler) - def del_filter(self, mode, handler): + def del_filter(self, mode: str, handler: Callable[[StanzaBase], Optional[StanzaBase]]) -> None: """Remove an incoming or outgoing filter.""" self.__filters[mode].remove(handler) - def register_handler(self, handler, before=None, after=None): + def register_handler(self, handler: BaseHandler, before: Optional[BaseHandler] = None, after: Optional[BaseHandler] = None) -> None: """Add a stream event handler that will be executed when a matching stanza is received. @@ -782,7 +904,7 @@ class XMLStream(asyncio.BaseProtocol): self.__handlers.append(handler) handler.stream = weakref.ref(self) - def remove_handler(self, name): + def remove_handler(self, name: str) -> bool: """Remove any stream event handlers with the given name. :param name: The name of the handler. @@ -831,9 +953,9 @@ class XMLStream(asyncio.BaseProtocol): try: return next(self._dns_answers) except StopIteration: - return + return None - def add_event_handler(self, name, pointer, disposable=False): + def add_event_handler(self, name: str, pointer: Callable[..., Any], disposable: bool = False) -> None: """Add a custom event handler that will be executed whenever its event is manually triggered. @@ -847,7 +969,7 @@ class XMLStream(asyncio.BaseProtocol): self.__event_handlers[name] = [] self.__event_handlers[name].append((pointer, disposable)) - def del_event_handler(self, name, pointer): + def del_event_handler(self, name: str, pointer: Callable[..., Any]) -> None: """Remove a function as a handler for an event. :param name: The name of the event. @@ -858,21 +980,21 @@ class XMLStream(asyncio.BaseProtocol): # Need to keep handlers that do not use # the given function pointer - def filter_pointers(handler): + def filter_pointers(handler: Tuple[Callable[..., Any], bool]) -> bool: return handler[0] != pointer self.__event_handlers[name] = list(filter( filter_pointers, self.__event_handlers[name])) - def event_handled(self, name): + def event_handled(self, name: str) -> int: """Returns the number of registered handlers for an event. :param name: The name of the event to check. """ return len(self.__event_handlers.get(name, [])) - async def event_async(self, name: str, data: Any = {}): + async def event_async(self, name: str, data: Any = {}) -> None: """Manually trigger a custom event, but await coroutines immediately. This event generator should only be called in situations when @@ -908,7 +1030,7 @@ class XMLStream(asyncio.BaseProtocol): except Exception as e: self.exception(e) - def event(self, name: str, data: Any = {}): + def event(self, name: str, data: Any = {}) -> None: """Manually trigger a custom event. Coroutine handlers are wrapped into a future and sent into the event loop for their execution, and not awaited. @@ -928,7 +1050,7 @@ class XMLStream(asyncio.BaseProtocol): # If the callback is a coroutine, schedule it instead of # running it directly if iscoroutinefunction(handler_callback): - async def handler_callback_routine(cb): + async def handler_callback_routine(cb: Callable[[ElementBase], Any]) -> None: try: await cb(data) except Exception as e: @@ -957,8 +1079,9 @@ class XMLStream(asyncio.BaseProtocol): except ValueError: pass - def schedule(self, name, seconds, callback, args=tuple(), - kwargs={}, repeat=False): + def schedule(self, name: str, seconds: int, callback: Callable[..., None], + args: Tuple[Any, ...] = tuple(), + kwargs: Dict[Any, Any] = {}, repeat: bool = False) -> None: """Schedule a callback function to execute after a given delay. :param name: A unique name for the scheduled callback. @@ -986,21 +1109,21 @@ class XMLStream(asyncio.BaseProtocol): # canceling scheduled_events[name] self.scheduled_events[name] = handle - def cancel_schedule(self, name): + def cancel_schedule(self, name: str) -> None: try: handle = self.scheduled_events.pop(name) handle.cancel() except KeyError: log.debug("Tried to cancel unscheduled event: %s" % (name,)) - def _safe_cb_run(self, name, cb): + def _safe_cb_run(self, name: str, cb: Callable[[], None]) -> None: log.debug('Scheduled event: %s', name) try: cb() except Exception as e: self.exception(e) - def _execute_and_reschedule(self, name, cb, seconds): + def _execute_and_reschedule(self, name: str, cb: Callable[[], None], seconds: int) -> None: """Simple method that calls the given callback, and then schedule itself to be called after the given number of seconds. """ @@ -1009,7 +1132,7 @@ class XMLStream(asyncio.BaseProtocol): name, cb, seconds) self.scheduled_events[name] = handle - def _execute_and_unschedule(self, name, cb): + def _execute_and_unschedule(self, name: str, cb: Callable[[], None]) -> None: """ Execute the callback and remove the handler for it. """ @@ -1018,7 +1141,7 @@ class XMLStream(asyncio.BaseProtocol): if name in self.scheduled_events: del self.scheduled_events[name] - def incoming_filter(self, xml): + def incoming_filter(self, xml: ET.Element) -> ET.Element: """Filter incoming XML objects before they are processed. Possible uses include remapping namespaces, or correcting elements @@ -1028,7 +1151,7 @@ class XMLStream(asyncio.BaseProtocol): """ return xml - def _reset_sendq(self): + def _reset_sendq(self) -> None: """Clear sending tasks on session end""" # Cancel all pending slow send tasks log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks)) @@ -1043,7 +1166,7 @@ class XMLStream(asyncio.BaseProtocol): async def _continue_slow_send( self, task: asyncio.Task, - already_used: Set[Callable[[ElementBase], Optional[StanzaBase]]] + already_used: Set[Filter] ) -> None: """ Used when an item in the send queue has taken too long to process. @@ -1060,14 +1183,16 @@ class XMLStream(asyncio.BaseProtocol): if filter in already_used: continue if iscoroutinefunction(filter): - data = await filter(data) + data = await filter(data) # type: ignore else: + filter = cast(SyncFilter, filter) data = filter(data) if data is None: return - if isinstance(data, ElementBase): + if isinstance(data, StanzaBase): for filter in self.__filters['out_sync']: + filter = cast(SyncFilter, filter) data = filter(data) if data is None: return @@ -1077,19 +1202,21 @@ class XMLStream(asyncio.BaseProtocol): else: self.send_raw(data) - async def run_filters(self): + async def run_filters(self) -> NoReturn: """ Background loop that processes stanzas to send. """ while True: + data: Optional[Union[StanzaBase, str]] (data, use_filters) = await self.waiting_queue.get() try: - if isinstance(data, ElementBase): + if isinstance(data, StanzaBase): if use_filters: already_run_filters = set() for filter in self.__filters['out']: already_run_filters.add(filter) if iscoroutinefunction(filter): + filter = cast(AsyncFilter, filter) task = asyncio.create_task(filter(data)) completed, pending = await wait( {task}, @@ -1108,21 +1235,26 @@ class XMLStream(asyncio.BaseProtocol): "Slow coroutine, rescheduling filters" ) data = task.result() - else: + elif isinstance(data, StanzaBase): + filter = cast(SyncFilter, filter) data = filter(data) if data is None: raise ContinueQueue('Empty stanza') - if isinstance(data, ElementBase): + if isinstance(data, StanzaBase): if use_filters: for filter in self.__filters['out_sync']: + filter = cast(SyncFilter, filter) data = filter(data) if data is None: raise ContinueQueue('Empty stanza') - str_data = tostring(data.xml, xmlns=self.default_ns, - stream=self, top_level=True) + if isinstance(data, StanzaBase): + str_data = tostring(data.xml, xmlns=self.default_ns, + stream=self, top_level=True) + else: + str_data = data self.send_raw(str_data) - else: + elif isinstance(data, (str, bytes)): self.send_raw(data) except ContinueQueue as exc: log.debug('Stanza in send queue not sent: %s', exc) @@ -1130,10 +1262,10 @@ class XMLStream(asyncio.BaseProtocol): log.error('Exception raised in send queue:', exc_info=True) self.waiting_queue.task_done() - def send(self, data, use_filters=True): + def send(self, data: Union[StanzaBase, str], use_filters: bool = True) -> None: """A wrapper for :meth:`send_raw()` for sending stanza objects. - :param data: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase` + :param data: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` stanza to send on the stream. :param bool use_filters: Indicates if outgoing filters should be applied to the given stanza data. Disabling @@ -1156,15 +1288,15 @@ class XMLStream(asyncio.BaseProtocol): return self.waiting_queue.put_nowait((data, use_filters)) - def send_xml(self, data): + def send_xml(self, data: ET.Element) -> None: """Send an XML object on the stream :param data: The :class:`~xml.etree.ElementTree.Element` XML object to send on the stream. """ - return self.send(tostring(data)) + self.send(tostring(data)) - def send_raw(self, data): + def send_raw(self, data: Union[str, bytes]) -> None: """Send raw data across the stream. :param string data: Any bytes or utf-8 string value. @@ -1176,7 +1308,8 @@ class XMLStream(asyncio.BaseProtocol): data = data.encode('utf-8') self.transport.write(data) - def _build_stanza(self, xml, default_ns=None): + def _build_stanza(self, xml: ET.Element, + default_ns: Optional[str] = None) -> StanzaBase: """Create a stanza object from a given XML object. If a specialized stanza type is not found for the XML, then @@ -1201,7 +1334,7 @@ class XMLStream(asyncio.BaseProtocol): stanza['lang'] = self.peer_default_lang return stanza - def _spawn_event(self, xml): + def _spawn_event(self, xml: ET.Element) -> None: """ Analyze incoming XML stanzas and convert them into stanza objects if applicable and queue stream events to be processed @@ -1215,9 +1348,10 @@ class XMLStream(asyncio.BaseProtocol): # Convert the raw XML object into a stanza object. If no registered # stanza type applies, a generic StanzaBase stanza will be used. - stanza = self._build_stanza(xml) + stanza: Optional[StanzaBase] = self._build_stanza(xml) for filter in self.__filters['in']: if stanza is not None: + filter = cast(SyncFilter, filter) stanza = filter(stanza) if stanza is None: return @@ -1244,7 +1378,7 @@ class XMLStream(asyncio.BaseProtocol): if not handled: stanza.unhandled() - def exception(self, exception): + def exception(self, exception: Exception) -> None: """Process an unknown exception. Meant to be overridden. @@ -1253,7 +1387,7 @@ class XMLStream(asyncio.BaseProtocol): """ pass - async def wait_until(self, event: str, timeout=30) -> Any: + async def wait_until(self, event: str, timeout: Union[int, float] = 30) -> Any: """Utility method to wake on the next firing of an event. (Registers a disposable handler on it) @@ -1261,9 +1395,9 @@ class XMLStream(asyncio.BaseProtocol): :param int timeout: Timeout :raises: :class:`asyncio.TimeoutError` when the timeout is reached """ - fut = asyncio.Future() + fut: Future = asyncio.Future() - def result_handler(event_data): + def result_handler(event_data: Any) -> None: if not fut.done(): fut.set_result(event_data) else: @@ -1280,19 +1414,19 @@ class XMLStream(asyncio.BaseProtocol): return await asyncio.wait_for(fut, timeout) @contextmanager - def event_handler(self, event: str, handler: Callable): + def event_handler(self, event: str, handler: Callable[..., Any]) -> Generator[None, None, None]: """ Context manager that adds then removes an event handler. """ self.add_event_handler(event, handler) try: yield - except Exception as exc: + except Exception: raise finally: self.del_event_handler(event, handler) - def wrap(self, coroutine: Coroutine[Any, Any, Any]) -> Future: + def wrap(self, coroutine: Coroutine[None, None, T]) -> Future: """Make a Future out of a coroutine with the current loop. :param coroutine: The coroutine to wrap. |