diff options
-rw-r--r-- | slixmpp/plugins/xep_0030/disco.py | 28 | ||||
-rw-r--r-- | slixmpp/plugins/xep_0363/http_upload.py | 16 |
2 files changed, 33 insertions, 11 deletions
diff --git a/slixmpp/plugins/xep_0030/disco.py b/slixmpp/plugins/xep_0030/disco.py index 166ad981..ea9a33f4 100644 --- a/slixmpp/plugins/xep_0030/disco.py +++ b/slixmpp/plugins/xep_0030/disco.py @@ -6,6 +6,7 @@ See the file LICENSE for copying permission. """ +import asyncio import logging from slixmpp import Iq @@ -123,6 +124,8 @@ class XEP_0030(BasePlugin): for op in self._disco_ops: self.api.register(getattr(self.static, op), op, default=True) + self.domain_infos = {} + def session_bind(self, jid): self.add_feature('http://jabber.org/protocol/disco#info') @@ -295,6 +298,31 @@ class XEP_0030(BasePlugin): 'cached': cached} return self.api['has_identity'](jid, node, ifrom, data) + async def get_info_from_domain(self, domain=None, timeout=None, + cached=True, callback=None, **kwargs): + if domain is None: + domain = self.xmpp.boundjid.domain + + if not cached or domain not in self.domain_infos: + infos = [self.get_info( + domain, timeout=timeout, **kwargs)] + iq_items = await self.get_items( + domain, timeout=timeout, **kwargs) + items = iq_items['disco_items']['items'] + infos += [ + self.get_info(item[0], timeout=timeout, **kwargs) + for item in items] + info_futures, _ = await asyncio.wait(infos, timeout=timeout) + + self.domain_infos[domain] = [ + future.result() for future in info_futures] + + results = self.domain_infos[domain] + + if callback is not None: + callback(results) + return results + @future_wrapper def get_info(self, jid=None, node=None, local=None, cached=None, **kwargs): diff --git a/slixmpp/plugins/xep_0363/http_upload.py b/slixmpp/plugins/xep_0363/http_upload.py index 0cca2d08..65894975 100644 --- a/slixmpp/plugins/xep_0363/http_upload.py +++ b/slixmpp/plugins/xep_0363/http_upload.py @@ -68,16 +68,10 @@ class XEP_0363(BasePlugin): def _handle_request(self, iq): self.xmpp.event('http_upload_request', iq) - async def find_upload_service(self, ifrom=None, timeout=None, callback=None, - timeout_callback=None): - infos = [self.xmpp['xep_0030'].get_info(self.xmpp.boundjid.domain)] - iq_items = await self.xmpp['xep_0030'].get_items( - self.xmpp.boundjid.domain, timeout=timeout) - items = iq_items['disco_items']['items'] - infos += [self.xmpp['xep_0030'].get_info(item[0]) for item in items] - info_futures, _ = await asyncio.wait(infos, timeout=timeout) - for future in info_futures: - info = future.result() + async def find_upload_service(self, timeout=None): + results = await self.xmpp['xep_0030'].get_info_from_domain() + + for info in results: for identity in info['disco_info']['identities']: if identity[0] == 'store' and identity[1] == 'file': return info @@ -100,7 +94,7 @@ class XEP_0363(BasePlugin): callback=None, timeout_callback=None): ''' Helper function which does all of the uploading process. ''' if self.upload_service is None: - info_iq = await self.find_upload_service(ifrom=ifrom, timeout=timeout) + info_iq = await self.find_upload_service(timeout=timeout) if info_iq is None: raise UploadServiceNotFound() self.upload_service = info_iq['from'] |