diff options
Diffstat (limited to 'louloulibs')
60 files changed, 5968 insertions, 0 deletions
diff --git a/louloulibs/CMakeLists.txt b/louloulibs/CMakeLists.txt new file mode 100644 index 0000000..bf53504 --- /dev/null +++ b/louloulibs/CMakeLists.txt @@ -0,0 +1,146 @@ +cmake_minimum_required(VERSION 2.6) + +set(${PROJECT_NAME}_VERSION_MAJOR 1) +set(${PROJECT_NAME}_VERSION_MINOR 0) +set(${PROJECT_NAME}_VERSION_SUFFIX "~dev") + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++1y -pedantic -Wall -Wextra") + +# Define a __FILENAME__ macro to get the filename of each file, instead of +# the full path as in __FILE__ +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILENAME__='\"$(subst ${CMAKE_SOURCE_DIR}/,,$(abspath $<))\"'") + +# +## Look for external libraries +# +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake/Modules/") +include(FindEXPAT) +find_package(EXPAT REQUIRED) +find_package(ICONV REQUIRED) +find_package(LIBUUID REQUIRED) + +if(WITH_LIBIDN) + find_package(LIBIDN REQUIRED) +elseif(NOT WITHOUT_LIBIDN) + find_package(LIBIDN) +endif() + +if(WITH_SYSTEMD) + find_package(SYSTEMD REQUIRED) +elseif(NOT WITHOUT_SYSTEMD) + find_package(SYSTEMD) +endif() + +if(WITH_BOTAN) + find_package(BOTAN REQUIRED) +elseif(NOT WITHOUT_BOTAN) + find_package(BOTAN) +endif() + +if(WITH_CARES) + find_package(CARES REQUIRED) +elseif(NOT WITHOUT_CARES) + find_package(CARES) +endif() + +# To be able to include the config.h file generated by cmake +include_directories("${CMAKE_CURRENT_BINARY_DIR}") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}") +include_directories(${EXPAT_INCLUDE_DIRS}) +include_directories(${ICONV_INCLUDE_DIRS}) +include_directories(${LIBUUID_INCLUDE_DIRS}) + +set(EXPAT_INCLUDE_DIRS ${EXPAT_INCLUDE_DIRS} PARENT_SCOPE) +set(ICONV_INCLUDE_DIRS ${ICONV_INCLUDE_DIRS} PARENT_SCOPE) +set(LIBUUID_INCLUDE_DIRS ${LIBUUID_INCLUDE_DIRS} PARENT_SCOPE) + +if(LIBIDN_FOUND) + include_directories(${LIBIDN_INCLUDE_DIRS}) + set(LIBDIN_FOUND ${LIBDIN_FOUND} PARENT_SCOPE) + set(LIBDIN_INCLUDE_DIRS ${LIBDIN_INCLUDE_DIRS} PARENT_SCOPE) +endif() + +if(SYSTEMD_FOUND) + include_directories(${SYSTEMD_INCLUDE_DIRS}) + set(SYSTEMD_FOUND ${SYSTEMD_FOUND} PARENT_SCOPE) + set(SYSTEMD_INCLUDE_DIRS ${SYSTEMD_INCLUDE_DIRS} PARENT_SCOPE) +endif() + +if(BOTAN_FOUND) + include_directories(SYSTEM ${BOTAN_INCLUDE_DIRS}) + set(BOTAN_FOUND ${BOTAN_FOUND} PARENT_SCOPE) + set(BOTAN_INCLUDE_DIRS ${BOTAN_INCLUDE_DIRS} PARENT_SCOPE) +endif() + +if(CARES_FOUND) + include_directories(${CARES_INCLUDE_DIRS}) + set(CARES_FOUND ${CARES_FOUND} PARENT_SCOPE) + set(CARES_INCLUDE_DIRS ${CARES_INCLUDE_DIRS} PARENT_SCOPE) +endif() + +set(POLLER_DOCSTRING "Choose the poller between POLL and EPOLL (Linux-only)") +if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") + set(POLLER "EPOLL" CACHE STRING ${POLLER_DOCSTRING}) +else() + set(POLLER "POLL" CACHE STRING ${POLLER_DOCSTRING}) +endif() +if((NOT ${POLLER} MATCHES "POLL") AND + (NOT ${POLLER} MATCHES "EPOLL")) + message(FATAL_ERROR "POLLER must be either POLL or EPOLL") +endif() + +# +## utils +# +file(GLOB source_utils + utils/*.[hc]pp) +add_library(utils STATIC ${source_utils}) +target_link_libraries(utils ${ICONV_LIBRARIES}) + +# +## config +# +file(GLOB source_config + config/*.[hc]pp) +add_library(config STATIC ${source_config}) +target_link_libraries(config utils) + +# +## logger +# +file(GLOB source_logger + logger/*.[hc]pp) +add_library(logger STATIC ${source_logger}) +target_link_libraries(logger config) + +# +## network +# +file(GLOB source_network + network/*.[hc]pp) +add_library(network STATIC ${source_network}) +target_link_libraries(network logger) +if(BOTAN_FOUND) + target_link_libraries(network ${BOTAN_LIBRARIES}) +endif() +if(CARES_FOUND) + target_link_libraries(network ${CARES_LIBRARIES}) +endif() + +# +## xmpplib +# +file(GLOB source_xmpplib + xmpp/*.[hc]pp) +add_library(xmpplib STATIC ${source_xmpplib}) +target_link_libraries(xmpplib network utils logger + ${EXPAT_LIBRARIES} + ${LIBUUID_LIBRARIES}) +if(LIBIDN_FOUND) + target_link_libraries(xmpplib ${LIBIDN_LIBRARIES}) +endif() +if(SYSTEMD_FOUND) + target_link_libraries(xmpplib ${SYSTEMD_LIBRARIES}) +endif() + +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/louloulibs.h.cmake ${CMAKE_BINARY_DIR}/src/louloulibs.h) diff --git a/louloulibs/cmake/Modules/FindBOTAN.cmake b/louloulibs/cmake/Modules/FindBOTAN.cmake new file mode 100644 index 0000000..a12bd35 --- /dev/null +++ b/louloulibs/cmake/Modules/FindBOTAN.cmake @@ -0,0 +1,35 @@ +# - Find botan +# Find the botan cryptographic library +# +# This module defines the following variables: +# BOTAN_FOUND - True if library and include directory are found +# If set to TRUE, the following are also defined: +# BOTAN_INCLUDE_DIRS - The directory where to find the header file +# BOTAN_LIBRARIES - Where to find the library file +# +# For conveniance, these variables are also set. They have the same values +# than the variables above. The user can thus choose his/her prefered way +# to write them. +# BOTAN_LIBRARY +# BOTAN_INCLUDE_DIR +# +# This file is in the public domain + +find_path(BOTAN_INCLUDE_DIRS NAMES botan/botan.h + PATH_SUFFIXES botan-1.11 + DOC "The botan include directory") + +find_library(BOTAN_LIBRARIES NAMES botan botan-1.11 + DOC "The botan library") + +# Use some standard module to handle the QUIETLY and REQUIRED arguments, and +# set BOTAN_FOUND to TRUE if these two variables are set. +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(BOTAN REQUIRED_VARS BOTAN_LIBRARIES BOTAN_INCLUDE_DIRS) + +if(BOTAN_FOUND) + set(BOTAN_LIBRARY ${BOTAN_LIBRARIES}) + set(BOTAN_INCLUDE_DIR ${BOTAN_INCLUDE_DIRS}) +endif() + +mark_as_advanced(BOTAN_INCLUDE_DIRS BOTAN_LIBRARIES) diff --git a/louloulibs/cmake/Modules/FindCARES.cmake b/louloulibs/cmake/Modules/FindCARES.cmake new file mode 100644 index 0000000..c4c757a --- /dev/null +++ b/louloulibs/cmake/Modules/FindCARES.cmake @@ -0,0 +1,37 @@ +# - Find c-ares +# Find the c-ares library, and more particularly the stringprep header. +# +# This module defines the following variables: +# CARES_FOUND - True if library and include directory are found +# If set to TRUE, the following are also defined: +# CARES_INCLUDE_DIRS - The directory where to find the header file +# CARES_LIBRARIES - Where to find the library file +# +# For conveniance, these variables are also set. They have the same values +# than the variables above. The user can thus choose his/her prefered way +# to write them. +# CARES_INCLUDE_DIR +# CARES_LIBRARY +# +# This file is in the public domain + +if(NOT CARES_FOUND) + find_path(CARES_INCLUDE_DIRS NAMES ares.h + DOC "The c-ares include directory") + + find_library(CARES_LIBRARIES NAMES cares + DOC "The c-ares library") + + # Use some standard module to handle the QUIETLY and REQUIRED arguments, and + # set CARES_FOUND to TRUE if these two variables are set. + include(FindPackageHandleStandardArgs) + find_package_handle_standard_args(CARES REQUIRED_VARS CARES_LIBRARIES CARES_INCLUDE_DIRS) + + # Compatibility for all the ways of writing these variables + if(CARES_FOUND) + set(CARES_INCLUDE_DIR ${CARES_INCLUDE_DIRS}) + set(CARES_LIBRARY ${CARES_LIBRARIES}) + endif() +endif() + +mark_as_advanced(CARES_INCLUDE_DIRS CARES_LIBRARIES) diff --git a/louloulibs/cmake/Modules/FindICONV.cmake b/louloulibs/cmake/Modules/FindICONV.cmake new file mode 100644 index 0000000..7ca173f --- /dev/null +++ b/louloulibs/cmake/Modules/FindICONV.cmake @@ -0,0 +1,60 @@ +# - Find iconv +# Find the iconv (character set conversion) library +# +# This module defines the following variables: +# ICONV_FOUND - True if library and include directory are found +# If set to TRUE, the following are also defined: +# ICONV_INCLUDE_DIRS - The directory where to find the header file +# ICONV_LIBRARIES - Where to find the library file +# ICONV_SECOND_ARGUMENT_IS_CONST - The second argument for iconv() is const +# +# For conveniance, these variables are also set. They have the same values +# than the variables above. The user can thus choose his/her prefered way +# to write them. +# ICONV_LIBRARY +# ICONV_INCLUDE_DIR +# +# This file is in the public domain + +find_path(ICONV_INCLUDE_DIRS NAMES iconv.h + DOC "The iconv include directory") + +find_library(ICONV_LIBRARIES NAMES iconv libiconv libiconv-2 c + DOC "The iconv library") + +# Use some standard module to handle the QUIETLY and REQUIRED arguments, and +# set ICONV_FOUND to TRUE if these two variables are set. +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(Iconv REQUIRED_VARS ICONV_LIBRARIES ICONV_INCLUDE_DIRS) + +# Check if the prototype is +# size_t iconv(iconv_t cd, char** inbuf, size_t* inbytesleft, +# char** outbuf, size_t* outbytesleft); +# or +# size_t iconv (iconv_t cd, const char** inbuf, size_t* inbytesleft, +# char** outbuf, size_t* outbytesleft); +if(ICONV_FOUND) + include(CheckCXXSourceCompiles) + + # Set the parameters needed to compile the following code. + set(CMAKE_REQUIRED_INCLUDES ${ICONV_INCLUDE_DIRS}) + set(CMAKE_REQUIRED_LIBRARIES ${ICONV_LIBRARIES}) + + check_cxx_source_compiles(" + #include <iconv.h> + int main(){ + iconv_t conv = 0; + const char* in = 0; + size_t ilen = 0; + char* out = 0; + size_t olen = 0; + iconv(conv, &in, &ilen, &out, &olen); + return 0;}" + ICONV_SECOND_ARGUMENT_IS_CONST) + +# Compatibility for all the ways of writing these variables + set(ICONV_LIBRARY ${ICONV_LIBRARIES}) + set(ICONV_INCLUDE_DIR ${ICONV_INCLUDE_DIRS}) +endif() + +mark_as_advanced(ICONV_INCLUDE_DIRS ICONV_LIBRARIES ICONV_SECOND_ARGUMENT_IS_CONST)
\ No newline at end of file diff --git a/louloulibs/cmake/Modules/FindLIBIDN.cmake b/louloulibs/cmake/Modules/FindLIBIDN.cmake new file mode 100644 index 0000000..611a6a8 --- /dev/null +++ b/louloulibs/cmake/Modules/FindLIBIDN.cmake @@ -0,0 +1,41 @@ +# - Find libidn +# Find the libidn library, and more particularly the stringprep header. +# +# This module defines the following variables: +# LIBIDN_FOUND - True if library and include directory are found +# If set to TRUE, the following are also defined: +# LIBIDN_INCLUDE_DIRS - The directory where to find the header file +# LIBIDN_LIBRARIES - Where to find the library file +# +# For conveniance, these variables are also set. They have the same values +# than the variables above. The user can thus choose his/her prefered way +# to write them. +# LIBIDN_INCLUDE_DIR +# LIBIDN_LIBRARY +# +# This file is in the public domain + +include(FindPkgConfig) +pkg_check_modules(LIBIDN libidn) + +if(NOT LIBIDN_FOUND) + find_path(LIBIDN_INCLUDE_DIRS NAMES stringprep.h + DOC "The libidn include directory") + + # The library containing the stringprep module is libidn + find_library(LIBIDN_LIBRARIES NAMES idn + DOC "The libidn library") + + # Use some standard module to handle the QUIETLY and REQUIRED arguments, and + # set LIBIDN_FOUND to TRUE if these two variables are set. + include(FindPackageHandleStandardArgs) + find_package_handle_standard_args(LIBIDN REQUIRED_VARS LIBIDN_LIBRARIES LIBIDN_INCLUDE_DIRS) + + # Compatibility for all the ways of writing these variables + if(LIBIDN_FOUND) + set(LIBIDN_INCLUDE_DIR ${LIBIDN_INCLUDE_DIRS}) + set(LIBIDN_LIBRARY ${LIBIDN_LIBRARIES}) + endif() +endif() + +mark_as_advanced(LIBIDN_INCLUDE_DIRS LIBIDN_LIBRARIES)
\ No newline at end of file diff --git a/louloulibs/cmake/Modules/FindLIBUUID.cmake b/louloulibs/cmake/Modules/FindLIBUUID.cmake new file mode 100644 index 0000000..17d3c42 --- /dev/null +++ b/louloulibs/cmake/Modules/FindLIBUUID.cmake @@ -0,0 +1,41 @@ +# - Find libuuid +# Find the libuuid library +# +# This module defines the following variables: +# LIBUUID_FOUND - True if library and include directory are found +# If set to TRUE, the following are also defined: +# LIBUUID_INCLUDE_DIRS - The directory where to find the header file +# LIBUUID_LIBRARIES - Where to find the library file +# +# For conveniance, these variables are also set. They have the same values +# than the variables above. The user can thus choose his/her prefered way +# to write them. +# LIBUUID_INCLUDE_DIR +# LIBUUID_LIBRARY +# +# This file is in the public domain + +include(FindPkgConfig) +pkg_check_modules(LIBUUID uuid) + +if(NOT LIBUUID_FOUND) + find_path(LIBUUID_INCLUDE_DIRS NAMES uuid.h + PATH_SUFFIXES uuid + DOC "The libuuid include directory") + + find_library(LIBUUID_LIBRARIES NAMES uuid + DOC "The libuuid library") + + # Use some standard module to handle the QUIETLY and REQUIRED arguments, and + # set LIBUUID_FOUND to TRUE if these two variables are set. + include(FindPackageHandleStandardArgs) + find_package_handle_standard_args(LIBUUID REQUIRED_VARS LIBUUID_LIBRARIES LIBUUID_INCLUDE_DIRS) + + # Compatibility for all the ways of writing these variables + if(LIBUUID_FOUND) + set(LIBUUID_INCLUDE_DIR ${LIBUUID_INCLUDE_DIRS}) + set(LIBUUID_LIBRARY ${LIBUUID_LIBRARIES}) + endif() +endif() + +mark_as_advanced(LIBUUID_INCLUDE_DIRS LIBUUID_LIBRARIES) diff --git a/louloulibs/cmake/Modules/FindSYSTEMD.cmake b/louloulibs/cmake/Modules/FindSYSTEMD.cmake new file mode 100644 index 0000000..c7decde --- /dev/null +++ b/louloulibs/cmake/Modules/FindSYSTEMD.cmake @@ -0,0 +1,39 @@ +# - Find SystemdDaemon +# Find the systemd daemon library +# +# This module defines the following variables: +# SYSTEMD_FOUND - True if library and include directory are found +# If set to TRUE, the following are also defined: +# SYSTEMD_INCLUDE_DIRS - The directory where to find the header file +# SYSTEMD_LIBRARIES - Where to find the library file +# +# For conveniance, these variables are also set. They have the same values +# than the variables above. The user can thus choose his/her prefered way +# to write them. +# SYSTEMD_LIBRARY +# SYSTEMD_INCLUDE_DIR +# +# This file is in the public domain + +include(FindPkgConfig) +pkg_check_modules(SYSTEMD libsystemd) + +if(NOT SYSTEMD_FOUND) + find_path(SYSTEMD_INCLUDE_DIRS NAMES systemd/sd-daemon.h + DOC "The Systemd include directory") + + find_library(SYSTEMD_LIBRARIES NAMES systemd + DOC "The Systemd library") + + # Use some standard module to handle the QUIETLY and REQUIRED arguments, and + # set SYSTEMD_FOUND to TRUE if these two variables are set. + include(FindPackageHandleStandardArgs) + find_package_handle_standard_args(SYSTEMD REQUIRED_VARS SYSTEMD_LIBRARIES SYSTEMD_INCLUDE_DIRS) + + if(SYSTEMD_FOUND) + set(SYSTEMD_LIBRARY ${SYSTEMD_LIBRARIES}) + set(SYSTEMD_INCLUDE_DIR ${SYSTEMD_INCLUDE_DIRS}) + endif() +endif() + +mark_as_advanced(SYSTEMD_INCLUDE_DIRS SYSTEMD_LIBRARIES)
\ No newline at end of file diff --git a/louloulibs/config/config.cpp b/louloulibs/config/config.cpp new file mode 100644 index 0000000..417981d --- /dev/null +++ b/louloulibs/config/config.cpp @@ -0,0 +1,104 @@ +#include <config/config.hpp> +#include <logger/logger.hpp> + +#include <cstring> +#include <sstream> + +#include <cstdlib> + +std::string Config::filename{}; +std::map<std::string, std::string> Config::values{}; +std::vector<t_config_changed_callback> Config::callbacks{}; + +std::string Config::get(const std::string& option, const std::string& def) +{ + auto it = Config::values.find(option); + + if (it == Config::values.end()) + return def; + return it->second; +} + +int Config::get_int(const std::string& option, const int& def) +{ + std::string res = Config::get(option, ""); + if (!res.empty()) + return std::atoi(res.c_str()); + else + return def; +} + +void Config::set(const std::string& option, const std::string& value, bool save) +{ + Config::values[option] = value; + if (save) + { + Config::save_to_file(); + Config::trigger_configuration_change(); + } +} + +void Config::connect(t_config_changed_callback callback) +{ + Config::callbacks.push_back(callback); +} + +void Config::clear() +{ + Config::values.clear(); +} + +/** + * Private methods + */ +void Config::trigger_configuration_change() +{ + std::vector<t_config_changed_callback>::iterator it; + for (it = Config::callbacks.begin(); it < Config::callbacks.end(); ++it) + (*it)(); +} + +bool Config::read_conf(const std::string& name) +{ + if (!name.empty()) + Config::filename = name; + + std::ifstream file(Config::filename.data()); + if (!file.is_open()) + { + log_error("Error while opening file ", filename, " for reading: ", strerror(errno)); + return false; + } + + Config::clear(); + + std::string line; + size_t pos; + std::string option; + std::string value; + while (file.good()) + { + std::getline(file, line); + if (line == "" || line[0] == '#') + continue ; + pos = line.find('='); + if (pos == std::string::npos) + continue ; + option = line.substr(0, pos); + value = line.substr(pos+1); + Config::values[option] = value; + } + return true; +} + +void Config::save_to_file() +{ + std::ofstream file(Config::filename.data()); + if (file.fail()) + { + log_error("Could not save config file."); + return ; + } + for (const auto& it: Config::values) + file << it.first << "=" << it.second << '\n'; +} diff --git a/louloulibs/config/config.hpp b/louloulibs/config/config.hpp new file mode 100644 index 0000000..6728df8 --- /dev/null +++ b/louloulibs/config/config.hpp @@ -0,0 +1,94 @@ +/** + * Read the config file and save all the values in a map. + * Also, a singleton. + * + * Use Config::filename = "bla" to set the filename you want to use. + * + * If you want to exit if the file does not exist when it is open for + * reading, set Config::file_must_exist = true. + * + * Config::get() can then be used to access the values in the conf. + * + * Use Config::close() when you're done getting/setting value. This will + * save the config into the file. + */ + +#pragma once + + +#include <functional> +#include <fstream> +#include <memory> +#include <vector> +#include <string> +#include <map> + +typedef std::function<void()> t_config_changed_callback; + +class Config +{ +public: + Config() = default; + ~Config() = default; + Config(const Config&) = delete; + Config& operator=(const Config&) = delete; + Config(Config&&) = delete; + Config& operator=(Config&&) = delete; + + /** + * returns a value from the config. If it doesn’t exist, use + * the second argument as the default. + */ + static std::string get(const std::string&, const std::string&); + /** + * returns a value from the config. If it doesn’t exist, use + * the second argument as the default. + */ + static int get_int(const std::string&, const int&); + /** + * Set a value for the given option. And write all the config + * in the file from which it was read if save is true. + */ + static void set(const std::string&, const std::string&, bool save = false); + /** + * Adds a function to a list. This function will be called whenever a + * configuration change occurs (when set() is called, or when the initial + * conf is read) + */ + static void connect(t_config_changed_callback); + /** + * Destroy the instance, forcing it to be recreated (with potentially + * different parameters) the next time it’s needed. + */ + static void clear(); + /** + * Read the configuration file at the given path. + */ + static bool read_conf(const std::string& name=""); + /** + * Get the filename + */ + static const std::string& get_filename() + { return Config::filename; } + +private: + /** + * Set the value of the filename to use, before calling any method. + */ + static std::string filename; + /** + * Write all the config values into the configuration file + */ + static void save_to_file(); + /** + * Call all the callbacks previously registered using connect(). + * This is used to notify any class that a configuration change occured. + */ + static void trigger_configuration_change(); + + static std::map<std::string, std::string> values; + static std::vector<t_config_changed_callback> callbacks; + +}; + + diff --git a/louloulibs/logger/logger.cpp b/louloulibs/logger/logger.cpp new file mode 100644 index 0000000..7336579 --- /dev/null +++ b/louloulibs/logger/logger.cpp @@ -0,0 +1,38 @@ +#include <logger/logger.hpp> +#include <config/config.hpp> + +Logger::Logger(const int log_level): + log_level(log_level), + stream(std::cout.rdbuf()) +{ +} + +Logger::Logger(const int log_level, const std::string& log_file): + log_level(log_level), + ofstream(log_file.data(), std::ios_base::app), + stream(ofstream.rdbuf()) +{ +} + +std::unique_ptr<Logger>& Logger::instance() +{ + static std::unique_ptr<Logger> instance; + + if (!instance) + { + const std::string log_file = Config::get("log_file", ""); + const int log_level = Config::get_int("log_level", 0); + if (log_file.empty()) + instance = std::make_unique<Logger>(log_level); + else + instance = std::make_unique<Logger>(log_level, log_file); + } + return instance; +} + +std::ostream& Logger::get_stream(const int lvl) +{ + if (lvl >= this->log_level) + return this->stream; + return this->null_stream; +} diff --git a/louloulibs/logger/logger.hpp b/louloulibs/logger/logger.hpp new file mode 100644 index 0000000..0893c77 --- /dev/null +++ b/louloulibs/logger/logger.hpp @@ -0,0 +1,126 @@ +#pragma once + + +/** + * Singleton used in logger macros to write into files or stdout, with + * various levels of severity. + * Only the macros should be used. + * @class Logger + */ + +#include <memory> +#include <iostream> +#include <fstream> + +#define debug_lvl 0 +#define info_lvl 1 +#define warning_lvl 2 +#define error_lvl 3 + +#include "louloulibs.h" +#ifdef SYSTEMD_FOUND +# include <systemd/sd-daemon.h> +#else +# define SD_DEBUG "[DEBUG]: " +# define SD_INFO "[INFO]: " +# define SD_WARNING "[WARNING]: " +# define SD_ERR "[ERROR]: " +#endif + +// Macro defined to get the filename instead of the full path. But if it is +// not properly defined by the build system, we fallback to __FILE__ +#ifndef __FILENAME__ +# define __FILENAME__ __FILE__ +#endif + +/** + * Juste a structure representing a stream doing nothing with its input. + */ +class nullstream: public std::ostream +{ +public: + nullstream(): + std::ostream(0) + { } +}; + +class Logger +{ +public: + static std::unique_ptr<Logger>& instance(); + std::ostream& get_stream(const int); + Logger(const int log_level, const std::string& log_file); + Logger(const int log_level); + + Logger(const Logger&) = delete; + Logger& operator=(const Logger&) = delete; + Logger(Logger&&) = delete; + Logger& operator=(Logger&&) = delete; + +private: + const int log_level; + std::ofstream ofstream; + nullstream null_stream; + std::ostream stream; +}; + +#define WHERE __FILENAME__, ":", __LINE__, ":\t" + +namespace logging_details +{ + template <typename T> + void log(std::ostream& os, const T& arg) + { + os << arg << std::endl; + } + + template <typename T, typename... U> + void log(std::ostream& os, const T& first, U&&... rest) + { + os << first; + log(os, std::forward<U>(rest)...); + } + + template <typename... U> + void log_debug(U&&... args) + { + auto& os = Logger::instance()->get_stream(debug_lvl); + os << SD_DEBUG; + log(os, std::forward<U>(args)...); + } + + template <typename... U> + void log_info(U&&... args) + { + auto& os = Logger::instance()->get_stream(info_lvl); + os << SD_INFO; + log(os, std::forward<U>(args)...); + } + + template <typename... U> + void log_warning(U&&... args) + { + auto& os = Logger::instance()->get_stream(warning_lvl); + os << SD_WARNING; + log(os, std::forward<U>(args)...); + } + + template <typename... U> + void log_error(U&&... args) + { + auto& os = Logger::instance()->get_stream(error_lvl); + os << SD_ERR; + log(os, std::forward<U>(args)...); + } +} + +#define log_info(...) logging_details::log_info(WHERE, __VA_ARGS__) + +#define log_warning(...) logging_details::log_warning(WHERE, __VA_ARGS__) + +#define log_error(...) logging_details::log_error(WHERE, __VA_ARGS__) + +#define log_debug(...) logging_details::log_debug(WHERE, __VA_ARGS__) + + + diff --git a/louloulibs/louloulibs.h.cmake b/louloulibs/louloulibs.h.cmake new file mode 100644 index 0000000..2feaf4e --- /dev/null +++ b/louloulibs/louloulibs.h.cmake @@ -0,0 +1,9 @@ +#define SYSTEM_NAME "${CMAKE_SYSTEM}" +#cmakedefine ICONV_SECOND_ARGUMENT_IS_CONST +#cmakedefine LIBIDN_FOUND +#cmakedefine SYSTEMD_FOUND +#cmakedefine POLLER ${POLLER} +#cmakedefine BOTAN_FOUND +#cmakedefine CARES_FOUND +#cmakedefine SOFTWARE_VERSION "${SOFTWARE_VERSION}" +#cmakedefine PROJECT_NAME "${PROJECT_NAME}"
\ No newline at end of file diff --git a/louloulibs/network/credentials_manager.cpp b/louloulibs/network/credentials_manager.cpp new file mode 100644 index 0000000..ee83c3b --- /dev/null +++ b/louloulibs/network/credentials_manager.cpp @@ -0,0 +1,116 @@ +#include "louloulibs.h" + +#ifdef BOTAN_FOUND +#include <network/tcp_socket_handler.hpp> +#include <network/credentials_manager.hpp> +#include <logger/logger.hpp> +#include <botan/tls_exceptn.h> +#include <config/config.hpp> + +#ifdef USE_DATABASE +# include <database/database.hpp> +#endif + +/** + * TODO find a standard way to find that out. + */ +static const std::vector<std::string> default_cert_files = { + "/etc/ssl/certs/ca-bundle.crt", + "/etc/pki/tls/certs/ca-bundle.crt", + "/etc/ssl/certs/ca-certificates.crt", + "/etc/ca-certificates/extracted/tls-ca-bundle.pem" +}; + +Botan::Certificate_Store_In_Memory BasicCredentialsManager::certificate_store; +bool BasicCredentialsManager::certs_loaded = false; + +BasicCredentialsManager::BasicCredentialsManager(const TCPSocketHandler* const socket_handler): + Botan::Credentials_Manager(), + socket_handler(socket_handler), + trusted_fingerprint{} +{ + this->load_certs(); +} + +void BasicCredentialsManager::set_trusted_fingerprint(const std::string& fingerprint) +{ + this->trusted_fingerprint = fingerprint; +} + +void BasicCredentialsManager::verify_certificate_chain(const std::string& type, + const std::string& purported_hostname, + const std::vector<Botan::X509_Certificate>& certs) +{ + log_debug("Checking remote certificate (", type, ") for hostname ", purported_hostname); + try + { + Botan::Credentials_Manager::verify_certificate_chain(type, purported_hostname, certs); + log_debug("Certificate is valid"); + } + catch (const std::exception& tls_exception) + { + log_warning("TLS certificate check failed: ", tls_exception.what()); + if (!this->trusted_fingerprint.empty() && !certs.empty() && + this->trusted_fingerprint == certs[0].fingerprint() && + certs[0].matches_dns_name(purported_hostname)) + // We trust the certificate, based on the trusted fingerprint and + // the fact that the hostname matches + return; + + if (this->socket_handler->abort_on_invalid_cert()) + throw; + } +} + +void BasicCredentialsManager::load_certs() +{ + // Only load the certificates the first time + if (BasicCredentialsManager::certs_loaded) + return; + const std::string conf_path = Config::get("ca_file", ""); + std::vector<std::string> paths; + if (conf_path.empty()) + paths = default_cert_files; + else + paths.push_back(conf_path); + for (const auto& path: paths) + { + try + { + Botan::DataSource_Stream bundle(path); + log_debug("Using ca bundle: ", path); + while (!bundle.end_of_data() && bundle.check_available(27)) + { + // TODO: remove this work-around for Botan 1.11.29 + // https://github.com/randombit/botan/issues/438#issuecomment-192866796 + // Note that every certificate that fails to be transcoded into latin-1 + // will be ignored. As a result, some TLS connection may be refused + // because the certificate is signed by an issuer that was ignored. + try { + const Botan::X509_Certificate cert(bundle); + BasicCredentialsManager::certificate_store.add_certificate(cert); + } catch (const Botan::Decoding_Error& error) + { + continue; + } + } + // Only use the first file that can successfully be read. + goto success; + } + catch (Botan::Stream_IO_Error& e) + { + log_debug(e.what()); + } + } + // If we could not open one of the files, print a warning + log_warning("The CA could not be loaded, TLS negociation will probably fail."); + success: + BasicCredentialsManager::certs_loaded = true; +} + +std::vector<Botan::Certificate_Store*> BasicCredentialsManager::trusted_certificate_authorities(const std::string&, const std::string&) +{ + return {&this->certificate_store}; +} + +#endif diff --git a/louloulibs/network/credentials_manager.hpp b/louloulibs/network/credentials_manager.hpp new file mode 100644 index 0000000..0fc4b89 --- /dev/null +++ b/louloulibs/network/credentials_manager.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include "louloulibs.h" + +#ifdef BOTAN_FOUND + +#include <botan/botan.h> +#include <botan/tls_client.h> + +class TCPSocketHandler; + +class BasicCredentialsManager: public Botan::Credentials_Manager +{ +public: + BasicCredentialsManager(const TCPSocketHandler* const socket_handler); + + BasicCredentialsManager(BasicCredentialsManager&&) = delete; + BasicCredentialsManager(const BasicCredentialsManager&) = delete; + BasicCredentialsManager& operator=(const BasicCredentialsManager&) = delete; + BasicCredentialsManager& operator=(BasicCredentialsManager&&) = delete; + + void verify_certificate_chain(const std::string& type, + const std::string& purported_hostname, + const std::vector<Botan::X509_Certificate>&) override final; + std::vector<Botan::Certificate_Store*> trusted_certificate_authorities(const std::string& type, + const std::string& context) override final; + void set_trusted_fingerprint(const std::string& fingerprint); + +private: + const TCPSocketHandler* const socket_handler; + + static void load_certs(); + static Botan::Certificate_Store_In_Memory certificate_store; + static bool certs_loaded; + std::string trusted_fingerprint; +}; + +#endif //BOTAN_FOUND + diff --git a/louloulibs/network/dns_handler.cpp b/louloulibs/network/dns_handler.cpp new file mode 100644 index 0000000..e267944 --- /dev/null +++ b/louloulibs/network/dns_handler.cpp @@ -0,0 +1,134 @@ +#include <louloulibs.h> +#ifdef CARES_FOUND + +#include <network/dns_socket_handler.hpp> +#include <network/dns_handler.hpp> +#include <network/poller.hpp> + +#include <utils/timed_events.hpp> + +#include <algorithm> +#include <stdexcept> + +DNSHandler DNSHandler::instance; + +using namespace std::string_literals; +DNSHandler::DNSHandler(): + socket_handlers{}, + channel{nullptr} +{ + int ares_error; + if ((ares_error = ::ares_library_init(ARES_LIB_INIT_ALL)) != 0) + throw std::runtime_error("Failed to initialize c-ares lib: "s + ares_strerror(ares_error)); + struct ares_options options = {}; + // The default timeout values are way too high + options.timeout = 1000; + options.tries = 3; + if ((ares_error = ::ares_init_options(&this->channel, + &options, + ARES_OPT_TIMEOUTMS|ARES_OPT_TRIES)) != ARES_SUCCESS) + throw std::runtime_error("Failed to initialize c-ares channel: "s + ares_strerror(ares_error)); +} + +ares_channel& DNSHandler::get_channel() +{ + return this->channel; +} + +void DNSHandler::destroy() +{ + this->remove_all_sockets_from_poller(); + this->socket_handlers.clear(); + ::ares_destroy(this->channel); + ::ares_library_cleanup(); +} + +void DNSHandler::gethostbyname(const std::string& name, ares_host_callback callback, + void* data, int family) +{ + if (family == AF_INET) + ::ares_gethostbyname(this->channel, name.data(), family, + callback, data); + else + ::ares_gethostbyname(this->channel, name.data(), family, + callback, data); +} + +void DNSHandler::watch_dns_sockets(std::shared_ptr<Poller>& poller) +{ + fd_set readers; + fd_set writers; + + FD_ZERO(&readers); + FD_ZERO(&writers); + + int ndfs = ::ares_fds(this->channel, &readers, &writers); + // For each existing DNS socket, see if we are still supposed to watch it, + // if not then erase it + this->socket_handlers.erase( + std::remove_if(this->socket_handlers.begin(), this->socket_handlers.end(), + [&readers](const auto& dns_socket) + { + return !FD_ISSET(dns_socket->get_socket(), &readers); + }), + this->socket_handlers.end()); + + for (auto i = 0; i < ndfs; ++i) + { + bool read = FD_ISSET(i, &readers); + bool write = FD_ISSET(i, &writers); + // Look for the DNSSocketHandler with this fd + auto it = std::find_if(this->socket_handlers.begin(), + this->socket_handlers.end(), + [i](const auto& socket_handler) + { + return i == socket_handler->get_socket(); + }); + if (!read && !write) // No need to read or write to it + { // If found, erase it and stop watching it because it is not + // needed anymore + if (it != this->socket_handlers.end()) + // The socket destructor removes it from the poller + this->socket_handlers.erase(it); + } + else // We need to write and/or read to it + { // If not found, create it because we need to watch it + if (it == this->socket_handlers.end()) + { + this->socket_handlers.emplace(this->socket_handlers.begin(), + std::make_unique<DNSSocketHandler>(poller, *this, i)); + it = this->socket_handlers.begin(); + } + poller->add_socket_handler(it->get()); + if (write) + poller->watch_send_events(it->get()); + } + } + // Cancel previous timer, if any. + TimedEventsManager::instance().cancel("DNS timeout"); + struct timeval tv; + struct timeval* tvp; + tvp = ::ares_timeout(this->channel, NULL, &tv); + if (tvp) + { + auto future_time = std::chrono::steady_clock::now() + std::chrono::seconds(tvp->tv_sec) + \ + std::chrono::microseconds(tvp->tv_usec); + TimedEventsManager::instance().add_event(TimedEvent(std::move(future_time), + [this]() + { + for (auto& dns_socket_handler: this->socket_handlers) + dns_socket_handler->on_recv(); + }, + "DNS timeout")); + } +} + +void DNSHandler::remove_all_sockets_from_poller() +{ + for (const auto& socket_handler: this->socket_handlers) + { + socket_handler->remove_from_poller(); + } +} + +#endif /* CARES_FOUND */ diff --git a/louloulibs/network/dns_handler.hpp b/louloulibs/network/dns_handler.hpp new file mode 100644 index 0000000..fd1729d --- /dev/null +++ b/louloulibs/network/dns_handler.hpp @@ -0,0 +1,58 @@ +#pragma once + +#include <louloulibs.h> +#ifdef CARES_FOUND + +class TCPSocketHandler; +class Poller; +class DNSSocketHandler; + +# include <ares.h> +# include <memory> +# include <string> +# include <vector> + +/** + * Class managing DNS resolution. It should only be statically instanciated + * once in SocketHandler. It manages ares channel and calls various + * functions of that library. + */ + +class DNSHandler +{ +public: + DNSHandler(); + ~DNSHandler() = default; + DNSHandler(const DNSHandler&) = delete; + DNSHandler(DNSHandler&&) = delete; + DNSHandler& operator=(const DNSHandler&) = delete; + DNSHandler& operator=(DNSHandler&&) = delete; + + void gethostbyname(const std::string& name, ares_host_callback callback, + void* socket_handler, int family); + /** + * Call ares_fds to know what fd needs to be watched by the poller, create + * or destroy DNSSocketHandlers depending on the result. + */ + void watch_dns_sockets(std::shared_ptr<Poller>& poller); + /** + * Destroy and stop watching all the DNS sockets. Then de-init the channel + * and library. + */ + void destroy(); + void remove_all_sockets_from_poller(); + ares_channel& get_channel(); + + static DNSHandler instance; + +private: + /** + * The list of sockets that needs to be watched, according to the last + * call to ares_fds. DNSSocketHandlers are added to it or removed from it + * in the watch_dns_sockets() method + */ + std::vector<std::unique_ptr<DNSSocketHandler>> socket_handlers; + ares_channel channel; +}; + +#endif /* CARES_FOUND */ diff --git a/louloulibs/network/dns_socket_handler.cpp b/louloulibs/network/dns_socket_handler.cpp new file mode 100644 index 0000000..5fd08cb --- /dev/null +++ b/louloulibs/network/dns_socket_handler.cpp @@ -0,0 +1,48 @@ +#include <louloulibs.h> +#ifdef CARES_FOUND + +#include <network/dns_socket_handler.hpp> +#include <network/dns_handler.hpp> +#include <network/poller.hpp> + +#include <ares.h> + +DNSSocketHandler::DNSSocketHandler(std::shared_ptr<Poller> poller, + DNSHandler& handler, + const socket_t socket): + SocketHandler(poller, socket), + handler(handler) +{ +} + +void DNSSocketHandler::connect() +{ +} + +void DNSSocketHandler::on_recv() +{ + // always stop watching send and read events. We will re-watch them if the + // next call to ares_fds tell us to + this->handler.remove_all_sockets_from_poller(); + ::ares_process_fd(DNSHandler::instance.get_channel(), this->socket, ARES_SOCKET_BAD); +} + +void DNSSocketHandler::on_send() +{ + // always stop watching send and read events. We will re-watch them if the + // next call to ares_fds tell us to + this->handler.remove_all_sockets_from_poller(); + ::ares_process_fd(DNSHandler::instance.get_channel(), ARES_SOCKET_BAD, this->socket); +} + +bool DNSSocketHandler::is_connected() const +{ + return true; +} + +void DNSSocketHandler::remove_from_poller() +{ + this->poller->remove_socket_handler(this->socket); +} + +#endif /* CARES_FOUND */ diff --git a/louloulibs/network/dns_socket_handler.hpp b/louloulibs/network/dns_socket_handler.hpp new file mode 100644 index 0000000..0570196 --- /dev/null +++ b/louloulibs/network/dns_socket_handler.hpp @@ -0,0 +1,49 @@ +#pragma once + +#include <louloulibs.h> +#ifdef CARES_FOUND + +#include <network/socket_handler.hpp> +#include <ares.h> + +/** + * Manage a socket returned by ares_fds. We do not create, open or close the + * socket ourself: this is done by c-ares. We just call ares_process_fd() + * with the correct parameters, depending on what can be done on that socket + * (Poller reported it to be writable or readeable) + */ + +class DNSHandler; + +class DNSSocketHandler: public SocketHandler +{ +public: + explicit DNSSocketHandler(std::shared_ptr<Poller> poller, DNSHandler& handler, const socket_t socket); + ~DNSSocketHandler() = default; + DNSSocketHandler(const DNSSocketHandler&) = delete; + DNSSocketHandler(DNSSocketHandler&&) = delete; + DNSSocketHandler& operator=(const DNSSocketHandler&) = delete; + DNSSocketHandler& operator=(DNSSocketHandler&&) = delete; + + /** + * Just call dns_process_fd, c-ares will do its work of send()ing or + * recv()ing the data it wants on that socket. + */ + void on_recv() override final; + void on_send() override final; + /** + * Do nothing, because we are always considered to be connected, since the + * connection is done by c-ares and not by us. + */ + void connect() override final; + /** + * Always true, see the comment for connect() + */ + bool is_connected() const override final; + void remove_from_poller(); + +private: + DNSHandler& handler; +}; + +#endif // CARES_FOUND diff --git a/louloulibs/network/poller.cpp b/louloulibs/network/poller.cpp new file mode 100644 index 0000000..8a6fd97 --- /dev/null +++ b/louloulibs/network/poller.cpp @@ -0,0 +1,228 @@ +#include <network/poller.hpp> +#include <logger/logger.hpp> +#include <utils/timed_events.hpp> + +#include <assert.h> +#include <errno.h> +#include <stdio.h> +#include <signal.h> +#include <unistd.h> + +#include <cstring> +#include <iostream> +#include <stdexcept> + +Poller::Poller() +{ +#if POLLER == POLL + this->nfds = 0; +#elif POLLER == EPOLL + this->epfd = ::epoll_create1(0); + if (this->epfd == -1) + { + log_error("epoll failed: ", strerror(errno)); + throw std::runtime_error("Could not create epoll instance"); + } +#endif +} + +Poller::~Poller() +{ +#if POLLER == EPOLL + if (this->epfd > 0) + ::close(this->epfd); +#endif +} + +void Poller::add_socket_handler(SocketHandler* socket_handler) +{ + // Don't do anything if the socket is already managed + const auto it = this->socket_handlers.find(socket_handler->get_socket()); + if (it != this->socket_handlers.end()) + return ; + + this->socket_handlers.emplace(socket_handler->get_socket(), socket_handler); + + // We always watch all sockets for receive events +#if POLLER == POLL + this->fds[this->nfds].fd = socket_handler->get_socket(); + this->fds[this->nfds].events = POLLIN; + this->nfds++; +#endif +#if POLLER == EPOLL + struct epoll_event event = {EPOLLIN, {socket_handler}}; + const int res = ::epoll_ctl(this->epfd, EPOLL_CTL_ADD, socket_handler->get_socket(), &event); + if (res == -1) + { + log_error("epoll_ctl failed: ", strerror(errno)); + throw std::runtime_error("Could not add socket to epoll"); + } +#endif +} + +void Poller::remove_socket_handler(const socket_t socket) +{ + const auto it = this->socket_handlers.find(socket); + if (it == this->socket_handlers.end()) + throw std::runtime_error("Trying to remove a SocketHandler that is not managed"); + this->socket_handlers.erase(it); + +#if POLLER == POLL + for (size_t i = 0; i < this->nfds; i++) + { + if (this->fds[i].fd == socket) + { + // Move all subsequent pollfd by one on the left, erasing the + // value of the one we remove + for (size_t j = i; j < this->nfds - 1; ++j) + { + this->fds[j].fd = this->fds[j+1].fd; + this->fds[j].events= this->fds[j+1].events; + } + this->nfds--; + } + } +#elif POLLER == EPOLL + const int res = ::epoll_ctl(this->epfd, EPOLL_CTL_DEL, socket, nullptr); + if (res == -1) + { + log_error("epoll_ctl failed: ", strerror(errno)); + throw std::runtime_error("Could not remove socket from epoll"); + } +#endif +} + +void Poller::watch_send_events(SocketHandler* socket_handler) +{ +#if POLLER == POLL + for (size_t i = 0; i <= this->nfds; ++i) + { + if (this->fds[i].fd == socket_handler->get_socket()) + { + this->fds[i].events = POLLIN|POLLOUT; + return; + } + } + throw std::runtime_error("Cannot watch a non-registered socket for send events"); +#elif POLLER == EPOLL + struct epoll_event event = {EPOLLIN|EPOLLOUT, {socket_handler}}; + const int res = ::epoll_ctl(this->epfd, EPOLL_CTL_MOD, socket_handler->get_socket(), &event); + if (res == -1) + { + log_error("epoll_ctl failed: ", strerror(errno)); + throw std::runtime_error("Could not modify socket flags in epoll"); + } +#endif +} + +void Poller::stop_watching_send_events(SocketHandler* socket_handler) +{ +#if POLLER == POLL + for (size_t i = 0; i <= this->nfds; ++i) + { + if (this->fds[i].fd == socket_handler->get_socket()) + { + this->fds[i].events = POLLIN; + return; + } + } + throw std::runtime_error("Cannot watch a non-registered socket for send events"); +#elif POLLER == EPOLL + struct epoll_event event = {EPOLLIN, {socket_handler}}; + const int res = ::epoll_ctl(this->epfd, EPOLL_CTL_MOD, socket_handler->get_socket(), &event); + if (res == -1) + { + log_error("epoll_ctl failed: ", strerror(errno)); + throw std::runtime_error("Could not modify socket flags in epoll"); + } +#endif +} + +int Poller::poll(const std::chrono::milliseconds& timeout) +{ + if (this->socket_handlers.empty() && timeout == utils::no_timeout) + return -1; +#if POLLER == POLL + // Convert our nice timeout into this ugly struct + struct timespec timeout_ts; + struct timespec* timeout_tsp; + if (timeout > 0s) + { + auto seconds = std::chrono::duration_cast<std::chrono::seconds>(timeout); + timeout_ts.tv_sec = seconds.count(); + timeout_ts.tv_nsec = std::chrono::duration_cast<std::chrono::nanoseconds>(timeout - seconds).count(); + timeout_tsp = &timeout_ts; + } + else + timeout_tsp = nullptr; + + // Unblock all signals, only during the ppoll call + sigset_t empty_signal_set; + sigemptyset(&empty_signal_set); + int nb_events = ::ppoll(this->fds, this->nfds, timeout_tsp, + &empty_signal_set); + if (nb_events < 0) + { + if (errno == EINTR) + return true; + log_error("poll failed: ", strerror(errno)); + throw std::runtime_error("Poll failed"); + } + // We cannot possibly have more ready events than the number of fds we are + // watching + assert(static_cast<unsigned int>(nb_events) <= this->nfds); + for (size_t i = 0; i <= this->nfds && nb_events != 0; ++i) + { + auto socket_handler = this->socket_handlers.at(this->fds[i].fd); + if (this->fds[i].revents == 0) + continue; + else if (this->fds[i].revents & POLLIN && socket_handler->is_connected()) + { + socket_handler->on_recv(); + nb_events--; + } + else if (this->fds[i].revents & POLLOUT && socket_handler->is_connected()) + { + socket_handler->on_send(); + nb_events--; + } + else if (this->fds[i].revents & POLLOUT) + { + socket_handler->connect(); + nb_events--; + } + } + return 1; +#elif POLLER == EPOLL + static const size_t max_events = 12; + struct epoll_event revents[max_events]; + // Unblock all signals, only during the epoll_pwait call + sigset_t empty_signal_set; + sigemptyset(&empty_signal_set); + const int nb_events = ::epoll_pwait(this->epfd, revents, max_events, timeout.count(), + &empty_signal_set); + if (nb_events == -1) + { + if (errno == EINTR) + return 0; + log_error("epoll wait: ", strerror(errno)); + throw std::runtime_error("Epoll_wait failed"); + } + for (int i = 0; i < nb_events; ++i) + { + auto socket_handler = static_cast<SocketHandler*>(revents[i].data.ptr); + if (revents[i].events & EPOLLIN && socket_handler->is_connected()) + socket_handler->on_recv(); + else if (revents[i].events & EPOLLOUT && socket_handler->is_connected()) + socket_handler->on_send(); + else if (revents[i].events & EPOLLOUT) + socket_handler->connect(); + } + return nb_events; +#endif +} + +size_t Poller::size() const +{ + return this->socket_handlers.size(); +} diff --git a/louloulibs/network/poller.hpp b/louloulibs/network/poller.hpp new file mode 100644 index 0000000..fc1a1a1 --- /dev/null +++ b/louloulibs/network/poller.hpp @@ -0,0 +1,94 @@ +#pragma once + + +#include <network/socket_handler.hpp> + +#include <unordered_map> +#include <memory> +#include <chrono> + +#define POLL 1 +#define EPOLL 2 +#define KQUEUE 3 +#include <louloulibs.h> +#ifndef POLLER + #define POLLER POLL +#endif + +#if POLLER == POLL + #include <poll.h> + #define MAX_POLL_FD_NUMBER 4096 +#elif POLLER == EPOLL + #include <sys/epoll.h> +#else + #error Invalid POLLER value +#endif + +/** + * We pass some SocketHandlers to this Poller, which uses + * poll/epoll/kqueue/select etc to wait for events on these SocketHandlers, + * and call the callbacks when event occurs. + * + * TODO: support these pollers: + * - kqueue(2) + */ + +class Poller +{ +public: + explicit Poller(); + ~Poller(); + Poller(const Poller&) = delete; + Poller(Poller&&) = delete; + Poller& operator=(const Poller&) = delete; + Poller& operator=(Poller&&) = delete; + /** + * Add a SocketHandler to be monitored by this Poller. All receive events + * are always automatically watched. + */ + void add_socket_handler(SocketHandler* socket_handler); + /** + * Remove (and stop managing) a SocketHandler, designated by the given socket_t. + */ + void remove_socket_handler(const socket_t socket); + /** + * Signal the poller that he needs to watch for send events for the given + * SocketHandler. + */ + void watch_send_events(SocketHandler* socket_handler); + /** + * Signal the poller that he needs to stop watching for send events for + * this SocketHandler. + */ + void stop_watching_send_events(SocketHandler* socket_handler); + /** + * Wait for all watched events, and call the SocketHandlers' callbacks + * when one is ready. Returns if nothing happened before the provided + * timeout. If the timeout is 0, it waits forever. If there is no + * watched event, returns -1 immediately, ignoring the timeout value. + * Otherwise, returns the number of event handled. If 0 is returned this + * means that we were interrupted by a signal, or the timeout occured. + */ + int poll(const std::chrono::milliseconds& timeout); + /** + * Returns the number of SocketHandlers managed by the poller. + */ + size_t size() const; + +private: + /** + * A "list" of all the SocketHandlers that we manage, indexed by socket, + * because that's what is returned by select/poll/etc when an event + * occures. + */ + std::unordered_map<socket_t, SocketHandler*> socket_handlers; + +#if POLLER == POLL + struct pollfd fds[MAX_POLL_FD_NUMBER]; + nfds_t nfds; +#elif POLLER == EPOLL + int epfd; +#endif +}; + + diff --git a/louloulibs/network/resolver.cpp b/louloulibs/network/resolver.cpp new file mode 100644 index 0000000..9d6de23 --- /dev/null +++ b/louloulibs/network/resolver.cpp @@ -0,0 +1,214 @@ +#include <network/dns_handler.hpp> +#include <network/resolver.hpp> +#include <string.h> +#include <arpa/inet.h> + +using namespace std::string_literals; + +Resolver::Resolver(): +#ifdef CARES_FOUND + resolved4(false), + resolved6(false), + resolving(false), + cares_addrinfo(nullptr), + port{}, +#endif + resolved(false), + error_msg{} +{ +} + +void Resolver::resolve(const std::string& hostname, const std::string& port, + SuccessCallbackType success_cb, ErrorCallbackType error_cb) +{ + this->error_cb = error_cb; + this->success_cb = success_cb; +#ifdef CARES_FOUND + this->port = port; +#endif + + this->start_resolving(hostname, port); +} + +#ifdef CARES_FOUND +void Resolver::start_resolving(const std::string& hostname, const std::string&) +{ + this->resolving = true; + this->resolved = false; + this->resolved4 = false; + this->resolved6 = false; + + this->error_msg.clear(); + this->cares_addrinfo = nullptr; + + auto hostname4_resolved = [](void* arg, int status, int, + struct hostent* hostent) + { + Resolver* resolver = static_cast<Resolver*>(arg); + resolver->on_hostname4_resolved(status, hostent); + }; + auto hostname6_resolved = [](void* arg, int status, int, + struct hostent* hostent) + { + Resolver* resolver = static_cast<Resolver*>(arg); + resolver->on_hostname6_resolved(status, hostent); + }; + + DNSHandler::instance.gethostbyname(hostname, hostname6_resolved, + this, AF_INET6); + DNSHandler::instance.gethostbyname(hostname, hostname4_resolved, + this, AF_INET); +} + +void Resolver::on_hostname4_resolved(int status, struct hostent* hostent) +{ + this->resolved4 = true; + if (status == ARES_SUCCESS) + this->fill_ares_addrinfo4(hostent); + else + this->error_msg = ::ares_strerror(status); + + if (this->resolved4 && this->resolved6) + this->on_resolved(); +} + +void Resolver::on_hostname6_resolved(int status, struct hostent* hostent) +{ + this->resolved6 = true; + if (status == ARES_SUCCESS) + this->fill_ares_addrinfo6(hostent); + else + this->error_msg = ::ares_strerror(status); + + if (this->resolved4 && this->resolved6) + this->on_resolved(); +} + +void Resolver::on_resolved() +{ + this->resolved = true; + this->resolving = false; + if (!this->cares_addrinfo) + { + if (this->error_cb) + this->error_cb(this->error_msg.data()); + } + else + { + this->addr.reset(this->cares_addrinfo); + if (this->success_cb) + this->success_cb(this->addr.get()); + } +} + +void Resolver::fill_ares_addrinfo4(const struct hostent* hostent) +{ + struct addrinfo* prev = this->cares_addrinfo; + struct in_addr** address = reinterpret_cast<struct in_addr**>(hostent->h_addr_list); + + while (*address) + { + // Create a new addrinfo list element, and fill it + struct addrinfo* current = new struct addrinfo; + current->ai_flags = 0; + current->ai_family = hostent->h_addrtype; + current->ai_socktype = SOCK_STREAM; + current->ai_protocol = 0; + current->ai_addrlen = sizeof(struct sockaddr_in); + + struct sockaddr_in* addr = new struct sockaddr_in; + addr->sin_family = hostent->h_addrtype; + addr->sin_port = htons(strtoul(this->port.data(), nullptr, 10)); + addr->sin_addr.s_addr = (*address)->s_addr; + + current->ai_addr = reinterpret_cast<struct sockaddr*>(addr); + current->ai_next = nullptr; + current->ai_canonname = nullptr; + + current->ai_next = prev; + this->cares_addrinfo = current; + prev = current; + ++address; + } +} + +void Resolver::fill_ares_addrinfo6(const struct hostent* hostent) +{ + struct addrinfo* prev = this->cares_addrinfo; + struct in6_addr** address = reinterpret_cast<struct in6_addr**>(hostent->h_addr_list); + + while (*address) + { + // Create a new addrinfo list element, and fill it + struct addrinfo* current = new struct addrinfo; + current->ai_flags = 0; + current->ai_family = hostent->h_addrtype; + current->ai_socktype = SOCK_STREAM; + current->ai_protocol = 0; + current->ai_addrlen = sizeof(struct sockaddr_in6); + + struct sockaddr_in6* addr = new struct sockaddr_in6; + addr->sin6_family = hostent->h_addrtype; + addr->sin6_port = htons(strtoul(this->port.data(), nullptr, 10)); + ::memcpy(addr->sin6_addr.s6_addr, (*address)->s6_addr, 16); + addr->sin6_flowinfo = 0; + addr->sin6_scope_id = 0; + + current->ai_addr = reinterpret_cast<struct sockaddr*>(addr); + current->ai_canonname = nullptr; + + current->ai_next = prev; + this->cares_addrinfo = current; + prev = current; + ++address; + } +} + +#else // ifdef CARES_FOUND + +void Resolver::start_resolving(const std::string& hostname, const std::string& port) +{ + // If the resolution fails, the addr will be unset + this->addr.reset(nullptr); + + struct addrinfo hints; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_flags = 0; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + struct addrinfo* addr_res = nullptr; + const int res = ::getaddrinfo(hostname.data(), port.data(), + &hints, &addr_res); + + this->resolved = true; + + if (res != 0) + { + this->error_msg = gai_strerror(res); + if (this->error_cb) + this->error_cb(this->error_msg.data()); + } + else + { + this->addr.reset(addr_res); + if (this->success_cb) + this->success_cb(this->addr.get()); + } +} +#endif // ifdef CARES_FOUND + +std::string addr_to_string(const struct addrinfo* rp) +{ + char buf[INET6_ADDRSTRLEN]; + if (rp->ai_family == AF_INET) + return ::inet_ntop(rp->ai_family, + &reinterpret_cast<sockaddr_in*>(rp->ai_addr)->sin_addr, + buf, sizeof(buf)); + else if (rp->ai_family == AF_INET6) + return ::inet_ntop(rp->ai_family, + &reinterpret_cast<sockaddr_in6*>(rp->ai_addr)->sin6_addr, + buf, sizeof(buf)); + return {}; +} diff --git a/louloulibs/network/resolver.hpp b/louloulibs/network/resolver.hpp new file mode 100644 index 0000000..afe6e2b --- /dev/null +++ b/louloulibs/network/resolver.hpp @@ -0,0 +1,128 @@ +#pragma once + + +#include "louloulibs.h" + +#include <functional> +#include <memory> +#include <string> + +#include <sys/types.h> +#include <sys/socket.h> +#include <netdb.h> + +struct AddrinfoDeleter +{ + void operator()(struct addrinfo* addr) + { +#ifdef CARES_FOUND + while (addr) + { + delete addr->ai_addr; + auto next = addr->ai_next; + delete addr; + addr = next; + } +#else + freeaddrinfo(addr); +#endif + } +}; + +class Resolver +{ +public: + using ErrorCallbackType = std::function<void(const char*)>; + using SuccessCallbackType = std::function<void(const struct addrinfo*)>; + + Resolver(); + ~Resolver() = default; + Resolver(const Resolver&) = delete; + Resolver(Resolver&&) = delete; + Resolver& operator=(const Resolver&) = delete; + Resolver& operator=(Resolver&&) = delete; + + bool is_resolving() const + { +#ifdef CARES_FOUND + return this->resolving; +#else + return false; +#endif + } + + bool is_resolved() const + { + return this->resolved; + } + + const auto& get_result() const + { + return this->addr; + } + std::string get_error_message() const + { + return this->error_msg; + } + + void clear() + { +#ifdef CARES_FOUND + this->resolved6 = false; + this->resolved4 = false; + this->resolving = false; + this->cares_addrinfo = nullptr; + this->port.clear(); +#endif + this->resolved = false; + this->addr.reset(); + this->error_msg.clear(); + } + + void resolve(const std::string& hostname, const std::string& port, + SuccessCallbackType success_cb, ErrorCallbackType error_cb); + +private: + void start_resolving(const std::string& hostname, const std::string& port); +#ifdef CARES_FOUND + void on_hostname4_resolved(int status, struct hostent* hostent); + void on_hostname6_resolved(int status, struct hostent* hostent); + + void fill_ares_addrinfo4(const struct hostent* hostent); + void fill_ares_addrinfo6(const struct hostent* hostent); + + void on_resolved(); + + bool resolved4; + bool resolved6; + + bool resolving; + + /** + * When using c-ares to resolve the host asynchronously, we need the + * c-ares callbacks to fill a structure (a struct addrinfo, for + * compatibility with getaddrinfo and the rest of the code that works when + * c-ares is not used) with all returned values (for example an IPv6 and + * an IPv4). The pointer is given to the unique_ptr to manage its lifetime. + */ + struct addrinfo* cares_addrinfo; + std::string port; + +#endif + /** + * Tells if we finished the resolution process. It doesn't indicate if it + * was successful (it is true even if the result is an error). + */ + bool resolved; + std::string error_msg; + + + std::unique_ptr<struct addrinfo, AddrinfoDeleter> addr; + + ErrorCallbackType error_cb; + SuccessCallbackType success_cb; +}; + +std::string addr_to_string(const struct addrinfo* rp); + + diff --git a/louloulibs/network/socket_handler.hpp b/louloulibs/network/socket_handler.hpp new file mode 100644 index 0000000..eeb41fe --- /dev/null +++ b/louloulibs/network/socket_handler.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include <louloulibs.h> +#include <memory> + +class Poller; + +using socket_t = int; + +class SocketHandler +{ +public: + explicit SocketHandler(std::shared_ptr<Poller> poller, const socket_t socket): + poller(poller), + socket(socket) + {} + virtual ~SocketHandler() {} + SocketHandler(const SocketHandler&) = delete; + SocketHandler(SocketHandler&&) = delete; + SocketHandler& operator=(const SocketHandler&) = delete; + SocketHandler& operator=(SocketHandler&&) = delete; + + virtual void on_recv() = 0; + virtual void on_send() = 0; + virtual void connect() = 0; + virtual bool is_connected() const = 0; + + socket_t get_socket() const + { return this->socket; } + +protected: + /** + * A pointer to the poller that manages us, because we need to communicate + * with it. + */ + std::shared_ptr<Poller> poller; + /** + * The handled socket. + */ + socket_t socket; +}; + diff --git a/louloulibs/network/tcp_socket_handler.cpp b/louloulibs/network/tcp_socket_handler.cpp new file mode 100644 index 0000000..5420b1c --- /dev/null +++ b/louloulibs/network/tcp_socket_handler.cpp @@ -0,0 +1,501 @@ +#include <network/tcp_socket_handler.hpp> +#include <network/dns_handler.hpp> + +#include <utils/timed_events.hpp> +#include <utils/scopeguard.hpp> +#include <network/poller.hpp> + +#include <logger/logger.hpp> +#include <sys/socket.h> +#include <sys/types.h> +#include <stdexcept> +#include <unistd.h> +#include <errno.h> +#include <cstring> +#include <fcntl.h> + +#ifdef BOTAN_FOUND +# include <botan/hex.h> +# include <botan/tls_exceptn.h> + +Botan::AutoSeeded_RNG TCPSocketHandler::rng; +Botan::TLS::Policy TCPSocketHandler::policy; +Botan::TLS::Session_Manager_In_Memory TCPSocketHandler::session_manager(TCPSocketHandler::rng); + +#endif + +#ifndef UIO_FASTIOV +# define UIO_FASTIOV 8 +#endif + +using namespace std::string_literals; +using namespace std::chrono_literals; + +namespace ph = std::placeholders; + +TCPSocketHandler::TCPSocketHandler(std::shared_ptr<Poller> poller): + SocketHandler(poller, -1), + use_tls(false), + connected(false), + connecting(false), + hostname_resolution_failed(false) +#ifdef BOTAN_FOUND + ,credential_manager(this) +#endif +{} + +TCPSocketHandler::~TCPSocketHandler() +{ + this->close(); +} + + +void TCPSocketHandler::init_socket(const struct addrinfo* rp) +{ + if (this->socket != -1) + ::close(this->socket); + if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1) + throw std::runtime_error("Could not create socket: "s + strerror(errno)); + // Bind the socket to a specific address, if specified + if (!this->bind_addr.empty()) + { + // Convert the address from string format to a sockaddr that can be + // used in bind() + struct addrinfo* result; + int err = ::getaddrinfo(this->bind_addr.data(), nullptr, nullptr, &result); + if (err != 0 || !result) + log_error("Failed to bind socket to ", this->bind_addr, ": ", + gai_strerror(err)); + else + { + utils::ScopeGuard sg([result](){ freeaddrinfo(result); }); + struct addrinfo* rp; + int bind_error = 0; + for (rp = result; rp; rp = rp->ai_next) + { + if ((bind_error = ::bind(this->socket, + reinterpret_cast<const struct sockaddr*>(rp->ai_addr), + rp->ai_addrlen)) == 0) + break; + } + if (!rp) + log_error("Failed to bind socket to ", this->bind_addr, ": ", + strerror(bind_error)); + else + log_info("Socket successfully bound to ", this->bind_addr); + } + } + int optval = 1; + if (::setsockopt(this->socket, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) == -1) + log_warning("Failed to enable TCP keepalive on socket: ", strerror(errno)); + // Set the socket on non-blocking mode. This is useful to receive a EAGAIN + // error when connect() would block, to not block the whole process if a + // remote is not responsive. + const int existing_flags = ::fcntl(this->socket, F_GETFL, 0); + if ((existing_flags == -1) || + (::fcntl(this->socket, F_SETFL, existing_flags | O_NONBLOCK) == -1)) + throw std::runtime_error("Could not initialize socket: "s + strerror(errno)); +} + +void TCPSocketHandler::connect(const std::string& address, const std::string& port, const bool tls) +{ + this->address = address; + this->port = port; + this->use_tls = tls; + + utils::ScopeGuard sg; + + struct addrinfo* addr_res; + + if (!this->connecting) + { + // Get the addrinfo from getaddrinfo (or ares_gethostbyname), only if + // this is the first call of this function. + if (!this->resolver.is_resolved()) + { + log_info("Trying to connect to ", address, ":", port); + // Start the asynchronous process of resolving the hostname. Once + // the addresses have been found and `resolved` has been set to true + // (but connecting will still be false), TCPSocketHandler::connect() + // needs to be called, again. + this->resolver.resolve(address, port, + [this](const struct addrinfo*) + { + log_debug("Resolution success, calling connect() again"); + this->connect(); + }, + [this](const char*) + { + log_debug("Resolution failed, calling connect() again"); + this->connect(); + }); + return; + } + else + { + // The c-ares resolved the hostname and the available addresses + // where saved in the cares_addrinfo linked list. Now, just use + // this list to try to connect. + addr_res = this->resolver.get_result().get(); + if (!addr_res) + { + this->hostname_resolution_failed = true; + const auto msg = this->resolver.get_error_message(); + this->close(); + this->on_connection_failed(msg); + return ; + } + } + } + else + { // This function is called again, use the saved addrinfo structure, + // instead of re-doing the whole getaddrinfo process. + addr_res = &this->addrinfo; + } + + for (struct addrinfo* rp = addr_res; rp; rp = rp->ai_next) + { + if (!this->connecting) + { + try { + this->init_socket(rp); + } + catch (const std::runtime_error& error) { + log_error("Failed to init socket: ", error.what()); + break; + } + } + + this->display_resolved_ip(rp); + + if (::connect(this->socket, rp->ai_addr, rp->ai_addrlen) == 0 + || errno == EISCONN) + { + log_info("Connection success."); + TimedEventsManager::instance().cancel("connection_timeout"s + + std::to_string(this->socket)); + this->poller->add_socket_handler(this); + this->connected = true; + this->connecting = false; +#ifdef BOTAN_FOUND + if (this->use_tls) + this->start_tls(); +#endif + this->on_connected(); + return ; + } + else if (errno == EINPROGRESS || errno == EALREADY) + { // retry this process later, when the socket + // is ready to be written on. + this->connecting = true; + this->poller->add_socket_handler(this); + this->poller->watch_send_events(this); + // Save the addrinfo structure, to use it on the next call + this->ai_addrlen = rp->ai_addrlen; + memcpy(&this->ai_addr, rp->ai_addr, this->ai_addrlen); + memcpy(&this->addrinfo, rp, sizeof(struct addrinfo)); + this->addrinfo.ai_addr = reinterpret_cast<struct sockaddr*>(&this->ai_addr); + this->addrinfo.ai_next = nullptr; + // If the connection has not succeeded or failed in 5s, we consider + // it to have failed + TimedEventsManager::instance().add_event( + TimedEvent(std::chrono::steady_clock::now() + 5s, + std::bind(&TCPSocketHandler::on_connection_timeout, this), + "connection_timeout"s + std::to_string(this->socket))); + return ; + } + log_info("Connection failed:", strerror(errno)); + } + log_error("All connection attempts failed."); + this->close(); + this->on_connection_failed(strerror(errno)); + return ; +} + +void TCPSocketHandler::on_connection_timeout() +{ + this->close(); + this->on_connection_failed("connection timed out"); +} + +void TCPSocketHandler::connect() +{ + this->connect(this->address, this->port, this->use_tls); +} + +void TCPSocketHandler::on_recv() +{ +#ifdef BOTAN_FOUND + if (this->use_tls) + this->tls_recv(); + else +#endif + this->plain_recv(); +} + +void TCPSocketHandler::plain_recv() +{ + static constexpr size_t buf_size = 4096; + char buf[buf_size]; + void* recv_buf = this->get_receive_buffer(buf_size); + + if (recv_buf == nullptr) + recv_buf = buf; + + const ssize_t size = this->do_recv(recv_buf, buf_size); + + if (size > 0) + { + if (buf == recv_buf) + { + // data needs to be placed in the in_buf string, because no buffer + // was provided to receive that data directly. The in_buf buffer + // will be handled in parse_in_buffer() + this->in_buf += std::string(buf, size); + } + this->parse_in_buffer(size); + } +} + +ssize_t TCPSocketHandler::do_recv(void* recv_buf, const size_t buf_size) +{ + ssize_t size = ::recv(this->socket, recv_buf, buf_size, 0); + if (0 == size) + { + this->on_connection_close(""); + this->close(); + } + else if (-1 == size) + { + if (this->connecting) + log_warning("Error connecting: ", strerror(errno)); + else + log_warning("Error while reading from socket: ", strerror(errno)); + // Remember if we were connecting, or already connected when this + // happened, because close() sets this->connecting to false + const auto were_connecting = this->connecting; + this->close(); + if (were_connecting) + this->on_connection_failed(strerror(errno)); + else + this->on_connection_close(strerror(errno)); + } + return size; +} + +void TCPSocketHandler::on_send() +{ + struct iovec msg_iov[UIO_FASTIOV] = {}; + struct msghdr msg{nullptr, 0, + msg_iov, + 0, nullptr, 0, 0}; + for (const std::string& s: this->out_buf) + { + // unconsting the content of s is ok, sendmsg will never modify it + msg_iov[msg.msg_iovlen].iov_base = const_cast<char*>(s.data()); + msg_iov[msg.msg_iovlen].iov_len = s.size(); + if (++msg.msg_iovlen == UIO_FASTIOV) + break; + } + ssize_t res = ::sendmsg(this->socket, &msg, MSG_NOSIGNAL); + if (res < 0) + { + log_error("sendmsg failed: ", strerror(errno)); + this->on_connection_close(strerror(errno)); + this->close(); + } + else + { + // remove all the strings that were successfully sent. + for (auto it = this->out_buf.begin(); + it != this->out_buf.end();) + { + if (static_cast<size_t>(res) >= (*it).size()) + { + res -= (*it).size(); + it = this->out_buf.erase(it); + } + else + { + // If one string has partially been sent, we use substr to + // crop it + if (res > 0) + (*it) = (*it).substr(res, std::string::npos); + break; + } + } + if (this->out_buf.empty()) + this->poller->stop_watching_send_events(this); + } +} + +void TCPSocketHandler::close() +{ + TimedEventsManager::instance().cancel("connection_timeout"s + + std::to_string(this->socket)); + if (this->connected || this->connecting) + this->poller->remove_socket_handler(this->get_socket()); + if (this->socket != -1) + { + ::close(this->socket); + this->socket = -1; + } + this->connected = false; + this->connecting = false; + this->in_buf.clear(); + this->out_buf.clear(); + this->port.clear(); + this->resolver.clear(); +} + +void TCPSocketHandler::display_resolved_ip(struct addrinfo* rp) const +{ + if (rp->ai_family == AF_INET) + log_debug("Trying IPv4 address ", addr_to_string(rp)); + else if (rp->ai_family == AF_INET6) + log_debug("Trying IPv6 address ", addr_to_string(rp)); +} + +void TCPSocketHandler::send_data(std::string&& data) +{ +#ifdef BOTAN_FOUND + if (this->use_tls) + try { + this->tls_send(std::move(data)); + } catch (const Botan::TLS::TLS_Exception& e) { + this->on_connection_close("TLS error: "s + e.what()); + this->close(); + return ; + } + else +#endif + this->raw_send(std::move(data)); +} + +void TCPSocketHandler::raw_send(std::string&& data) +{ + if (data.empty()) + return ; + this->out_buf.emplace_back(std::move(data)); + if (this->connected) + this->poller->watch_send_events(this); +} + +void TCPSocketHandler::send_pending_data() +{ + if (this->connected && !this->out_buf.empty()) + this->poller->watch_send_events(this); +} + +bool TCPSocketHandler::is_connected() const +{ + return this->connected; +} + +bool TCPSocketHandler::is_connecting() const +{ + return this->connecting || this->resolver.is_resolving(); +} + +void* TCPSocketHandler::get_receive_buffer(const size_t) const +{ + return nullptr; +} + +#ifdef BOTAN_FOUND +void TCPSocketHandler::start_tls() +{ + Botan::TLS::Server_Information server_info(this->address, "irc", std::stoul(this->port)); + this->tls = std::make_unique<Botan::TLS::Client>( + std::bind(&TCPSocketHandler::tls_output_fn, this, ph::_1, ph::_2), + std::bind(&TCPSocketHandler::tls_data_cb, this, ph::_1, ph::_2), + std::bind(&TCPSocketHandler::tls_alert_cb, this, ph::_1, ph::_2, ph::_3), + std::bind(&TCPSocketHandler::tls_handshake_cb, this, ph::_1), + session_manager, this->credential_manager, policy, + rng, server_info, Botan::TLS::Protocol_Version::latest_tls_version()); +} + +void TCPSocketHandler::tls_recv() +{ + static constexpr size_t buf_size = 4096; + char recv_buf[buf_size]; + + const ssize_t size = this->do_recv(recv_buf, buf_size); + if (size > 0) + { + const bool was_active = this->tls->is_active(); + try { + this->tls->received_data(reinterpret_cast<const Botan::byte*>(recv_buf), + static_cast<size_t>(size)); + } catch (const Botan::TLS::TLS_Exception& e) { + // May happen if the server sends malformed TLS data (buggy server, + // or more probably we are just connected to a server that sends + // plain-text) + this->on_connection_close("TLS error: "s + e.what()); + this->close(); + return ; + } + if (!was_active && this->tls->is_active()) + this->on_tls_activated(); + } +} + +void TCPSocketHandler::tls_send(std::string&& data) +{ + // We may not be connected yet, or the tls session has + // not yet been negociated + if (this->tls && this->tls->is_active()) + { + const bool was_active = this->tls->is_active(); + if (!this->pre_buf.empty()) + { + this->tls->send(reinterpret_cast<const Botan::byte*>(this->pre_buf.data()), + this->pre_buf.size()); + this->pre_buf = ""; + } + if (!data.empty()) + this->tls->send(reinterpret_cast<const Botan::byte*>(data.data()), + data.size()); + if (!was_active && this->tls->is_active()) + this->on_tls_activated(); + } + else + this->pre_buf += data; +} + +void TCPSocketHandler::tls_data_cb(const Botan::byte* data, size_t size) +{ + this->in_buf += std::string(reinterpret_cast<const char*>(data), + size); + if (!this->in_buf.empty()) + this->parse_in_buffer(size); +} + +void TCPSocketHandler::tls_output_fn(const Botan::byte* data, size_t size) +{ + this->raw_send(std::string(reinterpret_cast<const char*>(data), size)); +} + +void TCPSocketHandler::tls_alert_cb(Botan::TLS::Alert alert, const Botan::byte*, size_t) +{ + log_debug("tls_alert: ", alert.type_string()); +} + +bool TCPSocketHandler::tls_handshake_cb(const Botan::TLS::Session& session) +{ + log_debug("Handshake with ", session.server_info().hostname(), " complete.", + " Version: ", session.version().to_string(), + " using ", session.ciphersuite().to_string()); + if (!session.session_id().empty()) + log_debug("Session ID ", Botan::hex_encode(session.session_id())); + if (!session.session_ticket().empty()) + log_debug("Session ticket ", Botan::hex_encode(session.session_ticket())); + return true; +} + +void TCPSocketHandler::on_tls_activated() +{ + this->send_data({}); +} + +#endif // BOTAN_FOUND diff --git a/louloulibs/network/tcp_socket_handler.hpp b/louloulibs/network/tcp_socket_handler.hpp new file mode 100644 index 0000000..b0ba493 --- /dev/null +++ b/louloulibs/network/tcp_socket_handler.hpp @@ -0,0 +1,274 @@ +#pragma once + + +#include "louloulibs.h" + +#include <network/socket_handler.hpp> +#include <network/resolver.hpp> + +#include <network/credentials_manager.hpp> + +#include <sys/types.h> +#include <sys/socket.h> +#include <netinet/in.h> +#include <netdb.h> + +#include <vector> +#include <memory> +#include <string> +#include <list> + +/** + * An interface, with a series of callbacks that should be implemented in + * subclasses that deal with a socket. These callbacks are called on various events + * (read/write/timeout, etc) when they are notified to a poller + * (select/poll/epoll etc) + */ +class TCPSocketHandler: public SocketHandler +{ +protected: + ~TCPSocketHandler(); +public: + explicit TCPSocketHandler(std::shared_ptr<Poller> poller); + TCPSocketHandler(const TCPSocketHandler&) = delete; + TCPSocketHandler(TCPSocketHandler&&) = delete; + TCPSocketHandler& operator=(const TCPSocketHandler&) = delete; + TCPSocketHandler& operator=(TCPSocketHandler&&) = delete; + + /** + * Connect to the remote server, and call on_connected() if this + * succeeds. If tls is true, we set use_tls to true and will also call + * start_tls() when the connection succeeds. + */ + void connect(const std::string& address, const std::string& port, const bool tls); + void connect() override final; + /** + * Reads raw data from the socket. And pass it to parse_in_buffer() + * If we are using TLS on this connection, we call tls_recv() + */ + void on_recv() override final; + /** + * Write as much data from out_buf as possible, in the socket. + */ + void on_send() override final; + /** + * Add the given data to out_buf and tell our poller that we want to be + * notified when a send event is ready. + * + * This can be overriden if we want to modify the data before sending + * it. For example if we want to encrypt it. + */ + void send_data(std::string&& data); + /** + * Watch the socket for send events, if our out buffer is not empty. + */ + void send_pending_data(); + /** + * Close the connection, remove us from the poller + */ + void close(); + /** + * Called by a TimedEvent, when the connection did not succeed or fail + * after a given time. + */ + void on_connection_timeout(); + /** + * Called when the connection is successful. + */ + virtual void on_connected() = 0; + /** + * Called when the connection fails. Not when it is closed later, just at + * the connect() call. + */ + virtual void on_connection_failed(const std::string& reason) = 0; + /** + * Called when we detect a disconnection from the remote host. + */ + virtual void on_connection_close(const std::string& error) = 0; + /** + * Handle/consume (some of) the data received so far. The data to handle + * may be in the in_buf buffer, or somewhere else, depending on what + * get_receive_buffer() returned. If some data is used from in_buf, it + * should be truncated, only the unused data should be left untouched. + * + * The size argument is the size of the last chunk of data that was added to the buffer. + */ + virtual void parse_in_buffer(const size_t size) = 0; +#ifdef BOTAN_FOUND + /** + * Tell whether the credential manager should cancel the connection when the + * certificate is invalid. + */ + virtual bool abort_on_invalid_cert() const + { + return true; + } +#endif + bool is_connected() const override final; + bool is_connecting() const; + +private: + /** + * Initialize the socket with the parameters contained in the given + * addrinfo structure. + */ + void init_socket(const struct addrinfo* rp); + /** + * Reads from the socket into the provided buffer. If an error occurs + * (read returns <= 0), the handling of the error is done here (close the + * connection, log a message, etc). + * + * Returns the value returned by ::recv(), so the buffer should not be + * used if it’s not positive. + */ + ssize_t do_recv(void* recv_buf, const size_t buf_size); + /** + * Reads data from the socket and calls parse_in_buffer with it. + */ + void plain_recv(); + /** + * Mark the given data as ready to be sent, as-is, on the socket, as soon + * as we can. + */ + void raw_send(std::string&& data); + +#ifdef BOTAN_FOUND + /** + * Create the TLS::Client object, with all the callbacks etc. This must be + * called only when we know we are able to send TLS-encrypted data over + * the socket. + */ + void start_tls(); + /** + * An additional step to pass the data into our tls object to decrypt it + * before passing it to parse_in_buffer. + */ + void tls_recv(); + /** + * Pass the data to the tls object in order to encrypt it. The tls object + * will then call raw_send as a callback whenever data as been encrypted + * and can be sent on the socket. + */ + void tls_send(std::string&& data); + /** + * Called by the tls object that some data has been decrypt. We call + * parse_in_buffer() to handle that unencrypted data. + */ + void tls_data_cb(const Botan::byte* data, size_t size); + /** + * Called by the tls object to indicate that some data has been encrypted + * and is now ready to be sent on the socket as is. + */ + void tls_output_fn(const Botan::byte* data, size_t size); + /** + * Called by the tls object to indicate that a TLS alert has been + * received. We don’t use it, we just log some message, at the moment. + */ + void tls_alert_cb(Botan::TLS::Alert alert, const Botan::byte*, size_t); + /** + * Called by the tls object at the end of the TLS handshake. We don't do + * anything here appart from logging the TLS session information. + */ + bool tls_handshake_cb(const Botan::TLS::Session& session); + /** + * Called whenever the tls session goes from inactive to active. This + * means that the handshake has just been successfully done, and we can + * now proceed to send any available data into our tls object. + */ + void on_tls_activated(); +#endif // BOTAN_FOUND + /** + * Where data is added, when we want to send something to the client. + */ + std::vector<std::string> out_buf; + /** + * DNS resolver + */ + Resolver resolver; + /** + * Keep the details of the addrinfo returned by the resolver that + * triggered a EINPROGRESS error when connect()ing to it, to reuse it + * directly when connect() is called again. + */ + struct addrinfo addrinfo; + struct sockaddr_in6 ai_addr; + socklen_t ai_addrlen; + +protected: + /** + * Where data read from the socket is added until we can extract a full + * and meaningful “message” from it. + * + * TODO: something more efficient than a string. + */ + std::string in_buf; + /** + * Whether we are using TLS on this connection or not. + */ + bool use_tls; + /** + * Provide a buffer in which data can be directly received. This can be + * used to avoid copying data into in_buf before using it. If no buffer + * needs to be provided, nullptr is returned (the default implementation + * does that), in that case our internal in_buf will be used to save the + * data until it can be used by parse_in_buffer(). + */ + virtual void* get_receive_buffer(const size_t size) const; + /** + * Hostname we are connected/connecting to + */ + std::string address; + /** + * Port we are connected/connecting to + */ + std::string port; + + bool connected; + bool connecting; + + bool hostname_resolution_failed; + + /** + * Address to bind the socket to, before calling connect(). + * If empty, it’s equivalent to binding to INADDR_ANY. + */ + std::string bind_addr; + +private: + /** + * Display the resolved IP, just for information purpose. + */ + void display_resolved_ip(struct addrinfo* rp) const; + +#ifdef BOTAN_FOUND + /** + * Botan stuff to manipulate a TLS session. + */ + static Botan::AutoSeeded_RNG rng; + static Botan::TLS::Policy policy; + static Botan::TLS::Session_Manager_In_Memory session_manager; +protected: + BasicCredentialsManager credential_manager; +private: + /** + * We use a unique_ptr because we may not want to create the object at + * all. The Botan::TLS::Client object generates a handshake message and + * calls the output_fn callback with it as soon as it is created. + * Therefore, we do not want to create it if we do not intend to send any + * TLS-encrypted message. We create the object only when needed (for + * example after we have negociated a TLS session using a STARTTLS + * message, or stuf like that). + * + * See start_tls for the method where this object is created. + */ + std::unique_ptr<Botan::TLS::Client> tls; + /** + * An additional buffer to keep data that the user wants to send, but + * cannot because the handshake is not done. + */ + std::string pre_buf; +#endif // BOTAN_FOUND +}; + + + diff --git a/louloulibs/utils/encoding.cpp b/louloulibs/utils/encoding.cpp new file mode 100644 index 0000000..507f38a --- /dev/null +++ b/louloulibs/utils/encoding.cpp @@ -0,0 +1,258 @@ +#include <utils/encoding.hpp> + +#include <utils/scopeguard.hpp> + +#include <stdexcept> + +#include <assert.h> +#include <string.h> +#include <iconv.h> + +#include <map> +#include <bitset> + +/** + * The UTF-8-encoded character used as a place holder when a character conversion fails. + * This is U+FFFD � "replacement character" + */ +static const char* invalid_char = "\xef\xbf\xbd"; +static const size_t invalid_char_len = 3; + +namespace utils +{ + /** + * Based on http://en.wikipedia.org/wiki/UTF-8#Description + */ + std::size_t get_next_codepoint_size(const unsigned char c) + { + if ((c & 0b11111000) == 0b11110000) // 4 bytes: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + return 4; + else if ((c & 0b11110000) == 0b11100000) // 3 bytes: 1110xxx 10xxxxxx 10xxxxxx + return 3; + else if ((c & 0b11100000) == 0b11000000) // 2 bytes: 110xxxxx 10xxxxxx + return 2; + return 1; // 1 byte: 0xxxxxxx + } + + bool is_valid_utf8(const char* s) + { + if (!s) + return false; + + const unsigned char* str = reinterpret_cast<const unsigned char*>(s); + + while (*str) + { + const auto codepoint_size = get_next_codepoint_size(str[0]); + if (codepoint_size == 4) + { + if (!str[1] || !str[2] || !str[3] + || ((str[1] & 0b11000000) != 0b10000000) + || ((str[2] & 0b11000000) != 0b10000000) + || ((str[3] & 0b11000000) != 0b10000000)) + return false; + } + else if (codepoint_size == 3) + { + if (!str[1] || !str[2] + || ((str[1] & 0b11000000) != 0b10000000) + || ((str[2] & 0b11000000) != 0b10000000)) + return false; + } + else if (codepoint_size == 2) + { + if (!str[1] || + ((str[1] & 0b11000000) != 0b10000000)) + return false; + } + else if ((str[0] & 0b10000000) != 0) + return false; + str += codepoint_size; + } + return true; + } + + std::string remove_invalid_xml_chars(const std::string& original) + { + // The given string MUST be a valid utf-8 string + unsigned char* res = new unsigned char[original.size()]; + ScopeGuard sg([&res]() { delete[] res;}); + + // pointer where we write valid chars + unsigned char* r = res; + + const unsigned char* str = reinterpret_cast<const unsigned char*>(original.c_str()); + std::bitset<20> codepoint; + + while (*str) + { + // 4 bytes: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + if ((str[0] & 0b11111000) == 0b11110000) + { + codepoint = ((str[0] & 0b00000111) << 18); + codepoint |= ((str[1] & 0b00111111) << 12); + codepoint |= ((str[2] & 0b00111111) << 6 ); + codepoint |= ((str[3] & 0b00111111) << 0 ); + if (codepoint.to_ulong() <= 0x10FFFF) + { + ::memcpy(r, str, 4); + r += 4; + } + str += 4; + } + // 3 bytes: 1110xxx 10xxxxxx 10xxxxxx + else if ((str[0] & 0b11110000) == 0b11100000) + { + codepoint = ((str[0] & 0b00001111) << 12); + codepoint |= ((str[1] & 0b00111111) << 6); + codepoint |= ((str[2] & 0b00111111) << 0 ); + if (codepoint.to_ulong() <= 0xD7FF || + (codepoint.to_ulong() >= 0xE000 && codepoint.to_ulong() <= 0xFFFD)) + { + ::memcpy(r, str, 3); + r += 3; + } + str += 3; + } + // 2 bytes: 110xxxxx 10xxxxxx + else if (((str[0]) & 0b11100000) == 0b11000000) + { + // All 2 bytes char are valid, don't even bother calculating + // the codepoint + ::memcpy(r, str, 2); + r += 2; + str += 2; + } + // 1 byte: 0xxxxxxx + else if ((str[0] & 0b10000000) == 0) + { + codepoint = ((str[0] & 0b01111111)); + if (codepoint.to_ulong() == 0x09 || + codepoint.to_ulong() == 0x0A || + codepoint.to_ulong() == 0x0D || + codepoint.to_ulong() >= 0x20) + { + ::memcpy(r, str, 1); + r += 1; + } + str += 1; + } + else + throw std::runtime_error("Invalid UTF-8 passed to remove_invalid_xml_chars"); + } + return std::string(reinterpret_cast<char*>(res), r-res); + } + + std::string convert_to_utf8(const std::string& str, const char* charset) + { + std::string res; + + const iconv_t cd = iconv_open("UTF-8", charset); + if (cd == (iconv_t)-1) + throw std::runtime_error("Cannot convert into UTF-8"); + + // Make sure cd is always closed when we leave this function + ScopeGuard sg([&]{ iconv_close(cd); }); + + size_t inbytesleft = str.size(); + + // iconv will not attempt to modify this buffer, but some plateform + // require a char** anyway +#ifdef ICONV_SECOND_ARGUMENT_IS_CONST + const char* inbuf_ptr = str.c_str(); +#else + char* inbuf_ptr = const_cast<char*>(str.c_str()); +#endif + + size_t outbytesleft = str.size() * 4; + char* outbuf = new char[outbytesleft]; + char* outbuf_ptr = outbuf; + + // Make sure outbuf is always deleted when we leave this function + sg.add_callback([&]{ delete[] outbuf; }); + + bool done = false; + while (done == false) + { + size_t error = iconv(cd, &inbuf_ptr, &inbytesleft, &outbuf_ptr, &outbytesleft); + if ((size_t)-1 == error) + { + switch (errno) + { + case EILSEQ: + // Invalid byte found. Insert a placeholder instead of the + // converted character, jump one byte and continue + memcpy(outbuf_ptr, invalid_char, invalid_char_len); + outbuf_ptr += invalid_char_len; + inbytesleft--; + inbuf_ptr++; + break; + case EINVAL: + // A multibyte sequence is not terminated, but we can't + // provide any more data, so we just add a placeholder to + // indicate that the character is not properly converted, + // and we stop the conversion + memcpy(outbuf_ptr, invalid_char, invalid_char_len); + outbuf_ptr += invalid_char_len; + outbuf_ptr++; + done = true; + break; + case E2BIG: + // This should never happen + done = true; + break; + default: + // This should happen even neverer + done = true; + break; + } + } + else + { + // The conversion finished without any error, stop converting + done = true; + } + } + // Terminate the converted buffer, and copy that buffer it into the + // string we return + *outbuf_ptr = '\0'; + res = outbuf; + return res; + } + +} + +namespace xep0106 +{ + static const std::map<const char, const std::string> encode_map = { + {' ', "\\20"}, + {'"', "\\22"}, + {'&', "\\26"}, + {'\'',"\\27"}, + {'/', "\\2f"}, + {':', "\\3a"}, + {'<', "\\3c"}, + {'>', "\\3e"}, + {'@', "\\40"}, + }; + + void decode(std::string& s) + { + std::string::size_type pos; + for (const auto& pair: encode_map) + while ((pos = s.find(pair.second)) != std::string::npos) + s.replace(pos, pair.second.size(), + 1, pair.first); + } + + void encode(std::string& s) + { + std::string::size_type pos; + while ((pos = s.find_first_of(" \"&'/:<>@")) != std::string::npos) + { + auto it = encode_map.find(s[pos]); + assert(it != encode_map.end()); + s.replace(pos, 1, it->second); + } + } +} diff --git a/louloulibs/utils/encoding.hpp b/louloulibs/utils/encoding.hpp new file mode 100644 index 0000000..586edd8 --- /dev/null +++ b/louloulibs/utils/encoding.hpp @@ -0,0 +1,43 @@ +#pragma once + + +#include <string> + +namespace utils +{ + /** + * Return the size, in bytes, of the next UTF-8 codepoint, based on + * the given char. + */ + std::size_t get_next_codepoint_size(const unsigned char c); + /** + * Returns true if the given null-terminated string is valid utf-8. + * + * Based on http://en.wikipedia.org/wiki/UTF-8#Description + */ + bool is_valid_utf8(const char* s); + /** + * Remove all invalid codepoints from the given utf-8-encoded string. + * The value returned is a copy of the string, without the removed chars. + * + * See http://www.w3.org/TR/xml/#charsets for the list of valid characters + * in XML. + */ + std::string remove_invalid_xml_chars(const std::string& original); + /** + * Convert the given string (encoded is "encoding") into valid utf-8. + * If some decoding fails, insert an utf-8 placeholder character instead. + */ + std::string convert_to_utf8(const std::string& str, const char* encoding); +} + +namespace xep0106 +{ + /** + * Decode and encode inplace. + */ + void decode(std::string&); + void encode(std::string&); +} + + diff --git a/louloulibs/utils/revstr.cpp b/louloulibs/utils/revstr.cpp new file mode 100644 index 0000000..87fd801 --- /dev/null +++ b/louloulibs/utils/revstr.cpp @@ -0,0 +1,9 @@ +#include <utils/revstr.hpp> + +namespace utils +{ + std::string revstr(const std::string& original) + { + return {original.rbegin(), original.rend()}; + } +} diff --git a/louloulibs/utils/revstr.hpp b/louloulibs/utils/revstr.hpp new file mode 100644 index 0000000..8e521ea --- /dev/null +++ b/louloulibs/utils/revstr.hpp @@ -0,0 +1,11 @@ +#pragma once + + +#include <string> + +namespace utils +{ + std::string revstr(const std::string& original); +} + + diff --git a/louloulibs/utils/scopeguard.hpp b/louloulibs/utils/scopeguard.hpp new file mode 100644 index 0000000..ee1e2ef --- /dev/null +++ b/louloulibs/utils/scopeguard.hpp @@ -0,0 +1,89 @@ +#pragma once + +#include <functional> +#include <vector> + +/** + * A class to be used to make sure some functions are called when the scope + * is left, because they will be called in the ScopeGuard's destructor. It + * can for example be used to delete some pointer whenever any exception is + * called. Example: + + * { + * ScopeGuard scope; + * int* number = new int(2); + * scope.add_callback([number]() { delete number; }); + * // Do some other stuff with the number. But these stuff might throw an exception: + * throw std::runtime_error("Some error not caught here, but in our caller"); + * return true; + * } + + * In this example, our pointer will always be deleted, even when the + * exception is thrown. If we want the functions to be called only when the + * scope is left because of an unexpected exception, we can use + * ScopeGuard::disable(); + */ + +namespace utils +{ + +class ScopeGuard +{ +public: + /** + * The constructor can take a callback. But additional callbacks can be + * added later with add_callback() + */ + explicit ScopeGuard(std::function<void()>&& func): + enabled(true) + { + this->add_callback(std::move(func)); + } + + ScopeGuard(const ScopeGuard&) = delete; + ScopeGuard& operator=(ScopeGuard&&) = delete; + ScopeGuard(ScopeGuard&&) = delete; + ScopeGuard& operator=(const ScopeGuard&) = delete; + + /** + * default constructor, the scope guard is enabled but empty, use + * add_callback() + */ + explicit ScopeGuard(): + enabled(true) + { + } + /** + * Call all callbacks in the desctructor, unless it has been disabled. + */ + ~ScopeGuard() + { + if (this->enabled) + for (auto& func: this->callbacks) + func(); + } + /** + * Add a callback to be called in our destructor, one scope guard can be + * used for more than one task, if needed. + */ + void add_callback(std::function<void()>&& func) + { + this->callbacks.emplace_back(std::move(func)); + } + /** + * Disable that scope guard, nothing will be done when the scope is + * exited. + */ + void disable() + { + this->enabled = false; + } + +private: + bool enabled; + std::vector<std::function<void()>> callbacks; + +}; + +} + diff --git a/louloulibs/utils/sha1.cpp b/louloulibs/utils/sha1.cpp new file mode 100644 index 0000000..76476df --- /dev/null +++ b/louloulibs/utils/sha1.cpp @@ -0,0 +1,154 @@ +/* This code is public-domain - it is based on libcrypt + * placed in the public domain by Wei Dai and other contributors. + */ + +#include "sha1.hpp" + +#define SHA1_K0 0x5a827999 +#define SHA1_K20 0x6ed9eba1 +#define SHA1_K40 0x8f1bbcdc +#define SHA1_K60 0xca62c1d6 + +const uint8_t sha1InitState[] = { + 0x01,0x23,0x45,0x67, // H0 + 0x89,0xab,0xcd,0xef, // H1 + 0xfe,0xdc,0xba,0x98, // H2 + 0x76,0x54,0x32,0x10, // H3 + 0xf0,0xe1,0xd2,0xc3 // H4 +}; + +void sha1_init(sha1nfo *s) { + memcpy(s->state.b,sha1InitState,HASH_LENGTH); + s->byteCount = 0; + s->bufferOffset = 0; +} + +uint32_t sha1_rol32(uint32_t number, uint8_t bits) { + return ((number << bits) | (number >> (32-bits))); +} + +void sha1_hashBlock(sha1nfo *s) { + uint8_t i; + uint32_t a,b,c,d,e,t; + + a=s->state.w[0]; + b=s->state.w[1]; + c=s->state.w[2]; + d=s->state.w[3]; + e=s->state.w[4]; + for (i=0; i<80; i++) { + if (i>=16) { + t = s->buffer.w[(i+13)&15] ^ s->buffer.w[(i+8)&15] ^ s->buffer.w[(i+2)&15] ^ s->buffer.w[i&15]; + s->buffer.w[i&15] = sha1_rol32(t,1); + } + if (i<20) { + t = (d ^ (b & (c ^ d))) + SHA1_K0; + } else if (i<40) { + t = (b ^ c ^ d) + SHA1_K20; + } else if (i<60) { + t = ((b & c) | (d & (b | c))) + SHA1_K40; + } else { + t = (b ^ c ^ d) + SHA1_K60; + } + t+=sha1_rol32(a,5) + e + s->buffer.w[i&15]; + e=d; + d=c; + c=sha1_rol32(b,30); + b=a; + a=t; + } + s->state.w[0] += a; + s->state.w[1] += b; + s->state.w[2] += c; + s->state.w[3] += d; + s->state.w[4] += e; +} + +void sha1_addUncounted(sha1nfo *s, uint8_t data) { + s->buffer.b[s->bufferOffset ^ 3] = data; + s->bufferOffset++; + if (s->bufferOffset == BLOCK_LENGTH) { + sha1_hashBlock(s); + s->bufferOffset = 0; + } +} + +void sha1_writebyte(sha1nfo *s, uint8_t data) { + ++s->byteCount; + sha1_addUncounted(s, data); +} + +void sha1_write(sha1nfo *s, const char *data, size_t len) { + for (;len--;) sha1_writebyte(s, (uint8_t) *data++); +} + +void sha1_pad(sha1nfo *s) { + // Implement SHA-1 padding (fips180-2 §5.1.1) + + // Pad with 0x80 followed by 0x00 until the end of the block + sha1_addUncounted(s, 0x80); + while (s->bufferOffset != 56) sha1_addUncounted(s, 0x00); + + // Append length in the last 8 bytes + sha1_addUncounted(s, 0); // We're only using 32 bit lengths + sha1_addUncounted(s, 0); // But SHA-1 supports 64 bit lengths + sha1_addUncounted(s, 0); // So zero pad the top bits + sha1_addUncounted(s, s->byteCount >> 29); // Shifting to multiply by 8 + sha1_addUncounted(s, s->byteCount >> 21); // as SHA-1 supports bitstreams as well as + sha1_addUncounted(s, s->byteCount >> 13); // byte. + sha1_addUncounted(s, s->byteCount >> 5); + sha1_addUncounted(s, s->byteCount << 3); +} + +uint8_t* sha1_result(sha1nfo *s) { + int i; + // Pad to complete the last block + sha1_pad(s); + + // Swap byte order back + for (i=0; i<5; i++) { + uint32_t a,b; + a=s->state.w[i]; + b=a<<24; + b|=(a<<8) & 0x00ff0000; + b|=(a>>8) & 0x0000ff00; + b|=a>>24; + s->state.w[i]=b; + } + + // Return pointer to hash (20 characters) + return s->state.b; +} + +#define HMAC_IPAD 0x36 +#define HMAC_OPAD 0x5c + +void sha1_initHmac(sha1nfo *s, const uint8_t* key, int keyLength) { + uint8_t i; + memset(s->keyBuffer, 0, BLOCK_LENGTH); + if (keyLength > BLOCK_LENGTH) { + // Hash long keys + sha1_init(s); + for (;keyLength--;) sha1_writebyte(s, *key++); + memcpy(s->keyBuffer, sha1_result(s), HASH_LENGTH); + } else { + // Block length keys are used as is + memcpy(s->keyBuffer, key, keyLength); + } + // Start inner hash + sha1_init(s); + for (i=0; i<BLOCK_LENGTH; i++) { + sha1_writebyte(s, s->keyBuffer[i] ^ HMAC_IPAD); + } +} + +uint8_t* sha1_resultHmac(sha1nfo *s) { + uint8_t i; + // Complete inner hash + memcpy(s->innerHash,sha1_result(s),HASH_LENGTH); + // Calculate outer hash + sha1_init(s); + for (i=0; i<BLOCK_LENGTH; i++) sha1_writebyte(s, s->keyBuffer[i] ^ HMAC_OPAD); + for (i=0; i<HASH_LENGTH; i++) sha1_writebyte(s, s->innerHash[i]); + return sha1_result(s); +} diff --git a/louloulibs/utils/sha1.hpp b/louloulibs/utils/sha1.hpp new file mode 100644 index 0000000..d02de75 --- /dev/null +++ b/louloulibs/utils/sha1.hpp @@ -0,0 +1,35 @@ +/* This code is public-domain - it is based on libcrypt + * placed in the public domain by Wei Dai and other contributors. + */ + +#include <stdint.h> +#include <string.h> + +#define HASH_LENGTH 20 +#define BLOCK_LENGTH 64 + +union _buffer { + uint8_t b[BLOCK_LENGTH]; + uint32_t w[BLOCK_LENGTH/4]; +}; + +union _state { + uint8_t b[HASH_LENGTH]; + uint32_t w[HASH_LENGTH/4]; +}; + +typedef struct sha1nfo { + union _buffer buffer; + uint8_t bufferOffset; + union _state state; + uint32_t byteCount; + uint8_t keyBuffer[BLOCK_LENGTH]; + uint8_t innerHash[HASH_LENGTH]; +} sha1nfo; + +void sha1_init(sha1nfo *s); +void sha1_writebyte(sha1nfo *s, uint8_t data); +void sha1_write(sha1nfo *s, const char *data, size_t len); +uint8_t* sha1_result(sha1nfo *s); +void sha1_initHmac(sha1nfo *s, const uint8_t* key, int keyLength); +uint8_t* sha1_resultHmac(sha1nfo *s); diff --git a/louloulibs/utils/split.cpp b/louloulibs/utils/split.cpp new file mode 100644 index 0000000..80f8dae --- /dev/null +++ b/louloulibs/utils/split.cpp @@ -0,0 +1,19 @@ +#include <utils/split.hpp> +#include <sstream> + +namespace utils +{ + std::vector<std::string> split(const std::string& s, const char delim, const bool allow_empty) + { + std::vector<std::string> ret; + std::stringstream ss(s); + std::string item; + while (std::getline(ss, item, delim)) + { + if (item.empty() && !allow_empty) + continue ; + ret.emplace_back(std::move(item)); + } + return ret; + } +} diff --git a/louloulibs/utils/split.hpp b/louloulibs/utils/split.hpp new file mode 100644 index 0000000..3755ef8 --- /dev/null +++ b/louloulibs/utils/split.hpp @@ -0,0 +1,12 @@ +#pragma once + + +#include <string> +#include <vector> + +namespace utils +{ + std::vector<std::string> split(const std::string &s, const char delim, const bool allow_empty=true); +} + + diff --git a/louloulibs/utils/string.cpp b/louloulibs/utils/string.cpp new file mode 100644 index 0000000..635e71a --- /dev/null +++ b/louloulibs/utils/string.cpp @@ -0,0 +1,28 @@ +#include <utils/string.hpp> +#include <utils/encoding.hpp> + +bool to_bool(const std::string& val) +{ + return (val == "1" || val == "true"); +} + +std::vector<std::string> cut(const std::string& val, const std::size_t size) +{ + std::vector<std::string> res; + std::string::size_type pos = 0; + while (pos < val.size()) + { + // Get the number of chars, <= size, that contain only whole + // UTF-8 codepoints. + std::size_t s = 0; + auto codepoint_size = utils::get_next_codepoint_size(val[pos + s]); + while (s + codepoint_size <= size && pos + s < val.size()) + { + s += codepoint_size; + codepoint_size = utils::get_next_codepoint_size(val[pos + s]); + } + res.emplace_back(val.substr(pos, s)); + pos += s; + } + return res; +} diff --git a/louloulibs/utils/string.hpp b/louloulibs/utils/string.hpp new file mode 100644 index 0000000..84ba101 --- /dev/null +++ b/louloulibs/utils/string.hpp @@ -0,0 +1,10 @@ +#pragma once + + +#include <vector> +#include <string> + +bool to_bool(const std::string& val); +std::vector<std::string> cut(const std::string& val, const std::size_t size); + + diff --git a/louloulibs/utils/timed_events.cpp b/louloulibs/utils/timed_events.cpp new file mode 100644 index 0000000..68d009c --- /dev/null +++ b/louloulibs/utils/timed_events.cpp @@ -0,0 +1,49 @@ +#include <utils/timed_events.hpp> + +TimedEvent::TimedEvent(std::chrono::steady_clock::time_point&& time_point, + std::function<void()> callback, const std::string& name): + time_point(std::move(time_point)), + callback(callback), + repeat(false), + repeat_delay(0), + name(name) +{ +} + +TimedEvent::TimedEvent(std::chrono::milliseconds&& duration, + std::function<void()> callback, const std::string& name): + time_point(std::chrono::steady_clock::now() + duration), + callback(callback), + repeat(true), + repeat_delay(std::move(duration)), + name(name) +{ +} + +bool TimedEvent::is_after(const TimedEvent& other) const +{ + return this->is_after(other.time_point); +} + +bool TimedEvent::is_after(const std::chrono::steady_clock::time_point& time_point) const +{ + return this->time_point > time_point; +} + +std::chrono::milliseconds TimedEvent::get_timeout() const +{ + auto now = std::chrono::steady_clock::now(); + if (now > this->time_point) + return std::chrono::milliseconds(0); + return std::chrono::duration_cast<std::chrono::milliseconds>(this->time_point - now); +} + +void TimedEvent::execute() const +{ + this->callback(); +} + +const std::string& TimedEvent::get_name() const +{ + return this->name; +} diff --git a/louloulibs/utils/timed_events.hpp b/louloulibs/utils/timed_events.hpp new file mode 100644 index 0000000..6e28206 --- /dev/null +++ b/louloulibs/utils/timed_events.hpp @@ -0,0 +1,132 @@ +#pragma once + +#include <functional> +#include <string> +#include <chrono> +#include <vector> + +using namespace std::literals::chrono_literals; + +namespace utils { +static constexpr std::chrono::milliseconds no_timeout = std::chrono::milliseconds(-1); +} + +class TimedEventsManager; + +/** + * A callback with an associated date. + */ + +class TimedEvent +{ + friend class TimedEventsManager; +public: + /** + * An event the occurs only once, at the given time_point + */ + explicit TimedEvent(std::chrono::steady_clock::time_point&& time_point, + std::function<void()> callback, const std::string& name=""); + explicit TimedEvent(std::chrono::milliseconds&& duration, + std::function<void()> callback, const std::string& name=""); + + explicit TimedEvent(TimedEvent&&) = default; + TimedEvent& operator=(TimedEvent&&) = default; + ~TimedEvent() = default; + + TimedEvent(const TimedEvent&) = delete; + TimedEvent& operator=(const TimedEvent&) = delete; + + /** + * Whether or not this event happens after the other one. + */ + bool is_after(const TimedEvent& other) const; + bool is_after(const std::chrono::steady_clock::time_point& time_point) const; + /** + * Return the duration difference between now and the event time point. + * If the difference would be negative (i.e. the event is expired), the + * returned value is 0 instead. The value cannot then be negative. + */ + std::chrono::milliseconds get_timeout() const; + void execute() const; + const std::string& get_name() const; + +private: + /** + * The next time point at which the event is executed. + */ + std::chrono::steady_clock::time_point time_point; + /** + * The function to execute. + */ + std::function<void()> callback; + /** + * Whether or not this events repeats itself until it is destroyed. + */ + bool repeat; + /** + * This value is added to the time_point each time the event is executed, + * if repeat is true. Otherwise it is ignored. + */ + std::chrono::milliseconds repeat_delay; + /** + * A name that is used to identify that event. If you want to find your + * event (for example if you want to cancel it), the name should be + * unique. + */ + std::string name; +}; + +/** + * A class managing a list of TimedEvents. + * They are sorted, new events can be added, removed, fetch, etc. + */ + +class TimedEventsManager +{ +public: + ~TimedEventsManager() = default; + + TimedEventsManager(const TimedEventsManager&) = delete; + TimedEventsManager(TimedEventsManager&&) = delete; + TimedEventsManager& operator=(const TimedEventsManager&) = delete; + TimedEventsManager& operator=(TimedEventsManager&&) = delete; + + /** + * Return the unique instance of this class + */ + static TimedEventsManager& instance(); + /** + * Add an event to the list of managed events. The list is sorted after + * this call. + */ + void add_event(TimedEvent&& event); + /** + * Returns the duration, in milliseconds, between now and the next + * available event. If the event is already expired (the duration is + * negative), 0 is returned instead (as in “it's not too late, execute it + * now”) + * Returns a negative value if no event is available. + */ + std::chrono::milliseconds get_timeout() const; + /** + * Execute all the expired events (if their expiration time is exactly + * now, or before now). The event is then removed from the list. If the + * event does repeat, its expiration time is updated and it is reinserted + * in the list at the correct position. + * Returns the number of executed events. + */ + std::size_t execute_expired_events(); + /** + * Remove (and thus cancel) all the timed events with the given name. + * Returns the number of canceled events. + */ + std::size_t cancel(const std::string& name); + /** + * Return the number of managed events. + */ + std::size_t size() const; + +private: + std::vector<TimedEvent> events; + explicit TimedEventsManager() = default; +}; diff --git a/louloulibs/utils/timed_events_manager.cpp b/louloulibs/utils/timed_events_manager.cpp new file mode 100644 index 0000000..67d61fe --- /dev/null +++ b/louloulibs/utils/timed_events_manager.cpp @@ -0,0 +1,73 @@ +#include <utils/timed_events.hpp> + +TimedEventsManager& TimedEventsManager::instance() +{ + static TimedEventsManager inst; + return inst; +} + +void TimedEventsManager::add_event(TimedEvent&& event) +{ + for (auto it = this->events.begin(); it != this->events.end(); ++it) + { + if (it->is_after(event)) + { + this->events.emplace(it, std::move(event)); + return; + } + } + this->events.emplace_back(std::move(event)); +} + +std::chrono::milliseconds TimedEventsManager::get_timeout() const +{ + if (this->events.empty()) + return utils::no_timeout; + return this->events.front().get_timeout(); +} + +std::size_t TimedEventsManager::execute_expired_events() +{ + std::size_t count = 0; + const auto now = std::chrono::steady_clock::now(); + for (auto it = this->events.begin(); it != this->events.end();) + { + if (!it->is_after(now)) + { + TimedEvent copy(std::move(*it)); + it = this->events.erase(it); + ++count; + copy.execute(); + if (copy.repeat) + { + copy.time_point += copy.repeat_delay; + this->add_event(std::move(copy)); + } + continue; + } + else + break; + } + return count; +} + +std::size_t TimedEventsManager::cancel(const std::string& name) +{ + std::size_t res = 0; + for (auto it = this->events.begin(); it != this->events.end();) + { + if (it->get_name() == name) + { + it = this->events.erase(it); + res++; + } + else + ++it; + } + return res; +} + +std::size_t TimedEventsManager::size() const +{ + return this->events.size(); +} diff --git a/louloulibs/utils/tolower.cpp b/louloulibs/utils/tolower.cpp new file mode 100644 index 0000000..3e518bd --- /dev/null +++ b/louloulibs/utils/tolower.cpp @@ -0,0 +1,13 @@ +#include <utils/tolower.hpp> + +namespace utils +{ + std::string tolower(const std::string& original) + { + std::string res; + res.reserve(original.size()); + for (const char c: original) + res += static_cast<char>(std::tolower(c)); + return res; + } +} diff --git a/louloulibs/utils/tolower.hpp b/louloulibs/utils/tolower.hpp new file mode 100644 index 0000000..650e05d --- /dev/null +++ b/louloulibs/utils/tolower.hpp @@ -0,0 +1,11 @@ +#pragma once + + +#include <string> + +namespace utils +{ + std::string tolower(const std::string& original); +} + + diff --git a/louloulibs/utils/xdg.cpp b/louloulibs/utils/xdg.cpp new file mode 100644 index 0000000..48212a1 --- /dev/null +++ b/louloulibs/utils/xdg.cpp @@ -0,0 +1,29 @@ +#include <utils/xdg.hpp> +#include <cstdlib> + +#include "louloulibs.h" + +std::string xdg_path(const std::string& filename, const char* env_var) +{ + const char* xdg_home = ::getenv(env_var); + if (xdg_home && xdg_home[0] == '/') + return std::string{xdg_home} + "/" PROJECT_NAME "/" + filename; + else + { + const char* home = ::getenv("HOME"); + if (home) + return std::string{home} + "/" ".config" "/" PROJECT_NAME "/" + filename; + else + return filename; + } +} + +std::string xdg_config_path(const std::string& filename) +{ + return xdg_path(filename, "XDG_CONFIG_HOME"); +} + +std::string xdg_data_path(const std::string& filename) +{ + return xdg_path(filename, "XDG_DATA_HOME"); +} diff --git a/louloulibs/utils/xdg.hpp b/louloulibs/utils/xdg.hpp new file mode 100644 index 0000000..56e11da --- /dev/null +++ b/louloulibs/utils/xdg.hpp @@ -0,0 +1,14 @@ +#pragma once + + +#include <string> + +/** + * Returns a path for the given filename, according to the XDG base + * directory specification, see + * http://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html + */ +std::string xdg_config_path(const std::string& filename); +std::string xdg_data_path(const std::string& filename); + + diff --git a/louloulibs/xmpp/adhoc_command.cpp b/louloulibs/xmpp/adhoc_command.cpp new file mode 100644 index 0000000..99701d7 --- /dev/null +++ b/louloulibs/xmpp/adhoc_command.cpp @@ -0,0 +1,89 @@ +#include <xmpp/adhoc_command.hpp> +#include <xmpp/xmpp_component.hpp> +#include <utils/reload.hpp> + +using namespace std::string_literals; + +AdhocCommand::AdhocCommand(std::vector<AdhocStep>&& callbacks, const std::string& name, const bool admin_only): + name(name), + callbacks(std::move(callbacks)), + admin_only(admin_only) +{ +} + +bool AdhocCommand::is_admin_only() const +{ + return this->admin_only; +} + +void PingStep1(XmppComponent&, AdhocSession&, XmlNode& command_node) +{ + XmlNode note("note"); + note["type"] = "info"; + note.set_inner("Pong"); + command_node.add_child(std::move(note)); +} + +void HelloStep1(XmppComponent&, AdhocSession&, XmlNode& command_node) +{ + XmlNode x("jabber:x:data:x"); + x["type"] = "form"; + XmlNode title("title"); + title.set_inner("Configure your name."); + x.add_child(std::move(title)); + XmlNode instructions("instructions"); + instructions.set_inner("Please provide your name."); + x.add_child(std::move(instructions)); + XmlNode name_field("field"); + name_field["var"] = "name"; + name_field["type"] = "text-single"; + name_field["label"] = "Your name"; + XmlNode required("required"); + name_field.add_child(std::move(required)); + x.add_child(std::move(name_field)); + command_node.add_child(std::move(x)); +} + +void HelloStep2(XmppComponent&, AdhocSession& session, XmlNode& command_node) +{ + // Find out if the name was provided in the form. + if (const XmlNode* x = command_node.get_child("x", "jabber:x:data")) + { + const XmlNode* name_field = nullptr; + for (const XmlNode* field: x->get_children("field", "jabber:x:data")) + if (field->get_tag("var") == "name") + { + name_field = field; + break; + } + if (name_field) + { + if (const XmlNode* value = name_field->get_child("value", "jabber:x:data")) + { + XmlNode note("note"); + note["type"] = "info"; + note.set_inner("Hello "s + value->get_inner() + "!"s); + command_node.delete_all_children(); + command_node.add_child(std::move(note)); + return; + } + } + } + command_node.delete_all_children(); + XmlNode error(ADHOC_NS":error"); + error["type"] = "modify"; + XmlNode condition(STANZA_NS":bad-request"); + error.add_child(std::move(condition)); + command_node.add_child(std::move(error)); + session.terminate(); +} + +void Reload(XmppComponent&, AdhocSession&, XmlNode& command_node) +{ + ::reload_process(); + command_node.delete_all_children(); + XmlNode note("note"); + note["type"] = "info"; + note.set_inner("Configuration reloaded."); + command_node.add_child(std::move(note)); +} diff --git a/louloulibs/xmpp/adhoc_command.hpp b/louloulibs/xmpp/adhoc_command.hpp new file mode 100644 index 0000000..7c4de47 --- /dev/null +++ b/louloulibs/xmpp/adhoc_command.hpp @@ -0,0 +1,44 @@ +#pragma once + +/** + * Describe an ad-hoc command. + * + * Can only have zero or one step for now. When execution is requested, it + * can return a result immediately, or provide a form to be filled, and + * provide a result once the filled form is received. + */ + +#include <xmpp/adhoc_session.hpp> + +#include <functional> +#include <string> + +class AdhocCommand +{ + friend class AdhocSession; +public: + AdhocCommand(std::vector<AdhocStep>&& callback, const std::string& name, const bool admin_only); + ~AdhocCommand() = default; + AdhocCommand(const AdhocCommand&) = default; + AdhocCommand(AdhocCommand&&) = default; + AdhocCommand& operator=(AdhocCommand&&) = delete; + AdhocCommand& operator=(const AdhocCommand&) = delete; + + const std::string name; + + bool is_admin_only() const; + +private: + /** + * A command may have one or more steps. Each step is a different + * callback, inserting things into a <command/> XmlNode and calling + * methods of an AdhocSession. + */ + std::vector<AdhocStep> callbacks; + const bool admin_only; +}; + +void PingStep1(XmppComponent&, AdhocSession& session, XmlNode& command_node); +void HelloStep1(XmppComponent&, AdhocSession& session, XmlNode& command_node); +void HelloStep2(XmppComponent&, AdhocSession& session, XmlNode& command_node); +void Reload(XmppComponent&, AdhocSession& session, XmlNode& command_node); diff --git a/louloulibs/xmpp/adhoc_commands_handler.cpp b/louloulibs/xmpp/adhoc_commands_handler.cpp new file mode 100644 index 0000000..17c4e67 --- /dev/null +++ b/louloulibs/xmpp/adhoc_commands_handler.cpp @@ -0,0 +1,123 @@ +#include <xmpp/adhoc_commands_handler.hpp> +#include <xmpp/xmpp_component.hpp> + +#include <utils/timed_events.hpp> +#include <logger/logger.hpp> +#include <config/config.hpp> +#include <xmpp/jid.hpp> + +#include <iostream> + +using namespace std::string_literals; + +const std::map<const std::string, const AdhocCommand>& AdhocCommandsHandler::get_commands() const +{ + return this->commands; +} + +std::map<const std::string, const AdhocCommand>& AdhocCommandsHandler::get_commands() +{ + return this->commands; +} + +XmlNode AdhocCommandsHandler::handle_request(const std::string& executor_jid, const std::string& to, XmlNode command_node) +{ + std::string action = command_node.get_tag("action"); + if (action.empty()) + action = "execute"; + command_node.del_tag("action"); + + Jid jid(executor_jid); + + const std::string node = command_node.get_tag("node"); + auto command_it = this->commands.find(node); + if (command_it == this->commands.end()) + { + XmlNode error(ADHOC_NS":error"); + error["type"] = "cancel"; + XmlNode condition(STANZA_NS":item-not-found"); + error.add_child(std::move(condition)); + command_node.add_child(std::move(error)); + } + else if (command_it->second.is_admin_only() && + Config::get("admin", "") != jid.local + "@" + jid.domain) + { + XmlNode error(ADHOC_NS":error"); + error["type"] = "cancel"; + XmlNode condition(STANZA_NS":forbidden"); + error.add_child(std::move(condition)); + command_node.add_child(std::move(error)); + } + else + { + std::string sessionid = command_node.get_tag("sessionid"); + if (sessionid.empty()) + { // create a new session, with a new id + sessionid = XmppComponent::next_id(); + command_node["sessionid"] = sessionid; + this->sessions.emplace(std::piecewise_construct, + std::forward_as_tuple(sessionid, executor_jid), + std::forward_as_tuple(command_it->second, executor_jid, to)); + TimedEventsManager::instance().add_event(TimedEvent(std::chrono::steady_clock::now() + 3600s, + std::bind(&AdhocCommandsHandler::remove_session, this, sessionid, executor_jid), + "adhocsession"s + sessionid + executor_jid)); + } + auto session_it = this->sessions.find(std::make_pair(sessionid, executor_jid)); + if (session_it == this->sessions.end()) + { + XmlNode error(ADHOC_NS":error"); + error["type"] = "modify"; + XmlNode condition(STANZA_NS":bad-request"); + error.add_child(std::move(condition)); + command_node.add_child(std::move(error)); + } + else if (action == "execute" || action == "next" || action == "complete") + { + // execute the step + AdhocSession& session = session_it->second; + const AdhocStep& step = session.get_next_step(); + step(this->xmpp_component, session, command_node); + if (session.remaining_steps() == 0 || + session.is_terminated()) + { + this->sessions.erase(session_it); + command_node["status"] = "completed"; + TimedEventsManager::instance().cancel("adhocsession"s + sessionid + executor_jid); + } + else + { + command_node["status"] = "executing"; + XmlNode actions("actions"); + XmlNode next("next"); + actions.add_child(std::move(next)); + command_node.add_child(std::move(actions)); + } + } + else if (action == "cancel") + { + this->sessions.erase(session_it); + command_node["status"] = "canceled"; + TimedEventsManager::instance().cancel("adhocsession"s + sessionid + executor_jid); + } + else // unsupported action + { + XmlNode error(ADHOC_NS":error"); + error["type"] = "modify"; + XmlNode condition(STANZA_NS":bad-request"); + error.add_child(std::move(condition)); + command_node.add_child(std::move(error)); + } + } + return command_node; +} + +void AdhocCommandsHandler::remove_session(const std::string& session_id, const std::string& initiator_jid) +{ + auto session_it = this->sessions.find(std::make_pair(session_id, initiator_jid)); + if (session_it != this->sessions.end()) + { + this->sessions.erase(session_it); + return ; + } + log_error("Tried to remove ad-hoc session for [", session_id, ", ", initiator_jid, "] but none found"); +} diff --git a/louloulibs/xmpp/adhoc_commands_handler.hpp b/louloulibs/xmpp/adhoc_commands_handler.hpp new file mode 100644 index 0000000..91eb5bd --- /dev/null +++ b/louloulibs/xmpp/adhoc_commands_handler.hpp @@ -0,0 +1,71 @@ +#pragma once + +/** + * Manage a list of available AdhocCommands and the list of ongoing + * AdhocCommandSessions. + */ + +#include <xmpp/adhoc_command.hpp> +#include <xmpp/xmpp_stanza.hpp> + +#include <utility> +#include <string> +#include <map> + +class AdhocCommandsHandler +{ +public: + explicit AdhocCommandsHandler(XmppComponent& xmpp_component): + xmpp_component(xmpp_component), + commands{} + { } + ~AdhocCommandsHandler() = default; + + AdhocCommandsHandler(const AdhocCommandsHandler&) = delete; + AdhocCommandsHandler(AdhocCommandsHandler&&) = delete; + AdhocCommandsHandler& operator=(const AdhocCommandsHandler&) = delete; + AdhocCommandsHandler& operator=(AdhocCommandsHandler&&) = delete; + + /** + * Returns the list of available commands. + */ + const std::map<const std::string, const AdhocCommand>& get_commands() const; + /** + * This one can be used to add new commands. + */ + std::map<const std::string, const AdhocCommand>& get_commands(); + /** + * Find the requested command, create a new session or use an existing + * one, and process the request (provide a new form, an error, or a + * result). + * + * Returns a (moved) XmlNode that will be inserted in the iq response. It + * should be a <command/> node containing one or more useful children. If + * it contains an <error/> node, the iq response will have an error type. + * + * Takes a copy of the <command/> node so we can actually edit it and use + * it as our return value. + */ + XmlNode handle_request(const std::string& executor_jid, const std::string& to, XmlNode command_node); + /** + * Remove the session from the list. This is done to avoid filling the + * memory with waiting session (for example due to a client that starts + * multi-steps command but never finishes them). + */ + void remove_session(const std::string& session_id, const std::string& initiator_jid); +private: + /** + * To access basically anything in the gateway. + */ + XmppComponent& xmpp_component; + /** + * The list of all available commands. + */ + std::map<const std::string, const AdhocCommand> commands; + /** + * The list of all currently on-going commands. + * + * Of the form: {{session_id, owner_jid}, session}. + */ + std::map<std::pair<const std::string, const std::string>, AdhocSession> sessions; +}; diff --git a/louloulibs/xmpp/adhoc_session.cpp b/louloulibs/xmpp/adhoc_session.cpp new file mode 100644 index 0000000..dda4bea --- /dev/null +++ b/louloulibs/xmpp/adhoc_session.cpp @@ -0,0 +1,35 @@ +#include <xmpp/adhoc_session.hpp> +#include <xmpp/adhoc_command.hpp> + +#include <assert.h> + +AdhocSession::AdhocSession(const AdhocCommand& command, const std::string& owner_jid, + const std::string& to_jid): + command(command), + owner_jid(owner_jid), + to_jid(to_jid), + current_step(0), + terminated(false) +{ +} + +const AdhocStep& AdhocSession::get_next_step() +{ + assert(this->current_step < this->command.callbacks.size()); + return this->command.callbacks[this->current_step++]; +} + +size_t AdhocSession::remaining_steps() const +{ + return this->command.callbacks.size() - this->current_step; +} + +bool AdhocSession::is_terminated() const +{ + return this->terminated; +} + +void AdhocSession::terminate() +{ + this->terminated = true; +} diff --git a/louloulibs/xmpp/adhoc_session.hpp b/louloulibs/xmpp/adhoc_session.hpp new file mode 100644 index 0000000..0de8d13 --- /dev/null +++ b/louloulibs/xmpp/adhoc_session.hpp @@ -0,0 +1,88 @@ +#pragma once + +#include <xmpp/xmpp_stanza.hpp> + +#include <functional> +#include <string> +#include <map> + +class XmppComponent; + +class AdhocCommand; +class AdhocSession; + +/** + * A function executed as an ad-hoc command step. It takes a <command/> + * XmlNode and modifies it accordingly (inserting for example an <error/> + * node, or a data form…). + */ +using AdhocStep = std::function<void(XmppComponent&, AdhocSession&, XmlNode&)>; + +class AdhocSession +{ +public: + explicit AdhocSession(const AdhocCommand& command, const std::string& owner_jid, + const std::string& to_jid); + ~AdhocSession() = default; + + AdhocSession(const AdhocSession&) = delete; + AdhocSession(AdhocSession&&) = delete; + AdhocSession& operator=(const AdhocSession&) = delete; + AdhocSession& operator=(AdhocSession&&) = delete; + + /** + * Return the function to be executed, found in our AdhocCommand, for the + * current_step. And increment the current_step. + */ + const AdhocStep& get_next_step(); + /** + * Return the number of remaining steps. + */ + size_t remaining_steps() const; + /** + * This may be modified by an AdhocStep, to indicate that this session + * should no longer exist, because we encountered an error, and we can't + * execute any more step of it. + */ + void terminate(); + bool is_terminated() const; + std::string get_target_jid() const + { + return this->to_jid; + } + std::string get_owner_jid() const + { + return this->owner_jid; + } + +private: + /** + * A reference of the command concerned by this session. Used for example + * to get the next step of that command, things like that. + */ + const AdhocCommand& command; + /** + * The full JID of the XMPP user that created this session by executing + * the first step of a command. Only that JID must be allowed to access + * this session. + */ + const std::string& owner_jid; + /** + * The 'to' attribute in the request stanza. This is the target of the current session. + */ + const std::string& to_jid; + /** + * The current step we are at. It starts at zero. It is used to index the + * associated AdhocCommand::callbacks vector. + */ + size_t current_step; + bool terminated; + +public: + /** + * A map to store various things that we may want to remember between two + * steps of the same session. A step can insert any value associated to + * any key in there. + */ + std::map<std::string, std::string> vars; +}; diff --git a/louloulibs/xmpp/body.hpp b/louloulibs/xmpp/body.hpp new file mode 100644 index 0000000..068d1a4 --- /dev/null +++ b/louloulibs/xmpp/body.hpp @@ -0,0 +1,12 @@ +#pragma once + + +namespace Xmpp +{ +// Contains: +// - an XMPP-valid UTF-8 body +// - an XML node representing the XHTML-IM body, or null + using body = std::tuple<const std::string, std::unique_ptr<XmlNode>>; +} + + diff --git a/louloulibs/xmpp/jid.cpp b/louloulibs/xmpp/jid.cpp new file mode 100644 index 0000000..7b62f3e --- /dev/null +++ b/louloulibs/xmpp/jid.cpp @@ -0,0 +1,99 @@ +#include <xmpp/jid.hpp> +#include <algorithm> +#include <cstring> +#include <map> + +#include <louloulibs.h> +#ifdef LIBIDN_FOUND + #include <stringprep.h> +#endif + +#include <logger/logger.hpp> + +Jid::Jid(const std::string& jid) +{ + std::string::size_type slash = jid.find('/'); + if (slash != std::string::npos) + { + this->resource = jid.substr(slash + 1); + } + + std::string::size_type at = jid.find('@'); + if (at != std::string::npos && at < slash) + { + this->local = jid.substr(0, at); + at++; + } + else + at = 0; + + this->domain = jid.substr(at, slash - at); +} + +static constexpr size_t max_jid_part_len = 1023; + +std::string jidprep(const std::string& original) +{ +#ifdef LIBIDN_FOUND + using CacheType = std::map<std::string, std::string>; + static CacheType cache; + std::pair<CacheType::iterator, bool> cached = cache.insert({original, {}}); + if (std::get<1>(cached) == false) + { // Insertion failed: the result is already in the cache, return it + return std::get<0>(cached)->second; + } + + const std::string error_msg("Failed to convert " + original + " into a valid JID:"); + Jid jid(original); + + char local[max_jid_part_len] = {}; + memcpy(local, jid.local.data(), std::min(max_jid_part_len, jid.local.size())); + Stringprep_rc rc = static_cast<Stringprep_rc>(::stringprep(local, max_jid_part_len, + static_cast<Stringprep_profile_flags>(0), stringprep_xmpp_nodeprep)); + if (rc != STRINGPREP_OK) + { + log_error(error_msg + stringprep_strerror(rc)); + return ""; + } + + char domain[max_jid_part_len] = {}; + memcpy(domain, jid.domain.data(), std::min(max_jid_part_len, jid.domain.size())); + rc = static_cast<Stringprep_rc>(::stringprep(domain, max_jid_part_len, + static_cast<Stringprep_profile_flags>(0), stringprep_nameprep)); + if (rc != STRINGPREP_OK) + { + log_error(error_msg + stringprep_strerror(rc)); + return ""; + } + std::replace_if(std::begin(domain), domain + ::strlen(domain), + [](const char c) -> bool + { + return !((c >= 'a' && c <= 'z') || c == '-' || + (c >= '0' && c <= '9') || c == '.'); + }, '-'); + + // If there is no resource, stop here + if (jid.resource.empty()) + { + std::get<0>(cached)->second = std::string(local) + "@" + domain; + return std::get<0>(cached)->second; + } + + // Otherwise, also process the resource part + char resource[max_jid_part_len] = {}; + memcpy(resource, jid.resource.data(), std::min(max_jid_part_len, jid.resource.size())); + rc = static_cast<Stringprep_rc>(::stringprep(resource, max_jid_part_len, + static_cast<Stringprep_profile_flags>(0), stringprep_xmpp_resourceprep)); + if (rc != STRINGPREP_OK) + { + log_error(error_msg + stringprep_strerror(rc)); + return ""; + } + std::get<0>(cached)->second = std::string(local) + "@" + domain + "/" + resource; + return std::get<0>(cached)->second; + +#else + (void)original; + return ""; +#endif +} diff --git a/louloulibs/xmpp/jid.hpp b/louloulibs/xmpp/jid.hpp new file mode 100644 index 0000000..08327ef --- /dev/null +++ b/louloulibs/xmpp/jid.hpp @@ -0,0 +1,44 @@ +#pragma once + + +#include <string> + +/** + * Parse a JID into its different subart + */ +class Jid +{ +public: + explicit Jid(const std::string& jid); + + Jid(const Jid&) = delete; + Jid(Jid&&) = delete; + Jid& operator=(const Jid&) = delete; + Jid& operator=(Jid&&) = delete; + + std::string domain; + std::string local; + std::string resource; + + std::string bare() const + { + return this->local + "@" + this->domain; + } + std::string full() const + { + return this->local + "@" + this->domain + "/" + this->resource; + } +}; + +/** + * Prepare the given UTF-8 string according to the XMPP node stringprep + * identifier profile. This is used to send properly-formed JID to the XMPP + * server. + * + * If the stringprep library is not found, we return an empty string. When + * this function is used, the result must always be checked for an empty + * value, and if this is the case it must not be used as a JID. + */ +std::string jidprep(const std::string& original); + + diff --git a/louloulibs/xmpp/roster.cpp b/louloulibs/xmpp/roster.cpp new file mode 100644 index 0000000..a14a384 --- /dev/null +++ b/louloulibs/xmpp/roster.cpp @@ -0,0 +1,21 @@ +#include <xmpp/roster.hpp> + +RosterItem::RosterItem(const std::string& jid, const std::string& name, + std::vector<std::string>& groups): + jid(jid), + name(name), + groups(groups) +{ +} + +RosterItem::RosterItem(const std::string& jid, const std::string& name): + jid(jid), + name(name), + groups{} +{ +} + +void Roster::clear() +{ + this->items.clear(); +} diff --git a/louloulibs/xmpp/roster.hpp b/louloulibs/xmpp/roster.hpp new file mode 100644 index 0000000..aa1b449 --- /dev/null +++ b/louloulibs/xmpp/roster.hpp @@ -0,0 +1,71 @@ +#pragma once + + +#include <algorithm> +#include <string> +#include <vector> + +class RosterItem +{ +public: + RosterItem(const std::string& jid, const std::string& name, + std::vector<std::string>& groups); + RosterItem(const std::string& jid, const std::string& name); + RosterItem() = default; + ~RosterItem() = default; + RosterItem(const RosterItem&) = default; + RosterItem(RosterItem&&) = default; + RosterItem& operator=(const RosterItem&) = default; + RosterItem& operator=(RosterItem&&) = default; + + std::string jid; + std::string name; + std::vector<std::string> groups; + +private: +}; + +/** + * Keep track of the last known stat of a JID's roster + */ +class Roster +{ +public: + Roster() = default; + ~Roster() = default; + + void clear(); + + template <typename... ArgsType> + RosterItem* add_item(ArgsType&&... args) + { + this->items.emplace_back(std::forward<ArgsType>(args)...); + auto it = this->items.end() - 1; + return &*it; + } + RosterItem* get_item(const std::string& jid) + { + auto it = std::find_if(this->items.begin(), this->items.end(), + [this, &jid](const auto& item) + { + return item.jid == jid; + }); + if (it != this->items.end()) + return &*it; + return nullptr; + } + const std::vector<RosterItem>& get_items() const + { + return this->items; + } + +private: + std::vector<RosterItem> items; + + Roster(const Roster&) = delete; + Roster(Roster&&) = delete; + Roster& operator=(const Roster&) = delete; + Roster& operator=(Roster&&) = delete; +}; + + diff --git a/louloulibs/xmpp/xmpp_component.cpp b/louloulibs/xmpp/xmpp_component.cpp new file mode 100644 index 0000000..e87cdf7 --- /dev/null +++ b/louloulibs/xmpp/xmpp_component.cpp @@ -0,0 +1,664 @@ +#include <utils/timed_events.hpp> +#include <utils/scopeguard.hpp> +#include <utils/tolower.hpp> +#include <logger/logger.hpp> + +#include <xmpp/xmpp_component.hpp> +#include <config/config.hpp> +#include <xmpp/jid.hpp> +#include <utils/sha1.hpp> + +#include <stdexcept> +#include <iostream> +#include <set> + +#include <stdio.h> + +#include <uuid.h> + +#include <louloulibs.h> +#ifdef SYSTEMD_FOUND +# include <systemd/sd-daemon.h> +#endif + +using namespace std::string_literals; + +static std::set<std::string> kickable_errors{ + "gone", + "internal-server-error", + "item-not-found", + "jid-malformed", + "recipient-unavailable", + "redirect", + "remote-server-not-found", + "remote-server-timeout", + "service-unavailable", + "malformed-error" + }; + +XmppComponent::XmppComponent(std::shared_ptr<Poller> poller, const std::string& hostname, const std::string& secret): + TCPSocketHandler(poller), + ever_auth(false), + first_connection_try(true), + secret(secret), + authenticated(false), + doc_open(false), + served_hostname(hostname), + stanza_handlers{}, + adhoc_commands_handler(*this) +{ + this->parser.add_stream_open_callback(std::bind(&XmppComponent::on_remote_stream_open, this, + std::placeholders::_1)); + this->parser.add_stanza_callback(std::bind(&XmppComponent::on_stanza, this, + std::placeholders::_1)); + this->parser.add_stream_close_callback(std::bind(&XmppComponent::on_remote_stream_close, this, + std::placeholders::_1)); + this->stanza_handlers.emplace("handshake", + std::bind(&XmppComponent::handle_handshake, this,std::placeholders::_1)); + this->stanza_handlers.emplace("error", + std::bind(&XmppComponent::handle_error, this,std::placeholders::_1)); +} + +void XmppComponent::start() +{ + this->connect(Config::get("xmpp_server_ip", "127.0.0.1"), Config::get("port", "5347"), false); +} + +bool XmppComponent::is_document_open() const +{ + return this->doc_open; +} + +void XmppComponent::send_stanza(const Stanza& stanza) +{ + std::string str = stanza.to_string(); + log_debug("XMPP SENDING: ", str); + this->send_data(std::move(str)); +} + +void XmppComponent::on_connection_failed(const std::string& reason) +{ + this->first_connection_try = false; + log_error("Failed to connect to the XMPP server: ", reason); +#ifdef SYSTEMD_FOUND + sd_notifyf(0, "STATUS=Failed to connect to the XMPP server: %s", reason.data()); +#endif +} + +void XmppComponent::on_connected() +{ + log_info("connected to XMPP server"); + this->first_connection_try = true; + auto data = "<stream:stream to='"s + this->served_hostname + \ + "' xmlns:stream='http://etherx.jabber.org/streams' xmlns='" COMPONENT_NS "'>"; + log_debug("XMPP SENDING: ", data); + this->send_data(std::move(data)); + this->doc_open = true; + // We may have some pending data to send: this happens when we try to send + // some data before we are actually connected. We send that data right now, if any + this->send_pending_data(); +} + +void XmppComponent::on_connection_close(const std::string& error) +{ + if (error.empty()) + log_info("XMPP server closed connection"); + else + log_info("XMPP server closed connection: ", error); +} + +void XmppComponent::parse_in_buffer(const size_t size) +{ + if (!this->in_buf.empty()) + { // This may happen if the parser could not allocate enough space for + // us. We try to feed it the data that was read into our in_buf + // instead. If this fails again we are in trouble. + this->parser.feed(this->in_buf.data(), this->in_buf.size(), false); + this->in_buf.clear(); + } + else + { // Just tell the parser to parse the data that was placed into the + // buffer it provided to us with GetBuffer + this->parser.parse(size, false); + } +} + +void XmppComponent::on_remote_stream_open(const XmlNode& node) +{ + log_debug("XMPP RECEIVING: ", node.to_string()); + this->stream_id = node.get_tag("id"); + if (this->stream_id.empty()) + { + log_error("Error: no attribute 'id' found"); + this->send_stream_error("bad-format", "missing 'id' attribute"); + this->close_document(); + return ; + } + + // Try to authenticate + char digest[HASH_LENGTH * 2 + 1]; + sha1nfo sha1; + sha1_init(&sha1); + sha1_write(&sha1, this->stream_id.data(), this->stream_id.size()); + sha1_write(&sha1, this->secret.data(), this->secret.size()); + const uint8_t* result = sha1_result(&sha1); + for (int i=0; i < HASH_LENGTH; i++) + sprintf(digest + (i*2), "%02x", result[i]); + digest[HASH_LENGTH * 2] = '\0'; + + auto data = "<handshake xmlns='" COMPONENT_NS "'>"s + digest + "</handshake>"; + log_debug("XMPP SENDING: ", data); + this->send_data(std::move(data)); +} + +void XmppComponent::on_remote_stream_close(const XmlNode& node) +{ + log_debug("XMPP RECEIVING: ", node.to_string()); + this->doc_open = false; +} + +void XmppComponent::reset() +{ + this->parser.reset(); +} + +void XmppComponent::on_stanza(const Stanza& stanza) +{ + log_debug("XMPP RECEIVING: ", stanza.to_string()); + std::function<void(const Stanza&)> handler; + try + { + handler = this->stanza_handlers.at(stanza.get_name()); + } + catch (const std::out_of_range& exception) + { + log_warning("No handler for stanza of type ", stanza.get_name()); + return; + } + handler(stanza); +} + +void XmppComponent::send_stream_error(const std::string& name, const std::string& explanation) +{ + XmlNode node("stream:error", nullptr); + XmlNode error(name, nullptr); + error["xmlns"] = STREAM_NS; + if (!explanation.empty()) + error.set_inner(explanation); + node.add_child(std::move(error)); + this->send_stanza(node); +} + +void XmppComponent::send_stanza_error(const std::string& kind, const std::string& to, const std::string& from, + const std::string& id, const std::string& error_type, + const std::string& defined_condition, const std::string& text, + const bool fulljid) +{ + Stanza node(kind); + if (!to.empty()) + node["to"] = to; + if (!from.empty()) + { + if (fulljid) + node["from"] = from; + else + node["from"] = from + "@" + this->served_hostname; + } + if (!id.empty()) + node["id"] = id; + node["type"] = "error"; + XmlNode error("error"); + error["type"] = error_type; + XmlNode inner_error(defined_condition); + inner_error["xmlns"] = STANZA_NS; + error.add_child(std::move(inner_error)); + if (!text.empty()) + { + XmlNode text_node("text"); + text_node["xmlns"] = STANZA_NS; + text_node.set_inner(text); + error.add_child(std::move(text_node)); + } + node.add_child(std::move(error)); + this->send_stanza(node); +} + +void XmppComponent::close_document() +{ + log_debug("XMPP SENDING: </stream:stream>"); + this->send_data("</stream:stream>"); + this->doc_open = false; +} + +void XmppComponent::handle_handshake(const Stanza& stanza) +{ + (void)stanza; + this->authenticated = true; + this->ever_auth = true; + log_info("Authenticated with the XMPP server"); +#ifdef SYSTEMD_FOUND + sd_notify(0, "READY=1"); + // Install an event that sends a keepalive to systemd. If biboumi crashes + // or hangs for too long, systemd will restart it. + uint64_t usec; + if (sd_watchdog_enabled(0, &usec) > 0) + { + TimedEventsManager::instance().add_event(TimedEvent( + std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::microseconds(usec / 2)), + []() { sd_notify(0, "WATCHDOG=1"); })); + } +#endif + this->after_handshake(); +} + +void XmppComponent::handle_error(const Stanza& stanza) +{ + const XmlNode* text = stanza.get_child("text", STREAMS_NS); + std::string error_message("Unspecified error"); + if (text) + error_message = text->get_inner(); + log_error("Stream error received from the XMPP server: ", error_message); +#ifdef SYSTEMD_FOUND + if (!this->ever_auth) + sd_notifyf(0, "STATUS=Failed to authenticate to the XMPP server: %s", error_message.data()); +#endif + +} + +void* XmppComponent::get_receive_buffer(const size_t size) const +{ + return this->parser.get_buffer(size); +} + +void XmppComponent::send_message(const std::string& from, Xmpp::body&& body, const std::string& to, const std::string& type, const bool fulljid) +{ + XmlNode node("message"); + node["to"] = to; + if (fulljid) + node["from"] = from; + else + node["from"] = from + "@" + this->served_hostname; + if (!type.empty()) + node["type"] = type; + XmlNode body_node("body"); + body_node.set_inner(std::get<0>(body)); + node.add_child(std::move(body_node)); + if (std::get<1>(body)) + { + XmlNode html("html"); + html["xmlns"] = XHTMLIM_NS; + // Pass the ownership of the pointer to this xmlnode + html.add_child(std::move(std::get<1>(body))); + node.add_child(std::move(html)); + } + this->send_stanza(node); +} + +void XmppComponent::send_user_join(const std::string& from, + const std::string& nick, + const std::string& realjid, + const std::string& affiliation, + const std::string& role, + const std::string& to, + const bool self) +{ + XmlNode node("presence"); + node["to"] = to; + node["from"] = from + "@" + this->served_hostname + "/" + nick; + + XmlNode x("x"); + x["xmlns"] = MUC_USER_NS; + + XmlNode item("item"); + if (!affiliation.empty()) + item["affiliation"] = affiliation; + if (!role.empty()) + item["role"] = role; + if (!realjid.empty()) + { + const std::string preped_jid = jidprep(realjid); + if (!preped_jid.empty()) + item["jid"] = preped_jid; + } + x.add_child(std::move(item)); + + if (self) + { + XmlNode status("status"); + status["code"] = "110"; + x.add_child(std::move(status)); + } + node.add_child(std::move(x)); + this->send_stanza(node); +} + +void XmppComponent::send_invalid_room_error(const std::string& muc_name, + const std::string& nick, + const std::string& to) +{ + Stanza presence("presence"); + if (!muc_name.empty()) + presence["from"] = muc_name + "@" + this->served_hostname + "/" + nick; + else + presence["from"] = this->served_hostname; + presence["to"] = to; + presence["type"] = "error"; + XmlNode x("x"); + x["xmlns"] = MUC_NS; + presence.add_child(std::move(x)); + XmlNode error("error"); + error["by"] = muc_name + "@" + this->served_hostname; + error["type"] = "cancel"; + XmlNode item_not_found("item-not-found"); + item_not_found["xmlns"] = STANZA_NS; + error.add_child(std::move(item_not_found)); + XmlNode text("text"); + text["xmlns"] = STANZA_NS; + text["xml:lang"] = "en"; + text.set_inner(muc_name + + " is not a valid IRC channel name. A correct room jid is of the form: #<chan>%<server>@" + + this->served_hostname); + error.add_child(std::move(text)); + presence.add_child(std::move(error)); + this->send_stanza(presence); +} + +void XmppComponent::send_invalid_user_error(const std::string& user_name, const std::string& to) +{ + Stanza message("message"); + message["from"] = user_name + "@" + this->served_hostname; + message["to"] = to; + message["type"] = "error"; + XmlNode x("x"); + x["xmlns"] = MUC_NS; + message.add_child(std::move(x)); + XmlNode error("error"); + error["type"] = "cancel"; + XmlNode item_not_found("item-not-found"); + item_not_found["xmlns"] = STANZA_NS; + error.add_child(std::move(item_not_found)); + XmlNode text("text"); + text["xmlns"] = STANZA_NS; + text["xml:lang"] = "en"; + text.set_inner(user_name + + " is not a valid IRC user name. A correct user jid is of the form: <nick>!<server>@" + + this->served_hostname); + error.add_child(std::move(text)); + message.add_child(std::move(error)); + this->send_stanza(message); +} + +void XmppComponent::send_topic(const std::string& from, Xmpp::body&& topic, const std::string& to, const std::string& who) +{ + XmlNode message("message"); + message["to"] = to; + if (who.empty()) + message["from"] = from + "@" + this->served_hostname; + else + message["from"] = from + "@" + this->served_hostname + "/" + who; + message["type"] = "groupchat"; + XmlNode subject("subject"); + subject.set_inner(std::get<0>(topic)); + message.add_child(std::move(subject)); + this->send_stanza(message); +} + +void XmppComponent::send_muc_message(const std::string& muc_name, const std::string& nick, Xmpp::body&& xmpp_body, const std::string& jid_to) +{ + Stanza message("message"); + message["to"] = jid_to; + if (!nick.empty()) + message["from"] = muc_name + "@" + this->served_hostname + "/" + nick; + else // Message from the room itself + message["from"] = muc_name + "@" + this->served_hostname; + message["type"] = "groupchat"; + XmlNode body("body"); + body.set_inner(std::get<0>(xmpp_body)); + message.add_child(std::move(body)); + if (std::get<1>(xmpp_body)) + { + XmlNode html("html"); + html["xmlns"] = XHTMLIM_NS; + // Pass the ownership of the pointer to this xmlnode + html.add_child(std::move(std::get<1>(xmpp_body))); + message.add_child(std::move(html)); + } + this->send_stanza(message); +} + +void XmppComponent::send_muc_leave(const std::string& muc_name, std::string&& nick, Xmpp::body&& message, const std::string& jid_to, const bool self) +{ + Stanza presence("presence"); + presence["to"] = jid_to; + presence["from"] = muc_name + "@" + this->served_hostname + "/" + nick; + presence["type"] = "unavailable"; + const std::string message_str = std::get<0>(message); + XmlNode x("x"); + x["xmlns"] = MUC_USER_NS; + if (self) + { + XmlNode status("status"); + status["code"] = "110"; + x.add_child(std::move(status)); + } + presence.add_child(std::move(x)); + if (!message_str.empty()) + { + XmlNode status("status"); + status.set_inner(message_str); + presence.add_child(std::move(status)); + } + this->send_stanza(presence); +} + +void XmppComponent::send_nick_change(const std::string& muc_name, + const std::string& old_nick, + const std::string& new_nick, + const std::string& affiliation, + const std::string& role, + const std::string& jid_to, + const bool self) +{ + Stanza presence("presence"); + presence["to"] = jid_to; + presence["from"] = muc_name + "@" + this->served_hostname + "/" + old_nick; + presence["type"] = "unavailable"; + XmlNode x("x"); + x["xmlns"] = MUC_USER_NS; + XmlNode item("item"); + item["nick"] = new_nick; + x.add_child(std::move(item)); + XmlNode status("status"); + status["code"] = "303"; + x.add_child(std::move(status)); + if (self) + { + XmlNode status2("status"); + status2["code"] = "110"; + x.add_child(std::move(status2)); + } + presence.add_child(std::move(x)); + this->send_stanza(presence); + + this->send_user_join(muc_name, new_nick, "", affiliation, role, jid_to, self); +} + +void XmppComponent::kick_user(const std::string& muc_name, + const std::string& target, + const std::string& txt, + const std::string& author, + const std::string& jid_to) +{ + Stanza presence("presence"); + presence["from"] = muc_name + "@" + this->served_hostname + "/" + target; + presence["to"] = jid_to; + presence["type"] = "unavailable"; + XmlNode x("x"); + x["xmlns"] = MUC_USER_NS; + XmlNode item("item"); + item["affiliation"] = "none"; + item["role"] = "none"; + XmlNode actor("actor"); + actor["nick"] = author; + actor["jid"] = author; // backward compatibility with old clients + item.add_child(std::move(actor)); + XmlNode reason("reason"); + reason.set_inner(txt); + item.add_child(std::move(reason)); + x.add_child(std::move(item)); + XmlNode status("status"); + status["code"] = "307"; + x.add_child(std::move(status)); + presence.add_child(std::move(x)); + this->send_stanza(presence); +} + +void XmppComponent::send_presence_error(const std::string& muc_name, + const std::string& nickname, + const std::string& jid_to, + const std::string& type, + const std::string& condition, + const std::string& error_code, + const std::string& /* text */) +{ + Stanza presence("presence"); + presence["from"] = muc_name + "@" + this->served_hostname + "/" + nickname; + presence["to"] = jid_to; + presence["type"] = "error"; + XmlNode x("x"); + x["xmlns"] = MUC_NS; + presence.add_child(std::move(x)); + XmlNode error("error"); + error["by"] = muc_name + "@" + this->served_hostname; + error["type"] = type; + if (!error_code.empty()) + error["code"] = error_code; + XmlNode subnode(condition); + subnode["xmlns"] = STANZA_NS; + error.add_child(std::move(subnode)); + presence.add_child(std::move(error)); + this->send_stanza(presence); +} + +void XmppComponent::send_affiliation_role_change(const std::string& muc_name, + const std::string& target, + const std::string& affiliation, + const std::string& role, + const std::string& jid_to) +{ + Stanza presence("presence"); + presence["from"] = muc_name + "@" + this->served_hostname + "/" + target; + presence["to"] = jid_to; + XmlNode x("x"); + x["xmlns"] = MUC_USER_NS; + XmlNode item("item"); + item["affiliation"] = affiliation; + item["role"] = role; + x.add_child(std::move(item)); + presence.add_child(std::move(x)); + this->send_stanza(presence); +} + +void XmppComponent::send_version(const std::string& id, const std::string& jid_to, const std::string& jid_from, + const std::string& version) +{ + Stanza iq("iq"); + iq["type"] = "result"; + iq["id"] = id; + iq["to"] = jid_to; + iq["from"] = jid_from; + XmlNode query("query"); + query["xmlns"] = VERSION_NS; + if (version.empty()) + { + XmlNode name("name"); + name.set_inner("biboumi"); + query.add_child(std::move(name)); + XmlNode version("version"); + version.set_inner(SOFTWARE_VERSION); + query.add_child(std::move(version)); + XmlNode os("os"); + os.set_inner(SYSTEM_NAME); + query.add_child(std::move(os)); + } + else + { + XmlNode name("name"); + name.set_inner(version); + query.add_child(std::move(name)); + } + iq.add_child(std::move(query)); + this->send_stanza(iq); +} + +void XmppComponent::send_adhoc_commands_list(const std::string& id, const std::string& requester_jid, + const std::string& from_jid, + const bool with_admin_only, const AdhocCommandsHandler& adhoc_handler) +{ + Stanza iq("iq"); + iq["type"] = "result"; + iq["id"] = id; + iq["to"] = requester_jid; + iq["from"] = from_jid; + XmlNode query("query"); + query["xmlns"] = DISCO_ITEMS_NS; + query["node"] = ADHOC_NS; + for (const auto& kv: adhoc_handler.get_commands()) + { + if (kv.second.is_admin_only() && !with_admin_only) + continue; + XmlNode item("item"); + item["jid"] = from_jid; + item["node"] = kv.first; + item["name"] = kv.second.name; + query.add_child(std::move(item)); + } + iq.add_child(std::move(query)); + this->send_stanza(iq); +} + +void XmppComponent::send_iq_version_request(const std::string& from, + const std::string& jid_to) +{ + Stanza iq("iq"); + iq["type"] = "get"; + iq["id"] = "version_"s + XmppComponent::next_id(); + iq["from"] = from + "@" + this->served_hostname; + iq["to"] = jid_to; + XmlNode query("query"); + query["xmlns"] = VERSION_NS; + iq.add_child(std::move(query)); + this->send_stanza(iq); +} + +void XmppComponent::send_iq_result_full_jid(const std::string& id, const std::string& to_jid, const std::string& from_full_jid) +{ + Stanza iq("iq"); + iq["from"] = from_full_jid; + iq["to"] = to_jid; + iq["id"] = id; + iq["type"] = "result"; + this->send_stanza(iq); +} + +void XmppComponent::send_iq_result(const std::string& id, const std::string& to_jid, const std::string& from_local_part) +{ + Stanza iq("iq"); + if (!from_local_part.empty()) + iq["from"] = from_local_part + "@" + this->served_hostname; + else + iq["from"] = this->served_hostname; + iq["to"] = to_jid; + iq["id"] = id; + iq["type"] = "result"; + this->send_stanza(iq); +} + +std::string XmppComponent::next_id() +{ + char uuid_str[37]; + uuid_t uuid; + uuid_generate(uuid); + uuid_unparse(uuid, uuid_str); + return uuid_str; +} diff --git a/louloulibs/xmpp/xmpp_component.hpp b/louloulibs/xmpp/xmpp_component.hpp new file mode 100644 index 0000000..5fc6d2e --- /dev/null +++ b/louloulibs/xmpp/xmpp_component.hpp @@ -0,0 +1,243 @@ +#pragma once + + +#include <xmpp/adhoc_commands_handler.hpp> +#include <network/tcp_socket_handler.hpp> +#include <xmpp/xmpp_parser.hpp> +#include <xmpp/body.hpp> + +#include <unordered_map> +#include <memory> +#include <string> +#include <map> + +#define STREAM_NS "http://etherx.jabber.org/streams" +#define COMPONENT_NS "jabber:component:accept" +#define MUC_NS "http://jabber.org/protocol/muc" +#define MUC_USER_NS MUC_NS"#user" +#define MUC_ADMIN_NS MUC_NS"#admin" +#define DISCO_NS "http://jabber.org/protocol/disco" +#define DISCO_ITEMS_NS DISCO_NS"#items" +#define DISCO_INFO_NS DISCO_NS"#info" +#define XHTMLIM_NS "http://jabber.org/protocol/xhtml-im" +#define STANZA_NS "urn:ietf:params:xml:ns:xmpp-stanzas" +#define STREAMS_NS "urn:ietf:params:xml:ns:xmpp-streams" +#define VERSION_NS "jabber:iq:version" +#define ADHOC_NS "http://jabber.org/protocol/commands" +#define PING_NS "urn:xmpp:ping" + +/** + * An XMPP component, communicating with an XMPP server using the protocole + * described in XEP-0114: Jabber Component Protocol + * + * TODO: implement XEP-0225: Component Connections + */ +class XmppComponent: public TCPSocketHandler +{ +public: + explicit XmppComponent(std::shared_ptr<Poller> poller, const std::string& hostname, const std::string& secret); + virtual ~XmppComponent() = default; + + XmppComponent(const XmppComponent&) = delete; + XmppComponent(XmppComponent&&) = delete; + XmppComponent& operator=(const XmppComponent&) = delete; + XmppComponent& operator=(XmppComponent&&) = delete; + + void on_connection_failed(const std::string& reason) override final; + void on_connected() override final; + void on_connection_close(const std::string& error) override final; + void parse_in_buffer(const size_t size) override final; + + /** + * Returns a unique id, to be used in the 'id' element of our iq stanzas. + */ + static std::string next_id(); + bool is_document_open() const; + /** + * Connect to the XMPP server. + */ + void start(); + /** + * Reset the component so we can use the component on a new XMPP stream + */ + void reset(); + /** + * Serialize the stanza and add it to the out_buf to be sent to the + * server. + */ + void send_stanza(const Stanza& stanza); + /** + * Handle the opening of the remote stream + */ + void on_remote_stream_open(const XmlNode& node); + /** + * Handle the closing of the remote stream + */ + void on_remote_stream_close(const XmlNode& node); + /** + * Handle received stanzas + */ + void on_stanza(const Stanza& stanza); + /** + * Send an error stanza. Message being the name of the element inside the + * stanza, and explanation being a short human-readable sentence + * describing the error. + */ + void send_stream_error(const std::string& message, const std::string& explanation); + /** + * Send error stanza, described in http://xmpp.org/rfcs/rfc6120.html#stanzas-error + */ + void send_stanza_error(const std::string& kind, const std::string& to, const std::string& from, + const std::string& id, const std::string& error_type, + const std::string& defined_condition, const std::string& text, + const bool fulljid=true); + /** + * Send the closing signal for our document (not closing the connection though). + */ + void close_document(); + /** + * Send a message from from@served_hostname, with the given body + * + * If fulljid is false, the provided 'from' doesn't contain the + * server-part of the JID and must be added. + */ + void send_message(const std::string& from, Xmpp::body&& body, + const std::string& to, const std::string& type, + const bool fulljid=false); + /** + * Send a join from a new participant + */ + void send_user_join(const std::string& from, + const std::string& nick, + const std::string& realjid, + const std::string& affiliation, + const std::string& role, + const std::string& to, + const bool self); + /** + * Send an error to indicate that the user tried to join an invalid room + */ + void send_invalid_room_error(const std::string& muc_jid, + const std::string& nick, + const std::string& to); + /** + * Send an error to indicate that the user tried to send a message to an + * invalid user. + */ + void send_invalid_user_error(const std::string& user_name, + const std::string& to); + /** + * Send the MUC topic to the user + */ + void send_topic(const std::string& from, Xmpp::body&& xmpp_topic, const std::string& to, const std::string& who); + /** + * Send a (non-private) message to the MUC + */ + void send_muc_message(const std::string& muc_name, const std::string& nick, Xmpp::body&& body, const std::string& jid_to); + /** + * Send an unavailable presence for this nick + */ + void send_muc_leave(const std::string& muc_name, std::string&& nick, Xmpp::body&& message, const std::string& jid_to, const bool self); + /** + * Indicate that a participant changed his nick + */ + void send_nick_change(const std::string& muc_name, + const std::string& old_nick, + const std::string& new_nick, + const std::string& affiliation, + const std::string& role, + const std::string& jid_to, + const bool self); + /** + * An user is kicked from a room + */ + void kick_user(const std::string& muc_name, + const std::string& target, + const std::string& reason, + const std::string& author, + const std::string& jid_to); + /** + * Send a generic presence error + */ + void send_presence_error(const std::string& muc_name, + const std::string& nickname, + const std::string& jid_to, + const std::string& type, + const std::string& condition, + const std::string& error_code, + const std::string& text); + /** + * Send a presence from the MUC indicating a change in the role and/or + * affiliation of a participant + */ + void send_affiliation_role_change(const std::string& muc_name, + const std::string& target, + const std::string& affiliation, + const std::string& role, + const std::string& jid_to); + /** + * Send a result IQ with the gateway disco informations. + */ + void send_self_disco_info(const std::string& id, const std::string& jid_to); + /** + * Send a result IQ with the given version, or the gateway version if the + * passed string is empty. + */ + void send_version(const std::string& id, const std::string& jid_to, const std::string& jid_from, + const std::string& version=""); + /** + * Send the list of all available ad-hoc commands to that JID. The list is + * different depending on what JID made the request. + */ + void send_adhoc_commands_list(const std::string& id, const std::string& requester_jid, const std::string& from_jid, + const bool with_admin_only, const AdhocCommandsHandler& adhoc_handler); + /** + * Send an iq version request + */ + void send_iq_version_request(const std::string& from, + const std::string& jid_to); + /** + * Send an empty iq of type result + */ + void send_iq_result(const std::string& id, const std::string& to_jid, const std::string& from); + void send_iq_result_full_jid(const std::string& id, const std::string& to_jid, + const std::string& from_full_jid); + + void handle_handshake(const Stanza& stanza); + void handle_error(const Stanza& stanza); + + virtual void after_handshake() {} + + /** + * Whether or not we ever succeeded our authentication to the XMPP server + */ + bool ever_auth; + /** + * Whether or not this is the first consecutive try on connecting to the + * XMPP server. We use this to delay the connection attempt for a few + * seconds, if it is not the first try. + */ + bool first_connection_try; + +private: + /** + * Return a buffer provided by the XML parser, to read data directly into + * it, and avoiding some unnecessary copy. + */ + void* get_receive_buffer(const size_t size) const override final; + XmppParser parser; + std::string stream_id; + std::string secret; + bool authenticated; + /** + * Whether or not OUR XMPP document is open + */ + bool doc_open; +protected: + std::string served_hostname; + + std::unordered_map<std::string, std::function<void(const Stanza&)>> stanza_handlers; + AdhocCommandsHandler adhoc_commands_handler; +}; + + diff --git a/louloulibs/xmpp/xmpp_parser.cpp b/louloulibs/xmpp/xmpp_parser.cpp new file mode 100644 index 0000000..0488be9 --- /dev/null +++ b/louloulibs/xmpp/xmpp_parser.cpp @@ -0,0 +1,172 @@ +#include <xmpp/xmpp_parser.hpp> +#include <xmpp/xmpp_stanza.hpp> + +#include <logger/logger.hpp> + +/** + * Expat handlers. Called by the Expat library, never by ourself. + * They just forward the call to the XmppParser corresponding methods. + */ + +static void start_element_handler(void* user_data, const XML_Char* name, const XML_Char** atts) +{ + static_cast<XmppParser*>(user_data)->start_element(name, atts); +} + +static void end_element_handler(void* user_data, const XML_Char* name) +{ + static_cast<XmppParser*>(user_data)->end_element(name); +} + +static void character_data_handler(void *user_data, const XML_Char *s, int len) +{ + static_cast<XmppParser*>(user_data)->char_data(s, len); +} + +/** + * XmppParser class + */ + +XmppParser::XmppParser(): + level(0), + current_node(nullptr), + root(nullptr) +{ + this->init_xml_parser(); +} + +void XmppParser::init_xml_parser() +{ + // Create the expat parser + this->parser = XML_ParserCreateNS("UTF-8", ':'); + XML_SetUserData(this->parser, static_cast<void*>(this)); + + // Install Expat handlers + XML_SetElementHandler(this->parser, &start_element_handler, &end_element_handler); + XML_SetCharacterDataHandler(this->parser, &character_data_handler); +} + +XmppParser::~XmppParser() +{ + XML_ParserFree(this->parser); +} + +int XmppParser::feed(const char* data, const int len, const bool is_final) +{ + int res = XML_Parse(this->parser, data, len, is_final); + if (res == XML_STATUS_ERROR && + (XML_GetErrorCode(this->parser) != XML_ERROR_FINISHED)) + log_error("Xml_Parse encountered an error: ", + XML_ErrorString(XML_GetErrorCode(this->parser))); + return res; +} + +int XmppParser::parse(const int len, const bool is_final) +{ + int res = XML_ParseBuffer(this->parser, len, is_final); + if (res == XML_STATUS_ERROR) + log_error("Xml_Parsebuffer encountered an error: ", + XML_ErrorString(XML_GetErrorCode(this->parser))); + return res; +} + +void XmppParser::reset() +{ + XML_ParserFree(this->parser); + this->init_xml_parser(); + this->current_node = nullptr; + this->root.reset(nullptr); + this->level = 0; +} + +void* XmppParser::get_buffer(const size_t size) const +{ + return XML_GetBuffer(this->parser, static_cast<int>(size)); +} + +void XmppParser::start_element(const XML_Char* name, const XML_Char** attribute) +{ + this->level++; + + auto new_node = std::make_unique<XmlNode>(name, this->current_node); + auto new_node_ptr = new_node.get(); + if (this->current_node) + this->current_node->add_child(std::move(new_node)); + else + this->root = std::move(new_node); + this->current_node = new_node_ptr; + for (size_t i = 0; attribute[i]; i += 2) + this->current_node->set_attribute(attribute[i], attribute[i+1]); + if (this->level == 1) + this->stream_open_event(*this->current_node); +} + +void XmppParser::end_element(const XML_Char*) +{ + this->level--; + if (this->level == 0) + { // End of the whole stream + this->stream_close_event(*this->current_node); + this->current_node = nullptr; + this->root.reset(); + } + else + { + auto parent = this->current_node->get_parent(); + if (this->level == 1) + { // End of a stanza + this->stanza_event(*this->current_node); + // Note: deleting all the children of our parent deletes ourself, + // so current_node is an invalid pointer after this line + parent->delete_all_children(); + } + this->current_node = parent; + } +} + +void XmppParser::char_data(const XML_Char* data, const size_t len) +{ + if (this->current_node->has_children()) + this->current_node->get_last_child()->add_to_tail({data, len}); + else + this->current_node->add_to_inner({data, len}); +} + +void XmppParser::stanza_event(const Stanza& stanza) const +{ + for (const auto& callback: this->stanza_callbacks) + { + try { + callback(stanza); + } catch (const std::exception& e) { + log_error("Unhandled exception: ", e.what()); + } + } +} + +void XmppParser::stream_open_event(const XmlNode& node) const +{ + for (const auto& callback: this->stream_open_callbacks) + callback(node); +} + +void XmppParser::stream_close_event(const XmlNode& node) const +{ + for (const auto& callback: this->stream_close_callbacks) + callback(node); +} + +void XmppParser::add_stanza_callback(std::function<void(const Stanza&)>&& callback) +{ + this->stanza_callbacks.emplace_back(std::move(callback)); +} + +void XmppParser::add_stream_open_callback(std::function<void(const XmlNode&)>&& callback) +{ + this->stream_open_callbacks.emplace_back(std::move(callback)); +} + +void XmppParser::add_stream_close_callback(std::function<void(const XmlNode&)>&& callback) +{ + this->stream_close_callbacks.emplace_back(std::move(callback)); +} diff --git a/louloulibs/xmpp/xmpp_parser.hpp b/louloulibs/xmpp/xmpp_parser.hpp new file mode 100644 index 0000000..9d67228 --- /dev/null +++ b/louloulibs/xmpp/xmpp_parser.hpp @@ -0,0 +1,133 @@ +#pragma once + + +#include <xmpp/xmpp_stanza.hpp> + +#include <functional> + +#include <expat.h> + +/** + * A SAX XML parser that builds XML nodes and spawns events when a complete + * stanza is received (an element of level 2), or when the document is + * opened/closed (an element of level 1) + * + * After a stanza_event has been spawned, we delete the whole stanza. This + * means that even with a very long document (in XMPP the document is + * potentially infinite), the memory is never exhausted as long as each + * stanza is reasonnably short. + * + * The element names generated by expat contain the namespace of the + * element, a colon (':') and then the actual name of the element. To get + * an element "x" with a namespace of "http://jabber.org/protocol/muc", you + * just look for an XmlNode named "http://jabber.org/protocol/muc:x" + * + * TODO: enforce the size-limit for the stanza (limit the number of childs + * it can contain). For example forbid the parser going further than level + * 20 (arbitrary number here), and each XML node to have more than 15 childs + * (arbitrary number again). + */ +class XmppParser +{ +public: + explicit XmppParser(); + ~XmppParser(); + XmppParser(const XmppParser&) = delete; + XmppParser& operator=(const XmppParser&) = delete; + XmppParser(XmppParser&&) = delete; + XmppParser& operator=(XmppParser&&) = delete; + +public: + /** + * Feed the parser with some XML data + */ + int feed(const char* data, const int len, const bool is_final); + /** + * Parse the data placed in the parser buffer + */ + int parse(const int size, const bool is_final); + /** + * Reset the parser, so it can be used from scratch afterward + */ + void reset(); + /** + * Get a buffer provided by the xml parser. + */ + void* get_buffer(const size_t size) const; + /** + * Add one callback for the various events that this parser can spawn. + */ + void add_stanza_callback(std::function<void(const Stanza&)>&& callback); + void add_stream_open_callback(std::function<void(const XmlNode&)>&& callback); + void add_stream_close_callback(std::function<void(const XmlNode&)>&& callback); + + /** + * Called when a new XML element has been opened. We instanciate a new + * XmlNode and set it as our current node. The parent of this new node is + * the previous "current" node. We have all the element's attributes in + * this event. + * + * We spawn a stream_event with this node if this is a level-1 element. + */ + void start_element(const XML_Char* name, const XML_Char** attribute); + /** + * Called when an XML element has been closed. We close the current_node, + * set our current_node as the parent of the current_node, and if that was + * a level-2 element we spawn a stanza_event with this node. + * + * And we then delete the stanza (and everything under it, its children, + * attribute, etc). + */ + void end_element(const XML_Char* name); + /** + * Some inner or tail data has been parsed + */ + void char_data(const XML_Char* data, const size_t len); + /** + * Calls all the stanza_callbacks one by one. + */ + void stanza_event(const Stanza& stanza) const; + /** + * Calls all the stream_open_callbacks one by one. Note: the passed node is not + * closed yet. + */ + void stream_open_event(const XmlNode& node) const; + /** + * Calls all the stream_close_callbacks one by one. + */ + void stream_close_event(const XmlNode& node) const; + +private: + /** + * Init the XML parser and install the callbacks + */ + void init_xml_parser(); + + /** + * Expat structure. + */ + XML_Parser parser; + /** + * The current depth in the XML document + */ + size_t level; + /** + * The deepest XML node opened but not yet closed (to which we are adding + * new children, inner or tail) + */ + XmlNode* current_node; + /** + * The root node has no parent, so we keep it here: the XmppParser object + * is its owner. + */ + std::unique_ptr<XmlNode> root; + /** + * A list of callbacks to be called on an *_event, receiving the + * concerned Stanza/XmlNode. + */ + std::vector<std::function<void(const Stanza&)>> stanza_callbacks; + std::vector<std::function<void(const XmlNode&)>> stream_open_callbacks; + std::vector<std::function<void(const XmlNode&)>> stream_close_callbacks; +}; + + diff --git a/louloulibs/xmpp/xmpp_stanza.cpp b/louloulibs/xmpp/xmpp_stanza.cpp new file mode 100644 index 0000000..ac6ce9b --- /dev/null +++ b/louloulibs/xmpp/xmpp_stanza.cpp @@ -0,0 +1,229 @@ +#include <xmpp/xmpp_stanza.hpp> + +#include <utils/encoding.hpp> +#include <utils/split.hpp> + +#include <stdexcept> +#include <iostream> +#include <sstream> + +#include <string.h> + +std::string xml_escape(const std::string& data) +{ + std::string res; + res.reserve(data.size()); + for (size_t pos = 0; pos != data.size(); ++pos) + { + switch(data[pos]) + { + case '&': + res += "&"; + break; + case '<': + res += "<"; + break; + case '>': + res += ">"; + break; + case '\"': + res += """; + break; + case '\'': + res += "'"; + break; + default: + res += data[pos]; + break; + } + } + return res; +} + +std::string sanitize(const std::string& data, const std::string& encoding) +{ + if (utils::is_valid_utf8(data.data())) + return xml_escape(utils::remove_invalid_xml_chars(data)); + else + return xml_escape(utils::remove_invalid_xml_chars(utils::convert_to_utf8(data, encoding.data()))); +} + +XmlNode::XmlNode(const std::string& name, XmlNode* parent): + parent(parent) +{ + // split the namespace and the name + auto n = name.rfind(":"); + if (n == std::string::npos) + this->name = name; + else + { + this->name = name.substr(n+1); + this->attributes["xmlns"] = name.substr(0, n); + } +} + +XmlNode::XmlNode(const std::string& name): + XmlNode(name, nullptr) +{ +} + +void XmlNode::delete_all_children() +{ + this->children.clear(); +} + +void XmlNode::set_attribute(const std::string& name, const std::string& value) +{ + this->attributes[name] = value; +} + +void XmlNode::set_tail(const std::string& data) +{ + this->tail = data; +} + +void XmlNode::add_to_tail(const std::string& data) +{ + this->tail += data; +} + +void XmlNode::set_inner(const std::string& data) +{ + this->inner = data; +} + +void XmlNode::add_to_inner(const std::string& data) +{ + this->inner += data; +} + +std::string XmlNode::get_inner() const +{ + return this->inner; +} + +std::string XmlNode::get_tail() const +{ + return this->tail; +} + +const XmlNode* XmlNode::get_child(const std::string& name, const std::string& xmlns) const +{ + for (const auto& child: this->children) + { + if (child->name == name && child->get_tag("xmlns") == xmlns) + return child.get(); + } + return nullptr; +} + +std::vector<const XmlNode*> XmlNode::get_children(const std::string& name, const std::string& xmlns) const +{ + std::vector<const XmlNode*> res; + for (const auto& child: this->children) + { + if (child->name == name && child->get_tag("xmlns") == xmlns) + res.push_back(child.get()); + } + return res; +} + +XmlNode* XmlNode::add_child(std::unique_ptr<XmlNode> child) +{ + child->parent = this; + auto ret = child.get(); + this->children.push_back(std::move(child)); + return ret; +} + +XmlNode* XmlNode::add_child(XmlNode&& child) +{ + auto new_node = std::make_unique<XmlNode>(std::move(child)); + return this->add_child(std::move(new_node)); +} + +XmlNode* XmlNode::add_child(const XmlNode& child) +{ + auto new_node = std::make_unique<XmlNode>(child); + return this->add_child(std::move(new_node)); +} + +XmlNode* XmlNode::get_last_child() const +{ + return this->children.back().get(); +} + +XmlNode* XmlNode::get_parent() const +{ + return this->parent; +} + +void XmlNode::set_name(const std::string& name) +{ + this->name = name; +} + +void XmlNode::set_name(std::string&& name) +{ + this->name = std::move(name); +} + +const std::string XmlNode::get_name() const +{ + return this->name; +} + +std::string XmlNode::to_string() const +{ + std::ostringstream res; + res << "<" << this->name; + for (const auto& it: this->attributes) + res << " " << it.first << "='" << sanitize(it.second) + "'"; + if (!this->has_children() && this->inner.empty()) + res << "/>"; + else + { + res << ">" + sanitize(this->inner); + for (const auto& child: this->children) + res << child->to_string(); + res << "</" << this->get_name() << ">"; + } + res << sanitize(this->tail); + return res.str(); +} + +bool XmlNode::has_children() const +{ + return !this->children.empty(); +} + +const std::string& XmlNode::get_tag(const std::string& name) const +{ + try + { + const auto& value = this->attributes.at(name); + return value; + } + catch (const std::out_of_range& e) + { + static const std::string def{}; + return def; + } +} + +bool XmlNode::del_tag(const std::string& name) +{ + if (this->attributes.erase(name) != 0) + return true; + return false; +} + +std::string& XmlNode::operator[](const std::string& name) +{ + return this->attributes[name]; +} + +std::ostream& operator<<(std::ostream& os, const XmlNode& node) +{ + return os << node.to_string(); +} diff --git a/louloulibs/xmpp/xmpp_stanza.hpp b/louloulibs/xmpp/xmpp_stanza.hpp new file mode 100644 index 0000000..4ca758e --- /dev/null +++ b/louloulibs/xmpp/xmpp_stanza.hpp @@ -0,0 +1,146 @@ +#pragma once + + +#include <map> +#include <string> +#include <vector> +#include <memory> + +std::string xml_escape(const std::string& data); +std::string xml_unescape(const std::string& data); +std::string sanitize(const std::string& data, const std::string& encoding = "ISO-8859-1"); + +/** + * Represent an XML node. It has + * - A parent XML node (in the case of the first-level nodes, the parent is + nullptr) + * - zero, one or more children XML nodes + * - A name + * - A map of attributes + * - inner data (text inside the node) + * - tail data (text just after the node) + */ +class XmlNode +{ +public: + explicit XmlNode(const std::string& name, XmlNode* parent); + explicit XmlNode(const std::string& name); + /** + * The copy constructor does not copy the parent attribute. The children + * nodes are all copied recursively. + */ + XmlNode(const XmlNode& node): + name(node.name), + parent(nullptr), + attributes(node.attributes), + children{}, + inner(node.inner), + tail(node.tail) + { + for (const auto& child: node.children) + this->add_child(std::make_unique<XmlNode>(*child)); + } + + XmlNode(XmlNode&& node) = default; + XmlNode& operator=(const XmlNode&) = delete; + XmlNode& operator=(XmlNode&&) = delete; + + ~XmlNode() = default; + + void delete_all_children(); + void set_attribute(const std::string& name, const std::string& value); + /** + * Set the content of the tail, that is the text just after this node + */ + void set_tail(const std::string& data); + /** + * Append the given data to the content of the tail. This exists because + * the expat library may provide the complete text of an element in more + * than one call + */ + void add_to_tail(const std::string& data); + /** + * Set the content of the inner, that is the text inside this node. + */ + void set_inner(const std::string& data); + /** + * Append the given data to the content of the inner. For the reason + * described in add_to_tail comment. + */ + void add_to_inner(const std::string& data); + /** + * Get the content of inner + */ + std::string get_inner() const; + /** + * Get the content of the tail + */ + std::string get_tail() const; + /** + * Get a pointer to the first child element with that name and that xml namespace + */ + const XmlNode* get_child(const std::string& name, const std::string& xmlns) const; + /** + * Get a vector of all the children that have that name and that xml namespace. + */ + std::vector<const XmlNode*> get_children(const std::string& name, const std::string& xmlns) const; + /** + * Add a node child to this node. Assign this node to the child’s parent. + * Returns a pointer to the newly added child. + */ + XmlNode* add_child(std::unique_ptr<XmlNode> child); + XmlNode* add_child(XmlNode&& child); + XmlNode* add_child(const XmlNode& child); + /** + * Returns the last of the children. If the node doesn't have any child, + * the behaviour is undefined. The user should make sure this is the case + * by calling has_children() for example. + */ + XmlNode* get_last_child() const; + XmlNode* get_parent() const; + void set_name(const std::string& name); + void set_name(std::string&& name); + const std::string get_name() const; + /** + * Serialize the stanza into a string + */ + std::string to_string() const; + /** + * Whether or not this node has at least one child (if not, this is a leaf + * node) + */ + bool has_children() const; + /** + * Gets the value for the given attribute, returns an empty string if the + * node as no such attribute. + */ + const std::string& get_tag(const std::string& name) const; + /** + * Remove the attribute of the node. Does nothing if that attribute is not + * present. Returns true if the tag was removed, false if it was absent. + */ + bool del_tag(const std::string& name); + /** + * Use this to set an attribute's value, like node["id"] = "12"; + */ + std::string& operator[](const std::string& name); + +private: + std::string name; + XmlNode* parent; + std::map<std::string, std::string> attributes; + std::vector<std::unique_ptr<XmlNode>> children; + std::string inner; + std::string tail; +}; + +std::ostream& operator<<(std::ostream& os, const XmlNode& node); + +/** + * An XMPP stanza is just an XML node of level 2 in the XMPP document (the + * level 1 ones are the <stream::stream/>, and the ones above 2 are just the + * content of the stanzas) + */ +using Stanza = XmlNode; + + |