From 4960cffcb49ce9b996b820103e48702d694239a1 Mon Sep 17 00:00:00 2001 From: mathieui Date: Sun, 14 Feb 2021 11:56:20 +0100 Subject: XEP-0115: API changes - ``get_verstring``, ``get_caps`` are now coroutines. - ``assign_verstring``, ``cache_caps`` now return a Future. side-effect: fix supports() and has_identity() broken since forever --- slixmpp/plugins/xep_0115/caps.py | 70 +++++++++++++++++++++++++------------- slixmpp/plugins/xep_0115/static.py | 12 +++---- 2 files changed, 53 insertions(+), 29 deletions(-) (limited to 'slixmpp/plugins') diff --git a/slixmpp/plugins/xep_0115/caps.py b/slixmpp/plugins/xep_0115/caps.py index 5f71165c..75c96410 100644 --- a/slixmpp/plugins/xep_0115/caps.py +++ b/slixmpp/plugins/xep_0115/caps.py @@ -7,6 +7,8 @@ import logging import hashlib import base64 +from asyncio import Future + from slixmpp import __version__ from slixmpp.stanza import StreamFeatures, Presence, Iq from slixmpp.xmlstream import register_stanza_plugin, JID @@ -104,14 +106,14 @@ class XEP_0115(BasePlugin): def session_bind(self, jid): self.xmpp['xep_0030'].add_feature(stanza.Capabilities.namespace) - def _filter_add_caps(self, stanza): + async def _filter_add_caps(self, stanza): if not isinstance(stanza, Presence) or not self.broadcast: return stanza if stanza['type'] not in ('available', 'chat', 'away', 'dnd', 'xa'): return stanza - ver = self.get_verstring(stanza['from']) + ver = await self.get_verstring(stanza['from']) if ver: stanza['caps']['node'] = self.caps_node stanza['caps']['hash'] = self.hash @@ -145,13 +147,13 @@ class XEP_0115(BasePlugin): ver = pres['caps']['ver'] - existing_verstring = self.get_verstring(pres['from'].full) + existing_verstring = await self.get_verstring(pres['from'].full) if str(existing_verstring) == str(ver): return - existing_caps = self.get_caps(verstring=ver) + existing_caps = await self.get_caps(verstring=ver) if existing_caps is not None: - self.assign_verstring(pres['from'], ver) + await self.assign_verstring(pres['from'], ver) return ifrom = pres['to'] if self.xmpp.is_component else None @@ -174,13 +176,13 @@ class XEP_0115(BasePlugin): if isinstance(caps, Iq): caps = caps['disco_info'] - if self._validate_caps(caps, pres['caps']['hash'], - pres['caps']['ver']): - self.assign_verstring(pres['from'], pres['caps']['ver']) + if await self._validate_caps(caps, pres['caps']['hash'], + pres['caps']['ver']): + await self.assign_verstring(pres['from'], pres['caps']['ver']) except XMPPError: log.debug("Could not retrieve disco#info results for caps for %s", node) - def _validate_caps(self, caps, hash, check_verstring): + async def _validate_caps(self, caps, hash, check_verstring): # Check Identities full_ids = caps.get_identities(dedupe=False) deduped_ids = caps.get_identities() @@ -232,7 +234,7 @@ class XEP_0115(BasePlugin): verstring, check_verstring)) return False - self.cache_caps(verstring, caps) + await self.cache_caps(verstring, caps) return True def generate_verstring(self, info, hash): @@ -290,12 +292,13 @@ class XEP_0115(BasePlugin): if isinstance(info, Iq): info = info['disco_info'] ver = self.generate_verstring(info, self.hash) - self.xmpp['xep_0030'].set_info( - jid=jid, - node='%s#%s' % (self.caps_node, ver), - info=info) - self.cache_caps(ver, info) - self.assign_verstring(jid, ver) + await self.xmpp['xep_0030'].set_info( + jid=jid, + node='%s#%s' % (self.caps_node, ver), + info=info + ) + await self.cache_caps(ver, info) + await self.assign_verstring(jid, ver) if self.xmpp.sessionstarted and self.broadcast: if self.xmpp.is_component or preserve: @@ -306,32 +309,53 @@ class XEP_0115(BasePlugin): except XMPPError: return - def get_verstring(self, jid=None): + def get_verstring(self, jid=None) -> Future: + """Get the stored verstring for a JID. + + .. versionchanged:: 1.8.0 + This function now returns a Future. + """ if jid in ('', None): jid = self.xmpp.boundjid.full if isinstance(jid, JID): jid = jid.full return self.api['get_verstring'](jid) - def assign_verstring(self, jid=None, verstring=None): + def assign_verstring(self, jid=None, verstring=None) -> Future: + """Assign a vertification string to a jid. + + .. versionchanged:: 1.8.0 + This function now returns a Future. + """ if jid in (None, ''): jid = self.xmpp.boundjid.full if isinstance(jid, JID): jid = jid.full return self.api['assign_verstring'](jid, args={ - 'verstring': verstring}) + 'verstring': verstring + }) - def cache_caps(self, verstring=None, info=None): + def cache_caps(self, verstring=None, info=None) -> Future: + """Add caps to the cache. + + .. versionchanged:: 1.8.0 + This function now returns a Future. + """ data = {'verstring': verstring, 'info': info} return self.api['cache_caps'](args=data) - def get_caps(self, jid=None, verstring=None): + async def get_caps(self, jid=None, verstring=None): + """Get caps for a JID. + + .. versionchanged:: 1.8.0 + This function is now a coroutine. + """ if verstring is None: if jid is not None: - verstring = self.get_verstring(jid) + verstring = await self.get_verstring(jid) else: return None if isinstance(jid, JID): jid = jid.full data = {'verstring': verstring} - return self.api['get_caps'](jid, args=data) + return await self.api['get_caps'](jid, args=data) diff --git a/slixmpp/plugins/xep_0115/static.py b/slixmpp/plugins/xep_0115/static.py index 2461d2e3..74f2beb8 100644 --- a/slixmpp/plugins/xep_0115/static.py +++ b/slixmpp/plugins/xep_0115/static.py @@ -32,7 +32,7 @@ class StaticCaps(object): self.static = static self.jid_vers = {} - def supports(self, jid, node, ifrom, data): + async def supports(self, jid, node, ifrom, data): """ Check if a JID supports a given feature. @@ -65,8 +65,8 @@ class StaticCaps(object): return True try: - info = self.disco.get_info(jid=jid, node=node, - ifrom=ifrom, **data) + info = await self.disco.get_info(jid=jid, node=node, + ifrom=ifrom, **data) info = self.disco._wrap(ifrom, jid, info, True) return feature in info['disco_info']['features'] except IqError: @@ -74,7 +74,7 @@ class StaticCaps(object): except IqTimeout: return None - def has_identity(self, jid, node, ifrom, data): + async def has_identity(self, jid, node, ifrom, data): """ Check if a JID has a given identity. @@ -110,8 +110,8 @@ class StaticCaps(object): return True try: - info = self.disco.get_info(jid=jid, node=node, - ifrom=ifrom, **data) + info = await self.disco.get_info(jid=jid, node=node, + ifrom=ifrom, **data) info = self.disco._wrap(ifrom, jid, info, True) return identity in map(trunc, info['disco_info']['identities']) except IqError: -- cgit v1.2.3