diff options
Diffstat (limited to 'src')
51 files changed, 1040 insertions, 543 deletions
diff --git a/src/bridge/bridge.cpp b/src/bridge/bridge.cpp index 54bee84..cc2ef66 100644 --- a/src/bridge/bridge.cpp +++ b/src/bridge/bridge.cpp @@ -5,6 +5,7 @@ #include <utils/empty_if_fixed_server.hpp> #include <utils/encoding.hpp> #include <utils/tolower.hpp> +#include <utils/uuid.hpp> #include <logger/logger.hpp> #include <utils/revstr.hpp> #include <utils/split.hpp> @@ -62,8 +63,8 @@ void Bridge::shutdown(const std::string& exit_message) { for (auto& pair: this->irc_clients) { - pair.second->send_quit_command(exit_message); - pair.second->leave_dummy_channel(exit_message, {}); + std::unique_ptr<IrcClient>& irc = pair.second; + irc->send_quit_command(exit_message); } } @@ -103,7 +104,7 @@ const std::string& Bridge::get_jid() const std::string Bridge::get_bare_jid() const { Jid jid(this->user_jid); - return jid.local + "@" + jid.domain; + return jid.bare(); } Xmpp::body Bridge::make_xmpp_body(const std::string& str, const std::string& encoding) @@ -133,11 +134,11 @@ IrcClient* Bridge::make_irc_client(const std::string& hostname, const std::strin realname = this->get_bare_jid(); } this->irc_clients.emplace(hostname, - std::make_shared<IrcClient>(this->poller, hostname, + std::make_unique<IrcClient>(this->poller, hostname, nickname, username, realname, jid.domain, *this)); - std::shared_ptr<IrcClient> irc = this->irc_clients.at(hostname); + std::unique_ptr<IrcClient>& irc = this->irc_clients.at(hostname); return irc.get(); } } @@ -166,56 +167,36 @@ IrcClient* Bridge::find_irc_client(const std::string& hostname) const } } -bool Bridge::join_irc_channel(const Iid& iid, const std::string& nickname, const std::string& password, - const std::string& resource, HistoryLimit history_limit) +bool Bridge::join_irc_channel(const Iid& iid, std::string nickname, + const std::string& password, + const std::string& resource, + HistoryLimit history_limit, + const bool force_join) { const auto& hostname = iid.get_server(); +#ifdef USE_DATABASE + auto soptions = Database::get_irc_server_options(this->get_bare_jid(), hostname); + if (!soptions.col<Database::Nick>().empty()) + nickname = soptions.col<Database::Nick>(); +#endif IrcClient* irc = this->make_irc_client(hostname, nickname); irc->history_limit = history_limit; this->add_resource_to_server(hostname, resource); auto res_in_chan = this->is_resource_in_chan(ChannelKey{iid.get_local(), hostname}, resource); if (!res_in_chan) this->add_resource_to_chan(ChannelKey{iid.get_local(), hostname}, resource); - if (iid.get_local().empty()) - { // Join the dummy channel - if (irc->is_welcomed()) - { - if (res_in_chan) - return false; - // Immediately simulate a message coming from the IRC server saying that we - // joined the channel - if (irc->get_dummy_channel().joined) - { - this->generate_channel_join_for_resource(iid, resource); - } - else - { - const IrcMessage join_message(irc->get_nick(), "JOIN", {""}); - irc->on_channel_join(join_message); - const IrcMessage end_join_message(std::string(iid.get_server()), "366", - {irc->get_nick(), - "", "End of NAMES list"}); - irc->on_channel_completely_joined(end_join_message); - } - } - else - { - irc->get_dummy_channel().joining = true; - irc->start(); - } - return true; - } if (irc->is_channel_joined(iid.get_local()) == false) { irc->send_join_command(iid.get_local(), password); return true; - } else if (!res_in_chan) { + } else if (!res_in_chan || force_join) { + // See https://github.com/xsf/xeps/pull/499 for the force_join argument this->generate_channel_join_for_resource(iid, resource); } return false; } -void Bridge::send_channel_message(const Iid& iid, const std::string& body) +void Bridge::send_channel_message(const Iid& iid, const std::string& body, std::string id) { if (iid.get_server().empty()) { @@ -240,8 +221,26 @@ void Bridge::send_channel_message(const Iid& iid, const std::string& body) std::vector<std::string> lines = utils::split(body, '\n', true); if (lines.empty()) return ; + bool first = true; for (const std::string& line: lines) { + std::string uuid; +#ifdef USE_DATABASE + const auto xmpp_body = this->make_xmpp_body(line); + if (this->record_history) + uuid = Database::store_muc_message(this->get_bare_jid(), iid.get_local(), iid.get_server(), std::chrono::system_clock::now(), + std::get<0>(xmpp_body), irc->get_own_nick()); +#endif + if (!first || id.empty()) + id = utils::gen_uuid(); + + MessageCallback mirror_to_all_resources = [this, iid, uuid, id](const IrcClient* irc, const IrcMessage& message) { + const std::string& line = message.arguments[1]; + for (const auto& resource: this->resources_in_chan[iid.to_tuple()]) + this->xmpp.send_muc_message(std::to_string(iid), irc->get_own_nick(), this->make_xmpp_body(line), + this->user_jid + "/" + resource, uuid, id); + }; + if (line.substr(0, 5) == "/mode") { std::vector<std::string> args = utils::split(line.substr(5), ' ', false); @@ -250,20 +249,12 @@ void Bridge::send_channel_message(const Iid& iid, const std::string& body) // XMPP user, that’s not a textual message. } else if (line.substr(0, 4) == "/me ") - irc->send_channel_message(iid.get_local(), action_prefix + line.substr(4) + "\01"); + irc->send_channel_message(iid.get_local(), action_prefix + line.substr(4) + "\01", + std::move(mirror_to_all_resources)); else - irc->send_channel_message(iid.get_local(), line); + irc->send_channel_message(iid.get_local(), line, std::move(mirror_to_all_resources)); - std::string uuid; -#ifdef USE_DATABASE - const auto xmpp_body = this->make_xmpp_body(line); - if (this->record_history) - uuid = Database::store_muc_message(this->get_bare_jid(), iid.get_local(), iid.get_server(), std::chrono::system_clock::now(), - std::get<0>(xmpp_body), irc->get_own_nick()); -#endif - for (const auto& resource: this->resources_in_chan[iid.to_tuple()]) - this->xmpp.send_muc_message(std::to_string(iid), irc->get_own_nick(), this->make_xmpp_body(line), - this->user_jid + "/" + resource, uuid); + first = false; } } @@ -445,15 +436,11 @@ void Bridge::leave_irc_channel(Iid&& iid, const std::string& status_message, con #endif if (channel->joined && !channel->parting && !persistent) { - const auto& chan_name = iid.get_local(); - if (chan_name.empty()) - irc->leave_dummy_channel(status_message, resource); - else - irc->send_part_command(iid.get_local(), status_message); + irc->send_part_command(iid.get_local(), status_message); } else if (channel->joined) { - this->send_muc_leave(iid, channel->get_self()->nick, "", true, true, resource); + this->send_muc_leave(iid, *channel->get_self(), "", true, true, resource, irc); } if (persistent) this->remove_resource_from_chan(key, resource); @@ -464,14 +451,13 @@ void Bridge::leave_irc_channel(Iid&& iid, const std::string& status_message, con else { if (channel && channel->joined) - this->send_muc_leave(iid, channel->get_self()->nick, + this->send_muc_leave(iid, *channel->get_self(), "Biboumi note: " + std::to_string(resources - 1) + " resources are still in this channel.", - true, true, resource); + true, true, resource, irc); this->remove_resource_from_chan(key, resource); } - if (this->number_of_channels_the_resource_is_in(iid.get_server(), resource) == 0) - this->remove_resource_from_server(iid.get_server(), resource); - + if (this->number_of_channels_the_resource_is_in(iid.get_server(), resource) == 0) + this->remove_resource_from_server(iid.get_server(), resource); } void Bridge::send_irc_nick_change(const Iid& iid, const std::string& new_nick, const std::string& requesting_resource) @@ -836,22 +822,24 @@ void Bridge::send_irc_version_request(const std::string& irc_hostname, const std this->add_waiting_irc(std::move(cb)); } -void Bridge::send_message(const Iid& iid, const std::string& nick, const std::string& body, const bool muc) +void Bridge::send_message(const Iid& iid, const std::string& nick, const std::string& body, const bool muc, const bool log) { const auto encoding = in_encoding_for(*this, iid); + std::string uuid{}; if (muc) { #ifdef USE_DATABASE const auto xmpp_body = this->make_xmpp_body(body, encoding); - if (!nick.empty() && this->record_history) - Database::store_muc_message(this->get_bare_jid(), iid.get_local(), iid.get_server(), std::chrono::system_clock::now(), - std::get<0>(xmpp_body), nick); + if (log && this->record_history) + uuid = Database::store_muc_message(this->get_bare_jid(), iid.get_local(), iid.get_server(), std::chrono::system_clock::now(), + std::get<0>(xmpp_body), nick); +#else + (void)log; #endif for (const auto& resource: this->resources_in_chan[iid.to_tuple()]) { this->xmpp.send_muc_message(std::to_string(iid), nick, this->make_xmpp_body(body, encoding), - this->user_jid + "/" + resource, {}); - + this->user_jid + "/" + resource, uuid, utils::gen_uuid()); } } else @@ -881,19 +869,24 @@ void Bridge::send_presence_error(const Iid& iid, const std::string& nick, this->xmpp.send_presence_error(std::to_string(iid), nick, this->user_jid, type, condition, error_code, text); } -void Bridge::send_muc_leave(const Iid& iid, const std::string& nick, +void Bridge::send_muc_leave(const Iid& iid, const IrcUser& user, const std::string& message, const bool self, const bool user_requested, - const std::string& resource) + const std::string& resource, + const IrcClient* client) { + std::string affiliation; + std::string role; + std::tie(role, affiliation) = get_role_affiliation_from_irc_mode(user.get_most_significant_mode(client->get_sorted_user_modes())); + if (!resource.empty()) - this->xmpp.send_muc_leave(std::to_string(iid), nick, this->make_xmpp_body(message), - this->user_jid + "/" + resource, self, user_requested); + this->xmpp.send_muc_leave(std::to_string(iid), user.nick, this->make_xmpp_body(message), + this->user_jid + "/" + resource, self, user_requested, affiliation, role); else { for (const auto &res: this->resources_in_chan[iid.to_tuple()]) - this->xmpp.send_muc_leave(std::to_string(iid), nick, this->make_xmpp_body(message), - this->user_jid + "/" + res, self, user_requested); + this->xmpp.send_muc_leave(std::to_string(iid), user.nick, this->make_xmpp_body(message), + this->user_jid + "/" + res, self, user_requested, affiliation, role); if (self) { // Copy the resources currently in that channel @@ -906,9 +899,7 @@ void Bridge::send_muc_leave(const Iid& iid, const std::string& nick, for (const auto& r: resources_in_chan) if (this->number_of_channels_the_resource_is_in(iid.get_server(), r) == 0) this->remove_resource_from_server(iid.get_server(), r); - } - } IrcClient* irc = this->find_irc_client(iid.get_server()); if (self && irc && irc->number_of_joined_channels() == 0) @@ -1016,11 +1007,14 @@ void Bridge::send_room_history(const std::string& hostname, const std::string& c void Bridge::send_room_history(const std::string& hostname, std::string chan_name, const std::string& resource, const HistoryLimit& history_limit) { #ifdef USE_DATABASE - const auto coptions = Database::get_irc_channel_options_with_server_and_global_default(this->user_jid, hostname, chan_name); - auto limit = coptions.col<Database::MaxHistoryLength>(); + const auto goptions = Database::get_global_options(this->user_jid); + auto limit = goptions.col<Database::MaxHistoryLength>(); + if (limit < 0) + limit = 20; if (history_limit.stanzas >= 0 && history_limit.stanzas < limit) limit = history_limit.stanzas; - const auto lines = Database::get_muc_logs(this->user_jid, chan_name, hostname, limit, history_limit.since); + const auto result = Database::get_muc_logs(this->user_jid, chan_name, hostname, static_cast<std::size_t>(limit), history_limit.since, {}, Id::unset_value, Database::Paging::last); + const auto& lines = std::get<1>(result); chan_name.append(utils::empty_if_fixed_server("%" + hostname)); for (const auto& line: lines) { @@ -1156,12 +1150,12 @@ void Bridge::trigger_on_irc_message(const std::string& irc_hostname, const IrcMe } } -std::unordered_map<std::string, std::shared_ptr<IrcClient>>& Bridge::get_irc_clients() +std::unordered_map<std::string, std::unique_ptr<IrcClient>>& Bridge::get_irc_clients() { return this->irc_clients; } -const std::unordered_map<std::string, std::shared_ptr<IrcClient>>& Bridge::get_irc_clients() const +const std::unordered_map<std::string, std::unique_ptr<IrcClient>>& Bridge::get_irc_clients() const { return this->irc_clients; } @@ -1228,15 +1222,6 @@ void Bridge::remove_resource_from_server(const Bridge::IrcHostname& irc_hostname } } -bool Bridge::is_resource_in_server(const Bridge::IrcHostname& irc_hostname, const std::string& resource) const -{ - auto it = this->resources_in_server.find(irc_hostname); - if (it != this->resources_in_server.end()) - if (it->second.count(resource) == 1) - return true; - return false; -} - std::size_t Bridge::number_of_resources_in_chan(const Bridge::ChannelKey& channel) const { auto it = this->resources_in_chan.find(channel); @@ -1254,9 +1239,6 @@ std::size_t Bridge::number_of_channels_the_resource_is_in(const std::string& irc res++; } - IrcClient* irc = this->find_irc_client(irc_hostname); - if (irc && (irc->get_dummy_channel().joined || irc->get_dummy_channel().joining)) - res++; return res; } diff --git a/src/bridge/bridge.hpp b/src/bridge/bridge.hpp index c2f0233..5c547ff 100644 --- a/src/bridge/bridge.hpp +++ b/src/bridge/bridge.hpp @@ -75,9 +75,13 @@ public: * Try to join an irc_channel, does nothing and return true if the channel * was already joined. */ - bool join_irc_channel(const Iid& iid, const std::string& nickname, const std::string& password, const std::string& resource, HistoryLimit history_limit); + bool join_irc_channel(const Iid& iid, std::string nickname, + const std::string& password, + const std::string& resource, + HistoryLimit history_limit, + const bool force_join); - void send_channel_message(const Iid& iid, const std::string& body); + void send_channel_message(const Iid& iid, const std::string& body, std::string id); void send_private_message(const Iid& iid, const std::string& body, const std::string& type="PRIVMSG"); void send_raw_message(const std::string& hostname, const std::string& body); void leave_irc_channel(Iid&& iid, const std::string& status_message, const std::string& resource); @@ -162,7 +166,7 @@ public: /** * Send a MUC message from some participant */ - void send_message(const Iid& iid, const std::string& nick, const std::string& body, const bool muc); + void send_message(const Iid& iid, const std::string& nick, const std::string& body, const bool muc, const bool log=true); /** * Send a presence of type error, from a room. */ @@ -170,10 +174,11 @@ public: /** * Send an unavailable presence from this participant */ - void send_muc_leave(const Iid& iid, const std::string& nick, + void send_muc_leave(const Iid& iid, const IrcUser& nick, const std::string& message, const bool self, const bool user_requested, - const std::string& resource=""); + const std::string& resource, + const IrcClient* client); /** * Send presences to indicate that an user old_nick (ourself if self == * true) changed his nick to new_nick. The user_mode is needed because @@ -236,8 +241,8 @@ public: * iq_responder_callback_t and remove the callback from the list. */ void trigger_on_irc_message(const std::string& irc_hostname, const IrcMessage& message); - std::unordered_map<std::string, std::shared_ptr<IrcClient>>& get_irc_clients(); - const std::unordered_map<std::string, std::shared_ptr<IrcClient>>& get_irc_clients() const; + std::unordered_map<std::string, std::unique_ptr<IrcClient>>& get_irc_clients(); + const std::unordered_map<std::string, std::unique_ptr<IrcClient>>& get_irc_clients() const; std::set<char> get_chantypes(const std::string& hostname) const; #ifdef USE_DATABASE void set_record_history(const bool val); @@ -270,7 +275,7 @@ private: * One IrcClient for each IRC server we need to be connected to. * The pointer is shared by the bridge and the poller. */ - std::unordered_map<std::string, std::shared_ptr<IrcClient>> irc_clients; + std::unordered_map<std::string, std::unique_ptr<IrcClient>> irc_clients; /** * To communicate back with the XMPP component */ @@ -311,13 +316,14 @@ private: */ void add_resource_to_chan(const ChannelKey& channel, const std::string& resource); void remove_resource_from_chan(const ChannelKey& channel, const std::string& resource); +public: bool is_resource_in_chan(const ChannelKey& channel, const std::string& resource) const; +private: void remove_all_resources_from_chan(const ChannelKey& channel); std::size_t number_of_resources_in_chan(const ChannelKey& channel) const; void add_resource_to_server(const IrcHostname& irc_hostname, const std::string& resource); void remove_resource_from_server(const IrcHostname& irc_hostname, const std::string& resource); - bool is_resource_in_server(const IrcHostname& irc_hostname, const std::string& resource) const; size_t number_of_channels_the_resource_is_in(const std::string& irc_hostname, const std::string& resource) const; /** diff --git a/src/bridge/history_limit.hpp b/src/bridge/history_limit.hpp index 9c75256..93e36e1 100644 --- a/src/bridge/history_limit.hpp +++ b/src/bridge/history_limit.hpp @@ -1,5 +1,7 @@ #pragma once +#include <string> + // Default values means no limit struct HistoryLimit { diff --git a/src/config/config.cpp b/src/config/config.cpp index 412b170..2f64b9e 100644 --- a/src/config/config.cpp +++ b/src/config/config.cpp @@ -1,10 +1,12 @@ #include <config/config.hpp> #include <utils/tolower.hpp> +#include <utils/split.hpp> -#include <iostream> -#include <cstring> - +#include <algorithm> #include <cstdlib> +#include <cstring> +#include <iostream> +#include <vector> using namespace std::string_literals; @@ -40,6 +42,15 @@ int Config::get_int(const std::string& option, const int& def) return def; } +bool Config::is_in_list(const std::string& option, const std::string& value) +{ + std::string res = Config::get(option, ""); + if (res.empty()) + return false; + std::vector<std::string> list = utils::split(res, ':'); + return std::find(list.cbegin(), list.cend(), value) != list.cend(); +} + void Config::set(const std::string& option, const std::string& value, bool save) { Config::values[option] = value; diff --git a/src/config/config.hpp b/src/config/config.hpp index c5ef15d..9c28e8c 100644 --- a/src/config/config.hpp +++ b/src/config/config.hpp @@ -46,6 +46,11 @@ public: static int get_int(const std::string&, const int&); static bool get_bool(const std::string&, const bool); /** + * Returns true if value is present in a colon-separated list, otherwise + * false. + */ + static bool is_in_list(const std::string& option, const std::string& value); + /** * Set a value for the given option. And write all the config * in the file from which it was read if save is true. */ diff --git a/src/database/column.hpp b/src/database/column.hpp index 1f16bcf..837aa3f 100644 --- a/src/database/column.hpp +++ b/src/database/column.hpp @@ -9,14 +9,30 @@ struct Column value{default_value} {} Column(): value{} {} + void clear() + { + this->value = {}; + } using real_type = T; T value{}; }; -struct Id: Column<std::size_t> { +template <typename T> +struct UnclearableColumn: public Column<T> +{ + using Column<T>::Column; + void clear() + { } +}; + +struct ForeignKey: Column<std::size_t> { + static constexpr auto name = "fk_"; +}; + +struct Id: UnclearableColumn<std::size_t> { static constexpr std::size_t unset_value = static_cast<std::size_t>(-1); static constexpr auto name = "id_"; static constexpr auto options = "PRIMARY KEY"; - Id(): Column<std::size_t>(-1) {} + Id(): UnclearableColumn<std::size_t>(unset_value) {} }; diff --git a/src/database/database.cpp b/src/database/database.cpp index 3622963..9037ce1 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -1,10 +1,12 @@ #include "biboumi.h" #ifdef USE_DATABASE +#include <database/select_query.hpp> +#include <database/save.hpp> #include <database/database.hpp> -#include <uuid/uuid.h> #include <utils/get_first_non_empty.hpp> #include <utils/time.hpp> +#include <utils/uuid.hpp> #include <config/config.hpp> #include <database/sqlite3_engine.hpp> @@ -21,6 +23,7 @@ Database::GlobalOptionsTable Database::global_options("globaloptions_"); Database::IrcServerOptionsTable Database::irc_server_options("ircserveroptions_"); Database::IrcChannelOptionsTable Database::irc_channel_options("ircchanneloptions_"); Database::RosterTable Database::roster("roster"); +Database::AfterConnectionCommandsTable Database::after_connection_commands("after_connection_commands_"); std::map<Database::CacheKey, Database::EncodingIn::real_type> Database::encoding_in_cache{}; Database::GlobalPersistent::GlobalPersistent(): @@ -53,57 +56,80 @@ void Database::open(const std::string& filename) Database::irc_channel_options.upgrade(*Database::db); Database::roster.create(*Database::db); Database::roster.upgrade(*Database::db); + Database::after_connection_commands.create(*Database::db); + Database::after_connection_commands.upgrade(*Database::db); create_index<Database::Owner, Database::IrcChanName, Database::IrcServerName>(*Database::db, "archive_index", Database::muc_log_lines.get_name()); } Database::GlobalOptions Database::get_global_options(const std::string& owner) { - auto request = Database::global_options.select(); + auto request = select(Database::global_options); request.where() << Owner{} << "=" << owner; - Database::GlobalOptions options{Database::global_options.get_name()}; auto result = request.execute(*Database::db); if (result.size() == 1) - options = result.front(); - else - options.col<Owner>() = owner; + return result.front(); + Database::GlobalOptions options{Database::global_options.get_name()}; + options.col<Owner>() = owner; return options; } Database::IrcServerOptions Database::get_irc_server_options(const std::string& owner, const std::string& server) { - auto request = Database::irc_server_options.select(); + auto request = select(Database::irc_server_options); request.where() << Owner{} << "=" << owner << " and " << Server{} << "=" << server; - Database::IrcServerOptions options{Database::irc_server_options.get_name()}; auto result = request.execute(*Database::db); if (result.size() == 1) - options = result.front(); - else + return result.front(); + Database::IrcServerOptions options{Database::irc_server_options.get_name()}; + options.col<Owner>() = owner; + options.col<Server>() = server; + return options; +} + +Database::AfterConnectionCommands Database::get_after_connection_commands(const IrcServerOptions& server_options) +{ + const auto id = server_options.col<Id>(); + if (id == Id::unset_value) + return {}; + auto request = select(Database::after_connection_commands); + request.where() << ForeignKey{} << "=" << id; + return request.execute(*Database::db); +} + +void Database::set_after_connection_commands(const Database::IrcServerOptions& server_options, Database::AfterConnectionCommands& commands) +{ + const auto id = server_options.col<Id>(); + if (id == Id::unset_value) + return ; + + Transaction transaction; + auto query = Database::after_connection_commands.del(); + query.where() << ForeignKey{} << "=" << id; + query.execute(*Database::db); + + for (auto& command: commands) { - options.col<Owner>() = owner; - options.col<Server>() = server; + command.col<ForeignKey>() = server_options.col<Id>(); + save(command, *Database::db); } - return options; } Database::IrcChannelOptions Database::get_irc_channel_options(const std::string& owner, const std::string& server, const std::string& channel) { - auto request = Database::irc_channel_options.select(); + auto request = select(Database::irc_channel_options); request.where() << Owner{} << "=" << owner <<\ " and " << Server{} << "=" << server <<\ " and " << Channel{} << "=" << channel; - Database::IrcChannelOptions options{Database::irc_channel_options.get_name()}; auto result = request.execute(*Database::db); if (result.size() == 1) - options = result.front(); - else - { - options.col<Owner>() = owner; - options.col<Server>() = server; - options.col<Channel>() = channel; - } + return result.front(); + Database::IrcChannelOptions options{Database::irc_channel_options.get_name()}; + options.col<Owner>() = owner; + options.col<Server>() = server; + options.col<Channel>() = channel; return options; } @@ -136,10 +162,6 @@ Database::IrcChannelOptions Database::get_irc_channel_options_with_server_and_gl coptions.col<EncodingOut>() = get_first_non_empty(coptions.col<EncodingOut>(), soptions.col<EncodingOut>()); - coptions.col<MaxHistoryLength>() = get_first_non_empty(coptions.col<MaxHistoryLength>(), - soptions.col<MaxHistoryLength>(), - goptions.col<MaxHistoryLength>()); - return coptions; } @@ -159,15 +181,15 @@ std::string Database::store_muc_message(const std::string& owner, const std::str line.col<Body>() = body; line.col<Nick>() = nick; - line.save(Database::db); + save(line, *Database::db); return uuid; } -std::vector<Database::MucLogLine> Database::get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server, - int limit, const std::string& start, const std::string& end) +std::tuple<bool, std::vector<Database::MucLogLine>> Database::get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server, + std::size_t limit, const std::string& start, const std::string& end, const Id::real_type reference_record_id, Database::Paging paging) { - auto request = Database::muc_log_lines.select(); + auto request = select(Database::muc_log_lines); request.where() << Database::Owner{} << "=" << owner << \ " and " << Database::IrcChanName{} << "=" << chan_name << \ " and " << Database::IrcServerName{} << "=" << server; @@ -184,15 +206,70 @@ std::vector<Database::MucLogLine> Database::get_muc_logs(const std::string& owne if (end_time != -1) request << " and " << Database::Date{} << "<=" << end_time; } + if (reference_record_id != Id::unset_value) + { + request << " and " << Id{}; + if (paging == Database::Paging::first) + request << ">"; + else + request << "<"; + request << reference_record_id; + } - request.order_by() << Id{} << " DESC "; + if (paging == Database::Paging::first) + request.order_by() << Id{} << " ASC "; + else + request.order_by() << Id{} << " DESC "; - if (limit >= 0) - request.limit() << limit; + // Just a simple trick: to know whether we got the totality of the + // possible results matching this query (except for the limit), we just + // ask one more element. If we get that additional element, this means + // we don’t have everything. And then we just discard it. If we don’t + // have more, this means we have everything. + request.limit() << limit + 1; auto result = request.execute(*Database::db); + bool complete = true; - return {result.crbegin(), result.crend()}; + if (result.size() == limit + 1) + { + complete = false; + result.erase(std::prev(result.end())); + } + + if (paging == Database::Paging::first) + return std::make_tuple(complete, result); + else + return std::make_tuple(complete, std::vector<Database::MucLogLine>(result.crbegin(), result.crend())); +} + +Database::MucLogLine Database::get_muc_log(const std::string& owner, const std::string& chan_name, const std::string& server, + const std::string& uuid, const std::string& start, const std::string& end) +{ + auto request = select(Database::muc_log_lines); + request.where() << Database::Owner{} << "=" << owner << \ + " and " << Database::IrcChanName{} << "=" << chan_name << \ + " and " << Database::IrcServerName{} << "=" << server << \ + " and " << Database::Uuid{} << "=" << uuid; + + if (!start.empty()) + { + const auto start_time = utils::parse_datetime(start); + if (start_time != -1) + request << " and " << Database::Date{} << ">=" << start_time; + } + if (!end.empty()) + { + const auto end_time = utils::parse_datetime(end); + if (end_time != -1) + request << " and " << Database::Date{} << "<=" << end_time; + } + + auto result = request.execute(*Database::db); + + if (result.empty()) + throw Database::RecordNotFound{}; + return result.front(); } void Database::add_roster_item(const std::string& local, const std::string& remote) @@ -202,7 +279,7 @@ void Database::add_roster_item(const std::string& local, const std::string& remo roster_item.col<Database::LocalJid>() = local; roster_item.col<Database::RemoteJid>() = remote; - roster_item.save(Database::db); + save(roster_item, *Database::db); } void Database::delete_roster_item(const std::string& local, const std::string& remote) @@ -216,7 +293,7 @@ void Database::delete_roster_item(const std::string& local, const std::string& r bool Database::has_roster_item(const std::string& local, const std::string& remote) { - auto query = Database::roster.select(); + auto query = select(Database::roster); query.where() << Database::LocalJid{} << "=" << local << \ " and " << Database::RemoteJid{} << "=" << remote; @@ -227,7 +304,7 @@ bool Database::has_roster_item(const std::string& local, const std::string& remo std::vector<Database::RosterItem> Database::get_contact_list(const std::string& local) { - auto query = Database::roster.select(); + auto query = select(Database::roster); query.where() << Database::LocalJid{} << "=" << local; return query.execute(*Database::db); @@ -235,7 +312,7 @@ std::vector<Database::RosterItem> Database::get_contact_list(const std::string& std::vector<Database::RosterItem> Database::get_full_roster() { - auto query = Database::roster.select(); + auto query = select(Database::roster); return query.execute(*Database::db); } @@ -247,11 +324,25 @@ void Database::close() std::string Database::gen_uuid() { - char uuid_str[37]; - uuid_t uuid; - uuid_generate(uuid); - uuid_unparse(uuid, uuid_str); - return uuid_str; + return utils::gen_uuid(); +} + +Transaction::Transaction() +{ + const auto result = Database::raw_exec("BEGIN"); + if (std::get<bool>(result) == false) + log_error("Failed to create SQL transaction: ", std::get<std::string>(result)); + else + this->success = true; } +Transaction::~Transaction() +{ + if (this->success) + { + const auto result = Database::raw_exec("END"); + if (std::get<bool>(result) == false) + log_error("Failed to end SQL transaction: ", std::get<std::string>(result)); + } +} #endif diff --git a/src/database/database.hpp b/src/database/database.hpp index ec44543..4a413be 100644 --- a/src/database/database.hpp +++ b/src/database/database.hpp @@ -22,18 +22,20 @@ class Database { public: using time_point = std::chrono::system_clock::time_point; + struct RecordNotFound: public std::exception {}; + enum class Paging { first, last }; struct Uuid: Column<std::string> { static constexpr auto name = "uuid_"; }; - struct Owner: Column<std::string> { static constexpr auto name = "owner_"; }; + struct Owner: UnclearableColumn<std::string> { static constexpr auto name = "owner_"; }; - struct IrcChanName: Column<std::string> { static constexpr auto name = "ircchanname_"; }; + struct IrcChanName: UnclearableColumn<std::string> { static constexpr auto name = "ircchanname_"; }; - struct Channel: Column<std::string> { static constexpr auto name = "channel_"; }; + struct Channel: UnclearableColumn<std::string> { static constexpr auto name = "channel_"; }; - struct IrcServerName: Column<std::string> { static constexpr auto name = "ircservername_"; }; + struct IrcServerName: UnclearableColumn<std::string> { static constexpr auto name = "ircservername_"; }; - struct Server: Column<std::string> { static constexpr auto name = "server_"; }; + struct Server: UnclearableColumn<std::string> { static constexpr auto name = "server_"; }; struct Date: Column<time_point::rep> { static constexpr auto name = "date_"; }; @@ -82,6 +84,10 @@ class Database struct RemoteJid: Column<std::string> { static constexpr auto name = "remote"; }; + struct Address: Column<std::string> { static constexpr auto name = "address_"; }; + + struct ThrottleLimit: Column<long int> { static constexpr auto name = "throttlelimit_"; + ThrottleLimit(): Column<long int>(10) {} }; using MucLogLineTable = Table<Id, Uuid, Owner, IrcChanName, IrcServerName, Date, Body, Nick>; using MucLogLine = MucLogLineTable::RowType; @@ -89,7 +95,7 @@ class Database using GlobalOptionsTable = Table<Id, Owner, MaxHistoryLength, RecordHistory, GlobalPersistent>; using GlobalOptions = GlobalOptionsTable::RowType; - using IrcServerOptionsTable = Table<Id, Owner, Server, Pass, AfterConnectionCommand, TlsPorts, Ports, Username, Realname, VerifyCert, TrustedFingerprint, EncodingOut, EncodingIn, MaxHistoryLength>; + using IrcServerOptionsTable = Table<Id, Owner, Server, Pass, TlsPorts, Ports, Username, Realname, VerifyCert, TrustedFingerprint, EncodingOut, EncodingIn, MaxHistoryLength, Address, Nick, ThrottleLimit>; using IrcServerOptions = IrcServerOptionsTable::RowType; using IrcChannelOptionsTable = Table<Id, Owner, Server, Channel, EncodingOut, EncodingIn, MaxHistoryLength, Persistent, RecordHistoryOptional>; @@ -98,6 +104,9 @@ class Database using RosterTable = Table<LocalJid, RemoteJid>; using RosterItem = RosterTable::RowType; + using AfterConnectionCommandsTable = Table<Id, ForeignKey, AfterConnectionCommand>; + using AfterConnectionCommands = std::vector<AfterConnectionCommandsTable::RowType>; + Database() = default; ~Database() = default; @@ -118,8 +127,22 @@ class Database static IrcChannelOptions get_irc_channel_options_with_server_and_global_default(const std::string& owner, const std::string& server, const std::string& channel); - static std::vector<MucLogLine> get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server, - int limit=-1, const std::string& start="", const std::string& end=""); + static AfterConnectionCommands get_after_connection_commands(const IrcServerOptions& server_options); + static void set_after_connection_commands(const IrcServerOptions& server_options, AfterConnectionCommands& commands); + + /** + * Get all the lines between (optional) start and end dates, with a (optional) limit. + * If after_id is set, only the records after it will be returned. + */ + static std::tuple<bool, std::vector<MucLogLine>> get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server, + std::size_t limit, const std::string& start="", const std::string& end="", + const Id::real_type reference_record_id=Id::unset_value, Paging=Paging::first); + + /** + * Get just one single record matching the given uuid, between (optional) end and start. + * If it does not exist (or is not between end and start), throw a RecordNotFound exception. + */ + static MucLogLine get_muc_log(const std::string& owner, const std::string& chan_name, const std::string& server, const std::string& uuid, const std::string& start="", const std::string& end=""); static std::string store_muc_message(const std::string& owner, const std::string& chan_name, const std::string& server_name, time_point date, const std::string& body, const std::string& nick); @@ -144,6 +167,8 @@ class Database static IrcServerOptionsTable irc_server_options; static IrcChannelOptionsTable irc_channel_options; static RosterTable roster; + static AfterConnectionCommandsTable after_connection_commands; + static std::unique_ptr<DatabaseEngine> db; /** @@ -181,11 +206,20 @@ class Database static auto raw_exec(const std::string& query) { - Database::db->raw_exec(query); + return Database::db->raw_exec(query); } private: static std::string gen_uuid(); static std::map<CacheKey, EncodingIn::real_type> encoding_in_cache; }; + +class Transaction +{ +public: + Transaction(); + ~Transaction(); + bool success{false}; +}; + #endif /* USE_DATABASE */ diff --git a/src/database/delete_query.hpp b/src/database/delete_query.hpp new file mode 100644 index 0000000..dce705b --- /dev/null +++ b/src/database/delete_query.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include <database/query.hpp> +#include <database/engine.hpp> + +class DeleteQuery: public Query +{ +public: + DeleteQuery(const std::string& name): + Query("DELETE") + { + this->body += " from " + name; + } + + DeleteQuery& where() + { + this->body += " WHERE "; + return *this; + }; + + void execute(DatabaseEngine& db) + { + auto statement = db.prepare(this->body); + if (!statement) + return; +#ifdef DEBUG_SQL_QUERIES + const auto timer = this->log_and_time(); +#endif + statement->bind(std::move(this->params)); + if (statement->step() != StepResult::Done) + log_error("Failed to execute DELETE command"); + } +}; diff --git a/src/database/insert_query.hpp b/src/database/insert_query.hpp index 9726424..230e873 100644 --- a/src/database/insert_query.hpp +++ b/src/database/insert_query.hpp @@ -1,10 +1,15 @@ #pragma once #include <database/statement.hpp> +#include <database/database.hpp> #include <database/column.hpp> #include <database/query.hpp> +#include <database/row.hpp> + #include <logger/logger.hpp> +#include <utils/is_one_of.hpp> + #include <type_traits> #include <vector> #include <string> @@ -22,7 +27,7 @@ update_autoincrement_id(std::tuple<T...>& columns, Statement& statement) template <std::size_t N=0, typename... T> typename std::enable_if<N == sizeof...(T), void>::type -update_autoincrement_id(std::tuple<T...>&, Statement& statement) +update_autoincrement_id(std::tuple<T...>&, Statement&) {} struct InsertQuery: public Query @@ -127,3 +132,13 @@ struct InsertQuery: public Query insert_col_name(const std::tuple<T...>&) {} }; + +template <typename... T> +void insert(Row<T...>& row, DatabaseEngine& db) +{ + InsertQuery query(row.table_name, row.columns); + // Ugly workaround for non portable stuff + if (is_one_of<Id, T...>) + query.body += db.get_returning_id_sql_string(Id::name); + query.execute(db, row.columns); +} diff --git a/src/database/postgresql_engine.cpp b/src/database/postgresql_engine.cpp index 984a959..59bc885 100644 --- a/src/database/postgresql_engine.cpp +++ b/src/database/postgresql_engine.cpp @@ -11,6 +11,8 @@ #include <logger/logger.hpp> +#include <cstring> + PostgresqlEngine::PostgresqlEngine(PGconn*const conn): conn(conn) {} @@ -20,6 +22,15 @@ PostgresqlEngine::~PostgresqlEngine() PQfinish(this->conn); } +static void logging_notice_processor(void*, const char* original) +{ + if (original && std::strlen(original) > 0) + { + std::string message{original, std::strlen(original) - 1}; + log_warning("PostgreSQL: ", message); + } +} + std::unique_ptr<DatabaseEngine> PostgresqlEngine::open(const std::string& conninfo) { PGconn* con = PQconnectdb(conninfo.data()); @@ -34,8 +45,10 @@ std::unique_ptr<DatabaseEngine> PostgresqlEngine::open(const std::string& connin { const char* errmsg = PQerrorMessage(con); log_error("Postgresql connection failed: ", errmsg); + PQfinish(con); throw std::runtime_error("failed to open connection."); } + PQsetNoticeProcessor(con, &logging_notice_processor, nullptr); return std::make_unique<PostgresqlEngine>(con); } diff --git a/src/database/postgresql_engine.hpp b/src/database/postgresql_engine.hpp index fe4fb53..1a9c249 100644 --- a/src/database/postgresql_engine.hpp +++ b/src/database/postgresql_engine.hpp @@ -36,12 +36,15 @@ private: #else +using namespace std::string_literals; + class PostgresqlEngine { public: static std::unique_ptr<DatabaseEngine> open(const std::string& string) { throw std::runtime_error("Cannot open postgresql database "s + string + ": biboumi is not compiled with libpq."); + return {}; } }; diff --git a/src/database/postgresql_statement.hpp b/src/database/postgresql_statement.hpp index 571c8f1..37e8ea0 100644 --- a/src/database/postgresql_statement.hpp +++ b/src/database/postgresql_statement.hpp @@ -6,6 +6,8 @@ #include <libpq-fe.h> +#include <cstring> + class PostgresqlStatement: public Statement { public: @@ -90,7 +92,7 @@ class PostgresqlStatement: public Statement private: private: - bool execute() + bool execute(const bool second_attempt=false) { std::vector<const char*> params; params.reserve(this->params.size()); @@ -108,8 +110,20 @@ private: const auto status = PQresultStatus(this->result); if (status != PGRES_TUPLES_OK && status != PGRES_COMMAND_OK) { - log_error("Failed to execute command: ", PQresultErrorMessage(this->result)); - return false; + const char* original = PQerrorMessage(this->conn); + if (original && std::strlen(original) > 0) + log_error("Failed to execute command: ", std::string{original, std::strlen(original) - 1}); + if (PQstatus(this->conn) != CONNECTION_OK && !second_attempt) + { + log_info("Trying to reconnect to PostgreSQL server and execute the query again."); + PQreset(this->conn); + return this->execute(true); + } + else + { + log_error("Givin up."); + return false; + } } return true; } diff --git a/src/database/query.cpp b/src/database/query.cpp index d27dc59..5ec8599 100644 --- a/src/database/query.cpp +++ b/src/database/query.cpp @@ -6,11 +6,6 @@ void actual_bind(Statement& statement, const std::string& value, int index) statement.bind_text(index, value); } -void actual_bind(Statement& statement, const std::int64_t value, int index) -{ - statement.bind_int64(index, value); -} - void actual_bind(Statement& statement, const OptionalBool& value, int index) { if (!value.is_set) @@ -21,7 +16,6 @@ void actual_bind(Statement& statement, const OptionalBool& value, int index) statement.bind_int64(index, -1); } - void actual_add_param(Query& query, const std::string& val) { query.params.push_back(val); diff --git a/src/database/query.hpp b/src/database/query.hpp index 8434944..ae6e946 100644 --- a/src/database/query.hpp +++ b/src/database/query.hpp @@ -12,8 +12,14 @@ #include <string> void actual_bind(Statement& statement, const std::string& value, int index); -void actual_bind(Statement& statement, const std::int64_t value, int index); void actual_bind(Statement& statement, const OptionalBool& value, int index); +template <typename T> +void actual_bind(Statement& statement, const T& value, int index) +{ + static_assert(std::is_integral<T>::value, + "Only a string, an optional-bool or an integer can be used."); + statement.bind_int64(index, static_cast<int>(value)); +} #ifdef DEBUG_SQL_QUERIES #include <utils/scopetimer.hpp> diff --git a/src/database/row.hpp b/src/database/row.hpp index 4dc98be..4004b5d 100644 --- a/src/database/row.hpp +++ b/src/database/row.hpp @@ -1,11 +1,5 @@ #pragma once -#include <database/insert_query.hpp> -#include <database/update_query.hpp> -#include <logger/logger.hpp> - -#include <utils/is_one_of.hpp> - #include <type_traits> template <typename... T> @@ -29,44 +23,24 @@ struct Row return col.value; } - template <bool Coucou=true> - void save(std::unique_ptr<DatabaseEngine>& db, typename std::enable_if<!is_one_of<Id, T...> && Coucou>::type* = nullptr) - { - this->insert(*db); - } - - template <bool Coucou=true> - void save(std::unique_ptr<DatabaseEngine>& db, typename std::enable_if<is_one_of<Id, T...> && Coucou>::type* = nullptr) + void clear() { - const Id& id = std::get<Id>(this->columns); - if (id.value == Id::unset_value) - { - this->insert(*db); - if (db->last_inserted_rowid >= 0) - std::get<Id>(this->columns).value = static_cast<Id::real_type>(db->last_inserted_rowid); - } - else - this->update(*db); + this->clear_col<0>(); } - private: - void insert(DatabaseEngine& db) - { - InsertQuery query(this->table_name, this->columns); - // Ugly workaround for non portable stuff - query.body += db.get_returning_id_sql_string(Id::name); - query.execute(db, this->columns); - } + std::tuple<T...> columns{}; + std::string table_name; - void update(DatabaseEngine& db) +private: + template <std::size_t N> + typename std::enable_if<N < sizeof...(T), void>::type + clear_col() { - UpdateQuery query(this->table_name, this->columns); - - query.execute(db, this->columns); + std::get<N>(this->columns).clear(); + this->clear_col<N+1>(); } - -public: - std::tuple<T...> columns; - std::string table_name; - + template <std::size_t N> + typename std::enable_if<N == sizeof...(T), void>::type + clear_col() + { } }; diff --git a/src/database/save.hpp b/src/database/save.hpp new file mode 100644 index 0000000..4362110 --- /dev/null +++ b/src/database/save.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include <database/update_query.hpp> +#include <database/insert_query.hpp> + +#include <database/engine.hpp> + +#include <database/row.hpp> + +#include <utils/is_one_of.hpp> + +template <typename... T, bool Coucou=true> +void save(Row<T...>& row, DatabaseEngine& db, typename std::enable_if<!is_one_of<Id, T...> && Coucou>::type* = nullptr) +{ + insert(row, db); +} + +template <typename... T, bool Coucou=true> +void save(Row<T...>& row, DatabaseEngine& db, typename std::enable_if<is_one_of<Id, T...> && Coucou>::type* = nullptr) +{ + const Id& id = std::get<Id>(row.columns); + if (id.value == Id::unset_value) + { + insert(row, db); + if (db.last_inserted_rowid >= 0) + std::get<Id>(row.columns).value = static_cast<Id::real_type>(db.last_inserted_rowid); + } + else + update(row, db); +} + diff --git a/src/database/select_query.hpp b/src/database/select_query.hpp index 5a17f38..e372f2e 100644 --- a/src/database/select_query.hpp +++ b/src/database/select_query.hpp @@ -2,6 +2,8 @@ #include <database/engine.hpp> +#include <database/table.hpp> +#include <database/database.hpp> #include <database/statement.hpp> #include <database/query.hpp> #include <logger/logger.hpp> @@ -115,6 +117,8 @@ struct SelectQuery: public Query #endif auto statement = db.prepare(this->body); + if (!statement) + return rows; statement->bind(std::move(this->params)); while (statement->step() == StepResult::Row) @@ -130,3 +134,9 @@ struct SelectQuery: public Query const std::string table_name; }; +template <typename... T> +auto select(const Table<T...>& table) +{ + SelectQuery<T...> query(table.name); + return query; +} diff --git a/src/database/sqlite3_engine.cpp b/src/database/sqlite3_engine.cpp index ae4a146..5e3bba1 100644 --- a/src/database/sqlite3_engine.cpp +++ b/src/database/sqlite3_engine.cpp @@ -3,7 +3,6 @@ #ifdef SQLITE3_FOUND #include <database/sqlite3_engine.hpp> - #include <database/sqlite3_statement.hpp> #include <database/query.hpp> diff --git a/src/database/sqlite3_engine.hpp b/src/database/sqlite3_engine.hpp index 5b8176c..a7bfcdb 100644 --- a/src/database/sqlite3_engine.hpp +++ b/src/database/sqlite3_engine.hpp @@ -35,12 +35,15 @@ private: #else +using namespace std::string_literals; + class Sqlite3Engine { public: static std::unique_ptr<DatabaseEngine> open(const std::string& string) { throw std::runtime_error("Cannot open sqlite3 database "s + string + ": biboumi is not compiled with sqlite3 lib."); + return {}; } }; diff --git a/src/database/sqlite3_statement.hpp b/src/database/sqlite3_statement.hpp index 7738fa6..3ed60c0 100644 --- a/src/database/sqlite3_statement.hpp +++ b/src/database/sqlite3_statement.hpp @@ -88,5 +88,4 @@ class Sqlite3Statement: public Statement private: sqlite3_stmt* stmt; - int last_step_result{SQLITE_OK}; }; diff --git a/src/database/table.hpp b/src/database/table.hpp index 680e7cc..0b8bfc0 100644 --- a/src/database/table.hpp +++ b/src/database/table.hpp @@ -2,7 +2,7 @@ #include <database/engine.hpp> -#include <database/select_query.hpp> +#include <database/delete_query.hpp> #include <database/row.hpp> #include <algorithm> @@ -79,10 +79,10 @@ class Table return {this->name}; } - auto select() + auto del() { - SelectQuery<T...> select(this->name); - return select; + DeleteQuery query(this->name); + return query; } const std::string& get_name() const @@ -90,6 +90,8 @@ class Table return this->name; } + const std::string name; + private: template <std::size_t N=0> @@ -124,5 +126,4 @@ class Table add_column_create(DatabaseEngine&, std::string&) { } - const std::string name; }; diff --git a/src/database/update_query.hpp b/src/database/update_query.hpp index a29ac3f..c2b819d 100644 --- a/src/database/update_query.hpp +++ b/src/database/update_query.hpp @@ -1,7 +1,8 @@ #pragma once -#include <database/query.hpp> #include <database/engine.hpp> +#include <database/query.hpp> +#include <database/row.hpp> using namespace std::string_literals; @@ -102,3 +103,11 @@ struct UpdateQuery: public Query actual_bind(statement, value.value, sizeof...(T)); } }; + +template <typename... T> +void update(Row<T...>& row, DatabaseEngine& db) +{ + UpdateQuery query(row.table_name, row.columns); + + query.execute(db, row.columns); +} diff --git a/src/irc/irc_channel.cpp b/src/irc/irc_channel.cpp index 53043c7..2dd20fe 100644 --- a/src/irc/irc_channel.cpp +++ b/src/irc/irc_channel.cpp @@ -33,8 +33,9 @@ IrcUser* IrcChannel::find_user(const std::string& name) const return nullptr; } -void IrcChannel::remove_user(const IrcUser* user) +std::unique_ptr<IrcUser> IrcChannel::remove_user(const IrcUser* user) { + std::unique_ptr<IrcUser> result{}; const auto nick = user->nick; const bool is_self = (user == this->self); const auto it = std::find_if(this->users.begin(), this->users.end(), @@ -44,6 +45,7 @@ void IrcChannel::remove_user(const IrcUser* user) }); if (it != this->users.end()) { + result = std::move(*it); this->users.erase(it); if (is_self) { @@ -51,16 +53,5 @@ void IrcChannel::remove_user(const IrcUser* user) this->joined = false; } } -} - -void IrcChannel::remove_all_users() -{ - this->users.clear(); - this->self = nullptr; -} - -DummyIrcChannel::DummyIrcChannel(): - IrcChannel(), - joining(false) -{ + return result; } diff --git a/src/irc/irc_channel.hpp b/src/irc/irc_channel.hpp index 8f85edb..7000ada 100644 --- a/src/irc/irc_channel.hpp +++ b/src/irc/irc_channel.hpp @@ -32,8 +32,7 @@ public: IrcUser* add_user(const std::string& name, const std::map<char, char>& prefix_to_mode); IrcUser* find_user(const std::string& name) const; - void remove_user(const IrcUser* user); - void remove_all_users(); + std::unique_ptr<IrcUser> remove_user(const IrcUser* user); const std::vector<std::unique_ptr<IrcUser>>& get_users() const { return this->users; } @@ -42,33 +41,3 @@ protected: IrcUser* self{nullptr}; std::vector<std::unique_ptr<IrcUser>> users{}; }; - -/** - * A special channel that is not actually linked to any real irc - * channel. This is just a channel representing a connection to the - * server. If an user wants to maintain the connection to the server without - * having to be on any irc channel of that server, he can just join this - * dummy channel. - * It’s not actually dummy because it’s useful and it does things, but well. - */ -class DummyIrcChannel: public IrcChannel -{ -public: - explicit DummyIrcChannel(); - DummyIrcChannel(const DummyIrcChannel&) = delete; - DummyIrcChannel(DummyIrcChannel&&) = delete; - DummyIrcChannel& operator=(const DummyIrcChannel&) = delete; - DummyIrcChannel& operator=(DummyIrcChannel&&) = delete; - - /** - * This flag is at true whenever the user wants to join this channel, but - * he is not yet connected to the server. When the connection is made, we - * check that flag and if it’s true, we inform the user that he has just - * joined that channel. - * If the user is already connected to the server when he tries to join - * the channel, we don’t use that flag, we just join it immediately. - */ - bool joining; -}; - - diff --git a/src/irc/irc_client.cpp b/src/irc/irc_client.cpp index 40078d9..0b5715e 100644 --- a/src/irc/irc_client.cpp +++ b/src/irc/irc_client.cpp @@ -135,7 +135,7 @@ IrcClient::IrcClient(std::shared_ptr<Poller>& poller, std::string hostname, std::string realname, std::string user_hostname, Bridge& bridge): TCPClientSocketHandler(poller), - hostname(std::move(hostname)), + hostname(hostname), user_hostname(std::move(user_hostname)), username(std::move(username)), realname(std::move(realname)), @@ -143,16 +143,15 @@ IrcClient::IrcClient(std::shared_ptr<Poller>& poller, std::string hostname, bridge(bridge), welcomed(false), chanmodes({"", "", "", ""}), - chantypes({'#', '&'}) -{ - this->dummy_channel.topic = "This is a virtual channel provided for " - "convenience by biboumi, it is not connected " - "to any actual IRC channel of the server '" + this->hostname + - "', and sending messages in it has no effect. " - "Its main goal is to keep the connection to the IRC server " - "alive without having to join a real channel of that server. " - "To disconnect from the IRC server, leave this room and all " - "other IRC channels of that server."; + chantypes({'#', '&'}), + tokens_bucket(this->get_throttle_limit(), 1s, [this]() { + if (message_queue.empty()) + return true; + this->actual_send(std::move(this->message_queue.front())); + this->message_queue.pop_front(); + return false; + }, "TokensBucket" + this->hostname + this->bridge.get_jid()) +{ #ifdef USE_DATABASE auto options = Database::get_irc_server_options(this->bridge.get_bare_jid(), this->get_hostname()); @@ -179,6 +178,7 @@ IrcClient::~IrcClient() // This event may or may not exist (if we never got connected, it // doesn't), but it's ok TimedEventsManager::instance().cancel("PING" + this->hostname + this->bridge.get_jid()); + TimedEventsManager::instance().cancel("TokensBucket" + this->hostname + this->bridge.get_jid()); } void IrcClient::start() @@ -194,20 +194,23 @@ void IrcClient::start() bool tls; std::tie(port, tls) = this->ports_to_try.top(); this->ports_to_try.pop(); - this->bridge.send_xmpp_message(this->hostname, "", "Connecting to " + - this->hostname + ":" + port + " (" + - (tls ? "encrypted" : "not encrypted") + ")"); - this->bind_addr = Config::get("outgoing_bind", ""); + std::string address = this->hostname; -#ifdef BOTAN_FOUND -# ifdef USE_DATABASE +#ifdef USE_DATABASE auto options = Database::get_irc_server_options(this->bridge.get_bare_jid(), this->get_hostname()); +# ifdef BOTAN_FOUND this->credential_manager.set_trusted_fingerprint(options.col<Database::TrustedFingerprint>()); # endif + if (Config::get("fixed_irc_server", "").empty() && + !options.col<Database::Address>().empty()) + address = options.col<Database::Address>(); #endif - this->connect(this->hostname, port, tls); + this->bridge.send_xmpp_message(this->hostname, "", "Connecting to " + + address + ":" + port + " (" + + (tls ? "encrypted" : "not encrypted") + ")"); + this->connect(address, port, tls); } void IrcClient::on_connection_failed(const std::string& reason) @@ -315,8 +318,6 @@ void IrcClient::on_connection_close(const std::string& error_msg) IrcChannel* IrcClient::get_channel(const std::string& n) { - if (n.empty()) - return &this->dummy_channel; const std::string name = utils::tolower(n); try { @@ -324,9 +325,21 @@ IrcChannel* IrcClient::get_channel(const std::string& n) } catch (const std::out_of_range& exception) { - this->channels.emplace(name, std::make_unique<IrcChannel>()); + return this->channels.emplace(name, std::make_unique<IrcChannel>()).first->second.get(); + } +} + +const IrcChannel* IrcClient::find_channel(const std::string& n) const +{ + const std::string name = utils::tolower(n); + try + { return this->channels.at(name).get(); } + catch (const std::out_of_range& exception) + { + return nullptr; + } } bool IrcClient::is_channel_joined(const std::string& name) @@ -385,25 +398,39 @@ void IrcClient::parse_in_buffer(const size_t) } } -void IrcClient::send_message(IrcMessage&& message) +void IrcClient::actual_send(std::pair<IrcMessage, MessageCallback> message_pair) { - log_debug("IRC SENDING: (", this->get_hostname(), ") ", message); - std::string res; - if (!message.prefix.empty()) - res += ":" + std::move(message.prefix) + " "; - res += message.command; - for (const std::string& arg: message.arguments) - { - if (arg.find(' ') != std::string::npos || - (!arg.empty() && arg[0] == ':')) - { - res += " :" + arg; - break; - } - res += " " + arg; - } - res += "\r\n"; - this->send_data(std::move(res)); + const IrcMessage& message = message_pair.first; + const MessageCallback& callback = message_pair.second; + log_debug("IRC SENDING: (", this->get_hostname(), ") ", message); + std::string res; + if (!message.prefix.empty()) + res += ":" + message.prefix + " "; + res += message.command; + for (const std::string& arg: message.arguments) + { + if (arg.find(' ') != std::string::npos + || (!arg.empty() && arg[0] == ':')) + { + res += " :" + arg; + break; + } + res += " " + arg; + } + res += "\r\n"; + this->send_data(std::move(res)); + + if (callback) + callback(this, message); + } + +void IrcClient::send_message(IrcMessage message, MessageCallback callback, bool throttle) +{ + auto message_pair = std::make_pair(std::move(message), std::move(callback)); + if (this->tokens_bucket.use_token() || !throttle) + this->actual_send(std::move(message_pair)); + else + message_queue.push_back(std::move(message_pair)); } void IrcClient::send_raw(const std::string& txt) @@ -454,12 +481,12 @@ void IrcClient::send_topic_command(const std::string& chan_name, const std::stri void IrcClient::send_quit_command(const std::string& reason) { - this->send_message(IrcMessage("QUIT", {reason})); + this->send_message(IrcMessage("QUIT", {reason}), {}, false); } void IrcClient::send_join_command(const std::string& chan_name, const std::string& password) { - if (this->welcomed == false) + if (!this->welcomed) { const auto it = std::find_if(begin(this->channels_to_join), end(this->channels_to_join), [&chan_name](const auto& pair) { return std::get<0>(pair) == chan_name; }); @@ -473,10 +500,11 @@ void IrcClient::send_join_command(const std::string& chan_name, const std::strin this->start(); } -bool IrcClient::send_channel_message(const std::string& chan_name, const std::string& body) +bool IrcClient::send_channel_message(const std::string& chan_name, const std::string& body, + MessageCallback callback) { IrcChannel* channel = this->get_channel(chan_name); - if (channel->joined == false) + if (!channel->joined) { log_warning("Cannot send message to channel ", chan_name, ", it is not joined"); return false; @@ -496,7 +524,7 @@ bool IrcClient::send_channel_message(const std::string& chan_name, const std::st ::strlen(":!@ PRIVMSG ") - chan_name.length() - ::strlen(" :\r\n"); const auto lines = cut(body, line_size); for (const auto& line: lines) - this->send_message(IrcMessage("PRIVMSG", {chan_name, line})); + this->send_message(IrcMessage("PRIVMSG", {chan_name, line}), callback); return true; } @@ -544,7 +572,9 @@ void IrcClient::send_ping_command() void IrcClient::forward_server_message(const IrcMessage& message) { const std::string from = message.prefix; - const std::string body = message.arguments[1]; + std::string body; + for (auto it = std::next(message.arguments.begin()); it != message.arguments.end(); ++it) + body += *it + ' '; this->bridge.send_xmpp_message(this->hostname, from, body); } @@ -647,6 +677,11 @@ void IrcClient::set_and_forward_user_list(const IrcMessage& message) { const std::string chan_name = utils::tolower(message.arguments[2]); IrcChannel* channel = this->get_channel(chan_name); + if (channel->joined) + { + this->forward_server_message(message); + return; + } std::vector<std::string> nicks = utils::split(message.arguments[3], ' '); for (const std::string& nick: nicks) { @@ -670,10 +705,7 @@ void IrcClient::on_channel_join(const IrcMessage& message) { const std::string chan_name = utils::tolower(message.arguments[0]); IrcChannel* channel; - if (chan_name.empty()) - channel = &this->dummy_channel; - else - channel = this->get_channel(chan_name); + channel = this->get_channel(chan_name); const std::string nick = message.prefix; IrcUser* user = channel->add_user(nick, this->prefix_to_mode); if (channel->joined == false) @@ -785,6 +817,16 @@ void IrcClient::on_channel_completely_joined(const IrcMessage& message) { const std::string chan_name = utils::tolower(message.arguments[1]); IrcChannel* channel = this->get_channel(chan_name); + if (chan_name == "*" || channel->joined) + { + this->forward_server_message(message); + return; + } + if (!channel->get_self()) + { + log_error("End of NAMES list but we never received our own nick."); + return; + } channel->joined = true; this->bridge.send_user_join(this->hostname, chan_name, channel->get_self(), channel->get_self()->get_most_significant_mode(this->sorted_user_modes), true); @@ -900,8 +942,9 @@ void IrcClient::on_welcome_message(const IrcMessage& message) #ifdef USE_DATABASE auto options = Database::get_irc_server_options(this->bridge.get_bare_jid(), this->get_hostname()); - if (!options.col<Database::AfterConnectionCommand>().empty()) - this->send_raw(options.col<Database::AfterConnectionCommand>()); + const auto commands = Database::get_after_connection_commands(options); + for (const auto& command: commands) + this->send_raw(command.col<Database::AfterConnectionCommand>()); #endif // Install a repeated events to regularly send a PING TimedEventsManager::instance().add_event(TimedEvent(240s, std::bind(&IrcClient::send_ping_command, this), @@ -948,18 +991,6 @@ void IrcClient::on_welcome_message(const IrcMessage& message) if (!channels_with_key.empty()) this->send_join_command(channels_with_key, keys); this->channels_to_join.clear(); - // Indicate that the dummy channel is joined as well, if needed - if (this->dummy_channel.joining) - { - // Simulate a message coming from the IRC server saying that we joined - // the channel - const IrcMessage join_message(this->get_nick(), "JOIN", {""}); - this->on_channel_join(join_message); - const IrcMessage end_join_message(std::string(this->hostname), "366", - {this->get_nick(), - "", "End of NAMES list"}); - this->on_channel_completely_joined(end_join_message); - } } void IrcClient::on_part(const IrcMessage& message) @@ -976,18 +1007,18 @@ void IrcClient::on_part(const IrcMessage& message) { std::string nick = user->nick; bool self = channel->get_self() && channel->get_self()->nick == nick; - channel->remove_user(user); - Iid iid; - iid.set_local(chan_name); - iid.set_server(this->hostname); - iid.type = Iid::Type::Channel; + auto user_ptr = channel->remove_user(user); if (self) { this->channels.erase(utils::tolower(chan_name)); // channel pointer is now invalid channel = nullptr; } - this->bridge.send_muc_leave(iid, std::move(nick), txt, self, true); + Iid iid; + iid.set_local(chan_name); + iid.set_server(this->hostname); + iid.type = Iid::Type::Channel; + this->bridge.send_muc_leave(iid, *user_ptr, txt, self, true, {}, this); } } @@ -1004,8 +1035,7 @@ void IrcClient::on_error(const IrcMessage& message) IrcChannel* channel = pair.second.get(); if (!channel->joined) continue; - std::string own_nick = channel->get_self()->nick; - this->bridge.send_muc_leave(iid, std::move(own_nick), leave_message, true, false); + this->bridge.send_muc_leave(iid, *channel->get_self(), leave_message, true, false, {}, this); } this->channels.clear(); this->send_gateway_message("ERROR: " + leave_message); @@ -1030,7 +1060,7 @@ void IrcClient::on_quit(const IrcMessage& message) iid.set_local(chan_name); iid.set_server(this->hostname); iid.type = Iid::Type::Channel; - this->bridge.send_muc_leave(iid, user->nick, txt, self, false); + this->bridge.send_muc_leave(iid, *user, txt, self, false, {}, this); channel->remove_user(user); } } @@ -1062,10 +1092,6 @@ void IrcClient::on_nick(const IrcMessage& message) } }; - if (this->get_dummy_channel().joined) - { - change_nick_func("", &this->get_dummy_channel()); - } for (const auto& pair: this->channels) { change_nick_func(pair.first, pair.second.get()); @@ -1132,8 +1158,6 @@ void IrcClient::on_channel_bad_key(const IrcMessage& message) void IrcClient::on_channel_mode(const IrcMessage& message) { - // For now, just transmit the modes so the user can know what happens - // TODO, actually interprete the mode. Iid iid; iid.set_local(message.arguments[0]); iid.set_server(this->hostname); @@ -1151,7 +1175,7 @@ void IrcClient::on_channel_mode(const IrcMessage& message) } this->bridge.send_message(iid, "", "Mode " + iid.get_local() + " [" + mode_arguments + "] by " + user.nick, - true); + true, this->is_channel_joined(iid.get_local())); const IrcChannel* channel = this->get_channel(iid.get_local()); if (!channel) return; @@ -1224,6 +1248,11 @@ void IrcClient::on_channel_mode(const IrcMessage& message) } } +void IrcClient::set_throttle_limit(long int limit) +{ + this->tokens_bucket.set_limit(limit); +} + void IrcClient::on_user_mode(const IrcMessage& message) { this->bridge.send_xmpp_message(this->hostname, "", @@ -1248,25 +1277,7 @@ void IrcClient::on_unknown_message(const IrcMessage& message) size_t IrcClient::number_of_joined_channels() const { - if (this->dummy_channel.joined) - return this->channels.size() + 1; - else - return this->channels.size(); -} - -DummyIrcChannel& IrcClient::get_dummy_channel() -{ - return this->dummy_channel; -} - -void IrcClient::leave_dummy_channel(const std::string& exit_message, const std::string& resource) -{ - if (!this->dummy_channel.joined) - return; - this->dummy_channel.joined = false; - this->dummy_channel.joining = false; - this->dummy_channel.remove_all_users(); - this->bridge.send_muc_leave(Iid("%" + this->hostname, this->chantypes), std::string(this->current_nick), exit_message, true, true, resource); + return this->channels.size(); } #ifdef BOTAN_FOUND @@ -1279,3 +1290,12 @@ bool IrcClient::abort_on_invalid_cert() const return true; } #endif + +long int IrcClient::get_throttle_limit() const +{ +#ifdef USE_DATABASE + return Database::get_irc_server_options(this->bridge.get_bare_jid(), this->hostname).col<Database::ThrottleLimit>(); +#else + return 10; +#endif +} diff --git a/src/irc/irc_client.hpp b/src/irc/irc_client.hpp index de5c520..674f3ff 100644 --- a/src/irc/irc_client.hpp +++ b/src/irc/irc_client.hpp @@ -16,8 +16,14 @@ #include <vector> #include <string> #include <stack> +#include <deque> #include <map> #include <set> +#include <utils/tokens_bucket.hpp> + +class IrcClient; + +using MessageCallback = std::function<void(const IrcClient*, const IrcMessage&)>; class Bridge; @@ -28,7 +34,7 @@ class Bridge; class IrcClient: public TCPClientSocketHandler { public: - explicit IrcClient(std::shared_ptr<Poller>& poller, std::string hostname, + explicit IrcClient(std::shared_ptr<Poller>& poller, std::string hostname, std::string nickname, std::string username, std::string realname, std::string user_hostname, Bridge& bridge); @@ -68,6 +74,10 @@ public: */ IrcChannel* get_channel(const std::string& name); /** + * Return the channel with this name. Nullptr if it is not found + */ + const IrcChannel* find_channel(const std::string& name) const; + /** * Returns true if the channel is joined */ bool is_channel_joined(const std::string& name); @@ -80,8 +90,9 @@ public: * (actually, into our out_buf and signal the poller that we want to wach * for send events to be ready) */ - void send_message(IrcMessage&& message); + void send_message(IrcMessage message, MessageCallback callback={}, bool throttle=true); void send_raw(const std::string& txt); + void actual_send(std::pair<IrcMessage, MessageCallback> message_pair); /** * Send the PONG irc command */ @@ -110,7 +121,8 @@ public: * Send a PRIVMSG command for a channel * Return true if the message was actually sent */ - bool send_channel_message(const std::string& chan_name, const std::string& body); + bool send_channel_message(const std::string& chan_name, const std::string& body, + MessageCallback callback); /** * Send a PRIVMSG command for an user */ @@ -279,15 +291,6 @@ public: * Return the number of joined channels */ size_t number_of_joined_channels() const; - /** - * Get a reference to the unique dummy channel - */ - DummyIrcChannel& get_dummy_channel(); - /** - * Leave the dummy channel: forward a message to the user to indicate that - * he left it, and mark it as not joined. - */ - void leave_dummy_channel(const std::string& exit_message, const std::string& resource); const std::string& get_hostname() const { return this->hostname; } std::string get_nick() const { return this->current_nick; } @@ -298,7 +301,7 @@ public: const std::vector<char>& get_sorted_user_modes() const { return this->sorted_user_modes; } std::set<char> get_chantypes() const { return this->chantypes; } - + void set_throttle_limit(long int limit); /** * Store the history limit that the client asked when joining this room. */ @@ -336,14 +339,13 @@ private: */ Bridge& bridge; /** - * The list of joined channels, indexed by name + * Where messaged are stored when they are throttled. */ - std::unordered_map<std::string, std::unique_ptr<IrcChannel>> channels; + std::deque<std::pair<IrcMessage, MessageCallback>> message_queue{}; /** - * A single channel with a iid of the form "hostname" (normal channel have - * an iid of the form "chan%hostname". + * The list of joined channels, indexed by name */ - DummyIrcChannel dummy_channel; + std::unordered_map<std::string, std::unique_ptr<IrcChannel>> channels; /** * A list of chan we want to join (tuples with the channel name and the * password, if any), but we need a response 001 from the server before @@ -399,6 +401,8 @@ private: * the WebIRC protocole. */ Resolver dns_resolver; + TokensBucket tokens_bucket; + long int get_throttle_limit() const; }; diff --git a/src/irc/irc_message.hpp b/src/irc/irc_message.hpp index fe954e4..269a12a 100644 --- a/src/irc/irc_message.hpp +++ b/src/irc/irc_message.hpp @@ -14,9 +14,9 @@ public: ~IrcMessage() = default; IrcMessage(const IrcMessage&) = delete; - IrcMessage(IrcMessage&&) = delete; + IrcMessage(IrcMessage&&) = default; IrcMessage& operator=(const IrcMessage&) = delete; - IrcMessage& operator=(IrcMessage&&) = delete; + IrcMessage& operator=(IrcMessage&&) = default; std::string prefix; std::string command; diff --git a/src/main.cpp b/src/main.cpp index c877e43..2448197 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -55,45 +55,8 @@ static void sigusr_handler(int, siginfo_t*, void*) reload.store(true); } -int main(int ac, char** av) +static void setup_signals() { - if (ac > 1) - { - const std::string arg = av[1]; - if (arg.size() >= 2 && arg[0] == '-' && arg[1] == '-') - { - if (arg == "--help") - return display_help(); - else - { - std::cerr << "Unknow command line option: " << arg << std::endl; - return 1; - } - } - } - const std::string conf_filename = ac > 1 ? av[1] : xdg_config_path("biboumi.cfg"); - std::cout << "Using configuration file: " << conf_filename << std::endl; - - if (!Config::read_conf(conf_filename)) - return config_help(""); - - const std::string password = Config::get("password", ""); - if (password.empty()) - return config_help("password"); - const std::string hostname = Config::get("hostname", ""); - if (hostname.empty()) - return config_help("hostname"); - - -#ifdef USE_DATABASE - try { - open_database(); - } catch (const std::exception& e) { - log_error(e.what()); - return 1; - } -#endif - // Block the signals we want to manage. They will be unblocked only during // the epoll_pwait or ppoll calls. This avoids some race conditions, // explained in man 2 pselect on linux @@ -103,6 +66,7 @@ int main(int ac, char** av) sigaddset(&mask, SIGTERM); sigaddset(&mask, SIGUSR1); sigaddset(&mask, SIGUSR2); + sigaddset(&mask, SIGHUP); sigprocmask(SIG_BLOCK, &mask, nullptr); // Install the signals used to exit the process cleanly, or reload the @@ -113,7 +77,7 @@ int main(int ac, char** av) sigfillset(&on_sigint.sa_mask); // we want to catch that signal only once. // Sending SIGINT again will "force" an exit - on_sigint.sa_flags = SA_RESETHAND; + on_sigint.sa_flags = 0 & SA_RESETHAND; sigaction(SIGINT, &on_sigint, nullptr); sigaction(SIGTERM, &on_sigint, nullptr); @@ -124,7 +88,11 @@ int main(int ac, char** av) on_sigusr.sa_flags = 0; sigaction(SIGUSR1, &on_sigusr, nullptr); sigaction(SIGUSR2, &on_sigusr, nullptr); + sigaction(SIGHUP, &on_sigusr, nullptr); +} +static int main_loop(std::string hostname, std::string password) +{ auto p = std::make_shared<Poller>(); #ifdef UDNS_FOUND @@ -135,7 +103,9 @@ int main(int ac, char** av) std::make_shared<BiboumiComponent>(p, hostname, password); xmpp_component->start(); - IdentdServer identd(*xmpp_component, p, static_cast<uint16_t>(Config::get_int("identd_port", 113))); + std::unique_ptr<IdentdServer> identd; + if (Config::get_int("identd_port", 113) != 0) + identd = std::make_unique<IdentdServer>(*xmpp_component, p, static_cast<uint16_t>(Config::get_int("identd_port", 113))); auto timeout = TimedEventsManager::instance().get_timeout(); while (p->poll(timeout) != -1) @@ -144,7 +114,8 @@ int main(int ac, char** av) // Check for empty irc_clients (not connected, or with no joined // channel) and remove them xmpp_component->clean(); - identd.clean(); + if (identd) + identd->clean(); if (stop) { log_info("Signal received, exiting..."); @@ -157,7 +128,8 @@ int main(int ac, char** av) #ifdef UDNS_FOUND dns_handler.destroy(); #endif - identd.shutdown(); + if (identd) + identd->shutdown(); // Cancel the timer for a potential reconnection TimedEventsManager::instance().cancel("XMPP reconnection"); } @@ -199,7 +171,8 @@ int main(int ac, char** av) #ifdef UDNS_FOUND dns_handler.destroy(); #endif - identd.shutdown(); + if (identd) + identd->shutdown(); } } // If the only existing connection is the one to the XMPP component: @@ -218,3 +191,51 @@ int main(int ac, char** av) log_info("All connections cleanly closed, have a nice day."); return 0; } + +int main(int ac, char** av) +{ + if (ac > 1) + { + const std::string arg = av[1]; + if (arg.size() >= 2 && arg[0] == '-' && arg[1] == '-') + { + if (arg == "--help") + return display_help(); + else + { + std::cerr << "Unknow command line option: " << arg + << std::endl; + return 1; + } + } + } + const std::string conf_filename = + ac > 1 ? av[1]: xdg_config_path("biboumi.cfg"); + std::cout << "Using configuration file: " << conf_filename << std::endl; + + if (!Config::read_conf(conf_filename)) + return config_help(""); + + const std::string password = Config::get("password", ""); + if (password.empty()) + return config_help("password"); + const std::string hostname = Config::get("hostname", ""); + if (hostname.empty()) + return config_help("hostname"); + +#ifdef USE_DATABASE + try + { + open_database(); + } + catch (const std::exception& e) + { + log_error(e.what()); + return 1; + } +#endif + + setup_signals(); + + return main_loop(std::move(hostname), std::move(password)); +} diff --git a/src/network/credentials_manager.cpp b/src/network/credentials_manager.cpp index b25f442..89c694c 100644 --- a/src/network/credentials_manager.cpp +++ b/src/network/credentials_manager.cpp @@ -21,9 +21,8 @@ static const std::vector<std::string> default_cert_files = { Botan::Certificate_Store_In_Memory BasicCredentialsManager::certificate_store; bool BasicCredentialsManager::certs_loaded = false; -BasicCredentialsManager::BasicCredentialsManager(const TCPSocketHandler* const socket_handler): +BasicCredentialsManager::BasicCredentialsManager(): Botan::Credentials_Manager(), - socket_handler(socket_handler), trusted_fingerprint{} { BasicCredentialsManager::load_certs(); diff --git a/src/network/credentials_manager.hpp b/src/network/credentials_manager.hpp index 3a37bdc..210a628 100644 --- a/src/network/credentials_manager.hpp +++ b/src/network/credentials_manager.hpp @@ -25,7 +25,7 @@ void check_tls_certificate(const std::vector<Botan::X509_Certificate>& certs, class BasicCredentialsManager: public Botan::Credentials_Manager { public: - BasicCredentialsManager(const TCPSocketHandler* const socket_handler); + BasicCredentialsManager(); BasicCredentialsManager(BasicCredentialsManager&&) = delete; BasicCredentialsManager(const BasicCredentialsManager&) = delete; @@ -38,7 +38,6 @@ public: const std::string& get_trusted_fingerprint() const; private: - const TCPSocketHandler* const socket_handler; static bool try_to_open_one_ca_bundle(const std::vector<std::string>& paths); static void load_certs(); diff --git a/src/network/tcp_client_socket_handler.cpp b/src/network/tcp_client_socket_handler.cpp index aac13d0..9dda73d 100644 --- a/src/network/tcp_client_socket_handler.cpp +++ b/src/network/tcp_client_socket_handler.cpp @@ -146,15 +146,22 @@ void TCPClientSocketHandler::connect(const std::string& address, const std::stri || errno == EISCONN) { log_info("Connection success."); +#ifdef BOTAN_FOUND + if (this->use_tls) + try { + this->start_tls(this->address, this->port); + } catch (const Botan::Exception& e) + { + this->on_connection_failed("TLS error: "s + e.what()); + this->close(); + return ; + } +#endif TimedEventsManager::instance().cancel("connection_timeout" + std::to_string(this->socket)); this->poller->add_socket_handler(this); this->connected = true; this->connecting = false; -#ifdef BOTAN_FOUND - if (this->use_tls) - this->start_tls(this->address, this->port); -#endif this->connection_date = std::chrono::system_clock::now(); // Get our local TCP port and store it diff --git a/src/network/tcp_socket_handler.cpp b/src/network/tcp_socket_handler.cpp index 642cf03..e05caad 100644 --- a/src/network/tcp_socket_handler.cpp +++ b/src/network/tcp_socket_handler.cpp @@ -50,7 +50,7 @@ TCPSocketHandler::TCPSocketHandler(std::shared_ptr<Poller>& poller): SocketHandler(poller, -1), use_tls(false) #ifdef BOTAN_FOUND - ,credential_manager(this) + ,credential_manager() #endif {} @@ -84,10 +84,11 @@ void TCPSocketHandler::plain_recv() if (recv_buf == nullptr) recv_buf = buf; - const ssize_t size = this->do_recv(recv_buf, buf_size); + const ssize_t ssize = this->do_recv(recv_buf, buf_size); - if (size > 0) + if (ssize > 0) { + auto size = static_cast<std::size_t>(ssize); if (buf == recv_buf) { // data needs to be placed in the in_buf string, because no buffer @@ -149,21 +150,22 @@ void TCPSocketHandler::on_send() } else { + auto size = static_cast<std::size_t>(res); // remove all the strings that were successfully sent. auto it = this->out_buf.begin(); while (it != this->out_buf.end()) { - if (static_cast<size_t>(res) >= it->size()) + if (size >= it->size()) { - res -= it->size(); + size -= it->size(); ++it; } else { // If one string has partially been sent, we use substr to // crop it - if (res > 0) - *it = it->substr(res, std::string::npos); + if (size > 0) + *it = it->substr(size, std::string::npos); break; } } @@ -332,6 +334,11 @@ void TCPSocketHandler::tls_verify_cert_chain(const std::vector<Botan::X509_Certi Botan::Usage_Type usage, const std::string& hostname, const Botan::TLS::Policy& policy) { + if (!this->policy.verify_certificate) + { + log_debug("Not verifying certificate due to domain policy "); + return; + } log_debug("Checking remote certificate for hostname ", hostname); try { diff --git a/src/network/tls_policy.cpp b/src/network/tls_policy.cpp index b88eb88..f32557e 100644 --- a/src/network/tls_policy.cpp +++ b/src/network/tls_policy.cpp @@ -37,6 +37,8 @@ void BiboumiTLSPolicy::load(std::istream& is) // Workaround for options that are not overridden in Botan::TLS::Text_Policy if (pair.first == "require_cert_revocation_info") this->req_cert_revocation_info = !(pair.second == "0" || utils::tolower(pair.second) == "false"); + else if (pair.first == "verify_certificate") + this->verify_certificate = !(pair.second == "0" || utils::tolower(pair.second) == "false"); else this->set(pair.first, pair.second); } diff --git a/src/network/tls_policy.hpp b/src/network/tls_policy.hpp index 29fd2b3..e915646 100644 --- a/src/network/tls_policy.hpp +++ b/src/network/tls_policy.hpp @@ -21,6 +21,7 @@ public: BiboumiTLSPolicy &operator=(BiboumiTLSPolicy &&) = delete; bool require_cert_revocation_info() const override; + bool verify_certificate{true}; protected: bool req_cert_revocation_info{true}; }; diff --git a/src/utils/dirname.cpp b/src/utils/dirname.cpp index 71c9c38..a304117 100644 --- a/src/utils/dirname.cpp +++ b/src/utils/dirname.cpp @@ -2,7 +2,7 @@ namespace utils { - std::string dirname(const std::string filename) + std::string dirname(const std::string& filename) { if (filename.empty()) return "./"; diff --git a/src/utils/dirname.hpp b/src/utils/dirname.hpp index c1df81b..73e1b57 100644 --- a/src/utils/dirname.hpp +++ b/src/utils/dirname.hpp @@ -2,5 +2,5 @@ namespace utils { -std::string dirname(const std::string filename); +std::string dirname(const std::string& filename); } diff --git a/src/utils/encoding.cpp b/src/utils/encoding.cpp index cff0039..8532292 100644 --- a/src/utils/encoding.cpp +++ b/src/utils/encoding.cpp @@ -48,16 +48,16 @@ namespace utils if (codepoint_size == 4) { if (!str[1] || !str[2] || !str[3] - || ((str[1] & 0b11000000) != 0b10000000) - || ((str[2] & 0b11000000) != 0b10000000) - || ((str[3] & 0b11000000) != 0b10000000)) + || ((str[1] & 0b11000000u) != 0b10000000u) + || ((str[2] & 0b11000000u) != 0b10000000u) + || ((str[3] & 0b11000000u) != 0b10000000u)) return false; } else if (codepoint_size == 3) { if (!str[1] || !str[2] - || ((str[1] & 0b11000000) != 0b10000000) - || ((str[2] & 0b11000000) != 0b10000000)) + || ((str[1] & 0b11000000u) != 0b10000000u) + || ((str[2] & 0b11000000u) != 0b10000000u)) return false; } else if (codepoint_size == 2) @@ -81,7 +81,7 @@ namespace utils // pointer where we write valid chars char* r = res.data(); - const char* str = original.c_str(); + const unsigned char* str = reinterpret_cast<const unsigned char*>(original.c_str()); std::bitset<20> codepoint; while (*str) @@ -89,10 +89,10 @@ namespace utils // 4 bytes: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx if ((str[0] & 0b11111000) == 0b11110000) { - codepoint = ((str[0] & 0b00000111) << 18); - codepoint |= ((str[1] & 0b00111111) << 12); - codepoint |= ((str[2] & 0b00111111) << 6 ); - codepoint |= ((str[3] & 0b00111111) << 0 ); + codepoint = ((str[0] & 0b00000111u) << 18u); + codepoint |= ((str[1] & 0b00111111u) << 12u); + codepoint |= ((str[2] & 0b00111111u) << 6u ); + codepoint |= ((str[3] & 0b00111111u) << 0u ); if (codepoint.to_ulong() <= 0x10FFFF) { ::memcpy(r, str, 4); @@ -103,9 +103,9 @@ namespace utils // 3 bytes: 1110xxx 10xxxxxx 10xxxxxx else if ((str[0] & 0b11110000) == 0b11100000) { - codepoint = ((str[0] & 0b00001111) << 12); - codepoint |= ((str[1] & 0b00111111) << 6); - codepoint |= ((str[2] & 0b00111111) << 0 ); + codepoint = ((str[0] & 0b00001111u) << 12u); + codepoint |= ((str[1] & 0b00111111u) << 6u); + codepoint |= ((str[2] & 0b00111111u) << 0u ); if (codepoint.to_ulong() <= 0xD7FF || (codepoint.to_ulong() >= 0xE000 && codepoint.to_ulong() <= 0xFFFD)) { diff --git a/src/utils/optional_bool.hpp b/src/utils/optional_bool.hpp index 867aca2..3d00d23 100644 --- a/src/utils/optional_bool.hpp +++ b/src/utils/optional_bool.hpp @@ -6,7 +6,7 @@ struct OptionalBool { OptionalBool() = default; - OptionalBool(bool value): + explicit OptionalBool(bool value): is_set(true), value(value) {} void set_value(bool value) diff --git a/src/utils/string.cpp b/src/utils/string.cpp index 635e71a..366ec1f 100644 --- a/src/utils/string.cpp +++ b/src/utils/string.cpp @@ -15,11 +15,11 @@ std::vector<std::string> cut(const std::string& val, const std::size_t size) // Get the number of chars, <= size, that contain only whole // UTF-8 codepoints. std::size_t s = 0; - auto codepoint_size = utils::get_next_codepoint_size(val[pos + s]); + auto codepoint_size = utils::get_next_codepoint_size(static_cast<unsigned char>(val[pos + s])); while (s + codepoint_size <= size && pos + s < val.size()) { s += codepoint_size; - codepoint_size = utils::get_next_codepoint_size(val[pos + s]); + codepoint_size = utils::get_next_codepoint_size(static_cast<unsigned char>(val[pos + s])); } res.emplace_back(val.substr(pos, s)); pos += s; diff --git a/src/utils/tokens_bucket.hpp b/src/utils/tokens_bucket.hpp new file mode 100644 index 0000000..263359a --- /dev/null +++ b/src/utils/tokens_bucket.hpp @@ -0,0 +1,60 @@ +/** + * Implementation of the token bucket algorithm. + * + * It uses a repetitive TimedEvent, started at construction, to fill the + * bucket. + * + * Every n seconds, it executes the given callback. If the callback + * returns true, we add a token (if the limit is not yet reached). + * + */ + +#pragma once + +#include <utils/timed_events.hpp> +#include <logger/logger.hpp> + +class TokensBucket +{ +public: + TokensBucket(long int max_size, std::chrono::milliseconds fill_duration, std::function<bool()> callback, std::string name): + limit(max_size), + tokens(static_cast<std::size_t>(limit)), + callback(std::move(callback)) + { + log_debug("creating TokensBucket with max size: ", max_size); + TimedEvent event(std::move(fill_duration), [this]() { this->add_token(); }, std::move(name)); + TimedEventsManager::instance().add_event(std::move(event)); + } + + bool use_token() + { + if (this->limit < 0) + return true; + if (this->tokens > 0) + { + this->tokens--; + return true; + } + else + return false; + } + + void set_limit(long int limit) + { + this->limit = limit; + } + +private: + long int limit; + std::size_t tokens; + std::function<bool()> callback; + + void add_token() + { + if (this->limit < 0) + return; + if (this->callback() && this->tokens != static_cast<decltype(this->tokens)>(this->limit)) + this->tokens++; + } +}; diff --git a/src/utils/uuid.cpp b/src/utils/uuid.cpp new file mode 100644 index 0000000..23b71fe --- /dev/null +++ b/src/utils/uuid.cpp @@ -0,0 +1,14 @@ +#include <utils/uuid.hpp> +#include <uuid/uuid.h> + +namespace utils +{ +std::string gen_uuid() +{ + char uuid_str[37]; + uuid_t uuid; + uuid_generate(uuid); + uuid_unparse(uuid, uuid_str); + return uuid_str; +} +} diff --git a/src/utils/uuid.hpp b/src/utils/uuid.hpp new file mode 100644 index 0000000..d550475 --- /dev/null +++ b/src/utils/uuid.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include <string> + +namespace utils +{ +std::string gen_uuid(); +} diff --git a/src/xmpp/adhoc_commands_handler.cpp b/src/xmpp/adhoc_commands_handler.cpp index bb48781..bc4c108 100644 --- a/src/xmpp/adhoc_commands_handler.cpp +++ b/src/xmpp/adhoc_commands_handler.cpp @@ -41,7 +41,7 @@ XmlNode AdhocCommandsHandler::handle_request(const std::string& executor_jid, co XmlSubNode condition(error, STANZA_NS":item-not-found"); } else if (command_it->second.is_admin_only() && - Config::get("admin", "") != jid.local + "@" + jid.domain) + !Config::is_in_list("admin", jid.bare())) { XmlSubNode error(command_node, ADHOC_NS":error"); error["type"] = "cancel"; diff --git a/src/xmpp/biboumi_adhoc_commands.cpp b/src/xmpp/biboumi_adhoc_commands.cpp index bcdac39..7c31f36 100644 --- a/src/xmpp/biboumi_adhoc_commands.cpp +++ b/src/xmpp/biboumi_adhoc_commands.cpp @@ -14,6 +14,14 @@ #ifdef USE_DATABASE #include <database/database.hpp> +#include <database/save.hpp> + +static void set_desc(XmlSubNode& field, const char* text) +{ + XmlSubNode desc(field, "desc"); + desc.set_inner(text); +} + #endif #ifndef HAS_PUT_TIME @@ -115,6 +123,7 @@ void ConfigureGlobalStep1(XmppComponent&, AdhocSession& session, XmlNode& comman auto options = Database::get_global_options(owner.bare()); + command_node.delete_all_children(); XmlSubNode x(command_node, "jabber:x:data:x"); x["type"] = "form"; XmlSubNode title(x, "title"); @@ -127,7 +136,7 @@ void ConfigureGlobalStep1(XmppComponent&, AdhocSession& session, XmlNode& comman max_histo_length["var"] = "max_history_length"; max_histo_length["type"] = "text-single"; max_histo_length["label"] = "Max history length"; - max_histo_length["desc"] = "The maximum number of lines in the history that the server sends when joining a channel"; + set_desc(max_histo_length, "The maximum number of lines in the history that the server sends when joining a channel"); { XmlSubNode value(max_histo_length, "value"); value.set_inner(std::to_string(options.col<Database::MaxHistoryLength>())); @@ -139,7 +148,7 @@ void ConfigureGlobalStep1(XmppComponent&, AdhocSession& session, XmlNode& comman record_history["var"] = "record_history"; record_history["type"] = "boolean"; record_history["label"] = "Record history"; - record_history["desc"] = "Whether to save the messages into the database, or not"; + set_desc(record_history, "Whether to save the messages into the database, or not"); { XmlSubNode value(record_history, "value"); value.set_name("value"); @@ -155,7 +164,7 @@ void ConfigureGlobalStep1(XmppComponent&, AdhocSession& session, XmlNode& comman persistent["var"] = "persistent"; persistent["type"] = "boolean"; persistent["label"] = "Make all channels persistent"; - persistent["desc"] = "If true, all channels will be persistent"; + set_desc(persistent, "If true, all channels will be persistent"); { XmlSubNode value(persistent, "value"); value.set_name("value"); @@ -176,6 +185,7 @@ void ConfigureGlobalStep2(XmppComponent& xmpp_component, AdhocSession& session, { const Jid owner(session.get_owner_jid()); auto options = Database::get_global_options(owner.bare()); + options.clear(); for (const XmlNode* field: x->get_children("field", "jabber:x:data")) { const XmlNode* value = field->get_child("value", "jabber:x:data"); @@ -196,7 +206,7 @@ void ConfigureGlobalStep2(XmppComponent& xmpp_component, AdhocSession& session, options.col<Database::GlobalPersistent>() = to_bool(value->get_inner()); } - options.save(Database::db); + save(options, *Database::db); command_node.delete_all_children(); XmlSubNode note(command_node, "note"); @@ -219,7 +229,9 @@ void ConfigureIrcServerStep1(XmppComponent&, AdhocSession& session, XmlNode& com server_domain = target.local; auto options = Database::get_irc_server_options(owner.local + "@" + owner.domain, server_domain); + auto commands = Database::get_after_connection_commands(options); + command_node.delete_all_children(); XmlSubNode x(command_node, "jabber:x:data:x"); x["type"] = "form"; XmlSubNode title(x, "title"); @@ -227,12 +239,26 @@ void ConfigureIrcServerStep1(XmppComponent&, AdhocSession& session, XmlNode& com XmlSubNode instructions(x, "instructions"); instructions.set_inner("Edit the form, to configure the settings of the IRC server " + server_domain); + if (Config::get("fixed_irc_server", "").empty()) + { + XmlSubNode field(x, "field"); + field["var"] = "address"; + field["type"] = "text-single"; + field["label"] = "Address"; + set_desc(field, "The address (hostname or IP) to connect to."); + XmlSubNode value(field, "value"); + if (options.col<Database::Address>().empty()) + value.set_inner(server_domain); + else + value.set_inner(options.col<Database::Address>()); + } + { XmlSubNode ports(x, "field"); ports["var"] = "ports"; ports["type"] = "text-multi"; ports["label"] = "Ports"; - ports["desc"] = "List of ports to try, without TLS. Defaults: 6667."; + set_desc(ports, "List of ports to try, without TLS. Defaults: 6667."); for (const auto& val: utils::split(options.col<Database::Ports>(), ';', false)) { XmlSubNode ports_value(ports, "value"); @@ -246,7 +272,7 @@ void ConfigureIrcServerStep1(XmppComponent&, AdhocSession& session, XmlNode& com tls_ports["var"] = "tls_ports"; tls_ports["type"] = "text-multi"; tls_ports["label"] = "TLS ports"; - tls_ports["desc"] = "List of ports to try, with TLS. Defaults: 6697, 6670."; + set_desc(tls_ports, "List of ports to try, with TLS. Defaults: 6697, 6670."); for (const auto& val: utils::split(options.col<Database::TlsPorts>(), ';', false)) { XmlSubNode tls_ports_value(tls_ports, "value"); @@ -259,7 +285,7 @@ void ConfigureIrcServerStep1(XmppComponent&, AdhocSession& session, XmlNode& com verify_cert["var"] = "verify_cert"; verify_cert["type"] = "boolean"; verify_cert["label"] = "Verify certificate"; - verify_cert["desc"] = "Whether or not to abort the connection if the server’s TLS certificate is invalid"; + set_desc(verify_cert, "Whether or not to abort the connection if the server’s TLS certificate is invalid"); XmlSubNode verify_cert_value(verify_cert, "value"); if (options.col<Database::VerifyCert>()) verify_cert_value.set_inner("true"); @@ -279,12 +305,26 @@ void ConfigureIrcServerStep1(XmppComponent&, AdhocSession& session, XmlNode& com } } #endif + + { + XmlSubNode field(x, "field"); + field["var"] = "nick"; + field["type"] = "text-single"; + field["label"] = "Nickname"; + set_desc(field, "If set, will override the nickname provided in the initial presence sent to join the first server channel"); + if (!options.col<Database::Nick>().empty()) + { + XmlSubNode value(field, "value"); + value.set_inner(options.col<Database::Nick>()); + } + } + { XmlSubNode pass(x, "field"); pass["var"] = "pass"; pass["type"] = "text-private"; pass["label"] = "Server password"; - pass["desc"] = "Will be used in a PASS command when connecting"; + set_desc(pass, "Will be used in a PASS command when connecting"); if (!options.col<Database::Pass>().empty()) { XmlSubNode pass_value(pass, "value"); @@ -294,14 +334,14 @@ void ConfigureIrcServerStep1(XmppComponent&, AdhocSession& session, XmlNode& com { XmlSubNode after_cnt_cmd(x, "field"); - after_cnt_cmd["var"] = "after_connect_command"; - after_cnt_cmd["type"] = "text-single"; - after_cnt_cmd["desc"] = "Custom IRC command sent after the connection is established with the server."; - after_cnt_cmd["label"] = "After-connection IRC command"; - if (!options.col<Database::AfterConnectionCommand>().empty()) + after_cnt_cmd["var"] = "after_connect_commands"; + after_cnt_cmd["type"] = "text-multi"; + set_desc(after_cnt_cmd, "Custom IRC commands sent after the connection is established with the server."); + after_cnt_cmd["label"] = "After-connection IRC commands"; + for (const auto& command: commands) { XmlSubNode after_cnt_cmd_value(after_cnt_cmd, "value"); - after_cnt_cmd_value.set_inner(options.col<Database::AfterConnectionCommand>()); + after_cnt_cmd_value.set_inner(command.col<Database::AfterConnectionCommand>()); } } @@ -333,10 +373,19 @@ void ConfigureIrcServerStep1(XmppComponent&, AdhocSession& session, XmlNode& com } { + XmlSubNode throttle_limit(x, "field"); + throttle_limit["var"] = "throttle_limit"; + throttle_limit["type"] = "text-single"; + throttle_limit["label"] = "Throttle limit"; + XmlSubNode value(throttle_limit, "value"); + value.set_inner(std::to_string(options.col<Database::ThrottleLimit>())); + } + + { XmlSubNode encoding_out(x, "field"); encoding_out["var"] = "encoding_out"; encoding_out["type"] = "text-single"; - encoding_out["desc"] = "The encoding used when sending messages to the IRC server."; + set_desc(encoding_out, "The encoding used when sending messages to the IRC server."); encoding_out["label"] = "Out encoding"; if (!options.col<Database::EncodingOut>().empty()) { @@ -349,7 +398,7 @@ void ConfigureIrcServerStep1(XmppComponent&, AdhocSession& session, XmlNode& com XmlSubNode encoding_in(x, "field"); encoding_in["var"] = "encoding_in"; encoding_in["type"] = "text-single"; - encoding_in["desc"] = "The encoding used to decode message received from the IRC server."; + set_desc(encoding_in, "The encoding used to decode message received from the IRC server."); encoding_in["label"] = "In encoding"; if (!options.col<Database::EncodingIn>().empty()) { @@ -359,8 +408,10 @@ void ConfigureIrcServerStep1(XmppComponent&, AdhocSession& session, XmlNode& com } } -void ConfigureIrcServerStep2(XmppComponent&, AdhocSession& session, XmlNode& command_node) +void ConfigureIrcServerStep2(XmppComponent& xmpp_component, AdhocSession& session, XmlNode& command_node) { + auto& biboumi_component = dynamic_cast<BiboumiComponent&>(xmpp_component); + const XmlNode* x = command_node.get_child("x", "jabber:x:data"); if (x) { @@ -371,10 +422,17 @@ void ConfigureIrcServerStep2(XmppComponent&, AdhocSession& session, XmlNode& com server_domain = target.local; auto options = Database::get_irc_server_options(owner.local + "@" + owner.domain, server_domain); + options.clear(); + Database::AfterConnectionCommands commands{}; + for (const XmlNode* field: x->get_children("field", "jabber:x:data")) { const XmlNode* value = field->get_child("value", "jabber:x:data"); const std::vector<const XmlNode*> values = field->get_children("value", "jabber:x:data"); + + if (field->get_tag("var") == "address" && value && Config::get("fixed_irc_server", "").empty()) + options.col<Database::Address>() = value->get_inner(); + if (field->get_tag("var") == "ports") { std::string ports; @@ -406,11 +464,22 @@ void ConfigureIrcServerStep2(XmppComponent&, AdhocSession& session, XmlNode& com #endif // BOTAN_FOUND + else if (field->get_tag("var") == "nick" && value) + options.col<Database::Nick>() = value->get_inner(); + else if (field->get_tag("var") == "pass" && value) options.col<Database::Pass>() = value->get_inner(); - else if (field->get_tag("var") == "after_connect_command" && value) - options.col<Database::AfterConnectionCommand>() = value->get_inner(); + else if (field->get_tag("var") == "after_connect_commands") + { + commands.clear(); + for (const auto& val: values) + { + auto command = Database::after_connection_commands.row(); + command.col<Database::AfterConnectionCommand>() = val->get_inner(); + commands.push_back(std::move(command)); + } + } else if (field->get_tag("var") == "username" && value) { @@ -423,6 +492,22 @@ void ConfigureIrcServerStep2(XmppComponent&, AdhocSession& session, XmlNode& com else if (field->get_tag("var") == "realname" && value) options.col<Database::Realname>() = value->get_inner(); + else if (field->get_tag("var") == "throttle_limit" && value) + { + try { + options.col<Database::ThrottleLimit>() = std::stol(value->get_inner()); + } catch (const std::logic_error&) { + options.col<Database::ThrottleLimit>() = -1; + } + Bridge* bridge = biboumi_component.find_user_bridge(session.get_owner_jid()); + if (bridge) + { + IrcClient* client = bridge->find_irc_client(server_domain); + if (client) + client->set_throttle_limit(options.col<Database::ThrottleLimit>()); + } + } + else if (field->get_tag("var") == "encoding_out" && value) options.col<Database::EncodingOut>() = value->get_inner(); @@ -431,7 +516,8 @@ void ConfigureIrcServerStep2(XmppComponent&, AdhocSession& session, XmlNode& com } Database::invalidate_encoding_in_cache(); - options.save(Database::db); + save(options, *Database::db); + Database::set_after_connection_commands(options, commands); command_node.delete_all_children(); XmlSubNode note(command_node, "note"); @@ -459,6 +545,7 @@ void insert_irc_channel_configuration_form(XmlNode& node, const Jid& requester, auto options = Database::get_irc_channel_options_with_server_default(requester.local + "@" + requester.domain, iid.get_server(), iid.get_local()); + node.delete_all_children(); XmlSubNode x(node, "jabber:x:data:x"); x["type"] = "form"; XmlSubNode title(x, "title"); @@ -471,7 +558,7 @@ void insert_irc_channel_configuration_form(XmlNode& node, const Jid& requester, record_history["var"] = "record_history"; record_history["type"] = "list-single"; record_history["label"] = "Record history for this channel"; - record_history["desc"] = "If unset, the value is the one configured globally"; + set_desc(record_history, "If unset, the value is the one configured globally"); { // Value selected by default XmlSubNode value(record_history, "value"); @@ -491,7 +578,7 @@ void insert_irc_channel_configuration_form(XmlNode& node, const Jid& requester, XmlSubNode encoding_out(x, "field"); encoding_out["var"] = "encoding_out"; encoding_out["type"] = "text-single"; - encoding_out["desc"] = "The encoding used when sending messages to the IRC server. Defaults to the server's “out encoding” if unset for the channel"; + set_desc(encoding_out, "The encoding used when sending messages to the IRC server. Defaults to the server's “out encoding” if unset for the channel"); encoding_out["label"] = "Out encoding"; if (!options.col<Database::EncodingOut>().empty()) { @@ -504,7 +591,7 @@ void insert_irc_channel_configuration_form(XmlNode& node, const Jid& requester, XmlSubNode encoding_in(x, "field"); encoding_in["var"] = "encoding_in"; encoding_in["type"] = "text-single"; - encoding_in["desc"] = "The encoding used to decode message received from the IRC server. Defaults to the server's “in encoding” if unset for the channel"; + set_desc(encoding_in, "The encoding used to decode message received from the IRC server. Defaults to the server's “in encoding” if unset for the channel"); encoding_in["label"] = "In encoding"; if (!options.col<Database::EncodingIn>().empty()) { @@ -517,7 +604,7 @@ void insert_irc_channel_configuration_form(XmlNode& node, const Jid& requester, XmlSubNode persistent(x, "field"); persistent["var"] = "persistent"; persistent["type"] = "boolean"; - persistent["desc"] = "If set to true, when all XMPP clients have left this channel, biboumi will stay idle in it, without sending a PART command."; + set_desc(persistent, "If set to true, when all XMPP clients have left this channel, biboumi will stay idle in it, without sending a PART command."); persistent["label"] = "Persistent"; { XmlSubNode value(persistent, "value"); @@ -561,6 +648,7 @@ bool handle_irc_channel_configuration_form(XmppComponent& xmpp_component, const const Iid iid(target.local, {}); auto options = Database::get_irc_channel_options(requester.bare(), iid.get_server(), iid.get_local()); + options.clear(); for (const XmlNode *field: x->get_children("field", "jabber:x:data")) { const XmlNode *value = field->get_child("value", "jabber:x:data"); @@ -600,7 +688,7 @@ bool handle_irc_channel_configuration_form(XmppComponent& xmpp_component, const } Database::invalidate_encoding_in_cache(requester.bare(), iid.get_server(), iid.get_local()); - options.save(Database::db); + save(options, *Database::db); } return true; } @@ -611,7 +699,7 @@ bool handle_irc_channel_configuration_form(XmppComponent& xmpp_component, const void DisconnectUserFromServerStep1(XmppComponent& xmpp_component, AdhocSession& session, XmlNode& command_node) { const Jid owner(session.get_owner_jid()); - if (owner.bare() != Config::get("admin", "")) + if (!Config::is_in_list("admin", owner.bare())) { // A non-admin is not allowed to disconnect other users, only // him/herself, so we just skip this step auto next_step = session.get_next_step(); diff --git a/src/xmpp/biboumi_component.cpp b/src/xmpp/biboumi_component.cpp index 481ebb9..85617e8 100644 --- a/src/xmpp/biboumi_component.cpp +++ b/src/xmpp/biboumi_component.cpp @@ -72,11 +72,16 @@ BiboumiComponent::BiboumiComponent(std::shared_ptr<Poller>& poller, const std::s AdhocCommand configure_global_command({&ConfigureGlobalStep1, &ConfigureGlobalStep2}, "Configure a few settings", false); if (!Config::get("fixed_irc_server", "").empty()) - this->adhoc_commands_handler.add_command("configure", configure_server_command); + { + this->adhoc_commands_handler.add_command("server-configure", configure_server_command); + this->adhoc_commands_handler.add_command("global-configure", configure_global_command); + } else - this->adhoc_commands_handler.add_command("configure", configure_global_command); + { + this->adhoc_commands_handler.add_command("configure", configure_global_command); + this->irc_server_adhoc_commands_handler.add_command("configure", configure_server_command); + } - this->irc_server_adhoc_commands_handler.add_command("configure", configure_server_command); this->irc_channel_adhoc_commands_handler.add_command("configure", {{&ConfigureIrcChannelStep1, &ConfigureIrcChannelStep2}, "Configure a few settings for that IRC channel", false}); #endif } @@ -97,8 +102,8 @@ void BiboumiComponent::shutdown() void BiboumiComponent::clean() { - auto it = this->bridges.begin(); - while (it != this->bridges.end()) + auto it = std::begin(this->bridges); + while (it != std::end(this->bridges)) { it->second->clean(); if (it->second->active_clients() == 0) @@ -148,13 +153,10 @@ void BiboumiComponent::handle_presence(const Stanza& stanza) try { if (iid.type == Iid::Type::Channel && !iid.get_server().empty()) - { // presence toward a MUC that corresponds to an irc channel, or a - // dummy channel if iid.chan is empty + { // presence toward a MUC that corresponds to an irc channel if (type.empty()) { const std::string own_nick = bridge->get_own_nick(iid); - if (!own_nick.empty() && own_nick != to.resource) - bridge->send_irc_nick_change(iid, to.resource, from.resource); const XmlNode* x = stanza.get_child("x", MUC_NS); const XmlNode* password = x ? x->get_child("password", MUC_NS): nullptr; const XmlNode* history = x ? x->get_child("history", MUC_NS): nullptr; @@ -182,7 +184,9 @@ void BiboumiComponent::handle_presence(const Stanza& stanza) history_limit.stanzas = 0; } bridge->join_irc_channel(iid, to.resource, password ? password->get_inner(): "", - from.resource, history_limit); + from.resource, history_limit, x != nullptr); + if (!own_nick.empty() && own_nick != to.resource) + bridge->send_irc_nick_change(iid, to.resource, from.resource); } else if (type == "unavailable") { @@ -269,9 +273,10 @@ void BiboumiComponent::handle_message(const Stanza& stanza) std::string error_type("cancel"); std::string error_name("internal-server-error"); - utils::ScopeGuard stanza_error([this, &from_str, &to_str, &id, &error_type, &error_name](){ + std::string error_text{}; + utils::ScopeGuard stanza_error([this, &from_str, &to_str, &id, &error_type, &error_name, &error_text](){ this->send_stanza_error("message", from_str, to_str, id, - error_type, error_name, ""); + error_type, error_name, error_text); }); const XmlNode* body = stanza.get_child("body", COMPONENT_NS); @@ -280,7 +285,15 @@ void BiboumiComponent::handle_message(const Stanza& stanza) { if (body && !body->get_inner().empty()) { - bridge->send_channel_message(iid, body->get_inner()); + if (bridge->is_resource_in_chan(iid.to_tuple(), from.resource)) + bridge->send_channel_message(iid, body->get_inner(), id); + else + { + error_type = "modify"; + error_name = "not-acceptable"; + error_text = "You are not a participant in this room."; + return; + } } const XmlNode* subject = stanza.get_child("subject", COMPONENT_NS); if (subject) @@ -346,7 +359,6 @@ void BiboumiComponent::handle_message(const Stanza& stanza) this->send_invitation_from_fulljid(std::to_string(iid), invite_to, from_str); } } - } } catch (const IRCNotConnected& ex) { @@ -467,8 +479,13 @@ void BiboumiComponent::handle_iq(const Stanza& stanza) #ifdef USE_DATABASE else if ((query = stanza.get_child("query", MAM_NS))) { - if (this->handle_mam_request(stanza)) - stanza_error.disable(); + try { + if (this->handle_mam_request(stanza)) + stanza_error.disable(); + } catch (const Database::RecordNotFound& exc) { + error_name = "item-not-found"; + return; + } } else if ((query = stanza.get_child("query", MUC_OWNER_NS))) { @@ -505,7 +522,11 @@ void BiboumiComponent::handle_iq(const Stanza& stanza) { if (node.empty()) { - this->send_irc_channel_disco_info(id, from, to_str); + const IrcClient* irc_client = bridge->find_irc_client(iid.get_server()); + const IrcChannel* irc_channel{}; + if (irc_client) + irc_channel = irc_client->find_channel(iid.get_local()); + this->send_irc_channel_disco_info(id, from, to_str, irc_channel); stanza_error.disable(); } else if (node == MUC_TRAFFIC_NS) @@ -547,24 +568,21 @@ void BiboumiComponent::handle_iq(const Stanza& stanza) if (to.local.empty()) { // Get biboumi's adhoc commands this->send_adhoc_commands_list(id, from, this->served_hostname, - (Config::get("admin", "") == - from_jid.bare()), + Config::is_in_list("admin", from_jid.bare()), this->adhoc_commands_handler); stanza_error.disable(); } else if (iid.type == Iid::Type::Server) { // Get the server's adhoc commands this->send_adhoc_commands_list(id, from, to_str, - (Config::get("admin", "") == - from_jid.bare()), + Config::is_in_list("admin", from_jid.bare()), this->irc_server_adhoc_commands_handler); stanza_error.disable(); } else if (iid.type == Iid::Type::Channel && to.resource.empty()) { // Get the channel's adhoc commands this->send_adhoc_commands_list(id, from, to_str, - (Config::get("admin", "") == - from_jid.bare()), + Config::is_in_list("admin", from_jid.bare()), this->irc_channel_adhoc_commands_handler); stanza_error.disable(); } @@ -586,7 +604,6 @@ void BiboumiComponent::handle_iq(const Stanza& stanza) const XmlNode* max = set_node->get_child("max", RSM_NS); if (max) rs_info.max = std::atoi(max->get_inner().data()); - } if (rs_info.max == -1) rs_info.max = 100; @@ -715,32 +732,47 @@ bool BiboumiComponent::handle_mam_request(const Stanza& stanza) } const XmlNode* set = query->get_child("set", RSM_NS); int limit = -1; + Id::real_type reference_record_id{Id::unset_value}; + Database::Paging paging_order{Database::Paging::first}; if (set) { const XmlNode* max = set->get_child("max", RSM_NS); if (max) limit = std::atoi(max->get_inner().data()); + const XmlNode* after = set->get_child("after", RSM_NS); + if (after) + { + auto after_record = Database::get_muc_log(from.bare(), iid.get_local(), iid.get_server(), + after->get_inner(), start, end); + reference_record_id = after_record.col<Id>(); + } + const XmlNode* before = set->get_child("before", RSM_NS); + if (before) + { + paging_order = Database::Paging::last; + if (!before->get_inner().empty()) + { + auto before_record = Database::get_muc_log(from.bare(), iid.get_local(), iid.get_server(), before->get_inner(), start, end); + reference_record_id = before_record.col<Id>(); + } + } } - // Do send more than 100 messages, even if the client asked for more, + // Do not send more than 100 messages, even if the client asked for more, // or if it didn’t specify any limit. - // 101 is just a trick to know if there are more available messages. - // If our query returns 101 message, we know it’s incomplete, but we - // still send only 100 - if ((limit == -1 && start.empty() && end.empty()) - || limit > 100) - limit = 101; - auto lines = Database::get_muc_logs(from.bare(), iid.get_local(), iid.get_server(), limit, start, end); - bool complete = true; - if (lines.size() > 100) + if (limit < 0 || limit > 100) + limit = 100; + auto result = Database::get_muc_logs(from.bare(), iid.get_local(), iid.get_server(), + static_cast<std::size_t>(limit), + start, end, + reference_record_id, paging_order); + bool complete = std::get<bool>(result); + auto& lines = std::get<1>(result); + + for (const Database::MucLogLine& line: lines) { - complete = false; - lines.erase(lines.begin(), std::prev(lines.end(), 100)); + if (!line.col<Database::Nick>().empty()) + this->send_archived_message(line, to.full(), from.full(), query_id); } - for (const Database::MucLogLine& line: lines) - { - if (!line.col<Database::Nick>().empty()) - this->send_archived_message(line, to.full(), from.full(), query_id); - } { auto fin_ptr = std::make_unique<XmlNode>("fin"); { @@ -892,7 +924,7 @@ void BiboumiComponent::send_self_disco_info(const std::string& id, const std::st identity["category"] = "conference"; identity["type"] = "irc"; identity["name"] = "Biboumi XMPP-IRC gateway"; - for (const char *ns: {DISCO_INFO_NS, MUC_NS, ADHOC_NS, PING_NS, MAM_NS, VERSION_NS}) + for (const char *ns: {DISCO_INFO_NS, MUC_NS, ADHOC_NS, PING_NS, MAM_NS, VERSION_NS, STABLE_MUC_ID_NS}) { XmlSubNode feature(query, "feature"); feature["var"] = ns; @@ -916,7 +948,7 @@ void BiboumiComponent::send_irc_server_disco_info(const std::string& id, const s identity["category"] = "conference"; identity["type"] = "irc"; identity["name"] = "IRC server " + from.local + " over Biboumi"; - for (const char *ns: {DISCO_INFO_NS, MUC_NS, ADHOC_NS, PING_NS, MAM_NS, VERSION_NS}) + for (const char *ns: {DISCO_INFO_NS, MUC_NS, ADHOC_NS, PING_NS, MAM_NS, VERSION_NS, STABLE_MUC_ID_NS}) { XmlSubNode feature(query, "feature"); feature["var"] = ns; @@ -943,7 +975,8 @@ void BiboumiComponent::send_irc_channel_muc_traffic_info(const std::string& id, this->send_stanza(iq); } -void BiboumiComponent::send_irc_channel_disco_info(const std::string& id, const std::string& jid_to, const std::string& jid_from) +void BiboumiComponent::send_irc_channel_disco_info(const std::string& id, const std::string& jid_to, + const std::string& jid_from, const IrcChannel* irc_channel) { Jid from(jid_from); Iid iid(from.local, {}); @@ -958,12 +991,32 @@ void BiboumiComponent::send_irc_channel_disco_info(const std::string& id, const XmlSubNode identity(query, "identity"); identity["category"] = "conference"; identity["type"] = "irc"; - identity["name"] = "IRC channel " + iid.get_local() + " from server " + iid.get_server() + " over biboumi"; - for (const char *ns: {DISCO_INFO_NS, MUC_NS, ADHOC_NS, PING_NS, MAM_NS, VERSION_NS}) + identity["name"] = ""s + iid.get_local() + " on " + iid.get_server(); + for (const char *ns: {DISCO_INFO_NS, MUC_NS, ADHOC_NS, PING_NS, MAM_NS, VERSION_NS, STABLE_MUC_ID_NS}) { XmlSubNode feature(query, "feature"); feature["var"] = ns; } + + XmlSubNode x(query, "x"); + x["xmlns"] = DATAFORM_NS; + x["type"] = "result"; + { + XmlSubNode field(x, "field"); + field["var"] = "FORM_TYPE"; + field["type"] = "hidden"; + XmlSubNode value(field, "value"); + value.set_inner("http://jabber.org/protocol/muc#roominfo"); + } + + if (irc_channel && irc_channel->joined) + { + XmlSubNode field(x, "field"); + field["var"] = "muc#roominfo_occupants"; + field["label"] = "Number of occupants"; + XmlSubNode value(field, "value"); + value.set_inner(std::to_string(irc_channel->get_users().size())); + } } this->send_stanza(iq); } diff --git a/src/xmpp/biboumi_component.hpp b/src/xmpp/biboumi_component.hpp index caf990e..f59ed9b 100644 --- a/src/xmpp/biboumi_component.hpp +++ b/src/xmpp/biboumi_component.hpp @@ -73,7 +73,8 @@ public: * http://xmpp.org/extensions/xep-0045.html#impl-service-traffic */ void send_irc_channel_muc_traffic_info(const std::string& id, const std::string& jid_to, const std::string& jid_from); - void send_irc_channel_disco_info(const std::string& id, const std::string& jid_to, const std::string& jid_from); + void send_irc_channel_disco_info(const std::string& id, const std::string& jid_to, const std::string& jid_from, + const IrcChannel* irc_channel); /** * Send a ping request */ diff --git a/src/xmpp/jid.cpp b/src/xmpp/jid.cpp index 19d1b55..3c54fd4 100644 --- a/src/xmpp/jid.cpp +++ b/src/xmpp/jid.cpp @@ -106,7 +106,7 @@ std::string jidprep(const std::string& original) --domain_end; if (domain_end != domain && special_chars.count(domain[0])) { - std::memmove(domain, domain + 1, domain_end - domain + 1); + std::memmove(domain, domain + 1, static_cast<std::size_t>(domain_end - domain) + 1); --domain_end; } // And if the final result is an empty string, return a dummy hostname diff --git a/src/xmpp/xmpp_component.cpp b/src/xmpp/xmpp_component.cpp index 9be9e34..b3d925e 100644 --- a/src/xmpp/xmpp_component.cpp +++ b/src/xmpp/xmpp_component.cpp @@ -2,6 +2,7 @@ #include <utils/scopeguard.hpp> #include <utils/tolower.hpp> #include <logger/logger.hpp> +#include <utils/uuid.hpp> #include <xmpp/xmpp_component.hpp> #include <config/config.hpp> @@ -14,8 +15,6 @@ #include <iostream> #include <set> -#include <uuid/uuid.h> - #include <cstdlib> #include <set> @@ -364,10 +363,11 @@ void XmppComponent::send_topic(const std::string& from, Xmpp::body&& topic, cons this->send_stanza(message); } -void XmppComponent::send_muc_message(const std::string& muc_name, const std::string& nick, Xmpp::body&& xmpp_body, const std::string& jid_to, std::string uuid) +void XmppComponent::send_muc_message(const std::string& muc_name, const std::string& nick, Xmpp::body&& xmpp_body, const std::string& jid_to, std::string uuid, std::string id) { Stanza message("message"); message["to"] = jid_to; + message["id"] = std::move(id); if (!nick.empty()) message["from"] = muc_name + "@" + this->served_hostname + "/" + nick; else // Message from the room itself @@ -425,7 +425,8 @@ void XmppComponent::send_history_message(const std::string& muc_name, const std: #endif void XmppComponent::send_muc_leave(const std::string& muc_name, const std::string& nick, Xmpp::body&& message, - const std::string& jid_to, const bool self, const bool user_requested) + const std::string& jid_to, const bool self, const bool user_requested, + const std::string& affiliation, const std::string& role) { Stanza presence("presence"); { @@ -447,6 +448,9 @@ void XmppComponent::send_muc_leave(const std::string& muc_name, const std::strin status["code"] = "332"; } } + XmlSubNode item(x, "item"); + item["affiliation"] = affiliation; + item["role"] = role; if (!message_str.empty()) { XmlSubNode status(presence, "status"); @@ -669,9 +673,5 @@ void XmppComponent::send_iq_result(const std::string& id, const std::string& to_ std::string XmppComponent::next_id() { - char uuid_str[37]; - uuid_t uuid; - uuid_generate(uuid); - uuid_unparse(uuid, uuid_str); - return uuid_str; + return utils::gen_uuid(); } diff --git a/src/xmpp/xmpp_component.hpp b/src/xmpp/xmpp_component.hpp index 1daa6fb..e18da40 100644 --- a/src/xmpp/xmpp_component.hpp +++ b/src/xmpp/xmpp_component.hpp @@ -37,6 +37,7 @@ #define RSM_NS "http://jabber.org/protocol/rsm" #define MUC_TRAFFIC_NS "http://jabber.org/protocol/muc#traffic" #define STABLE_ID_NS "urn:xmpp:sid:0" +#define STABLE_MUC_ID_NS "http://jabber.org/protocol/muc#stable_id" /** * An XMPP component, communicating with an XMPP server using the protocole @@ -134,7 +135,7 @@ public: * Send a (non-private) message to the MUC */ void send_muc_message(const std::string& muc_name, const std::string& nick, Xmpp::body&& body, const std::string& jid_to, - std::string uuid); + std::string uuid, std::string id); #ifdef USE_DATABASE /** * Send a message, with a <delay/> element, part of a MUC history @@ -150,7 +151,8 @@ public: Xmpp::body&& message, const std::string& jid_to, const bool self, - const bool user_requested); + const bool user_requested, + const std::string& affiliation, const std::string& role); /** * Indicate that a participant changed his nick */ diff --git a/src/xmpp/xmpp_parser.cpp b/src/xmpp/xmpp_parser.cpp index 0488be9..781fe4c 100644 --- a/src/xmpp/xmpp_parser.cpp +++ b/src/xmpp/xmpp_parser.cpp @@ -20,7 +20,7 @@ static void end_element_handler(void* user_data, const XML_Char* name) static void character_data_handler(void *user_data, const XML_Char *s, int len) { - static_cast<XmppParser*>(user_data)->char_data(s, len); + static_cast<XmppParser*>(user_data)->char_data(s, static_cast<std::size_t>(len)); } /** |