From 191702a185efc2c904dad9350c603cff20dcb5bb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Maxime=20=E2=80=9Cpep=E2=80=9D=20Buquet?= <pep@bouah.net>
Date: Fri, 11 Mar 2022 16:45:10 +0100
Subject: plugin_e2ee: Ensure all encrypted messages we handle are processed
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: Maxime “pep” Buquet <pep@bouah.net>
---
 poezio/plugin_e2ee.py | 54 +++++++++++++++++++++++++++++++++++++++++++++++----
 1 file changed, 50 insertions(+), 4 deletions(-)

diff --git a/poezio/plugin_e2ee.py b/poezio/plugin_e2ee.py
index 83bee4b6..eba68769 100644
--- a/poezio/plugin_e2ee.py
+++ b/poezio/plugin_e2ee.py
@@ -23,6 +23,8 @@ from typing import (
 
 from slixmpp import InvalidJID, JID, Message
 from slixmpp.xmlstream import StanzaBase
+from slixmpp.xmlstream.handler import CoroutineCallback
+from slixmpp.xmlstream.matcher import MatchXPath
 from poezio.tabs import (
     ChatTab,
     ConversationTab,
@@ -36,6 +38,7 @@ from poezio.theming import get_theme, dump_tuple
 from poezio.config import config
 from poezio.decorators import command_args_parser
 
+import asyncio
 from asyncio import iscoroutinefunction
 
 import logging
@@ -118,7 +121,9 @@ class E2EEPlugin(BasePlugin):
 
     #: Used to figure out what messages to attempt decryption for. Also used
     #: in combination with `tag_whitelist` to avoid removing encrypted tags
-    #: before sending.
+    #: before sending. If multiple tags are present, a handler will be
+    #: registered for each invididual tag/ns pair under <message/>, as opposed
+    #: to a single handler for all tags combined.
     encrypted_tags: Optional[List[Tuple[str, str]]] = None
 
     # Static map, to be able to limit to one encryption mechanism per tab at a
@@ -152,6 +157,16 @@ class E2EEPlugin(BasePlugin):
         self.api.add_event_handler('conversation_msg', self._decrypt_wrapper, priority=0)
         self.api.add_event_handler('private_msg', self._decrypt_wrapper, priority=0)
 
+        # Register a handler for each invididual tag/ns pair in encrypted_tags
+        # as well. as _msg handlers only include messages with a <body/>.
+        if self.encrypted_tags is not None:
+            default_ns = self.core.xmpp.default_ns
+            for i, (namespace, tag) in enumerate(self.encrypted_tags):
+                self.core.xmpp.register_handler(CoroutineCallback(f'EncryptedTag{i}',
+                    MatchXPath(f'{{{default_ns}}}message/{{{namespace}}}{tag}'),
+                    self._decrypt_encryptedtag,
+                ))
+
         # Ensure encryption is done after everything, so that whatever can be
         # encrypted is encrypted, and no plain element slips in.
         # Using a stream filter might be a bit too much, but at least we're
@@ -359,7 +374,7 @@ class E2EEPlugin(BasePlugin):
             return None
         return result
 
-    async def _decrypt_wrapper(self, stanza: Message, tab: ChatTabs) -> Optional[Message]:
+    async def _decrypt_wrapper(self, stanza: Message, tab: Optional[ChatTabs]) -> Optional[Message]:
         """
         Wrapper around _decrypt() to handle errors and display the message after encryption.
         """
@@ -381,7 +396,38 @@ class E2EEPlugin(BasePlugin):
             return None
         return result
 
-    async def _decrypt(self, message: Message, tab: ChatTabs, passthrough: bool = True) -> None:
+    async def _decrypt_encryptedtag(self, stanza: Message) -> None:
+        """
+        Handler to decrypt encrypted_tags elements that are matched separately
+        from other messages because the default 'message' handler that we use
+        only matches messages containing a <body/>.
+        """
+        # If the message contains a body, it will already be handled by the
+        # other handler. If not, pass it to the handler.
+        if stanza.xml.find(f'{{{self.core.xmpp.default_ns}}}body') is not None:
+            return None
+
+        mfrom = stanza['from']
+
+        # Find what tab this message corresponds to.
+        if stanza['type'] == 'groupchat':  # MUC
+            tab = self.core.tabs.by_name_and_class(
+                name=mfrom.bare, cls=MucTab,
+            )
+        elif self.core.handler.is_known_muc_pm(stanza, mfrom):  # MUC-PM
+            tab = self.core.tabs.by_name_and_class(
+                name=mfrom.full, cls=PrivateTab,
+            )
+        else:  # 1:1
+            tab = self.core.get_conversation_by_jid(
+                jid=JID(mfrom.bare),
+                create=False,
+                fallback_barejid=True,
+            )
+        log.debug('Found tab %r for encrypted message', tab)
+        await self._decrypt_wrapper(stanza, tab)
+
+    async def _decrypt(self, message: Message, tab: Optional[ChatTabs], passthrough: bool = True) -> None:
 
         has_eme: bool = False
         if message.xml.find(f'{{{EME_NS}}}{EME_TAG}') is not None and \
@@ -575,7 +621,7 @@ class E2EEPlugin(BasePlugin):
         option_name = f'{self.encryption_short_name}:{fingerprint}'
         return config.getstr(option=option_name, section=jid)
 
-    async def decrypt(self, message: Message, jid: Optional[JID], tab: ChatTab):
+    async def decrypt(self, message: Message, jid: Optional[JID], tab: Optional[ChatTab]):
         """Decryption method
 
         This is a method the plugin must implement.  It is expected that this
-- 
cgit v1.2.3