From 3afb63a650b8b925ce1ba722dd42b7418f623713 Mon Sep 17 00:00:00 2001
From: Florent Le Coz <louiz@louiz.org>
Date: Sat, 21 Dec 2013 21:04:41 +0100
Subject: Shutdown cleanly on SIGINT

---
 src/bridge/bridge.cpp       |  8 +++++++
 src/bridge/bridge.hpp       |  4 ++++
 src/irc/irc_client.cpp      | 20 +++++++++++++++++
 src/irc/irc_client.hpp      |  6 +++++
 src/main.cpp                | 53 ++++++++++++++++++++++++++++++++++++++++++---
 src/network/poller.cpp      | 45 ++++++++++++++++++++------------------
 src/network/poller.hpp      | 14 +++++++++---
 src/xmpp/xmpp_component.cpp | 18 ++++++++++++++-
 src/xmpp/xmpp_component.hpp | 13 ++++++++++-
 9 files changed, 152 insertions(+), 29 deletions(-)

(limited to 'src')

diff --git a/src/bridge/bridge.cpp b/src/bridge/bridge.cpp
index 973e095..606cb02 100644
--- a/src/bridge/bridge.cpp
+++ b/src/bridge/bridge.cpp
@@ -21,6 +21,14 @@ Bridge::~Bridge()
 {
 }
 
+void Bridge::shutdown()
+{
+  for (auto it = this->irc_clients.begin(); it != this->irc_clients.end(); ++it)
+  {
+    it->second->send_quit_command();
+  }
+}
+
 Xmpp::body Bridge::make_xmpp_body(const std::string& str)
 {
   std::string res;
diff --git a/src/bridge/bridge.hpp b/src/bridge/bridge.hpp
index bbbca95..7a36b59 100644
--- a/src/bridge/bridge.hpp
+++ b/src/bridge/bridge.hpp
@@ -24,6 +24,10 @@ class Bridge
 public:
   explicit Bridge(const std::string& user_jid, XmppComponent* xmpp, Poller* poller);
   ~Bridge();
+  /**
+   * QUIT all connected IRC servers.
+   */
+  void shutdown();
 
   static Xmpp::body make_xmpp_body(const std::string& str);
   /***
diff --git a/src/irc/irc_client.cpp b/src/irc/irc_client.cpp
index ed98653..2115bdc 100644
--- a/src/irc/irc_client.cpp
+++ b/src/irc/irc_client.cpp
@@ -118,6 +118,11 @@ void IrcClient::send_kick_command(const std::string& chan_name, const std::strin
   this->send_message(IrcMessage("KICK", {chan_name, target, reason}));
 }
 
+void IrcClient::send_quit_command()
+{
+  this->send_message(IrcMessage("QUIT", {"gateway shutdown"}));
+}
+
 void IrcClient::send_join_command(const std::string& chan_name)
 {
   if (this->welcomed == false)
@@ -310,6 +315,21 @@ void IrcClient::on_part(const IrcMessage& message)
     }
 }
 
+void IrcClient::on_error(const IrcMessage& message)
+{
+  const std::string leave_message = message.arguments[0];
+  // The user is out of all the channels
+  for (auto it = this->channels.begin(); it != this->channels.end(); ++it)
+  {
+    Iid iid;
+    iid.chan = it->first;
+    iid.server = this->hostname;
+    IrcChannel* channel = it->second.get();
+    std::string own_nick = channel->get_self()->nick;
+    this->bridge->send_muc_leave(std::move(iid), std::move(own_nick), leave_message, true);
+  }
+}
+
 void IrcClient::on_quit(const IrcMessage& message)
 {
   std::string txt;
diff --git a/src/irc/irc_client.hpp b/src/irc/irc_client.hpp
index 4749cac..4038cdf 100644
--- a/src/irc/irc_client.hpp
+++ b/src/irc/irc_client.hpp
@@ -95,6 +95,10 @@ public:
    * Send the KICK irc command
    */
   void send_kick_command(const std::string& chan_name, const std::string& target, const std::string& reason);
+  /**
+   * Send the QUIT irc command
+   */
+  void send_quit_command();
   /**
    * Forward the server message received from IRC to the XMPP component
    */
@@ -139,6 +143,7 @@ public:
    */
   void on_welcome_message(const IrcMessage& message);
   void on_part(const IrcMessage& message);
+  void on_error(const IrcMessage& message);
   void on_nick(const IrcMessage& message);
   void on_kick(const IrcMessage& message);
   void on_mode(const IrcMessage& message);
@@ -216,6 +221,7 @@ static const std::unordered_map<std::string, irc_callback_t> irc_callbacks = {
   {"366", &IrcClient::on_channel_completely_joined},
   {"001", &IrcClient::on_welcome_message},
   {"PART", &IrcClient::on_part},
+  {"ERROR", &IrcClient::on_error},
   {"QUIT", &IrcClient::on_quit},
   {"NICK", &IrcClient::on_nick},
   {"MODE", &IrcClient::on_mode},
diff --git a/src/main.cpp b/src/main.cpp
index 2da180b..6c9560c 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -1,9 +1,19 @@
 #include <xmpp/xmpp_component.hpp>
 #include <network/poller.hpp>
 #include <config/config.hpp>
+#include <logger/logger.hpp>
 
 #include <iostream>
 #include <memory>
+#include <atomic>
+
+#include <signal.h>
+
+// A flag set by the SIGINT signal handler.
+volatile std::atomic<bool> stop(false);
+// A flag indicating that we are wanting to exit the process. i.e: if this
+// flag is set and all connections are closed, we can exit properly.
+static bool exiting = false;
 
 /**
  * Provide an helpful message to help the user write a minimal working
@@ -20,6 +30,11 @@ int config_help(const std::string& missing_option)
   return 1;
 }
 
+static void sigint_handler(int, siginfo_t*, void*)
+{
+  stop = true;
+}
+
 int main(int ac, char** av)
 {
   if (ac > 1)
@@ -44,8 +59,40 @@ int main(int ac, char** av)
 
   Poller p;
   p.add_socket_handler(xmpp_component);
-  xmpp_component->start();
-  while (p.poll())
-    ;
+  if (!xmpp_component->start())
+  {
+    log_info("Exiting");
+    return -1;
+  }
+
+  // Install the signals used to exit the process cleanly, or reload the
+  // config
+  sigset_t mask;
+  sigemptyset(&mask);
+  struct sigaction on_sig;
+  on_sig.sa_sigaction = &sigint_handler;
+  on_sig.sa_mask = mask;
+  // we want to catch that signal only once.
+  // Sending SIGINT again will "force" an exit
+  on_sig.sa_flags = SA_RESETHAND;
+  sigaction(SIGINT, &on_sig, nullptr);
+  sigaction(SIGTERM, &on_sig, nullptr);
+
+  const std::chrono::milliseconds timeout(-1);
+  while (p.poll(timeout) != -1 || !exiting)
+  {
+    if (stop)
+    {
+      log_info("Signal received, exiting...");
+      exiting = true;
+      stop = false;
+      xmpp_component->shutdown();
+    }
+    // If the only existing connection is the one to the XMPP component:
+    // close the XMPP stream.
+    if (exiting && p.size() == 1 && xmpp_component->is_document_open())
+      xmpp_component->close_document();
+  }
+  log_info("All connection cleanely closed, have a nice day.");
   return 0;
 }
diff --git a/src/network/poller.cpp b/src/network/poller.cpp
index 71c7172..919ceb0 100644
--- a/src/network/poller.cpp
+++ b/src/network/poller.cpp
@@ -9,7 +9,6 @@
 Poller::Poller()
 {
 #if POLLER == POLL
-  memset(this->fds, 0, sizeof(this->fds));
   this->nfds = 0;
 #elif POLLER == EPOLL
   this->epfd = ::epoll_create1(0);
@@ -42,9 +41,7 @@ void Poller::add_socket_handler(std::shared_ptr<SocketHandler> socket_handler)
   this->nfds++;
 #endif
 #if POLLER == EPOLL
-  struct epoll_event event;
-  event.data.ptr = socket_handler.get();
-  event.events = EPOLLIN;
+  struct epoll_event event = {EPOLLIN, {socket_handler.get()}};
   const int res = ::epoll_ctl(this->epfd, EPOLL_CTL_ADD, socket_handler->get_socket(), &event);
   if (res == -1)
     {
@@ -99,9 +96,7 @@ void Poller::watch_send_events(SocketHandler* socket_handler)
     }
   throw std::runtime_error("Cannot watch a non-registered socket for send events");
 #elif POLLER == EPOLL
-  struct epoll_event event;
-  event.data.ptr = socket_handler;
-  event.events = EPOLLIN|EPOLLOUT;
+  struct epoll_event event = {EPOLLIN|EPOLLOUT, {socket_handler}};
   const int res = ::epoll_ctl(this->epfd, EPOLL_CTL_MOD, socket_handler->get_socket(), &event);
   if (res == -1)
     {
@@ -124,9 +119,7 @@ void Poller::stop_watching_send_events(SocketHandler* socket_handler)
     }
   throw std::runtime_error("Cannot watch a non-registered socket for send events");
 #elif POLLER == EPOLL
-  struct epoll_event event;
-  event.data.ptr = socket_handler;
-  event.events = EPOLLIN;
+  struct epoll_event event = {EPOLLIN, {socket_handler}};
   const int res = ::epoll_ctl(this->epfd, EPOLL_CTL_MOD, socket_handler->get_socket(), &event);
   if (res == -1)
     {
@@ -136,21 +129,23 @@ void Poller::stop_watching_send_events(SocketHandler* socket_handler)
 #endif
 }
 
-bool Poller::poll()
+int Poller::poll(const std::chrono::milliseconds& timeout)
 {
+  if (this->socket_handlers.size() == 0)
+    return -1;
 #if POLLER == POLL
-  if (this->nfds == 0)
-    return false;
-  int res = ::poll(this->fds, this->nfds, -1);
-  if (res < 0)
+  int nb_events = ::poll(this->fds, this->nfds, timeout.count());
+  if (nb_events < 0)
     {
+      if (errno == EINTR)
+        return true;
       perror("poll");
       throw std::runtime_error("Poll failed");
     }
   // We cannot possibly have more ready events than the number of fds we are
   // watching
-  assert(static_cast<unsigned int>(res) <= this->nfds);
-  for (size_t i = 0; i <= this->nfds && res != 0; ++i)
+  assert(static_cast<unsigned int>(nb_events) <= this->nfds);
+  for (size_t i = 0; i <= this->nfds && nb_events != 0; ++i)
     {
       if (this->fds[i].revents == 0)
         continue;
@@ -158,21 +153,24 @@ bool Poller::poll()
         {
           auto socket_handler = this->socket_handlers.at(this->fds[i].fd);
           socket_handler->on_recv();
-          res--;
+          nb_events--;
         }
       else if (this->fds[i].revents & POLLOUT)
         {
           auto socket_handler = this->socket_handlers.at(this->fds[i].fd);
           socket_handler->on_send();
-          res--;
+          nb_events--;
         }
     }
+  return 1;
 #elif POLLER == EPOLL
   static const size_t max_events = 12;
   struct epoll_event revents[max_events];
-  const int nb_events = epoll_wait(this->epfd, revents, max_events, -1);
+  const int nb_events = ::epoll_wait(this->epfd, revents, max_events, timeout.count());
   if (nb_events == -1)
     {
+      if (errno == EINTR)
+        return 0;
       perror("epoll_wait");
       throw std::runtime_error("Epoll_wait failed");
     }
@@ -184,6 +182,11 @@ bool Poller::poll()
       if (revents[i].events & EPOLLOUT)
         socket_handler->on_send();
     }
+  return nb_events;
 #endif
-  return true;
+}
+
+size_t Poller::size() const
+{
+  return this->socket_handlers.size();
 }
diff --git a/src/network/poller.hpp b/src/network/poller.hpp
index fe52fda..dc087a2 100644
--- a/src/network/poller.hpp
+++ b/src/network/poller.hpp
@@ -5,6 +5,7 @@
 
 #include <unordered_map>
 #include <memory>
+#include <chrono>
 
 #define POLL 1
 #define EPOLL 2
@@ -58,10 +59,17 @@ public:
   void stop_watching_send_events(SocketHandler* socket_handler);
   /**
    * Wait for all watched events, and call the SocketHandlers' callbacks
-   * when one is ready.
-   * Returns false if there are 0 SocketHandler in the list.
+   * when one is ready.  Returns if nothing happened before the provided
+   * timeout.  If the timeout is 0, it waits forever.  If there is no
+   * watched event, returns -1 immediately, ignoring the timeout value.
+   * Otherwise, returns the number of event handled. If 0 is returned this
+   * means that we were interrupted by a signal, or the timeout occured.
    */
-  bool poll();
+  int poll(const std::chrono::milliseconds& timeout);
+  /**
+   * Returns the number of SocketHandlers managed by the poller.
+   */
+  size_t size() const;
 
 private:
   /**
diff --git a/src/xmpp/xmpp_component.cpp b/src/xmpp/xmpp_component.cpp
index 433f87a..dc77934 100644
--- a/src/xmpp/xmpp_component.cpp
+++ b/src/xmpp/xmpp_component.cpp
@@ -24,7 +24,8 @@
 XmppComponent::XmppComponent(const std::string& hostname, const std::string& secret):
   served_hostname(hostname),
   secret(secret),
-  authenticated(false)
+  authenticated(false),
+  doc_open(false)
 {
   this->parser.add_stream_open_callback(std::bind(&XmppComponent::on_remote_stream_open, this,
                                                   std::placeholders::_1));
@@ -51,6 +52,11 @@ bool XmppComponent::start()
   return this->connect("127.0.0.1", "5347");
 }
 
+bool XmppComponent::is_document_open() const
+{
+  return this->doc_open;
+}
+
 void XmppComponent::send_stanza(const Stanza& stanza)
 {
   std::string str = stanza.to_string();
@@ -66,6 +72,7 @@ void XmppComponent::on_connected()
   node["xmlns:stream"] = STREAM_NS;
   node["to"] = this->served_hostname;
   this->send_stanza(node);
+  this->doc_open = true;
 }
 
 void XmppComponent::on_connection_close()
@@ -79,6 +86,14 @@ void XmppComponent::parse_in_buffer()
   this->in_buf.clear();
 }
 
+void XmppComponent::shutdown()
+{
+  for (auto it = this->bridges.begin(); it != this->bridges.end(); ++it)
+  {
+    it->second->shutdown();
+  }
+}
+
 void XmppComponent::on_remote_stream_open(const XmlNode& node)
 {
   log_debug("XMPP DOCUMENT OPEN: " << node.to_string());
@@ -145,6 +160,7 @@ void XmppComponent::close_document()
 {
   log_debug("XMPP SENDING: </stream:stream>");
   this->send_data("</stream:stream>");
+  this->doc_open = false;
 }
 
 void XmppComponent::handle_handshake(const Stanza& stanza)
diff --git a/src/xmpp/xmpp_component.hpp b/src/xmpp/xmpp_component.hpp
index 1a7fc6b..1952e19 100644
--- a/src/xmpp/xmpp_component.hpp
+++ b/src/xmpp/xmpp_component.hpp
@@ -23,7 +23,14 @@ public:
   void on_connected() override final;
   void on_connection_close() override final;
   void parse_in_buffer() override final;
-
+  /**
+   * Send a "close" message to all our connected peers.  That message
+   * depends on the protocol used (this may be a QUIT irc message, or a
+   * <stream/>, etc).  We may also directly close the connection, or we may
+   * wait for the remote peer to acknowledge it before closing.
+   */
+  void shutdown();
+  bool is_document_open() const;
   /**
    * Connect to the XMPP server.
    * Returns false if we failed to connect
@@ -115,6 +122,10 @@ private:
   std::string served_hostname;
   std::string secret;
   bool authenticated;
+  /**
+   * Whether or not OUR XMPP document is open
+   */
+  bool doc_open;
 
   std::unordered_map<std::string, std::function<void(const Stanza&)>> stanza_handlers;
 
-- 
cgit v1.2.3