diff options
Diffstat (limited to 'louloulibs/network')
-rw-r--r-- | louloulibs/network/credentials_manager.cpp | 35 | ||||
-rw-r--r-- | louloulibs/network/credentials_manager.hpp | 15 | ||||
-rw-r--r-- | louloulibs/network/dns_handler.cpp | 126 | ||||
-rw-r--r-- | louloulibs/network/dns_handler.hpp | 45 | ||||
-rw-r--r-- | louloulibs/network/dns_socket_handler.cpp | 34 | ||||
-rw-r--r-- | louloulibs/network/dns_socket_handler.hpp | 34 | ||||
-rw-r--r-- | louloulibs/network/resolver.cpp | 282 | ||||
-rw-r--r-- | louloulibs/network/resolver.hpp | 47 | ||||
-rw-r--r-- | louloulibs/network/socket_handler.hpp | 6 | ||||
-rw-r--r-- | louloulibs/network/tcp_client_socket_handler.cpp | 257 | ||||
-rw-r--r-- | louloulibs/network/tcp_client_socket_handler.hpp | 82 | ||||
-rw-r--r-- | louloulibs/network/tcp_server_socket.hpp | 70 | ||||
-rw-r--r-- | louloulibs/network/tcp_socket_handler.cpp | 301 | ||||
-rw-r--r-- | louloulibs/network/tcp_socket_handler.hpp | 152 |
14 files changed, 835 insertions, 651 deletions
diff --git a/louloulibs/network/credentials_manager.cpp b/louloulibs/network/credentials_manager.cpp index ed04d24..289307b 100644 --- a/louloulibs/network/credentials_manager.cpp +++ b/louloulibs/network/credentials_manager.cpp @@ -37,6 +37,28 @@ void BasicCredentialsManager::set_trusted_fingerprint(const std::string& fingerp this->trusted_fingerprint = fingerprint; } +const std::string& BasicCredentialsManager::get_trusted_fingerprint() const +{ + return this->trusted_fingerprint; +} + +void check_tls_certificate(const std::vector<Botan::X509_Certificate>& certs, + const std::string& hostname, const std::string& trusted_fingerprint, + std::exception_ptr exc) +{ + + if (!trusted_fingerprint.empty() && !certs.empty() && + trusted_fingerprint == certs[0].fingerprint() && + certs[0].matches_dns_name(hostname)) + // We trust the certificate, based on the trusted fingerprint and + // the fact that the hostname matches + return; + + if (exc) + std::rethrow_exception(exc); +} + +#if BOTAN_VERSION_CODE < BOTAN_VERSION_CODE_FOR(1,11,34) void BasicCredentialsManager::verify_certificate_chain(const std::string& type, const std::string& purported_hostname, const std::vector<Botan::X509_Certificate>& certs) @@ -50,17 +72,14 @@ void BasicCredentialsManager::verify_certificate_chain(const std::string& type, catch (const std::exception& tls_exception) { log_warning("TLS certificate check failed: ", tls_exception.what()); - if (!this->trusted_fingerprint.empty() && !certs.empty() && - this->trusted_fingerprint == certs[0].fingerprint() && - certs[0].matches_dns_name(purported_hostname)) - // We trust the certificate, based on the trusted fingerprint and - // the fact that the hostname matches - return; - + std::exception_ptr exception_ptr{}; if (this->socket_handler->abort_on_invalid_cert()) - throw; + exception_ptr = std::current_exception(); + + check_tls_certificate(certs, purported_hostname, this->trusted_fingerprint, exception_ptr); } } +#endif bool BasicCredentialsManager::try_to_open_one_ca_bundle(const std::vector<std::string>& paths) { diff --git a/louloulibs/network/credentials_manager.hpp b/louloulibs/network/credentials_manager.hpp index 7557372..29ee024 100644 --- a/louloulibs/network/credentials_manager.hpp +++ b/louloulibs/network/credentials_manager.hpp @@ -9,6 +9,18 @@ class TCPSocketHandler; +/** + * If the given cert isn’t valid, based on the given hostname + * and fingerprint, then throws the exception if it’s non-empty. + * + * Must be called after the standard (from Botan) way of + * checking the certificate, if we want to also accept certificates based + * on a trusted fingerprint. + */ +void check_tls_certificate(const std::vector<Botan::X509_Certificate>& certs, + const std::string& hostname, const std::string& trusted_fingerprint, + std::exception_ptr exc); + class BasicCredentialsManager: public Botan::Credentials_Manager { public: @@ -19,12 +31,15 @@ public: BasicCredentialsManager& operator=(const BasicCredentialsManager&) = delete; BasicCredentialsManager& operator=(BasicCredentialsManager&&) = delete; +#if BOTAN_VERSION_CODE < BOTAN_VERSION_CODE_FOR(1,11,34) void verify_certificate_chain(const std::string& type, const std::string& purported_hostname, const std::vector<Botan::X509_Certificate>&) override final; +#endif std::vector<Botan::Certificate_Store*> trusted_certificate_authorities(const std::string& type, const std::string& context) override final; void set_trusted_fingerprint(const std::string& fingerprint); + const std::string& get_trusted_fingerprint() const; private: const TCPSocketHandler* const socket_handler; diff --git a/louloulibs/network/dns_handler.cpp b/louloulibs/network/dns_handler.cpp index fef0cfc..fbd2763 100644 --- a/louloulibs/network/dns_handler.cpp +++ b/louloulibs/network/dns_handler.cpp @@ -1,5 +1,5 @@ #include <louloulibs.h> -#ifdef CARES_FOUND +#ifdef UDNS_FOUND #include <network/dns_socket_handler.hpp> #include <network/dns_handler.hpp> @@ -7,124 +7,40 @@ #include <utils/timed_events.hpp> -#include <algorithm> -#include <stdexcept> +#include <udns.h> +#include <cerrno> +#include <cstring> -DNSHandler DNSHandler::instance; +class Resolver; using namespace std::string_literals; -DNSHandler::DNSHandler(): - socket_handlers{}, - channel{nullptr} -{ - int ares_error; - if ((ares_error = ::ares_library_init(ARES_LIB_INIT_ALL)) != 0) - throw std::runtime_error("Failed to initialize c-ares lib: "s + ares_strerror(ares_error)); - struct ares_options options = {}; - // The default timeout values are way too high - options.timeout = 1000; - options.tries = 3; - if ((ares_error = ::ares_init_options(&this->channel, - &options, - ARES_OPT_TIMEOUTMS|ARES_OPT_TRIES)) != ARES_SUCCESS) - throw std::runtime_error("Failed to initialize c-ares channel: "s + ares_strerror(ares_error)); -} -ares_channel& DNSHandler::get_channel() -{ - return this->channel; -} +std::unique_ptr<DNSSocketHandler> DNSHandler::socket_handler{}; -void DNSHandler::destroy() +DNSHandler::DNSHandler(std::shared_ptr<Poller> poller) { - this->remove_all_sockets_from_poller(); - this->socket_handlers.clear(); - ::ares_destroy(this->channel); - ::ares_library_cleanup(); + dns_init(nullptr, 0); + const auto socket = dns_open(nullptr); + if (socket == -1) + throw std::runtime_error("Failed to initialize udns socket: "s + strerror(errno)); + + DNSHandler::socket_handler = std::make_unique<DNSSocketHandler>(poller, socket); } -void DNSHandler::gethostbyname(const std::string& name, ares_host_callback callback, - void* data, int family) +void DNSHandler::destroy() { - ::ares_gethostbyname(this->channel, name.data(), family, - callback, data); + DNSHandler::socket_handler.reset(nullptr); + dns_close(nullptr); } -void DNSHandler::watch_dns_sockets(std::shared_ptr<Poller>& poller) +void DNSHandler::watch() { - fd_set readers; - fd_set writers; - - FD_ZERO(&readers); - FD_ZERO(&writers); - - int ndfs = ::ares_fds(this->channel, &readers, &writers); - // For each existing DNS socket, see if we are still supposed to watch it, - // if not then erase it - this->socket_handlers.erase( - std::remove_if(this->socket_handlers.begin(), this->socket_handlers.end(), - [&readers](const auto& dns_socket) - { - return !FD_ISSET(dns_socket->get_socket(), &readers); - }), - this->socket_handlers.end()); - - for (auto i = 0; i < ndfs; ++i) - { - bool read = FD_ISSET(i, &readers); - bool write = FD_ISSET(i, &writers); - // Look for the DNSSocketHandler with this fd - auto it = std::find_if(this->socket_handlers.begin(), - this->socket_handlers.end(), - [i](const auto& socket_handler) - { - return i == socket_handler->get_socket(); - }); - if (!read && !write) // No need to read or write to it - { // If found, erase it and stop watching it because it is not - // needed anymore - if (it != this->socket_handlers.end()) - // The socket destructor removes it from the poller - this->socket_handlers.erase(it); - } - else // We need to write and/or read to it - { // If not found, create it because we need to watch it - if (it == this->socket_handlers.end()) - { - this->socket_handlers.emplace(this->socket_handlers.begin(), - std::make_unique<DNSSocketHandler>(poller, *this, i)); - it = this->socket_handlers.begin(); - } - poller->add_socket_handler(it->get()); - if (write) - poller->watch_send_events(it->get()); - } - } - // Cancel previous timer, if any. - TimedEventsManager::instance().cancel("DNS timeout"); - struct timeval tv; - struct timeval* tvp; - tvp = ::ares_timeout(this->channel, NULL, &tv); - if (tvp) - { - auto future_time = std::chrono::steady_clock::now() + std::chrono::seconds(tvp->tv_sec) + \ - std::chrono::microseconds(tvp->tv_usec); - TimedEventsManager::instance().add_event(TimedEvent(std::move(future_time), - [this]() - { - for (auto& dns_socket_handler: this->socket_handlers) - dns_socket_handler->on_recv(); - }, - "DNS timeout")); - } + DNSHandler::socket_handler->watch(); } -void DNSHandler::remove_all_sockets_from_poller() +void DNSHandler::unwatch() { - for (const auto& socket_handler: this->socket_handlers) - { - socket_handler->remove_from_poller(); - } + DNSHandler::socket_handler->unwatch(); } -#endif /* CARES_FOUND */ +#endif /* UDNS_FOUND */ diff --git a/louloulibs/network/dns_handler.hpp b/louloulibs/network/dns_handler.hpp index fd1729d..78ffe4d 100644 --- a/louloulibs/network/dns_handler.hpp +++ b/louloulibs/network/dns_handler.hpp @@ -1,58 +1,37 @@ #pragma once #include <louloulibs.h> -#ifdef CARES_FOUND +#ifdef UDNS_FOUND -class TCPSocketHandler; class Poller; -class DNSSocketHandler; -# include <ares.h> -# include <memory> -# include <string> -# include <vector> +#include <network/dns_socket_handler.hpp> -/** - * Class managing DNS resolution. It should only be statically instanciated - * once in SocketHandler. It manages ares channel and calls various - * functions of that library. - */ +#include <string> +#include <vector> +#include <memory> class DNSHandler { public: - DNSHandler(); + explicit DNSHandler(std::shared_ptr<Poller> poller); ~DNSHandler() = default; + DNSHandler(const DNSHandler&) = delete; DNSHandler(DNSHandler&&) = delete; DNSHandler& operator=(const DNSHandler&) = delete; DNSHandler& operator=(DNSHandler&&) = delete; - void gethostbyname(const std::string& name, ares_host_callback callback, - void* socket_handler, int family); - /** - * Call ares_fds to know what fd needs to be watched by the poller, create - * or destroy DNSSocketHandlers depending on the result. - */ - void watch_dns_sockets(std::shared_ptr<Poller>& poller); - /** - * Destroy and stop watching all the DNS sockets. Then de-init the channel - * and library. - */ void destroy(); - void remove_all_sockets_from_poller(); - ares_channel& get_channel(); - static DNSHandler instance; + static void watch(); + static void unwatch(); private: /** - * The list of sockets that needs to be watched, according to the last - * call to ares_fds. DNSSocketHandlers are added to it or removed from it - * in the watch_dns_sockets() method + * Manager for the socket returned by udns, that we need to watch with the poller */ - std::vector<std::unique_ptr<DNSSocketHandler>> socket_handlers; - ares_channel channel; + static std::unique_ptr<DNSSocketHandler> socket_handler; }; -#endif /* CARES_FOUND */ +#endif /* UDNS_FOUND */ diff --git a/louloulibs/network/dns_socket_handler.cpp b/louloulibs/network/dns_socket_handler.cpp index 403a5be..ad744a9 100644 --- a/louloulibs/network/dns_socket_handler.cpp +++ b/louloulibs/network/dns_socket_handler.cpp @@ -1,38 +1,27 @@ #include <louloulibs.h> -#ifdef CARES_FOUND +#ifdef UDNS_FOUND #include <network/dns_socket_handler.hpp> #include <network/dns_handler.hpp> #include <network/poller.hpp> -#include <ares.h> +#include <udns.h> DNSSocketHandler::DNSSocketHandler(std::shared_ptr<Poller> poller, - DNSHandler& handler, const socket_t socket): - SocketHandler(poller, socket), - handler(handler) + SocketHandler(poller, socket) { + poller->add_socket_handler(this); } -void DNSSocketHandler::connect() +DNSSocketHandler::~DNSSocketHandler() { + this->unwatch(); } void DNSSocketHandler::on_recv() { - // always stop watching send and read events. We will re-watch them if the - // next call to ares_fds tell us to - this->handler.remove_all_sockets_from_poller(); - ::ares_process_fd(DNSHandler::instance.get_channel(), this->socket, ARES_SOCKET_BAD); -} - -void DNSSocketHandler::on_send() -{ - // always stop watching send and read events. We will re-watch them if the - // next call to ares_fds tell us to - this->handler.remove_all_sockets_from_poller(); - ::ares_process_fd(DNSHandler::instance.get_channel(), ARES_SOCKET_BAD, this->socket); + dns_ioevent(nullptr, 0); } bool DNSSocketHandler::is_connected() const @@ -40,10 +29,15 @@ bool DNSSocketHandler::is_connected() const return true; } -void DNSSocketHandler::remove_from_poller() +void DNSSocketHandler::unwatch() { if (this->poller->is_managing_socket(this->socket)) this->poller->remove_socket_handler(this->socket); } -#endif /* CARES_FOUND */ +void DNSSocketHandler::watch() +{ + this->poller->add_socket_handler(this); +} + +#endif /* UDNS_FOUND */ diff --git a/louloulibs/network/dns_socket_handler.hpp b/louloulibs/network/dns_socket_handler.hpp index 0570196..e12f145 100644 --- a/louloulibs/network/dns_socket_handler.hpp +++ b/louloulibs/network/dns_socket_handler.hpp @@ -1,49 +1,33 @@ #pragma once #include <louloulibs.h> -#ifdef CARES_FOUND +#ifdef UDNS_FOUND #include <network/socket_handler.hpp> -#include <ares.h> /** - * Manage a socket returned by ares_fds. We do not create, open or close the - * socket ourself: this is done by c-ares. We just call ares_process_fd() - * with the correct parameters, depending on what can be done on that socket - * (Poller reported it to be writable or readeable) + * Manage the UDP socket provided by udns, we do not create, open or close the + * socket ourself: this is done by udns. We only watch it for readability */ - -class DNSHandler; - class DNSSocketHandler: public SocketHandler { public: - explicit DNSSocketHandler(std::shared_ptr<Poller> poller, DNSHandler& handler, const socket_t socket); - ~DNSSocketHandler() = default; + explicit DNSSocketHandler(std::shared_ptr<Poller> poller, const socket_t socket); + ~DNSSocketHandler(); DNSSocketHandler(const DNSSocketHandler&) = delete; DNSSocketHandler(DNSSocketHandler&&) = delete; DNSSocketHandler& operator=(const DNSSocketHandler&) = delete; DNSSocketHandler& operator=(DNSSocketHandler&&) = delete; - /** - * Just call dns_process_fd, c-ares will do its work of send()ing or - * recv()ing the data it wants on that socket. - */ void on_recv() override final; - void on_send() override final; - /** - * Do nothing, because we are always considered to be connected, since the - * connection is done by c-ares and not by us. - */ - void connect() override final; + /** * Always true, see the comment for connect() */ bool is_connected() const override final; - void remove_from_poller(); -private: - DNSHandler& handler; + void watch(); + void unwatch(); }; -#endif // CARES_FOUND +#endif // UDNS_FOUND diff --git a/louloulibs/network/resolver.cpp b/louloulibs/network/resolver.cpp index 2987aaa..efb0cf0 100644 --- a/louloulibs/network/resolver.cpp +++ b/louloulibs/network/resolver.cpp @@ -1,17 +1,32 @@ #include <network/dns_handler.hpp> +#include <utils/timed_events.hpp> #include <network/resolver.hpp> #include <string.h> #include <arpa/inet.h> +#include <netinet/in.h> +#include <udns.h> + +#include <fstream> #include <cstdlib> +#include <sstream> +#include <chrono> +#include <map> using namespace std::string_literals; +static std::map<int, std::string> dns_error_messages { + {DNS_E_TEMPFAIL, "Timeout while contacting DNS servers"}, + {DNS_E_PROTOCOL, "Misformatted DNS reply"}, + {DNS_E_NXDOMAIN, "Domain name not found"}, + {DNS_E_NOMEM, "Out of memory"}, + {DNS_E_BADQUERY, "Misformatted domain name"} +}; + Resolver::Resolver(): -#ifdef CARES_FOUND +#ifdef UDNS_FOUND resolved4(false), resolved6(false), resolving(false), - cares_addrinfo(nullptr), port{}, #endif resolved(false), @@ -24,15 +39,44 @@ void Resolver::resolve(const std::string& hostname, const std::string& port, { this->error_cb = error_cb; this->success_cb = success_cb; -#ifdef CARES_FOUND +#ifdef UDNS_FOUND this->port = port; #endif this->start_resolving(hostname, port); } -#ifdef CARES_FOUND -void Resolver::start_resolving(const std::string& hostname, const std::string&) +int Resolver::call_getaddrinfo(const char *name, const char* port, int flags) +{ + struct addrinfo hints; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_flags = flags; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + struct addrinfo* addr_res = nullptr; + const int res = ::getaddrinfo(name, port, + &hints, &addr_res); + + if (res == 0 && addr_res) + { + if (!this->addr) + this->addr.reset(addr_res); + else + { // Append this result at the end of the linked list + struct addrinfo *rp = this->addr.get(); + while (rp->ai_next) + rp = rp->ai_next; + rp->ai_next = addr_res; + } + } + + return res; +} + +#ifdef UDNS_FOUND +void Resolver::start_resolving(const std::string& hostname, const std::string& port) { this->resolving = true; this->resolved = false; @@ -40,48 +84,139 @@ void Resolver::start_resolving(const std::string& hostname, const std::string&) this->resolved6 = false; this->error_msg.clear(); - this->cares_addrinfo = nullptr; + this->addr.reset(nullptr); - auto hostname4_resolved = [](void* arg, int status, int, - struct hostent* hostent) + // We first try to use it as an IP address directly. We tell getaddrinfo + // to NOT use any DNS resolution. + if (this->call_getaddrinfo(hostname.data(), port.data(), AI_NUMERICHOST) == 0) { - Resolver* resolver = static_cast<Resolver*>(arg); - resolver->on_hostname4_resolved(status, hostent); - }; - auto hostname6_resolved = [](void* arg, int status, int, - struct hostent* hostent) + this->on_resolved(); + return; + } + + // Then we look into /etc/hosts to translate the given hostname + const auto hosts = this->look_in_etc_hosts(hostname); + if (!hosts.empty()) + { + for (const auto &host: hosts) + this->call_getaddrinfo(host.data(), port.data(), AI_NUMERICHOST); + this->on_resolved(); + return; + } + + // And finally, we try a DNS resolution + auto hostname6_resolved = [](dns_ctx*, dns_rr_a6* result, void* data) + { + Resolver* resolver = static_cast<Resolver*>(data); + resolver->on_hostname6_resolved(result); + }; + + auto hostname4_resolved = [](dns_ctx*, dns_rr_a4* result, void* data) + { + Resolver* resolver = static_cast<Resolver*>(data); + resolver->on_hostname4_resolved(result); + }; + + DNSHandler::watch(); + auto res = dns_submit_a4(nullptr, hostname.data(), 0, hostname4_resolved, this); + if (!res) + this->on_hostname4_resolved(nullptr); + res = dns_submit_a6(nullptr, hostname.data(), 0, hostname6_resolved, this); + if (!res) + this->on_hostname6_resolved(nullptr); + + this->start_timer(); +} + +void Resolver::start_timer() +{ + const auto timeout = dns_timeouts(nullptr, -1, 0); + if (timeout < 0) + return; + TimedEvent event(std::chrono::steady_clock::now() + std::chrono::seconds(timeout), [this]() { this->start_timer(); }, "DNS"); + TimedEventsManager::instance().add_event(std::move(event)); +} + +std::vector<std::string> Resolver::look_in_etc_hosts(const std::string &hostname) +{ + std::ifstream hosts("/etc/hosts"); + std::string line; + + std::vector<std::string> results; + while (std::getline(hosts, line)) { - Resolver* resolver = static_cast<Resolver*>(arg); - resolver->on_hostname6_resolved(status, hostent); - }; - - DNSHandler::instance.gethostbyname(hostname, hostname6_resolved, - this, AF_INET6); - DNSHandler::instance.gethostbyname(hostname, hostname4_resolved, - this, AF_INET); + if (line.empty()) + continue; + + std::string ip; + std::istringstream line_stream(line); + line_stream >> ip; + if (ip.empty() || ip[0] == '#') + continue; + + std::string host; + while (line_stream >> host && !host.empty() && host[0] != '#') + { + if (hostname == host) + { + results.push_back(ip); + break; + } + } + } + return results; } -void Resolver::on_hostname4_resolved(int status, struct hostent* hostent) +void Resolver::on_hostname4_resolved(dns_rr_a4 *result) { + if (dns_active(nullptr) == 0) + DNSHandler::unwatch(); + this->resolved4 = true; - if (status == ARES_SUCCESS) - this->fill_ares_addrinfo4(hostent); + + const auto status = dns_status(nullptr); + + if (status >= 0 && result) + { + char buf[INET6_ADDRSTRLEN]; + + for (auto i = 0; i < result->dnsa4_nrr; ++i) + { + inet_ntop(AF_INET, &result->dnsa4_addr[i], buf, sizeof(buf)); + this->call_getaddrinfo(buf, this->port.data(), AI_NUMERICHOST); + } + } else - this->error_msg = ::ares_strerror(status); + { + const auto error = dns_error_messages.find(status); + if (error != end(dns_error_messages)) + this->error_msg = error->second; + } - if (this->resolved4 && this->resolved6) + if (this->resolved6 && this->resolved4) this->on_resolved(); } -void Resolver::on_hostname6_resolved(int status, struct hostent* hostent) +void Resolver::on_hostname6_resolved(dns_rr_a6 *result) { + if (dns_active(nullptr) == 0) + DNSHandler::unwatch(); + this->resolved6 = true; - if (status == ARES_SUCCESS) - this->fill_ares_addrinfo6(hostent); - else - this->error_msg = ::ares_strerror(status); + char buf[INET6_ADDRSTRLEN]; - if (this->resolved4 && this->resolved6) + const auto status = dns_status(nullptr); + + if (status >= 0 && result) + { + for (auto i = 0; i < result->dnsa6_nrr; ++i) + { + inet_ntop(AF_INET6, &result->dnsa6_addr[i], buf, sizeof(buf)); + this->call_getaddrinfo(buf, this->port.data(), AI_NUMERICHOST); + } + } + + if (this->resolved6 && this->resolved4) this->on_resolved(); } @@ -89,100 +224,26 @@ void Resolver::on_resolved() { this->resolved = true; this->resolving = false; - if (!this->cares_addrinfo) + if (!this->addr) { if (this->error_cb) this->error_cb(this->error_msg.data()); } else { - this->addr.reset(this->cares_addrinfo); if (this->success_cb) this->success_cb(this->addr.get()); } } -void Resolver::fill_ares_addrinfo4(const struct hostent* hostent) -{ - struct addrinfo* prev = this->cares_addrinfo; - struct in_addr** address = reinterpret_cast<struct in_addr**>(hostent->h_addr_list); - - while (*address) - { - // Create a new addrinfo list element, and fill it - struct addrinfo* current = new struct addrinfo; - current->ai_flags = 0; - current->ai_family = hostent->h_addrtype; - current->ai_socktype = SOCK_STREAM; - current->ai_protocol = 0; - current->ai_addrlen = sizeof(struct sockaddr_in); - - struct sockaddr_in* ai_addr = new struct sockaddr_in; - - ai_addr->sin_family = hostent->h_addrtype; - ai_addr->sin_port = htons(std::strtoul(this->port.data(), nullptr, 10)); - ai_addr->sin_addr.s_addr = (*address)->s_addr; - - current->ai_addr = reinterpret_cast<struct sockaddr*>(ai_addr); - current->ai_next = nullptr; - current->ai_canonname = nullptr; - - current->ai_next = prev; - this->cares_addrinfo = current; - prev = current; - ++address; - } -} - -void Resolver::fill_ares_addrinfo6(const struct hostent* hostent) -{ - struct addrinfo* prev = this->cares_addrinfo; - struct in6_addr** address = reinterpret_cast<struct in6_addr**>(hostent->h_addr_list); - - while (*address) - { - // Create a new addrinfo list element, and fill it - struct addrinfo* current = new struct addrinfo; - current->ai_flags = 0; - current->ai_family = hostent->h_addrtype; - current->ai_socktype = SOCK_STREAM; - current->ai_protocol = 0; - current->ai_addrlen = sizeof(struct sockaddr_in6); - - struct sockaddr_in6* ai_addr = new struct sockaddr_in6; - ai_addr->sin6_family = hostent->h_addrtype; - ai_addr->sin6_port = htons(std::strtoul(this->port.data(), nullptr, 10)); - ::memcpy(ai_addr->sin6_addr.s6_addr, (*address)->s6_addr, sizeof(ai_addr->sin6_addr.s6_addr)); - ai_addr->sin6_flowinfo = 0; - ai_addr->sin6_scope_id = 0; - - current->ai_addr = reinterpret_cast<struct sockaddr*>(ai_addr); - current->ai_canonname = nullptr; - - current->ai_next = prev; - this->cares_addrinfo = current; - prev = current; - ++address; - } -} - -#else // ifdef CARES_FOUND +#else // ifdef UDNS_FOUND void Resolver::start_resolving(const std::string& hostname, const std::string& port) { // If the resolution fails, the addr will be unset this->addr.reset(nullptr); - struct addrinfo hints; - memset(&hints, 0, sizeof(struct addrinfo)); - hints.ai_flags = 0; - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_protocol = 0; - - struct addrinfo* addr_res = nullptr; - const int res = ::getaddrinfo(hostname.data(), port.data(), - &hints, &addr_res); + const auto res = this->call_getaddrinfo(hostname.data(), port.data(), 0); this->resolved = true; @@ -194,12 +255,11 @@ void Resolver::start_resolving(const std::string& hostname, const std::string& p } else { - this->addr.reset(addr_res); if (this->success_cb) this->success_cb(this->addr.get()); } } -#endif // ifdef CARES_FOUND +#endif // ifdef UDNS_FOUND std::string addr_to_string(const struct addrinfo* rp) { diff --git a/louloulibs/network/resolver.hpp b/louloulibs/network/resolver.hpp index 29e6f3a..f516da5 100644 --- a/louloulibs/network/resolver.hpp +++ b/louloulibs/network/resolver.hpp @@ -1,38 +1,31 @@ #pragma once - #include "louloulibs.h" #include <functional> +#include <vector> #include <memory> #include <string> #include <sys/types.h> #include <sys/socket.h> #include <netdb.h> +#include <udns.h> class AddrinfoDeleter { public: void operator()(struct addrinfo* addr) { -#ifdef CARES_FOUND - while (addr) - { - delete addr->ai_addr; - auto next = addr->ai_next; - delete addr; - addr = next; - } -#else freeaddrinfo(addr); -#endif } }; + class Resolver { public: + using ErrorCallbackType = std::function<void(const char*)>; using SuccessCallbackType = std::function<void(const struct addrinfo*)>; @@ -45,7 +38,7 @@ public: bool is_resolving() const { -#ifdef CARES_FOUND +#ifdef UDNS_FOUND return this->resolving; #else return false; @@ -68,11 +61,10 @@ public: void clear() { -#ifdef CARES_FOUND +#ifdef UDNS_FOUND this->resolved6 = false; this->resolved4 = false; this->resolving = false; - this->cares_addrinfo = nullptr; this->port.clear(); #endif this->resolved = false; @@ -85,12 +77,18 @@ public: private: void start_resolving(const std::string& hostname, const std::string& port); -#ifdef CARES_FOUND - void on_hostname4_resolved(int status, struct hostent* hostent); - void on_hostname6_resolved(int status, struct hostent* hostent); + std::vector<std::string> look_in_etc_hosts(const std::string& hostname); + /** + * Call getaddrinfo() on the given hostname or IP, and append the result + * to our internal addrinfo list. Return getaddrinfo()’s return value. + */ + int call_getaddrinfo(const char* name, const char* port, int flags); - void fill_ares_addrinfo4(const struct hostent* hostent); - void fill_ares_addrinfo6(const struct hostent* hostent); +#ifdef UDNS_FOUND + void on_hostname4_resolved(dns_rr_a4 *result); + void on_hostname6_resolved(dns_rr_a6 *result); + + void start_timer(); void on_resolved(); @@ -99,14 +97,6 @@ private: bool resolving; - /** - * When using c-ares to resolve the host asynchronously, we need the - * c-ares callbacks to fill a structure (a struct addrinfo, for - * compatibility with getaddrinfo and the rest of the code that works when - * c-ares is not used) with all returned values (for example an IPv6 and - * an IPv4). The pointer is given to the unique_ptr to manage its lifetime. - */ - struct addrinfo* cares_addrinfo; std::string port; #endif @@ -117,7 +107,6 @@ private: bool resolved; std::string error_msg; - std::unique_ptr<struct addrinfo, AddrinfoDeleter> addr; ErrorCallbackType error_cb; @@ -125,5 +114,3 @@ private: }; std::string addr_to_string(const struct addrinfo* rp); - - diff --git a/louloulibs/network/socket_handler.hpp b/louloulibs/network/socket_handler.hpp index ea79a18..607a106 100644 --- a/louloulibs/network/socket_handler.hpp +++ b/louloulibs/network/socket_handler.hpp @@ -20,9 +20,9 @@ public: SocketHandler& operator=(const SocketHandler&) = delete; SocketHandler& operator=(SocketHandler&&) = delete; - virtual void on_recv() = 0; - virtual void on_send() = 0; - virtual void connect() = 0; + virtual void on_recv() {} + virtual void on_send() {} + virtual void connect() {} virtual bool is_connected() const = 0; socket_t get_socket() const diff --git a/louloulibs/network/tcp_client_socket_handler.cpp b/louloulibs/network/tcp_client_socket_handler.cpp new file mode 100644 index 0000000..4e6445c --- /dev/null +++ b/louloulibs/network/tcp_client_socket_handler.cpp @@ -0,0 +1,257 @@ +#include <network/tcp_client_socket_handler.hpp> +#include <utils/timed_events.hpp> +#include <utils/scopeguard.hpp> +#include <network/poller.hpp> + +#include <logger/logger.hpp> + +#include <cstring> +#include <unistd.h> +#include <fcntl.h> + +using namespace std::string_literals; + +TCPClientSocketHandler::TCPClientSocketHandler(std::shared_ptr<Poller> poller): + TCPSocketHandler(poller), + hostname_resolution_failed(false), + connected(false), + connecting(false) +{} + +TCPClientSocketHandler::~TCPClientSocketHandler() +{ + this->close(); +} + +void TCPClientSocketHandler::init_socket(const struct addrinfo* rp) +{ + if (this->socket != -1) + ::close(this->socket); + if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1) + throw std::runtime_error("Could not create socket: "s + std::strerror(errno)); + // Bind the socket to a specific address, if specified + if (!this->bind_addr.empty()) + { + // Convert the address from string format to a sockaddr that can be + // used in bind() + struct addrinfo* result; + int err = ::getaddrinfo(this->bind_addr.data(), nullptr, nullptr, &result); + if (err != 0 || !result) + log_error("Failed to bind socket to ", this->bind_addr, ": ", + gai_strerror(err)); + else + { + utils::ScopeGuard sg([result](){ freeaddrinfo(result); }); + struct addrinfo* rp; + for (rp = result; rp; rp = rp->ai_next) + { + if ((::bind(this->socket, + reinterpret_cast<const struct sockaddr*>(rp->ai_addr), + rp->ai_addrlen)) == 0) + break; + } + if (!rp) + log_error("Failed to bind socket to ", this->bind_addr, ": ", + strerror(errno)); + else + log_info("Socket successfully bound to ", this->bind_addr); + } + } + int optval = 1; + if (::setsockopt(this->socket, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) == -1) + log_warning("Failed to enable TCP keepalive on socket: ", strerror(errno)); + // Set the socket on non-blocking mode. This is useful to receive a EAGAIN + // error when connect() would block, to not block the whole process if a + // remote is not responsive. + const int existing_flags = ::fcntl(this->socket, F_GETFL, 0); + if ((existing_flags == -1) || + (::fcntl(this->socket, F_SETFL, existing_flags | O_NONBLOCK) == -1)) + throw std::runtime_error("Could not initialize socket: "s + std::strerror(errno)); +} + +void TCPClientSocketHandler::connect(const std::string& address, const std::string& port, const bool tls) +{ + this->address = address; + this->port = port; + this->use_tls = tls; + + struct addrinfo* addr_res; + + if (!this->connecting) + { + // Get the addrinfo from getaddrinfo (or using udns), only if + // this is the first call of this function. + if (!this->resolver.is_resolved()) + { + log_info("Trying to connect to ", address, ":", port); + // Start the asynchronous process of resolving the hostname. Once + // the addresses have been found and `resolved` has been set to true + // (but connecting will still be false), TCPClientSocketHandler::connect() + // needs to be called, again. + this->resolver.resolve(address, port, + [this](const struct addrinfo*) + { + log_debug("Resolution success, calling connect() again"); + this->connect(); + }, + [this](const char*) + { + log_debug("Resolution failed, calling connect() again"); + this->connect(); + }); + return; + } + else + { + // The DNS resolver resolved the hostname and the available addresses + // where saved in the addrinfo linked list. Now, just use + // this list to try to connect. + addr_res = this->resolver.get_result().get(); + if (!addr_res) + { + this->hostname_resolution_failed = true; + const auto msg = this->resolver.get_error_message(); + this->close(); + this->on_connection_failed(msg); + return ; + } + } + } + else + { // This function is called again, use the saved addrinfo structure, + // instead of re-doing the whole getaddrinfo process. + addr_res = &this->addrinfo; + } + + for (struct addrinfo* rp = addr_res; rp; rp = rp->ai_next) + { + if (!this->connecting) + { + try { + this->init_socket(rp); + } + catch (const std::runtime_error& error) { + log_error("Failed to init socket: ", error.what()); + break; + } + } + + this->display_resolved_ip(rp); + + if (::connect(this->socket, rp->ai_addr, rp->ai_addrlen) == 0 + || errno == EISCONN) + { + log_info("Connection success."); + TimedEventsManager::instance().cancel("connection_timeout"s + + std::to_string(this->socket)); + this->poller->add_socket_handler(this); + this->connected = true; + this->connecting = false; +#ifdef BOTAN_FOUND + if (this->use_tls) + this->start_tls(this->address, this->port); +#endif + this->connection_date = std::chrono::system_clock::now(); + + // Get our local TCP port and store it + this->local_port = static_cast<uint16_t>(-1); + if (rp->ai_family == AF_INET6) + { + struct sockaddr_in6 a; + socklen_t l = sizeof(a); + if (::getsockname(this->socket, (struct sockaddr*)&a, &l) != -1) + this->local_port = ntohs(a.sin6_port); + } + else if (rp->ai_family == AF_INET) + { + struct sockaddr_in a; + socklen_t l = sizeof(a); + if (::getsockname(this->socket, (struct sockaddr*)&a, &l) != -1) + this->local_port = ntohs(a.sin_port); + } + + log_debug("Local port: ", this->local_port, ", and remote port: ", this->port); + + this->on_connected(); + return ; + } + else if (errno == EINPROGRESS || errno == EALREADY) + { // retry this process later, when the socket + // is ready to be written on. + this->connecting = true; + this->poller->add_socket_handler(this); + this->poller->watch_send_events(this); + // Save the addrinfo structure, to use it on the next call + this->ai_addrlen = rp->ai_addrlen; + memcpy(&this->ai_addr, rp->ai_addr, this->ai_addrlen); + memcpy(&this->addrinfo, rp, sizeof(struct addrinfo)); + this->addrinfo.ai_addr = reinterpret_cast<struct sockaddr*>(&this->ai_addr); + this->addrinfo.ai_next = nullptr; + // If the connection has not succeeded or failed in 5s, we consider + // it to have failed + TimedEventsManager::instance().add_event( + TimedEvent(std::chrono::steady_clock::now() + 5s, + std::bind(&TCPClientSocketHandler::on_connection_timeout, this), + "connection_timeout"s + std::to_string(this->socket))); + return ; + } + log_info("Connection failed:", std::strerror(errno)); + } + log_error("All connection attempts failed."); + this->close(); + this->on_connection_failed(std::strerror(errno)); + return ; +} + +void TCPClientSocketHandler::on_connection_timeout() +{ + this->close(); + this->on_connection_failed("connection timed out"); +} + +void TCPClientSocketHandler::connect() +{ + this->connect(this->address, this->port, this->use_tls); +} + +void TCPClientSocketHandler::close() +{ + TimedEventsManager::instance().cancel("connection_timeout"s + + std::to_string(this->socket)); + + TCPSocketHandler::close(); + + this->connected = false; + this->connecting = false; + this->port.clear(); + this->resolver.clear(); +} + +void TCPClientSocketHandler::display_resolved_ip(struct addrinfo* rp) const +{ + if (rp->ai_family == AF_INET) + log_debug("Trying IPv4 address ", addr_to_string(rp)); + else if (rp->ai_family == AF_INET6) + log_debug("Trying IPv6 address ", addr_to_string(rp)); +} + +bool TCPClientSocketHandler::is_connected() const +{ + return this->connected; +} + +bool TCPClientSocketHandler::is_connecting() const +{ + return this->connecting || this->resolver.is_resolving(); +} + +std::string TCPClientSocketHandler::get_port() const +{ + return this->port; +} + +bool TCPClientSocketHandler::match_port_pairt(const uint16_t local, const uint16_t remote) const +{ + const uint16_t remote_port = static_cast<uint16_t>(std::stoi(this->port)); + return this->is_connected() && local == this->local_port && remote == remote_port; +} diff --git a/louloulibs/network/tcp_client_socket_handler.hpp b/louloulibs/network/tcp_client_socket_handler.hpp new file mode 100644 index 0000000..75e1364 --- /dev/null +++ b/louloulibs/network/tcp_client_socket_handler.hpp @@ -0,0 +1,82 @@ +#pragma once + +#include <network/tcp_socket_handler.hpp> + +class TCPClientSocketHandler: public TCPSocketHandler +{ + public: + TCPClientSocketHandler(std::shared_ptr<Poller> poller); + ~TCPClientSocketHandler(); + /** + * Connect to the remote server, and call on_connected() if this + * succeeds. If tls is true, we set use_tls to true and will also call + * start_tls() when the connection succeeds. + */ + void connect(const std::string& address, const std::string& port, const bool tls); + void connect() override final; + /** + * Called by a TimedEvent, when the connection did not succeed or fail + * after a given time. + */ + void on_connection_timeout(); + /** + * Called when the connection is successful. + */ + virtual void on_connected() = 0; + bool is_connected() const override; + bool is_connecting() const override; + + std::string get_port() const; + + void close() override final; + std::chrono::system_clock::time_point connection_date; + + /** + * Whether or not this connection is using the two given TCP ports. + */ + bool match_port_pairt(const uint16_t local, const uint16_t remote) const; + + protected: + bool hostname_resolution_failed; + /** + * Address to bind the socket to, before calling connect(). + * If empty, it’s equivalent to binding to INADDR_ANY. + */ + std::string bind_addr; + /** + * Display the resolved IP, just for information purpose. + */ + void display_resolved_ip(struct addrinfo* rp) const; + private: + /** + * Initialize the socket with the parameters contained in the given + * addrinfo structure. + */ + void init_socket(const struct addrinfo* rp); + /** + * DNS resolver + */ + Resolver resolver; + /** + * Keep the details of the addrinfo returned by the resolver that + * triggered a EINPROGRESS error when connect()ing to it, to reuse it + * directly when connect() is called again. + */ + struct addrinfo addrinfo{}; + struct sockaddr_in6 ai_addr{}; + socklen_t ai_addrlen{}; + + /** + * Hostname we are connected/connecting to + */ + std::string address; + /** + * Port we are connected/connecting to + */ + std::string port; + + uint16_t local_port{}; + + bool connected; + bool connecting; +}; diff --git a/louloulibs/network/tcp_server_socket.hpp b/louloulibs/network/tcp_server_socket.hpp new file mode 100644 index 0000000..7ea49ab --- /dev/null +++ b/louloulibs/network/tcp_server_socket.hpp @@ -0,0 +1,70 @@ +#pragma once + +#include <network/socket_handler.hpp> +#include <network/poller.hpp> +#include <logger/logger.hpp> + +#include <string> + +#include <arpa/inet.h> +#include <sys/types.h> +#include <sys/socket.h> +#include <netinet/ip.h> + +#include <cstring> +#include <cassert> + +template <typename RemoteSocketType> +class TcpSocketServer: public SocketHandler +{ + public: + TcpSocketServer(std::shared_ptr<Poller> poller, const uint16_t port): + SocketHandler(poller, -1) + { + if ((this->socket = ::socket(AF_INET6, SOCK_STREAM, 0)) == -1) + throw std::runtime_error(std::string{"Could not create socket: "} + std::strerror(errno)); + + int opt = 1; + if (::setsockopt(this->socket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) + throw std::runtime_error(std::string{"Failed to set socket option: "} + std::strerror(errno)); + + struct sockaddr_in6 addr{}; + addr.sin6_family = AF_INET6; + addr.sin6_port = htons(port); + addr.sin6_addr = IN6ADDR_ANY_INIT; + if ((::bind(this->socket, (const struct sockaddr*)&addr, sizeof(addr))) == -1) + { // If we can’t listen on this port, we just give up, but this is not fatal. + log_warning("Failed to bind on port ", std::to_string(port), ": ", std::strerror(errno)); + return; + } + + if ((::listen(this->socket, 10)) == -1) + throw std::runtime_error("listen() failed"); + + this->accept(); + } + ~TcpSocketServer() = default; + + void on_recv() override + { + // Accept a RemoteSocketType + int socket = ::accept(this->socket, nullptr, nullptr); + + auto client = std::make_unique<RemoteSocketType>(poller, socket, *this); + this->poller->add_socket_handler(client.get()); + this->sockets.push_back(std::move(client)); + } + + protected: + std::vector<std::unique_ptr<RemoteSocketType>> sockets; + + private: + void accept() + { + this->poller->add_socket_handler(this); + } + bool is_connected() const override + { + return true; + } +}; diff --git a/louloulibs/network/tcp_socket_handler.cpp b/louloulibs/network/tcp_socket_handler.cpp index 1dddde5..6aef2b1 100644 --- a/louloulibs/network/tcp_socket_handler.cpp +++ b/louloulibs/network/tcp_socket_handler.cpp @@ -1,8 +1,6 @@ #include <network/tcp_socket_handler.hpp> #include <network/dns_handler.hpp> -#include <utils/timed_events.hpp> -#include <utils/scopeguard.hpp> #include <network/poller.hpp> #include <logger/logger.hpp> @@ -12,16 +10,29 @@ #include <unistd.h> #include <errno.h> #include <cstring> -#include <fcntl.h> #ifdef BOTAN_FOUND # include <botan/hex.h> # include <botan/tls_exceptn.h> -Botan::AutoSeeded_RNG TCPSocketHandler::rng; -Botan::TLS::Policy TCPSocketHandler::policy; -Botan::TLS::Session_Manager_In_Memory TCPSocketHandler::session_manager(TCPSocketHandler::rng); - +namespace +{ + Botan::AutoSeeded_RNG& get_rng() + { + static Botan::AutoSeeded_RNG rng{}; + return rng; + } + BiboumiTLSPolicy& get_policy() + { + static BiboumiTLSPolicy policy{}; + return policy; + } + Botan::TLS::Session_Manager_In_Memory& get_session_manager() + { + static Botan::TLS::Session_Manager_In_Memory session_manager{get_rng()}; + return session_manager; + } +} #endif #ifndef UIO_FASTIOV @@ -35,10 +46,7 @@ namespace ph = std::placeholders; TCPSocketHandler::TCPSocketHandler(std::shared_ptr<Poller> poller): SocketHandler(poller, -1), - use_tls(false), - connected(false), - connecting(false), - hostname_resolution_failed(false) + use_tls(false) #ifdef BOTAN_FOUND ,credential_manager(this) #endif @@ -46,181 +54,13 @@ TCPSocketHandler::TCPSocketHandler(std::shared_ptr<Poller> poller): TCPSocketHandler::~TCPSocketHandler() { - this->close(); -} - - -void TCPSocketHandler::init_socket(const struct addrinfo* rp) -{ + if (this->poller->is_managing_socket(this->get_socket())) + this->poller->remove_socket_handler(this->get_socket()); if (this->socket != -1) - ::close(this->socket); - if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1) - throw std::runtime_error("Could not create socket: "s + strerror(errno)); - // Bind the socket to a specific address, if specified - if (!this->bind_addr.empty()) - { - // Convert the address from string format to a sockaddr that can be - // used in bind() - struct addrinfo* result; - int err = ::getaddrinfo(this->bind_addr.data(), nullptr, nullptr, &result); - if (err != 0 || !result) - log_error("Failed to bind socket to ", this->bind_addr, ": ", - gai_strerror(err)); - else - { - utils::ScopeGuard sg([result](){ freeaddrinfo(result); }); - struct addrinfo* rp; - int bind_error = 0; - for (rp = result; rp; rp = rp->ai_next) - { - if ((bind_error = ::bind(this->socket, - reinterpret_cast<const struct sockaddr*>(rp->ai_addr), - rp->ai_addrlen)) == 0) - break; - } - if (!rp) - log_error("Failed to bind socket to ", this->bind_addr, ": ", - strerror(errno)); - else - log_info("Socket successfully bound to ", this->bind_addr); - } - } - int optval = 1; - if (::setsockopt(this->socket, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) == -1) - log_warning("Failed to enable TCP keepalive on socket: ", strerror(errno)); - // Set the socket on non-blocking mode. This is useful to receive a EAGAIN - // error when connect() would block, to not block the whole process if a - // remote is not responsive. - const int existing_flags = ::fcntl(this->socket, F_GETFL, 0); - if ((existing_flags == -1) || - (::fcntl(this->socket, F_SETFL, existing_flags | O_NONBLOCK) == -1)) - throw std::runtime_error("Could not initialize socket: "s + strerror(errno)); -} - -void TCPSocketHandler::connect(const std::string& address, const std::string& port, const bool tls) -{ - this->address = address; - this->port = port; - this->use_tls = tls; - - struct addrinfo* addr_res; - - if (!this->connecting) - { - // Get the addrinfo from getaddrinfo (or ares_gethostbyname), only if - // this is the first call of this function. - if (!this->resolver.is_resolved()) - { - log_info("Trying to connect to ", address, ":", port); - // Start the asynchronous process of resolving the hostname. Once - // the addresses have been found and `resolved` has been set to true - // (but connecting will still be false), TCPSocketHandler::connect() - // needs to be called, again. - this->resolver.resolve(address, port, - [this](const struct addrinfo*) - { - log_debug("Resolution success, calling connect() again"); - this->connect(); - }, - [this](const char*) - { - log_debug("Resolution failed, calling connect() again"); - this->connect(); - }); - return; - } - else - { - // The c-ares resolved the hostname and the available addresses - // where saved in the cares_addrinfo linked list. Now, just use - // this list to try to connect. - addr_res = this->resolver.get_result().get(); - if (!addr_res) - { - this->hostname_resolution_failed = true; - const auto msg = this->resolver.get_error_message(); - this->close(); - this->on_connection_failed(msg); - return ; - } - } - } - else - { // This function is called again, use the saved addrinfo structure, - // instead of re-doing the whole getaddrinfo process. - addr_res = &this->addrinfo; - } - - for (struct addrinfo* rp = addr_res; rp; rp = rp->ai_next) { - if (!this->connecting) - { - try { - this->init_socket(rp); - } - catch (const std::runtime_error& error) { - log_error("Failed to init socket: ", error.what()); - break; - } - } - - this->display_resolved_ip(rp); - - if (::connect(this->socket, rp->ai_addr, rp->ai_addrlen) == 0 - || errno == EISCONN) - { - log_info("Connection success."); - TimedEventsManager::instance().cancel("connection_timeout"s + - std::to_string(this->socket)); - this->poller->add_socket_handler(this); - this->connected = true; - this->connecting = false; -#ifdef BOTAN_FOUND - if (this->use_tls) - this->start_tls(); -#endif - this->connection_date = std::chrono::system_clock::now(); - - this->on_connected(); - return ; - } - else if (errno == EINPROGRESS || errno == EALREADY) - { // retry this process later, when the socket - // is ready to be written on. - this->connecting = true; - this->poller->add_socket_handler(this); - this->poller->watch_send_events(this); - // Save the addrinfo structure, to use it on the next call - this->ai_addrlen = rp->ai_addrlen; - memcpy(&this->ai_addr, rp->ai_addr, this->ai_addrlen); - memcpy(&this->addrinfo, rp, sizeof(struct addrinfo)); - this->addrinfo.ai_addr = reinterpret_cast<struct sockaddr*>(&this->ai_addr); - this->addrinfo.ai_next = nullptr; - // If the connection has not succeeded or failed in 5s, we consider - // it to have failed - TimedEventsManager::instance().add_event( - TimedEvent(std::chrono::steady_clock::now() + 5s, - std::bind(&TCPSocketHandler::on_connection_timeout, this), - "connection_timeout"s + std::to_string(this->socket))); - return ; - } - log_info("Connection failed:", strerror(errno)); + ::close(this->socket); + this->socket = -1; } - log_error("All connection attempts failed."); - this->close(); - this->on_connection_failed(strerror(errno)); - return ; -} - -void TCPSocketHandler::on_connection_timeout() -{ - this->close(); - this->on_connection_failed("connection timed out"); -} - -void TCPSocketHandler::connect() -{ - this->connect(this->address, this->port, this->use_tls); } void TCPSocketHandler::on_recv() @@ -267,13 +107,13 @@ ssize_t TCPSocketHandler::do_recv(void* recv_buf, const size_t buf_size) } else if (-1 == size) { - if (this->connecting) + if (this->is_connecting()) log_warning("Error connecting: ", strerror(errno)); else log_warning("Error while reading from socket: ", strerror(errno)); // Remember if we were connecting, or already connected when this // happened, because close() sets this->connecting to false - const auto were_connecting = this->connecting; + const auto were_connecting = this->is_connecting(); this->close(); if (were_connecting) this->on_connection_failed(strerror(errno)); @@ -333,29 +173,15 @@ void TCPSocketHandler::on_send() void TCPSocketHandler::close() { - TimedEventsManager::instance().cancel("connection_timeout"s + - std::to_string(this->socket)); - if (this->connected || this->connecting) + if (this->is_connected() || this->is_connecting()) this->poller->remove_socket_handler(this->get_socket()); if (this->socket != -1) { ::close(this->socket); this->socket = -1; } - this->connected = false; - this->connecting = false; this->in_buf.clear(); this->out_buf.clear(); - this->port.clear(); - this->resolver.clear(); -} - -void TCPSocketHandler::display_resolved_ip(struct addrinfo* rp) const -{ - if (rp->ai_family == AF_INET) - log_debug("Trying IPv4 address ", addr_to_string(rp)); - else if (rp->ai_family == AF_INET6) - log_debug("Trying IPv6 address ", addr_to_string(rp)); } void TCPSocketHandler::send_data(std::string&& data) @@ -379,52 +205,46 @@ void TCPSocketHandler::raw_send(std::string&& data) if (data.empty()) return ; this->out_buf.emplace_back(std::move(data)); - if (this->connected) + if (this->is_connected()) this->poller->watch_send_events(this); } void TCPSocketHandler::send_pending_data() { - if (this->connected && !this->out_buf.empty()) + if (this->is_connected() && !this->out_buf.empty()) this->poller->watch_send_events(this); } -bool TCPSocketHandler::is_connected() const -{ - return this->connected; -} - -bool TCPSocketHandler::is_connecting() const -{ - return this->connecting || this->resolver.is_resolving(); -} - bool TCPSocketHandler::is_using_tls() const { return this->use_tls; } -std::string TCPSocketHandler::get_port() const +void* TCPSocketHandler::get_receive_buffer(const size_t) const { - return this->port; + return nullptr; } -void* TCPSocketHandler::get_receive_buffer(const size_t) const +void TCPSocketHandler::consume_in_buffer(const std::size_t size) { - return nullptr; + this->in_buf = this->in_buf.substr(size, std::string::npos); } #ifdef BOTAN_FOUND -void TCPSocketHandler::start_tls() +void TCPSocketHandler::start_tls(const std::string& address, const std::string& port) { - Botan::TLS::Server_Information server_info(this->address, "irc", std::stoul(this->port)); + Botan::TLS::Server_Information server_info(address, "irc", std::stoul(port)); this->tls = std::make_unique<Botan::TLS::Client>( - std::bind(&TCPSocketHandler::tls_output_fn, this, ph::_1, ph::_2), - std::bind(&TCPSocketHandler::tls_data_cb, this, ph::_1, ph::_2), - std::bind(&TCPSocketHandler::tls_alert_cb, this, ph::_1, ph::_2, ph::_3), - std::bind(&TCPSocketHandler::tls_handshake_cb, this, ph::_1), - session_manager, this->credential_manager, policy, - rng, server_info, Botan::TLS::Protocol_Version::latest_tls_version()); +# if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,32) + *this, +# else + [this](const Botan::byte* data, size_t size) { this->tls_emit_data(data, size); }, + [this](const Botan::byte* data, size_t size) { this->tls_record_received(0, data, size); }, + [this](Botan::TLS::Alert alert, const Botan::byte*, size_t) { this->tls_alert(alert); }, + [this](const Botan::TLS::Session& session) { return this->tls_session_established(session); }, +# endif + get_session_manager(), this->credential_manager, get_policy(), + get_rng(), server_info, Botan::TLS::Protocol_Version::latest_tls_version()); } void TCPSocketHandler::tls_recv() @@ -475,7 +295,7 @@ void TCPSocketHandler::tls_send(std::string&& data) std::make_move_iterator(data.end())); } -void TCPSocketHandler::tls_data_cb(const Botan::byte* data, size_t size) +void TCPSocketHandler::tls_record_received(uint64_t, const Botan::byte *data, size_t size) { this->in_buf += std::string(reinterpret_cast<const char*>(data), size); @@ -483,17 +303,17 @@ void TCPSocketHandler::tls_data_cb(const Botan::byte* data, size_t size) this->parse_in_buffer(size); } -void TCPSocketHandler::tls_output_fn(const Botan::byte* data, size_t size) +void TCPSocketHandler::tls_emit_data(const Botan::byte *data, size_t size) { this->raw_send(std::string(reinterpret_cast<const char*>(data), size)); } -void TCPSocketHandler::tls_alert_cb(Botan::TLS::Alert alert, const Botan::byte*, size_t) +void TCPSocketHandler::tls_alert(Botan::TLS::Alert alert) { log_debug("tls_alert: ", alert.type_string()); } -bool TCPSocketHandler::tls_handshake_cb(const Botan::TLS::Session& session) +bool TCPSocketHandler::tls_session_established(const Botan::TLS::Session& session) { log_debug("Handshake with ", session.server_info().hostname(), " complete.", " Version: ", session.version().to_string(), @@ -505,6 +325,31 @@ bool TCPSocketHandler::tls_handshake_cb(const Botan::TLS::Session& session) return true; } +#if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,34) +void TCPSocketHandler::tls_verify_cert_chain(const std::vector<Botan::X509_Certificate>& cert_chain, + const std::vector<std::shared_ptr<const Botan::OCSP::Response>>& ocsp_responses, + const std::vector<Botan::Certificate_Store*>& trusted_roots, + Botan::Usage_Type usage, const std::string& hostname, + const Botan::TLS::Policy& policy) +{ + log_debug("Checking remote certificate for hostname ", hostname); + try + { + Botan::TLS::Callbacks::tls_verify_cert_chain(cert_chain, ocsp_responses, trusted_roots, usage, hostname, policy); + log_debug("Certificate is valid"); + } + catch (const std::exception& tls_exception) + { + log_warning("TLS certificate check failed: ", tls_exception.what()); + std::exception_ptr exception_ptr{}; + if (this->abort_on_invalid_cert()) + exception_ptr = std::current_exception(); + + check_tls_certificate(cert_chain, hostname, this->credential_manager.get_trusted_fingerprint(), exception_ptr); + } +} +#endif + void TCPSocketHandler::on_tls_activated() { this->send_data({}); diff --git a/louloulibs/network/tcp_socket_handler.hpp b/louloulibs/network/tcp_socket_handler.hpp index 20a3e5a..600405d 100644 --- a/louloulibs/network/tcp_socket_handler.hpp +++ b/louloulibs/network/tcp_socket_handler.hpp @@ -1,6 +1,5 @@ #pragma once - #include "louloulibs.h" #include <network/socket_handler.hpp> @@ -19,13 +18,44 @@ #include <string> #include <list> +#ifdef BOTAN_FOUND + +# include <botan/types.h> +# include <botan/botan.h> +# include <botan/tls_session_manager.h> + +class BiboumiTLSPolicy: public Botan::TLS::Policy +{ +public: +# if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,33) + bool use_ecc_point_compression() const override + { + return true; + } + bool require_cert_revocation_info() const override + { + return false; + } +# endif +}; + +# if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,32) +# define BOTAN_TLS_CALLBACKS_OVERRIDE override final +# else +# define BOTAN_TLS_CALLBACKS_OVERRIDE +# endif +#endif + /** - * An interface, with a series of callbacks that should be implemented in - * subclasses that deal with a socket. These callbacks are called on various events - * (read/write/timeout, etc) when they are notified to a poller - * (select/poll/epoll etc) + * Does all the read/write, buffering etc. With optional tls. + * But doesn’t do any connect() or accept() or anything else. */ class TCPSocketHandler: public SocketHandler +#ifdef BOTAN_FOUND +# if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,32) + ,public Botan::TLS::Callbacks +# endif +#endif { protected: ~TCPSocketHandler(); @@ -37,13 +67,6 @@ public: TCPSocketHandler& operator=(TCPSocketHandler&&) = delete; /** - * Connect to the remote server, and call on_connected() if this - * succeeds. If tls is true, we set use_tls to true and will also call - * start_tls() when the connection succeeds. - */ - void connect(const std::string& address, const std::string& port, const bool tls); - void connect() override final; - /** * Reads raw data from the socket. And pass it to parse_in_buffer() * If we are using TLS on this connection, we call tls_recv() */ @@ -67,25 +90,7 @@ public: /** * Close the connection, remove us from the poller */ - void close(); - /** - * Called by a TimedEvent, when the connection did not succeed or fail - * after a given time. - */ - void on_connection_timeout(); - /** - * Called when the connection is successful. - */ - virtual void on_connected() = 0; - /** - * Called when the connection fails. Not when it is closed later, just at - * the connect() call. - */ - virtual void on_connection_failed(const std::string& reason) = 0; - /** - * Called when we detect a disconnection from the remote host. - */ - virtual void on_connection_close(const std::string& error) = 0; + virtual void close(); /** * Handle/consume (some of) the data received so far. The data to handle * may be in the in_buf buffer, or somewhere else, depending on what @@ -93,6 +98,9 @@ public: * should be truncated, only the unused data should be left untouched. * * The size argument is the size of the last chunk of data that was added to the buffer. + * + * The function should call consume_in_buffer, with the size that was consumed by the + * “parsing”, and thus to be removed from the input buffer. */ virtual void parse_in_buffer(const size_t size) = 0; #ifdef BOTAN_FOUND @@ -105,19 +113,10 @@ public: return true; } #endif - bool is_connected() const override final; - bool is_connecting() const; bool is_using_tls() const; - std::string get_port() const; - std::chrono::system_clock::time_point connection_date; private: /** - * Initialize the socket with the parameters contained in the given - * addrinfo structure. - */ - void init_socket(const struct addrinfo* rp); - /** * Reads from the socket into the provided buffer. If an error occurs * (read returns <= 0), the handling of the error is done here (close the * connection, log a message, etc). @@ -136,13 +135,16 @@ private: */ void raw_send(std::string&& data); + protected: + virtual bool is_connecting() const = 0; #ifdef BOTAN_FOUND /** * Create the TLS::Client object, with all the callbacks etc. This must be * called only when we know we are able to send TLS-encrypted data over * the socket. */ - void start_tls(); + void start_tls(const std::string& address, const std::string& port); + private: /** * An additional step to pass the data into our tls object to decrypt it * before passing it to parse_in_buffer. @@ -158,22 +160,31 @@ private: * Called by the tls object that some data has been decrypt. We call * parse_in_buffer() to handle that unencrypted data. */ - void tls_data_cb(const Botan::byte* data, size_t size); + void tls_record_received(uint64_t rec_no, const Botan::byte* data, size_t size) BOTAN_TLS_CALLBACKS_OVERRIDE; /** * Called by the tls object to indicate that some data has been encrypted * and is now ready to be sent on the socket as is. */ - void tls_output_fn(const Botan::byte* data, size_t size); + void tls_emit_data(const Botan::byte* data, size_t size) BOTAN_TLS_CALLBACKS_OVERRIDE; /** * Called by the tls object to indicate that a TLS alert has been * received. We don’t use it, we just log some message, at the moment. */ - void tls_alert_cb(Botan::TLS::Alert alert, const Botan::byte*, size_t); + void tls_alert(Botan::TLS::Alert alert) BOTAN_TLS_CALLBACKS_OVERRIDE; /** * Called by the tls object at the end of the TLS handshake. We don't do * anything here appart from logging the TLS session information. */ - bool tls_handshake_cb(const Botan::TLS::Session& session); + bool tls_session_established(const Botan::TLS::Session& session) BOTAN_TLS_CALLBACKS_OVERRIDE; + +#if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,34) + void tls_verify_cert_chain(const std::vector<Botan::X509_Certificate>& cert_chain, + const std::vector<std::shared_ptr<const Botan::OCSP::Response>>& ocsp_responses, + const std::vector<Botan::Certificate_Store*>& trusted_roots, + Botan::Usage_Type usage, + const std::string& hostname, + const Botan::TLS::Policy& policy) BOTAN_TLS_CALLBACKS_OVERRIDE; +#endif /** * Called whenever the tls session goes from inactive to active. This * means that the handshake has just been successfully done, and we can @@ -185,20 +196,11 @@ private: * Where data is added, when we want to send something to the client. */ std::vector<std::string> out_buf; +protected: /** - * DNS resolver - */ - Resolver resolver; - /** - * Keep the details of the addrinfo returned by the resolver that - * triggered a EINPROGRESS error when connect()ing to it, to reuse it - * directly when connect() is called again. + * Whether we are using TLS on this connection or not. */ - struct addrinfo addrinfo; - struct sockaddr_in6 ai_addr; - socklen_t ai_addrlen; - -protected: + bool use_tls; /** * Where data read from the socket is added until we can extract a full * and meaningful “message” from it. @@ -207,9 +209,9 @@ protected: */ std::string in_buf; /** - * Whether we are using TLS on this connection or not. + * Remove the given “size” first bytes from our in_buf. */ - bool use_tls; + void consume_in_buffer(const std::size_t size); /** * Provide a buffer in which data can be directly received. This can be * used to avoid copying data into in_buf before using it. If no buffer @@ -219,38 +221,12 @@ protected: */ virtual void* get_receive_buffer(const size_t size) const; /** - * Hostname we are connected/connecting to - */ - std::string address; - /** - * Port we are connected/connecting to - */ - std::string port; - - bool connected; - bool connecting; - - bool hostname_resolution_failed; - - /** - * Address to bind the socket to, before calling connect(). - * If empty, it’s equivalent to binding to INADDR_ANY. - */ - std::string bind_addr; - -private: - /** - * Display the resolved IP, just for information purpose. + * Called when we detect a disconnection from the remote host. */ - void display_resolved_ip(struct addrinfo* rp) const; + virtual void on_connection_close(const std::string&) {} + virtual void on_connection_failed(const std::string&) {} #ifdef BOTAN_FOUND - /** - * Botan stuff to manipulate a TLS session. - */ - static Botan::AutoSeeded_RNG rng; - static Botan::TLS::Policy policy; - static Botan::TLS::Session_Manager_In_Memory session_manager; protected: BasicCredentialsManager credential_manager; private: |