diff options
Diffstat (limited to 'poezio/mam.py')
-rw-r--r-- | poezio/mam.py | 305 |
1 files changed, 141 insertions, 164 deletions
diff --git a/poezio/mam.py b/poezio/mam.py index 0f745f30..7cb1d369 100644 --- a/poezio/mam.py +++ b/poezio/mam.py @@ -1,102 +1,107 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - """ Query and control an archive of messages stored on a server using XEP-0313: Message Archive Management(MAM). """ -import random +from __future__ import annotations + +import logging from datetime import datetime, timedelta, timezone from hashlib import md5 -from typing import Optional, Callable - -from slixmpp import JID +from typing import ( + Any, + AsyncIterable, + Dict, + List, + Optional, +) + +from slixmpp import JID, Message as SMessage from slixmpp.exceptions import IqError, IqTimeout from poezio.theming import get_theme from poezio import tabs -from poezio import xhtml, colors -from poezio.config import config -from poezio.text_buffer import TextBuffer +from poezio import colors +from poezio.common import to_utc +from poezio.ui.types import ( + BaseMessage, + Message, +) + +log = logging.getLogger(__name__) class DiscoInfoException(Exception): pass class MAMQueryException(Exception): pass class NoMAMSupportException(Exception): pass -def add_line( - tab, - text_buffer: TextBuffer, +def make_line( + tab: tabs.ChatTab, text: str, time: datetime, - nick: str, - top: bool, - ) -> None: + jid: JID, + identifier: str = '', + nick: str = '' + ) -> Message: """Adds a textual entry in the TextBuffer""" # Convert to local timezone time = time.replace(tzinfo=timezone.utc).astimezone(tz=None) time = time.replace(tzinfo=None) - deterministic = config.get_by_tabname('deterministic_nick_colors', tab.jid.bare) if isinstance(tab, tabs.MucTab): - nick = nick.split('/')[1] + nick = jid.resource user = tab.get_user_by_name(nick) - if deterministic: - if user: - color = user.color - else: - theme = get_theme() - if theme.ccg_palette: - fg_color = colors.ccg_text_to_color(theme.ccg_palette, nick) - color = fg_color, -1 - else: - mod = len(theme.LIST_COLOR_NICKNAMES) - nick_pos = int(md5(nick.encode('utf-8')).hexdigest(), 16) % mod - color = theme.LIST_COLOR_NICKNAMES[nick_pos] + if user: + color = user.color else: - color = random.choice(list(xhtml.colors)) - color = xhtml.colors.get(color) - color = (color, -1) + theme = get_theme() + if theme.ccg_palette: + fg_color = colors.ccg_text_to_color(theme.ccg_palette, nick) + color = fg_color, -1 + else: + mod = len(theme.LIST_COLOR_NICKNAMES) + nick_pos = int(md5(nick.encode('utf-8')).hexdigest(), 16) % mod + color = theme.LIST_COLOR_NICKNAMES[nick_pos] else: - nick = nick.split('/')[0] - color = get_theme().COLOR_OWN_NICK - text_buffer.add_message( + if jid.bare == tab.core.xmpp.boundjid.bare: + if not nick: + nick = tab.core.own_nick + color = get_theme().COLOR_OWN_NICK + else: + color = get_theme().COLOR_REMOTE_USER + if not nick: + nick = tab.get_nick() + return Message( txt=text, + identifier=identifier, time=time, nickname=nick, nick_color=color, history=True, user=None, - highlight=False, - top=top, - identifier=None, - str_time=None, - jid=None, ) - -async def query( +async def get_mam_iterator( core, groupchat: bool, remote_jid: JID, amount: int, - reverse: bool, - start: Optional[datetime] = None, - end: Optional[datetime] = None, + reverse: bool = True, + start: Optional[str] = None, + end: Optional[str] = None, before: Optional[str] = None, - callback: Optional[Callable] = None, - ) -> None: + ) -> AsyncIterable[SMessage]: + """Get an async iterator for this mam query""" try: query_jid = remote_jid if groupchat else JID(core.xmpp.boundjid.bare) iq = await core.xmpp.plugin['xep_0030'].get_info(jid=query_jid) except (IqError, IqTimeout): - raise DiscoInfoException + raise DiscoInfoException() if 'urn:xmpp:mam:2' not in iq['disco_info'].get_features(): - raise NoMAMSupportException + raise NoMAMSupportException() - args = { + args: Dict[str, Any] = { 'iterator': True, 'reverse': reverse, } @@ -106,129 +111,101 @@ async def query( else: args['with_jid'] = remote_jid - args['rsm'] = {'max': amount} - if reverse: - if before is not None: - args['rsm']['before'] = before - else: - args['end'] = end - else: - args['rsm']['start'] = start - if before is not None: - args['rsm']['end'] = end - try: - results = core.xmpp['xep_0313'].retrieve(**args) - except (IqError, IqTimeout): - raise MAMQueryException - if callback is not None: - callback(results) + if amount > 0: + args['rsm'] = {'max': amount} + args['start'] = start + args['end'] = end + return core.xmpp['xep_0313'].retrieve(**args) + + +def _parse_message(msg: SMessage) -> Dict: + """Parse info inside a MAM forwarded message""" + forwarded = msg['mam_result']['forwarded'] + message = forwarded['stanza'] + return { + 'time': forwarded['delay']['stamp'], + 'jid': message['from'], + 'text': message['body'], + 'identifier': message['origin-id'] + } - return results +def _ignore_private_message(stanza: SMessage, filter_jid: Optional[JID]) -> bool: + """Returns True if a MUC-PM should be ignored, as prosody returns + all PMs within the same room. + """ + if filter_jid is None: + return False + sent = stanza['from'].bare != filter_jid.bare + if sent and stanza['to'].full != filter_jid.full: + return True + elif not sent and stanza['from'].full != filter_jid.full: + return True + return False -async def add_messages_to_buffer(tab, top: bool, results, amount: int) -> bool: - """Prepends or appends messages to the tab text_buffer""" - text_buffer = tab._text_buffer +async def retrieve_messages(tab: tabs.ChatTab, + results: AsyncIterable[SMessage], + amount: int = 100) -> List[BaseMessage]: + """Run the MAM query and put messages in order""" msg_count = 0 msgs = [] - async for rsm in results: - if top: + to_add = [] + tab_is_private = isinstance(tab, tabs.PrivateTab) + filter_jid = None + if tab_is_private: + filter_jid = tab.jid + try: + async for rsm in results: for msg in rsm['mam']['results']: - if msg['mam_result']['forwarded']['stanza'] \ - .xml.find('{%s}%s' % ('jabber:client', 'body')) is not None: - msgs.append(msg) - if msg_count == amount: - tab.core.refresh_window() - return False + stanza = msg['mam_result']['forwarded']['stanza'] + if stanza.xml.find('{%s}%s' % ('jabber:client', 'body')) is not None: + if _ignore_private_message(stanza, filter_jid): + continue + args = _parse_message(msg) + msgs.append(make_line(tab, **args)) + for msg in reversed(msgs): + to_add.append(msg) msg_count += 1 - msgs.reverse() - for msg in msgs: - forwarded = msg['mam_result']['forwarded'] - timestamp = forwarded['delay']['stamp'] - message = forwarded['stanza'] - tab.last_stanza_id = msg['mam_result']['id'] - nick = str(message['from']) - add_line(tab, text_buffer, message['body'], timestamp, nick, top) - else: - for msg in rsm['mam']['results']: - forwarded = msg['mam_result']['forwarded'] - timestamp = forwarded['delay']['stamp'] - message = forwarded['stanza'] - nick = str(message['from']) - add_line(tab, text_buffer, message['body'], timestamp, nick, top) - tab.core.refresh_window() - return False - - -async def fetch_history(tab, end: Optional[datetime] = None, amount: Optional[int] = None): + if msg_count == amount: + to_add.reverse() + return to_add + msgs = [] + to_add.reverse() + return to_add + except (IqError, IqTimeout) as exc: + log.debug('Unable to complete MAM query: %s', exc, exc_info=True) + raise MAMQueryException('Query interrupted') + + +async def fetch_history(tab: tabs.ChatTab, + start: Optional[datetime] = None, + end: Optional[datetime] = None, + amount: int = 100) -> List[BaseMessage]: remote_jid = tab.jid - before = tab.last_stanza_id + if not end: + for msg in tab._text_buffer.messages: + if isinstance(msg, Message): + end = msg.time + end -= timedelta(microseconds=1) + break if end is None: end = datetime.now() - tzone = datetime.now().astimezone().tzinfo - end = end.replace(tzinfo=tzone).astimezone(tz=timezone.utc) - end = end.replace(tzinfo=None) - end = datetime.strftime(end, '%Y-%m-%dT%H:%M:%SZ') - - if amount >= 100: - amount = 99 - - groupchat = isinstance(tab, tabs.MucTab) - - results = await query( - tab.core, - groupchat, - remote_jid, - amount, + end = to_utc(end) + end_str = datetime.strftime(end, '%Y-%m-%dT%H:%M:%SZ') + + start_str = None + if start is not None: + start = to_utc(start) + start_str = datetime.strftime(start, '%Y-%m-%dT%H:%M:%SZ') + + mam_iterator = await get_mam_iterator( + core=tab.core, + groupchat=isinstance(tab, tabs.MucTab), + remote_jid=remote_jid, + amount=amount, + end=end_str, + start=start_str, reverse=True, - end=end, - before=before, ) - query_status = await add_messages_to_buffer(tab, True, results, amount) - tab.query_status = query_status - - -async def on_tab_open(tab) -> None: - amount = 2 * tab.text_win.height - end = datetime.now() - tab.query_status = True - for message in tab._text_buffer.messages: - time = message.time - if time < end: - end = time - end = end + timedelta(seconds=-1) - try: - await fetch_history(tab, end=end, amount=amount) - except (NoMAMSupportException, MAMQueryException, DiscoInfoException): - tab.query_status = False - return None - - -async def on_scroll_up(tab) -> None: - tw = tab.text_win - - # If position in the tab is < two screen pages, then fetch MAM, so that we - # keep some prefetched margin. A first page should also be prefetched on - # join if not already available. - total, pos, height = len(tw.built_lines), tw.pos, tw.height - rest = (total - pos) // height - # Not resetting the state of query_status here, it is changed only after the - # query is complete (in fetch_history) - # This is done to stop message repetition, eg: if the user presses PageUp continuously. - tab.query_status = True - - if rest > 1: - return None - - try: - # XXX: Do we want to fetch a possibly variable number of messages? - # (InfoTab changes height depending on the type of messages, see - # `information_buffer_popup_on`). - await fetch_history(tab, amount=height) - except NoMAMSupportException: - tab.core.information('MAM not supported for %r' % tab.jid, 'Info') - return None - except (MAMQueryException, DiscoInfoException): - tab.core.information('An error occured when fetching MAM for %r' % tab.jid, 'Error') - return None + return await retrieve_messages(tab, mam_iterator, amount) |