#include <network/tcp_client_socket_handler.hpp> #include <utils/timed_events.hpp> #include <utils/scopeguard.hpp> #include <network/poller.hpp> #include <logger/logger.hpp> #include <cstring> #include <unistd.h> #include <fcntl.h> using namespace std::string_literals; TCPClientSocketHandler::TCPClientSocketHandler(std::shared_ptr<Poller>& poller): TCPSocketHandler(poller), hostname_resolution_failed(false), connected(false), connecting(false) {} TCPClientSocketHandler::~TCPClientSocketHandler() { this->close(); } void TCPClientSocketHandler::init_socket(const struct addrinfo* rp) { if (this->socket != -1) ::close(this->socket); if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1) throw std::runtime_error("Could not create socket: "s + std::strerror(errno)); // Bind the socket to a specific address, if specified if (!this->bind_addr.empty()) { // Convert the address from string format to a sockaddr that can be // used in bind() struct addrinfo* result; struct addrinfo hints{}; memset(&hints, 0, sizeof(hints)); hints.ai_flags = AI_NUMERICHOST; hints.ai_family = AF_UNSPEC; int err = ::getaddrinfo(this->bind_addr.data(), nullptr, &hints, &result); if (err != 0 || !result) log_error("Failed to bind socket to ", this->bind_addr, ": ", gai_strerror(err)); else { utils::ScopeGuard sg([result](){ freeaddrinfo(result); }); struct addrinfo* rp; for (rp = result; rp; rp = rp->ai_next) { if ((::bind(this->socket, reinterpret_cast<const struct sockaddr*>(rp->ai_addr), rp->ai_addrlen)) == 0) break; } if (!rp) log_error("Failed to bind socket to ", this->bind_addr, ": ", strerror(errno)); else log_info("Socket successfully bound to ", this->bind_addr); } } int optval = 1; if (::setsockopt(this->socket, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) == -1) log_warning("Failed to enable TCP keepalive on socket: ", strerror(errno)); // Set the socket on non-blocking mode. This is useful to receive a EAGAIN // error when connect() would block, to not block the whole process if a // remote is not responsive. const int existing_flags = ::fcntl(this->socket, F_GETFL, 0); if ((existing_flags == -1) || (::fcntl(this->socket, F_SETFL, existing_flags | O_NONBLOCK) == -1)) throw std::runtime_error("Could not initialize socket: "s + std::strerror(errno)); } void TCPClientSocketHandler::connect(const std::string& address, const std::string& port, const bool tls) { this->address = address; this->port = port; this->use_tls = tls; struct addrinfo* addr_res; if (!this->connecting) { // Get the addrinfo from getaddrinfo (or using udns), only if // this is the first call of this function. if (!this->resolver.is_resolved()) { log_info("Trying to connect to ", address, ":", port); // Start the asynchronous process of resolving the hostname. Once // the addresses have been found and `resolved` has been set to true // (but connecting will still be false), TCPClientSocketHandler::connect() // needs to be called, again. this->resolver.resolve(address, port, [this](const struct addrinfo*) { log_debug("Resolution success, calling connect() again"); this->connect(); }, [this](const char*) { log_debug("Resolution failed, calling connect() again"); this->connect(); }); return; } else { // The DNS resolver resolved the hostname and the available addresses // where saved in the addrinfo linked list. Now, just use // this list to try to connect. addr_res = this->resolver.get_result().get(); if (!addr_res) { this->hostname_resolution_failed = true; const auto msg = this->resolver.get_error_message(); this->close(); this->on_connection_failed(msg); return ; } } } else { // This function is called again, use the saved addrinfo structure, // instead of re-doing the whole getaddrinfo process. addr_res = &this->addrinfo; } for (struct addrinfo* rp = addr_res; rp; rp = rp->ai_next) { if (!this->connecting) { try { this->init_socket(rp); } catch (const std::runtime_error& error) { log_error("Failed to init socket: ", error.what()); break; } } this->display_resolved_ip(rp); if (::connect(this->socket, rp->ai_addr, rp->ai_addrlen) == 0 || errno == EISCONN) { log_info("Connection success."); TimedEventsManager::instance().cancel("connection_timeout" + std::to_string(this->socket)); this->poller->add_socket_handler(this); this->connected = true; this->connecting = false; #ifdef BOTAN_FOUND if (this->use_tls) this->start_tls(this->address, this->port); #endif this->connection_date = std::chrono::system_clock::now(); // Get our local TCP port and store it this->local_port = static_cast<uint16_t>(-1); if (rp->ai_family == AF_INET6) { struct sockaddr_in6 a{}; socklen_t l = sizeof(a); if (::getsockname(this->socket, (struct sockaddr*)&a, &l) != -1) this->local_port = ntohs(a.sin6_port); } else if (rp->ai_family == AF_INET) { struct sockaddr_in a{}; socklen_t l = sizeof(a); if (::getsockname(this->socket, (struct sockaddr*)&a, &l) != -1) this->local_port = ntohs(a.sin_port); } log_debug("Local port: ", this->local_port, ", and remote port: ", this->port); this->on_connected(); return ; } else if (errno == EINPROGRESS || errno == EALREADY) { // retry this process later, when the socket // is ready to be written on. this->connecting = true; this->poller->add_socket_handler(this); this->poller->watch_send_events(this); // Save the addrinfo structure, to use it on the next call this->ai_addrlen = rp->ai_addrlen; memcpy(&this->ai_addr, rp->ai_addr, this->ai_addrlen); memcpy(&this->addrinfo, rp, sizeof(struct addrinfo)); this->addrinfo.ai_addr = reinterpret_cast<struct sockaddr*>(&this->ai_addr); this->addrinfo.ai_next = nullptr; // If the connection has not succeeded or failed in 5s, we consider // it to have failed TimedEventsManager::instance().add_event( TimedEvent(std::chrono::steady_clock::now() + 5s, std::bind(&TCPClientSocketHandler::on_connection_timeout, this), "connection_timeout" + std::to_string(this->socket))); return ; } log_info("Connection failed:", std::strerror(errno)); } log_error("All connection attempts failed."); this->close(); this->on_connection_failed(std::strerror(errno)); return ; } void TCPClientSocketHandler::on_connection_timeout() { this->close(); this->on_connection_failed("connection timed out"); } void TCPClientSocketHandler::connect() { this->connect(this->address, this->port, this->use_tls); } void TCPClientSocketHandler::close() { TimedEventsManager::instance().cancel("connection_timeout" + std::to_string(this->socket)); TCPSocketHandler::close(); this->connected = false; this->connecting = false; this->port.clear(); this->resolver.clear(); } void TCPClientSocketHandler::display_resolved_ip(struct addrinfo* rp) const { if (rp->ai_family == AF_INET) log_debug("Trying IPv4 address ", addr_to_string(rp)); else if (rp->ai_family == AF_INET6) log_debug("Trying IPv6 address ", addr_to_string(rp)); } bool TCPClientSocketHandler::is_connected() const { return this->connected; } bool TCPClientSocketHandler::is_connecting() const { return this->connecting || this->resolver.is_resolving(); } std::string TCPClientSocketHandler::get_port() const { return this->port; } bool TCPClientSocketHandler::match_port_pairt(const uint16_t local, const uint16_t remote) const { const auto remote_port = static_cast<uint16_t>(std::stoi(this->port)); return this->is_connected() && local == this->local_port && remote == remote_port; }