summaryrefslogtreecommitdiff
path: root/src/network/poller.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/network/poller.cpp')
-rw-r--r--src/network/poller.cpp238
1 files changed, 238 insertions, 0 deletions
diff --git a/src/network/poller.cpp b/src/network/poller.cpp
new file mode 100644
index 0000000..0f02cc5
--- /dev/null
+++ b/src/network/poller.cpp
@@ -0,0 +1,238 @@
+#include <network/poller.hpp>
+#include <logger/logger.hpp>
+#include <utils/timed_events.hpp>
+
+#include <cassert>
+#include <cerrno>
+#include <stdio.h>
+#include <signal.h>
+#include <unistd.h>
+
+#include <cstring>
+#include <iostream>
+#include <stdexcept>
+
+Poller::Poller()
+{
+#if POLLER == POLL
+ this->nfds = 0;
+#elif POLLER == EPOLL
+ this->epfd = ::epoll_create1(0);
+ if (this->epfd == -1)
+ {
+ log_error("epoll failed: ", strerror(errno));
+ throw std::runtime_error("Could not create epoll instance");
+ }
+#endif
+}
+
+Poller::~Poller()
+{
+#if POLLER == EPOLL
+ if (this->epfd > 0)
+ ::close(this->epfd);
+#endif
+}
+
+void Poller::add_socket_handler(SocketHandler* socket_handler)
+{
+ // Don't do anything if the socket is already managed
+ const auto it = this->socket_handlers.find(socket_handler->get_socket());
+ if (it != this->socket_handlers.end())
+ return ;
+
+ this->socket_handlers.emplace(socket_handler->get_socket(), socket_handler);
+
+ // We always watch all sockets for receive events
+#if POLLER == POLL
+ this->fds[this->nfds].fd = socket_handler->get_socket();
+ this->fds[this->nfds].events = POLLIN;
+ this->nfds++;
+#endif
+#if POLLER == EPOLL
+ struct epoll_event event = {EPOLLIN, {socket_handler}};
+ const int res = ::epoll_ctl(this->epfd, EPOLL_CTL_ADD, socket_handler->get_socket(), &event);
+ if (res == -1)
+ {
+ log_error("epoll_ctl failed: ", strerror(errno));
+ throw std::runtime_error("Could not add socket to epoll");
+ }
+#endif
+}
+
+void Poller::remove_socket_handler(const socket_t socket)
+{
+ const auto it = this->socket_handlers.find(socket);
+ if (it == this->socket_handlers.end())
+ throw std::runtime_error("Trying to remove a SocketHandler that is not managed");
+ this->socket_handlers.erase(it);
+
+#if POLLER == POLL
+ for (size_t i = 0; i < this->nfds; i++)
+ {
+ if (this->fds[i].fd == socket)
+ {
+ // Move all subsequent pollfd by one on the left, erasing the
+ // value of the one we remove
+ for (size_t j = i; j < this->nfds - 1; ++j)
+ {
+ this->fds[j].fd = this->fds[j+1].fd;
+ this->fds[j].events= this->fds[j+1].events;
+ }
+ this->nfds--;
+ }
+ }
+#elif POLLER == EPOLL
+ const int res = ::epoll_ctl(this->epfd, EPOLL_CTL_DEL, socket, nullptr);
+ if (res == -1)
+ {
+ log_error("epoll_ctl failed: ", strerror(errno));
+ throw std::runtime_error("Could not remove socket from epoll");
+ }
+#endif
+}
+
+void Poller::watch_send_events(SocketHandler* socket_handler)
+{
+#if POLLER == POLL
+ for (size_t i = 0; i < this->nfds; ++i)
+ {
+ if (this->fds[i].fd == socket_handler->get_socket())
+ {
+ this->fds[i].events = POLLIN|POLLOUT;
+ return;
+ }
+ }
+ throw std::runtime_error("Cannot watch a non-registered socket for send events");
+#elif POLLER == EPOLL
+ struct epoll_event event = {EPOLLIN|EPOLLOUT, {socket_handler}};
+ const int res = ::epoll_ctl(this->epfd, EPOLL_CTL_MOD, socket_handler->get_socket(), &event);
+ if (res == -1)
+ {
+ log_error("epoll_ctl failed: ", strerror(errno));
+ throw std::runtime_error("Could not modify socket flags in epoll");
+ }
+#endif
+}
+
+void Poller::stop_watching_send_events(SocketHandler* socket_handler)
+{
+#if POLLER == POLL
+ for (size_t i = 0; i <= this->nfds; ++i)
+ {
+ if (this->fds[i].fd == socket_handler->get_socket())
+ {
+ this->fds[i].events = POLLIN;
+ return;
+ }
+ }
+ throw std::runtime_error("Cannot watch a non-registered socket for send events");
+#elif POLLER == EPOLL
+ struct epoll_event event = {EPOLLIN, {socket_handler}};
+ const int res = ::epoll_ctl(this->epfd, EPOLL_CTL_MOD, socket_handler->get_socket(), &event);
+ if (res == -1)
+ {
+ log_error("epoll_ctl failed: ", strerror(errno));
+ throw std::runtime_error("Could not modify socket flags in epoll");
+ }
+#endif
+}
+
+int Poller::poll(const std::chrono::milliseconds& timeout)
+{
+ if (this->socket_handlers.empty() && timeout == utils::no_timeout)
+ return -1;
+#if POLLER == POLL
+ // Convert our nice timeout into this ugly struct
+ struct timespec timeout_ts;
+ struct timespec* timeout_tsp;
+ if (timeout > 0s)
+ {
+ auto seconds = std::chrono::duration_cast<std::chrono::seconds>(timeout);
+ timeout_ts.tv_sec = seconds.count();
+ timeout_ts.tv_nsec = std::chrono::duration_cast<std::chrono::nanoseconds>(timeout - seconds).count();
+ timeout_tsp = &timeout_ts;
+ }
+ else
+ timeout_tsp = nullptr;
+
+ // Unblock all signals, only during the ppoll call
+ sigset_t empty_signal_set;
+ sigemptyset(&empty_signal_set);
+ int nb_events = ::ppoll(this->fds, this->nfds, timeout_tsp,
+ &empty_signal_set);
+ if (nb_events < 0)
+ {
+ if (errno == EINTR)
+ return true;
+ log_error("poll failed: ", strerror(errno));
+ throw std::runtime_error("Poll failed");
+ }
+ // We cannot possibly have more ready events than the number of fds we are
+ // watching
+ assert(static_cast<unsigned int>(nb_events) <= this->nfds);
+ for (size_t i = 0; i < this->nfds && nb_events != 0; ++i)
+ {
+ auto socket_handler = this->socket_handlers.at(this->fds[i].fd);
+ if (this->fds[i].revents == 0)
+ continue;
+ else if (this->fds[i].revents & POLLIN && socket_handler->is_connected())
+ {
+ socket_handler->on_recv();
+ nb_events--;
+ }
+ else if (this->fds[i].revents & POLLOUT && socket_handler->is_connected())
+ {
+ socket_handler->on_send();
+ nb_events--;
+ }
+ else if (this->fds[i].revents & POLLOUT ||
+ this->fds[i].revents & POLLIN)
+ {
+ socket_handler->connect();
+ nb_events--;
+ }
+ }
+ return 1;
+#elif POLLER == EPOLL
+ static const size_t max_events = 12;
+ struct epoll_event revents[max_events];
+ // Unblock all signals, only during the epoll_pwait call
+ sigset_t empty_signal_set{};
+ sigemptyset(&empty_signal_set);
+
+ int real_timeout = std::numeric_limits<int>::max();
+ if (timeout.count() < real_timeout) // Just avoid any potential int overflow
+ real_timeout = static_cast<int>(timeout.count());
+ const int nb_events = ::epoll_pwait(this->epfd, revents, max_events, real_timeout,
+ &empty_signal_set);
+ if (nb_events == -1)
+ {
+ if (errno == EINTR)
+ return 0;
+ log_error("epoll wait: ", strerror(errno));
+ throw std::runtime_error("Epoll_wait failed");
+ }
+ for (int i = 0; i < nb_events; ++i)
+ {
+ auto socket_handler = static_cast<SocketHandler*>(revents[i].data.ptr);
+ if (revents[i].events & EPOLLIN && socket_handler->is_connected())
+ socket_handler->on_recv();
+ else if (revents[i].events & EPOLLOUT && socket_handler->is_connected())
+ socket_handler->on_send();
+ else if (revents[i].events & EPOLLOUT)
+ socket_handler->connect();
+ }
+ return nb_events;
+#endif
+}
+
+size_t Poller::size() const
+{
+ return this->socket_handlers.size();
+}
+
+bool Poller::is_managing_socket(const socket_t socket) const
+{
+ return (this->socket_handlers.find(socket) != this->socket_handlers.end());
+}