summaryrefslogtreecommitdiff
path: root/louloulibs/network/tcp_socket_handler.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'louloulibs/network/tcp_socket_handler.cpp')
m---------louloulibs0
-rw-r--r--louloulibs/network/tcp_socket_handler.cpp590
2 files changed, 590 insertions, 0 deletions
diff --git a/louloulibs b/louloulibs
deleted file mode 160000
-Subproject 0f3c1183e2bf0941ae2bffd3f31577bce4f3001
diff --git a/louloulibs/network/tcp_socket_handler.cpp b/louloulibs/network/tcp_socket_handler.cpp
new file mode 100644
index 0000000..f647b86
--- /dev/null
+++ b/louloulibs/network/tcp_socket_handler.cpp
@@ -0,0 +1,590 @@
+#include <network/tcp_socket_handler.hpp>
+#include <network/dns_handler.hpp>
+
+#include <utils/timed_events.hpp>
+#include <utils/scopeguard.hpp>
+#include <network/poller.hpp>
+
+#include <logger/logger.hpp>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <stdexcept>
+#include <unistd.h>
+#include <errno.h>
+#include <netdb.h>
+#include <cstring>
+#include <fcntl.h>
+
+#include <iostream>
+
+#ifdef BOTAN_FOUND
+# include <botan/hex.h>
+
+Botan::AutoSeeded_RNG TCPSocketHandler::rng;
+Permissive_Credentials_Manager TCPSocketHandler::credential_manager;
+Botan::TLS::Policy TCPSocketHandler::policy;
+Botan::TLS::Session_Manager_In_Memory TCPSocketHandler::session_manager(TCPSocketHandler::rng);
+
+#endif
+
+#ifndef UIO_FASTIOV
+# define UIO_FASTIOV 8
+#endif
+
+using namespace std::string_literals;
+using namespace std::chrono_literals;
+
+namespace ph = std::placeholders;
+
+TCPSocketHandler::TCPSocketHandler(std::shared_ptr<Poller> poller):
+ SocketHandler(poller, -1),
+ use_tls(false),
+ connected(false),
+ connecting(false)
+#ifdef CARES_FOUND
+ ,resolving(false),
+ resolved(false),
+ resolved4(false),
+ resolved6(false),
+ cares_addrinfo(nullptr),
+ cares_error()
+#endif
+{}
+
+TCPSocketHandler::~TCPSocketHandler()
+{
+#ifdef CARES_FOUND
+ this->free_cares_addrinfo();
+#endif
+}
+
+void TCPSocketHandler::init_socket(const struct addrinfo* rp)
+{
+ if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1)
+ throw std::runtime_error("Could not create socket: "s + strerror(errno));
+ int optval = 1;
+ if (::setsockopt(this->socket, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) == -1)
+ log_warning("Failed to enable TCP keepalive on socket: " << strerror(errno));
+ // Set the socket on non-blocking mode. This is useful to receive a EAGAIN
+ // error when connect() would block, to not block the whole process if a
+ // remote is not responsive.
+ const int existing_flags = ::fcntl(this->socket, F_GETFL, 0);
+ if ((existing_flags == -1) ||
+ (::fcntl(this->socket, F_SETFL, existing_flags | O_NONBLOCK) == -1))
+ throw std::runtime_error("Could not initialize socket: "s + strerror(errno));
+}
+
+void TCPSocketHandler::connect(const std::string& address, const std::string& port, const bool tls)
+{
+ this->address = address;
+ this->port = port;
+ this->use_tls = tls;
+
+ utils::ScopeGuard sg;
+
+ struct addrinfo* addr_res;
+
+ if (!this->connecting)
+ {
+ // Get the addrinfo from getaddrinfo (or ares_gethostbyname), only if
+ // this is the first call of this function.
+#ifdef CARES_FOUND
+ if (!this->resolved)
+ {
+ log_info("Trying to connect to " << address << ":" << port);
+ // Start the asynchronous process of resolving the hostname. Once
+ // the addresses have been found and `resolved` has been set to true
+ // (but connecting will still be false), TCPSocketHandler::connect()
+ // needs to be called, again.
+ this->resolving = true;
+ DNSHandler::instance.gethostbyname(address, this, AF_INET6);
+ DNSHandler::instance.gethostbyname(address, this, AF_INET);
+ return;
+ }
+ else
+ {
+ // The c-ares resolved the hostname and the available addresses
+ // where saved in the cares_addrinfo linked list. Now, just use
+ // this list to try to connect.
+ addr_res = this->cares_addrinfo;
+ if (!addr_res)
+ {
+ this->close();
+ this->on_connection_failed(this->cares_error);
+ return ;
+ }
+ }
+#else
+ log_info("Trying to connect to " << address << ":" << port);
+ struct addrinfo hints;
+ memset(&hints, 0, sizeof(struct addrinfo));
+ hints.ai_flags = 0;
+ hints.ai_family = AF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_protocol = 0;
+
+ const int res = ::getaddrinfo(address.c_str(), port.c_str(), &hints, &addr_res);
+
+ if (res != 0)
+ {
+ log_warning("getaddrinfo failed: "s + gai_strerror(res));
+ this->close();
+ this->on_connection_failed(gai_strerror(res));
+ return ;
+ }
+ // Make sure the alloced structure is always freed at the end of the
+ // function
+ sg.add_callback([&addr_res](){ freeaddrinfo(addr_res); });
+#endif
+ }
+ else
+ { // This function is called again, use the saved addrinfo structure,
+ // instead of re-doing the whole getaddrinfo process.
+ addr_res = &this->addrinfo;
+ }
+
+ for (struct addrinfo* rp = addr_res; rp; rp = rp->ai_next)
+ {
+ if (!this->connecting)
+ {
+ try {
+ this->init_socket(rp);
+ }
+ catch (const std::runtime_error& error) {
+ log_error("Failed to init socket: " << error.what());
+ break;
+ }
+ }
+ if (::connect(this->socket, rp->ai_addr, rp->ai_addrlen) == 0
+ || errno == EISCONN)
+ {
+ log_info("Connection success.");
+ TimedEventsManager::instance().cancel("connection_timeout"s +
+ std::to_string(this->socket));
+ this->poller->add_socket_handler(this);
+ this->connected = true;
+ this->connecting = false;
+#ifdef BOTAN_FOUND
+ if (this->use_tls)
+ this->start_tls();
+#endif
+ this->on_connected();
+ return ;
+ }
+ else if (errno == EINPROGRESS || errno == EALREADY)
+ { // retry this process later, when the socket
+ // is ready to be written on.
+ this->connecting = true;
+ this->poller->add_socket_handler(this);
+ this->poller->watch_send_events(this);
+ // Save the addrinfo structure, to use it on the next call
+ this->ai_addrlen = rp->ai_addrlen;
+ memcpy(&this->ai_addr, rp->ai_addr, this->ai_addrlen);
+ memcpy(&this->addrinfo, rp, sizeof(struct addrinfo));
+ this->addrinfo.ai_addr = reinterpret_cast<struct sockaddr*>(&this->ai_addr);
+ this->addrinfo.ai_next = nullptr;
+ // If the connection has not succeeded or failed in 5s, we consider
+ // it to have failed
+ TimedEventsManager::instance().add_event(
+ TimedEvent(std::chrono::steady_clock::now() + 5s,
+ std::bind(&TCPSocketHandler::on_connection_timeout, this),
+ "connection_timeout"s + std::to_string(this->socket)));
+ return ;
+ }
+ log_info("Connection failed:" << strerror(errno));
+ }
+ log_error("All connection attempts failed.");
+ this->close();
+ this->on_connection_failed(strerror(errno));
+ return ;
+}
+
+void TCPSocketHandler::on_connection_timeout()
+{
+ this->close();
+ this->on_connection_failed("connection timed out");
+}
+
+void TCPSocketHandler::connect()
+{
+ this->connect(this->address, this->port, this->use_tls);
+}
+
+void TCPSocketHandler::on_recv()
+{
+#ifdef BOTAN_FOUND
+ if (this->use_tls)
+ this->tls_recv();
+ else
+#endif
+ this->plain_recv();
+}
+
+void TCPSocketHandler::plain_recv()
+{
+ static constexpr size_t buf_size = 4096;
+ char buf[buf_size];
+ void* recv_buf = this->get_receive_buffer(buf_size);
+
+ if (recv_buf == nullptr)
+ recv_buf = buf;
+
+ const ssize_t size = this->do_recv(recv_buf, buf_size);
+
+ if (size > 0)
+ {
+ if (buf == recv_buf)
+ {
+ // data needs to be placed in the in_buf string, because no buffer
+ // was provided to receive that data directly. The in_buf buffer
+ // will be handled in parse_in_buffer()
+ this->in_buf += std::string(buf, size);
+ }
+ this->parse_in_buffer(size);
+ }
+}
+
+ssize_t TCPSocketHandler::do_recv(void* recv_buf, const size_t buf_size)
+{
+ ssize_t size = ::recv(this->socket, recv_buf, buf_size, 0);
+ if (0 == size)
+ {
+ this->on_connection_close("");
+ this->close();
+ }
+ else if (-1 == size)
+ {
+ log_warning("Error while reading from socket: " << strerror(errno));
+ // Remember if we were connecting, or already connected when this
+ // happened, because close() sets this->connecting to false
+ const auto were_connecting = this->connecting;
+ this->close();
+ if (were_connecting)
+ this->on_connection_failed(strerror(errno));
+ else
+ this->on_connection_close(strerror(errno));
+ }
+ return size;
+}
+
+void TCPSocketHandler::on_send()
+{
+ struct iovec msg_iov[UIO_FASTIOV] = {};
+ struct msghdr msg{nullptr, 0,
+ msg_iov,
+ 0, nullptr, 0, 0};
+ for (std::string& s: this->out_buf)
+ {
+ // unconsting the content of s is ok, sendmsg will never modify it
+ msg_iov[msg.msg_iovlen].iov_base = const_cast<char*>(s.data());
+ msg_iov[msg.msg_iovlen].iov_len = s.size();
+ if (++msg.msg_iovlen == UIO_FASTIOV)
+ break;
+ }
+ ssize_t res = ::sendmsg(this->socket, &msg, MSG_NOSIGNAL);
+ if (res < 0)
+ {
+ log_error("sendmsg failed: " << strerror(errno));
+ this->on_connection_close(strerror(errno));
+ this->close();
+ }
+ else
+ {
+ // remove all the strings that were successfully sent.
+ for (auto it = this->out_buf.begin();
+ it != this->out_buf.end();)
+ {
+ if (static_cast<size_t>(res) >= (*it).size())
+ {
+ res -= (*it).size();
+ it = this->out_buf.erase(it);
+ }
+ else
+ {
+ // If one string has partially been sent, we use substr to
+ // crop it
+ if (res > 0)
+ (*it) = (*it).substr(res, std::string::npos);
+ break;
+ }
+ }
+ if (this->out_buf.empty())
+ this->poller->stop_watching_send_events(this);
+ }
+}
+
+void TCPSocketHandler::close()
+{
+ TimedEventsManager::instance().cancel("connection_timeout"s +
+ std::to_string(this->socket));
+ if (this->connected || this->connecting)
+ this->poller->remove_socket_handler(this->get_socket());
+ if (this->socket != -1)
+ {
+ ::close(this->socket);
+ this->socket = -1;
+ }
+ this->connected = false;
+ this->connecting = false;
+#ifdef CARES_FOUND
+ this->resolving = false;
+ this->resolved = false;
+ this->resolved4 = false;
+ this->resolved6 = false;
+ this->free_cares_addrinfo();
+ this->cares_error.clear();
+#endif
+ this->in_buf.clear();
+ this->out_buf.clear();
+ this->port.clear();
+}
+
+void TCPSocketHandler::send_data(std::string&& data)
+{
+#ifdef BOTAN_FOUND
+ if (this->use_tls)
+ this->tls_send(std::move(data));
+ else
+#endif
+ this->raw_send(std::move(data));
+}
+
+void TCPSocketHandler::raw_send(std::string&& data)
+{
+ if (data.empty())
+ return ;
+ this->out_buf.emplace_back(std::move(data));
+ if (this->connected)
+ this->poller->watch_send_events(this);
+}
+
+void TCPSocketHandler::send_pending_data()
+{
+ if (this->connected && !this->out_buf.empty())
+ this->poller->watch_send_events(this);
+}
+
+bool TCPSocketHandler::is_connected() const
+{
+ return this->connected;
+}
+
+bool TCPSocketHandler::is_connecting() const
+{
+#ifdef CARES_FOUND
+ return this->connecting || this->resolving;
+#else
+ return this->connecting;
+#endif
+}
+
+void* TCPSocketHandler::get_receive_buffer(const size_t) const
+{
+ return nullptr;
+}
+
+#ifdef BOTAN_FOUND
+void TCPSocketHandler::start_tls()
+{
+ Botan::TLS::Server_Information server_info(this->address, "irc", std::stoul(this->port));
+ this->tls = std::make_unique<Botan::TLS::Client>(
+ std::bind(&TCPSocketHandler::tls_output_fn, this, ph::_1, ph::_2),
+ std::bind(&TCPSocketHandler::tls_data_cb, this, ph::_1, ph::_2),
+ std::bind(&TCPSocketHandler::tls_alert_cb, this, ph::_1, ph::_2, ph::_3),
+ std::bind(&TCPSocketHandler::tls_handshake_cb, this, ph::_1),
+ session_manager, credential_manager, policy,
+ rng, server_info, Botan::TLS::Protocol_Version::latest_tls_version());
+}
+
+void TCPSocketHandler::tls_recv()
+{
+ static constexpr size_t buf_size = 4096;
+ char recv_buf[buf_size];
+
+ const ssize_t size = this->do_recv(recv_buf, buf_size);
+ if (size > 0)
+ {
+ const bool was_active = this->tls->is_active();
+ this->tls->received_data(reinterpret_cast<const Botan::byte*>(recv_buf),
+ static_cast<size_t>(size));
+ if (!was_active && this->tls->is_active())
+ this->on_tls_activated();
+ }
+}
+
+void TCPSocketHandler::tls_send(std::string&& data)
+{
+ if (this->tls->is_active())
+ {
+ const bool was_active = this->tls->is_active();
+ if (!this->pre_buf.empty())
+ {
+ this->tls->send(reinterpret_cast<const Botan::byte*>(this->pre_buf.data()),
+ this->pre_buf.size());
+ this->pre_buf = "";
+ }
+ if (!data.empty())
+ this->tls->send(reinterpret_cast<const Botan::byte*>(data.data()),
+ data.size());
+ if (!was_active && this->tls->is_active())
+ this->on_tls_activated();
+ }
+ else
+ this->pre_buf += data;
+}
+
+void TCPSocketHandler::tls_data_cb(const Botan::byte* data, size_t size)
+{
+ this->in_buf += std::string(reinterpret_cast<const char*>(data),
+ size);
+ if (!this->in_buf.empty())
+ this->parse_in_buffer(size);
+}
+
+void TCPSocketHandler::tls_output_fn(const Botan::byte* data, size_t size)
+{
+ this->raw_send(std::string(reinterpret_cast<const char*>(data), size));
+}
+
+void TCPSocketHandler::tls_alert_cb(Botan::TLS::Alert alert, const Botan::byte*, size_t)
+{
+ log_debug("tls_alert: " << alert.type_string());
+}
+
+bool TCPSocketHandler::tls_handshake_cb(const Botan::TLS::Session& session)
+{
+ log_debug("Handshake with " << session.server_info().hostname() << " complete."
+ << " Version: " << session.version().to_string()
+ << " using " << session.ciphersuite().to_string());
+ if (!session.session_id().empty())
+ log_debug("Session ID " << Botan::hex_encode(session.session_id()));
+ if (!session.session_ticket().empty())
+ log_debug("Session ticket " << Botan::hex_encode(session.session_ticket()));
+ return true;
+}
+
+void TCPSocketHandler::on_tls_activated()
+{
+ this->send_data("");
+}
+
+void Permissive_Credentials_Manager::verify_certificate_chain(const std::string& type,
+ const std::string& purported_hostname,
+ const std::vector<Botan::X509_Certificate>&)
+{ // TODO: Offer the admin to disallow connection on untrusted
+ // certificates
+ log_debug("Checking remote certificate (" << type << ") for hostname " << purported_hostname);
+}
+
+#endif // BOTAN_FOUND
+
+#ifdef CARES_FOUND
+
+void TCPSocketHandler::on_hostname4_resolved(int status, struct hostent* hostent)
+{
+ this->resolved4 = true;
+ if (status == ARES_SUCCESS)
+ this->fill_ares_addrinfo4(hostent);
+ else
+ this->cares_error = ::ares_strerror(status);
+
+ if (this->resolved4 && this->resolved6)
+ {
+ this->resolved = true;
+ this->resolving = false;
+ this->connect();
+ }
+}
+
+void TCPSocketHandler::on_hostname6_resolved(int status, struct hostent* hostent)
+{
+ this->resolved6 = true;
+ if (status == ARES_SUCCESS)
+ this->fill_ares_addrinfo6(hostent);
+ else
+ this->cares_error = ::ares_strerror(status);
+
+ if (this->resolved4 && this->resolved6)
+ {
+ this->resolved = true;
+ this->resolving = false;
+ this->connect();
+ }
+}
+
+void TCPSocketHandler::fill_ares_addrinfo4(const struct hostent* hostent)
+{
+ struct addrinfo* prev = this->cares_addrinfo;
+ struct in_addr** address = reinterpret_cast<struct in_addr**>(hostent->h_addr_list);
+
+ while (*address)
+ {
+ // Create a new addrinfo list element, and fill it
+ struct addrinfo* current = new struct addrinfo;
+ current->ai_flags = 0;
+ current->ai_family = hostent->h_addrtype;
+ current->ai_socktype = SOCK_STREAM;
+ current->ai_protocol = 0;
+ current->ai_addrlen = sizeof(struct sockaddr_in);
+
+ struct sockaddr_in* addr = new struct sockaddr_in;
+ addr->sin_family = hostent->h_addrtype;
+ addr->sin_port = htons(strtoul(this->port.data(), nullptr, 10));
+ addr->sin_addr.s_addr = (*address)->s_addr;
+
+ current->ai_addr = reinterpret_cast<struct sockaddr*>(addr);
+ current->ai_next = nullptr;
+ current->ai_canonname = nullptr;
+
+ current->ai_next = prev;
+ this->cares_addrinfo = current;
+ prev = current;
+ ++address;
+ }
+}
+
+void TCPSocketHandler::fill_ares_addrinfo6(const struct hostent* hostent)
+{
+ struct addrinfo* prev = this->cares_addrinfo;
+ struct in6_addr** address = reinterpret_cast<struct in6_addr**>(hostent->h_addr_list);
+
+ while (*address)
+ {
+ // Create a new addrinfo list element, and fill it
+ struct addrinfo* current = new struct addrinfo;
+ current->ai_flags = 0;
+ current->ai_family = hostent->h_addrtype;
+ current->ai_socktype = SOCK_STREAM;
+ current->ai_protocol = 0;
+ current->ai_addrlen = sizeof(struct sockaddr_in6);
+
+ struct sockaddr_in6* addr = new struct sockaddr_in6;
+ addr->sin6_family = hostent->h_addrtype;
+ addr->sin6_port = htons(strtoul(this->port.data(), nullptr, 10));
+ ::memcpy(addr->sin6_addr.s6_addr, (*address)->s6_addr, 16);
+ addr->sin6_flowinfo = 0;
+ addr->sin6_scope_id = 0;
+
+ current->ai_addr = reinterpret_cast<struct sockaddr*>(addr);
+ current->ai_next = nullptr;
+ current->ai_canonname = nullptr;
+
+ current->ai_next = prev;
+ this->cares_addrinfo = current;
+ prev = current;
+ ++address;
+ }
+}
+
+void TCPSocketHandler::free_cares_addrinfo()
+{
+ while (this->cares_addrinfo)
+ {
+ delete this->cares_addrinfo->ai_addr;
+ auto next = this->cares_addrinfo->ai_next;
+ delete this->cares_addrinfo;
+ this->cares_addrinfo = next;
+ }
+}
+
+#endif // CARES_FOUND