From 0280343ced6c520700c3ca508e2d04c6b512d319 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?louiz=E2=80=99?= <louiz@louiz.org>
Date: Sat, 10 Feb 2018 19:51:59 +0100
Subject: =?UTF-8?q?Handle=20the=20=E2=80=9Cafter=E2=80=9D=20RSM=20value=20?=
 =?UTF-8?q?to=20page=20through=20results?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 src/database/database.cpp      | 35 ++++++++++++++++++++++++++++++++++-
 src/database/database.hpp      |  9 ++++++++-
 src/xmpp/biboumi_component.cpp | 10 +++++++++-
 tests/end_to_end/__main__.py   | 24 ++++++++++++++++++++++++
 4 files changed, 75 insertions(+), 3 deletions(-)

diff --git a/src/database/database.cpp b/src/database/database.cpp
index c43ace4..2d6fbbd 100644
--- a/src/database/database.cpp
+++ b/src/database/database.cpp
@@ -165,7 +165,7 @@ std::string Database::store_muc_message(const std::string& owner, const std::str
 }
 
 std::vector<Database::MucLogLine> Database::get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server,
-                                                   int limit, const std::string& start, const std::string& end)
+                                                   int limit, const std::string& start, const std::string& end, const Id::real_type after_id)
 {
   auto request = Database::muc_log_lines.select();
   request.where() << Database::Owner{} << "=" << owner << \
@@ -184,6 +184,10 @@ std::vector<Database::MucLogLine> Database::get_muc_logs(const std::string& owne
       if (end_time != -1)
         request << " and " << Database::Date{} << "<=" << end_time;
     }
+  if (after_id != Id::unset_value)
+    {
+      request << " and " << Id{} << ">" << after_id;
+    }
 
   if (limit >= 0)
     request.limit() << limit;
@@ -218,6 +222,35 @@ std::vector<Database::MucLogLine> Database::get_muc_most_recent_logs(const std::
   return {result.crbegin(), result.crend()};
 }
 
+Database::MucLogLine Database::get_muc_log(const std::string& owner, const std::string& chan_name, const std::string& server,
+                                           const std::string& uuid, const std::string& start, const std::string& end)
+{
+  auto request = Database::muc_log_lines.select();
+  request.where() << Database::Owner{} << "=" << owner << \
+          " and " << Database::IrcChanName{} << "=" << chan_name << \
+          " and " << Database::IrcServerName{} << "=" << server << \
+          " and " << Database::Uuid{} << "=" << uuid;
+
+  if (!start.empty())
+    {
+      const auto start_time = utils::parse_datetime(start);
+      if (start_time != -1)
+        request << " and " << Database::Date{} << ">=" << start_time;
+    }
+  if (!end.empty())
+    {
+      const auto end_time = utils::parse_datetime(end);
+      if (end_time != -1)
+        request << " and " << Database::Date{} << "<=" << end_time;
+    }
+
+  auto result = request.execute(*Database::db);
+
+  if (result.empty())
+    throw Database::RecordNotFound{};
+  return result.front();
+}
+
 void Database::add_roster_item(const std::string& local, const std::string& remote)
 {
   auto roster_item = Database::roster.row();
diff --git a/src/database/database.hpp b/src/database/database.hpp
index f9aed2b..810af16 100644
--- a/src/database/database.hpp
+++ b/src/database/database.hpp
@@ -120,15 +120,22 @@ class Database
                                                                                   const std::string& channel);
   /**
    * Get all the lines between (optional) start and end dates, with a (optional) limit.
+   * If after_id is set, only the records after it will be returned.
    */
   static std::vector<MucLogLine> get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server,
-                                              int limit=-1, const std::string& start="", const std::string& end="");
+                                              int limit=-1, const std::string& start="", const std::string& end="",
+                                              const Id::real_type after_id=Id::unset_value);
 
   /**
    * Get the most recent messages from the archive, with optional limit and start date
    */
   static std::vector<MucLogLine> get_muc_most_recent_logs(const std::string& owner, const std::string& chan_name, const std::string& server,
                                               int limit=-1, const std::string& start="");
+  /**
+   * Get just one single record matching the given uuid, between (optional) end and start.
+   * If it does not exist (or is not between end and start), throw a RecordNotFound exception.
+   */
+  static MucLogLine get_muc_log(const std::string& owner, const std::string& chan_name, const std::string& server, const std::string& uuid, const std::string& start="", const std::string& end="");
   static std::string store_muc_message(const std::string& owner, const std::string& chan_name, const std::string& server_name,
                                        time_point date, const std::string& body, const std::string& nick);
 
diff --git a/src/xmpp/biboumi_component.cpp b/src/xmpp/biboumi_component.cpp
index 250007e..cd6d570 100644
--- a/src/xmpp/biboumi_component.cpp
+++ b/src/xmpp/biboumi_component.cpp
@@ -715,11 +715,19 @@ bool BiboumiComponent::handle_mam_request(const Stanza& stanza)
           }
         const XmlNode* set = query->get_child("set", RSM_NS);
         int limit = -1;
+        Id::real_type after_id{Id::unset_value};
         if (set)
           {
             const XmlNode* max = set->get_child("max", RSM_NS);
             if (max)
               limit = std::atoi(max->get_inner().data());
+            const XmlNode* after = set->get_child("after", RSM_NS);
+            if (after)
+              {
+                auto after_record = Database::get_muc_log(from.bare(), iid.get_local(), iid.get_server(),
+                                                          after->get_inner(), start, end);
+                after_id = after_record.col<Id>();
+              }
           }
         // Do not send more than 100 messages, even if the client asked for more,
         // or if it didn’t specify any limit.
@@ -729,7 +737,7 @@ bool BiboumiComponent::handle_mam_request(const Stanza& stanza)
         if ((limit == -1 && start.empty() && end.empty())
             || limit > 100)
           limit = 101;
-        auto lines = Database::get_muc_logs(from.bare(), iid.get_local(), iid.get_server(), limit, start, end);
+        auto lines = Database::get_muc_logs(from.bare(), iid.get_local(), iid.get_server(), limit, start, end, after_id);
         bool complete = true;
         if (lines.size() > 100)
           {
diff --git a/tests/end_to_end/__main__.py b/tests/end_to_end/__main__.py
index 7fa779e..3875a7e 100644
--- a/tests/end_to_end/__main__.py
+++ b/tests/end_to_end/__main__.py
@@ -2164,6 +2164,30 @@ if __name__ == '__main__':
                              "!/iq//mam:fin[@complete='true']",
                              "/iq//mam:fin")),
 
+                     # Retrieve the next page, using the “after” thingy
+                    partial(send_stanza, "<iq to='#foo%{irc_server_one}' from='{jid_one}/{resource_one}' type='set' id='id2'><query xmlns='urn:xmpp:mam:2' queryid='qid2' ><set xmlns='http://jabber.org/protocol/rsm'><after>{last_uuid}</after></set></query></iq>"),
+
+                    partial(expect_stanza,
+                            ("/message/mam:result[@queryid='qid2']/forward:forwarded/delay:delay",
+                            "/message/mam:result[@queryid='qid2']/forward:forwarded/client:message[@from='#foo%{irc_server_one}/{nick_one}'][@type='groupchat']/client:body[text()='101']")
+                            ),
+                  ] + 47 * [
+                    partial(expect_stanza,
+                            ("/message/mam:result[@queryid='qid2']/forward:forwarded/delay:delay",
+                            "/message/mam:result[@queryid='qid2']/forward:forwarded/client:message[@from='#foo%{irc_server_one}/{nick_one}'][@type='groupchat']/client:body")
+                            ),
+                  ] + [
+                    partial(expect_stanza,
+                            ("/message/mam:result[@queryid='qid2']/forward:forwarded/delay:delay",
+                            "/message/mam:result[@queryid='qid2']/forward:forwarded/client:message[@from='#foo%{irc_server_one}/{nick_one}'][@type='groupchat']/client:body[text()='149']"),
+                            after = partial(save_value, "last_uuid", partial(extract_attribute, "/message/mam:result", "id"))
+                            ),
+                     # And it should not be marked as complete
+                    partial(expect_stanza,
+                            ("/iq[@type='result'][@id='id2'][@from='#foo%{irc_server_one}'][@to='{jid_one}/{resource_one}']",
+                             "/iq/mam:fin/rsm:set/rsm:last[text()='{last_uuid}']",
+                             "/iq//mam:fin[@complete='true']",
+                             "/iq//mam:fin")),
                   ]),
         Scenario("channel_history_on_fixed_server",
                  [
-- 
cgit v1.2.3