summaryrefslogtreecommitdiff
path: root/louloulibs
diff options
context:
space:
mode:
Diffstat (limited to 'louloulibs')
-rw-r--r--louloulibs/CMakeLists.txt33
-rw-r--r--louloulibs/cmake/Modules/FindBOTAN.cmake4
-rw-r--r--louloulibs/cmake/Modules/FindCARES.cmake37
-rw-r--r--louloulibs/cmake/Modules/FindUDNS.cmake37
-rw-r--r--louloulibs/config/config.cpp7
-rw-r--r--louloulibs/config/config.hpp1
-rw-r--r--louloulibs/logger/logger.cpp8
-rw-r--r--louloulibs/logger/logger.hpp18
-rw-r--r--louloulibs/louloulibs.h.cmake2
-rw-r--r--louloulibs/network/credentials_manager.cpp35
-rw-r--r--louloulibs/network/credentials_manager.hpp15
-rw-r--r--louloulibs/network/dns_handler.cpp126
-rw-r--r--louloulibs/network/dns_handler.hpp45
-rw-r--r--louloulibs/network/dns_socket_handler.cpp34
-rw-r--r--louloulibs/network/dns_socket_handler.hpp34
-rw-r--r--louloulibs/network/resolver.cpp282
-rw-r--r--louloulibs/network/resolver.hpp47
-rw-r--r--louloulibs/network/socket_handler.hpp6
-rw-r--r--louloulibs/network/tcp_client_socket_handler.cpp257
-rw-r--r--louloulibs/network/tcp_client_socket_handler.hpp82
-rw-r--r--louloulibs/network/tcp_server_socket.hpp70
-rw-r--r--louloulibs/network/tcp_socket_handler.cpp301
-rw-r--r--louloulibs/network/tcp_socket_handler.hpp152
-rw-r--r--louloulibs/utils/encoding.cpp1
-rw-r--r--louloulibs/utils/time.cpp2
-rw-r--r--louloulibs/xmpp/adhoc_command.cpp35
-rw-r--r--louloulibs/xmpp/adhoc_commands_handler.cpp35
-rw-r--r--louloulibs/xmpp/xmpp_component.cpp536
-rw-r--r--louloulibs/xmpp/xmpp_component.hpp8
-rw-r--r--louloulibs/xmpp/xmpp_stanza.hpp14
30 files changed, 1222 insertions, 1042 deletions
diff --git a/louloulibs/CMakeLists.txt b/louloulibs/CMakeLists.txt
index 908c35f..f672833 100644
--- a/louloulibs/CMakeLists.txt
+++ b/louloulibs/CMakeLists.txt
@@ -6,10 +6,6 @@ set(${PROJECT_NAME}_VERSION_SUFFIX "~dev")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++1y -pedantic -Wall -Wextra")
-# Define a __FILENAME__ macro to get the filename of each file, instead of
-# the full path as in __FILE__
-set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILENAME__='\"$(subst ${CMAKE_SOURCE_DIR}/,,$(abspath $<))\"'")
-
#
## Look for external libraries
#
@@ -37,10 +33,10 @@ elseif(NOT WITHOUT_BOTAN)
find_package(BOTAN)
endif()
-if(WITH_CARES)
- find_package(CARES REQUIRED)
-elseif(NOT WITHOUT_CARES)
- find_package(CARES)
+if(WITH_UDNS)
+ find_package(UDNS REQUIRED)
+elseif(NOT WITHOUT_UDNS)
+ find_package(UDNS)
endif()
# To be able to include the config.h file generated by cmake
@@ -72,10 +68,10 @@ if(BOTAN_FOUND)
set(BOTAN_INCLUDE_DIRS ${BOTAN_INCLUDE_DIRS} PARENT_SCOPE)
endif()
-if(CARES_FOUND)
- include_directories(${CARES_INCLUDE_DIRS})
- set(CARES_FOUND ${CARES_FOUND} PARENT_SCOPE)
- set(CARES_INCLUDE_DIRS ${CARES_INCLUDE_DIRS} PARENT_SCOPE)
+if(UDNS_FOUND)
+ include_directories(${UDNS_INCLUDE_DIRS})
+ set(UDNS_FOUND ${UDNS_FOUND} PARENT_SCOPE)
+ set(UDNS_INCLUDE_DIRS ${UDNS_INCLUDE_DIRS} PARENT_SCOPE)
endif()
set(POLLER_DOCSTRING "Choose the poller between POLL and EPOLL (Linux-only)")
@@ -103,7 +99,6 @@ target_link_libraries(utils ${ICONV_LIBRARIES})
file(GLOB source_config
config/*.[hc]pp)
add_library(config STATIC ${source_config})
-target_link_libraries(config utils)
#
## logger
@@ -123,8 +118,8 @@ target_link_libraries(network logger)
if(BOTAN_FOUND)
target_link_libraries(network ${BOTAN_LIBRARIES})
endif()
-if(CARES_FOUND)
- target_link_libraries(network ${CARES_LIBRARIES})
+if(UDNS_FOUND)
+ target_link_libraries(network ${UDNS_LIBRARIES})
endif()
#
@@ -143,6 +138,14 @@ if(SYSTEMD_FOUND)
target_link_libraries(xmpplib ${SYSTEMD_LIBRARIES})
endif()
+# Define a __FILENAME__ macro with the relative path (from the base project directory)
+# of each source file
+file(GLOB_RECURSE source_all *.[hc]pp)
+foreach(file ${source_all})
+ file(RELATIVE_PATH shorter_file ${CMAKE_CURRENT_SOURCE_DIR} ${file})
+ set_property(SOURCE ${file} APPEND PROPERTY COMPILE_DEFINITIONS __FILENAME__="${shorter_file}")
+endforeach()
+
#
## Check if we have std::get_time
#
diff --git a/louloulibs/cmake/Modules/FindBOTAN.cmake b/louloulibs/cmake/Modules/FindBOTAN.cmake
index a12bd35..26069f4 100644
--- a/louloulibs/cmake/Modules/FindBOTAN.cmake
+++ b/louloulibs/cmake/Modules/FindBOTAN.cmake
@@ -16,10 +16,10 @@
# This file is in the public domain
find_path(BOTAN_INCLUDE_DIRS NAMES botan/botan.h
- PATH_SUFFIXES botan-1.11
+ PATH_SUFFIXES botan-2 botan-1.11
DOC "The botan include directory")
-find_library(BOTAN_LIBRARIES NAMES botan botan-1.11
+find_library(BOTAN_LIBRARIES NAMES botan botan-2 botan-1.11
DOC "The botan library")
# Use some standard module to handle the QUIETLY and REQUIRED arguments, and
diff --git a/louloulibs/cmake/Modules/FindCARES.cmake b/louloulibs/cmake/Modules/FindCARES.cmake
deleted file mode 100644
index c4c757a..0000000
--- a/louloulibs/cmake/Modules/FindCARES.cmake
+++ /dev/null
@@ -1,37 +0,0 @@
-# - Find c-ares
-# Find the c-ares library, and more particularly the stringprep header.
-#
-# This module defines the following variables:
-# CARES_FOUND - True if library and include directory are found
-# If set to TRUE, the following are also defined:
-# CARES_INCLUDE_DIRS - The directory where to find the header file
-# CARES_LIBRARIES - Where to find the library file
-#
-# For conveniance, these variables are also set. They have the same values
-# than the variables above. The user can thus choose his/her prefered way
-# to write them.
-# CARES_INCLUDE_DIR
-# CARES_LIBRARY
-#
-# This file is in the public domain
-
-if(NOT CARES_FOUND)
- find_path(CARES_INCLUDE_DIRS NAMES ares.h
- DOC "The c-ares include directory")
-
- find_library(CARES_LIBRARIES NAMES cares
- DOC "The c-ares library")
-
- # Use some standard module to handle the QUIETLY and REQUIRED arguments, and
- # set CARES_FOUND to TRUE if these two variables are set.
- include(FindPackageHandleStandardArgs)
- find_package_handle_standard_args(CARES REQUIRED_VARS CARES_LIBRARIES CARES_INCLUDE_DIRS)
-
- # Compatibility for all the ways of writing these variables
- if(CARES_FOUND)
- set(CARES_INCLUDE_DIR ${CARES_INCLUDE_DIRS})
- set(CARES_LIBRARY ${CARES_LIBRARIES})
- endif()
-endif()
-
-mark_as_advanced(CARES_INCLUDE_DIRS CARES_LIBRARIES)
diff --git a/louloulibs/cmake/Modules/FindUDNS.cmake b/louloulibs/cmake/Modules/FindUDNS.cmake
new file mode 100644
index 0000000..1d32cd3
--- /dev/null
+++ b/louloulibs/cmake/Modules/FindUDNS.cmake
@@ -0,0 +1,37 @@
+# - Find udns
+# Find the udns library
+#
+# This module defines the following variables:
+# UDNS_FOUND - True if library and include directory are found
+# If set to TRUE, the following are also defined:
+# UDNS_INCLUDE_DIRS - The directory where to find the header file
+# UDNS_LIBRARIES - Where to find the library file
+#
+# For conveniance, these variables are also set. They have the same values
+# as the variables above. The user can thus choose his/her prefered way
+# to write them.
+# UDNS_INCLUDE_DIR
+# UDNS_LIBRARY
+#
+# This file is in the public domain
+
+if(NOT UDNS_FOUND)
+ find_path(UDNS_INCLUDE_DIRS NAMES udns.h
+ DOC "The udns include directory")
+
+ find_library(UDNS_LIBRARIES NAMES udns
+ DOC "The udns library")
+
+ # Use some standard module to handle the QUIETLY and REQUIRED arguments, and
+ # set UDNS_FOUND to TRUE if these two variables are set.
+ include(FindPackageHandleStandardArgs)
+ find_package_handle_standard_args(UDNS REQUIRED_VARS UDNS_LIBRARIES UDNS_INCLUDE_DIRS)
+
+ # Compatibility for all the ways of writing these variables
+ if(UDNS_FOUND)
+ set(UDNS_INCLUDE_DIR ${UDNS_INCLUDE_DIRS})
+ set(UDNS_LIBRARY ${UDNS_LIBRARIES})
+ endif()
+endif()
+
+mark_as_advanced(UDNS_INCLUDE_DIRS UDNS_LIBRARIES)
diff --git a/louloulibs/config/config.cpp b/louloulibs/config/config.cpp
index 417981d..24a1c87 100644
--- a/louloulibs/config/config.cpp
+++ b/louloulibs/config/config.cpp
@@ -1,8 +1,7 @@
#include <config/config.hpp>
-#include <logger/logger.hpp>
+#include <iostream>
#include <cstring>
-#include <sstream>
#include <cstdlib>
@@ -66,7 +65,7 @@ bool Config::read_conf(const std::string& name)
std::ifstream file(Config::filename.data());
if (!file.is_open())
{
- log_error("Error while opening file ", filename, " for reading: ", strerror(errno));
+ std::cerr << "Error while opening file " << filename << " for reading: " << strerror(errno) << std::endl;
return false;
}
@@ -96,7 +95,7 @@ void Config::save_to_file()
std::ofstream file(Config::filename.data());
if (file.fail())
{
- log_error("Could not save config file.");
+ std::cerr << "Could not save config file." << std::endl;
return ;
}
for (const auto& it: Config::values)
diff --git a/louloulibs/config/config.hpp b/louloulibs/config/config.hpp
index 6728df8..4e01281 100644
--- a/louloulibs/config/config.hpp
+++ b/louloulibs/config/config.hpp
@@ -15,7 +15,6 @@
#pragma once
-
#include <functional>
#include <fstream>
#include <memory>
diff --git a/louloulibs/logger/logger.cpp b/louloulibs/logger/logger.cpp
index 7336579..92a3d9b 100644
--- a/louloulibs/logger/logger.cpp
+++ b/louloulibs/logger/logger.cpp
@@ -3,14 +3,18 @@
Logger::Logger(const int log_level):
log_level(log_level),
- stream(std::cout.rdbuf())
+ stream(std::cout.rdbuf()),
+ null_buffer{},
+ null_stream{&null_buffer}
{
}
Logger::Logger(const int log_level, const std::string& log_file):
log_level(log_level),
ofstream(log_file.data(), std::ios_base::app),
- stream(ofstream.rdbuf())
+ stream(ofstream.rdbuf()),
+ null_buffer{},
+ null_stream{&null_buffer}
{
}
diff --git a/louloulibs/logger/logger.hpp b/louloulibs/logger/logger.hpp
index 0893c77..b3284a6 100644
--- a/louloulibs/logger/logger.hpp
+++ b/louloulibs/logger/logger.hpp
@@ -33,15 +33,15 @@
# define __FILENAME__ __FILE__
#endif
+
/**
- * Juste a structure representing a stream doing nothing with its input.
+ * A buffer, used to construct an ostream that does nothing
+ * when we output data in it
*/
-class nullstream: public std::ostream
+class NullBuffer: public std::streambuf
{
-public:
- nullstream():
- std::ostream(0)
- { }
+ public:
+ int overflow(int c) { return c; }
};
class Logger
@@ -59,9 +59,11 @@ public:
private:
const int log_level;
- std::ofstream ofstream;
- nullstream null_stream;
+ std::ofstream ofstream{};
std::ostream stream;
+
+ NullBuffer null_buffer;
+ std::ostream null_stream;
};
#define WHERE __FILENAME__, ":", __LINE__, ":\t"
diff --git a/louloulibs/louloulibs.h.cmake b/louloulibs/louloulibs.h.cmake
index 6131b70..ebb9b9a 100644
--- a/louloulibs/louloulibs.h.cmake
+++ b/louloulibs/louloulibs.h.cmake
@@ -4,7 +4,7 @@
#cmakedefine SYSTEMD_FOUND
#cmakedefine POLLER ${POLLER}
#cmakedefine BOTAN_FOUND
-#cmakedefine CARES_FOUND
+#cmakedefine UDNS_FOUND
#cmakedefine SOFTWARE_VERSION "${SOFTWARE_VERSION}"
#cmakedefine PROJECT_NAME "${PROJECT_NAME}"
#cmakedefine HAS_GET_TIME
diff --git a/louloulibs/network/credentials_manager.cpp b/louloulibs/network/credentials_manager.cpp
index ed04d24..289307b 100644
--- a/louloulibs/network/credentials_manager.cpp
+++ b/louloulibs/network/credentials_manager.cpp
@@ -37,6 +37,28 @@ void BasicCredentialsManager::set_trusted_fingerprint(const std::string& fingerp
this->trusted_fingerprint = fingerprint;
}
+const std::string& BasicCredentialsManager::get_trusted_fingerprint() const
+{
+ return this->trusted_fingerprint;
+}
+
+void check_tls_certificate(const std::vector<Botan::X509_Certificate>& certs,
+ const std::string& hostname, const std::string& trusted_fingerprint,
+ std::exception_ptr exc)
+{
+
+ if (!trusted_fingerprint.empty() && !certs.empty() &&
+ trusted_fingerprint == certs[0].fingerprint() &&
+ certs[0].matches_dns_name(hostname))
+ // We trust the certificate, based on the trusted fingerprint and
+ // the fact that the hostname matches
+ return;
+
+ if (exc)
+ std::rethrow_exception(exc);
+}
+
+#if BOTAN_VERSION_CODE < BOTAN_VERSION_CODE_FOR(1,11,34)
void BasicCredentialsManager::verify_certificate_chain(const std::string& type,
const std::string& purported_hostname,
const std::vector<Botan::X509_Certificate>& certs)
@@ -50,17 +72,14 @@ void BasicCredentialsManager::verify_certificate_chain(const std::string& type,
catch (const std::exception& tls_exception)
{
log_warning("TLS certificate check failed: ", tls_exception.what());
- if (!this->trusted_fingerprint.empty() && !certs.empty() &&
- this->trusted_fingerprint == certs[0].fingerprint() &&
- certs[0].matches_dns_name(purported_hostname))
- // We trust the certificate, based on the trusted fingerprint and
- // the fact that the hostname matches
- return;
-
+ std::exception_ptr exception_ptr{};
if (this->socket_handler->abort_on_invalid_cert())
- throw;
+ exception_ptr = std::current_exception();
+
+ check_tls_certificate(certs, purported_hostname, this->trusted_fingerprint, exception_ptr);
}
}
+#endif
bool BasicCredentialsManager::try_to_open_one_ca_bundle(const std::vector<std::string>& paths)
{
diff --git a/louloulibs/network/credentials_manager.hpp b/louloulibs/network/credentials_manager.hpp
index 7557372..29ee024 100644
--- a/louloulibs/network/credentials_manager.hpp
+++ b/louloulibs/network/credentials_manager.hpp
@@ -9,6 +9,18 @@
class TCPSocketHandler;
+/**
+ * If the given cert isn’t valid, based on the given hostname
+ * and fingerprint, then throws the exception if it’s non-empty.
+ *
+ * Must be called after the standard (from Botan) way of
+ * checking the certificate, if we want to also accept certificates based
+ * on a trusted fingerprint.
+ */
+void check_tls_certificate(const std::vector<Botan::X509_Certificate>& certs,
+ const std::string& hostname, const std::string& trusted_fingerprint,
+ std::exception_ptr exc);
+
class BasicCredentialsManager: public Botan::Credentials_Manager
{
public:
@@ -19,12 +31,15 @@ public:
BasicCredentialsManager& operator=(const BasicCredentialsManager&) = delete;
BasicCredentialsManager& operator=(BasicCredentialsManager&&) = delete;
+#if BOTAN_VERSION_CODE < BOTAN_VERSION_CODE_FOR(1,11,34)
void verify_certificate_chain(const std::string& type,
const std::string& purported_hostname,
const std::vector<Botan::X509_Certificate>&) override final;
+#endif
std::vector<Botan::Certificate_Store*> trusted_certificate_authorities(const std::string& type,
const std::string& context) override final;
void set_trusted_fingerprint(const std::string& fingerprint);
+ const std::string& get_trusted_fingerprint() const;
private:
const TCPSocketHandler* const socket_handler;
diff --git a/louloulibs/network/dns_handler.cpp b/louloulibs/network/dns_handler.cpp
index fef0cfc..fbd2763 100644
--- a/louloulibs/network/dns_handler.cpp
+++ b/louloulibs/network/dns_handler.cpp
@@ -1,5 +1,5 @@
#include <louloulibs.h>
-#ifdef CARES_FOUND
+#ifdef UDNS_FOUND
#include <network/dns_socket_handler.hpp>
#include <network/dns_handler.hpp>
@@ -7,124 +7,40 @@
#include <utils/timed_events.hpp>
-#include <algorithm>
-#include <stdexcept>
+#include <udns.h>
+#include <cerrno>
+#include <cstring>
-DNSHandler DNSHandler::instance;
+class Resolver;
using namespace std::string_literals;
-DNSHandler::DNSHandler():
- socket_handlers{},
- channel{nullptr}
-{
- int ares_error;
- if ((ares_error = ::ares_library_init(ARES_LIB_INIT_ALL)) != 0)
- throw std::runtime_error("Failed to initialize c-ares lib: "s + ares_strerror(ares_error));
- struct ares_options options = {};
- // The default timeout values are way too high
- options.timeout = 1000;
- options.tries = 3;
- if ((ares_error = ::ares_init_options(&this->channel,
- &options,
- ARES_OPT_TIMEOUTMS|ARES_OPT_TRIES)) != ARES_SUCCESS)
- throw std::runtime_error("Failed to initialize c-ares channel: "s + ares_strerror(ares_error));
-}
-ares_channel& DNSHandler::get_channel()
-{
- return this->channel;
-}
+std::unique_ptr<DNSSocketHandler> DNSHandler::socket_handler{};
-void DNSHandler::destroy()
+DNSHandler::DNSHandler(std::shared_ptr<Poller> poller)
{
- this->remove_all_sockets_from_poller();
- this->socket_handlers.clear();
- ::ares_destroy(this->channel);
- ::ares_library_cleanup();
+ dns_init(nullptr, 0);
+ const auto socket = dns_open(nullptr);
+ if (socket == -1)
+ throw std::runtime_error("Failed to initialize udns socket: "s + strerror(errno));
+
+ DNSHandler::socket_handler = std::make_unique<DNSSocketHandler>(poller, socket);
}
-void DNSHandler::gethostbyname(const std::string& name, ares_host_callback callback,
- void* data, int family)
+void DNSHandler::destroy()
{
- ::ares_gethostbyname(this->channel, name.data(), family,
- callback, data);
+ DNSHandler::socket_handler.reset(nullptr);
+ dns_close(nullptr);
}
-void DNSHandler::watch_dns_sockets(std::shared_ptr<Poller>& poller)
+void DNSHandler::watch()
{
- fd_set readers;
- fd_set writers;
-
- FD_ZERO(&readers);
- FD_ZERO(&writers);
-
- int ndfs = ::ares_fds(this->channel, &readers, &writers);
- // For each existing DNS socket, see if we are still supposed to watch it,
- // if not then erase it
- this->socket_handlers.erase(
- std::remove_if(this->socket_handlers.begin(), this->socket_handlers.end(),
- [&readers](const auto& dns_socket)
- {
- return !FD_ISSET(dns_socket->get_socket(), &readers);
- }),
- this->socket_handlers.end());
-
- for (auto i = 0; i < ndfs; ++i)
- {
- bool read = FD_ISSET(i, &readers);
- bool write = FD_ISSET(i, &writers);
- // Look for the DNSSocketHandler with this fd
- auto it = std::find_if(this->socket_handlers.begin(),
- this->socket_handlers.end(),
- [i](const auto& socket_handler)
- {
- return i == socket_handler->get_socket();
- });
- if (!read && !write) // No need to read or write to it
- { // If found, erase it and stop watching it because it is not
- // needed anymore
- if (it != this->socket_handlers.end())
- // The socket destructor removes it from the poller
- this->socket_handlers.erase(it);
- }
- else // We need to write and/or read to it
- { // If not found, create it because we need to watch it
- if (it == this->socket_handlers.end())
- {
- this->socket_handlers.emplace(this->socket_handlers.begin(),
- std::make_unique<DNSSocketHandler>(poller, *this, i));
- it = this->socket_handlers.begin();
- }
- poller->add_socket_handler(it->get());
- if (write)
- poller->watch_send_events(it->get());
- }
- }
- // Cancel previous timer, if any.
- TimedEventsManager::instance().cancel("DNS timeout");
- struct timeval tv;
- struct timeval* tvp;
- tvp = ::ares_timeout(this->channel, NULL, &tv);
- if (tvp)
- {
- auto future_time = std::chrono::steady_clock::now() + std::chrono::seconds(tvp->tv_sec) + \
- std::chrono::microseconds(tvp->tv_usec);
- TimedEventsManager::instance().add_event(TimedEvent(std::move(future_time),
- [this]()
- {
- for (auto& dns_socket_handler: this->socket_handlers)
- dns_socket_handler->on_recv();
- },
- "DNS timeout"));
- }
+ DNSHandler::socket_handler->watch();
}
-void DNSHandler::remove_all_sockets_from_poller()
+void DNSHandler::unwatch()
{
- for (const auto& socket_handler: this->socket_handlers)
- {
- socket_handler->remove_from_poller();
- }
+ DNSHandler::socket_handler->unwatch();
}
-#endif /* CARES_FOUND */
+#endif /* UDNS_FOUND */
diff --git a/louloulibs/network/dns_handler.hpp b/louloulibs/network/dns_handler.hpp
index fd1729d..78ffe4d 100644
--- a/louloulibs/network/dns_handler.hpp
+++ b/louloulibs/network/dns_handler.hpp
@@ -1,58 +1,37 @@
#pragma once
#include <louloulibs.h>
-#ifdef CARES_FOUND
+#ifdef UDNS_FOUND
-class TCPSocketHandler;
class Poller;
-class DNSSocketHandler;
-# include <ares.h>
-# include <memory>
-# include <string>
-# include <vector>
+#include <network/dns_socket_handler.hpp>
-/**
- * Class managing DNS resolution. It should only be statically instanciated
- * once in SocketHandler. It manages ares channel and calls various
- * functions of that library.
- */
+#include <string>
+#include <vector>
+#include <memory>
class DNSHandler
{
public:
- DNSHandler();
+ explicit DNSHandler(std::shared_ptr<Poller> poller);
~DNSHandler() = default;
+
DNSHandler(const DNSHandler&) = delete;
DNSHandler(DNSHandler&&) = delete;
DNSHandler& operator=(const DNSHandler&) = delete;
DNSHandler& operator=(DNSHandler&&) = delete;
- void gethostbyname(const std::string& name, ares_host_callback callback,
- void* socket_handler, int family);
- /**
- * Call ares_fds to know what fd needs to be watched by the poller, create
- * or destroy DNSSocketHandlers depending on the result.
- */
- void watch_dns_sockets(std::shared_ptr<Poller>& poller);
- /**
- * Destroy and stop watching all the DNS sockets. Then de-init the channel
- * and library.
- */
void destroy();
- void remove_all_sockets_from_poller();
- ares_channel& get_channel();
- static DNSHandler instance;
+ static void watch();
+ static void unwatch();
private:
/**
- * The list of sockets that needs to be watched, according to the last
- * call to ares_fds. DNSSocketHandlers are added to it or removed from it
- * in the watch_dns_sockets() method
+ * Manager for the socket returned by udns, that we need to watch with the poller
*/
- std::vector<std::unique_ptr<DNSSocketHandler>> socket_handlers;
- ares_channel channel;
+ static std::unique_ptr<DNSSocketHandler> socket_handler;
};
-#endif /* CARES_FOUND */
+#endif /* UDNS_FOUND */
diff --git a/louloulibs/network/dns_socket_handler.cpp b/louloulibs/network/dns_socket_handler.cpp
index 403a5be..ad744a9 100644
--- a/louloulibs/network/dns_socket_handler.cpp
+++ b/louloulibs/network/dns_socket_handler.cpp
@@ -1,38 +1,27 @@
#include <louloulibs.h>
-#ifdef CARES_FOUND
+#ifdef UDNS_FOUND
#include <network/dns_socket_handler.hpp>
#include <network/dns_handler.hpp>
#include <network/poller.hpp>
-#include <ares.h>
+#include <udns.h>
DNSSocketHandler::DNSSocketHandler(std::shared_ptr<Poller> poller,
- DNSHandler& handler,
const socket_t socket):
- SocketHandler(poller, socket),
- handler(handler)
+ SocketHandler(poller, socket)
{
+ poller->add_socket_handler(this);
}
-void DNSSocketHandler::connect()
+DNSSocketHandler::~DNSSocketHandler()
{
+ this->unwatch();
}
void DNSSocketHandler::on_recv()
{
- // always stop watching send and read events. We will re-watch them if the
- // next call to ares_fds tell us to
- this->handler.remove_all_sockets_from_poller();
- ::ares_process_fd(DNSHandler::instance.get_channel(), this->socket, ARES_SOCKET_BAD);
-}
-
-void DNSSocketHandler::on_send()
-{
- // always stop watching send and read events. We will re-watch them if the
- // next call to ares_fds tell us to
- this->handler.remove_all_sockets_from_poller();
- ::ares_process_fd(DNSHandler::instance.get_channel(), ARES_SOCKET_BAD, this->socket);
+ dns_ioevent(nullptr, 0);
}
bool DNSSocketHandler::is_connected() const
@@ -40,10 +29,15 @@ bool DNSSocketHandler::is_connected() const
return true;
}
-void DNSSocketHandler::remove_from_poller()
+void DNSSocketHandler::unwatch()
{
if (this->poller->is_managing_socket(this->socket))
this->poller->remove_socket_handler(this->socket);
}
-#endif /* CARES_FOUND */
+void DNSSocketHandler::watch()
+{
+ this->poller->add_socket_handler(this);
+}
+
+#endif /* UDNS_FOUND */
diff --git a/louloulibs/network/dns_socket_handler.hpp b/louloulibs/network/dns_socket_handler.hpp
index 0570196..e12f145 100644
--- a/louloulibs/network/dns_socket_handler.hpp
+++ b/louloulibs/network/dns_socket_handler.hpp
@@ -1,49 +1,33 @@
#pragma once
#include <louloulibs.h>
-#ifdef CARES_FOUND
+#ifdef UDNS_FOUND
#include <network/socket_handler.hpp>
-#include <ares.h>
/**
- * Manage a socket returned by ares_fds. We do not create, open or close the
- * socket ourself: this is done by c-ares. We just call ares_process_fd()
- * with the correct parameters, depending on what can be done on that socket
- * (Poller reported it to be writable or readeable)
+ * Manage the UDP socket provided by udns, we do not create, open or close the
+ * socket ourself: this is done by udns. We only watch it for readability
*/
-
-class DNSHandler;
-
class DNSSocketHandler: public SocketHandler
{
public:
- explicit DNSSocketHandler(std::shared_ptr<Poller> poller, DNSHandler& handler, const socket_t socket);
- ~DNSSocketHandler() = default;
+ explicit DNSSocketHandler(std::shared_ptr<Poller> poller, const socket_t socket);
+ ~DNSSocketHandler();
DNSSocketHandler(const DNSSocketHandler&) = delete;
DNSSocketHandler(DNSSocketHandler&&) = delete;
DNSSocketHandler& operator=(const DNSSocketHandler&) = delete;
DNSSocketHandler& operator=(DNSSocketHandler&&) = delete;
- /**
- * Just call dns_process_fd, c-ares will do its work of send()ing or
- * recv()ing the data it wants on that socket.
- */
void on_recv() override final;
- void on_send() override final;
- /**
- * Do nothing, because we are always considered to be connected, since the
- * connection is done by c-ares and not by us.
- */
- void connect() override final;
+
/**
* Always true, see the comment for connect()
*/
bool is_connected() const override final;
- void remove_from_poller();
-private:
- DNSHandler& handler;
+ void watch();
+ void unwatch();
};
-#endif // CARES_FOUND
+#endif // UDNS_FOUND
diff --git a/louloulibs/network/resolver.cpp b/louloulibs/network/resolver.cpp
index 2987aaa..efb0cf0 100644
--- a/louloulibs/network/resolver.cpp
+++ b/louloulibs/network/resolver.cpp
@@ -1,17 +1,32 @@
#include <network/dns_handler.hpp>
+#include <utils/timed_events.hpp>
#include <network/resolver.hpp>
#include <string.h>
#include <arpa/inet.h>
+#include <netinet/in.h>
+#include <udns.h>
+
+#include <fstream>
#include <cstdlib>
+#include <sstream>
+#include <chrono>
+#include <map>
using namespace std::string_literals;
+static std::map<int, std::string> dns_error_messages {
+ {DNS_E_TEMPFAIL, "Timeout while contacting DNS servers"},
+ {DNS_E_PROTOCOL, "Misformatted DNS reply"},
+ {DNS_E_NXDOMAIN, "Domain name not found"},
+ {DNS_E_NOMEM, "Out of memory"},
+ {DNS_E_BADQUERY, "Misformatted domain name"}
+};
+
Resolver::Resolver():
-#ifdef CARES_FOUND
+#ifdef UDNS_FOUND
resolved4(false),
resolved6(false),
resolving(false),
- cares_addrinfo(nullptr),
port{},
#endif
resolved(false),
@@ -24,15 +39,44 @@ void Resolver::resolve(const std::string& hostname, const std::string& port,
{
this->error_cb = error_cb;
this->success_cb = success_cb;
-#ifdef CARES_FOUND
+#ifdef UDNS_FOUND
this->port = port;
#endif
this->start_resolving(hostname, port);
}
-#ifdef CARES_FOUND
-void Resolver::start_resolving(const std::string& hostname, const std::string&)
+int Resolver::call_getaddrinfo(const char *name, const char* port, int flags)
+{
+ struct addrinfo hints;
+ memset(&hints, 0, sizeof(struct addrinfo));
+ hints.ai_flags = flags;
+ hints.ai_family = AF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_protocol = 0;
+
+ struct addrinfo* addr_res = nullptr;
+ const int res = ::getaddrinfo(name, port,
+ &hints, &addr_res);
+
+ if (res == 0 && addr_res)
+ {
+ if (!this->addr)
+ this->addr.reset(addr_res);
+ else
+ { // Append this result at the end of the linked list
+ struct addrinfo *rp = this->addr.get();
+ while (rp->ai_next)
+ rp = rp->ai_next;
+ rp->ai_next = addr_res;
+ }
+ }
+
+ return res;
+}
+
+#ifdef UDNS_FOUND
+void Resolver::start_resolving(const std::string& hostname, const std::string& port)
{
this->resolving = true;
this->resolved = false;
@@ -40,48 +84,139 @@ void Resolver::start_resolving(const std::string& hostname, const std::string&)
this->resolved6 = false;
this->error_msg.clear();
- this->cares_addrinfo = nullptr;
+ this->addr.reset(nullptr);
- auto hostname4_resolved = [](void* arg, int status, int,
- struct hostent* hostent)
+ // We first try to use it as an IP address directly. We tell getaddrinfo
+ // to NOT use any DNS resolution.
+ if (this->call_getaddrinfo(hostname.data(), port.data(), AI_NUMERICHOST) == 0)
{
- Resolver* resolver = static_cast<Resolver*>(arg);
- resolver->on_hostname4_resolved(status, hostent);
- };
- auto hostname6_resolved = [](void* arg, int status, int,
- struct hostent* hostent)
+ this->on_resolved();
+ return;
+ }
+
+ // Then we look into /etc/hosts to translate the given hostname
+ const auto hosts = this->look_in_etc_hosts(hostname);
+ if (!hosts.empty())
+ {
+ for (const auto &host: hosts)
+ this->call_getaddrinfo(host.data(), port.data(), AI_NUMERICHOST);
+ this->on_resolved();
+ return;
+ }
+
+ // And finally, we try a DNS resolution
+ auto hostname6_resolved = [](dns_ctx*, dns_rr_a6* result, void* data)
+ {
+ Resolver* resolver = static_cast<Resolver*>(data);
+ resolver->on_hostname6_resolved(result);
+ };
+
+ auto hostname4_resolved = [](dns_ctx*, dns_rr_a4* result, void* data)
+ {
+ Resolver* resolver = static_cast<Resolver*>(data);
+ resolver->on_hostname4_resolved(result);
+ };
+
+ DNSHandler::watch();
+ auto res = dns_submit_a4(nullptr, hostname.data(), 0, hostname4_resolved, this);
+ if (!res)
+ this->on_hostname4_resolved(nullptr);
+ res = dns_submit_a6(nullptr, hostname.data(), 0, hostname6_resolved, this);
+ if (!res)
+ this->on_hostname6_resolved(nullptr);
+
+ this->start_timer();
+}
+
+void Resolver::start_timer()
+{
+ const auto timeout = dns_timeouts(nullptr, -1, 0);
+ if (timeout < 0)
+ return;
+ TimedEvent event(std::chrono::steady_clock::now() + std::chrono::seconds(timeout), [this]() { this->start_timer(); }, "DNS");
+ TimedEventsManager::instance().add_event(std::move(event));
+}
+
+std::vector<std::string> Resolver::look_in_etc_hosts(const std::string &hostname)
+{
+ std::ifstream hosts("/etc/hosts");
+ std::string line;
+
+ std::vector<std::string> results;
+ while (std::getline(hosts, line))
{
- Resolver* resolver = static_cast<Resolver*>(arg);
- resolver->on_hostname6_resolved(status, hostent);
- };
-
- DNSHandler::instance.gethostbyname(hostname, hostname6_resolved,
- this, AF_INET6);
- DNSHandler::instance.gethostbyname(hostname, hostname4_resolved,
- this, AF_INET);
+ if (line.empty())
+ continue;
+
+ std::string ip;
+ std::istringstream line_stream(line);
+ line_stream >> ip;
+ if (ip.empty() || ip[0] == '#')
+ continue;
+
+ std::string host;
+ while (line_stream >> host && !host.empty() && host[0] != '#')
+ {
+ if (hostname == host)
+ {
+ results.push_back(ip);
+ break;
+ }
+ }
+ }
+ return results;
}
-void Resolver::on_hostname4_resolved(int status, struct hostent* hostent)
+void Resolver::on_hostname4_resolved(dns_rr_a4 *result)
{
+ if (dns_active(nullptr) == 0)
+ DNSHandler::unwatch();
+
this->resolved4 = true;
- if (status == ARES_SUCCESS)
- this->fill_ares_addrinfo4(hostent);
+
+ const auto status = dns_status(nullptr);
+
+ if (status >= 0 && result)
+ {
+ char buf[INET6_ADDRSTRLEN];
+
+ for (auto i = 0; i < result->dnsa4_nrr; ++i)
+ {
+ inet_ntop(AF_INET, &result->dnsa4_addr[i], buf, sizeof(buf));
+ this->call_getaddrinfo(buf, this->port.data(), AI_NUMERICHOST);
+ }
+ }
else
- this->error_msg = ::ares_strerror(status);
+ {
+ const auto error = dns_error_messages.find(status);
+ if (error != end(dns_error_messages))
+ this->error_msg = error->second;
+ }
- if (this->resolved4 && this->resolved6)
+ if (this->resolved6 && this->resolved4)
this->on_resolved();
}
-void Resolver::on_hostname6_resolved(int status, struct hostent* hostent)
+void Resolver::on_hostname6_resolved(dns_rr_a6 *result)
{
+ if (dns_active(nullptr) == 0)
+ DNSHandler::unwatch();
+
this->resolved6 = true;
- if (status == ARES_SUCCESS)
- this->fill_ares_addrinfo6(hostent);
- else
- this->error_msg = ::ares_strerror(status);
+ char buf[INET6_ADDRSTRLEN];
- if (this->resolved4 && this->resolved6)
+ const auto status = dns_status(nullptr);
+
+ if (status >= 0 && result)
+ {
+ for (auto i = 0; i < result->dnsa6_nrr; ++i)
+ {
+ inet_ntop(AF_INET6, &result->dnsa6_addr[i], buf, sizeof(buf));
+ this->call_getaddrinfo(buf, this->port.data(), AI_NUMERICHOST);
+ }
+ }
+
+ if (this->resolved6 && this->resolved4)
this->on_resolved();
}
@@ -89,100 +224,26 @@ void Resolver::on_resolved()
{
this->resolved = true;
this->resolving = false;
- if (!this->cares_addrinfo)
+ if (!this->addr)
{
if (this->error_cb)
this->error_cb(this->error_msg.data());
}
else
{
- this->addr.reset(this->cares_addrinfo);
if (this->success_cb)
this->success_cb(this->addr.get());
}
}
-void Resolver::fill_ares_addrinfo4(const struct hostent* hostent)
-{
- struct addrinfo* prev = this->cares_addrinfo;
- struct in_addr** address = reinterpret_cast<struct in_addr**>(hostent->h_addr_list);
-
- while (*address)
- {
- // Create a new addrinfo list element, and fill it
- struct addrinfo* current = new struct addrinfo;
- current->ai_flags = 0;
- current->ai_family = hostent->h_addrtype;
- current->ai_socktype = SOCK_STREAM;
- current->ai_protocol = 0;
- current->ai_addrlen = sizeof(struct sockaddr_in);
-
- struct sockaddr_in* ai_addr = new struct sockaddr_in;
-
- ai_addr->sin_family = hostent->h_addrtype;
- ai_addr->sin_port = htons(std::strtoul(this->port.data(), nullptr, 10));
- ai_addr->sin_addr.s_addr = (*address)->s_addr;
-
- current->ai_addr = reinterpret_cast<struct sockaddr*>(ai_addr);
- current->ai_next = nullptr;
- current->ai_canonname = nullptr;
-
- current->ai_next = prev;
- this->cares_addrinfo = current;
- prev = current;
- ++address;
- }
-}
-
-void Resolver::fill_ares_addrinfo6(const struct hostent* hostent)
-{
- struct addrinfo* prev = this->cares_addrinfo;
- struct in6_addr** address = reinterpret_cast<struct in6_addr**>(hostent->h_addr_list);
-
- while (*address)
- {
- // Create a new addrinfo list element, and fill it
- struct addrinfo* current = new struct addrinfo;
- current->ai_flags = 0;
- current->ai_family = hostent->h_addrtype;
- current->ai_socktype = SOCK_STREAM;
- current->ai_protocol = 0;
- current->ai_addrlen = sizeof(struct sockaddr_in6);
-
- struct sockaddr_in6* ai_addr = new struct sockaddr_in6;
- ai_addr->sin6_family = hostent->h_addrtype;
- ai_addr->sin6_port = htons(std::strtoul(this->port.data(), nullptr, 10));
- ::memcpy(ai_addr->sin6_addr.s6_addr, (*address)->s6_addr, sizeof(ai_addr->sin6_addr.s6_addr));
- ai_addr->sin6_flowinfo = 0;
- ai_addr->sin6_scope_id = 0;
-
- current->ai_addr = reinterpret_cast<struct sockaddr*>(ai_addr);
- current->ai_canonname = nullptr;
-
- current->ai_next = prev;
- this->cares_addrinfo = current;
- prev = current;
- ++address;
- }
-}
-
-#else // ifdef CARES_FOUND
+#else // ifdef UDNS_FOUND
void Resolver::start_resolving(const std::string& hostname, const std::string& port)
{
// If the resolution fails, the addr will be unset
this->addr.reset(nullptr);
- struct addrinfo hints;
- memset(&hints, 0, sizeof(struct addrinfo));
- hints.ai_flags = 0;
- hints.ai_family = AF_UNSPEC;
- hints.ai_socktype = SOCK_STREAM;
- hints.ai_protocol = 0;
-
- struct addrinfo* addr_res = nullptr;
- const int res = ::getaddrinfo(hostname.data(), port.data(),
- &hints, &addr_res);
+ const auto res = this->call_getaddrinfo(hostname.data(), port.data(), 0);
this->resolved = true;
@@ -194,12 +255,11 @@ void Resolver::start_resolving(const std::string& hostname, const std::string& p
}
else
{
- this->addr.reset(addr_res);
if (this->success_cb)
this->success_cb(this->addr.get());
}
}
-#endif // ifdef CARES_FOUND
+#endif // ifdef UDNS_FOUND
std::string addr_to_string(const struct addrinfo* rp)
{
diff --git a/louloulibs/network/resolver.hpp b/louloulibs/network/resolver.hpp
index 29e6f3a..f516da5 100644
--- a/louloulibs/network/resolver.hpp
+++ b/louloulibs/network/resolver.hpp
@@ -1,38 +1,31 @@
#pragma once
-
#include "louloulibs.h"
#include <functional>
+#include <vector>
#include <memory>
#include <string>
#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>
+#include <udns.h>
class AddrinfoDeleter
{
public:
void operator()(struct addrinfo* addr)
{
-#ifdef CARES_FOUND
- while (addr)
- {
- delete addr->ai_addr;
- auto next = addr->ai_next;
- delete addr;
- addr = next;
- }
-#else
freeaddrinfo(addr);
-#endif
}
};
+
class Resolver
{
public:
+
using ErrorCallbackType = std::function<void(const char*)>;
using SuccessCallbackType = std::function<void(const struct addrinfo*)>;
@@ -45,7 +38,7 @@ public:
bool is_resolving() const
{
-#ifdef CARES_FOUND
+#ifdef UDNS_FOUND
return this->resolving;
#else
return false;
@@ -68,11 +61,10 @@ public:
void clear()
{
-#ifdef CARES_FOUND
+#ifdef UDNS_FOUND
this->resolved6 = false;
this->resolved4 = false;
this->resolving = false;
- this->cares_addrinfo = nullptr;
this->port.clear();
#endif
this->resolved = false;
@@ -85,12 +77,18 @@ public:
private:
void start_resolving(const std::string& hostname, const std::string& port);
-#ifdef CARES_FOUND
- void on_hostname4_resolved(int status, struct hostent* hostent);
- void on_hostname6_resolved(int status, struct hostent* hostent);
+ std::vector<std::string> look_in_etc_hosts(const std::string& hostname);
+ /**
+ * Call getaddrinfo() on the given hostname or IP, and append the result
+ * to our internal addrinfo list. Return getaddrinfo()’s return value.
+ */
+ int call_getaddrinfo(const char* name, const char* port, int flags);
- void fill_ares_addrinfo4(const struct hostent* hostent);
- void fill_ares_addrinfo6(const struct hostent* hostent);
+#ifdef UDNS_FOUND
+ void on_hostname4_resolved(dns_rr_a4 *result);
+ void on_hostname6_resolved(dns_rr_a6 *result);
+
+ void start_timer();
void on_resolved();
@@ -99,14 +97,6 @@ private:
bool resolving;
- /**
- * When using c-ares to resolve the host asynchronously, we need the
- * c-ares callbacks to fill a structure (a struct addrinfo, for
- * compatibility with getaddrinfo and the rest of the code that works when
- * c-ares is not used) with all returned values (for example an IPv6 and
- * an IPv4). The pointer is given to the unique_ptr to manage its lifetime.
- */
- struct addrinfo* cares_addrinfo;
std::string port;
#endif
@@ -117,7 +107,6 @@ private:
bool resolved;
std::string error_msg;
-
std::unique_ptr<struct addrinfo, AddrinfoDeleter> addr;
ErrorCallbackType error_cb;
@@ -125,5 +114,3 @@ private:
};
std::string addr_to_string(const struct addrinfo* rp);
-
-
diff --git a/louloulibs/network/socket_handler.hpp b/louloulibs/network/socket_handler.hpp
index ea79a18..607a106 100644
--- a/louloulibs/network/socket_handler.hpp
+++ b/louloulibs/network/socket_handler.hpp
@@ -20,9 +20,9 @@ public:
SocketHandler& operator=(const SocketHandler&) = delete;
SocketHandler& operator=(SocketHandler&&) = delete;
- virtual void on_recv() = 0;
- virtual void on_send() = 0;
- virtual void connect() = 0;
+ virtual void on_recv() {}
+ virtual void on_send() {}
+ virtual void connect() {}
virtual bool is_connected() const = 0;
socket_t get_socket() const
diff --git a/louloulibs/network/tcp_client_socket_handler.cpp b/louloulibs/network/tcp_client_socket_handler.cpp
new file mode 100644
index 0000000..4e6445c
--- /dev/null
+++ b/louloulibs/network/tcp_client_socket_handler.cpp
@@ -0,0 +1,257 @@
+#include <network/tcp_client_socket_handler.hpp>
+#include <utils/timed_events.hpp>
+#include <utils/scopeguard.hpp>
+#include <network/poller.hpp>
+
+#include <logger/logger.hpp>
+
+#include <cstring>
+#include <unistd.h>
+#include <fcntl.h>
+
+using namespace std::string_literals;
+
+TCPClientSocketHandler::TCPClientSocketHandler(std::shared_ptr<Poller> poller):
+ TCPSocketHandler(poller),
+ hostname_resolution_failed(false),
+ connected(false),
+ connecting(false)
+{}
+
+TCPClientSocketHandler::~TCPClientSocketHandler()
+{
+ this->close();
+}
+
+void TCPClientSocketHandler::init_socket(const struct addrinfo* rp)
+{
+ if (this->socket != -1)
+ ::close(this->socket);
+ if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1)
+ throw std::runtime_error("Could not create socket: "s + std::strerror(errno));
+ // Bind the socket to a specific address, if specified
+ if (!this->bind_addr.empty())
+ {
+ // Convert the address from string format to a sockaddr that can be
+ // used in bind()
+ struct addrinfo* result;
+ int err = ::getaddrinfo(this->bind_addr.data(), nullptr, nullptr, &result);
+ if (err != 0 || !result)
+ log_error("Failed to bind socket to ", this->bind_addr, ": ",
+ gai_strerror(err));
+ else
+ {
+ utils::ScopeGuard sg([result](){ freeaddrinfo(result); });
+ struct addrinfo* rp;
+ for (rp = result; rp; rp = rp->ai_next)
+ {
+ if ((::bind(this->socket,
+ reinterpret_cast<const struct sockaddr*>(rp->ai_addr),
+ rp->ai_addrlen)) == 0)
+ break;
+ }
+ if (!rp)
+ log_error("Failed to bind socket to ", this->bind_addr, ": ",
+ strerror(errno));
+ else
+ log_info("Socket successfully bound to ", this->bind_addr);
+ }
+ }
+ int optval = 1;
+ if (::setsockopt(this->socket, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) == -1)
+ log_warning("Failed to enable TCP keepalive on socket: ", strerror(errno));
+ // Set the socket on non-blocking mode. This is useful to receive a EAGAIN
+ // error when connect() would block, to not block the whole process if a
+ // remote is not responsive.
+ const int existing_flags = ::fcntl(this->socket, F_GETFL, 0);
+ if ((existing_flags == -1) ||
+ (::fcntl(this->socket, F_SETFL, existing_flags | O_NONBLOCK) == -1))
+ throw std::runtime_error("Could not initialize socket: "s + std::strerror(errno));
+}
+
+void TCPClientSocketHandler::connect(const std::string& address, const std::string& port, const bool tls)
+{
+ this->address = address;
+ this->port = port;
+ this->use_tls = tls;
+
+ struct addrinfo* addr_res;
+
+ if (!this->connecting)
+ {
+ // Get the addrinfo from getaddrinfo (or using udns), only if
+ // this is the first call of this function.
+ if (!this->resolver.is_resolved())
+ {
+ log_info("Trying to connect to ", address, ":", port);
+ // Start the asynchronous process of resolving the hostname. Once
+ // the addresses have been found and `resolved` has been set to true
+ // (but connecting will still be false), TCPClientSocketHandler::connect()
+ // needs to be called, again.
+ this->resolver.resolve(address, port,
+ [this](const struct addrinfo*)
+ {
+ log_debug("Resolution success, calling connect() again");
+ this->connect();
+ },
+ [this](const char*)
+ {
+ log_debug("Resolution failed, calling connect() again");
+ this->connect();
+ });
+ return;
+ }
+ else
+ {
+ // The DNS resolver resolved the hostname and the available addresses
+ // where saved in the addrinfo linked list. Now, just use
+ // this list to try to connect.
+ addr_res = this->resolver.get_result().get();
+ if (!addr_res)
+ {
+ this->hostname_resolution_failed = true;
+ const auto msg = this->resolver.get_error_message();
+ this->close();
+ this->on_connection_failed(msg);
+ return ;
+ }
+ }
+ }
+ else
+ { // This function is called again, use the saved addrinfo structure,
+ // instead of re-doing the whole getaddrinfo process.
+ addr_res = &this->addrinfo;
+ }
+
+ for (struct addrinfo* rp = addr_res; rp; rp = rp->ai_next)
+ {
+ if (!this->connecting)
+ {
+ try {
+ this->init_socket(rp);
+ }
+ catch (const std::runtime_error& error) {
+ log_error("Failed to init socket: ", error.what());
+ break;
+ }
+ }
+
+ this->display_resolved_ip(rp);
+
+ if (::connect(this->socket, rp->ai_addr, rp->ai_addrlen) == 0
+ || errno == EISCONN)
+ {
+ log_info("Connection success.");
+ TimedEventsManager::instance().cancel("connection_timeout"s +
+ std::to_string(this->socket));
+ this->poller->add_socket_handler(this);
+ this->connected = true;
+ this->connecting = false;
+#ifdef BOTAN_FOUND
+ if (this->use_tls)
+ this->start_tls(this->address, this->port);
+#endif
+ this->connection_date = std::chrono::system_clock::now();
+
+ // Get our local TCP port and store it
+ this->local_port = static_cast<uint16_t>(-1);
+ if (rp->ai_family == AF_INET6)
+ {
+ struct sockaddr_in6 a;
+ socklen_t l = sizeof(a);
+ if (::getsockname(this->socket, (struct sockaddr*)&a, &l) != -1)
+ this->local_port = ntohs(a.sin6_port);
+ }
+ else if (rp->ai_family == AF_INET)
+ {
+ struct sockaddr_in a;
+ socklen_t l = sizeof(a);
+ if (::getsockname(this->socket, (struct sockaddr*)&a, &l) != -1)
+ this->local_port = ntohs(a.sin_port);
+ }
+
+ log_debug("Local port: ", this->local_port, ", and remote port: ", this->port);
+
+ this->on_connected();
+ return ;
+ }
+ else if (errno == EINPROGRESS || errno == EALREADY)
+ { // retry this process later, when the socket
+ // is ready to be written on.
+ this->connecting = true;
+ this->poller->add_socket_handler(this);
+ this->poller->watch_send_events(this);
+ // Save the addrinfo structure, to use it on the next call
+ this->ai_addrlen = rp->ai_addrlen;
+ memcpy(&this->ai_addr, rp->ai_addr, this->ai_addrlen);
+ memcpy(&this->addrinfo, rp, sizeof(struct addrinfo));
+ this->addrinfo.ai_addr = reinterpret_cast<struct sockaddr*>(&this->ai_addr);
+ this->addrinfo.ai_next = nullptr;
+ // If the connection has not succeeded or failed in 5s, we consider
+ // it to have failed
+ TimedEventsManager::instance().add_event(
+ TimedEvent(std::chrono::steady_clock::now() + 5s,
+ std::bind(&TCPClientSocketHandler::on_connection_timeout, this),
+ "connection_timeout"s + std::to_string(this->socket)));
+ return ;
+ }
+ log_info("Connection failed:", std::strerror(errno));
+ }
+ log_error("All connection attempts failed.");
+ this->close();
+ this->on_connection_failed(std::strerror(errno));
+ return ;
+}
+
+void TCPClientSocketHandler::on_connection_timeout()
+{
+ this->close();
+ this->on_connection_failed("connection timed out");
+}
+
+void TCPClientSocketHandler::connect()
+{
+ this->connect(this->address, this->port, this->use_tls);
+}
+
+void TCPClientSocketHandler::close()
+{
+ TimedEventsManager::instance().cancel("connection_timeout"s +
+ std::to_string(this->socket));
+
+ TCPSocketHandler::close();
+
+ this->connected = false;
+ this->connecting = false;
+ this->port.clear();
+ this->resolver.clear();
+}
+
+void TCPClientSocketHandler::display_resolved_ip(struct addrinfo* rp) const
+{
+ if (rp->ai_family == AF_INET)
+ log_debug("Trying IPv4 address ", addr_to_string(rp));
+ else if (rp->ai_family == AF_INET6)
+ log_debug("Trying IPv6 address ", addr_to_string(rp));
+}
+
+bool TCPClientSocketHandler::is_connected() const
+{
+ return this->connected;
+}
+
+bool TCPClientSocketHandler::is_connecting() const
+{
+ return this->connecting || this->resolver.is_resolving();
+}
+
+std::string TCPClientSocketHandler::get_port() const
+{
+ return this->port;
+}
+
+bool TCPClientSocketHandler::match_port_pairt(const uint16_t local, const uint16_t remote) const
+{
+ const uint16_t remote_port = static_cast<uint16_t>(std::stoi(this->port));
+ return this->is_connected() && local == this->local_port && remote == remote_port;
+}
diff --git a/louloulibs/network/tcp_client_socket_handler.hpp b/louloulibs/network/tcp_client_socket_handler.hpp
new file mode 100644
index 0000000..75e1364
--- /dev/null
+++ b/louloulibs/network/tcp_client_socket_handler.hpp
@@ -0,0 +1,82 @@
+#pragma once
+
+#include <network/tcp_socket_handler.hpp>
+
+class TCPClientSocketHandler: public TCPSocketHandler
+{
+ public:
+ TCPClientSocketHandler(std::shared_ptr<Poller> poller);
+ ~TCPClientSocketHandler();
+ /**
+ * Connect to the remote server, and call on_connected() if this
+ * succeeds. If tls is true, we set use_tls to true and will also call
+ * start_tls() when the connection succeeds.
+ */
+ void connect(const std::string& address, const std::string& port, const bool tls);
+ void connect() override final;
+ /**
+ * Called by a TimedEvent, when the connection did not succeed or fail
+ * after a given time.
+ */
+ void on_connection_timeout();
+ /**
+ * Called when the connection is successful.
+ */
+ virtual void on_connected() = 0;
+ bool is_connected() const override;
+ bool is_connecting() const override;
+
+ std::string get_port() const;
+
+ void close() override final;
+ std::chrono::system_clock::time_point connection_date;
+
+ /**
+ * Whether or not this connection is using the two given TCP ports.
+ */
+ bool match_port_pairt(const uint16_t local, const uint16_t remote) const;
+
+ protected:
+ bool hostname_resolution_failed;
+ /**
+ * Address to bind the socket to, before calling connect().
+ * If empty, it’s equivalent to binding to INADDR_ANY.
+ */
+ std::string bind_addr;
+ /**
+ * Display the resolved IP, just for information purpose.
+ */
+ void display_resolved_ip(struct addrinfo* rp) const;
+ private:
+ /**
+ * Initialize the socket with the parameters contained in the given
+ * addrinfo structure.
+ */
+ void init_socket(const struct addrinfo* rp);
+ /**
+ * DNS resolver
+ */
+ Resolver resolver;
+ /**
+ * Keep the details of the addrinfo returned by the resolver that
+ * triggered a EINPROGRESS error when connect()ing to it, to reuse it
+ * directly when connect() is called again.
+ */
+ struct addrinfo addrinfo{};
+ struct sockaddr_in6 ai_addr{};
+ socklen_t ai_addrlen{};
+
+ /**
+ * Hostname we are connected/connecting to
+ */
+ std::string address;
+ /**
+ * Port we are connected/connecting to
+ */
+ std::string port;
+
+ uint16_t local_port{};
+
+ bool connected;
+ bool connecting;
+};
diff --git a/louloulibs/network/tcp_server_socket.hpp b/louloulibs/network/tcp_server_socket.hpp
new file mode 100644
index 0000000..7ea49ab
--- /dev/null
+++ b/louloulibs/network/tcp_server_socket.hpp
@@ -0,0 +1,70 @@
+#pragma once
+
+#include <network/socket_handler.hpp>
+#include <network/poller.hpp>
+#include <logger/logger.hpp>
+
+#include <string>
+
+#include <arpa/inet.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/ip.h>
+
+#include <cstring>
+#include <cassert>
+
+template <typename RemoteSocketType>
+class TcpSocketServer: public SocketHandler
+{
+ public:
+ TcpSocketServer(std::shared_ptr<Poller> poller, const uint16_t port):
+ SocketHandler(poller, -1)
+ {
+ if ((this->socket = ::socket(AF_INET6, SOCK_STREAM, 0)) == -1)
+ throw std::runtime_error(std::string{"Could not create socket: "} + std::strerror(errno));
+
+ int opt = 1;
+ if (::setsockopt(this->socket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1)
+ throw std::runtime_error(std::string{"Failed to set socket option: "} + std::strerror(errno));
+
+ struct sockaddr_in6 addr{};
+ addr.sin6_family = AF_INET6;
+ addr.sin6_port = htons(port);
+ addr.sin6_addr = IN6ADDR_ANY_INIT;
+ if ((::bind(this->socket, (const struct sockaddr*)&addr, sizeof(addr))) == -1)
+ { // If we can’t listen on this port, we just give up, but this is not fatal.
+ log_warning("Failed to bind on port ", std::to_string(port), ": ", std::strerror(errno));
+ return;
+ }
+
+ if ((::listen(this->socket, 10)) == -1)
+ throw std::runtime_error("listen() failed");
+
+ this->accept();
+ }
+ ~TcpSocketServer() = default;
+
+ void on_recv() override
+ {
+ // Accept a RemoteSocketType
+ int socket = ::accept(this->socket, nullptr, nullptr);
+
+ auto client = std::make_unique<RemoteSocketType>(poller, socket, *this);
+ this->poller->add_socket_handler(client.get());
+ this->sockets.push_back(std::move(client));
+ }
+
+ protected:
+ std::vector<std::unique_ptr<RemoteSocketType>> sockets;
+
+ private:
+ void accept()
+ {
+ this->poller->add_socket_handler(this);
+ }
+ bool is_connected() const override
+ {
+ return true;
+ }
+};
diff --git a/louloulibs/network/tcp_socket_handler.cpp b/louloulibs/network/tcp_socket_handler.cpp
index 1dddde5..6aef2b1 100644
--- a/louloulibs/network/tcp_socket_handler.cpp
+++ b/louloulibs/network/tcp_socket_handler.cpp
@@ -1,8 +1,6 @@
#include <network/tcp_socket_handler.hpp>
#include <network/dns_handler.hpp>
-#include <utils/timed_events.hpp>
-#include <utils/scopeguard.hpp>
#include <network/poller.hpp>
#include <logger/logger.hpp>
@@ -12,16 +10,29 @@
#include <unistd.h>
#include <errno.h>
#include <cstring>
-#include <fcntl.h>
#ifdef BOTAN_FOUND
# include <botan/hex.h>
# include <botan/tls_exceptn.h>
-Botan::AutoSeeded_RNG TCPSocketHandler::rng;
-Botan::TLS::Policy TCPSocketHandler::policy;
-Botan::TLS::Session_Manager_In_Memory TCPSocketHandler::session_manager(TCPSocketHandler::rng);
-
+namespace
+{
+ Botan::AutoSeeded_RNG& get_rng()
+ {
+ static Botan::AutoSeeded_RNG rng{};
+ return rng;
+ }
+ BiboumiTLSPolicy& get_policy()
+ {
+ static BiboumiTLSPolicy policy{};
+ return policy;
+ }
+ Botan::TLS::Session_Manager_In_Memory& get_session_manager()
+ {
+ static Botan::TLS::Session_Manager_In_Memory session_manager{get_rng()};
+ return session_manager;
+ }
+}
#endif
#ifndef UIO_FASTIOV
@@ -35,10 +46,7 @@ namespace ph = std::placeholders;
TCPSocketHandler::TCPSocketHandler(std::shared_ptr<Poller> poller):
SocketHandler(poller, -1),
- use_tls(false),
- connected(false),
- connecting(false),
- hostname_resolution_failed(false)
+ use_tls(false)
#ifdef BOTAN_FOUND
,credential_manager(this)
#endif
@@ -46,181 +54,13 @@ TCPSocketHandler::TCPSocketHandler(std::shared_ptr<Poller> poller):
TCPSocketHandler::~TCPSocketHandler()
{
- this->close();
-}
-
-
-void TCPSocketHandler::init_socket(const struct addrinfo* rp)
-{
+ if (this->poller->is_managing_socket(this->get_socket()))
+ this->poller->remove_socket_handler(this->get_socket());
if (this->socket != -1)
- ::close(this->socket);
- if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1)
- throw std::runtime_error("Could not create socket: "s + strerror(errno));
- // Bind the socket to a specific address, if specified
- if (!this->bind_addr.empty())
- {
- // Convert the address from string format to a sockaddr that can be
- // used in bind()
- struct addrinfo* result;
- int err = ::getaddrinfo(this->bind_addr.data(), nullptr, nullptr, &result);
- if (err != 0 || !result)
- log_error("Failed to bind socket to ", this->bind_addr, ": ",
- gai_strerror(err));
- else
- {
- utils::ScopeGuard sg([result](){ freeaddrinfo(result); });
- struct addrinfo* rp;
- int bind_error = 0;
- for (rp = result; rp; rp = rp->ai_next)
- {
- if ((bind_error = ::bind(this->socket,
- reinterpret_cast<const struct sockaddr*>(rp->ai_addr),
- rp->ai_addrlen)) == 0)
- break;
- }
- if (!rp)
- log_error("Failed to bind socket to ", this->bind_addr, ": ",
- strerror(errno));
- else
- log_info("Socket successfully bound to ", this->bind_addr);
- }
- }
- int optval = 1;
- if (::setsockopt(this->socket, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) == -1)
- log_warning("Failed to enable TCP keepalive on socket: ", strerror(errno));
- // Set the socket on non-blocking mode. This is useful to receive a EAGAIN
- // error when connect() would block, to not block the whole process if a
- // remote is not responsive.
- const int existing_flags = ::fcntl(this->socket, F_GETFL, 0);
- if ((existing_flags == -1) ||
- (::fcntl(this->socket, F_SETFL, existing_flags | O_NONBLOCK) == -1))
- throw std::runtime_error("Could not initialize socket: "s + strerror(errno));
-}
-
-void TCPSocketHandler::connect(const std::string& address, const std::string& port, const bool tls)
-{
- this->address = address;
- this->port = port;
- this->use_tls = tls;
-
- struct addrinfo* addr_res;
-
- if (!this->connecting)
- {
- // Get the addrinfo from getaddrinfo (or ares_gethostbyname), only if
- // this is the first call of this function.
- if (!this->resolver.is_resolved())
- {
- log_info("Trying to connect to ", address, ":", port);
- // Start the asynchronous process of resolving the hostname. Once
- // the addresses have been found and `resolved` has been set to true
- // (but connecting will still be false), TCPSocketHandler::connect()
- // needs to be called, again.
- this->resolver.resolve(address, port,
- [this](const struct addrinfo*)
- {
- log_debug("Resolution success, calling connect() again");
- this->connect();
- },
- [this](const char*)
- {
- log_debug("Resolution failed, calling connect() again");
- this->connect();
- });
- return;
- }
- else
- {
- // The c-ares resolved the hostname and the available addresses
- // where saved in the cares_addrinfo linked list. Now, just use
- // this list to try to connect.
- addr_res = this->resolver.get_result().get();
- if (!addr_res)
- {
- this->hostname_resolution_failed = true;
- const auto msg = this->resolver.get_error_message();
- this->close();
- this->on_connection_failed(msg);
- return ;
- }
- }
- }
- else
- { // This function is called again, use the saved addrinfo structure,
- // instead of re-doing the whole getaddrinfo process.
- addr_res = &this->addrinfo;
- }
-
- for (struct addrinfo* rp = addr_res; rp; rp = rp->ai_next)
{
- if (!this->connecting)
- {
- try {
- this->init_socket(rp);
- }
- catch (const std::runtime_error& error) {
- log_error("Failed to init socket: ", error.what());
- break;
- }
- }
-
- this->display_resolved_ip(rp);
-
- if (::connect(this->socket, rp->ai_addr, rp->ai_addrlen) == 0
- || errno == EISCONN)
- {
- log_info("Connection success.");
- TimedEventsManager::instance().cancel("connection_timeout"s +
- std::to_string(this->socket));
- this->poller->add_socket_handler(this);
- this->connected = true;
- this->connecting = false;
-#ifdef BOTAN_FOUND
- if (this->use_tls)
- this->start_tls();
-#endif
- this->connection_date = std::chrono::system_clock::now();
-
- this->on_connected();
- return ;
- }
- else if (errno == EINPROGRESS || errno == EALREADY)
- { // retry this process later, when the socket
- // is ready to be written on.
- this->connecting = true;
- this->poller->add_socket_handler(this);
- this->poller->watch_send_events(this);
- // Save the addrinfo structure, to use it on the next call
- this->ai_addrlen = rp->ai_addrlen;
- memcpy(&this->ai_addr, rp->ai_addr, this->ai_addrlen);
- memcpy(&this->addrinfo, rp, sizeof(struct addrinfo));
- this->addrinfo.ai_addr = reinterpret_cast<struct sockaddr*>(&this->ai_addr);
- this->addrinfo.ai_next = nullptr;
- // If the connection has not succeeded or failed in 5s, we consider
- // it to have failed
- TimedEventsManager::instance().add_event(
- TimedEvent(std::chrono::steady_clock::now() + 5s,
- std::bind(&TCPSocketHandler::on_connection_timeout, this),
- "connection_timeout"s + std::to_string(this->socket)));
- return ;
- }
- log_info("Connection failed:", strerror(errno));
+ ::close(this->socket);
+ this->socket = -1;
}
- log_error("All connection attempts failed.");
- this->close();
- this->on_connection_failed(strerror(errno));
- return ;
-}
-
-void TCPSocketHandler::on_connection_timeout()
-{
- this->close();
- this->on_connection_failed("connection timed out");
-}
-
-void TCPSocketHandler::connect()
-{
- this->connect(this->address, this->port, this->use_tls);
}
void TCPSocketHandler::on_recv()
@@ -267,13 +107,13 @@ ssize_t TCPSocketHandler::do_recv(void* recv_buf, const size_t buf_size)
}
else if (-1 == size)
{
- if (this->connecting)
+ if (this->is_connecting())
log_warning("Error connecting: ", strerror(errno));
else
log_warning("Error while reading from socket: ", strerror(errno));
// Remember if we were connecting, or already connected when this
// happened, because close() sets this->connecting to false
- const auto were_connecting = this->connecting;
+ const auto were_connecting = this->is_connecting();
this->close();
if (were_connecting)
this->on_connection_failed(strerror(errno));
@@ -333,29 +173,15 @@ void TCPSocketHandler::on_send()
void TCPSocketHandler::close()
{
- TimedEventsManager::instance().cancel("connection_timeout"s +
- std::to_string(this->socket));
- if (this->connected || this->connecting)
+ if (this->is_connected() || this->is_connecting())
this->poller->remove_socket_handler(this->get_socket());
if (this->socket != -1)
{
::close(this->socket);
this->socket = -1;
}
- this->connected = false;
- this->connecting = false;
this->in_buf.clear();
this->out_buf.clear();
- this->port.clear();
- this->resolver.clear();
-}
-
-void TCPSocketHandler::display_resolved_ip(struct addrinfo* rp) const
-{
- if (rp->ai_family == AF_INET)
- log_debug("Trying IPv4 address ", addr_to_string(rp));
- else if (rp->ai_family == AF_INET6)
- log_debug("Trying IPv6 address ", addr_to_string(rp));
}
void TCPSocketHandler::send_data(std::string&& data)
@@ -379,52 +205,46 @@ void TCPSocketHandler::raw_send(std::string&& data)
if (data.empty())
return ;
this->out_buf.emplace_back(std::move(data));
- if (this->connected)
+ if (this->is_connected())
this->poller->watch_send_events(this);
}
void TCPSocketHandler::send_pending_data()
{
- if (this->connected && !this->out_buf.empty())
+ if (this->is_connected() && !this->out_buf.empty())
this->poller->watch_send_events(this);
}
-bool TCPSocketHandler::is_connected() const
-{
- return this->connected;
-}
-
-bool TCPSocketHandler::is_connecting() const
-{
- return this->connecting || this->resolver.is_resolving();
-}
-
bool TCPSocketHandler::is_using_tls() const
{
return this->use_tls;
}
-std::string TCPSocketHandler::get_port() const
+void* TCPSocketHandler::get_receive_buffer(const size_t) const
{
- return this->port;
+ return nullptr;
}
-void* TCPSocketHandler::get_receive_buffer(const size_t) const
+void TCPSocketHandler::consume_in_buffer(const std::size_t size)
{
- return nullptr;
+ this->in_buf = this->in_buf.substr(size, std::string::npos);
}
#ifdef BOTAN_FOUND
-void TCPSocketHandler::start_tls()
+void TCPSocketHandler::start_tls(const std::string& address, const std::string& port)
{
- Botan::TLS::Server_Information server_info(this->address, "irc", std::stoul(this->port));
+ Botan::TLS::Server_Information server_info(address, "irc", std::stoul(port));
this->tls = std::make_unique<Botan::TLS::Client>(
- std::bind(&TCPSocketHandler::tls_output_fn, this, ph::_1, ph::_2),
- std::bind(&TCPSocketHandler::tls_data_cb, this, ph::_1, ph::_2),
- std::bind(&TCPSocketHandler::tls_alert_cb, this, ph::_1, ph::_2, ph::_3),
- std::bind(&TCPSocketHandler::tls_handshake_cb, this, ph::_1),
- session_manager, this->credential_manager, policy,
- rng, server_info, Botan::TLS::Protocol_Version::latest_tls_version());
+# if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,32)
+ *this,
+# else
+ [this](const Botan::byte* data, size_t size) { this->tls_emit_data(data, size); },
+ [this](const Botan::byte* data, size_t size) { this->tls_record_received(0, data, size); },
+ [this](Botan::TLS::Alert alert, const Botan::byte*, size_t) { this->tls_alert(alert); },
+ [this](const Botan::TLS::Session& session) { return this->tls_session_established(session); },
+# endif
+ get_session_manager(), this->credential_manager, get_policy(),
+ get_rng(), server_info, Botan::TLS::Protocol_Version::latest_tls_version());
}
void TCPSocketHandler::tls_recv()
@@ -475,7 +295,7 @@ void TCPSocketHandler::tls_send(std::string&& data)
std::make_move_iterator(data.end()));
}
-void TCPSocketHandler::tls_data_cb(const Botan::byte* data, size_t size)
+void TCPSocketHandler::tls_record_received(uint64_t, const Botan::byte *data, size_t size)
{
this->in_buf += std::string(reinterpret_cast<const char*>(data),
size);
@@ -483,17 +303,17 @@ void TCPSocketHandler::tls_data_cb(const Botan::byte* data, size_t size)
this->parse_in_buffer(size);
}
-void TCPSocketHandler::tls_output_fn(const Botan::byte* data, size_t size)
+void TCPSocketHandler::tls_emit_data(const Botan::byte *data, size_t size)
{
this->raw_send(std::string(reinterpret_cast<const char*>(data), size));
}
-void TCPSocketHandler::tls_alert_cb(Botan::TLS::Alert alert, const Botan::byte*, size_t)
+void TCPSocketHandler::tls_alert(Botan::TLS::Alert alert)
{
log_debug("tls_alert: ", alert.type_string());
}
-bool TCPSocketHandler::tls_handshake_cb(const Botan::TLS::Session& session)
+bool TCPSocketHandler::tls_session_established(const Botan::TLS::Session& session)
{
log_debug("Handshake with ", session.server_info().hostname(), " complete.",
" Version: ", session.version().to_string(),
@@ -505,6 +325,31 @@ bool TCPSocketHandler::tls_handshake_cb(const Botan::TLS::Session& session)
return true;
}
+#if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,34)
+void TCPSocketHandler::tls_verify_cert_chain(const std::vector<Botan::X509_Certificate>& cert_chain,
+ const std::vector<std::shared_ptr<const Botan::OCSP::Response>>& ocsp_responses,
+ const std::vector<Botan::Certificate_Store*>& trusted_roots,
+ Botan::Usage_Type usage, const std::string& hostname,
+ const Botan::TLS::Policy& policy)
+{
+ log_debug("Checking remote certificate for hostname ", hostname);
+ try
+ {
+ Botan::TLS::Callbacks::tls_verify_cert_chain(cert_chain, ocsp_responses, trusted_roots, usage, hostname, policy);
+ log_debug("Certificate is valid");
+ }
+ catch (const std::exception& tls_exception)
+ {
+ log_warning("TLS certificate check failed: ", tls_exception.what());
+ std::exception_ptr exception_ptr{};
+ if (this->abort_on_invalid_cert())
+ exception_ptr = std::current_exception();
+
+ check_tls_certificate(cert_chain, hostname, this->credential_manager.get_trusted_fingerprint(), exception_ptr);
+ }
+}
+#endif
+
void TCPSocketHandler::on_tls_activated()
{
this->send_data({});
diff --git a/louloulibs/network/tcp_socket_handler.hpp b/louloulibs/network/tcp_socket_handler.hpp
index 20a3e5a..600405d 100644
--- a/louloulibs/network/tcp_socket_handler.hpp
+++ b/louloulibs/network/tcp_socket_handler.hpp
@@ -1,6 +1,5 @@
#pragma once
-
#include "louloulibs.h"
#include <network/socket_handler.hpp>
@@ -19,13 +18,44 @@
#include <string>
#include <list>
+#ifdef BOTAN_FOUND
+
+# include <botan/types.h>
+# include <botan/botan.h>
+# include <botan/tls_session_manager.h>
+
+class BiboumiTLSPolicy: public Botan::TLS::Policy
+{
+public:
+# if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,33)
+ bool use_ecc_point_compression() const override
+ {
+ return true;
+ }
+ bool require_cert_revocation_info() const override
+ {
+ return false;
+ }
+# endif
+};
+
+# if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,32)
+# define BOTAN_TLS_CALLBACKS_OVERRIDE override final
+# else
+# define BOTAN_TLS_CALLBACKS_OVERRIDE
+# endif
+#endif
+
/**
- * An interface, with a series of callbacks that should be implemented in
- * subclasses that deal with a socket. These callbacks are called on various events
- * (read/write/timeout, etc) when they are notified to a poller
- * (select/poll/epoll etc)
+ * Does all the read/write, buffering etc. With optional tls.
+ * But doesn’t do any connect() or accept() or anything else.
*/
class TCPSocketHandler: public SocketHandler
+#ifdef BOTAN_FOUND
+# if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,32)
+ ,public Botan::TLS::Callbacks
+# endif
+#endif
{
protected:
~TCPSocketHandler();
@@ -37,13 +67,6 @@ public:
TCPSocketHandler& operator=(TCPSocketHandler&&) = delete;
/**
- * Connect to the remote server, and call on_connected() if this
- * succeeds. If tls is true, we set use_tls to true and will also call
- * start_tls() when the connection succeeds.
- */
- void connect(const std::string& address, const std::string& port, const bool tls);
- void connect() override final;
- /**
* Reads raw data from the socket. And pass it to parse_in_buffer()
* If we are using TLS on this connection, we call tls_recv()
*/
@@ -67,25 +90,7 @@ public:
/**
* Close the connection, remove us from the poller
*/
- void close();
- /**
- * Called by a TimedEvent, when the connection did not succeed or fail
- * after a given time.
- */
- void on_connection_timeout();
- /**
- * Called when the connection is successful.
- */
- virtual void on_connected() = 0;
- /**
- * Called when the connection fails. Not when it is closed later, just at
- * the connect() call.
- */
- virtual void on_connection_failed(const std::string& reason) = 0;
- /**
- * Called when we detect a disconnection from the remote host.
- */
- virtual void on_connection_close(const std::string& error) = 0;
+ virtual void close();
/**
* Handle/consume (some of) the data received so far. The data to handle
* may be in the in_buf buffer, or somewhere else, depending on what
@@ -93,6 +98,9 @@ public:
* should be truncated, only the unused data should be left untouched.
*
* The size argument is the size of the last chunk of data that was added to the buffer.
+ *
+ * The function should call consume_in_buffer, with the size that was consumed by the
+ * “parsing”, and thus to be removed from the input buffer.
*/
virtual void parse_in_buffer(const size_t size) = 0;
#ifdef BOTAN_FOUND
@@ -105,19 +113,10 @@ public:
return true;
}
#endif
- bool is_connected() const override final;
- bool is_connecting() const;
bool is_using_tls() const;
- std::string get_port() const;
- std::chrono::system_clock::time_point connection_date;
private:
/**
- * Initialize the socket with the parameters contained in the given
- * addrinfo structure.
- */
- void init_socket(const struct addrinfo* rp);
- /**
* Reads from the socket into the provided buffer. If an error occurs
* (read returns <= 0), the handling of the error is done here (close the
* connection, log a message, etc).
@@ -136,13 +135,16 @@ private:
*/
void raw_send(std::string&& data);
+ protected:
+ virtual bool is_connecting() const = 0;
#ifdef BOTAN_FOUND
/**
* Create the TLS::Client object, with all the callbacks etc. This must be
* called only when we know we are able to send TLS-encrypted data over
* the socket.
*/
- void start_tls();
+ void start_tls(const std::string& address, const std::string& port);
+ private:
/**
* An additional step to pass the data into our tls object to decrypt it
* before passing it to parse_in_buffer.
@@ -158,22 +160,31 @@ private:
* Called by the tls object that some data has been decrypt. We call
* parse_in_buffer() to handle that unencrypted data.
*/
- void tls_data_cb(const Botan::byte* data, size_t size);
+ void tls_record_received(uint64_t rec_no, const Botan::byte* data, size_t size) BOTAN_TLS_CALLBACKS_OVERRIDE;
/**
* Called by the tls object to indicate that some data has been encrypted
* and is now ready to be sent on the socket as is.
*/
- void tls_output_fn(const Botan::byte* data, size_t size);
+ void tls_emit_data(const Botan::byte* data, size_t size) BOTAN_TLS_CALLBACKS_OVERRIDE;
/**
* Called by the tls object to indicate that a TLS alert has been
* received. We don’t use it, we just log some message, at the moment.
*/
- void tls_alert_cb(Botan::TLS::Alert alert, const Botan::byte*, size_t);
+ void tls_alert(Botan::TLS::Alert alert) BOTAN_TLS_CALLBACKS_OVERRIDE;
/**
* Called by the tls object at the end of the TLS handshake. We don't do
* anything here appart from logging the TLS session information.
*/
- bool tls_handshake_cb(const Botan::TLS::Session& session);
+ bool tls_session_established(const Botan::TLS::Session& session) BOTAN_TLS_CALLBACKS_OVERRIDE;
+
+#if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,34)
+ void tls_verify_cert_chain(const std::vector<Botan::X509_Certificate>& cert_chain,
+ const std::vector<std::shared_ptr<const Botan::OCSP::Response>>& ocsp_responses,
+ const std::vector<Botan::Certificate_Store*>& trusted_roots,
+ Botan::Usage_Type usage,
+ const std::string& hostname,
+ const Botan::TLS::Policy& policy) BOTAN_TLS_CALLBACKS_OVERRIDE;
+#endif
/**
* Called whenever the tls session goes from inactive to active. This
* means that the handshake has just been successfully done, and we can
@@ -185,20 +196,11 @@ private:
* Where data is added, when we want to send something to the client.
*/
std::vector<std::string> out_buf;
+protected:
/**
- * DNS resolver
- */
- Resolver resolver;
- /**
- * Keep the details of the addrinfo returned by the resolver that
- * triggered a EINPROGRESS error when connect()ing to it, to reuse it
- * directly when connect() is called again.
+ * Whether we are using TLS on this connection or not.
*/
- struct addrinfo addrinfo;
- struct sockaddr_in6 ai_addr;
- socklen_t ai_addrlen;
-
-protected:
+ bool use_tls;
/**
* Where data read from the socket is added until we can extract a full
* and meaningful “message” from it.
@@ -207,9 +209,9 @@ protected:
*/
std::string in_buf;
/**
- * Whether we are using TLS on this connection or not.
+ * Remove the given “size” first bytes from our in_buf.
*/
- bool use_tls;
+ void consume_in_buffer(const std::size_t size);
/**
* Provide a buffer in which data can be directly received. This can be
* used to avoid copying data into in_buf before using it. If no buffer
@@ -219,38 +221,12 @@ protected:
*/
virtual void* get_receive_buffer(const size_t size) const;
/**
- * Hostname we are connected/connecting to
- */
- std::string address;
- /**
- * Port we are connected/connecting to
- */
- std::string port;
-
- bool connected;
- bool connecting;
-
- bool hostname_resolution_failed;
-
- /**
- * Address to bind the socket to, before calling connect().
- * If empty, it’s equivalent to binding to INADDR_ANY.
- */
- std::string bind_addr;
-
-private:
- /**
- * Display the resolved IP, just for information purpose.
+ * Called when we detect a disconnection from the remote host.
*/
- void display_resolved_ip(struct addrinfo* rp) const;
+ virtual void on_connection_close(const std::string&) {}
+ virtual void on_connection_failed(const std::string&) {}
#ifdef BOTAN_FOUND
- /**
- * Botan stuff to manipulate a TLS session.
- */
- static Botan::AutoSeeded_RNG rng;
- static Botan::TLS::Policy policy;
- static Botan::TLS::Session_Manager_In_Memory session_manager;
protected:
BasicCredentialsManager credential_manager;
private:
diff --git a/louloulibs/utils/encoding.cpp b/louloulibs/utils/encoding.cpp
index 60f2212..087095f 100644
--- a/louloulibs/utils/encoding.cpp
+++ b/louloulibs/utils/encoding.cpp
@@ -7,6 +7,7 @@
#include <assert.h>
#include <string.h>
#include <iconv.h>
+#include <cerrno>
#include <map>
#include <bitset>
diff --git a/louloulibs/utils/time.cpp b/louloulibs/utils/time.cpp
index afd6117..e3c49ed 100644
--- a/louloulibs/utils/time.cpp
+++ b/louloulibs/utils/time.cpp
@@ -24,7 +24,7 @@ std::time_t parse_datetime(const std::string& stamp)
std::tm t = {};
#ifdef HAS_GET_TIME
std::istringstream ss(stamp);
- ss.imbue(std::locale("en_US.utf-8"));
+ ss.imbue(std::locale("en_US.UTF-8"));
std::string timezone;
ss >> std::get_time(&t, format) >> timezone;
diff --git a/louloulibs/xmpp/adhoc_command.cpp b/louloulibs/xmpp/adhoc_command.cpp
index 99701d7..825cc92 100644
--- a/louloulibs/xmpp/adhoc_command.cpp
+++ b/louloulibs/xmpp/adhoc_command.cpp
@@ -18,30 +18,24 @@ bool AdhocCommand::is_admin_only() const
void PingStep1(XmppComponent&, AdhocSession&, XmlNode& command_node)
{
- XmlNode note("note");
+ XmlSubNode note(command_node, "note");
note["type"] = "info";
note.set_inner("Pong");
- command_node.add_child(std::move(note));
}
void HelloStep1(XmppComponent&, AdhocSession&, XmlNode& command_node)
{
- XmlNode x("jabber:x:data:x");
+ XmlSubNode x(command_node, "jabber:x:data:x");
x["type"] = "form";
- XmlNode title("title");
+ XmlSubNode title(x, "title");
title.set_inner("Configure your name.");
- x.add_child(std::move(title));
- XmlNode instructions("instructions");
+ XmlSubNode instructions(x, "instructions");
instructions.set_inner("Please provide your name.");
- x.add_child(std::move(instructions));
- XmlNode name_field("field");
+ XmlSubNode name_field(x, "field");
name_field["var"] = "name";
name_field["type"] = "text-single";
name_field["label"] = "Your name";
- XmlNode required("required");
- name_field.add_child(std::move(required));
- x.add_child(std::move(name_field));
- command_node.add_child(std::move(x));
+ XmlSubNode required(name_field, "required");
}
void HelloStep2(XmppComponent&, AdhocSession& session, XmlNode& command_node)
@@ -60,21 +54,19 @@ void HelloStep2(XmppComponent&, AdhocSession& session, XmlNode& command_node)
{
if (const XmlNode* value = name_field->get_child("value", "jabber:x:data"))
{
- XmlNode note("note");
- note["type"] = "info";
- note.set_inner("Hello "s + value->get_inner() + "!"s);
+ const std::string value_str = value->get_inner();
command_node.delete_all_children();
- command_node.add_child(std::move(note));
+ XmlSubNode note(command_node, "note");
+ note["type"] = "info";
+ note.set_inner("Hello "s + value_str + "!"s);
return;
}
}
}
command_node.delete_all_children();
- XmlNode error(ADHOC_NS":error");
+ XmlSubNode error(command_node, ADHOC_NS":error");
error["type"] = "modify";
- XmlNode condition(STANZA_NS":bad-request");
- error.add_child(std::move(condition));
- command_node.add_child(std::move(error));
+ XmlSubNode condition(error, STANZA_NS":bad-request");
session.terminate();
}
@@ -82,8 +74,7 @@ void Reload(XmppComponent&, AdhocSession&, XmlNode& command_node)
{
::reload_process();
command_node.delete_all_children();
- XmlNode note("note");
+ XmlSubNode note(command_node, "note");
note["type"] = "info";
note.set_inner("Configuration reloaded.");
- command_node.add_child(std::move(note));
}
diff --git a/louloulibs/xmpp/adhoc_commands_handler.cpp b/louloulibs/xmpp/adhoc_commands_handler.cpp
index 540cac0..040d0ff 100644
--- a/louloulibs/xmpp/adhoc_commands_handler.cpp
+++ b/louloulibs/xmpp/adhoc_commands_handler.cpp
@@ -36,20 +36,16 @@ XmlNode AdhocCommandsHandler::handle_request(const std::string& executor_jid, co
auto command_it = this->commands.find(node);
if (command_it == this->commands.end())
{
- XmlNode error(ADHOC_NS":error");
+ XmlSubNode error(command_node, ADHOC_NS":error");
error["type"] = "cancel";
- XmlNode condition(STANZA_NS":item-not-found");
- error.add_child(std::move(condition));
- command_node.add_child(std::move(error));
+ XmlSubNode condition(error, STANZA_NS":item-not-found");
}
else if (command_it->second.is_admin_only() &&
Config::get("admin", "") != jid.local + "@" + jid.domain)
{
- XmlNode error(ADHOC_NS":error");
+ XmlSubNode error(command_node, ADHOC_NS":error");
error["type"] = "cancel";
- XmlNode condition(STANZA_NS":forbidden");
- error.add_child(std::move(condition));
- command_node.add_child(std::move(error));
+ XmlSubNode condition(error, STANZA_NS":forbidden");
}
else
{
@@ -66,15 +62,8 @@ XmlNode AdhocCommandsHandler::handle_request(const std::string& executor_jid, co
"adhocsession"s + sessionid + executor_jid));
}
auto session_it = this->sessions.find(std::make_pair(sessionid, executor_jid));
- if (session_it == this->sessions.end())
- {
- XmlNode error(ADHOC_NS":error");
- error["type"] = "modify";
- XmlNode condition(STANZA_NS":bad-request");
- error.add_child(std::move(condition));
- command_node.add_child(std::move(error));
- }
- else if (action == "execute" || action == "next" || action == "complete")
+ if ((session_it != this->sessions.end()) &&
+ (action == "execute" || action == "next" || action == "complete"))
{
// execute the step
AdhocSession& session = session_it->second;
@@ -90,10 +79,8 @@ XmlNode AdhocCommandsHandler::handle_request(const std::string& executor_jid, co
else
{
command_node["status"] = "executing";
- XmlNode actions("actions");
- XmlNode next("next");
- actions.add_child(std::move(next));
- command_node.add_child(std::move(actions));
+ XmlSubNode actions(command_node, "actions");
+ XmlSubNode next(actions, "next");
}
}
else if (action == "cancel")
@@ -104,11 +91,9 @@ XmlNode AdhocCommandsHandler::handle_request(const std::string& executor_jid, co
}
else // unsupported action
{
- XmlNode error(ADHOC_NS":error");
+ XmlSubNode error(command_node, ADHOC_NS":error");
error["type"] = "modify";
- XmlNode condition(STANZA_NS":bad-request");
- error.add_child(std::move(condition));
- command_node.add_child(std::move(error));
+ XmlSubNode condition(error, STANZA_NS":bad-request");
}
}
return command_node;
diff --git a/louloulibs/xmpp/xmpp_component.cpp b/louloulibs/xmpp/xmpp_component.cpp
index fa8b0a5..e1b6131 100644
--- a/louloulibs/xmpp/xmpp_component.cpp
+++ b/louloulibs/xmpp/xmpp_component.cpp
@@ -39,7 +39,7 @@ static std::set<std::string> kickable_errors{
};
XmppComponent::XmppComponent(std::shared_ptr<Poller> poller, const std::string& hostname, const std::string& secret):
- TCPSocketHandler(poller),
+ TCPClientSocketHandler(poller),
ever_auth(false),
first_connection_try(true),
secret(secret),
@@ -172,12 +172,13 @@ void XmppComponent::on_stanza(const Stanza& stanza)
void XmppComponent::send_stream_error(const std::string& name, const std::string& explanation)
{
- XmlNode node("stream:error", nullptr);
- XmlNode error(name, nullptr);
- error["xmlns"] = STREAM_NS;
- if (!explanation.empty())
- error.set_inner(explanation);
- node.add_child(std::move(error));
+ Stanza node("stream:error");
+ {
+ XmlSubNode error(node, name);
+ error["xmlns"] = STREAM_NS;
+ if (!explanation.empty())
+ error.set_inner(explanation);
+ }
this->send_stanza(node);
}
@@ -187,31 +188,34 @@ void XmppComponent::send_stanza_error(const std::string& kind, const std::string
const bool fulljid)
{
Stanza node(kind);
- if (!to.empty())
- node["to"] = to;
- if (!from.empty())
+ {
+ if (!to.empty())
+ node["to"] = to;
+ if (!from.empty())
+ {
+ if (fulljid)
+ node["from"] = from;
+ else
+ node["from"] = from + "@" + this->served_hostname;
+ }
+ if (!id.empty())
+ node["id"] = id;
+ node["type"] = "error";
{
- if (fulljid)
- node["from"] = from;
- else
- node["from"] = from + "@" + this->served_hostname;
+ XmlSubNode error(node, "error");
+ error["type"] = error_type;
+ {
+ XmlSubNode inner_error(error, defined_condition);
+ inner_error["xmlns"] = STANZA_NS;
+ }
+ if (!text.empty())
+ {
+ XmlSubNode text_node(error, "text");
+ text_node["xmlns"] = STANZA_NS;
+ text_node.set_inner(text);
+ }
}
- if (!id.empty())
- node["id"] = id;
- node["type"] = "error";
- XmlNode error("error");
- error["type"] = error_type;
- XmlNode inner_error(defined_condition);
- inner_error["xmlns"] = STANZA_NS;
- error.add_child(std::move(inner_error));
- if (!text.empty())
- {
- XmlNode text_node("text");
- text_node["xmlns"] = STANZA_NS;
- text_node.set_inner(text);
- error.add_child(std::move(text_node));
- }
- node.add_child(std::move(error));
+ }
this->send_stanza(node);
}
@@ -264,38 +268,33 @@ void* XmppComponent::get_receive_buffer(const size_t size) const
void XmppComponent::send_message(const std::string& from, Xmpp::body&& body, const std::string& to,
const std::string& type, const bool fulljid, const bool nocopy)
{
- XmlNode node("message");
- node["to"] = to;
- if (fulljid)
- node["from"] = from;
- else
- node["from"] = from + "@" + this->served_hostname;
- if (!type.empty())
- node["type"] = type;
- XmlNode body_node("body");
- body_node.set_inner(std::get<0>(body));
- node.add_child(std::move(body_node));
- if (std::get<1>(body))
- {
- XmlNode html("html");
- html["xmlns"] = XHTMLIM_NS;
- // Pass the ownership of the pointer to this xmlnode
- html.add_child(std::move(std::get<1>(body)));
- node.add_child(std::move(html));
- }
-
- if (nocopy)
- {
- XmlNode private_node("private");
- private_node["xmlns"] = "urn:xmpp:carbons:2";
- node.add_child(std::move(private_node));
-
- XmlNode nocopy("no-copy");
- nocopy["xmlns"] = "urn:xmpp:hints";
- node.add_child(std::move(nocopy));
- }
-
- this->send_stanza(node);
+ Stanza message("message");
+ {
+ message["to"] = to;
+ if (fulljid)
+ message["from"] = from;
+ else
+ message["from"] = from + "@" + this->served_hostname;
+ if (!type.empty())
+ message["type"] = type;
+ XmlSubNode body_node(message, "body");
+ body_node.set_inner(std::get<0>(body));
+ if (std::get<1>(body))
+ {
+ XmlSubNode html(message, "html");
+ html["xmlns"] = XHTMLIM_NS;
+ // Pass the ownership of the pointer to this xmlnode
+ html.add_child(std::move(std::get<1>(body)));
+ }
+ if (nocopy)
+ {
+ XmlSubNode private_node(message, "private");
+ private_node["xmlns"] = "urn:xmpp:carbons:2";
+ XmlSubNode nocopy(message, "no-copy");
+ nocopy["xmlns"] = "urn:xmpp:hints";
+ }
+ }
+ this->send_stanza(message);
}
void XmppComponent::send_user_join(const std::string& from,
@@ -306,34 +305,33 @@ void XmppComponent::send_user_join(const std::string& from,
const std::string& to,
const bool self)
{
- XmlNode node("presence");
- node["to"] = to;
- node["from"] = from + "@" + this->served_hostname + "/" + nick;
-
- XmlNode x("x");
- x["xmlns"] = MUC_USER_NS;
-
- XmlNode item("item");
- if (!affiliation.empty())
- item["affiliation"] = affiliation;
- if (!role.empty())
- item["role"] = role;
- if (!realjid.empty())
- {
- const std::string preped_jid = jidprep(realjid);
- if (!preped_jid.empty())
- item["jid"] = preped_jid;
- }
- x.add_child(std::move(item));
-
- if (self)
- {
- XmlNode status("status");
- status["code"] = "110";
- x.add_child(std::move(status));
- }
- node.add_child(std::move(x));
- this->send_stanza(node);
+ Stanza presence("presence");
+ {
+ presence["to"] = to;
+ presence["from"] = from + "@" + this->served_hostname + "/" + nick;
+
+ XmlSubNode x(presence, "x");
+ x["xmlns"] = MUC_USER_NS;
+
+ XmlSubNode item(x, "item");
+ if (!affiliation.empty())
+ item["affiliation"] = affiliation;
+ if (!role.empty())
+ item["role"] = role;
+ if (!realjid.empty())
+ {
+ const std::string preped_jid = jidprep(realjid);
+ if (!preped_jid.empty())
+ item["jid"] = preped_jid;
+ }
+
+ if (self)
+ {
+ XmlSubNode status(x, "status");
+ status["code"] = "110";
+ }
+ }
+ this->send_stanza(presence);
}
void XmppComponent::send_invalid_room_error(const std::string& muc_name,
@@ -341,44 +339,43 @@ void XmppComponent::send_invalid_room_error(const std::string& muc_name,
const std::string& to)
{
Stanza presence("presence");
- if (!muc_name.empty())
- presence["from"] = muc_name + "@" + this->served_hostname + "/" + nick;
- else
- presence["from"] = this->served_hostname;
- presence["to"] = to;
- presence["type"] = "error";
- XmlNode x("x");
- x["xmlns"] = MUC_NS;
- presence.add_child(std::move(x));
- XmlNode error("error");
- error["by"] = muc_name + "@" + this->served_hostname;
- error["type"] = "cancel";
- XmlNode item_not_found("item-not-found");
- item_not_found["xmlns"] = STANZA_NS;
- error.add_child(std::move(item_not_found));
- XmlNode text("text");
- text["xmlns"] = STANZA_NS;
- text["xml:lang"] = "en";
- text.set_inner(muc_name +
- " is not a valid IRC channel name. A correct room jid is of the form: #<chan>%<server>@" +
- this->served_hostname);
- error.add_child(std::move(text));
- presence.add_child(std::move(error));
+ {
+ if (!muc_name.empty ())
+ presence["from"] = muc_name + "@" + this->served_hostname + "/" + nick;
+ else
+ presence["from"] = this->served_hostname;
+ presence["to"] = to;
+ presence["type"] = "error";
+ XmlSubNode x(presence, "x");
+ x["xmlns"] = MUC_NS;
+ XmlSubNode error(presence, "error");
+ error["by"] = muc_name + "@" + this->served_hostname;
+ error["type"] = "cancel";
+ XmlSubNode item_not_found(error, "item-not-found");
+ item_not_found["xmlns"] = STANZA_NS;
+ XmlSubNode text(error, "text");
+ text["xmlns"] = STANZA_NS;
+ text["xml:lang"] = "en";
+ text.set_inner(muc_name +
+ " is not a valid IRC channel name. A correct room jid is of the form: #<chan>%<server>@" +
+ this->served_hostname);
+ }
this->send_stanza(presence);
}
void XmppComponent::send_topic(const std::string& from, Xmpp::body&& topic, const std::string& to, const std::string& who)
{
- XmlNode message("message");
- message["to"] = to;
- if (who.empty())
- message["from"] = from + "@" + this->served_hostname;
- else
- message["from"] = from + "@" + this->served_hostname + "/" + who;
- message["type"] = "groupchat";
- XmlNode subject("subject");
- subject.set_inner(std::get<0>(topic));
- message.add_child(std::move(subject));
+ Stanza message("message");
+ {
+ message["to"] = to;
+ if (who.empty())
+ message["from"] = from + "@" + this->served_hostname;
+ else
+ message["from"] = from + "@" + this->served_hostname + "/" + who;
+ message["type"] = "groupchat";
+ XmlSubNode subject(message, "subject");
+ subject.set_inner(std::get<0>(topic));
+ }
this->send_stanza(message);
}
@@ -391,16 +388,18 @@ void XmppComponent::send_muc_message(const std::string& muc_name, const std::str
else // Message from the room itself
message["from"] = muc_name + "@" + this->served_hostname;
message["type"] = "groupchat";
- XmlNode body("body");
- body.set_inner(std::get<0>(xmpp_body));
- message.add_child(std::move(body));
+
+ {
+ XmlSubNode body(message, "body");
+ body.set_inner(std::get<0>(xmpp_body));
+ }
+
if (std::get<1>(xmpp_body))
{
- XmlNode html("html");
+ XmlSubNode html(message, "html");
html["xmlns"] = XHTMLIM_NS;
// Pass the ownership of the pointer to this xmlnode
html.add_child(std::move(std::get<1>(xmpp_body)));
- message.add_child(std::move(html));
}
this->send_stanza(message);
}
@@ -415,41 +414,41 @@ void XmppComponent::send_history_message(const std::string& muc_name, const std:
message["from"] = muc_name + "@" + this->served_hostname;
message["type"] = "groupchat";
- XmlNode body("body");
- body.set_inner(body_txt);
- message.add_child(std::move(body));
+ {
+ XmlSubNode body(message, "body");
+ body.set_inner(body_txt);
+ }
+ {
+ XmlSubNode delay(message, "delay");
+ delay["xmlns"] = DELAY_NS;
+ delay["from"] = muc_name + "@" + this->served_hostname;
+ delay["stamp"] = utils::to_string(timestamp);
+ }
- XmlNode delay("delay");
- delay["xmlns"] = DELAY_NS;
- delay["from"] = muc_name + "@" + this->served_hostname;
- delay["stamp"] = utils::to_string(timestamp);
-
- message.add_child(std::move(delay));
this->send_stanza(message);
}
void XmppComponent::send_muc_leave(const std::string& muc_name, std::string&& nick, Xmpp::body&& message, const std::string& jid_to, const bool self)
{
Stanza presence("presence");
- presence["to"] = jid_to;
- presence["from"] = muc_name + "@" + this->served_hostname + "/" + nick;
- presence["type"] = "unavailable";
- const std::string message_str = std::get<0>(message);
- XmlNode x("x");
- x["xmlns"] = MUC_USER_NS;
- if (self)
- {
- XmlNode status("status");
- status["code"] = "110";
- x.add_child(std::move(status));
- }
- presence.add_child(std::move(x));
- if (!message_str.empty())
- {
- XmlNode status("status");
- status.set_inner(message_str);
- presence.add_child(std::move(status));
- }
+ {
+ presence["to"] = jid_to;
+ presence["from"] = muc_name + "@" + this->served_hostname + "/" + nick;
+ presence["type"] = "unavailable";
+ const std::string message_str = std::get<0>(message);
+ XmlSubNode x(presence, "x");
+ x["xmlns"] = MUC_USER_NS;
+ if (self)
+ {
+ XmlSubNode status(x, "status");
+ status["code"] = "110";
+ }
+ if (!message_str.empty())
+ {
+ XmlSubNode status(presence, "status");
+ status.set_inner(message_str);
+ }
+ }
this->send_stanza(presence);
}
@@ -462,24 +461,22 @@ void XmppComponent::send_nick_change(const std::string& muc_name,
const bool self)
{
Stanza presence("presence");
- presence["to"] = jid_to;
- presence["from"] = muc_name + "@" + this->served_hostname + "/" + old_nick;
- presence["type"] = "unavailable";
- XmlNode x("x");
- x["xmlns"] = MUC_USER_NS;
- XmlNode item("item");
- item["nick"] = new_nick;
- x.add_child(std::move(item));
- XmlNode status("status");
- status["code"] = "303";
- x.add_child(std::move(status));
- if (self)
- {
- XmlNode status2("status");
- status2["code"] = "110";
- x.add_child(std::move(status2));
- }
- presence.add_child(std::move(x));
+ {
+ presence["to"] = jid_to;
+ presence["from"] = muc_name + "@" + this->served_hostname + "/" + old_nick;
+ presence["type"] = "unavailable";
+ XmlSubNode x(presence, "x");
+ x["xmlns"] = MUC_USER_NS;
+ XmlSubNode item(x, "item");
+ item["nick"] = new_nick;
+ XmlSubNode status(x, "status");
+ status["code"] = "303";
+ if (self)
+ {
+ XmlSubNode status(x, "status");
+ status["code"] = "110";
+ }
+ }
this->send_stanza(presence);
this->send_user_join(muc_name, new_nick, "", affiliation, role, jid_to, self);
@@ -489,32 +486,28 @@ void XmppComponent::kick_user(const std::string& muc_name, const std::string& ta
const std::string& author, const std::string& jid_to, const bool self)
{
Stanza presence("presence");
- presence["from"] = muc_name + "@" + this->served_hostname + "/" + target;
- presence["to"] = jid_to;
- presence["type"] = "unavailable";
- XmlNode x("x");
- x["xmlns"] = MUC_USER_NS;
- XmlNode item("item");
- item["affiliation"] = "none";
- item["role"] = "none";
- XmlNode actor("actor");
- actor["nick"] = author;
- actor["jid"] = author; // backward compatibility with old clients
- item.add_child(std::move(actor));
- XmlNode reason("reason");
- reason.set_inner(txt);
- item.add_child(std::move(reason));
- x.add_child(std::move(item));
- XmlNode status("status");
- status["code"] = "307";
- x.add_child(std::move(status));
- if (self)
- {
- XmlNode status("status");
- status["code"] = "110";
- x.add_child(std::move(status));
- }
- presence.add_child(std::move(x));
+ {
+ presence["from"] = muc_name + "@" + this->served_hostname + "/" + target;
+ presence["to"] = jid_to;
+ presence["type"] = "unavailable";
+ XmlSubNode x(presence, "x");
+ x["xmlns"] = MUC_USER_NS;
+ XmlSubNode item(x, "item");
+ item["affiliation"] = "none";
+ item["role"] = "none";
+ XmlSubNode actor(item, "actor");
+ actor["nick"] = author;
+ actor["jid"] = author; // backward compatibility with old clients
+ XmlSubNode reason(item, "reason");
+ reason.set_inner(txt);
+ XmlSubNode status(x, "status");
+ status["code"] = "307";
+ if (self)
+ {
+ XmlSubNode status(x, "status");
+ status["code"] = "110";
+ }
+ }
this->send_stanza(presence);
}
@@ -524,24 +517,29 @@ void XmppComponent::send_presence_error(const std::string& muc_name,
const std::string& type,
const std::string& condition,
const std::string& error_code,
- const std::string& /* text */)
+ const std::string& text)
{
Stanza presence("presence");
- presence["from"] = muc_name + "@" + this->served_hostname + "/" + nickname;
- presence["to"] = jid_to;
- presence["type"] = "error";
- XmlNode x("x");
- x["xmlns"] = MUC_NS;
- presence.add_child(std::move(x));
- XmlNode error("error");
- error["by"] = muc_name + "@" + this->served_hostname;
- error["type"] = type;
- if (!error_code.empty())
- error["code"] = error_code;
- XmlNode subnode(condition);
- subnode["xmlns"] = STANZA_NS;
- error.add_child(std::move(subnode));
- presence.add_child(std::move(error));
+ {
+ presence["from"] = muc_name + "@" + this->served_hostname + "/" + nickname;
+ presence["to"] = jid_to;
+ presence["type"] = "error";
+ XmlSubNode x(presence, "x");
+ x["xmlns"] = MUC_NS;
+ XmlSubNode error(presence, "error");
+ error["by"] = muc_name + "@" + this->served_hostname;
+ error["type"] = type;
+ if (!text.empty())
+ {
+ XmlSubNode text_node(error, "text");
+ text_node["xmlns"] = STANZA_NS;
+ text_node.set_inner(text);
+ }
+ if (!error_code.empty())
+ error["code"] = error_code;
+ XmlSubNode subnode(error, condition);
+ subnode["xmlns"] = STANZA_NS;
+ }
this->send_stanza(presence);
}
@@ -552,15 +550,15 @@ void XmppComponent::send_affiliation_role_change(const std::string& muc_name,
const std::string& jid_to)
{
Stanza presence("presence");
- presence["from"] = muc_name + "@" + this->served_hostname + "/" + target;
- presence["to"] = jid_to;
- XmlNode x("x");
- x["xmlns"] = MUC_USER_NS;
- XmlNode item("item");
- item["affiliation"] = affiliation;
- item["role"] = role;
- x.add_child(std::move(item));
- presence.add_child(std::move(x));
+ {
+ presence["from"] = muc_name + "@" + this->served_hostname + "/" + target;
+ presence["to"] = jid_to;
+ XmlSubNode x(presence, "x");
+ x["xmlns"] = MUC_USER_NS;
+ XmlSubNode item(x, "item");
+ item["affiliation"] = affiliation;
+ item["role"] = role;
+ }
this->send_stanza(presence);
}
@@ -572,27 +570,30 @@ void XmppComponent::send_version(const std::string& id, const std::string& jid_t
iq["id"] = id;
iq["to"] = jid_to;
iq["from"] = jid_from;
- XmlNode query("query");
- query["xmlns"] = VERSION_NS;
- if (version.empty())
- {
- XmlNode name("name");
- name.set_inner("biboumi");
- query.add_child(std::move(name));
- XmlNode version("version");
- version.set_inner(SOFTWARE_VERSION);
- query.add_child(std::move(version));
- XmlNode os("os");
- os.set_inner(SYSTEM_NAME);
- query.add_child(std::move(os));
+ {
+ XmlSubNode query(iq, "query");
+ query["xmlns"] = VERSION_NS;
+ if (version.empty())
+ {
+ {
+ XmlSubNode name(query, "name");
+ name.set_inner("biboumi");
+ }
+ {
+ XmlSubNode version(query, "version");
+ version.set_inner(SOFTWARE_VERSION);
+ }
+ {
+ XmlSubNode os(query, "os");
+ os.set_inner(SYSTEM_NAME);
+ }
}
- else
+ else
{
- XmlNode name("name");
+ XmlSubNode name(query, "name");
name.set_inner(version);
- query.add_child(std::move(name));
}
- iq.add_child(std::move(query));
+ }
this->send_stanza(iq);
}
@@ -601,24 +602,24 @@ void XmppComponent::send_adhoc_commands_list(const std::string& id, const std::s
const bool with_admin_only, const AdhocCommandsHandler& adhoc_handler)
{
Stanza iq("iq");
- iq["type"] = "result";
- iq["id"] = id;
- iq["to"] = requester_jid;
- iq["from"] = from_jid;
- XmlNode query("query");
- query["xmlns"] = DISCO_ITEMS_NS;
- query["node"] = ADHOC_NS;
- for (const auto& kv: adhoc_handler.get_commands())
- {
- if (kv.second.is_admin_only() && !with_admin_only)
- continue;
- XmlNode item("item");
- item["jid"] = from_jid;
- item["node"] = kv.first;
- item["name"] = kv.second.name;
- query.add_child(std::move(item));
- }
- iq.add_child(std::move(query));
+ {
+ iq["type"] = "result";
+ iq["id"] = id;
+ iq["to"] = requester_jid;
+ iq["from"] = from_jid;
+ XmlSubNode query(iq, "query");
+ query["xmlns"] = DISCO_ITEMS_NS;
+ query["node"] = ADHOC_NS;
+ for (const auto &kv: adhoc_handler.get_commands())
+ {
+ if (kv.second.is_admin_only() && !with_admin_only)
+ continue;
+ XmlSubNode item(query, "item");
+ item["jid"] = from_jid;
+ item["node"] = kv.first;
+ item["name"] = kv.second.name;
+ }
+ }
this->send_stanza(iq);
}
@@ -626,13 +627,14 @@ void XmppComponent::send_iq_version_request(const std::string& from,
const std::string& jid_to)
{
Stanza iq("iq");
- iq["type"] = "get";
- iq["id"] = "version_"s + XmppComponent::next_id();
- iq["from"] = from + "@" + this->served_hostname;
- iq["to"] = jid_to;
- XmlNode query("query");
- query["xmlns"] = VERSION_NS;
- iq.add_child(std::move(query));
+ {
+ iq["type"] = "get";
+ iq["id"] = "version_"s + XmppComponent::next_id();
+ iq["from"] = from + "@" + this->served_hostname;
+ iq["to"] = jid_to;
+ XmlSubNode query(iq, "query");
+ query["xmlns"] = VERSION_NS;
+ }
this->send_stanza(iq);
}
diff --git a/louloulibs/xmpp/xmpp_component.hpp b/louloulibs/xmpp/xmpp_component.hpp
index 5f5f937..a9bac0f 100644
--- a/louloulibs/xmpp/xmpp_component.hpp
+++ b/louloulibs/xmpp/xmpp_component.hpp
@@ -2,7 +2,7 @@
#include <xmpp/adhoc_commands_handler.hpp>
-#include <network/tcp_socket_handler.hpp>
+#include <network/tcp_client_socket_handler.hpp>
#include <xmpp/xmpp_parser.hpp>
#include <xmpp/body.hpp>
@@ -40,7 +40,7 @@
*
* TODO: implement XEP-0225: Component Connections
*/
-class XmppComponent: public TCPSocketHandler
+class XmppComponent: public TCPClientSocketHandler
{
public:
explicit XmppComponent(std::shared_ptr<Poller> poller, const std::string& hostname, const std::string& secret);
@@ -179,10 +179,6 @@ public:
const std::string& role,
const std::string& jid_to);
/**
- * Send a result IQ with the gateway disco informations.
- */
- void send_self_disco_info(const std::string& id, const std::string& jid_to);
- /**
* Send a result IQ with the given version, or the gateway version if the
* passed string is empty.
*/
diff --git a/louloulibs/xmpp/xmpp_stanza.hpp b/louloulibs/xmpp/xmpp_stanza.hpp
index 4ca758e..f4b3948 100644
--- a/louloulibs/xmpp/xmpp_stanza.hpp
+++ b/louloulibs/xmpp/xmpp_stanza.hpp
@@ -143,4 +143,18 @@ std::ostream& operator<<(std::ostream& os, const XmlNode& node);
*/
using Stanza = XmlNode;
+class XmlSubNode: public XmlNode
+{
+public:
+ XmlSubNode(XmlNode& parent_ref, const std::string& name):
+ XmlNode(name),
+ parent_to_add(parent_ref)
+ {}
+ ~XmlSubNode()
+ {
+ this->parent_to_add.add_child(std::move(*this));
+ }
+private:
+ XmlNode& parent_to_add;
+}; \ No newline at end of file