diff options
Diffstat (limited to 'poezio/mam.py')
-rw-r--r-- | poezio/mam.py | 211 |
1 files changed, 211 insertions, 0 deletions
diff --git a/poezio/mam.py b/poezio/mam.py new file mode 100644 index 00000000..7cb1d369 --- /dev/null +++ b/poezio/mam.py @@ -0,0 +1,211 @@ +""" + Query and control an archive of messages stored on a server using + XEP-0313: Message Archive Management(MAM). +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timedelta, timezone +from hashlib import md5 +from typing import ( + Any, + AsyncIterable, + Dict, + List, + Optional, +) + +from slixmpp import JID, Message as SMessage +from slixmpp.exceptions import IqError, IqTimeout +from poezio.theming import get_theme +from poezio import tabs +from poezio import colors +from poezio.common import to_utc +from poezio.ui.types import ( + BaseMessage, + Message, +) + + +log = logging.getLogger(__name__) + +class DiscoInfoException(Exception): pass +class MAMQueryException(Exception): pass +class NoMAMSupportException(Exception): pass + + +def make_line( + tab: tabs.ChatTab, + text: str, + time: datetime, + jid: JID, + identifier: str = '', + nick: str = '' + ) -> Message: + """Adds a textual entry in the TextBuffer""" + + # Convert to local timezone + time = time.replace(tzinfo=timezone.utc).astimezone(tz=None) + time = time.replace(tzinfo=None) + + if isinstance(tab, tabs.MucTab): + nick = jid.resource + user = tab.get_user_by_name(nick) + if user: + color = user.color + else: + theme = get_theme() + if theme.ccg_palette: + fg_color = colors.ccg_text_to_color(theme.ccg_palette, nick) + color = fg_color, -1 + else: + mod = len(theme.LIST_COLOR_NICKNAMES) + nick_pos = int(md5(nick.encode('utf-8')).hexdigest(), 16) % mod + color = theme.LIST_COLOR_NICKNAMES[nick_pos] + else: + if jid.bare == tab.core.xmpp.boundjid.bare: + if not nick: + nick = tab.core.own_nick + color = get_theme().COLOR_OWN_NICK + else: + color = get_theme().COLOR_REMOTE_USER + if not nick: + nick = tab.get_nick() + return Message( + txt=text, + identifier=identifier, + time=time, + nickname=nick, + nick_color=color, + history=True, + user=None, + ) + +async def get_mam_iterator( + core, + groupchat: bool, + remote_jid: JID, + amount: int, + reverse: bool = True, + start: Optional[str] = None, + end: Optional[str] = None, + before: Optional[str] = None, + ) -> AsyncIterable[SMessage]: + """Get an async iterator for this mam query""" + try: + query_jid = remote_jid if groupchat else JID(core.xmpp.boundjid.bare) + iq = await core.xmpp.plugin['xep_0030'].get_info(jid=query_jid) + except (IqError, IqTimeout): + raise DiscoInfoException() + if 'urn:xmpp:mam:2' not in iq['disco_info'].get_features(): + raise NoMAMSupportException() + + args: Dict[str, Any] = { + 'iterator': True, + 'reverse': reverse, + } + + if groupchat: + args['jid'] = remote_jid + else: + args['with_jid'] = remote_jid + + if amount > 0: + args['rsm'] = {'max': amount} + args['start'] = start + args['end'] = end + return core.xmpp['xep_0313'].retrieve(**args) + + +def _parse_message(msg: SMessage) -> Dict: + """Parse info inside a MAM forwarded message""" + forwarded = msg['mam_result']['forwarded'] + message = forwarded['stanza'] + return { + 'time': forwarded['delay']['stamp'], + 'jid': message['from'], + 'text': message['body'], + 'identifier': message['origin-id'] + } + + +def _ignore_private_message(stanza: SMessage, filter_jid: Optional[JID]) -> bool: + """Returns True if a MUC-PM should be ignored, as prosody returns + all PMs within the same room. + """ + if filter_jid is None: + return False + sent = stanza['from'].bare != filter_jid.bare + if sent and stanza['to'].full != filter_jid.full: + return True + elif not sent and stanza['from'].full != filter_jid.full: + return True + return False + + +async def retrieve_messages(tab: tabs.ChatTab, + results: AsyncIterable[SMessage], + amount: int = 100) -> List[BaseMessage]: + """Run the MAM query and put messages in order""" + msg_count = 0 + msgs = [] + to_add = [] + tab_is_private = isinstance(tab, tabs.PrivateTab) + filter_jid = None + if tab_is_private: + filter_jid = tab.jid + try: + async for rsm in results: + for msg in rsm['mam']['results']: + stanza = msg['mam_result']['forwarded']['stanza'] + if stanza.xml.find('{%s}%s' % ('jabber:client', 'body')) is not None: + if _ignore_private_message(stanza, filter_jid): + continue + args = _parse_message(msg) + msgs.append(make_line(tab, **args)) + for msg in reversed(msgs): + to_add.append(msg) + msg_count += 1 + if msg_count == amount: + to_add.reverse() + return to_add + msgs = [] + to_add.reverse() + return to_add + except (IqError, IqTimeout) as exc: + log.debug('Unable to complete MAM query: %s', exc, exc_info=True) + raise MAMQueryException('Query interrupted') + + +async def fetch_history(tab: tabs.ChatTab, + start: Optional[datetime] = None, + end: Optional[datetime] = None, + amount: int = 100) -> List[BaseMessage]: + remote_jid = tab.jid + if not end: + for msg in tab._text_buffer.messages: + if isinstance(msg, Message): + end = msg.time + end -= timedelta(microseconds=1) + break + if end is None: + end = datetime.now() + end = to_utc(end) + end_str = datetime.strftime(end, '%Y-%m-%dT%H:%M:%SZ') + + start_str = None + if start is not None: + start = to_utc(start) + start_str = datetime.strftime(start, '%Y-%m-%dT%H:%M:%SZ') + + mam_iterator = await get_mam_iterator( + core=tab.core, + groupchat=isinstance(tab, tabs.MucTab), + remote_jid=remote_jid, + amount=amount, + end=end_str, + start=start_str, + reverse=True, + ) + return await retrieve_messages(tab, mam_iterator, amount) |