diff options
Diffstat (limited to 'src/network/socket_handler.cpp')
-rw-r--r-- | src/network/socket_handler.cpp | 172 |
1 files changed, 154 insertions, 18 deletions
diff --git a/src/network/socket_handler.cpp b/src/network/socket_handler.cpp index d509513..d989623 100644 --- a/src/network/socket_handler.cpp +++ b/src/network/socket_handler.cpp @@ -10,26 +10,39 @@ #include <unistd.h> #include <stdlib.h> #include <errno.h> +#include <netdb.h> #include <cstring> #include <fcntl.h> -#include <netdb.h> #include <stdio.h> #include <iostream> -using namespace std::string_literals; +#ifdef BOTAN_FOUND +# include <botan/hex.h> +#endif #ifndef UIO_FASTIOV # define UIO_FASTIOV 8 #endif +using namespace std::string_literals; + +namespace ph = std::placeholders; + SocketHandler::SocketHandler(std::shared_ptr<Poller> poller): socket(-1), poller(poller), + use_tls(false), connected(false), connecting(false) -{ -} +#ifdef BOTAN_FOUND + , + rng(), + credential_manager(), + policy(), + session_manager(rng) +#endif +{} void SocketHandler::init_socket(const struct addrinfo* rp) { @@ -47,10 +60,11 @@ void SocketHandler::init_socket(const struct addrinfo* rp) throw std::runtime_error("Could not initialize socket: "s + strerror(errno)); } -void SocketHandler::connect(const std::string& address, const std::string& port) +void SocketHandler::connect(const std::string& address, const std::string& port, const bool tls) { this->address = address; this->port = port; + this->use_tls = tls; utils::ScopeGuard sg; @@ -106,6 +120,10 @@ void SocketHandler::connect(const std::string& address, const std::string& port) this->poller->add_socket_handler(this); this->connected = true; this->connecting = false; +#ifdef BOTAN_FOUND + if (this->use_tls) + this->start_tls(); +#endif this->on_connected(); return ; } @@ -133,11 +151,21 @@ void SocketHandler::connect(const std::string& address, const std::string& port) void SocketHandler::connect() { - this->connect(this->address, this->port); + this->connect(this->address, this->port, this->use_tls); } void SocketHandler::on_recv() { +#ifdef BOTAN_FOUND + if (this->use_tls) + this->tls_recv(); + else +#endif + this->plain_recv(); +} + +void SocketHandler::plain_recv() +{ static constexpr size_t buf_size = 4096; char buf[buf_size]; void* recv_buf = this->get_receive_buffer(buf_size); @@ -145,6 +173,23 @@ void SocketHandler::on_recv() if (recv_buf == nullptr) recv_buf = buf; + const ssize_t size = this->do_recv(recv_buf, buf_size); + + if (size > 0) + { + if (buf == recv_buf) + { + // data needs to be placed in the in_buf string, because no buffer + // was provided to receive that data directly. The in_buf buffer + // will be handled in parse_in_buffer() + this->in_buf += std::string(buf, size); + } + this->parse_in_buffer(size); + } +} + +ssize_t SocketHandler::do_recv(void* recv_buf, const size_t buf_size) +{ ssize_t size = ::recv(this->socket, recv_buf, buf_size, 0); if (0 == size) { @@ -155,22 +200,17 @@ void SocketHandler::on_recv() { log_warning("Error while reading from socket: " << strerror(errno)); if (this->connecting) - this->on_connection_failed(strerror(errno)); + { + this->close(); + this->on_connection_failed(strerror(errno)); + } else - this->on_connection_close(); - this->close(); - } - else - { - if (buf == recv_buf) { - // data needs to be placed in the in_buf string, because no buffer - // was provided to receive that data directly. The in_buf buffer - // will be handled in parse_in_buffer() - this->in_buf += std::string(buf, size); + this->close(); + this->on_connection_close(); } - this->parse_in_buffer(size); } + return size; } void SocketHandler::on_send() @@ -242,6 +282,16 @@ socket_t SocketHandler::get_socket() const void SocketHandler::send_data(std::string&& data) { +#ifdef BOTAN_FOUND + if (this->use_tls) + this->tls_send(std::move(data)); + else +#endif + this->raw_send(std::move(data)); +} + +void SocketHandler::raw_send(std::string&& data) +{ if (data.empty()) return ; this->out_buf.emplace_back(std::move(data)); @@ -269,3 +319,89 @@ void* SocketHandler::get_receive_buffer(const size_t) const { return nullptr; } + +#ifdef BOTAN_FOUND +void SocketHandler::start_tls() +{ + Botan::TLS::Server_Information server_info(this->address, "irc", std::stoul(this->port)); + this->tls = std::make_unique<Botan::TLS::Client>( + std::bind(&SocketHandler::tls_output_fn, this, ph::_1, ph::_2), + std::bind(&SocketHandler::tls_data_cb, this, ph::_1, ph::_2), + std::bind(&SocketHandler::tls_alert_cb, this, ph::_1, ph::_2, ph::_3), + std::bind(&SocketHandler::tls_handshake_cb, this, ph::_1), + session_manager, credential_manager, policy, + rng, server_info, Botan::TLS::Protocol_Version::latest_tls_version()); +} + +void SocketHandler::tls_recv() +{ + static constexpr size_t buf_size = 4096; + char recv_buf[buf_size]; + + const ssize_t size = this->do_recv(recv_buf, buf_size); + if (size > 0) + { + const bool was_active = this->tls->is_active(); + this->tls->received_data(reinterpret_cast<const Botan::byte*>(recv_buf), + static_cast<size_t>(size)); + if (!was_active && this->tls->is_active()) + this->on_tls_activated(); + } +} + +void SocketHandler::tls_send(std::string&& data) +{ + if (this->tls->is_active()) + { + const bool was_active = this->tls->is_active(); + if (!this->pre_buf.empty()) + { + this->tls->send(reinterpret_cast<const Botan::byte*>(this->pre_buf.data()), + this->pre_buf.size()); + this->pre_buf = ""; + } + if (!data.empty()) + this->tls->send(reinterpret_cast<const Botan::byte*>(data.data()), + data.size()); + if (!was_active && this->tls->is_active()) + this->on_tls_activated(); + } + else + this->pre_buf += data; +} + +void SocketHandler::tls_data_cb(const Botan::byte* data, size_t size) +{ + this->in_buf += std::string(reinterpret_cast<const char*>(data), + size); + if (!this->in_buf.empty()) + this->parse_in_buffer(size); +} + +void SocketHandler::tls_output_fn(const Botan::byte* data, size_t size) +{ + this->raw_send(std::string(reinterpret_cast<const char*>(data), size)); +} + +void SocketHandler::tls_alert_cb(Botan::TLS::Alert alert, const Botan::byte*, size_t) +{ + log_debug("tls_alert: " << alert.type_string()); +} + +bool SocketHandler::tls_handshake_cb(const Botan::TLS::Session& session) +{ + log_debug("Handshake with " << session.server_info().hostname() << " complete." + << " Version: " << session.version().to_string() + << " using " << session.ciphersuite().to_string()); + if (!session.session_id().empty()) + log_debug("Session ID " << Botan::hex_encode(session.session_id())); + if (!session.session_ticket().empty()) + log_debug("Session ticket " << Botan::hex_encode(session.session_ticket())); + return true; +} + +void SocketHandler::on_tls_activated() +{ + this->send_data(""); +} +#endif // BOTAN_FOUND |