From b86547dc1ef407ca3838444533bc7145e32a0d90 Mon Sep 17 00:00:00 2001 From: Florent Le Coz Date: Wed, 9 Jul 2014 13:02:37 +0200 Subject: Implement async DNS resolution using c-ares fix #2533 --- src/config.h.cmake | 1 + src/main.cpp | 16 +++- src/network/dns_handler.cpp | 112 +++++++++++++++++++++++++ src/network/dns_handler.hpp | 62 ++++++++++++++ src/network/dns_socket_handler.cpp | 45 ++++++++++ src/network/dns_socket_handler.hpp | 46 ++++++++++ src/network/socket_handler.hpp | 2 + src/network/tcp_socket_handler.cpp | 166 +++++++++++++++++++++++++++++++++++-- src/network/tcp_socket_handler.hpp | 47 +++++++++-- src/xmpp/xmpp_component.cpp | 2 +- 10 files changed, 486 insertions(+), 13 deletions(-) create mode 100644 src/network/dns_handler.cpp create mode 100644 src/network/dns_handler.hpp create mode 100644 src/network/dns_socket_handler.cpp create mode 100644 src/network/dns_socket_handler.hpp (limited to 'src') diff --git a/src/config.h.cmake b/src/config.h.cmake index 8eb2d1c..18d546f 100644 --- a/src/config.h.cmake +++ b/src/config.h.cmake @@ -4,4 +4,5 @@ #cmakedefine SYSTEMD_FOUND #cmakedefine POLLER ${POLLER} #cmakedefine BOTAN_FOUND +#cmakedefine CARES_FOUND #cmakedefine BIBOUMI_VERSION "${BIBOUMI_VERSION}" diff --git a/src/main.cpp b/src/main.cpp index a67baf9..94c3cb5 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -11,6 +10,10 @@ #include +#ifdef CARES_FOUND +# include +#endif + // A flag set by the SIGINT signal handler. static volatile std::atomic stop(false); // Flag set by the SIGUSR1/2 signal handler. @@ -95,6 +98,10 @@ int main(int ac, char** av) xmpp_component->start(); + +#ifdef CARES_FOUND + DNSHandler::instance.watch_dns_sockets(p); +#endif auto timeout = TimedEventsManager::instance().get_timeout(); while (p->poll(timeout) != -1) { @@ -108,6 +115,9 @@ int main(int ac, char** av) exiting = true; stop.store(false); xmpp_component->shutdown(); +#ifdef CARES_FOUND + DNSHandler::instance.destroy(); +#endif // Cancel the timer for an potential reconnection TimedEventsManager::instance().cancel("XMPP reconnection"); } @@ -153,6 +163,10 @@ int main(int ac, char** av) xmpp_component->close(); if (exiting && p->size() == 1 && xmpp_component->is_document_open()) xmpp_component->close_document(); +#ifdef CARES_FOUND + if (!exiting) + DNSHandler::instance.watch_dns_sockets(p); +#endif if (exiting) // If we are exiting, do not wait for any timed event timeout = utils::no_timeout; else diff --git a/src/network/dns_handler.cpp b/src/network/dns_handler.cpp new file mode 100644 index 0000000..45bf626 --- /dev/null +++ b/src/network/dns_handler.cpp @@ -0,0 +1,112 @@ +#include +#ifdef CARES_FOUND + +#include +#include +#include +#include + +#include +#include + +DNSHandler DNSHandler::instance; + +using namespace std::string_literals; + +void on_hostname4_resolved(void* arg, int status, int, struct hostent* hostent) +{ + TCPSocketHandler* socket_handler = static_cast(arg); + socket_handler->on_hostname4_resolved(status, hostent); +} + +void on_hostname6_resolved(void* arg, int status, int, struct hostent* hostent) +{ + TCPSocketHandler* socket_handler = static_cast(arg); + socket_handler->on_hostname6_resolved(status, hostent); +} + +DNSHandler::DNSHandler() +{ + int ares_error; + if ((ares_error = ::ares_library_init(ARES_LIB_INIT_ALL)) != 0) + throw std::runtime_error("Failed to initialize c-ares lib: "s + ares_strerror(ares_error)); + if ((ares_error = ::ares_init(&this->channel)) != ARES_SUCCESS) + throw std::runtime_error("Failed to initialize c-ares channel: "s + ares_strerror(ares_error)); +} + +ares_channel& DNSHandler::get_channel() +{ + return this->channel; +} + +void DNSHandler::destroy() +{ + this->socket_handlers.clear(); + ::ares_destroy(this->channel); + ::ares_library_cleanup(); +} + +void DNSHandler::gethostbyname(const std::string& name, + TCPSocketHandler* socket_handler, int family) +{ + socket_handler->free_cares_addrinfo(); + if (family == AF_INET) + ::ares_gethostbyname(this->channel, name.data(), family, + &::on_hostname4_resolved, socket_handler); + else + ::ares_gethostbyname(this->channel, name.data(), family, + &::on_hostname6_resolved, socket_handler); +} + +void DNSHandler::watch_dns_sockets(std::shared_ptr& poller) +{ + fd_set readers; + fd_set writers; + + FD_ZERO(&readers); + FD_ZERO(&writers); + + int ndfs = ::ares_fds(this->channel, &readers, &writers); + // For each existing DNS socket, see if we are still supposed to watch it, + // if not then erase it + this->socket_handlers.erase( + std::remove_if(this->socket_handlers.begin(), this->socket_handlers.end(), + [&readers](const auto& dns_socket) + { + return !FD_ISSET(dns_socket->get_socket(), &readers); + }), + this->socket_handlers.end()); + + for (auto i = 0; i < ndfs; ++i) + { + bool read = FD_ISSET(i, &readers); + bool write = FD_ISSET(i, &writers); + // Look for the DNSSocketHandler with this fd + auto it = std::find_if(this->socket_handlers.begin(), + this->socket_handlers.end(), + [i](const auto& socket_handler) + { + return i == socket_handler->get_socket(); + }); + if (!read && !write) // No need to read or write to it + { // If found, erase it and stop watching it because it is not + // needed anymore + if (it != this->socket_handlers.end()) + // The socket destructor removes it from the poller + this->socket_handlers.erase(it); + } + else // We need to write and/or read to it + { // If not found, create it because we need to watch it + if (it == this->socket_handlers.end()) + { + this->socket_handlers.emplace_front(std::make_unique(poller, i)); + it = this->socket_handlers.begin(); + } + poller->add_socket_handler(it->get()); + if (write) + poller->watch_send_events(it->get()); + } + } +} + +#endif /* CARES_FOUND */ diff --git a/src/network/dns_handler.hpp b/src/network/dns_handler.hpp new file mode 100644 index 0000000..ec5b2fa --- /dev/null +++ b/src/network/dns_handler.hpp @@ -0,0 +1,62 @@ +#ifndef DNS_HANDLER_HPP_INCLUDED +#define DNS_HANDLER_HPP_INCLUDED + +#include +#ifdef CARES_FOUND + +class TCPSocketHandler; +class Poller; +class DNSSocketHandler; + +# include +# include +# include +# include + +void on_hostname4_resolved(void* arg, int status, int, struct hostent* hostent); +void on_hostname6_resolved(void* arg, int status, int, struct hostent* hostent); + +/** + * Class managing DNS resolution. It should only be statically instanciated + * once in SocketHandler. It manages ares channel and calls various + * functions of that library. + */ + +class DNSHandler +{ +public: + DNSHandler(); + ~DNSHandler() = default; + void gethostbyname(const std::string& name, TCPSocketHandler* socket_handler, + int family); + /** + * Call ares_fds to know what fd needs to be watched by the poller, create + * or destroy DNSSocketHandlers depending on the result. + */ + void watch_dns_sockets(std::shared_ptr& poller); + /** + * Destroy and stop watching all the DNS sockets. Then de-init the channel + * and library. + */ + void destroy(); + ares_channel& get_channel(); + + static DNSHandler instance; + +private: + /** + * The list of sockets that needs to be watched, according to the last + * call to ares_fds. DNSSocketHandlers are added to it or removed from it + * in the watch_dns_sockets() method + */ + std::list> socket_handlers; + ares_channel channel; + + DNSHandler(const DNSHandler&) = delete; + DNSHandler(DNSHandler&&) = delete; + DNSHandler& operator=(const DNSHandler&) = delete; + DNSHandler& operator=(DNSHandler&&) = delete; +}; + +#endif /* CARES_FOUND */ +#endif /* DNS_HANDLER_HPP_INCLUDED */ diff --git a/src/network/dns_socket_handler.cpp b/src/network/dns_socket_handler.cpp new file mode 100644 index 0000000..6563894 --- /dev/null +++ b/src/network/dns_socket_handler.cpp @@ -0,0 +1,45 @@ +#include +#ifdef CARES_FOUND + +#include +#include +#include + +#include + +DNSSocketHandler::DNSSocketHandler(std::shared_ptr poller, + const socket_t socket): + SocketHandler(poller, socket) +{ +} + +DNSSocketHandler::~DNSSocketHandler() +{ +} + +void DNSSocketHandler::connect() +{ +} + +void DNSSocketHandler::on_recv() +{ + // always stop watching send and read events. We will re-watch them if the + // next call to ares_fds tell us to + this->poller->remove_socket_handler(this->socket); + ::ares_process_fd(DNSHandler::instance.get_channel(), this->socket, ARES_SOCKET_BAD); +} + +void DNSSocketHandler::on_send() +{ + // always stop watching send and read events. We will re-watch them if the + // next call to ares_fds tell us to + this->poller->remove_socket_handler(this->socket); + ::ares_process_fd(DNSHandler::instance.get_channel(), ARES_SOCKET_BAD, this->socket); +} + +bool DNSSocketHandler::is_connected() const +{ + return true; +} + +#endif /* CARES_FOUND */ diff --git a/src/network/dns_socket_handler.hpp b/src/network/dns_socket_handler.hpp new file mode 100644 index 0000000..beb47d9 --- /dev/null +++ b/src/network/dns_socket_handler.hpp @@ -0,0 +1,46 @@ +#ifndef DNS_SOCKET_HANDLER_HPP +# define DNS_SOCKET_HANDLER_HPP + +#include +#ifdef CARES_FOUND + +#include +#include + +/** + * Manage a socket returned by ares_fds. We do not create, open or close the + * socket ourself: this is done by c-ares. We just call ares_process_fd() + * with the correct parameters, depending on what can be done on that socket + * (Poller reported it to be writable or readeable) + */ + +class DNSSocketHandler: public SocketHandler +{ +public: + explicit DNSSocketHandler(std::shared_ptr poller, const socket_t socket); + ~DNSSocketHandler(); + /** + * Just call dns_process_fd, c-ares will do its work of send()ing or + * recv()ing the data it wants on that socket. + */ + void on_recv() override final; + void on_send() override final; + /** + * Do nothing, because we are always considered to be connected, since the + * connection is done by c-ares and not by us. + */ + void connect() override final; + /** + * Always true, see the comment for connect() + */ + bool is_connected() const override final; + +private: + DNSSocketHandler(const DNSSocketHandler&) = delete; + DNSSocketHandler(DNSSocketHandler&&) = delete; + DNSSocketHandler& operator=(const DNSSocketHandler&) = delete; + DNSSocketHandler& operator=(DNSSocketHandler&&) = delete; +}; + +#endif // CARES_FOUND +#endif // DNS_SOCKET_HANDLER_HPP diff --git a/src/network/socket_handler.hpp b/src/network/socket_handler.hpp index 9a894a4..0858474 100644 --- a/src/network/socket_handler.hpp +++ b/src/network/socket_handler.hpp @@ -1,6 +1,7 @@ #ifndef SOCKET_HANDLER_HPP # define SOCKET_HANDLER_HPP +#include #include class Poller; @@ -19,6 +20,7 @@ public: virtual void on_send() = 0; virtual void connect() = 0; virtual bool is_connected() const = 0; + socket_t get_socket() const { return this->socket; } diff --git a/src/network/tcp_socket_handler.cpp b/src/network/tcp_socket_handler.cpp index 1d1eaa7..e9984e3 100644 --- a/src/network/tcp_socket_handler.cpp +++ b/src/network/tcp_socket_handler.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -42,8 +43,22 @@ TCPSocketHandler::TCPSocketHandler(std::shared_ptr poller): use_tls(false), connected(false), connecting(false) +#ifdef CARES_FOUND + ,resolved(false), + resolved4(false), + resolved6(false), + cares_addrinfo(nullptr), + cares_error() +#endif {} +TCPSocketHandler::~TCPSocketHandler() +{ +#ifdef CARES_FOUND + this->free_cares_addrinfo(); +#endif +} + void TCPSocketHandler::init_socket(const struct addrinfo* rp) { if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1) @@ -72,9 +87,35 @@ void TCPSocketHandler::connect(const std::string& address, const std::string& po if (!this->connecting) { + // Get the addrinfo from getaddrinfo (or ares_gethostbyname), only if + // this is the first call of this function. +#ifdef CARES_FOUND + if (!this->resolved) + { + log_info("Trying to connect to " << address << ":" << port); + // Start the asynchronous process of resolving the hostname. Once + // the addresses have been found and `resolved` has been set to true + // (but connecting will still be false), TCPSocketHandler::connect() + // needs to be called, again. + DNSHandler::instance.gethostbyname(address, this, AF_INET6); + DNSHandler::instance.gethostbyname(address, this, AF_INET); + return; + } + else + { + // The c-ares resolved the hostname and the available addresses + // where saved in the cares_addrinfo linked list. Now, just use + // this list to try to connect. + addr_res = this->cares_addrinfo; + if (!addr_res) + { + this->close(); + this->on_connection_failed(this->cares_error); + return ; + } + } +#else log_info("Trying to connect to " << address << ":" << port); - // Get the addrinfo from getaddrinfo, only if this is the first call - // of this function. struct addrinfo hints; memset(&hints, 0, sizeof(struct addrinfo)); hints.ai_flags = 0; @@ -94,6 +135,7 @@ void TCPSocketHandler::connect(const std::string& address, const std::string& po // Make sure the alloced structure is always freed at the end of the // function sg.add_callback([&addr_res](){ freeaddrinfo(addr_res); }); +#endif } else { // This function is called again, use the saved addrinfo structure, @@ -144,9 +186,9 @@ void TCPSocketHandler::connect(const std::string& address, const std::string& po // If the connection has not succeeded or failed in 5s, we consider // it to have failed TimedEventsManager::instance().add_event( - TimedEvent(std::chrono::steady_clock::now() + 5s, - std::bind(&TCPSocketHandler::on_connection_timeout, this), - "connection_timeout"s + std::to_string(this->socket))); + TimedEvent(std::chrono::steady_clock::now() + 5s, + std::bind(&TCPSocketHandler::on_connection_timeout, this), + "connection_timeout"s + std::to_string(this->socket))); return ; } log_info("Connection failed:" << strerror(errno)); @@ -321,7 +363,11 @@ bool TCPSocketHandler::is_connected() const bool TCPSocketHandler::is_connecting() const { +#ifdef CARES_FOUND + return this->connecting || !this->resolved; +#else return this->connecting; +#endif } void* TCPSocketHandler::get_receive_buffer(const size_t) const @@ -413,4 +459,114 @@ void TCPSocketHandler::on_tls_activated() { this->send_data(""); } + #endif // BOTAN_FOUND + +#ifdef CARES_FOUND + +void TCPSocketHandler::on_hostname4_resolved(int status, struct hostent* hostent) +{ + this->resolved4 = true; + if (status == ARES_SUCCESS) + this->fill_ares_addrinfo4(hostent); + else + this->cares_error = ::ares_strerror(status); + + if (this->resolved4 && this->resolved6) + { + this->resolved = true; + this->connect(); + } +} + +void TCPSocketHandler::on_hostname6_resolved(int status, struct hostent* hostent) +{ + this->resolved6 = true; + if (status == ARES_SUCCESS) + this->fill_ares_addrinfo6(hostent); + else + this->cares_error = ::ares_strerror(status); + + if (this->resolved4 && this->resolved6) + { + this->resolved = true; + this->connect(); + } +} + +void TCPSocketHandler::fill_ares_addrinfo4(const struct hostent* hostent) +{ + struct addrinfo* prev = this->cares_addrinfo; + struct in_addr** address = reinterpret_cast(hostent->h_addr_list); + + while (*address) + { + // Create a new addrinfo list element, and fill it + struct addrinfo* current = new struct addrinfo; + current->ai_flags = 0; + current->ai_family = hostent->h_addrtype; + current->ai_socktype = SOCK_STREAM; + current->ai_protocol = 0; + current->ai_addrlen = sizeof(struct sockaddr_in); + + struct sockaddr_in* addr = new struct sockaddr_in; + addr->sin_family = hostent->h_addrtype; + addr->sin_port = htons(strtoul(this->port.data(), nullptr, 10)); + addr->sin_addr.s_addr = (*address)->s_addr; + + current->ai_addr = reinterpret_cast(addr); + current->ai_next = nullptr; + current->ai_canonname = nullptr; + + current->ai_next = prev; + this->cares_addrinfo = current; + prev = current; + ++address; + } +} + +void TCPSocketHandler::fill_ares_addrinfo6(const struct hostent* hostent) +{ + struct addrinfo* prev = this->cares_addrinfo; + struct in6_addr** address = reinterpret_cast(hostent->h_addr_list); + + while (*address) + { + // Create a new addrinfo list element, and fill it + struct addrinfo* current = new struct addrinfo; + current->ai_flags = 0; + current->ai_family = hostent->h_addrtype; + current->ai_socktype = SOCK_STREAM; + current->ai_protocol = 0; + current->ai_addrlen = sizeof(struct sockaddr_in6); + + struct sockaddr_in6* addr = new struct sockaddr_in6; + addr->sin6_family = hostent->h_addrtype; + addr->sin6_port = htons(strtoul(this->port.data(), nullptr, 10)); + ::memcpy(addr->sin6_addr.s6_addr, (*address)->s6_addr, 16); + addr->sin6_flowinfo = 0; + addr->sin6_scope_id = 0; + + current->ai_addr = reinterpret_cast(addr); + current->ai_next = nullptr; + current->ai_canonname = nullptr; + + current->ai_next = prev; + this->cares_addrinfo = current; + prev = current; + ++address; + } +} + +void TCPSocketHandler::free_cares_addrinfo() +{ + while (this->cares_addrinfo) + { + delete this->cares_addrinfo->ai_addr; + auto next = this->cares_addrinfo->ai_next; + delete this->cares_addrinfo; + this->cares_addrinfo = next; + } +} + +#endif // CARES_FOUND diff --git a/src/network/tcp_socket_handler.hpp b/src/network/tcp_socket_handler.hpp index 6d4bbe4..7f10cff 100644 --- a/src/network/tcp_socket_handler.hpp +++ b/src/network/tcp_socket_handler.hpp @@ -17,6 +17,10 @@ #include "config.h" +#ifdef CARES_FOUND +# include +#endif + #ifdef BOTAN_FOUND # include # include @@ -44,7 +48,7 @@ public: class TCPSocketHandler: public SocketHandler { protected: - ~TCPSocketHandler() {} + ~TCPSocketHandler(); public: explicit TCPSocketHandler(std::shared_ptr poller); @@ -54,16 +58,16 @@ public: * start_tls() when the connection succeeds. */ void connect(const std::string& address, const std::string& port, const bool tls); - void connect(); + void connect() override final; /** * Reads raw data from the socket. And pass it to parse_in_buffer() * If we are using TLS on this connection, we call tls_recv() */ - void on_recv(); + void on_recv() override final; /** * Write as much data from out_buf as possible, in the socket. */ - void on_send(); + void on_send() override final; /** * Add the given data to out_buf and tell our poller that we want to be * notified when a send event is ready. @@ -107,9 +111,19 @@ public: * The size argument is the size of the last chunk of data that was added to the buffer. */ virtual void parse_in_buffer(const size_t size) = 0; - bool is_connected() const; + bool is_connected() const override final; bool is_connecting() const; +#ifdef CARES_FOUND + void on_hostname4_resolved(int status, struct hostent* hostent); + void on_hostname6_resolved(int status, struct hostent* hostent); + + void free_cares_addrinfo(); + + void fill_ares_addrinfo4(const struct hostent* hostent); + void fill_ares_addrinfo6(const struct hostent* hostent); +#endif + private: /** * Initialize the socket with the parameters contained in the given @@ -185,7 +199,7 @@ private: */ std::list out_buf; /** - * Keep the details of the addrinfo the triggered a EINPROGRESS error when + * Keep the details of the addrinfo that triggered a EINPROGRESS error when * connect()ing to it, to reuse it directly when connect() is called * again. */ @@ -225,6 +239,27 @@ protected: bool connected; bool connecting; +#ifdef CARES_FOUND + /** + * Whether or not the DNS resolution was successfully done + */ + bool resolved; + bool resolved4; + bool resolved6; + /** + * When using c-ares to resolve the host asynchronously, we need the + * c-ares callback to fill a structure (a struct addrinfo, for + * compatibility with getaddrinfo and the rest of the code that works when + * c-ares is not used) with all returned values (for example an IPv6 and + * an IPv4). The next call of connect() will then try all these values + * (exactly like we do with the result of getaddrinfo) and save the one + * that worked (or returned EINPROGRESS) in the other struct addrinfo (see + * the members addrinfo, ai_addrlen, and ai_addr). + */ + struct addrinfo* cares_addrinfo; + std::string cares_error; +#endif // CARES_FOUND + private: TCPSocketHandler(const TCPSocketHandler&) = delete; TCPSocketHandler(TCPSocketHandler&&) = delete; diff --git a/src/xmpp/xmpp_component.cpp b/src/xmpp/xmpp_component.cpp index 841ead4..1df1e5d 100644 --- a/src/xmpp/xmpp_component.cpp +++ b/src/xmpp/xmpp_component.cpp @@ -70,7 +70,7 @@ XmppComponent::~XmppComponent() void XmppComponent::start() { - this->connect("127.0.0.1", Config::get("port", "5347"), false); + this->connect("localhost", Config::get("port", "5347"), false); } bool XmppComponent::is_document_open() const -- cgit v1.2.3