#include <network/tcp_socket_handler.hpp> #include <network/dns_handler.hpp> #include <utils/timed_events.hpp> #include <utils/scopeguard.hpp> #include <network/poller.hpp> #include <logger/logger.hpp> #include <sys/socket.h> #include <sys/types.h> #include <stdexcept> #include <unistd.h> #include <errno.h> #include <cstring> #include <fcntl.h> #ifdef BOTAN_FOUND # include <botan/hex.h> # include <botan/tls_exceptn.h> Botan::AutoSeeded_RNG TCPSocketHandler::rng; Botan::TLS::Policy TCPSocketHandler::policy; Botan::TLS::Session_Manager_In_Memory TCPSocketHandler::session_manager(TCPSocketHandler::rng); #endif #ifndef UIO_FASTIOV # define UIO_FASTIOV 8 #endif using namespace std::string_literals; using namespace std::chrono_literals; namespace ph = std::placeholders; TCPSocketHandler::TCPSocketHandler(std::shared_ptr<Poller> poller): SocketHandler(poller, -1), use_tls(false), connected(false), connecting(false), hostname_resolution_failed(false) #ifdef BOTAN_FOUND ,credential_manager(this) #endif {} TCPSocketHandler::~TCPSocketHandler() { this->close(); } void TCPSocketHandler::init_socket(const struct addrinfo* rp) { if (this->socket != -1) ::close(this->socket); if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1) throw std::runtime_error("Could not create socket: "s + strerror(errno)); // Bind the socket to a specific address, if specified if (!this->bind_addr.empty()) { // Convert the address from string format to a sockaddr that can be // used in bind() struct addrinfo* result; int err = ::getaddrinfo(this->bind_addr.data(), nullptr, nullptr, &result); if (err != 0 || !result) log_error("Failed to bind socket to ", this->bind_addr, ": ", gai_strerror(err)); else { utils::ScopeGuard sg([result](){ freeaddrinfo(result); }); struct addrinfo* rp; int bind_error = 0; for (rp = result; rp; rp = rp->ai_next) { if ((bind_error = ::bind(this->socket, reinterpret_cast<const struct sockaddr*>(rp->ai_addr), rp->ai_addrlen)) == 0) break; } if (!rp) log_error("Failed to bind socket to ", this->bind_addr, ": ", strerror(errno)); else log_info("Socket successfully bound to ", this->bind_addr); } } int optval = 1; if (::setsockopt(this->socket, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) == -1) log_warning("Failed to enable TCP keepalive on socket: ", strerror(errno)); // Set the socket on non-blocking mode. This is useful to receive a EAGAIN // error when connect() would block, to not block the whole process if a // remote is not responsive. const int existing_flags = ::fcntl(this->socket, F_GETFL, 0); if ((existing_flags == -1) || (::fcntl(this->socket, F_SETFL, existing_flags | O_NONBLOCK) == -1)) throw std::runtime_error("Could not initialize socket: "s + strerror(errno)); } void TCPSocketHandler::connect(const std::string& address, const std::string& port, const bool tls) { this->address = address; this->port = port; this->use_tls = tls; struct addrinfo* addr_res; if (!this->connecting) { // Get the addrinfo from getaddrinfo (or ares_gethostbyname), only if // this is the first call of this function. if (!this->resolver.is_resolved()) { log_info("Trying to connect to ", address, ":", port); // Start the asynchronous process of resolving the hostname. Once // the addresses have been found and `resolved` has been set to true // (but connecting will still be false), TCPSocketHandler::connect() // needs to be called, again. this->resolver.resolve(address, port, [this](const struct addrinfo*) { log_debug("Resolution success, calling connect() again"); this->connect(); }, [this](const char*) { log_debug("Resolution failed, calling connect() again"); this->connect(); }); return; } else { // The c-ares resolved the hostname and the available addresses // where saved in the cares_addrinfo linked list. Now, just use // this list to try to connect. addr_res = this->resolver.get_result().get(); if (!addr_res) { this->hostname_resolution_failed = true; const auto msg = this->resolver.get_error_message(); this->close(); this->on_connection_failed(msg); return ; } } } else { // This function is called again, use the saved addrinfo structure, // instead of re-doing the whole getaddrinfo process. addr_res = &this->addrinfo; } for (struct addrinfo* rp = addr_res; rp; rp = rp->ai_next) { if (!this->connecting) { try { this->init_socket(rp); } catch (const std::runtime_error& error) { log_error("Failed to init socket: ", error.what()); break; } } this->display_resolved_ip(rp); if (::connect(this->socket, rp->ai_addr, rp->ai_addrlen) == 0 || errno == EISCONN) { log_info("Connection success."); TimedEventsManager::instance().cancel("connection_timeout"s + std::to_string(this->socket)); this->poller->add_socket_handler(this); this->connected = true; this->connecting = false; #ifdef BOTAN_FOUND if (this->use_tls) this->start_tls(); #endif this->on_connected(); return ; } else if (errno == EINPROGRESS || errno == EALREADY) { // retry this process later, when the socket // is ready to be written on. this->connecting = true; this->poller->add_socket_handler(this); this->poller->watch_send_events(this); // Save the addrinfo structure, to use it on the next call this->ai_addrlen = rp->ai_addrlen; memcpy(&this->ai_addr, rp->ai_addr, this->ai_addrlen); memcpy(&this->addrinfo, rp, sizeof(struct addrinfo)); this->addrinfo.ai_addr = reinterpret_cast<struct sockaddr*>(&this->ai_addr); this->addrinfo.ai_next = nullptr; // If the connection has not succeeded or failed in 5s, we consider // it to have failed TimedEventsManager::instance().add_event( TimedEvent(std::chrono::steady_clock::now() + 5s, std::bind(&TCPSocketHandler::on_connection_timeout, this), "connection_timeout"s + std::to_string(this->socket))); return ; } log_info("Connection failed:", strerror(errno)); } log_error("All connection attempts failed."); this->close(); this->on_connection_failed(strerror(errno)); return ; } void TCPSocketHandler::on_connection_timeout() { this->close(); this->on_connection_failed("connection timed out"); } void TCPSocketHandler::connect() { this->connect(this->address, this->port, this->use_tls); } void TCPSocketHandler::on_recv() { #ifdef BOTAN_FOUND if (this->use_tls) this->tls_recv(); else #endif this->plain_recv(); } void TCPSocketHandler::plain_recv() { static constexpr size_t buf_size = 4096; char buf[buf_size]; void* recv_buf = this->get_receive_buffer(buf_size); if (recv_buf == nullptr) recv_buf = buf; const ssize_t size = this->do_recv(recv_buf, buf_size); if (size > 0) { if (buf == recv_buf) { // data needs to be placed in the in_buf string, because no buffer // was provided to receive that data directly. The in_buf buffer // will be handled in parse_in_buffer() this->in_buf += std::string(buf, size); } this->parse_in_buffer(size); } } ssize_t TCPSocketHandler::do_recv(void* recv_buf, const size_t buf_size) { ssize_t size = ::recv(this->socket, recv_buf, buf_size, 0); if (0 == size) { this->on_connection_close(""); this->close(); } else if (-1 == size) { if (this->connecting) log_warning("Error connecting: ", strerror(errno)); else log_warning("Error while reading from socket: ", strerror(errno)); // Remember if we were connecting, or already connected when this // happened, because close() sets this->connecting to false const auto were_connecting = this->connecting; this->close(); if (were_connecting) this->on_connection_failed(strerror(errno)); else this->on_connection_close(strerror(errno)); } return size; } void TCPSocketHandler::on_send() { struct iovec msg_iov[UIO_FASTIOV] = {}; struct msghdr msg{nullptr, 0, msg_iov, 0, nullptr, 0, 0}; for (const std::string& s: this->out_buf) { // unconsting the content of s is ok, sendmsg will never modify it msg_iov[msg.msg_iovlen].iov_base = const_cast<char*>(s.data()); msg_iov[msg.msg_iovlen].iov_len = s.size(); msg.msg_iovlen++; if (msg.msg_iovlen == UIO_FASTIOV) break; } ssize_t res = ::sendmsg(this->socket, &msg, MSG_NOSIGNAL); if (res < 0) { log_error("sendmsg failed: ", strerror(errno)); this->on_connection_close(strerror(errno)); this->close(); } else { // remove all the strings that were successfully sent. auto it = this->out_buf.begin(); while (it != this->out_buf.end()) { if (static_cast<size_t>(res) >= it->size()) { res -= it->size(); ++it; } else { // If one string has partially been sent, we use substr to // crop it if (res > 0) *it = it->substr(res, std::string::npos); break; } } this->out_buf.erase(this->out_buf.begin(), it); if (this->out_buf.empty()) this->poller->stop_watching_send_events(this); } } void TCPSocketHandler::close() { TimedEventsManager::instance().cancel("connection_timeout"s + std::to_string(this->socket)); if (this->connected || this->connecting) this->poller->remove_socket_handler(this->get_socket()); if (this->socket != -1) { ::close(this->socket); this->socket = -1; } this->connected = false; this->connecting = false; this->in_buf.clear(); this->out_buf.clear(); this->port.clear(); this->resolver.clear(); } void TCPSocketHandler::display_resolved_ip(struct addrinfo* rp) const { if (rp->ai_family == AF_INET) log_debug("Trying IPv4 address ", addr_to_string(rp)); else if (rp->ai_family == AF_INET6) log_debug("Trying IPv6 address ", addr_to_string(rp)); } void TCPSocketHandler::send_data(std::string&& data) { #ifdef BOTAN_FOUND if (this->use_tls) try { this->tls_send(std::move(data)); } catch (const Botan::TLS::TLS_Exception& e) { this->on_connection_close("TLS error: "s + e.what()); this->close(); return ; } else #endif this->raw_send(std::move(data)); } void TCPSocketHandler::raw_send(std::string&& data) { if (data.empty()) return ; this->out_buf.emplace_back(std::move(data)); if (this->connected) this->poller->watch_send_events(this); } void TCPSocketHandler::send_pending_data() { if (this->connected && !this->out_buf.empty()) this->poller->watch_send_events(this); } bool TCPSocketHandler::is_connected() const { return this->connected; } bool TCPSocketHandler::is_connecting() const { return this->connecting || this->resolver.is_resolving(); } void* TCPSocketHandler::get_receive_buffer(const size_t) const { return nullptr; } #ifdef BOTAN_FOUND void TCPSocketHandler::start_tls() { Botan::TLS::Server_Information server_info(this->address, "irc", std::stoul(this->port)); this->tls = std::make_unique<Botan::TLS::Client>( std::bind(&TCPSocketHandler::tls_output_fn, this, ph::_1, ph::_2), std::bind(&TCPSocketHandler::tls_data_cb, this, ph::_1, ph::_2), std::bind(&TCPSocketHandler::tls_alert_cb, this, ph::_1, ph::_2, ph::_3), std::bind(&TCPSocketHandler::tls_handshake_cb, this, ph::_1), session_manager, this->credential_manager, policy, rng, server_info, Botan::TLS::Protocol_Version::latest_tls_version()); } void TCPSocketHandler::tls_recv() { static constexpr size_t buf_size = 4096; Botan::byte recv_buf[buf_size]; const ssize_t size = this->do_recv(recv_buf, buf_size); if (size > 0) { const bool was_active = this->tls->is_active(); try { this->tls->received_data(recv_buf, static_cast<size_t>(size)); } catch (const Botan::TLS::TLS_Exception& e) { // May happen if the server sends malformed TLS data (buggy server, // or more probably we are just connected to a server that sends // plain-text) this->on_connection_close("TLS error: "s + e.what()); this->close(); return ; } if (!was_active && this->tls->is_active()) this->on_tls_activated(); } } void TCPSocketHandler::tls_send(std::string&& data) { // We may not be connected yet, or the tls session has // not yet been negociated if (this->tls && this->tls->is_active()) { const bool was_active = this->tls->is_active(); if (!this->pre_buf.empty()) { this->tls->send(this->pre_buf.data(), this->pre_buf.size()); this->pre_buf.clear(); } if (!data.empty()) this->tls->send(reinterpret_cast<const Botan::byte*>(data.data()), data.size()); if (!was_active && this->tls->is_active()) this->on_tls_activated(); } else this->pre_buf.insert(this->pre_buf.end(), std::make_move_iterator(data.begin()), std::make_move_iterator(data.end())); } void TCPSocketHandler::tls_data_cb(const Botan::byte* data, size_t size) { this->in_buf += std::string(reinterpret_cast<const char*>(data), size); if (!this->in_buf.empty()) this->parse_in_buffer(size); } void TCPSocketHandler::tls_output_fn(const Botan::byte* data, size_t size) { this->raw_send(std::string(reinterpret_cast<const char*>(data), size)); } void TCPSocketHandler::tls_alert_cb(Botan::TLS::Alert alert, const Botan::byte*, size_t) { log_debug("tls_alert: ", alert.type_string()); } bool TCPSocketHandler::tls_handshake_cb(const Botan::TLS::Session& session) { log_debug("Handshake with ", session.server_info().hostname(), " complete.", " Version: ", session.version().to_string(), " using ", session.ciphersuite().to_string()); if (!session.session_id().empty()) log_debug("Session ID ", Botan::hex_encode(session.session_id())); if (!session.session_ticket().empty()) log_debug("Session ticket ", Botan::hex_encode(session.session_ticket())); return true; } void TCPSocketHandler::on_tls_activated() { this->send_data({}); } #endif // BOTAN_FOUND