summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/biboumi.h.cmake12
-rw-r--r--src/bridge/bridge.cpp218
-rw-r--r--src/bridge/bridge.hpp24
-rw-r--r--src/bridge/colors.cpp2
-rw-r--r--src/bridge/colors.hpp2
-rw-r--r--src/config/config.cpp127
-rw-r--r--src/config/config.hpp93
-rw-r--r--src/database/column.hpp17
-rw-r--r--src/database/count_query.hpp35
-rw-r--r--src/database/database.cpp223
-rw-r--r--src/database/database.hpp162
-rw-r--r--src/database/insert_query.hpp129
-rw-r--r--src/database/query.cpp34
-rw-r--r--src/database/query.hpp90
-rw-r--r--src/database/row.hpp75
-rw-r--r--src/database/select_query.hpp127
-rw-r--r--src/database/statement.hpp35
-rw-r--r--src/database/table.cpp25
-rw-r--r--src/database/table.hpp127
-rw-r--r--src/database/type_to_sql.cpp9
-rw-r--r--src/database/type_to_sql.hpp16
-rw-r--r--src/identd/identd_server.hpp39
-rw-r--r--src/identd/identd_socket.cpp63
-rw-r--r--src/identd/identd_socket.hpp36
-rw-r--r--src/irc/iid.cpp12
-rw-r--r--src/irc/iid.hpp3
-rw-r--r--src/irc/irc_channel.cpp24
-rw-r--r--src/irc/irc_channel.hpp5
-rw-r--r--src/irc/irc_client.cpp235
-rw-r--r--src/irc/irc_client.hpp17
-rw-r--r--src/irc/irc_message.cpp6
-rw-r--r--src/irc/irc_user.cpp2
-rw-r--r--src/irc/irc_user.hpp2
-rw-r--r--src/logger/logger.cpp42
-rw-r--r--src/logger/logger.hpp128
-rw-r--r--src/main.cpp75
-rw-r--r--src/network/credentials_manager.cpp136
-rw-r--r--src/network/credentials_manager.hpp55
-rw-r--r--src/network/dns_handler.cpp46
-rw-r--r--src/network/dns_handler.hpp37
-rw-r--r--src/network/dns_socket_handler.cpp43
-rw-r--r--src/network/dns_socket_handler.hpp33
-rw-r--r--src/network/poller.cpp238
-rw-r--r--src/network/poller.hpp98
-rw-r--r--src/network/resolver.cpp280
-rw-r--r--src/network/resolver.hpp122
-rw-r--r--src/network/socket_handler.hpp42
-rw-r--r--src/network/tcp_client_socket_handler.cpp261
-rw-r--r--src/network/tcp_client_socket_handler.hpp82
-rw-r--r--src/network/tcp_server_socket.hpp69
-rw-r--r--src/network/tcp_socket_handler.cpp360
-rw-r--r--src/network/tcp_socket_handler.hpp238
-rw-r--r--src/network/tls_policy.cpp48
-rw-r--r--src/network/tls_policy.hpp28
-rw-r--r--src/utils/dirname.cpp16
-rw-r--r--src/utils/dirname.hpp6
-rw-r--r--src/utils/encoding.cpp254
-rw-r--r--src/utils/encoding.hpp43
-rw-r--r--src/utils/get_first_non_empty.cpp11
-rw-r--r--src/utils/get_first_non_empty.hpp20
-rw-r--r--src/utils/optional_bool.hpp35
-rw-r--r--src/utils/reload.cpp2
-rw-r--r--src/utils/revstr.cpp9
-rw-r--r--src/utils/revstr.hpp11
-rw-r--r--src/utils/scopeguard.hpp98
-rw-r--r--src/utils/sha1.cpp40
-rw-r--r--src/utils/sha1.hpp5
-rw-r--r--src/utils/split.cpp19
-rw-r--r--src/utils/split.hpp12
-rw-r--r--src/utils/string.cpp28
-rw-r--r--src/utils/string.hpp8
-rw-r--r--src/utils/system.cpp21
-rw-r--r--src/utils/system.hpp8
-rw-r--r--src/utils/time.cpp80
-rw-r--r--src/utils/time.hpp10
-rw-r--r--src/utils/timed_events.cpp48
-rw-r--r--src/utils/timed_events.hpp137
-rw-r--r--src/utils/timed_events_manager.cpp87
-rw-r--r--src/utils/tolower.cpp13
-rw-r--r--src/utils/tolower.hpp11
-rw-r--r--src/utils/xdg.cpp29
-rw-r--r--src/utils/xdg.hpp12
-rw-r--r--src/xmpp/adhoc_command.cpp81
-rw-r--r--src/xmpp/adhoc_command.hpp44
-rw-r--r--src/xmpp/adhoc_commands_handler.cpp111
-rw-r--r--src/xmpp/adhoc_commands_handler.hpp71
-rw-r--r--src/xmpp/adhoc_session.cpp35
-rw-r--r--src/xmpp/adhoc_session.hpp88
-rw-r--r--src/xmpp/auth.cpp8
-rw-r--r--src/xmpp/auth.hpp6
-rw-r--r--src/xmpp/biboumi_adhoc_commands.cpp694
-rw-r--r--src/xmpp/biboumi_adhoc_commands.hpp3
-rw-r--r--src/xmpp/biboumi_component.cpp437
-rw-r--r--src/xmpp/biboumi_component.hpp19
-rw-r--r--src/xmpp/body.hpp12
-rw-r--r--src/xmpp/jid.cpp152
-rw-r--r--src/xmpp/jid.hpp49
-rw-r--r--src/xmpp/xmpp_component.cpp684
-rw-r--r--src/xmpp/xmpp_component.hpp248
-rw-r--r--src/xmpp/xmpp_parser.cpp172
-rw-r--r--src/xmpp/xmpp_parser.hpp133
-rw-r--r--src/xmpp/xmpp_stanza.cpp229
-rw-r--r--src/xmpp/xmpp_stanza.hpp160
103 files changed, 8159 insertions, 837 deletions
diff --git a/src/biboumi.h.cmake b/src/biboumi.h.cmake
new file mode 100644
index 0000000..1ad9a40
--- /dev/null
+++ b/src/biboumi.h.cmake
@@ -0,0 +1,12 @@
+#cmakedefine USE_DATABASE
+#cmakedefine ICONV_SECOND_ARGUMENT_IS_CONST
+#cmakedefine LIBIDN_FOUND
+#cmakedefine SYSTEMD_FOUND
+#cmakedefine POLLER ${POLLER}
+#cmakedefine BOTAN_FOUND
+#cmakedefine GCRYPT_FOUND
+#cmakedefine UDNS_FOUND
+#cmakedefine SOFTWARE_VERSION "${SOFTWARE_VERSION}"
+#cmakedefine PROJECT_NAME "${PROJECT_NAME}"
+#cmakedefine HAS_GET_TIME
+#cmakedefine HAS_PUT_TIME
diff --git a/src/bridge/bridge.cpp b/src/bridge/bridge.cpp
index a0ecc6e..81ca147 100644
--- a/src/bridge/bridge.cpp
+++ b/src/bridge/bridge.cpp
@@ -1,4 +1,5 @@
#include <bridge/bridge.hpp>
+#include <utility>
#include <xmpp/biboumi_component.hpp>
#include <network/poller.hpp>
#include <utils/empty_if_fixed_server.hpp>
@@ -22,20 +23,24 @@ static std::string in_encoding_for(const Bridge& bridge, const Iid& iid)
#ifdef USE_DATABASE
const auto jid = bridge.get_bare_jid();
auto options = Database::get_irc_channel_options_with_server_default(jid, iid.get_server(), iid.get_local());
- return options.encodingIn.value();
+ auto result = options.col<Database::EncodingIn>();
+ if (!result.empty())
+ return result;
#else
- return {"ISO-8859-1"};
+ (void)bridge;
+ (void)iid;
#endif
+ return {"ISO-8859-1"};
}
-Bridge::Bridge(const std::string& user_jid, BiboumiComponent& xmpp, std::shared_ptr<Poller> poller):
- user_jid(user_jid),
+Bridge::Bridge(std::string user_jid, BiboumiComponent& xmpp, std::shared_ptr<Poller>& poller):
+ user_jid(std::move(user_jid)),
xmpp(xmpp),
poller(poller)
{
#ifdef USE_DATABASE
const auto options = Database::get_global_options(this->user_jid);
- this->set_record_history(options.recordHistory.value());
+ this->set_record_history(options.col<Database::RecordHistory>());
#endif
}
@@ -58,10 +63,10 @@ static std::tuple<std::string, std::string> get_role_affiliation_from_irc_mode(c
void Bridge::shutdown(const std::string& exit_message)
{
- for (auto it = this->irc_clients.begin(); it != this->irc_clients.end(); ++it)
+ for (auto& pair: this->irc_clients)
{
- it->second->send_quit_command(exit_message);
- it->second->leave_dummy_channel(exit_message);
+ pair.second->send_quit_command(exit_message);
+ pair.second->leave_dummy_channel(exit_message, {});
}
}
@@ -167,7 +172,7 @@ IrcClient* Bridge::find_irc_client(const std::string& hostname) const
bool Bridge::join_irc_channel(const Iid& iid, const std::string& nickname, const std::string& password,
const std::string& resource)
{
- const auto hostname = iid.get_server();
+ const auto& hostname = iid.get_server();
IrcClient* irc = this->make_irc_client(hostname, nickname);
this->add_resource_to_server(hostname, resource);
auto res_in_chan = this->is_resource_in_chan(ChannelKey{iid.get_local(), hostname}, resource);
@@ -251,21 +256,24 @@ void Bridge::send_channel_message(const Iid& iid, const std::string& body)
else
irc->send_channel_message(iid.get_local(), line);
+ std::string uuid;
#ifdef USE_DATABASE
const auto xmpp_body = this->make_xmpp_body(line);
if (this->record_history)
- Database::store_muc_message(this->get_bare_jid(), iid, std::chrono::system_clock::now(),
+ uuid = Database::store_muc_message(this->get_bare_jid(), iid.get_local(), iid.get_server(), std::chrono::system_clock::now(),
std::get<0>(xmpp_body), irc->get_own_nick());
#endif
for (const auto& resource: this->resources_in_chan[iid.to_tuple()])
- this->xmpp.send_muc_message(std::to_string(iid), irc->get_own_nick(),
- this->make_xmpp_body(line), this->user_jid + "/" + resource);
+ this->xmpp.send_muc_message(std::to_string(iid), irc->get_own_nick(), this->make_xmpp_body(line),
+ this->user_jid + "/" + resource, uuid);
}
}
-void Bridge::forward_affiliation_role_change(const Iid& iid, const std::string& nick,
+void Bridge::forward_affiliation_role_change(const Iid& iid, const std::string& from,
+ const std::string& nick,
const std::string& affiliation,
- const std::string& role)
+ const std::string& role,
+ const std::string& id)
{
IrcClient* irc = this->get_irc_client(iid.get_server());
IrcChannel* chan = irc->get_channel(iid.get_local());
@@ -273,7 +281,11 @@ void Bridge::forward_affiliation_role_change(const Iid& iid, const std::string&
return;
IrcUser* user = chan->find_user(nick);
if (!user)
- return;
+ {
+ this->xmpp.send_stanza_error("iq", from, std::to_string(iid), id, "cancel",
+ "item-not-found", "no such nick", false);
+ return;
+ }
// For each affiliation or role, we have a “maximal” mode that we want to
// set. We must remove any superior mode at the same time. For example if
// the user already has +o mode, and we set its affiliation to member, we
@@ -325,6 +337,56 @@ void Bridge::forward_affiliation_role_change(const Iid& iid, const std::string&
std::vector<std::string> args(nb, nick);
args.insert(args.begin(), modes);
irc->send_mode_command(iid.get_local(), args);
+
+ irc_responder_callback_t cb = [this, iid, irc, id, from, nick](const std::string& irc_hostname, const IrcMessage& message) -> bool
+ {
+ if (irc_hostname != iid.get_server())
+ return false;
+
+ if (message.command == "MODE" && message.arguments.size() >= 2)
+ {
+ const std::string& chan_name = message.arguments[0];
+ if (chan_name != iid.get_local())
+ return false;
+ const std::string actor_nick = IrcUser{message.prefix}.nick;
+ if (!irc || irc->get_own_nick() != actor_nick)
+ return false;
+
+ this->xmpp.send_iq_result(id, from, std::to_string(iid));
+ }
+ else if (message.command == "401" && message.arguments.size() >= 2)
+ {
+ const std::string target_later = message.arguments[1];
+ if (target_later != nick)
+ return false;
+ std::string error_message = "No such nick";
+ if (message.arguments.size() >= 3)
+ error_message = message.arguments[2];
+ this->xmpp.send_stanza_error("iq", from, std::to_string(iid), id, "cancel", "item-not-found",
+ error_message, false);
+ }
+ else if (message.command == "482" && message.arguments.size() >= 2)
+ {
+ const std::string chan_name_later = utils::tolower(message.arguments[1]);
+ if (chan_name_later != iid.get_local())
+ return false;
+ std::string error_message = "You're not channel operator";
+ if (message.arguments.size() >= 3)
+ error_message = message.arguments[2];
+ this->xmpp.send_stanza_error("iq", from, std::to_string(iid), id, "cancel", "not-allowed",
+ error_message, false);
+ }
+ else if (message.command == "472" && message.arguments.size() >= 2)
+ {
+ std::string error_message = "Unknown mode: "s + message.arguments[1];
+ if (message.arguments.size() >= 3)
+ error_message = message.arguments[2];
+ this->xmpp.send_stanza_error("iq", from, std::to_string(iid), id, "cancel", "not-allowed",
+ error_message, false);
+ }
+ return true;
+ };
+ this->add_waiting_irc(std::move(cb));
}
void Bridge::send_private_message(const Iid& iid, const std::string& body, const std::string& type)
@@ -364,37 +426,60 @@ void Bridge::leave_irc_channel(Iid&& iid, const std::string& status_message, con
if (!this->is_resource_in_chan(key, resource))
return ;
+ IrcChannel* channel = irc->get_channel(iid.get_local());
+
const auto resources = this->number_of_resources_in_chan(key);
if (resources == 1)
{
// Do not send a PART message if we actually are not in that channel
// or if we already sent a PART but we are just waiting for the
// acknowledgment from the server
- IrcChannel* channel = irc->get_channel(iid.get_local());
- if (channel->joined && !channel->parting)
- irc->send_part_command(iid.get_local(), status_message);
+ bool persistent = false;
+#ifdef USE_DATABASE
+ const auto goptions = Database::get_global_options(this->user_jid);
+ if (goptions.col<Database::Persistent>())
+ persistent = true;
+ else
+ {
+ const auto coptions = Database::get_irc_channel_options_with_server_default(this->user_jid, iid.get_server(), iid.get_local());
+ persistent = coptions.col<Database::Persistent>();
+ }
+#endif
+ if (channel->joined && !channel->parting && !persistent)
+ {
+ const auto& chan_name = iid.get_local();
+ if (chan_name.empty())
+ irc->leave_dummy_channel(status_message, resource);
+ else
+ irc->send_part_command(iid.get_local(), status_message);
+ }
+ else if (channel->joined)
+ {
+ this->send_muc_leave(iid, channel->get_self()->nick, "", true, resource);
+ }
// Since there are no resources left in that channel, we don't
// want to receive private messages using this room's JID
this->remove_all_preferred_from_jid_of_room(iid.get_local());
}
else
{
- IrcChannel* chan = irc->get_channel(iid.get_local());
- if (chan)
- {
- auto nick = chan->get_self()->nick;
- this->remove_resource_from_chan(key, resource);
- this->send_muc_leave(std::move(iid), std::move(nick),
- "Biboumi note: "s + std::to_string(resources - 1) + " resources are still in this channel.",
- true, resource);
- if (this->number_of_channels_the_resource_is_in(iid.get_server(), resource) == 0)
- this->remove_resource_from_server(iid.get_server(), resource);
- }
+ if (channel && channel->joined)
+ this->send_muc_leave(iid, channel->get_self()->nick,
+ "Biboumi note: "s + std::to_string(resources - 1) + " resources are still in this channel.",
+ true, resource);
+ this->remove_resource_from_chan(key, resource);
+ if (this->number_of_channels_the_resource_is_in(iid.get_server(), resource) == 0)
+ this->remove_resource_from_server(iid.get_server(), resource);
}
+
}
-void Bridge::send_irc_nick_change(const Iid& iid, const std::string& new_nick)
+void Bridge::send_irc_nick_change(const Iid& iid, const std::string& new_nick, const std::string& requesting_resource)
{
+ // We don’t change the nick if the presence was sent to a channel the resource is not in.
+ auto res_in_chan = this->is_resource_in_chan(ChannelKey{iid.get_local(), iid.get_server()}, requesting_resource);
+ if (!res_in_chan)
+ return;
IrcClient* irc = this->get_irc_client(iid.get_server());
irc->send_nick_command(new_nick);
}
@@ -402,7 +487,7 @@ void Bridge::send_irc_nick_change(const Iid& iid, const std::string& new_nick)
void Bridge::send_irc_channel_list_request(const Iid& iid, const std::string& iq_id, const std::string& to_jid,
ResultSetInfo rs_info)
{
- auto& list = channel_list_cache[iid.get_server()];
+ auto& list = this->channel_list_cache[iid.get_server()];
// We fetch the list from the IRC server only if we have a complete
// cached list that needs to be invalidated (that is, when the request
@@ -425,7 +510,7 @@ void Bridge::send_irc_channel_list_request(const Iid& iid, const std::string& iq
if (irc_hostname != iid.get_server())
return false;
- auto& list = channel_list_cache[iid.get_server()];
+ auto& list = this->channel_list_cache[iid.get_server()];
if (message.command == "263" || message.command == "RPL_TRYAGAIN" || message.command == "ERR_TOOMANYMATCHES"
|| message.command == "ERR_NOSUCHSERVER")
@@ -483,7 +568,6 @@ void Bridge::send_irc_channel_list_request(const Iid& iid, const std::string& iq
{
auto& list = channel_list_cache[iid.get_server()];
const auto res = this->send_matching_channel_list(list, rs_info, iq_id, to_jid, std::to_string(iid));
- log_debug("We added a new channel in our list, can we send the result? ", std::boolalpha, res);
return res;
}
else if (message.command == "323" || message.command == "RPL_LISTEND")
@@ -503,7 +587,7 @@ bool Bridge::send_matching_channel_list(const ChannelList& channel_list, const R
const std::string& id, const std::string& to_jid, const std::string& from)
{
auto begin = channel_list.channels.begin();
- auto end = channel_list.channels.begin();
+ auto end = channel_list.channels.end();
if (channel_list.complete)
{
begin = std::find_if(channel_list.channels.begin(), channel_list.channels.end(), [this, &rs_info](const ListElement& element)
@@ -605,9 +689,12 @@ void Bridge::send_irc_kick(const Iid& iid, const std::string& target, const std:
this->add_waiting_irc(std::move(cb));
}
-void Bridge::set_channel_topic(const Iid& iid, const std::string& subject)
+void Bridge::set_channel_topic(const Iid& iid, std::string subject)
{
IrcClient* irc = this->get_irc_client(iid.get_server());
+ std::string::size_type pos{0};
+ while ((pos = subject.find('\n', pos)) != std::string::npos)
+ subject[pos] = ' ';
irc->send_topic_command(iid.get_local(), subject);
}
@@ -665,9 +752,10 @@ void Bridge::send_irc_participant_ping_request(const Iid& iid, const std::string
const std::string& iq_id, const std::string& to_jid,
const std::string& from_jid)
{
+ Jid from(to_jid);
IrcClient* irc = this->get_irc_client(iid.get_server());
IrcChannel* chan = irc->get_channel(iid.get_local());
- if (!chan->joined)
+ if (!chan->joined || !this->is_resource_in_chan(iid.to_tuple(), from.resource))
{
this->xmpp.send_stanza_error("iq", to_jid, from_jid, iq_id, "cancel", "not-allowed",
"", true);
@@ -756,13 +844,13 @@ void Bridge::send_message(const Iid& iid, const std::string& nick, const std::st
#ifdef USE_DATABASE
const auto xmpp_body = this->make_xmpp_body(body, encoding);
if (!nick.empty() && this->record_history)
- Database::store_muc_message(this->get_bare_jid(), iid, std::chrono::system_clock::now(),
+ Database::store_muc_message(this->get_bare_jid(), iid.get_local(), iid.get_server(), std::chrono::system_clock::now(),
std::get<0>(xmpp_body), nick);
#endif
for (const auto& resource: this->resources_in_chan[iid.to_tuple()])
{
- this->xmpp.send_muc_message(std::to_string(iid), nick,
- this->make_xmpp_body(body, encoding), this->user_jid + "/" + resource);
+ this->xmpp.send_muc_message(std::to_string(iid), nick, this->make_xmpp_body(body, encoding),
+ this->user_jid + "/" + resource, {});
}
}
@@ -793,18 +881,24 @@ void Bridge::send_presence_error(const Iid& iid, const std::string& nick,
this->xmpp.send_presence_error(std::to_string(iid), nick, this->user_jid, type, condition, error_code, text);
}
-void Bridge::send_muc_leave(Iid&& iid, std::string&& nick, const std::string& message, const bool self,
+void Bridge::send_muc_leave(const Iid& iid, const std::string& nick,
+ const std::string& message, const bool self,
const std::string& resource)
{
if (!resource.empty())
- this->xmpp.send_muc_leave(std::to_string(iid), std::move(nick), this->make_xmpp_body(message),
+ this->xmpp.send_muc_leave(std::to_string(iid), nick, this->make_xmpp_body(message),
this->user_jid + "/" + resource, self);
else
- for (const auto& res: this->resources_in_chan[iid.to_tuple()])
- this->xmpp.send_muc_leave(std::to_string(iid), std::move(nick), this->make_xmpp_body(message),
- this->user_jid + "/" + res, self);
+ {
+ for (const auto &res: this->resources_in_chan[iid.to_tuple()])
+ this->xmpp.send_muc_leave(std::to_string(iid), nick, this->make_xmpp_body(message),
+ this->user_jid + "/" + res, self);
+ if (self)
+ this->remove_all_resources_from_chan(iid.to_tuple());
+
+ }
IrcClient* irc = this->find_irc_client(iid.get_server());
- if (irc && irc->number_of_joined_channels() == 0)
+ if (self && irc && irc->number_of_joined_channels() == 0)
irc->send_quit_command("");
}
@@ -903,18 +997,22 @@ void Bridge::send_room_history(const std::string& hostname, const std::string& c
this->send_room_history(hostname, chan_name, resource);
}
-void Bridge::send_room_history(const std::string& hostname, const std::string& chan_name, const std::string& resource)
+void Bridge::send_room_history(const std::string& hostname, std::string chan_name, const std::string& resource)
{
#ifdef USE_DATABASE
const auto coptions = Database::get_irc_channel_options_with_server_and_global_default(this->user_jid, hostname, chan_name);
- const auto lines = Database::get_muc_logs(this->user_jid, chan_name, hostname, coptions.maxHistoryLength.value());
+ const auto lines = Database::get_muc_logs(this->user_jid, chan_name, hostname, coptions.col<Database::MaxHistoryLength>());
+ chan_name.append(utils::empty_if_fixed_server("%" + hostname));
for (const auto& line: lines)
{
- const auto seconds = line.date.value().timeStamp();
- this->xmpp.send_history_message(chan_name + utils::empty_if_fixed_server("%" + hostname), line.nick.value(),
- line.body.value(),
+ const auto seconds = line.col<Database::Date>();
+ this->xmpp.send_history_message(chan_name, line.col<Database::Nick>(), line.col<Database::Body>(),
this->user_jid + "/" + resource, seconds);
}
+#else
+ (void)hostname;
+ (void)chan_name;
+ (void)resource;
#endif
}
@@ -960,7 +1058,7 @@ void Bridge::send_iq_version_request(const std::string& nick, const std::string&
{
const auto resources = this->resources_in_server[hostname];
if (resources.begin() != resources.end())
- this->xmpp.send_iq_version_request(utils::tolower(nick) + "%" + utils::empty_if_fixed_server(hostname),
+ this->xmpp.send_iq_version_request(utils::tolower(nick) + utils::empty_if_fixed_server("%" + hostname),
this->user_jid + "/" + *resources.begin());
}
@@ -968,12 +1066,12 @@ void Bridge::send_xmpp_ping_request(const std::string& nick, const std::string&
const std::string& id)
{
// Use revstr because the forwarded ping to target XMPP user must not be
- // the same that the request iq, but we also need to get it back easily
+ // the same as the request iq, but we also need to get it back easily
// (revstr again)
// Forward to the first resource (arbitrary, based on the “order” of the std::set) only
const auto resources = this->resources_in_server[hostname];
if (resources.begin() != resources.end())
- this->xmpp.send_ping_request(utils::tolower(nick) + "%" + utils::empty_if_fixed_server(hostname),
+ this->xmpp.send_ping_request(utils::tolower(nick) + utils::empty_if_fixed_server("%" + hostname),
this->user_jid + "/" + *resources.begin(), utils::revstr(id));
}
@@ -1033,6 +1131,11 @@ std::unordered_map<std::string, std::shared_ptr<IrcClient>>& Bridge::get_irc_cli
return this->irc_clients;
}
+const std::unordered_map<std::string, std::shared_ptr<IrcClient>>& Bridge::get_irc_clients() const
+{
+ return this->irc_clients;
+}
+
std::set<char> Bridge::get_chantypes(const std::string& hostname) const
{
IrcClient* irc = this->find_irc_client(hostname);
@@ -1070,6 +1173,11 @@ bool Bridge::is_resource_in_chan(const Bridge::ChannelKey& channel, const std::s
return false;
}
+void Bridge::remove_all_resources_from_chan(const Bridge::ChannelKey& channel)
+{
+ this->resources_in_chan.erase(channel);
+}
+
void Bridge::add_resource_to_server(const Bridge::IrcHostname& irc_hostname, const std::string& resource)
{
auto it = this->resources_in_server.find(irc_hostname);
@@ -1099,9 +1207,9 @@ bool Bridge::is_resource_in_server(const Bridge::IrcHostname& irc_hostname, cons
return false;
}
-std::size_t Bridge::number_of_resources_in_chan(const Bridge::ChannelKey& channel_key) const
+std::size_t Bridge::number_of_resources_in_chan(const Bridge::ChannelKey& channel) const
{
- auto it = this->resources_in_chan.find(channel_key);
+ auto it = this->resources_in_chan.find(channel);
if (it == this->resources_in_chan.end())
return 0;
return it->second.size();
diff --git a/src/bridge/bridge.hpp b/src/bridge/bridge.hpp
index 18ebfeb..033291c 100644
--- a/src/bridge/bridge.hpp
+++ b/src/bridge/bridge.hpp
@@ -38,7 +38,7 @@ using irc_responder_callback_t = std::function<bool(const std::string& irc_hostn
class Bridge
{
public:
- explicit Bridge(const std::string& user_jid, BiboumiComponent& xmpp, std::shared_ptr<Poller> poller);
+ explicit Bridge(std::string user_jid, BiboumiComponent& xmpp, std::shared_ptr<Poller>& poller);
~Bridge() = default;
Bridge(const Bridge&) = delete;
@@ -80,10 +80,10 @@ public:
void send_private_message(const Iid& iid, const std::string& body, const std::string& type="PRIVMSG");
void send_raw_message(const std::string& hostname, const std::string& body);
void leave_irc_channel(Iid&& iid, const std::string& status_message, const std::string& resource);
- void send_irc_nick_change(const Iid& iid, const std::string& new_nick);
+ void send_irc_nick_change(const Iid& iid, const std::string& new_nick, const std::string& requesting_resource);
void send_irc_kick(const Iid& iid, const std::string& target, const std::string& reason,
const std::string& iq_id, const std::string& to_jid);
- void set_channel_topic(const Iid& iid, const std::string& subject);
+ void set_channel_topic(const Iid& iid, std::string subject);
void send_xmpp_version_to_irc(const Iid& iid, const std::string& name, const std::string& version,
const std::string& os);
void send_irc_ping_result(const Iid& iid, const std::string& id);
@@ -103,8 +103,8 @@ public:
bool send_matching_channel_list(const ChannelList& channel_list,
const ResultSetInfo& rs_info, const std::string& id, const std::string& to_jid,
const std::string& from);
- void forward_affiliation_role_change(const Iid& iid, const std::string& nick,
- const std::string& affiliation, const std::string& role);
+ void forward_affiliation_role_change(const Iid& iid, const std::string& from, const std::string& nick,
+ const std::string& affiliation, const std::string& role, const std::string& id);
/**
* Directly send a CTCP PING request to the IRC user
*/
@@ -157,7 +157,7 @@ public:
* Send the MUC history to the user
*/
void send_room_history(const std::string& hostname, const std::string& chan_name);
- void send_room_history(const std::string& hostname, const std::string& chan_name, const std::string& resource);
+ void send_room_history(const std::string& hostname, std::string chan_name, const std::string& resource);
/**
* Send a MUC message from some participant
*/
@@ -169,7 +169,7 @@ public:
/**
* Send an unavailable presence from this participant
*/
- void send_muc_leave(Iid&& iid, std::string&& nick, const std::string& message, const bool self, const std::string& resource="");
+ void send_muc_leave(const Iid& iid, const std::string& nick, const std::string& message, const bool self, const std::string& resource = "");
/**
* Send presences to indicate that an user old_nick (ourself if self ==
* true) changed his nick to new_nick. The user_mode is needed because
@@ -231,6 +231,7 @@ public:
*/
void trigger_on_irc_message(const std::string& irc_hostname, const IrcMessage& message);
std::unordered_map<std::string, std::shared_ptr<IrcClient>>& get_irc_clients();
+ const std::unordered_map<std::string, std::shared_ptr<IrcClient>>& get_irc_clients() const;
std::set<char> get_chantypes(const std::string& hostname) const;
#ifdef USE_DATABASE
void set_record_history(const bool val);
@@ -302,10 +303,11 @@ private:
/**
* Manage which resource is in which channel
*/
- void add_resource_to_chan(const ChannelKey& channel_key, const std::string& resource);
- void remove_resource_from_chan(const ChannelKey& channel_key, const std::string& resource);
- bool is_resource_in_chan(const ChannelKey& channel_key, const std::string& resource) const;
- std::size_t number_of_resources_in_chan(const ChannelKey& channel_key) const;
+ void add_resource_to_chan(const ChannelKey& channel, const std::string& resource);
+ void remove_resource_from_chan(const ChannelKey& channel, const std::string& resource);
+ bool is_resource_in_chan(const ChannelKey& channel, const std::string& resource) const;
+ void remove_all_resources_from_chan(const ChannelKey& channel);
+ std::size_t number_of_resources_in_chan(const ChannelKey& channel) const;
void add_resource_to_server(const IrcHostname& irc_hostname, const std::string& resource);
void remove_resource_from_server(const IrcHostname& irc_hostname, const std::string& resource);
diff --git a/src/bridge/colors.cpp b/src/bridge/colors.cpp
index 66f51ee..7662425 100644
--- a/src/bridge/colors.cpp
+++ b/src/bridge/colors.cpp
@@ -4,7 +4,7 @@
#include <algorithm>
#include <iostream>
-#include <string.h>
+#include <cstring>
using namespace std::string_literals;
diff --git a/src/bridge/colors.hpp b/src/bridge/colors.hpp
index e2c8a87..dceed74 100644
--- a/src/bridge/colors.hpp
+++ b/src/bridge/colors.hpp
@@ -51,6 +51,6 @@ static const char irc_format_char[] = {
* Returns the body cleaned from any IRC formatting (but without any xhtml),
* and the body as XHTML-IM
*/
-Xmpp::body irc_format_to_xhtmlim(const std::string& str);
+Xmpp::body irc_format_to_xhtmlim(const std::string& s);
diff --git a/src/config/config.cpp b/src/config/config.cpp
new file mode 100644
index 0000000..0f3d639
--- /dev/null
+++ b/src/config/config.cpp
@@ -0,0 +1,127 @@
+#include <config/config.hpp>
+#include <utils/tolower.hpp>
+
+#include <iostream>
+#include <cstring>
+
+#include <cstdlib>
+
+using namespace std::string_literals;
+
+extern char** environ;
+
+std::string Config::filename{};
+std::map<std::string, std::string> Config::values{};
+std::vector<t_config_changed_callback> Config::callbacks{};
+
+std::string Config::get(const std::string& option, const std::string& def)
+{
+ auto it = Config::values.find(option);
+
+ if (it == Config::values.end())
+ return def;
+ return it->second;
+}
+
+int Config::get_int(const std::string& option, const int& def)
+{
+ std::string res = Config::get(option, "");
+ if (!res.empty())
+ return std::atoi(res.c_str());
+ else
+ return def;
+}
+
+void Config::set(const std::string& option, const std::string& value, bool save)
+{
+ Config::values[option] = value;
+ if (save)
+ {
+ Config::save_to_file();
+ Config::trigger_configuration_change();
+ }
+}
+
+void Config::connect(const t_config_changed_callback& callback)
+{
+ Config::callbacks.push_back(callback);
+}
+
+void Config::clear()
+{
+ Config::values.clear();
+}
+
+/**
+ * Private methods
+ */
+void Config::trigger_configuration_change()
+{
+ std::vector<t_config_changed_callback>::iterator it;
+ for (it = Config::callbacks.begin(); it < Config::callbacks.end(); ++it)
+ (*it)();
+}
+
+bool Config::read_conf(const std::string& name)
+{
+ if (!name.empty())
+ Config::filename = name;
+
+ std::ifstream file(Config::filename.data());
+ if (!file.is_open())
+ {
+ std::cerr << "Error while opening file " << filename << " for reading: " << strerror(errno) << std::endl;
+ return false;
+ }
+
+ Config::clear();
+
+ auto parse_line = [](const std::string& line, const bool env)
+ {
+ static const auto env_option_prefix = "BIBOUMI_"s;
+
+ if (line == "" || line[0] == '#')
+ return;
+ size_t pos = line.find('=');
+ if (pos == std::string::npos)
+ return;
+ std::string option = line.substr(0, pos);
+ std::string value = line.substr(pos+1);
+ if (env)
+ {
+ auto a = option.substr(0, env_option_prefix.size());
+ if (a == env_option_prefix)
+ option = utils::tolower(option.substr(env_option_prefix.size()));
+ else
+ return;
+ }
+ Config::values[option] = value;
+ };
+
+ std::string line;
+ while (file.good())
+ {
+ std::getline(file, line);
+ parse_line(line, false);
+ }
+
+ char** env_line = environ;
+ while (*env_line)
+ {
+ parse_line(*env_line, true);
+ env_line++;
+ }
+ return true;
+}
+
+void Config::save_to_file()
+{
+ std::ofstream file(Config::filename.data());
+ if (file.fail())
+ {
+ std::cerr << "Could not save config file." << std::endl;
+ return ;
+ }
+ for (const auto& it: Config::values)
+ file << it.first << "=" << it.second << '\n';
+}
diff --git a/src/config/config.hpp b/src/config/config.hpp
new file mode 100644
index 0000000..2ba38cc
--- /dev/null
+++ b/src/config/config.hpp
@@ -0,0 +1,93 @@
+/**
+ * Read the config file and save all the values in a map.
+ * Also, a singleton.
+ *
+ * Use Config::filename = "bla" to set the filename you want to use.
+ *
+ * If you want to exit if the file does not exist when it is open for
+ * reading, set Config::file_must_exist = true.
+ *
+ * Config::get() can then be used to access the values in the conf.
+ *
+ * Use Config::close() when you're done getting/setting value. This will
+ * save the config into the file.
+ */
+
+#pragma once
+
+#include <functional>
+#include <fstream>
+#include <memory>
+#include <vector>
+#include <string>
+#include <map>
+
+typedef std::function<void()> t_config_changed_callback;
+
+class Config
+{
+public:
+ Config() = default;
+ ~Config() = default;
+ Config(const Config&) = delete;
+ Config& operator=(const Config&) = delete;
+ Config(Config&&) = delete;
+ Config& operator=(Config&&) = delete;
+
+ /**
+ * returns a value from the config. If it doesn’t exist, use
+ * the second argument as the default.
+ */
+ static std::string get(const std::string&, const std::string&);
+ /**
+ * returns a value from the config. If it doesn’t exist, use
+ * the second argument as the default.
+ */
+ static int get_int(const std::string&, const int&);
+ /**
+ * Set a value for the given option. And write all the config
+ * in the file from which it was read if save is true.
+ */
+ static void set(const std::string&, const std::string&, bool save = false);
+ /**
+ * Adds a function to a list. This function will be called whenever a
+ * configuration change occurs (when set() is called, or when the initial
+ * conf is read)
+ */
+ static void connect(const t_config_changed_callback&);
+ /**
+ * Destroy the instance, forcing it to be recreated (with potentially
+ * different parameters) the next time it’s needed.
+ */
+ static void clear();
+ /**
+ * Read the configuration file at the given path.
+ */
+ static bool read_conf(const std::string& name="");
+ /**
+ * Get the filename
+ */
+ static const std::string& get_filename()
+ { return Config::filename; }
+
+private:
+ /**
+ * Set the value of the filename to use, before calling any method.
+ */
+ static std::string filename;
+ /**
+ * Write all the config values into the configuration file
+ */
+ static void save_to_file();
+ /**
+ * Call all the callbacks previously registered using connect().
+ * This is used to notify any class that a configuration change occured.
+ */
+ static void trigger_configuration_change();
+
+ static std::map<std::string, std::string> values;
+ static std::vector<t_config_changed_callback> callbacks;
+
+};
+
+
diff --git a/src/database/column.hpp b/src/database/column.hpp
new file mode 100644
index 0000000..111f9ca
--- /dev/null
+++ b/src/database/column.hpp
@@ -0,0 +1,17 @@
+#pragma once
+
+#include <cstddef>
+
+template <typename T>
+struct Column
+{
+ Column(T default_value):
+ value{default_value} {}
+ Column():
+ value{} {}
+ using real_type = T;
+ T value{};
+};
+
+struct Id: Column<std::size_t> { static constexpr auto name = "id_";
+ static constexpr auto options = "PRIMARY KEY AUTOINCREMENT"; };
diff --git a/src/database/count_query.hpp b/src/database/count_query.hpp
new file mode 100644
index 0000000..b7bbf51
--- /dev/null
+++ b/src/database/count_query.hpp
@@ -0,0 +1,35 @@
+#pragma once
+
+#include <database/query.hpp>
+#include <database/table.hpp>
+
+#include <string>
+
+#include <sqlite3.h>
+
+struct CountQuery: public Query
+{
+ CountQuery(std::string name):
+ Query("SELECT count(*) FROM ")
+ {
+ this->body += std::move(name);
+ }
+
+ int64_t execute(sqlite3* db)
+ {
+ auto statement = this->prepare(db);
+ int64_t res = 0;
+ if (sqlite3_step(statement.get()) == SQLITE_ROW)
+ res = sqlite3_column_int64(statement.get(), 0);
+ else
+ {
+ log_error("Count request didn’t return a result");
+ return 0;
+ }
+ if (sqlite3_step(statement.get()) != SQLITE_DONE)
+ log_warning("Count request returned more than one result.");
+
+ log_debug("Returning count: ", res);
+ return res;
+ }
+};
diff --git a/src/database/database.cpp b/src/database/database.cpp
index f7d309b..92f7682 100644
--- a/src/database/database.cpp
+++ b/src/database/database.cpp
@@ -2,171 +2,185 @@
#ifdef USE_DATABASE
#include <database/database.hpp>
-#include <logger/logger.hpp>
-#include <irc/iid.hpp>
#include <uuid/uuid.h>
#include <utils/get_first_non_empty.hpp>
#include <utils/time.hpp>
-using namespace std::string_literals;
+#include <sqlite3.h>
-std::unique_ptr<db::BibouDB> Database::db;
+sqlite3* Database::db;
+Database::MucLogLineTable Database::muc_log_lines("MucLogLine_");
+Database::GlobalOptionsTable Database::global_options("GlobalOptions_");
+Database::IrcServerOptionsTable Database::irc_server_options("IrcServerOptions_");
+Database::IrcChannelOptionsTable Database::irc_channel_options("IrcChannelOptions_");
-void Database::open(const std::string& filename, const std::string& db_type)
+void Database::open(const std::string& filename)
{
- try
+ // Try to open the specified database.
+ // Close and replace the previous database pointer if it succeeded. If it did
+ // not, just leave things untouched
+ sqlite3* new_db;
+ auto res = sqlite3_open_v2(filename.data(), &new_db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr);
+ if (res != SQLITE_OK)
{
- auto new_db = std::make_unique<db::BibouDB>(db_type,
- "database="s + filename);
- if (new_db->needsUpgrade())
- new_db->upgrade();
- Database::db.reset(new_db.release());
- } catch (const litesql::DatabaseError& e) {
- log_error("Failed to open database ", filename, ". ", e.what());
- throw;
+ log_error("Failed to open database file ", filename, ": ", sqlite3_errmsg(Database::db));
+ throw std::runtime_error("");
}
+ Database::close();
+ Database::db = new_db;
+ Database::muc_log_lines.create(Database::db);
+ Database::muc_log_lines.upgrade(Database::db);
+ Database::global_options.create(Database::db);
+ Database::global_options.upgrade(Database::db);
+ Database::irc_server_options.create(Database::db);
+ Database::irc_server_options.upgrade(Database::db);
+ Database::irc_channel_options.create(Database::db);
+ Database::irc_channel_options.upgrade(Database::db);
}
-void Database::set_verbose(const bool val)
-{
- Database::db->verbose = val;
-}
-db::GlobalOptions Database::get_global_options(const std::string& owner)
+Database::GlobalOptions Database::get_global_options(const std::string& owner)
{
- try {
- auto options = litesql::select<db::GlobalOptions>(*Database::db,
- db::GlobalOptions::Owner == owner).one();
- return options;
- } catch (const litesql::NotFound& e) {
- db::GlobalOptions options(*Database::db);
- options.owner = owner;
- return options;
- }
+ auto request = Database::global_options.select();
+ request.where() << Owner{} << "=" << owner;
+
+ Database::GlobalOptions options{Database::global_options.get_name()};
+ auto result = request.execute(Database::db);
+ if (result.size() == 1)
+ options = result.front();
+ else
+ options.col<Owner>() = owner;
+ return options;
}
-db::IrcServerOptions Database::get_irc_server_options(const std::string& owner,
- const std::string& server)
+Database::IrcServerOptions Database::get_irc_server_options(const std::string& owner, const std::string& server)
{
- try {
- auto options = litesql::select<db::IrcServerOptions>(*Database::db,
- db::IrcServerOptions::Owner == owner &&
- db::IrcServerOptions::Server == server).one();
- return options;
- } catch (const litesql::NotFound& e) {
- db::IrcServerOptions options(*Database::db);
- options.owner = owner;
- options.server = server;
- // options.update();
- return options;
- }
+ auto request = Database::irc_server_options.select();
+ request.where() << Owner{} << "=" << owner << " and " << Server{} << "=" << server;
+
+ Database::IrcServerOptions options{Database::irc_server_options.get_name()};
+ auto result = request.execute(Database::db);
+ if (result.size() == 1)
+ options = result.front();
+ else
+ {
+ options.col<Owner>() = owner;
+ options.col<Server>() = server;
+ }
+ return options;
}
-db::IrcChannelOptions Database::get_irc_channel_options(const std::string& owner,
- const std::string& server,
- const std::string& channel)
+Database::IrcChannelOptions Database::get_irc_channel_options(const std::string& owner, const std::string& server, const std::string& channel)
{
- try {
- auto options = litesql::select<db::IrcChannelOptions>(*Database::db,
- db::IrcChannelOptions::Owner == owner &&
- db::IrcChannelOptions::Server == server &&
- db::IrcChannelOptions::Channel == channel).one();
- return options;
- } catch (const litesql::NotFound& e) {
- db::IrcChannelOptions options(*Database::db);
- options.owner = owner;
- options.server = server;
- options.channel = channel;
- return options;
- }
+ auto request = Database::irc_channel_options.select();
+ request.where() << Owner{} << "=" << owner <<\
+ " and " << Server{} << "=" << server <<\
+ " and " << Channel{} << "=" << channel;
+ Database::IrcChannelOptions options{Database::irc_channel_options.get_name()};
+ auto result = request.execute(Database::db);
+ if (result.size() == 1)
+ options = result.front();
+ else
+ {
+ options.col<Owner>() = owner;
+ options.col<Server>() = server;
+ options.col<Channel>() = channel;
+ }
+ return options;
}
-db::IrcChannelOptions Database::get_irc_channel_options_with_server_default(const std::string& owner,
- const std::string& server,
- const std::string& channel)
+Database::IrcChannelOptions Database::get_irc_channel_options_with_server_default(const std::string& owner, const std::string& server,
+ const std::string& channel)
{
auto coptions = Database::get_irc_channel_options(owner, server, channel);
auto soptions = Database::get_irc_server_options(owner, server);
- coptions.encodingIn = get_first_non_empty(coptions.encodingIn.value(),
- soptions.encodingIn.value());
- coptions.encodingOut = get_first_non_empty(coptions.encodingOut.value(),
- soptions.encodingOut.value());
+ coptions.col<EncodingIn>() = get_first_non_empty(coptions.col<EncodingIn>(),
+ soptions.col<EncodingIn>());
+ coptions.col<EncodingOut>() = get_first_non_empty(coptions.col<EncodingOut>(),
+ soptions.col<EncodingOut>());
- coptions.maxHistoryLength = get_first_non_empty(coptions.maxHistoryLength.value(),
- soptions.maxHistoryLength.value());
+ coptions.col<MaxHistoryLength>() = get_first_non_empty(coptions.col<MaxHistoryLength>(),
+ soptions.col<MaxHistoryLength>());
return coptions;
}
-db::IrcChannelOptions Database::get_irc_channel_options_with_server_and_global_default(const std::string& owner,
- const std::string& server,
- const std::string& channel)
+Database::IrcChannelOptions Database::get_irc_channel_options_with_server_and_global_default(const std::string& owner, const std::string& server, const std::string& channel)
{
auto coptions = Database::get_irc_channel_options(owner, server, channel);
auto soptions = Database::get_irc_server_options(owner, server);
auto goptions = Database::get_global_options(owner);
- coptions.encodingIn = get_first_non_empty(coptions.encodingIn.value(),
- soptions.encodingIn.value());
- coptions.encodingOut = get_first_non_empty(coptions.encodingOut.value(),
- soptions.encodingOut.value());
+ coptions.col<EncodingIn>() = get_first_non_empty(coptions.col<EncodingIn>(),
+ soptions.col<EncodingIn>());
- coptions.maxHistoryLength = get_first_non_empty(coptions.maxHistoryLength.value(),
- soptions.maxHistoryLength.value(),
- goptions.maxHistoryLength.value());
+ coptions.col<EncodingOut>() = get_first_non_empty(coptions.col<EncodingOut>(),
+ soptions.col<EncodingOut>());
+
+ coptions.col<MaxHistoryLength>() = get_first_non_empty(coptions.col<MaxHistoryLength>(),
+ soptions.col<MaxHistoryLength>(),
+ goptions.col<MaxHistoryLength>());
return coptions;
}
-void Database::store_muc_message(const std::string& owner, const Iid& iid,
- Database::time_point date,
- const std::string& body,
- const std::string& nick)
+std::string Database::store_muc_message(const std::string& owner, const std::string& chan_name,
+ const std::string& server_name, Database::time_point date,
+ const std::string& body, const std::string& nick)
{
- db::MucLogLine line(*Database::db);
+ auto line = Database::muc_log_lines.row();
+
+ auto uuid = Database::gen_uuid();
- line.uuid = Database::gen_uuid();
- line.owner = owner;
- line.ircChanName = iid.get_local();
- line.ircServerName = iid.get_server();
- line.date = date.time_since_epoch().count() / 1'000'000'000;
- line.body = body;
- line.nick = nick;
+ line.col<Uuid>() = uuid;
+ line.col<Owner>() = owner;
+ line.col<IrcChanName>() = chan_name;
+ line.col<IrcServerName>() = server_name;
+ line.col<Date>() = std::chrono::duration_cast<std::chrono::seconds>(date.time_since_epoch()).count();
+ line.col<Body>() = body;
+ line.col<Nick>() = nick;
- line.update();
+ line.save(Database::db);
+
+ return uuid;
}
-std::vector<db::MucLogLine> Database::get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server,
+std::vector<Database::MucLogLine> Database::get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server,
int limit, const std::string& start, const std::string& end)
{
- auto request = litesql::select<db::MucLogLine>(*Database::db,
- db::MucLogLine::Owner == owner &&
- db::MucLogLine::IrcChanName == chan_name &&
- db::MucLogLine::IrcServerName == server);
- request.orderBy(db::MucLogLine::Id, false);
+ auto request = Database::muc_log_lines.select();
+ request.where() << Database::Owner{} << "=" << owner << \
+ " and " << Database::IrcChanName{} << "=" << chan_name << \
+ " and " << Database::IrcServerName{} << "=" << server;
- if (limit >= 0)
- request.limit(limit);
if (!start.empty())
{
const auto start_time = utils::parse_datetime(start);
if (start_time != -1)
- request.where(db::MucLogLine::Date >= start_time);
+ request << " and " << Database::Date{} << ">=" << start_time;
}
if (!end.empty())
{
const auto end_time = utils::parse_datetime(end);
if (end_time != -1)
- request.where(db::MucLogLine::Date <= end_time);
+ request << " and " << Database::Date{} << "<=" << end_time;
}
- const auto& res = request.all();
- return {res.crbegin(), res.crend()};
+
+ request.order_by() << Id{} << " DESC ";
+
+ if (limit >= 0)
+ request.limit() << limit;
+
+ auto result = request.execute(Database::db);
+
+ return {result.crbegin(), result.crend()};
}
void Database::close()
{
- Database::db.reset(nullptr);
+ sqlite3_close_v2(Database::db);
+ Database::db = nullptr;
}
std::string Database::gen_uuid()
@@ -178,5 +192,4 @@ std::string Database::gen_uuid()
return uuid_str;
}
-
-#endif
+#endif \ No newline at end of file
diff --git a/src/database/database.hpp b/src/database/database.hpp
index 6823574..8364abc 100644
--- a/src/database/database.hpp
+++ b/src/database/database.hpp
@@ -1,22 +1,112 @@
#pragma once
-
#include <biboumi.h>
#ifdef USE_DATABASE
-#include "biboudb.hpp"
+#include <database/table.hpp>
+#include <database/column.hpp>
+#include <database/count_query.hpp>
-#include <memory>
+#include <utils/optional_bool.hpp>
-#include <litesql.hpp>
#include <chrono>
+#include <string>
+
+#include <memory>
-class Iid;
class Database
{
-public:
+ public:
using time_point = std::chrono::system_clock::time_point;
+
+ struct Uuid: Column<std::string> { static constexpr auto name = "uuid_";
+ static constexpr auto options = ""; };
+
+ struct Owner: Column<std::string> { static constexpr auto name = "owner_";
+ static constexpr auto options = ""; };
+
+ struct IrcChanName: Column<std::string> { static constexpr auto name = "ircChanName_";
+ static constexpr auto options = ""; };
+
+ struct Channel: Column<std::string> { static constexpr auto name = "channel_";
+ static constexpr auto options = ""; };
+
+ struct IrcServerName: Column<std::string> { static constexpr auto name = "ircServerName_";
+ static constexpr auto options = ""; };
+
+ struct Server: Column<std::string> { static constexpr auto name = "server_";
+ static constexpr auto options = ""; };
+
+ struct Date: Column<time_point::rep> { static constexpr auto name = "date_";
+ static constexpr auto options = ""; };
+
+ struct Body: Column<std::string> { static constexpr auto name = "body_";
+ static constexpr auto options = ""; };
+
+ struct Nick: Column<std::string> { static constexpr auto name = "nick_";
+ static constexpr auto options = ""; };
+
+ struct Pass: Column<std::string> { static constexpr auto name = "pass_";
+ static constexpr auto options = ""; };
+
+ struct Ports: Column<std::string> { static constexpr auto name = "ports_";
+ static constexpr auto options = "";
+ Ports(): Column<std::string>("6667") {} };
+
+ struct TlsPorts: Column<std::string> { static constexpr auto name = "tlsPorts_";
+ static constexpr auto options = "";
+ TlsPorts(): Column<std::string>("6697;6670") {} };
+
+ struct Username: Column<std::string> { static constexpr auto name = "username_";
+ static constexpr auto options = ""; };
+
+ struct Realname: Column<std::string> { static constexpr auto name = "realname_";
+ static constexpr auto options = ""; };
+
+ struct AfterConnectionCommand: Column<std::string> { static constexpr auto name = "afterConnectionCommand_";
+ static constexpr auto options = ""; };
+
+ struct TrustedFingerprint: Column<std::string> { static constexpr auto name = "trustedFingerprint_";
+ static constexpr auto options = ""; };
+
+ struct EncodingOut: Column<std::string> { static constexpr auto name = "encodingOut_";
+ static constexpr auto options = ""; };
+
+ struct EncodingIn: Column<std::string> { static constexpr auto name = "encodingIn_";
+ static constexpr auto options = ""; };
+
+ struct MaxHistoryLength: Column<int> { static constexpr auto name = "maxHistoryLength_";
+ static constexpr auto options = "";
+ MaxHistoryLength(): Column<int>(20) {} };
+
+ struct RecordHistory: Column<bool> { static constexpr auto name = "recordHistory_";
+ static constexpr auto options = "";
+ RecordHistory(): Column<bool>(true) {}};
+
+ struct RecordHistoryOptional: Column<OptionalBool> { static constexpr auto name = "recordHistory_";
+ static constexpr auto options = ""; };
+
+ struct VerifyCert: Column<bool> { static constexpr auto name = "verifyCert_";
+ static constexpr auto options = "";
+ VerifyCert(): Column<bool>(true) {} };
+
+ struct Persistent: Column<bool> { static constexpr auto name = "persistent_";
+ static constexpr auto options = "";
+ Persistent(): Column<bool>(false) {} };
+
+ using MucLogLineTable = Table<Id, Uuid, Owner, IrcChanName, IrcServerName, Date, Body, Nick>;
+ using MucLogLine = MucLogLineTable::RowType;
+
+ using GlobalOptionsTable = Table<Id, Owner, MaxHistoryLength, RecordHistory, Persistent>;
+ using GlobalOptions = GlobalOptionsTable::RowType;
+
+ using IrcServerOptionsTable = Table<Id, Owner, Server, Pass, AfterConnectionCommand, TlsPorts, Ports, Username, Realname, VerifyCert, TrustedFingerprint, EncodingOut, EncodingIn, MaxHistoryLength>;
+ using IrcServerOptions = IrcServerOptionsTable::RowType;
+
+ using IrcChannelOptionsTable = Table<Id, Owner, Server, Channel, EncodingOut, EncodingIn, MaxHistoryLength, Persistent, RecordHistoryOptional>;
+ using IrcChannelOptions = IrcChannelOptionsTable::RowType;
+
Database() = default;
~Database() = default;
@@ -25,42 +115,40 @@ public:
Database& operator=(const Database&) = delete;
Database& operator=(Database&&) = delete;
- static void set_verbose(const bool val);
-
- template<typename PersistentType>
- static size_t count()
- {
- return litesql::select<PersistentType>(*Database::db).count();
- }
- /**
- * Return the object from the db. Create it beforehand (with all default
- * values) if it is not already present.
- */
- static db::GlobalOptions get_global_options(const std::string& owner);
- static db::IrcServerOptions get_irc_server_options(const std::string& owner,
+ static GlobalOptions get_global_options(const std::string& owner);
+ static IrcServerOptions get_irc_server_options(const std::string& owner,
const std::string& server);
- static db::IrcChannelOptions get_irc_channel_options(const std::string& owner,
- const std::string& server,
- const std::string& channel);
- static db::IrcChannelOptions get_irc_channel_options_with_server_default(const std::string& owner,
- const std::string& server,
- const std::string& channel);
- static db::IrcChannelOptions get_irc_channel_options_with_server_and_global_default(const std::string& owner,
- const std::string& server,
- const std::string& channel);
- static std::vector<db::MucLogLine> get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server,
- int limit=-1, const std::string& before="", const std::string& after="");
- static void store_muc_message(const std::string& owner, const Iid& iid,
- time_point date, const std::string& body, const std::string& nick);
+ static IrcChannelOptions get_irc_channel_options(const std::string& owner,
+ const std::string& server,
+ const std::string& channel);
+ static IrcChannelOptions get_irc_channel_options_with_server_default(const std::string& owner,
+ const std::string& server,
+ const std::string& channel);
+ static IrcChannelOptions get_irc_channel_options_with_server_and_global_default(const std::string& owner,
+ const std::string& server,
+ const std::string& channel);
+ static std::vector<MucLogLine> get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server,
+ int limit=-1, const std::string& start="", const std::string& end="");
+ static std::string store_muc_message(const std::string& owner, const std::string& chan_name, const std::string& server_name,
+ time_point date, const std::string& body, const std::string& nick);
static void close();
- static void open(const std::string& filename, const std::string& db_type="sqlite3");
+ static void open(const std::string& filename);
+ template <typename TableType>
+ static int64_t count(const TableType& table)
+ {
+ CountQuery query{table.get_name()};
+ return query.execute(Database::db);
+ }
+
+ static MucLogLineTable muc_log_lines;
+ static GlobalOptionsTable global_options;
+ static IrcServerOptionsTable irc_server_options;
+ static IrcChannelOptionsTable irc_channel_options;
+ static sqlite3* db;
-private:
+ private:
static std::string gen_uuid();
- static std::unique_ptr<db::BibouDB> db;
};
#endif /* USE_DATABASE */
-
-
diff --git a/src/database/insert_query.hpp b/src/database/insert_query.hpp
new file mode 100644
index 0000000..9e410ce
--- /dev/null
+++ b/src/database/insert_query.hpp
@@ -0,0 +1,129 @@
+#pragma once
+
+#include <database/statement.hpp>
+#include <database/column.hpp>
+#include <database/query.hpp>
+#include <logger/logger.hpp>
+
+#include <type_traits>
+#include <vector>
+#include <string>
+#include <tuple>
+
+#include <sqlite3.h>
+
+template <int N, typename ColumnType, typename... T>
+typename std::enable_if<!std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type
+actual_bind(Statement& statement, std::vector<std::string>& params, const std::tuple<T...>&)
+{
+ const auto value = params.front();
+ params.erase(params.begin());
+ if (sqlite3_bind_text(statement.get(), N + 1, value.data(), static_cast<int>(value.size()), SQLITE_TRANSIENT) != SQLITE_OK)
+ log_error("Failed to bind ", value, " to param ", N);
+ else
+ log_debug("Bound (not id) [", value, "] to ", N);
+}
+
+template <int N, typename ColumnType, typename... T>
+typename std::enable_if<std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type
+actual_bind(Statement& statement, std::vector<std::string>&, const std::tuple<T...>& columns)
+{
+ auto&& column = std::get<Id>(columns);
+ if (column.value != 0)
+ {
+ if (sqlite3_bind_int64(statement.get(), N + 1, static_cast<sqlite3_int64>(column.value)) != SQLITE_OK)
+ log_error("Failed to bind ", column.value, " to id.");
+ }
+ else if (sqlite3_bind_null(statement.get(), N + 1) != SQLITE_OK)
+ log_error("Failed to bind NULL to param ", N);
+ else
+ log_debug("Bound NULL to ", N);
+}
+
+struct InsertQuery: public Query
+{
+ InsertQuery(const std::string& name):
+ Query("INSERT OR REPLACE INTO ")
+ {
+ this->body += name;
+ }
+
+ template <typename... T>
+ void execute(const std::tuple<T...>& columns, sqlite3* db)
+ {
+ auto statement = this->prepare(db);
+ {
+ this->bind_param(columns, statement);
+ if (sqlite3_step(statement.get()) != SQLITE_DONE)
+ log_error("Failed to execute query: ", sqlite3_errmsg(db));
+ }
+ }
+
+ template <int N=0, typename... T>
+ typename std::enable_if<N < sizeof...(T), void>::type
+ bind_param(const std::tuple<T...>& columns, Statement& statement)
+ {
+ using ColumnType = typename std::remove_reference<decltype(std::get<N>(columns))>::type;
+
+ actual_bind<N, ColumnType>(statement, this->params, columns);
+ this->bind_param<N+1>(columns, statement);
+ }
+
+ template <int N=0, typename... T>
+ typename std::enable_if<N == sizeof...(T), void>::type
+ bind_param(const std::tuple<T...>&, Statement&)
+ {}
+
+ template <typename... T>
+ void insert_values(const std::tuple<T...>& columns)
+ {
+ this->body += "VALUES (";
+ this->insert_value(columns);
+ this->body += ")";
+ }
+
+ template <int N=0, typename... T>
+ typename std::enable_if<N < sizeof...(T), void>::type
+ insert_value(const std::tuple<T...>& columns)
+ {
+ this->body += "?";
+ if (N != sizeof...(T) - 1)
+ this->body += ",";
+ this->body += " ";
+ add_param(*this, std::get<N>(columns));
+ this->insert_value<N+1>(columns);
+ }
+ template <int N=0, typename... T>
+ typename std::enable_if<N == sizeof...(T), void>::type
+ insert_value(const std::tuple<T...>&)
+ { }
+
+ template <typename... T>
+ void insert_col_names(const std::tuple<T...>& columns)
+ {
+ this->body += " (";
+ this->insert_col_name(columns);
+ this->body += ")\n";
+ }
+
+ template <int N=0, typename... T>
+ typename std::enable_if<N < sizeof...(T), void>::type
+ insert_col_name(const std::tuple<T...>& columns)
+ {
+ using ColumnType = typename std::remove_reference<decltype(std::get<N>(columns))>::type;
+
+ this->body += ColumnType::name;
+
+ if (N < (sizeof...(T) - 1))
+ this->body += ", ";
+
+ this->insert_col_name<N+1>(columns);
+ }
+ template <int N=0, typename... T>
+ typename std::enable_if<N == sizeof...(T), void>::type
+ insert_col_name(const std::tuple<T...>&)
+ {}
+
+
+ private:
+};
diff --git a/src/database/query.cpp b/src/database/query.cpp
new file mode 100644
index 0000000..ba63a92
--- /dev/null
+++ b/src/database/query.cpp
@@ -0,0 +1,34 @@
+#include <database/query.hpp>
+#include <database/column.hpp>
+
+template <>
+void add_param<Id>(Query&, const Id&)
+{}
+
+void actual_add_param(Query& query, const std::string& val)
+{
+ query.params.push_back(val);
+}
+
+void actual_add_param(Query& query, const OptionalBool& val)
+{
+ if (!val.is_set)
+ query.params.push_back("0");
+ else if (val.value)
+ query.params.push_back("1");
+ else
+ query.params.push_back("-1");
+}
+
+Query& operator<<(Query& query, const char* str)
+{
+ query.body += str;
+ return query;
+}
+
+Query& operator<<(Query& query, const std::string& str)
+{
+ query.body += "?";
+ actual_add_param(query, str);
+ return query;
+}
diff --git a/src/database/query.hpp b/src/database/query.hpp
new file mode 100644
index 0000000..f103fe9
--- /dev/null
+++ b/src/database/query.hpp
@@ -0,0 +1,90 @@
+#pragma once
+
+#include <utils/optional_bool.hpp>
+#include <database/statement.hpp>
+#include <database/column.hpp>
+
+#include <logger/logger.hpp>
+
+#include <vector>
+#include <string>
+
+#include <sqlite3.h>
+
+struct Query
+{
+ std::string body;
+ std::vector<std::string> params;
+
+ Query(std::string str):
+ body(std::move(str))
+ {}
+
+ Statement prepare(sqlite3* db)
+ {
+ sqlite3_stmt* stmt;
+ log_debug(this->body);
+ auto res = sqlite3_prepare(db, this->body.data(), static_cast<int>(this->body.size()) + 1,
+ &stmt, nullptr);
+ if (res != SQLITE_OK)
+ {
+ log_error("Error preparing statement: ", sqlite3_errmsg(db));
+ return nullptr;
+ }
+ Statement statement(stmt);
+ int i = 1;
+ for (const std::string& param: this->params)
+ {
+ if (sqlite3_bind_text(statement.get(), i, param.data(), static_cast<int>(param.size()), SQLITE_TRANSIENT) != SQLITE_OK)
+ log_debug("Failed to bind ", param, " to param ", i);
+ else
+ log_debug("Bound ", param, " to ", i);
+ i++;
+ }
+
+ return statement;
+ }
+
+ void execute(sqlite3* db)
+ {
+ auto statement = this->prepare(db);
+ while (sqlite3_step(statement.get()) != SQLITE_DONE)
+ ;
+ }
+};
+
+template <typename ColumnType>
+void add_param(Query& query, const ColumnType& column)
+{
+ actual_add_param(query, column.value);
+}
+template <>
+void add_param<Id>(Query& query, const Id& column);
+
+template <typename T>
+void actual_add_param(Query& query, const T& val)
+{
+ query.params.push_back(std::to_string(val));
+}
+
+void actual_add_param(Query& query, const std::string& val);
+void actual_add_param(Query& query, const OptionalBool& val);
+
+template <typename T>
+typename std::enable_if<!std::is_integral<T>::value, Query&>::type
+operator<<(Query& query, const T&)
+{
+ query.body += T::name;
+ return query;
+}
+
+Query& operator<<(Query& query, const char* str);
+Query& operator<<(Query& query, const std::string& str);
+template <typename Integer>
+typename std::enable_if<std::is_integral<Integer>::value, Query&>::type
+operator<<(Query& query, const Integer& i)
+{
+ query.body += "?";
+ actual_add_param(query, i);
+ return query;
+}
diff --git a/src/database/row.hpp b/src/database/row.hpp
new file mode 100644
index 0000000..e7a58c4
--- /dev/null
+++ b/src/database/row.hpp
@@ -0,0 +1,75 @@
+#pragma once
+
+#include <database/insert_query.hpp>
+#include <logger/logger.hpp>
+
+#include <type_traits>
+
+#include <sqlite3.h>
+
+template <typename ColumnType, typename... T>
+typename std::enable_if<!std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type
+update_id(std::tuple<T...>&, sqlite3*)
+{}
+
+template <typename ColumnType, typename... T>
+typename std::enable_if<std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type
+update_id(std::tuple<T...>& columns, sqlite3* db)
+{
+ auto&& column = std::get<ColumnType>(columns);
+ log_debug("Found an autoincrement col.");
+ auto res = sqlite3_last_insert_rowid(db);
+ log_debug("Value is now: ", res);
+ column.value = static_cast<Id::real_type>(res);
+}
+
+template <std::size_t N=0, typename... T>
+typename std::enable_if<N < sizeof...(T), void>::type
+update_autoincrement_id(std::tuple<T...>& columns, sqlite3* db)
+{
+ using ColumnType = typename std::remove_reference<decltype(std::get<N>(columns))>::type;
+ update_id<ColumnType>(columns, db);
+ update_autoincrement_id<N+1>(columns, db);
+}
+
+template <std::size_t N=0, typename... T>
+typename std::enable_if<N == sizeof...(T), void>::type
+update_autoincrement_id(std::tuple<T...>&, sqlite3*)
+{}
+
+template <typename... T>
+struct Row
+{
+ Row(std::string name):
+ table_name(std::move(name))
+ {}
+
+ template <typename Type>
+ auto& col()
+ {
+ auto&& col = std::get<Type>(this->columns);
+ return col.value;
+ }
+
+ template <typename Type>
+ const auto& col() const
+ {
+ auto&& col = std::get<Type>(this->columns);
+ return col.value;
+ }
+
+ void save(sqlite3* db)
+ {
+ InsertQuery query(this->table_name);
+ query.insert_col_names(this->columns);
+ query.insert_values(this->columns);
+ log_debug(query.body);
+
+ query.execute(this->columns, db);
+
+ update_autoincrement_id(this->columns, db);
+ }
+
+ std::tuple<T...> columns;
+ std::string table_name;
+};
diff --git a/src/database/select_query.hpp b/src/database/select_query.hpp
new file mode 100644
index 0000000..f4d71af
--- /dev/null
+++ b/src/database/select_query.hpp
@@ -0,0 +1,127 @@
+#pragma once
+
+#include <database/statement.hpp>
+#include <database/query.hpp>
+#include <logger/logger.hpp>
+#include <database/row.hpp>
+
+#include <utils/optional_bool.hpp>
+
+#include <vector>
+#include <string>
+
+#include <sqlite3.h>
+
+using namespace std::string_literals;
+
+template <typename T>
+typename std::enable_if<std::is_integral<T>::value, sqlite3_int64>::type
+extract_row_value(Statement& statement, const int i)
+{
+ return sqlite3_column_int64(statement.get(), i);
+}
+
+template <typename T>
+typename std::enable_if<std::is_same<std::string, T>::value, T>::type
+extract_row_value(Statement& statement, const int i)
+{
+ const auto size = sqlite3_column_bytes(statement.get(), i);
+ const unsigned char* str = sqlite3_column_text(statement.get(), i);
+ std::string result(reinterpret_cast<const char*>(str), static_cast<std::size_t>(size));
+ return result;
+}
+
+template <typename T>
+typename std::enable_if<std::is_same<OptionalBool, T>::value, T>::type
+extract_row_value(Statement& statement, const int i)
+{
+ const auto integer = sqlite3_column_int(statement.get(), i);
+ OptionalBool result;
+ if (integer > 0)
+ result.set_value(true);
+ else if (integer < 0)
+ result.set_value(false);
+ return result;
+}
+
+template <std::size_t N=0, typename... T>
+typename std::enable_if<N < sizeof...(T), void>::type
+extract_row_values(Row<T...>& row, Statement& statement)
+{
+ using ColumnType = typename std::remove_reference<decltype(std::get<N>(row.columns))>::type;
+
+ auto&& column = std::get<N>(row.columns);
+ column.value = static_cast<decltype(column.value)>(extract_row_value<typename ColumnType::real_type>(statement, N));
+
+ extract_row_values<N+1>(row, statement);
+}
+
+template <std::size_t N=0, typename... T>
+typename std::enable_if<N == sizeof...(T), void>::type
+extract_row_values(Row<T...>&, Statement&)
+{}
+
+template <typename... T>
+struct SelectQuery: public Query
+{
+ SelectQuery(std::string table_name):
+ Query("SELECT"),
+ table_name(table_name)
+ {
+ this->insert_col_name();
+ this->body += " from " + this->table_name;
+ }
+
+ template <std::size_t N=0>
+ typename std::enable_if<N < sizeof...(T), void>::type
+ insert_col_name()
+ {
+ using ColumnsType = std::tuple<T...>;
+ using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<ColumnsType>()))>::type;
+
+ this->body += " "s + ColumnType::name;
+
+ if (N < (sizeof...(T) - 1))
+ this->body += ", ";
+
+ this->insert_col_name<N+1>();
+ }
+ template <std::size_t N=0>
+ typename std::enable_if<N == sizeof...(T), void>::type
+ insert_col_name()
+ {}
+
+ SelectQuery& where()
+ {
+ this->body += " WHERE ";
+ return *this;
+ };
+
+ SelectQuery& order_by()
+ {
+ this->body += " ORDER BY ";
+ return *this;
+ }
+
+ SelectQuery& limit()
+ {
+ this->body += " LIMIT ";
+ return *this;
+ }
+
+ auto execute(sqlite3* db)
+ {
+ auto statement = this->prepare(db);
+ std::vector<Row<T...>> rows;
+ while (sqlite3_step(statement.get()) == SQLITE_ROW)
+ {
+ Row<T...> row(this->table_name);
+ extract_row_values(row, statement);
+ rows.push_back(row);
+ }
+ return rows;
+ }
+
+ const std::string table_name;
+};
+
diff --git a/src/database/statement.hpp b/src/database/statement.hpp
new file mode 100644
index 0000000..87cd70f
--- /dev/null
+++ b/src/database/statement.hpp
@@ -0,0 +1,35 @@
+#pragma once
+
+#include <sqlite3.h>
+
+class Statement
+{
+ public:
+ Statement(sqlite3_stmt* stmt):
+ stmt(stmt) {}
+ ~Statement()
+ {
+ sqlite3_finalize(this->stmt);
+ }
+
+ Statement(const Statement&) = delete;
+ Statement& operator=(const Statement&) = delete;
+ Statement(Statement&& other):
+ stmt(other.stmt)
+ {
+ other.stmt = nullptr;
+ }
+ Statement& operator=(Statement&& other)
+ {
+ this->stmt = other.stmt;
+ other.stmt = nullptr;
+ return *this;
+ }
+ sqlite3_stmt* get()
+ {
+ return this->stmt;
+ }
+
+ private:
+ sqlite3_stmt* stmt;
+};
diff --git a/src/database/table.cpp b/src/database/table.cpp
new file mode 100644
index 0000000..5929f33
--- /dev/null
+++ b/src/database/table.cpp
@@ -0,0 +1,25 @@
+#include <database/table.hpp>
+
+std::set<std::string> get_all_columns_from_table(sqlite3* db, const std::string& table_name)
+{
+ std::set<std::string> result;
+ char* errmsg;
+ std::string query{"PRAGMA table_info("s + table_name + ")"};
+ log_debug(query);
+ int res = sqlite3_exec(db, query.data(), [](void* param, int columns_nb, char** columns, char**) -> int {
+ constexpr int name_column = 1;
+ std::set<std::string>* result = static_cast<std::set<std::string>*>(param);
+ log_debug("Table has column ", columns[name_column]);
+ if (name_column < columns_nb)
+ result->insert(columns[name_column]);
+ return 0;
+ }, &result, &errmsg);
+
+ if (res != SQLITE_OK)
+ {
+ log_error("Error executing ", query, ": ", errmsg);
+ sqlite3_free(errmsg);
+ }
+
+ return result;
+}
diff --git a/src/database/table.hpp b/src/database/table.hpp
new file mode 100644
index 0000000..411ac6a
--- /dev/null
+++ b/src/database/table.hpp
@@ -0,0 +1,127 @@
+#pragma once
+
+#include <database/select_query.hpp>
+#include <database/type_to_sql.hpp>
+#include <logger/logger.hpp>
+#include <database/row.hpp>
+
+#include <algorithm>
+#include <string>
+#include <set>
+
+using namespace std::string_literals;
+
+std::set<std::string> get_all_columns_from_table(sqlite3* db, const std::string& table_name);
+
+template <typename ColumnType>
+void add_column_to_table(sqlite3* db, const std::string& table_name)
+{
+ const std::string name = ColumnType::name;
+ std::string query{"ALTER TABLE "s + table_name + " ADD " + ColumnType::name + " " + TypeToSQLType<typename ColumnType::real_type>::type};
+ log_debug(query);
+ char* error;
+ const auto result = sqlite3_exec(db, query.data(), nullptr, nullptr, &error);
+ if (result != SQLITE_OK)
+ {
+ log_error("Error adding column ", name, " to table ", table_name, ": ", error);
+ sqlite3_free(error);
+ }
+}
+
+template <typename... T>
+class Table
+{
+ static_assert(sizeof...(T) > 0, "Table cannot be empty");
+ using ColumnTypes = std::tuple<T...>;
+
+ public:
+ using RowType = Row<T...>;
+
+ Table(std::string name):
+ name(std::move(name))
+ {}
+
+ void upgrade(sqlite3* db)
+ {
+ const auto existing_columns = get_all_columns_from_table(db, this->name);
+ add_column_if_not_exists(db, existing_columns);
+ }
+
+ void create(sqlite3* db)
+ {
+ std::string res{"CREATE TABLE IF NOT EXISTS "};
+ res += this->name;
+ res += " (\n";
+ this->add_column_create(res);
+ res += ")";
+
+ log_debug(res);
+
+ char* error;
+ const auto result = sqlite3_exec(db, res.data(), nullptr, nullptr, &error);
+ log_debug("result: ", +result);
+ if (result != SQLITE_OK)
+ {
+ log_error("Error executing query: ", error);
+ sqlite3_free(error);
+ }
+ }
+
+ RowType row()
+ {
+ return {this->name};
+ }
+
+ SelectQuery<T...> select()
+ {
+ SelectQuery<T...> select(this->name);
+ return select;
+ }
+
+ const std::string& get_name() const
+ {
+ return this->name;
+ }
+
+ private:
+
+ template <std::size_t N=0>
+ typename std::enable_if<N < sizeof...(T), void>::type
+ add_column_if_not_exists(sqlite3* db, const std::set<std::string>& existing_columns)
+ {
+ using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<ColumnTypes>()))>::type;
+ if (existing_columns.count(ColumnType::name) != 1)
+ {
+ add_column_to_table<ColumnType>(db, this->name);
+ }
+ add_column_if_not_exists<N+1>(db, existing_columns);
+ }
+ template <std::size_t N=0>
+ typename std::enable_if<N == sizeof...(T), void>::type
+ add_column_if_not_exists(sqlite3*, const std::set<std::string>&)
+ {}
+
+ template <std::size_t N=0>
+ typename std::enable_if<N < sizeof...(T), void>::type
+ add_column_create(std::string& str)
+ {
+ using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<ColumnTypes>()))>::type;
+ using RealType = typename ColumnType::real_type;
+ str += ColumnType::name;
+ str += " ";
+ str += TypeToSQLType<RealType>::type;
+ str += " "s + ColumnType::options;
+ if (N != sizeof...(T) - 1)
+ str += ",";
+ str += "\n";
+
+ add_column_create<N+1>(str);
+ }
+
+ template <std::size_t N=0>
+ typename std::enable_if<N == sizeof...(T), void>::type
+ add_column_create(std::string&)
+ { }
+
+ const std::string name;
+};
diff --git a/src/database/type_to_sql.cpp b/src/database/type_to_sql.cpp
new file mode 100644
index 0000000..bcd9daa
--- /dev/null
+++ b/src/database/type_to_sql.cpp
@@ -0,0 +1,9 @@
+#include <database/type_to_sql.hpp>
+
+template <> const std::string TypeToSQLType<int>::type = "INTEGER";
+template <> const std::string TypeToSQLType<std::size_t>::type = "INTEGER";
+template <> const std::string TypeToSQLType<long>::type = "INTEGER";
+template <> const std::string TypeToSQLType<long long>::type = "INTEGER";
+template <> const std::string TypeToSQLType<bool>::type = "INTEGER";
+template <> const std::string TypeToSQLType<std::string>::type = "TEXT";
+template <> const std::string TypeToSQLType<OptionalBool>::type = "INTEGER"; \ No newline at end of file
diff --git a/src/database/type_to_sql.hpp b/src/database/type_to_sql.hpp
new file mode 100644
index 0000000..ba806ab
--- /dev/null
+++ b/src/database/type_to_sql.hpp
@@ -0,0 +1,16 @@
+#pragma once
+
+#include <utils/optional_bool.hpp>
+
+#include <string>
+
+template <typename T>
+struct TypeToSQLType { static const std::string type; };
+
+template <> const std::string TypeToSQLType<int>::type;
+template <> const std::string TypeToSQLType<std::size_t>::type;
+template <> const std::string TypeToSQLType<long>::type;
+template <> const std::string TypeToSQLType<long long>::type;
+template <> const std::string TypeToSQLType<bool>::type;
+template <> const std::string TypeToSQLType<std::string>::type;
+template <> const std::string TypeToSQLType<OptionalBool>::type; \ No newline at end of file
diff --git a/src/identd/identd_server.hpp b/src/identd/identd_server.hpp
new file mode 100644
index 0000000..b1c8ec8
--- /dev/null
+++ b/src/identd/identd_server.hpp
@@ -0,0 +1,39 @@
+#pragma once
+
+#include <network/tcp_server_socket.hpp>
+#include <identd/identd_socket.hpp>
+#include <algorithm>
+#include <unistd.h>
+
+class BiboumiComponent;
+
+class IdentdServer: public TcpSocketServer<IdentdSocket>
+{
+ public:
+ IdentdServer(const BiboumiComponent& biboumi_component, std::shared_ptr<Poller>& poller, const uint16_t port):
+ TcpSocketServer<IdentdSocket>(poller, port),
+ biboumi_component(biboumi_component)
+ {}
+
+ const BiboumiComponent& get_biboumi_component() const
+ {
+ return this->biboumi_component;
+ }
+ void shutdown()
+ {
+ if (this->poller->is_managing_socket(this->socket))
+ this->poller->remove_socket_handler(this->socket);
+ ::close(this->socket);
+ }
+ void clean()
+ {
+ this->sockets.erase(std::remove_if(this->sockets.begin(), this->sockets.end(),
+ [](const std::unique_ptr<IdentdSocket>& socket)
+ {
+ return socket->get_socket() == -1;
+ }),
+ this->sockets.end());
+ }
+ private:
+ const BiboumiComponent& biboumi_component;
+};
diff --git a/src/identd/identd_socket.cpp b/src/identd/identd_socket.cpp
new file mode 100644
index 0000000..b85257c
--- /dev/null
+++ b/src/identd/identd_socket.cpp
@@ -0,0 +1,63 @@
+#include <identd/identd_socket.hpp>
+#include <identd/identd_server.hpp>
+#include <xmpp/biboumi_component.hpp>
+#include <sstream>
+#include <iomanip>
+
+#include <utils/sha1.hpp>
+
+#include <logger/logger.hpp>
+
+IdentdSocket::IdentdSocket(std::shared_ptr<Poller>& poller, const socket_t socket, TcpSocketServer<IdentdSocket>& server):
+ TCPSocketHandler(poller),
+ server(dynamic_cast<IdentdServer&>(server))
+{
+ this->socket = socket;
+}
+
+void IdentdSocket::parse_in_buffer(const std::size_t)
+{
+ while (true)
+ {
+ const auto line_end = this->in_buf.find('\n');
+ if (line_end == std::string::npos)
+ break;
+ std::istringstream line(this->in_buf.substr(0, line_end));
+ this->consume_in_buffer(line_end + 1);
+
+ uint16_t local_port;
+ uint16_t remote_port;
+ char sep;
+ line >> local_port >> sep >> remote_port;
+ const auto& xmpp = this->server.get_biboumi_component();
+ auto response = this->generate_answer(xmpp, local_port, remote_port);
+
+ this->send_data(std::move(response));
+ }
+}
+
+static std::string hash_jid(const std::string& jid)
+{
+ return sha1(jid);
+}
+
+std::string IdentdSocket::generate_answer(const BiboumiComponent& biboumi, uint16_t local, uint16_t remote)
+{
+ for (const Bridge* bridge: biboumi.get_bridges())
+ {
+ for (const auto& pair: bridge->get_irc_clients())
+ {
+ if (pair.second->match_port_pairt(local, remote))
+ {
+ std::ostringstream os;
+ os << local << " , " << remote << " : USERID : OTHER : " << hash_jid(bridge->get_bare_jid());
+ log_debug("Identd, sending: ", os.str());
+ return os.str();
+ }
+ }
+ }
+ std::ostringstream os;
+ os << local << " , " << remote << " ERROR : NO-USER";
+ log_debug("Identd, sending: ", os.str());
+ return os.str();
+}
diff --git a/src/identd/identd_socket.hpp b/src/identd/identd_socket.hpp
new file mode 100644
index 0000000..a386d80
--- /dev/null
+++ b/src/identd/identd_socket.hpp
@@ -0,0 +1,36 @@
+#pragma once
+
+#include <network/socket_handler.hpp>
+
+#include <network/tcp_socket_handler.hpp>
+
+#include <logger/logger.hpp>
+#include <xmpp/biboumi_component.hpp>
+
+class XmppComponent;
+class IdentdSocket;
+class IdentdServer;
+template <typename T>
+class TcpSocketServer;
+
+class IdentdSocket: public TCPSocketHandler
+{
+ public:
+ IdentdSocket(std::shared_ptr<Poller>& poller, const socket_t socket, TcpSocketServer<IdentdSocket>& server);
+ ~IdentdSocket() = default;
+ std::string generate_answer(const BiboumiComponent& biboumi, uint16_t local, uint16_t remote);
+
+ void parse_in_buffer(const std::size_t size) override final;
+
+ bool is_connected() const override final
+ {
+ return true;
+ }
+ bool is_connecting() const override final
+ {
+ return false;
+ }
+
+ private:
+ IdentdServer& server;
+};
diff --git a/src/irc/iid.cpp b/src/irc/iid.cpp
index d442013..a63a1c3 100644
--- a/src/irc/iid.cpp
+++ b/src/irc/iid.cpp
@@ -1,3 +1,4 @@
+#include <utility>
#include <utils/tolower.hpp>
#include <config/config.hpp>
#include <bridge/bridge.hpp>
@@ -7,10 +8,10 @@
constexpr char Iid::separator[];
-Iid::Iid(const std::string& local, const std::string& server, Iid::Type type):
+Iid::Iid(std::string local, std::string server, Iid::Type type):
type(type),
- local(local),
- server(server)
+ local(std::move(local)),
+ server(std::move(server))
{
}
@@ -34,9 +35,10 @@ Iid::Iid(const std::string& iid, const Bridge *bridge)
void Iid::set_type(const std::set<char>& chantypes)
{
+ if (this->local.empty() && this->server.empty())
+ this->type = Iid::Type::None;
if (this->local.empty())
return;
-
if (chantypes.count(this->local[0]) == 1)
this->type = Iid::Type::Channel;
else
@@ -105,6 +107,8 @@ namespace std {
{
if (iid.type == Iid::Type::Server)
return iid.get_server();
+ else if (iid.get_local().empty() && iid.get_server().empty())
+ return {};
else
return iid.get_encoded_local() + iid.separator + iid.get_server();
}
diff --git a/src/irc/iid.hpp b/src/irc/iid.hpp
index 44861c1..89f4797 100644
--- a/src/irc/iid.hpp
+++ b/src/irc/iid.hpp
@@ -53,12 +53,13 @@ public:
Channel,
User,
Server,
+ None,
};
static constexpr char separator[]{"%"};
Iid(const std::string& iid, const std::set<char>& chantypes);
Iid(const std::string& iid, const std::initializer_list<char>& chantypes);
Iid(const std::string& iid, const Bridge* bridge);
- Iid(const std::string& local, const std::string& server, Type type);
+ Iid(std::string local, std::string server, Type type);
Iid() = default;
Iid(const Iid&) = default;
diff --git a/src/irc/irc_channel.cpp b/src/irc/irc_channel.cpp
index 40d7f54..53043c7 100644
--- a/src/irc/irc_channel.cpp
+++ b/src/irc/irc_channel.cpp
@@ -1,21 +1,25 @@
#include <irc/irc_channel.hpp>
#include <algorithm>
-void IrcChannel::set_self(const std::string& name)
+void IrcChannel::set_self(IrcUser* user)
{
- this->self = std::make_unique<IrcUser>(name);
+ this->self = user;
}
IrcUser* IrcChannel::add_user(const std::string& name,
const std::map<char, char>& prefix_to_mode)
{
- this->users.emplace_back(std::make_unique<IrcUser>(name, prefix_to_mode));
+ auto new_user = std::make_unique<IrcUser>(name, prefix_to_mode);
+ auto old_user = this->find_user(new_user->nick);
+ if (old_user)
+ return old_user;
+ this->users.emplace_back(std::move(new_user));
return this->users.back().get();
}
IrcUser* IrcChannel::get_self() const
{
- return this->self.get();
+ return this->self;
}
IrcUser* IrcChannel::find_user(const std::string& name) const
@@ -32,19 +36,27 @@ IrcUser* IrcChannel::find_user(const std::string& name) const
void IrcChannel::remove_user(const IrcUser* user)
{
const auto nick = user->nick;
+ const bool is_self = (user == this->self);
const auto it = std::find_if(this->users.begin(), this->users.end(),
[nick](const std::unique_ptr<IrcUser>& u)
{
return nick == u->nick;
});
if (it != this->users.end())
- this->users.erase(it);
+ {
+ this->users.erase(it);
+ if (is_self)
+ {
+ this->self = nullptr;
+ this->joined = false;
+ }
+ }
}
void IrcChannel::remove_all_users()
{
this->users.clear();
- this->self.reset();
+ this->self = nullptr;
}
DummyIrcChannel::DummyIrcChannel():
diff --git a/src/irc/irc_channel.hpp b/src/irc/irc_channel.hpp
index 7c269b9..8f85edb 100644
--- a/src/irc/irc_channel.hpp
+++ b/src/irc/irc_channel.hpp
@@ -27,7 +27,7 @@ public:
bool parting{false};
std::string topic{};
std::string topic_author{};
- void set_self(const std::string& name);
+ void set_self(IrcUser* user);
IrcUser* get_self() const;
IrcUser* add_user(const std::string& name,
const std::map<char, char>& prefix_to_mode);
@@ -38,7 +38,8 @@ public:
{ return this->users; }
protected:
- std::unique_ptr<IrcUser> self{};
+ // Pointer to one IrcUser stored in users
+ IrcUser* self{nullptr};
std::vector<std::unique_ptr<IrcUser>> users{};
};
diff --git a/src/irc/irc_client.cpp b/src/irc/irc_client.cpp
index de6b089..bacb89e 100644
--- a/src/irc/irc_client.cpp
+++ b/src/irc/irc_client.cpp
@@ -1,3 +1,4 @@
+#include <utility>
#include <utils/timed_events.hpp>
#include <database/database.hpp>
#include <irc/irc_message.hpp>
@@ -14,13 +15,13 @@
#include <sstream>
#include <iostream>
#include <stdexcept>
+#include <algorithm>
#include <cstring>
#include <chrono>
#include <string>
#include "biboumi.h"
-#include "louloulibs.h"
using namespace std::string_literals;
using namespace std::chrono_literals;
@@ -61,11 +62,14 @@ static const std::unordered_map<std::string,
{"333", {&IrcClient::on_topic_who_time_received, {4, 0}}},
{"RPL_TOPICWHOTIME", {&IrcClient::on_topic_who_time_received, {4, 0}}},
{"366", {&IrcClient::on_channel_completely_joined, {2, 0}}},
+ {"367", {&IrcClient::on_banlist, {3, 0}}},
+ {"368", {&IrcClient::on_banlist_end, {3, 0}}},
{"396", {&IrcClient::on_own_host_received, {2, 0}}},
{"432", {&IrcClient::on_erroneous_nickname, {2, 0}}},
{"433", {&IrcClient::on_nickname_conflict, {2, 0}}},
{"438", {&IrcClient::on_nickname_change_too_fast, {2, 0}}},
{"443", {&IrcClient::on_useronchannel, {3, 0}}},
+ {"475", {&IrcClient::on_channel_bad_key, {3, 0}}},
{"ERR_USERONCHANNEL", {&IrcClient::on_useronchannel, {3, 0}}},
{"001", {&IrcClient::on_welcome_message, {1, 0}}},
{"PART", {&IrcClient::on_part, {1, 0}}},
@@ -113,7 +117,6 @@ static const std::unordered_map<std::string,
{"472", {&IrcClient::on_generic_error, {2, 0}}},
{"473", {&IrcClient::on_generic_error, {2, 0}}},
{"474", {&IrcClient::on_generic_error, {2, 0}}},
- {"475", {&IrcClient::on_generic_error, {2, 0}}},
{"476", {&IrcClient::on_generic_error, {2, 0}}},
{"477", {&IrcClient::on_generic_error, {2, 0}}},
{"481", {&IrcClient::on_generic_error, {2, 0}}},
@@ -127,16 +130,16 @@ static const std::unordered_map<std::string,
{"502", {&IrcClient::on_generic_error, {2, 0}}},
};
-IrcClient::IrcClient(std::shared_ptr<Poller> poller, const std::string& hostname,
- const std::string& nickname, const std::string& username,
- const std::string& realname, const std::string& user_hostname,
+IrcClient::IrcClient(std::shared_ptr<Poller>& poller, std::string hostname,
+ std::string nickname, std::string username,
+ std::string realname, std::string user_hostname,
Bridge& bridge):
- TCPSocketHandler(poller),
- hostname(hostname),
- user_hostname(user_hostname),
- username(username),
- realname(realname),
- current_nick(nickname),
+ TCPClientSocketHandler(poller),
+ hostname(std::move(hostname)),
+ user_hostname(std::move(user_hostname)),
+ username(std::move(username)),
+ realname(std::move(realname)),
+ current_nick(std::move(nickname)),
bridge(bridge),
welcomed(false),
chanmodes({"", "", "", ""}),
@@ -153,11 +156,11 @@ IrcClient::IrcClient(std::shared_ptr<Poller> poller, const std::string& hostname
#ifdef USE_DATABASE
auto options = Database::get_irc_server_options(this->bridge.get_bare_jid(),
this->get_hostname());
- std::vector<std::string> ports = utils::split(options.ports, ';', false);
+ std::vector<std::string> ports = utils::split(options.col<Database::Ports>(), ';', false);
for (auto it = ports.rbegin(); it != ports.rend(); ++it)
this->ports_to_try.emplace(*it, false);
# ifdef BOTAN_FOUND
- ports = utils::split(options.tlsPorts, ';', false);
+ ports = utils::split(options.col<Database::TlsPorts>(), ';', false);
for (auto it = ports.rbegin(); it != ports.rend(); ++it)
this->ports_to_try.emplace(*it, true);
# endif // BOTAN_FOUND
@@ -201,7 +204,7 @@ void IrcClient::start()
# ifdef USE_DATABASE
auto options = Database::get_irc_server_options(this->bridge.get_bare_jid(),
this->get_hostname());
- this->credential_manager.set_trusted_fingerprint(options.trustedFingerprint);
+ this->credential_manager.set_trusted_fingerprint(options.col<Database::TrustedFingerprint>());
# endif
#endif
this->connect(this->hostname, port, tls);
@@ -272,8 +275,8 @@ void IrcClient::on_connected()
#ifdef USE_DATABASE
auto options = Database::get_irc_server_options(this->bridge.get_bare_jid(),
this->get_hostname());
- if (!options.pass.value().empty())
- this->send_pass_command(options.pass.value());
+ if (!options.col<Database::Pass>().empty())
+ this->send_pass_command(options.col<Database::Pass>());
#endif
this->send_nick_command(this->current_nick);
@@ -281,10 +284,10 @@ void IrcClient::on_connected()
#ifdef USE_DATABASE
if (Config::get("realname_customization", "true") == "true")
{
- if (!options.username.value().empty())
- this->username = options.username.value();
- if (!options.realname.value().empty())
- this->realname = options.realname.value();
+ if (!options.col<Database::Username>().empty())
+ this->username = options.col<Database::Username>();
+ if (!options.col<Database::Realname>().empty())
+ this->realname = options.col<Database::Realname>();
this->send_user_command(username, realname);
}
else
@@ -343,7 +346,7 @@ void IrcClient::parse_in_buffer(const size_t)
if (pos == std::string::npos)
break ;
IrcMessage message(this->in_buf.substr(0, pos));
- this->in_buf = this->in_buf.substr(pos + 2, std::string::npos);
+ this->consume_in_buffer(pos + 2);
log_debug("IRC RECEIVING: (", this->get_hostname(), ") ", message);
// Call the standard callback (if any), associated with the command
@@ -386,10 +389,10 @@ void IrcClient::send_message(IrcMessage&& message)
std::string res;
if (!message.prefix.empty())
res += ":" + std::move(message.prefix) + " ";
- res += std::move(message.command);
+ res += message.command;
for (const std::string& arg: message.arguments)
{
- if (arg.find(" ") != std::string::npos ||
+ if (arg.find(' ') != std::string::npos ||
(!arg.empty() && arg[0] == ':'))
{
res += " :" + arg;
@@ -455,7 +458,12 @@ void IrcClient::send_quit_command(const std::string& reason)
void IrcClient::send_join_command(const std::string& chan_name, const std::string& password)
{
if (this->welcomed == false)
- this->channels_to_join.emplace_back(chan_name, password);
+ {
+ const auto it = std::find_if(begin(this->channels_to_join), end(this->channels_to_join),
+ [&chan_name](const auto& pair) { return std::get<0>(pair) == chan_name; });
+ if (it == end(this->channels_to_join))
+ this->channels_to_join.emplace_back(chan_name, password);
+ }
else if (password.empty())
this->send_message(IrcMessage("JOIN", {chan_name}));
else
@@ -501,15 +509,7 @@ void IrcClient::send_private_message(const std::string& username, const std::str
void IrcClient::send_part_command(const std::string& chan_name, const std::string& status_message)
{
- IrcChannel* channel = this->get_channel(chan_name);
- if (channel->joined == true)
- {
- if (chan_name.empty())
- this->leave_dummy_channel(status_message);
- else
- this->send_message(IrcMessage("PART", {chan_name, status_message}));
- channel->parting = true;
- }
+ this->send_message(IrcMessage("PART", {chan_name, status_message}));
}
void IrcClient::send_mode_command(const std::string& chan_name, const std::vector<std::string>& arguments)
@@ -546,9 +546,18 @@ void IrcClient::forward_server_message(const IrcMessage& message)
void IrcClient::on_notice(const IrcMessage& message)
{
std::string from = message.prefix;
- const std::string to = message.arguments[0];
+ std::string to = message.arguments[0];
const std::string body = message.arguments[1];
+ // Handle notices starting with [#channame] as if they were sent to that channel
+ if (body.size() > 3 && body[0] == '[')
+ {
+ const auto chan_prefix = body[1];
+ auto end = body.find(']');
+ if (this->chantypes.find(chan_prefix) != this->chantypes.end() && end != std::string::npos)
+ to = body.substr(1, end - 1);
+ }
+
if (!body.empty() && body[0] == '\01' && body[body.size() - 1] == '\01')
// Do not forward the notice to the user if it's a CTCP command
return ;
@@ -635,15 +644,18 @@ void IrcClient::set_and_forward_user_list(const IrcMessage& message)
std::vector<std::string> nicks = utils::split(message.arguments[3], ' ');
for (const std::string& nick: nicks)
{
- const IrcUser* user = channel->add_user(nick, this->prefix_to_mode);
- if (user->nick != channel->get_self()->nick)
+ // Just create this dummy user to parse and get its modes
+ IrcUser tmp_user{nick, this->prefix_to_mode};
+ // Does this concern ourself
+ if (channel->get_self() && channel->find_user(tmp_user.nick) == channel->get_self())
{
- this->bridge.send_user_join(this->hostname, chan_name, user, user->get_most_significant_mode(this->sorted_user_modes), false);
+ // We now know our own modes, that’s all.
+ channel->get_self()->modes = tmp_user.modes;
}
else
- {
- // we now know the modes of self, so copy the modes into self
- channel->get_self()->modes = user->modes;
+ { // Otherwise this is a new user
+ const IrcUser *user = channel->add_user(nick, this->prefix_to_mode);
+ this->bridge.send_user_join(this->hostname, chan_name, user, user->get_most_significant_mode(this->sorted_user_modes), false);
}
}
}
@@ -657,13 +669,11 @@ void IrcClient::on_channel_join(const IrcMessage& message)
else
channel = this->get_channel(chan_name);
const std::string nick = message.prefix;
+ IrcUser* user = channel->add_user(nick, this->prefix_to_mode);
if (channel->joined == false)
- channel->set_self(nick);
+ channel->set_self(user);
else
- {
- const IrcUser* user = channel->add_user(nick, this->prefix_to_mode);
- this->bridge.send_user_join(this->hostname, chan_name, user, user->get_most_significant_mode(this->sorted_user_modes), false);
- }
+ this->bridge.send_user_join(this->hostname, chan_name, user, user->get_most_significant_mode(this->sorted_user_modes), false);
}
void IrcClient::on_channel_message(const IrcMessage& message)
@@ -776,6 +786,43 @@ void IrcClient::on_channel_completely_joined(const IrcMessage& message)
this->bridge.send_topic(this->hostname, chan_name, channel->topic, channel->topic_author);
}
+void IrcClient::on_banlist(const IrcMessage& message)
+{
+ const std::string chan_name = utils::tolower(message.arguments[1]);
+ IrcChannel* channel = this->get_channel(chan_name);
+ if (channel->joined)
+ {
+ Iid iid;
+ iid.set_local(chan_name);
+ iid.set_server(this->hostname);
+ iid.type = Iid::Type::Channel;
+ std::string body{message.arguments[2] + " banned"};
+ if (message.arguments.size() >= 4)
+ {
+ IrcUser by(message.arguments[3], this->prefix_to_mode);
+ body += " by " + by.nick;
+ }
+ if (message.arguments.size() >= 5)
+ body += " on " + message.arguments[4];
+
+ this->bridge.send_message(iid, "", body, true);
+ }
+}
+
+void IrcClient::on_banlist_end(const IrcMessage& message)
+{
+ const std::string chan_name = utils::tolower(message.arguments[1]);
+ IrcChannel* channel = this->get_channel(chan_name);
+ if (channel->joined)
+ {
+ Iid iid;
+ iid.set_local(chan_name);
+ iid.set_server(this->hostname);
+ iid.type = Iid::Type::Channel;
+ this->bridge.send_message(iid, "", message.arguments[2], true);
+ }
+}
+
void IrcClient::on_own_host_received(const IrcMessage& message)
{
this->own_host = message.arguments[1];
@@ -799,10 +846,10 @@ void IrcClient::on_nickname_conflict(const IrcMessage& message)
{
const std::string nickname = message.arguments[1];
this->on_generic_error(message);
- for (auto it = this->channels.begin(); it != this->channels.end(); ++it)
+ for (const auto& pair: this->channels)
{
Iid iid;
- iid.set_local(it->first);
+ iid.set_local(pair.first);
iid.set_server(this->hostname);
iid.type = Iid::Type::Channel;
this->bridge.send_nickname_conflict_error(iid, nickname);
@@ -816,10 +863,10 @@ void IrcClient::on_nickname_change_too_fast(const IrcMessage& message)
if (message.arguments.size() >= 3)
txt = message.arguments[2];
this->on_generic_error(message);
- for (auto it = this->channels.begin(); it != this->channels.end(); ++it)
+ for (const auto& pair: this->channels)
{
Iid iid;
- iid.set_local(it->first);
+ iid.set_local(pair.first);
iid.set_server(this->hostname);
iid.type = Iid::Type::Channel;
this->bridge.send_presence_error(iid, nickname,
@@ -847,14 +894,53 @@ void IrcClient::on_welcome_message(const IrcMessage& message)
#ifdef USE_DATABASE
auto options = Database::get_irc_server_options(this->bridge.get_bare_jid(),
this->get_hostname());
- if (!options.afterConnectionCommand.value().empty())
- this->send_raw(options.afterConnectionCommand.value());
+ if (!options.col<Database::AfterConnectionCommand>().empty())
+ this->send_raw(options.col<Database::AfterConnectionCommand>());
#endif
// Install a repeated events to regularly send a PING
TimedEventsManager::instance().add_event(TimedEvent(240s, std::bind(&IrcClient::send_ping_command, this),
"PING"s + this->hostname + this->bridge.get_jid()));
+ std::string channels{};
+ std::string channels_with_key{};
+ std::string keys{};
+
for (const auto& tuple: this->channels_to_join)
- this->send_join_command(std::get<0>(tuple), std::get<1>(tuple));
+ {
+ const auto& chan = std::get<0>(tuple);
+ const auto& key = std::get<1>(tuple);
+ if (chan.empty())
+ continue;
+ if (!key.empty())
+ {
+ if (keys.size() + channels_with_key.size() >= 300)
+ { // Arbitrary size, to make sure we never send more than 512
+ this->send_join_command(channels_with_key, keys);
+ channels_with_key.clear();
+ keys.clear();
+ }
+ if (!keys.empty())
+ keys += ",";
+ keys += key;
+ if (!channels_with_key.empty())
+ channels_with_key += ",";
+ channels_with_key += chan;
+ }
+ else
+ {
+ if (channels.size() >= 300)
+ { // Arbitrary size, to make sure we never send more than 512
+ this->send_join_command(channels, {});
+ channels.clear();
+ }
+ if (!channels.empty())
+ channels += ",";
+ channels += chan;
+ }
+ }
+ if (!channels.empty())
+ this->send_join_command(channels, {});
+ if (!channels_with_key.empty())
+ this->send_join_command(channels_with_key, keys);
this->channels_to_join.clear();
// Indicate that the dummy channel is joined as well, if needed
if (this->dummy_channel.joining)
@@ -883,20 +969,19 @@ void IrcClient::on_part(const IrcMessage& message)
if (user)
{
std::string nick = user->nick;
+ bool self = channel->get_self() && channel->get_self()->nick == nick;
channel->remove_user(user);
Iid iid;
iid.set_local(chan_name);
iid.set_server(this->hostname);
iid.type = Iid::Type::Channel;
- bool self = channel->get_self()->nick == nick;
if (self)
{
- channel->joined = false;
this->channels.erase(utils::tolower(chan_name));
// channel pointer is now invalid
channel = nullptr;
}
- this->bridge.send_muc_leave(std::move(iid), std::move(nick), std::move(txt), self);
+ this->bridge.send_muc_leave(iid, std::move(nick), txt, self);
}
}
@@ -904,17 +989,17 @@ void IrcClient::on_error(const IrcMessage& message)
{
const std::string leave_message = message.arguments[0];
// The user is out of all the channels
- for (auto it = this->channels.begin(); it != this->channels.end(); ++it)
+ for (const auto& pair: this->channels)
{
Iid iid;
- iid.set_local(it->first);
+ iid.set_local(pair.first);
iid.set_server(this->hostname);
iid.type = Iid::Type::Channel;
- IrcChannel* channel = it->second.get();
+ IrcChannel* channel = pair.second.get();
if (!channel->joined)
continue;
std::string own_nick = channel->get_self()->nick;
- this->bridge.send_muc_leave(std::move(iid), std::move(own_nick), leave_message, true);
+ this->bridge.send_muc_leave(iid, std::move(own_nick), leave_message, true);
}
this->channels.clear();
this->send_gateway_message("ERROR: "s + leave_message);
@@ -925,10 +1010,10 @@ void IrcClient::on_quit(const IrcMessage& message)
std::string txt;
if (message.arguments.size() >= 1)
txt = message.arguments[0];
- for (auto it = this->channels.begin(); it != this->channels.end(); ++it)
+ for (const auto& pair: this->channels)
{
- const std::string chan_name = it->first;
- IrcChannel* channel = it->second.get();
+ const std::string& chan_name = pair.first;
+ IrcChannel* channel = pair.second.get();
const IrcUser* user = channel->find_user(message.prefix);
if (user)
{
@@ -938,7 +1023,7 @@ void IrcClient::on_quit(const IrcMessage& message)
iid.set_local(chan_name);
iid.set_server(this->hostname);
iid.type = Iid::Type::Channel;
- this->bridge.send_muc_leave(std::move(iid), std::move(nick), txt, false);
+ this->bridge.send_muc_leave(iid, std::move(nick), txt, false);
}
}
}
@@ -974,9 +1059,9 @@ void IrcClient::on_nick(const IrcMessage& message)
{
change_nick_func("", &this->get_dummy_channel());
}
- for (auto it = this->channels.begin(); it != this->channels.end(); ++it)
+ for (const auto& pair: this->channels)
{
- change_nick_func(it->first, it->second.get());
+ change_nick_func(pair.first, pair.second.get());
}
}
@@ -1019,6 +1104,18 @@ void IrcClient::on_mode(const IrcMessage& message)
this->on_user_mode(message);
}
+void IrcClient::on_channel_bad_key(const IrcMessage& message)
+{
+ this->on_generic_error(message);
+ const std::string& nickname = message.arguments[0];
+ const std::string& channel = message.arguments[1];
+ std::string text;
+ if (message.arguments.size() > 2)
+ text = message.arguments[2];
+
+ this->bridge.send_presence_error({channel, this->hostname, Iid::Type::Channel}, nickname, "auth", "not-authorized", "", text);
+}
+
void IrcClient::on_channel_mode(const IrcMessage& message)
{
// For now, just transmit the modes so the user can know what happens
@@ -1075,7 +1172,7 @@ void IrcClient::on_channel_mode(const IrcMessage& message)
{
// That mode can also be of type B if it is present in the
// prefix_to_mode map
- for (const std::pair<char, char>& pair: this->prefix_to_mode)
+ for (const auto& pair: this->prefix_to_mode)
if (pair.second == c)
{
type = 1;
@@ -1148,14 +1245,14 @@ DummyIrcChannel& IrcClient::get_dummy_channel()
return this->dummy_channel;
}
-void IrcClient::leave_dummy_channel(const std::string& exit_message)
+void IrcClient::leave_dummy_channel(const std::string& exit_message, const std::string& resource)
{
if (!this->dummy_channel.joined)
return;
this->dummy_channel.joined = false;
this->dummy_channel.joining = false;
this->dummy_channel.remove_all_users();
- this->bridge.send_muc_leave(Iid("%"s + this->hostname, this->chantypes), std::string(this->current_nick), exit_message, true);
+ this->bridge.send_muc_leave(Iid("%"s + this->hostname, this->chantypes), std::string(this->current_nick), exit_message, true, resource);
}
#ifdef BOTAN_FOUND
@@ -1163,7 +1260,7 @@ bool IrcClient::abort_on_invalid_cert() const
{
#ifdef USE_DATABASE
auto options = Database::get_irc_server_options(this->bridge.get_bare_jid(), this->hostname);
- return options.verifyCert.value();
+ return options.col<Database::VerifyCert>();
#endif
return true;
}
diff --git a/src/irc/irc_client.hpp b/src/irc/irc_client.hpp
index 1b4d892..aec6cd9 100644
--- a/src/irc/irc_client.hpp
+++ b/src/irc/irc_client.hpp
@@ -5,7 +5,7 @@
#include <irc/irc_channel.hpp>
#include <irc/iid.hpp>
-#include <network/tcp_socket_handler.hpp>
+#include <network/tcp_client_socket_handler.hpp>
#include <network/resolver.hpp>
#include <unordered_map>
@@ -23,12 +23,12 @@ class Bridge;
* Represent one IRC client, i.e. an endpoint connected to a single IRC
* server, through a TCP socket, receiving and sending commands to it.
*/
-class IrcClient: public TCPSocketHandler
+class IrcClient: public TCPClientSocketHandler
{
public:
- explicit IrcClient(std::shared_ptr<Poller> poller, const std::string& hostname,
- const std::string& nickname, const std::string& username,
- const std::string& realname, const std::string& user_hostname,
+ explicit IrcClient(std::shared_ptr<Poller>& poller, std::string hostname,
+ std::string nickname, std::string username,
+ std::string realname, std::string user_hostname,
Bridge& bridge);
~IrcClient();
@@ -52,7 +52,7 @@ public:
/**
* Close the connection, remove us from the poller
*/
- void on_connection_close(const std::string& error) override final;
+ void on_connection_close(const std::string& error_msg) override final;
/**
* Parse the data we have received so far and try to get one or more
* complete messages from it.
@@ -222,6 +222,8 @@ public:
* received etc), send the self presence and topic to the XMPP user.
*/
void on_channel_completely_joined(const IrcMessage& message);
+ void on_banlist(const IrcMessage& message);
+ void on_banlist_end(const IrcMessage& message);
/**
* Save our own host, as reported by the server
*/
@@ -257,6 +259,7 @@ public:
void on_nick(const IrcMessage& message);
void on_kick(const IrcMessage& message);
void on_mode(const IrcMessage& message);
+ void on_channel_bad_key(const IrcMessage& message);
/**
* A mode towards our own user is received (note, that is different from a
* channel mode towards or own nick, see
@@ -282,7 +285,7 @@ public:
* Leave the dummy channel: forward a message to the user to indicate that
* he left it, and mark it as not joined.
*/
- void leave_dummy_channel(const std::string& exit_message);
+ void leave_dummy_channel(const std::string& exit_message, const std::string& resource);
const std::string& get_hostname() const { return this->hostname; }
std::string get_nick() const { return this->current_nick; }
diff --git a/src/irc/irc_message.cpp b/src/irc/irc_message.cpp
index 966a47c..14fdb0e 100644
--- a/src/irc/irc_message.cpp
+++ b/src/irc/irc_message.cpp
@@ -8,12 +8,12 @@ IrcMessage::IrcMessage(std::string&& line)
// optional prefix
if (line[0] == ':')
{
- pos = line.find(" ");
+ pos = line.find(' ');
this->prefix = line.substr(1, pos - 1);
line = line.substr(pos + 1, std::string::npos);
}
// command
- pos = line.find(" ");
+ pos = line.find(' ');
this->command = line.substr(0, pos);
line = line.substr(pos + 1, std::string::npos);
// arguments
@@ -24,7 +24,7 @@ IrcMessage::IrcMessage(std::string&& line)
this->arguments.emplace_back(line.substr(1, std::string::npos));
break ;
}
- pos = line.find(" ");
+ pos = line.find(' ');
this->arguments.emplace_back(line.substr(0, pos));
line = line.substr(pos + 1, std::string::npos);
} while (pos != std::string::npos);
diff --git a/src/irc/irc_user.cpp b/src/irc/irc_user.cpp
index 9fa3612..139015e 100644
--- a/src/irc/irc_user.cpp
+++ b/src/irc/irc_user.cpp
@@ -21,7 +21,7 @@ IrcUser::IrcUser(const std::string& name,
name_begin++;
}
- const std::string::size_type sep = name.find("!", name_begin);
+ const std::string::size_type sep = name.find('!', name_begin);
if (sep == std::string::npos)
this->nick = name.substr(name_begin);
else
diff --git a/src/irc/irc_user.hpp b/src/irc/irc_user.hpp
index c84030e..a4291d4 100644
--- a/src/irc/irc_user.hpp
+++ b/src/irc/irc_user.hpp
@@ -23,7 +23,7 @@ public:
void add_mode(const char mode);
void remove_mode(const char mode);
- char get_most_significant_mode(const std::vector<char>& sorted_user_modes) const;
+ char get_most_significant_mode(const std::vector<char>& modes) const;
std::string nick;
std::string host;
diff --git a/src/logger/logger.cpp b/src/logger/logger.cpp
new file mode 100644
index 0000000..92a3d9b
--- /dev/null
+++ b/src/logger/logger.cpp
@@ -0,0 +1,42 @@
+#include <logger/logger.hpp>
+#include <config/config.hpp>
+
+Logger::Logger(const int log_level):
+ log_level(log_level),
+ stream(std::cout.rdbuf()),
+ null_buffer{},
+ null_stream{&null_buffer}
+{
+}
+
+Logger::Logger(const int log_level, const std::string& log_file):
+ log_level(log_level),
+ ofstream(log_file.data(), std::ios_base::app),
+ stream(ofstream.rdbuf()),
+ null_buffer{},
+ null_stream{&null_buffer}
+{
+}
+
+std::unique_ptr<Logger>& Logger::instance()
+{
+ static std::unique_ptr<Logger> instance;
+
+ if (!instance)
+ {
+ const std::string log_file = Config::get("log_file", "");
+ const int log_level = Config::get_int("log_level", 0);
+ if (log_file.empty())
+ instance = std::make_unique<Logger>(log_level);
+ else
+ instance = std::make_unique<Logger>(log_level, log_file);
+ }
+ return instance;
+}
+
+std::ostream& Logger::get_stream(const int lvl)
+{
+ if (lvl >= this->log_level)
+ return this->stream;
+ return this->null_stream;
+}
diff --git a/src/logger/logger.hpp b/src/logger/logger.hpp
new file mode 100644
index 0000000..ff6a82b
--- /dev/null
+++ b/src/logger/logger.hpp
@@ -0,0 +1,128 @@
+#pragma once
+
+
+/**
+ * Singleton used in logger macros to write into files or stdout, with
+ * various levels of severity.
+ * Only the macros should be used.
+ * @class Logger
+ */
+
+#include <memory>
+#include <iostream>
+#include <fstream>
+
+#define debug_lvl 0
+#define info_lvl 1
+#define warning_lvl 2
+#define error_lvl 3
+
+#include "biboumi.h"
+#ifdef SYSTEMD_FOUND
+# include <systemd/sd-daemon.h>
+#else
+# define SD_DEBUG "[DEBUG]: "
+# define SD_INFO "[INFO]: "
+# define SD_WARNING "[WARNING]: "
+# define SD_ERR "[ERROR]: "
+#endif
+
+// Macro defined to get the filename instead of the full path. But if it is
+// not properly defined by the build system, we fallback to __FILE__
+#ifndef __FILENAME__
+# define __FILENAME__ __FILE__
+#endif
+
+
+/**
+ * A buffer, used to construct an ostream that does nothing
+ * when we output data in it
+ */
+class NullBuffer: public std::streambuf
+{
+ public:
+ int overflow(int c) { return c; }
+};
+
+class Logger
+{
+public:
+ static std::unique_ptr<Logger>& instance();
+ std::ostream& get_stream(const int);
+ Logger(const int log_level, const std::string& log_file);
+ Logger(const int log_level);
+
+ Logger(const Logger&) = delete;
+ Logger& operator=(const Logger&) = delete;
+ Logger(Logger&&) = delete;
+ Logger& operator=(Logger&&) = delete;
+
+private:
+ const int log_level;
+ std::ofstream ofstream{};
+ std::ostream stream;
+
+ NullBuffer null_buffer;
+ std::ostream null_stream;
+};
+
+#define WHERE __FILENAME__, ":", __LINE__, ":\t"
+
+namespace logging_details
+{
+ template <typename T>
+ void log(std::ostream& os, const T& arg)
+ {
+ os << arg << std::endl;
+ }
+
+ template <typename T, typename... U>
+ void log(std::ostream& os, const T& first, U&&... rest)
+ {
+ os << first;
+ log(os, std::forward<U>(rest)...);
+ }
+
+ template <typename... U>
+ void log_debug(U&&... args)
+ {
+ auto& os = Logger::instance()->get_stream(debug_lvl);
+ os << SD_DEBUG;
+ log(os, std::forward<U>(args)...);
+ }
+
+ template <typename... U>
+ void log_info(U&&... args)
+ {
+ auto& os = Logger::instance()->get_stream(info_lvl);
+ os << SD_INFO;
+ log(os, std::forward<U>(args)...);
+ }
+
+ template <typename... U>
+ void log_warning(U&&... args)
+ {
+ auto& os = Logger::instance()->get_stream(warning_lvl);
+ os << SD_WARNING;
+ log(os, std::forward<U>(args)...);
+ }
+
+ template <typename... U>
+ void log_error(U&&... args)
+ {
+ auto& os = Logger::instance()->get_stream(error_lvl);
+ os << SD_ERR;
+ log(os, std::forward<U>(args)...);
+ }
+}
+
+#define log_info(...) logging_details::log_info(WHERE, __VA_ARGS__)
+
+#define log_warning(...) logging_details::log_warning(WHERE, __VA_ARGS__)
+
+#define log_error(...) logging_details::log_error(WHERE, __VA_ARGS__)
+
+#define log_debug(...) logging_details::log_debug(WHERE, __VA_ARGS__)
+
+
+
diff --git a/src/main.cpp b/src/main.cpp
index 488032d..5725584 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -6,15 +6,14 @@
#include <utils/xdg.hpp>
#include <utils/reload.hpp>
-#ifdef CARES_FOUND
+#ifdef UDNS_FOUND
# include <network/dns_handler.hpp>
#endif
#include <atomic>
-#include <signal.h>
-#ifdef USE_DATABASE
-# include <litesql.hpp>
-#endif
+#include <csignal>
+
+#include <identd/identd_server.hpp>
// A flag set by the SIGINT signal handler.
static std::atomic<bool> stop(false);
@@ -89,7 +88,7 @@ int main(int ac, char** av)
#ifdef USE_DATABASE
try {
open_database();
- } catch (const litesql::DatabaseError&) {
+ } catch (...) {
return 1;
}
#endif
@@ -97,7 +96,7 @@ int main(int ac, char** av)
// Block the signals we want to manage. They will be unblocked only during
// the epoll_pwait or ppoll calls. This avoids some race conditions,
// explained in man 2 pselect on linux
- sigset_t mask;
+ sigset_t mask{};
sigemptyset(&mask);
sigaddset(&mask, SIGINT);
sigaddset(&mask, SIGTERM);
@@ -126,13 +125,17 @@ int main(int ac, char** av)
sigaction(SIGUSR2, &on_sigusr, nullptr);
auto p = std::make_shared<Poller>();
+
+#ifdef UDNS_FOUND
+ DNSHandler dns_handler(p);
+#endif
+
auto xmpp_component =
std::make_shared<BiboumiComponent>(p, hostname, password);
xmpp_component->start();
-#ifdef CARES_FOUND
- DNSHandler::instance.watch_dns_sockets(p);
-#endif
+ IdentdServer identd(*xmpp_component, p, static_cast<uint16_t>(Config::get_int("identd_port", 113)));
+
auto timeout = TimedEventsManager::instance().get_timeout();
while (p->poll(timeout) != -1)
{
@@ -140,6 +143,7 @@ int main(int ac, char** av)
// Check for empty irc_clients (not connected, or with no joined
// channel) and remove them
xmpp_component->clean();
+ identd.clean();
if (stop)
{
log_info("Signal received, exiting...");
@@ -149,6 +153,10 @@ int main(int ac, char** av)
exiting = true;
stop.store(false);
xmpp_component->shutdown();
+#ifdef UDNS_FOUND
+ dns_handler.destroy();
+#endif
+ identd.shutdown();
// Cancel the timer for a potential reconnection
TimedEventsManager::instance().cancel("XMPP reconnection");
}
@@ -162,26 +170,36 @@ int main(int ac, char** av)
// happened because we sent something invalid to it and it decided to
// close the connection. This is a bug that should be fixed, but we
// still reconnect automatically instead of dropping everything
- if (!exiting && xmpp_component->ever_auth &&
+ if (!exiting &&
!xmpp_component->is_connected() &&
!xmpp_component->is_connecting())
{
- if (xmpp_component->first_connection_try == true)
- { // immediately re-try to connect
- xmpp_component->reset();
- xmpp_component->start();
- }
+ if (xmpp_component->ever_auth)
+ {
+ static const std::string reconnect_name{"XMPP reconnection"};
+ if (xmpp_component->first_connection_try == true)
+ { // immediately re-try to connect
+ xmpp_component->reset();
+ xmpp_component->start();
+ }
+ else if (!TimedEventsManager::instance().find_event(reconnect_name))
+ { // Re-connecting failed, we now try only each few seconds
+ auto reconnect_later = [xmpp_component]()
+ {
+ xmpp_component->reset();
+ xmpp_component->start();
+ };
+ TimedEvent event(std::chrono::steady_clock::now() + 2s, reconnect_later, reconnect_name);
+ TimedEventsManager::instance().add_event(std::move(event));
+ }
+ }
else
- { // Re-connecting failed, we now try only each few seconds
- auto reconnect_later = [xmpp_component]()
{
- xmpp_component->reset();
- xmpp_component->start();
- };
- TimedEvent event(std::chrono::steady_clock::now() + 2s,
- reconnect_later, "XMPP reconnection");
- TimedEventsManager::instance().add_event(std::move(event));
- }
+#ifdef UDNS_FOUND
+ dns_handler.destroy();
+#endif
+ identd.shutdown();
+ }
}
// If the only existing connection is the one to the XMPP component:
// close the XMPP stream.
@@ -189,18 +207,11 @@ int main(int ac, char** av)
xmpp_component->close();
if (exiting && p->size() == 1 && xmpp_component->is_document_open())
xmpp_component->close_document();
-#ifdef CARES_FOUND
- if (!exiting)
- DNSHandler::instance.watch_dns_sockets(p);
-#endif
if (exiting) // If we are exiting, do not wait for any timed event
timeout = utils::no_timeout;
else
timeout = TimedEventsManager::instance().get_timeout();
}
-#ifdef CARES_FOUND
- DNSHandler::instance.destroy();
-#endif
if (!xmpp_component->ever_auth)
return 1; // To signal that the process did not properly start
log_info("All connections cleanly closed, have a nice day.");
diff --git a/src/network/credentials_manager.cpp b/src/network/credentials_manager.cpp
new file mode 100644
index 0000000..f93a366
--- /dev/null
+++ b/src/network/credentials_manager.cpp
@@ -0,0 +1,136 @@
+#include "biboumi.h"
+
+#ifdef BOTAN_FOUND
+#include <network/tcp_socket_handler.hpp>
+#include <network/credentials_manager.hpp>
+#include <logger/logger.hpp>
+#include <botan/tls_exceptn.h>
+#include <config/config.hpp>
+
+/**
+ * TODO find a standard way to find that out.
+ */
+static const std::vector<std::string> default_cert_files = {
+ "/etc/ssl/certs/ca-bundle.crt",
+ "/etc/pki/tls/certs/ca-bundle.crt",
+ "/etc/ssl/certs/ca-certificates.crt",
+ "/etc/ca-certificates/extracted/tls-ca-bundle.pem"
+};
+
+Botan::Certificate_Store_In_Memory BasicCredentialsManager::certificate_store;
+bool BasicCredentialsManager::certs_loaded = false;
+
+BasicCredentialsManager::BasicCredentialsManager(const TCPSocketHandler* const socket_handler):
+ Botan::Credentials_Manager(),
+ socket_handler(socket_handler),
+ trusted_fingerprint{}
+{
+ BasicCredentialsManager::load_certs();
+}
+
+void BasicCredentialsManager::set_trusted_fingerprint(const std::string& fingerprint)
+{
+ this->trusted_fingerprint = fingerprint;
+}
+
+const std::string& BasicCredentialsManager::get_trusted_fingerprint() const
+{
+ return this->trusted_fingerprint;
+}
+
+void check_tls_certificate(const std::vector<Botan::X509_Certificate>& certs,
+ const std::string& hostname, const std::string& trusted_fingerprint,
+ const std::exception_ptr& exc)
+{
+
+ if (!trusted_fingerprint.empty() && !certs.empty() &&
+ trusted_fingerprint == certs[0].fingerprint() &&
+ certs[0].matches_dns_name(hostname))
+ // We trust the certificate, based on the trusted fingerprint and
+ // the fact that the hostname matches
+ return;
+
+ if (exc)
+ std::rethrow_exception(exc);
+}
+
+#if BOTAN_VERSION_CODE < BOTAN_VERSION_CODE_FOR(1,11,34)
+void BasicCredentialsManager::verify_certificate_chain(const std::string& type,
+ const std::string& purported_hostname,
+ const std::vector<Botan::X509_Certificate>& certs)
+{
+ log_debug("Checking remote certificate (", type, ") for hostname ", purported_hostname);
+ try
+ {
+ Botan::Credentials_Manager::verify_certificate_chain(type, purported_hostname, certs);
+ log_debug("Certificate is valid");
+ }
+ catch (const std::exception& tls_exception)
+ {
+ log_warning("TLS certificate check failed: ", tls_exception.what());
+ std::exception_ptr exception_ptr{};
+ if (this->socket_handler->abort_on_invalid_cert())
+ exception_ptr = std::current_exception();
+
+ check_tls_certificate(certs, purported_hostname, this->trusted_fingerprint, exception_ptr);
+ }
+}
+#endif
+
+bool BasicCredentialsManager::try_to_open_one_ca_bundle(const std::vector<std::string>& paths)
+{
+ for (const auto& path: paths)
+ {
+ try
+ {
+ Botan::DataSource_Stream bundle(path);
+ log_debug("Using ca bundle: ", path);
+ while (!bundle.end_of_data() && bundle.check_available(27))
+ {
+ // TODO: remove this work-around for Botan 1.11.29
+ // https://github.com/randombit/botan/issues/438#issuecomment-192866796
+ // Note that every certificate that fails to be transcoded into latin-1
+ // will be ignored. As a result, some TLS connection may be refused
+ // because the certificate is signed by an issuer that was ignored.
+ try {
+ Botan::X509_Certificate cert(bundle);
+ BasicCredentialsManager::certificate_store.add_certificate(cert);
+ } catch (const Botan::Decoding_Error& error) {
+ continue;
+ }
+ }
+ // Only use the first file that can successfully be read.
+ return true;
+ }
+ catch (const Botan::Stream_IO_Error& e)
+ {
+ log_debug(e.what());
+ }
+ }
+ return false;
+}
+
+void BasicCredentialsManager::load_certs()
+{
+ // Only load the certificates the first time
+ if (BasicCredentialsManager::certs_loaded)
+ return;
+ const std::string conf_path = Config::get("ca_file", "");
+ std::vector<std::string> paths;
+ if (conf_path.empty())
+ paths = default_cert_files;
+ else
+ paths.push_back(conf_path);
+
+ if (BasicCredentialsManager::try_to_open_one_ca_bundle(paths))
+ BasicCredentialsManager::certs_loaded = true;
+ else
+ log_warning("The CA could not be loaded, TLS negociation will probably fail.");
+}
+
+std::vector<Botan::Certificate_Store*> BasicCredentialsManager::trusted_certificate_authorities(const std::string&, const std::string&)
+{
+ return {&this->certificate_store};
+}
+
+#endif
diff --git a/src/network/credentials_manager.hpp b/src/network/credentials_manager.hpp
new file mode 100644
index 0000000..e7c247d
--- /dev/null
+++ b/src/network/credentials_manager.hpp
@@ -0,0 +1,55 @@
+#pragma once
+
+#include "biboumi.h"
+
+#ifdef BOTAN_FOUND
+
+#include <botan/botan.h>
+#include <botan/tls_client.h>
+
+class TCPSocketHandler;
+
+/**
+ * If the given cert isn’t valid, based on the given hostname
+ * and fingerprint, then throws the exception if it’s non-empty.
+ *
+ * Must be called after the standard (from Botan) way of
+ * checking the certificate, if we want to also accept certificates based
+ * on a trusted fingerprint.
+ */
+void check_tls_certificate(const std::vector<Botan::X509_Certificate>& certs,
+ const std::string& hostname, const std::string& trusted_fingerprint,
+ const std::exception_ptr& exc);
+
+class BasicCredentialsManager: public Botan::Credentials_Manager
+{
+public:
+ BasicCredentialsManager(const TCPSocketHandler* const socket_handler);
+
+ BasicCredentialsManager(BasicCredentialsManager&&) = delete;
+ BasicCredentialsManager(const BasicCredentialsManager&) = delete;
+ BasicCredentialsManager& operator=(const BasicCredentialsManager&) = delete;
+ BasicCredentialsManager& operator=(BasicCredentialsManager&&) = delete;
+
+#if BOTAN_VERSION_CODE < BOTAN_VERSION_CODE_FOR(1,11,34)
+ void verify_certificate_chain(const std::string& type,
+ const std::string& purported_hostname,
+ const std::vector<Botan::X509_Certificate>&) override final;
+#endif
+ std::vector<Botan::Certificate_Store*> trusted_certificate_authorities(const std::string& type,
+ const std::string& context) override final;
+ void set_trusted_fingerprint(const std::string& fingerprint);
+ const std::string& get_trusted_fingerprint() const;
+
+private:
+ const TCPSocketHandler* const socket_handler;
+
+ static bool try_to_open_one_ca_bundle(const std::vector<std::string>& paths);
+ static void load_certs();
+ static Botan::Certificate_Store_In_Memory certificate_store;
+ static bool certs_loaded;
+ std::string trusted_fingerprint;
+};
+
+#endif //BOTAN_FOUND
+
diff --git a/src/network/dns_handler.cpp b/src/network/dns_handler.cpp
new file mode 100644
index 0000000..7f0c96a
--- /dev/null
+++ b/src/network/dns_handler.cpp
@@ -0,0 +1,46 @@
+#include <biboumi.h>
+#ifdef UDNS_FOUND
+
+#include <network/dns_socket_handler.hpp>
+#include <network/dns_handler.hpp>
+#include <network/poller.hpp>
+
+#include <utils/timed_events.hpp>
+
+#include <udns.h>
+#include <cerrno>
+#include <cstring>
+
+class Resolver;
+
+using namespace std::string_literals;
+
+std::unique_ptr<DNSSocketHandler> DNSHandler::socket_handler{};
+
+DNSHandler::DNSHandler(std::shared_ptr<Poller>& poller)
+{
+ dns_init(nullptr, 0);
+ const auto socket = dns_open(nullptr);
+ if (socket == -1)
+ throw std::runtime_error("Failed to initialize udns socket: "s + strerror(errno));
+
+ DNSHandler::socket_handler = std::make_unique<DNSSocketHandler>(poller, socket);
+}
+
+void DNSHandler::destroy()
+{
+ DNSHandler::socket_handler.reset(nullptr);
+ dns_close(nullptr);
+}
+
+void DNSHandler::watch()
+{
+ DNSHandler::socket_handler->watch();
+}
+
+void DNSHandler::unwatch()
+{
+ DNSHandler::socket_handler->unwatch();
+}
+
+#endif /* UDNS_FOUND */
diff --git a/src/network/dns_handler.hpp b/src/network/dns_handler.hpp
new file mode 100644
index 0000000..c694452
--- /dev/null
+++ b/src/network/dns_handler.hpp
@@ -0,0 +1,37 @@
+#pragma once
+
+#include <biboumi.h>
+#ifdef UDNS_FOUND
+
+class Poller;
+
+#include <network/dns_socket_handler.hpp>
+
+#include <string>
+#include <vector>
+#include <memory>
+
+class DNSHandler
+{
+public:
+ explicit DNSHandler(std::shared_ptr<Poller>& poller);
+ ~DNSHandler() = default;
+
+ DNSHandler(const DNSHandler&) = delete;
+ DNSHandler(DNSHandler&&) = delete;
+ DNSHandler& operator=(const DNSHandler&) = delete;
+ DNSHandler& operator=(DNSHandler&&) = delete;
+
+ void destroy();
+
+ static void watch();
+ static void unwatch();
+
+private:
+ /**
+ * Manager for the socket returned by udns, that we need to watch with the poller
+ */
+ static std::unique_ptr<DNSSocketHandler> socket_handler;
+};
+
+#endif /* UDNS_FOUND */
diff --git a/src/network/dns_socket_handler.cpp b/src/network/dns_socket_handler.cpp
new file mode 100644
index 0000000..5c286c4
--- /dev/null
+++ b/src/network/dns_socket_handler.cpp
@@ -0,0 +1,43 @@
+#include <biboumi.h>
+#ifdef UDNS_FOUND
+
+#include <network/dns_socket_handler.hpp>
+#include <network/dns_handler.hpp>
+#include <network/poller.hpp>
+
+#include <udns.h>
+
+DNSSocketHandler::DNSSocketHandler(std::shared_ptr<Poller>& poller,
+ const socket_t socket):
+ SocketHandler(poller, socket)
+{
+ poller->add_socket_handler(this);
+}
+
+DNSSocketHandler::~DNSSocketHandler()
+{
+ this->unwatch();
+}
+
+void DNSSocketHandler::on_recv()
+{
+ dns_ioevent(nullptr, 0);
+}
+
+bool DNSSocketHandler::is_connected() const
+{
+ return true;
+}
+
+void DNSSocketHandler::unwatch()
+{
+ if (this->poller->is_managing_socket(this->socket))
+ this->poller->remove_socket_handler(this->socket);
+}
+
+void DNSSocketHandler::watch()
+{
+ this->poller->add_socket_handler(this);
+}
+
+#endif /* UDNS_FOUND */
diff --git a/src/network/dns_socket_handler.hpp b/src/network/dns_socket_handler.hpp
new file mode 100644
index 0000000..6e83e87
--- /dev/null
+++ b/src/network/dns_socket_handler.hpp
@@ -0,0 +1,33 @@
+#pragma once
+
+#include <biboumi.h>
+#ifdef UDNS_FOUND
+
+#include <network/socket_handler.hpp>
+
+/**
+ * Manage the UDP socket provided by udns, we do not create, open or close the
+ * socket ourself: this is done by udns. We only watch it for readability
+ */
+class DNSSocketHandler: public SocketHandler
+{
+public:
+ explicit DNSSocketHandler(std::shared_ptr<Poller>& poller, const socket_t socket);
+ ~DNSSocketHandler();
+ DNSSocketHandler(const DNSSocketHandler&) = delete;
+ DNSSocketHandler(DNSSocketHandler&&) = delete;
+ DNSSocketHandler& operator=(const DNSSocketHandler&) = delete;
+ DNSSocketHandler& operator=(DNSSocketHandler&&) = delete;
+
+ void on_recv() override final;
+
+ /**
+ * Always true, see the comment for connect()
+ */
+ bool is_connected() const override final;
+
+ void watch();
+ void unwatch();
+};
+
+#endif // UDNS_FOUND
diff --git a/src/network/poller.cpp b/src/network/poller.cpp
new file mode 100644
index 0000000..0f02cc5
--- /dev/null
+++ b/src/network/poller.cpp
@@ -0,0 +1,238 @@
+#include <network/poller.hpp>
+#include <logger/logger.hpp>
+#include <utils/timed_events.hpp>
+
+#include <cassert>
+#include <cerrno>
+#include <stdio.h>
+#include <signal.h>
+#include <unistd.h>
+
+#include <cstring>
+#include <iostream>
+#include <stdexcept>
+
+Poller::Poller()
+{
+#if POLLER == POLL
+ this->nfds = 0;
+#elif POLLER == EPOLL
+ this->epfd = ::epoll_create1(0);
+ if (this->epfd == -1)
+ {
+ log_error("epoll failed: ", strerror(errno));
+ throw std::runtime_error("Could not create epoll instance");
+ }
+#endif
+}
+
+Poller::~Poller()
+{
+#if POLLER == EPOLL
+ if (this->epfd > 0)
+ ::close(this->epfd);
+#endif
+}
+
+void Poller::add_socket_handler(SocketHandler* socket_handler)
+{
+ // Don't do anything if the socket is already managed
+ const auto it = this->socket_handlers.find(socket_handler->get_socket());
+ if (it != this->socket_handlers.end())
+ return ;
+
+ this->socket_handlers.emplace(socket_handler->get_socket(), socket_handler);
+
+ // We always watch all sockets for receive events
+#if POLLER == POLL
+ this->fds[this->nfds].fd = socket_handler->get_socket();
+ this->fds[this->nfds].events = POLLIN;
+ this->nfds++;
+#endif
+#if POLLER == EPOLL
+ struct epoll_event event = {EPOLLIN, {socket_handler}};
+ const int res = ::epoll_ctl(this->epfd, EPOLL_CTL_ADD, socket_handler->get_socket(), &event);
+ if (res == -1)
+ {
+ log_error("epoll_ctl failed: ", strerror(errno));
+ throw std::runtime_error("Could not add socket to epoll");
+ }
+#endif
+}
+
+void Poller::remove_socket_handler(const socket_t socket)
+{
+ const auto it = this->socket_handlers.find(socket);
+ if (it == this->socket_handlers.end())
+ throw std::runtime_error("Trying to remove a SocketHandler that is not managed");
+ this->socket_handlers.erase(it);
+
+#if POLLER == POLL
+ for (size_t i = 0; i < this->nfds; i++)
+ {
+ if (this->fds[i].fd == socket)
+ {
+ // Move all subsequent pollfd by one on the left, erasing the
+ // value of the one we remove
+ for (size_t j = i; j < this->nfds - 1; ++j)
+ {
+ this->fds[j].fd = this->fds[j+1].fd;
+ this->fds[j].events= this->fds[j+1].events;
+ }
+ this->nfds--;
+ }
+ }
+#elif POLLER == EPOLL
+ const int res = ::epoll_ctl(this->epfd, EPOLL_CTL_DEL, socket, nullptr);
+ if (res == -1)
+ {
+ log_error("epoll_ctl failed: ", strerror(errno));
+ throw std::runtime_error("Could not remove socket from epoll");
+ }
+#endif
+}
+
+void Poller::watch_send_events(SocketHandler* socket_handler)
+{
+#if POLLER == POLL
+ for (size_t i = 0; i < this->nfds; ++i)
+ {
+ if (this->fds[i].fd == socket_handler->get_socket())
+ {
+ this->fds[i].events = POLLIN|POLLOUT;
+ return;
+ }
+ }
+ throw std::runtime_error("Cannot watch a non-registered socket for send events");
+#elif POLLER == EPOLL
+ struct epoll_event event = {EPOLLIN|EPOLLOUT, {socket_handler}};
+ const int res = ::epoll_ctl(this->epfd, EPOLL_CTL_MOD, socket_handler->get_socket(), &event);
+ if (res == -1)
+ {
+ log_error("epoll_ctl failed: ", strerror(errno));
+ throw std::runtime_error("Could not modify socket flags in epoll");
+ }
+#endif
+}
+
+void Poller::stop_watching_send_events(SocketHandler* socket_handler)
+{
+#if POLLER == POLL
+ for (size_t i = 0; i <= this->nfds; ++i)
+ {
+ if (this->fds[i].fd == socket_handler->get_socket())
+ {
+ this->fds[i].events = POLLIN;
+ return;
+ }
+ }
+ throw std::runtime_error("Cannot watch a non-registered socket for send events");
+#elif POLLER == EPOLL
+ struct epoll_event event = {EPOLLIN, {socket_handler}};
+ const int res = ::epoll_ctl(this->epfd, EPOLL_CTL_MOD, socket_handler->get_socket(), &event);
+ if (res == -1)
+ {
+ log_error("epoll_ctl failed: ", strerror(errno));
+ throw std::runtime_error("Could not modify socket flags in epoll");
+ }
+#endif
+}
+
+int Poller::poll(const std::chrono::milliseconds& timeout)
+{
+ if (this->socket_handlers.empty() && timeout == utils::no_timeout)
+ return -1;
+#if POLLER == POLL
+ // Convert our nice timeout into this ugly struct
+ struct timespec timeout_ts;
+ struct timespec* timeout_tsp;
+ if (timeout > 0s)
+ {
+ auto seconds = std::chrono::duration_cast<std::chrono::seconds>(timeout);
+ timeout_ts.tv_sec = seconds.count();
+ timeout_ts.tv_nsec = std::chrono::duration_cast<std::chrono::nanoseconds>(timeout - seconds).count();
+ timeout_tsp = &timeout_ts;
+ }
+ else
+ timeout_tsp = nullptr;
+
+ // Unblock all signals, only during the ppoll call
+ sigset_t empty_signal_set;
+ sigemptyset(&empty_signal_set);
+ int nb_events = ::ppoll(this->fds, this->nfds, timeout_tsp,
+ &empty_signal_set);
+ if (nb_events < 0)
+ {
+ if (errno == EINTR)
+ return true;
+ log_error("poll failed: ", strerror(errno));
+ throw std::runtime_error("Poll failed");
+ }
+ // We cannot possibly have more ready events than the number of fds we are
+ // watching
+ assert(static_cast<unsigned int>(nb_events) <= this->nfds);
+ for (size_t i = 0; i < this->nfds && nb_events != 0; ++i)
+ {
+ auto socket_handler = this->socket_handlers.at(this->fds[i].fd);
+ if (this->fds[i].revents == 0)
+ continue;
+ else if (this->fds[i].revents & POLLIN && socket_handler->is_connected())
+ {
+ socket_handler->on_recv();
+ nb_events--;
+ }
+ else if (this->fds[i].revents & POLLOUT && socket_handler->is_connected())
+ {
+ socket_handler->on_send();
+ nb_events--;
+ }
+ else if (this->fds[i].revents & POLLOUT ||
+ this->fds[i].revents & POLLIN)
+ {
+ socket_handler->connect();
+ nb_events--;
+ }
+ }
+ return 1;
+#elif POLLER == EPOLL
+ static const size_t max_events = 12;
+ struct epoll_event revents[max_events];
+ // Unblock all signals, only during the epoll_pwait call
+ sigset_t empty_signal_set{};
+ sigemptyset(&empty_signal_set);
+
+ int real_timeout = std::numeric_limits<int>::max();
+ if (timeout.count() < real_timeout) // Just avoid any potential int overflow
+ real_timeout = static_cast<int>(timeout.count());
+ const int nb_events = ::epoll_pwait(this->epfd, revents, max_events, real_timeout,
+ &empty_signal_set);
+ if (nb_events == -1)
+ {
+ if (errno == EINTR)
+ return 0;
+ log_error("epoll wait: ", strerror(errno));
+ throw std::runtime_error("Epoll_wait failed");
+ }
+ for (int i = 0; i < nb_events; ++i)
+ {
+ auto socket_handler = static_cast<SocketHandler*>(revents[i].data.ptr);
+ if (revents[i].events & EPOLLIN && socket_handler->is_connected())
+ socket_handler->on_recv();
+ else if (revents[i].events & EPOLLOUT && socket_handler->is_connected())
+ socket_handler->on_send();
+ else if (revents[i].events & EPOLLOUT)
+ socket_handler->connect();
+ }
+ return nb_events;
+#endif
+}
+
+size_t Poller::size() const
+{
+ return this->socket_handlers.size();
+}
+
+bool Poller::is_managing_socket(const socket_t socket) const
+{
+ return (this->socket_handlers.find(socket) != this->socket_handlers.end());
+}
diff --git a/src/network/poller.hpp b/src/network/poller.hpp
new file mode 100644
index 0000000..3cc2710
--- /dev/null
+++ b/src/network/poller.hpp
@@ -0,0 +1,98 @@
+#pragma once
+
+
+#include <network/socket_handler.hpp>
+
+#include <unordered_map>
+#include <memory>
+#include <chrono>
+
+#define POLL 1
+#define EPOLL 2
+#define KQUEUE 3
+#include <biboumi.h>
+#ifndef POLLER
+ #define POLLER POLL
+#endif
+
+#if POLLER == POLL
+ #include <poll.h>
+ #define MAX_POLL_FD_NUMBER 4096
+#elif POLLER == EPOLL
+ #include <sys/epoll.h>
+#else
+ #error Invalid POLLER value
+#endif
+
+/**
+ * We pass some SocketHandlers to this Poller, which uses
+ * poll/epoll/kqueue/select etc to wait for events on these SocketHandlers,
+ * and call the callbacks when event occurs.
+ *
+ * TODO: support these pollers:
+ * - kqueue(2)
+ */
+
+class Poller
+{
+public:
+ explicit Poller();
+ ~Poller();
+ Poller(const Poller&) = delete;
+ Poller(Poller&&) = delete;
+ Poller& operator=(const Poller&) = delete;
+ Poller& operator=(Poller&&) = delete;
+ /**
+ * Add a SocketHandler to be monitored by this Poller. All receive events
+ * are always automatically watched.
+ */
+ void add_socket_handler(SocketHandler* socket_handler);
+ /**
+ * Remove (and stop managing) a SocketHandler, designated by the given socket_t.
+ */
+ void remove_socket_handler(const socket_t socket);
+ /**
+ * Signal the poller that he needs to watch for send events for the given
+ * SocketHandler.
+ */
+ void watch_send_events(SocketHandler* socket_handler);
+ /**
+ * Signal the poller that he needs to stop watching for send events for
+ * this SocketHandler.
+ */
+ void stop_watching_send_events(SocketHandler* socket_handler);
+ /**
+ * Wait for all watched events, and call the SocketHandlers' callbacks
+ * when one is ready. Returns if nothing happened before the provided
+ * timeout. If the timeout is 0, it waits forever. If there is no
+ * watched event, returns -1 immediately, ignoring the timeout value.
+ * Otherwise, returns the number of event handled. If 0 is returned this
+ * means that we were interrupted by a signal, or the timeout occured.
+ */
+ int poll(const std::chrono::milliseconds& timeout);
+ /**
+ * Returns the number of SocketHandlers managed by the poller.
+ */
+ size_t size() const;
+ /**
+ * Whether the given socket is managed by the poller
+ */
+ bool is_managing_socket(const socket_t socket) const;
+
+private:
+ /**
+ * A "list" of all the SocketHandlers that we manage, indexed by socket,
+ * because that's what is returned by select/poll/etc when an event
+ * occures.
+ */
+ std::unordered_map<socket_t, SocketHandler*> socket_handlers;
+
+#if POLLER == POLL
+ struct pollfd fds[MAX_POLL_FD_NUMBER];
+ nfds_t nfds;
+#elif POLLER == EPOLL
+ int epfd;
+#endif
+};
+
+
diff --git a/src/network/resolver.cpp b/src/network/resolver.cpp
new file mode 100644
index 0000000..ef54ba2
--- /dev/null
+++ b/src/network/resolver.cpp
@@ -0,0 +1,280 @@
+#include <network/dns_handler.hpp>
+#include <utils/timed_events.hpp>
+#include <network/resolver.hpp>
+#include <cstring>
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#ifdef UDNS_FOUND
+# include <udns.h>
+#endif
+
+#include <fstream>
+#include <cstdlib>
+#include <sstream>
+#include <chrono>
+#include <map>
+
+using namespace std::string_literals;
+
+#ifdef UDNS_FOUND
+static std::map<int, std::string> dns_error_messages {
+ {DNS_E_TEMPFAIL, "Timeout while contacting DNS servers"},
+ {DNS_E_PROTOCOL, "Misformatted DNS reply"},
+ {DNS_E_NXDOMAIN, "Domain name not found"},
+ {DNS_E_NOMEM, "Out of memory"},
+ {DNS_E_BADQUERY, "Misformatted domain name"}
+};
+#endif
+
+Resolver::Resolver():
+#ifdef UDNS_FOUND
+ resolved4(false),
+ resolved6(false),
+ resolving(false),
+ port{},
+#endif
+ resolved(false),
+ error_msg{}
+{
+}
+
+void Resolver::resolve(const std::string& hostname, const std::string& port,
+ SuccessCallbackType success_cb, ErrorCallbackType error_cb)
+{
+ this->error_cb = std::move(error_cb);
+ this->success_cb = std::move(success_cb);
+#ifdef UDNS_FOUND
+ this->port = port;
+#endif
+
+ this->start_resolving(hostname, port);
+}
+
+int Resolver::call_getaddrinfo(const char *name, const char* port, int flags)
+{
+ struct addrinfo hints{};
+ hints.ai_flags = flags;
+ hints.ai_family = AF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_protocol = 0;
+
+ struct addrinfo* addr_res = nullptr;
+ const int res = ::getaddrinfo(name, port,
+ &hints, &addr_res);
+
+ if (res == 0 && addr_res)
+ {
+ if (!this->addr)
+ this->addr.reset(addr_res);
+ else
+ { // Append this result at the end of the linked list
+ struct addrinfo *rp = this->addr.get();
+ while (rp->ai_next)
+ rp = rp->ai_next;
+ rp->ai_next = addr_res;
+ }
+ }
+
+ return res;
+}
+
+#ifdef UDNS_FOUND
+void Resolver::start_resolving(const std::string& hostname, const std::string& port)
+{
+ this->resolving = true;
+ this->resolved = false;
+ this->resolved4 = false;
+ this->resolved6 = false;
+
+ this->error_msg.clear();
+ this->addr.reset(nullptr);
+
+ // We first try to use it as an IP address directly. We tell getaddrinfo
+ // to NOT use any DNS resolution.
+ if (this->call_getaddrinfo(hostname.data(), port.data(), AI_NUMERICHOST) == 0)
+ {
+ this->on_resolved();
+ return;
+ }
+
+ // Then we look into /etc/hosts to translate the given hostname
+ const auto hosts = this->look_in_etc_hosts(hostname);
+ if (!hosts.empty())
+ {
+ for (const auto &host: hosts)
+ this->call_getaddrinfo(host.data(), port.data(), AI_NUMERICHOST);
+ this->on_resolved();
+ return;
+ }
+
+ // And finally, we try a DNS resolution
+ auto hostname6_resolved = [](dns_ctx*, dns_rr_a6* result, void* data)
+ {
+ auto resolver = static_cast<Resolver*>(data);
+ resolver->on_hostname6_resolved(result);
+ resolver->after_resolved();
+ std::free(result);
+ };
+
+ auto hostname4_resolved = [](dns_ctx*, dns_rr_a4* result, void* data)
+ {
+ auto resolver = static_cast<Resolver*>(data);
+ resolver->on_hostname4_resolved(result);
+ resolver->after_resolved();
+ std::free(result);
+ };
+
+ DNSHandler::watch();
+ auto res = dns_submit_a4(nullptr, hostname.data(), 0, hostname4_resolved, this);
+ if (!res)
+ this->on_hostname4_resolved(nullptr);
+ res = dns_submit_a6(nullptr, hostname.data(), 0, hostname6_resolved, this);
+ if (!res)
+ this->on_hostname6_resolved(nullptr);
+
+ this->start_timer();
+}
+
+void Resolver::start_timer()
+{
+ const auto timeout = dns_timeouts(nullptr, -1, 0);
+ if (timeout < 0)
+ return;
+ TimedEvent event(std::chrono::steady_clock::now() + std::chrono::seconds(timeout), [this]() { this->start_timer(); }, "DNS");
+ TimedEventsManager::instance().add_event(std::move(event));
+}
+
+std::vector<std::string> Resolver::look_in_etc_hosts(const std::string &hostname)
+{
+ std::ifstream hosts("/etc/hosts");
+ std::string line;
+
+ std::vector<std::string> results;
+ while (std::getline(hosts, line))
+ {
+ if (line.empty())
+ continue;
+
+ std::string ip;
+ std::istringstream line_stream(line);
+ line_stream >> ip;
+ if (ip.empty() || ip[0] == '#')
+ continue;
+
+ std::string host;
+ while (line_stream >> host && !host.empty() && host[0] != '#')
+ {
+ if (hostname == host)
+ {
+ results.push_back(ip);
+ break;
+ }
+ }
+ }
+ return results;
+}
+
+void Resolver::on_hostname4_resolved(dns_rr_a4 *result)
+{
+ this->resolved4 = true;
+
+ const auto status = dns_status(nullptr);
+
+ if (status >= 0 && result)
+ {
+ char buf[INET6_ADDRSTRLEN];
+
+ for (auto i = 0; i < result->dnsa4_nrr; ++i)
+ {
+ inet_ntop(AF_INET, &result->dnsa4_addr[i], buf, sizeof(buf));
+ this->call_getaddrinfo(buf, this->port.data(), AI_NUMERICHOST);
+ }
+ }
+ else
+ {
+ const auto error = dns_error_messages.find(status);
+ if (error != end(dns_error_messages))
+ this->error_msg = error->second;
+ }
+}
+
+void Resolver::on_hostname6_resolved(dns_rr_a6 *result)
+{
+ this->resolved6 = true;
+
+ const auto status = dns_status(nullptr);
+
+ if (status >= 0 && result)
+ {
+ char buf[INET6_ADDRSTRLEN];
+ for (auto i = 0; i < result->dnsa6_nrr; ++i)
+ {
+ inet_ntop(AF_INET6, &result->dnsa6_addr[i], buf, sizeof(buf));
+ this->call_getaddrinfo(buf, this->port.data(), AI_NUMERICHOST);
+ }
+ }
+}
+
+void Resolver::after_resolved()
+{
+ if (dns_active(nullptr) == 0)
+ DNSHandler::unwatch();
+
+ if (this->resolved6 && this->resolved4)
+ this->on_resolved();
+}
+
+void Resolver::on_resolved()
+{
+ this->resolved = true;
+ this->resolving = false;
+ if (!this->addr)
+ {
+ if (this->error_cb)
+ this->error_cb(this->error_msg.data());
+ }
+ else
+ {
+ if (this->success_cb)
+ this->success_cb(this->addr.get());
+ }
+}
+
+#else // ifdef UDNS_FOUND
+
+void Resolver::start_resolving(const std::string& hostname, const std::string& port)
+{
+ // If the resolution fails, the addr will be unset
+ this->addr.reset(nullptr);
+
+ const auto res = this->call_getaddrinfo(hostname.data(), port.data(), 0);
+
+ this->resolved = true;
+
+ if (res != 0)
+ {
+ this->error_msg = gai_strerror(res);
+ if (this->error_cb)
+ this->error_cb(this->error_msg.data());
+ }
+ else
+ {
+ if (this->success_cb)
+ this->success_cb(this->addr.get());
+ }
+}
+#endif // ifdef UDNS_FOUND
+
+std::string addr_to_string(const struct addrinfo* rp)
+{
+ char buf[INET6_ADDRSTRLEN];
+ if (rp->ai_family == AF_INET)
+ return ::inet_ntop(rp->ai_family,
+ &reinterpret_cast<sockaddr_in*>(rp->ai_addr)->sin_addr,
+ buf, sizeof(buf));
+ else if (rp->ai_family == AF_INET6)
+ return ::inet_ntop(rp->ai_family,
+ &reinterpret_cast<sockaddr_in6*>(rp->ai_addr)->sin6_addr,
+ buf, sizeof(buf));
+ return {};
+}
diff --git a/src/network/resolver.hpp b/src/network/resolver.hpp
new file mode 100644
index 0000000..f65ff86
--- /dev/null
+++ b/src/network/resolver.hpp
@@ -0,0 +1,122 @@
+#pragma once
+
+#include "biboumi.h"
+
+#include <functional>
+#include <vector>
+#include <memory>
+#include <string>
+
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netdb.h>
+#ifdef UDNS_FOUND
+# include <udns.h>
+#endif
+
+class AddrinfoDeleter
+{
+ public:
+ void operator()(struct addrinfo* addr)
+ {
+ freeaddrinfo(addr);
+ }
+};
+
+
+class Resolver
+{
+public:
+
+ using ErrorCallbackType = std::function<void(const char*)>;
+ using SuccessCallbackType = std::function<void(const struct addrinfo*)>;
+
+ Resolver();
+ ~Resolver() = default;
+ Resolver(const Resolver&) = delete;
+ Resolver(Resolver&&) = delete;
+ Resolver& operator=(const Resolver&) = delete;
+ Resolver& operator=(Resolver&&) = delete;
+
+ bool is_resolving() const
+ {
+#ifdef UDNS_FOUND
+ return this->resolving;
+#else
+ return false;
+#endif
+ }
+
+ bool is_resolved() const
+ {
+ return this->resolved;
+ }
+
+ const auto& get_result() const
+ {
+ return this->addr;
+ }
+ std::string get_error_message() const
+ {
+ return this->error_msg;
+ }
+
+ void clear()
+ {
+#ifdef UDNS_FOUND
+ this->resolved6 = false;
+ this->resolved4 = false;
+ this->resolving = false;
+ this->port.clear();
+#endif
+ this->resolved = false;
+ this->addr.reset();
+ this->error_msg.clear();
+ }
+
+ void resolve(const std::string& hostname, const std::string& port,
+ SuccessCallbackType success_cb, ErrorCallbackType error_cb);
+
+private:
+ void start_resolving(const std::string& hostname, const std::string& port);
+ std::vector<std::string> look_in_etc_hosts(const std::string& hostname);
+ /**
+ * Call getaddrinfo() on the given hostname or IP, and append the result
+ * to our internal addrinfo list. Return getaddrinfo()’s return value.
+ */
+ int call_getaddrinfo(const char* name, const char* port, int flags);
+
+#ifdef UDNS_FOUND
+ void on_hostname4_resolved(dns_rr_a4 *result);
+ void on_hostname6_resolved(dns_rr_a6 *result);
+ /**
+ * Called after one record (4 or 6) has been resolved.
+ */
+ void after_resolved();
+
+ void start_timer();
+
+ void on_resolved();
+
+ bool resolved4;
+ bool resolved6;
+
+ bool resolving;
+
+ std::string port;
+
+#endif
+ /**
+ * Tells if we finished the resolution process. It doesn't indicate if it
+ * was successful (it is true even if the result is an error).
+ */
+ bool resolved;
+ std::string error_msg;
+
+ std::unique_ptr<struct addrinfo, AddrinfoDeleter> addr;
+
+ ErrorCallbackType error_cb;
+ SuccessCallbackType success_cb;
+};
+
+std::string addr_to_string(const struct addrinfo* rp);
diff --git a/src/network/socket_handler.hpp b/src/network/socket_handler.hpp
new file mode 100644
index 0000000..181a6c0
--- /dev/null
+++ b/src/network/socket_handler.hpp
@@ -0,0 +1,42 @@
+#pragma once
+
+#include <biboumi.h>
+#include <memory>
+
+class Poller;
+
+using socket_t = int;
+
+class SocketHandler
+{
+public:
+ explicit SocketHandler(std::shared_ptr<Poller>& poller, const socket_t socket):
+ poller(poller),
+ socket(socket)
+ {}
+ virtual ~SocketHandler() = default;
+ SocketHandler(const SocketHandler&) = delete;
+ SocketHandler(SocketHandler&&) = delete;
+ SocketHandler& operator=(const SocketHandler&) = delete;
+ SocketHandler& operator=(SocketHandler&&) = delete;
+
+ virtual void on_recv() {}
+ virtual void on_send() {}
+ virtual void connect() {}
+ virtual bool is_connected() const = 0;
+
+ socket_t get_socket() const
+ { return this->socket; }
+
+protected:
+ /**
+ * A pointer to the poller that manages us, because we need to communicate
+ * with it.
+ */
+ std::shared_ptr<Poller> poller;
+ /**
+ * The handled socket.
+ */
+ socket_t socket;
+};
+
diff --git a/src/network/tcp_client_socket_handler.cpp b/src/network/tcp_client_socket_handler.cpp
new file mode 100644
index 0000000..35f2446
--- /dev/null
+++ b/src/network/tcp_client_socket_handler.cpp
@@ -0,0 +1,261 @@
+#include <network/tcp_client_socket_handler.hpp>
+#include <utils/timed_events.hpp>
+#include <utils/scopeguard.hpp>
+#include <network/poller.hpp>
+
+#include <logger/logger.hpp>
+
+#include <cstring>
+#include <unistd.h>
+#include <fcntl.h>
+
+using namespace std::string_literals;
+
+TCPClientSocketHandler::TCPClientSocketHandler(std::shared_ptr<Poller>& poller):
+ TCPSocketHandler(poller),
+ hostname_resolution_failed(false),
+ connected(false),
+ connecting(false)
+{}
+
+TCPClientSocketHandler::~TCPClientSocketHandler()
+{
+ this->close();
+}
+
+void TCPClientSocketHandler::init_socket(const struct addrinfo* rp)
+{
+ if (this->socket != -1)
+ ::close(this->socket);
+ if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1)
+ throw std::runtime_error("Could not create socket: "s + std::strerror(errno));
+ // Bind the socket to a specific address, if specified
+ if (!this->bind_addr.empty())
+ {
+ // Convert the address from string format to a sockaddr that can be
+ // used in bind()
+ struct addrinfo* result;
+ struct addrinfo hints{};
+ memset(&hints, 0, sizeof(hints));
+ hints.ai_flags = AI_NUMERICHOST;
+ hints.ai_family = AF_UNSPEC;
+ int err = ::getaddrinfo(this->bind_addr.data(), nullptr, &hints, &result);
+ if (err != 0 || !result)
+ log_error("Failed to bind socket to ", this->bind_addr, ": ",
+ gai_strerror(err));
+ else
+ {
+ utils::ScopeGuard sg([result](){ freeaddrinfo(result); });
+ struct addrinfo* rp;
+ for (rp = result; rp; rp = rp->ai_next)
+ {
+ if ((::bind(this->socket,
+ reinterpret_cast<const struct sockaddr*>(rp->ai_addr),
+ rp->ai_addrlen)) == 0)
+ break;
+ }
+ if (!rp)
+ log_error("Failed to bind socket to ", this->bind_addr, ": ",
+ strerror(errno));
+ else
+ log_info("Socket successfully bound to ", this->bind_addr);
+ }
+ }
+ int optval = 1;
+ if (::setsockopt(this->socket, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) == -1)
+ log_warning("Failed to enable TCP keepalive on socket: ", strerror(errno));
+ // Set the socket on non-blocking mode. This is useful to receive a EAGAIN
+ // error when connect() would block, to not block the whole process if a
+ // remote is not responsive.
+ const int existing_flags = ::fcntl(this->socket, F_GETFL, 0);
+ if ((existing_flags == -1) ||
+ (::fcntl(this->socket, F_SETFL, existing_flags | O_NONBLOCK) == -1))
+ throw std::runtime_error("Could not initialize socket: "s + std::strerror(errno));
+}
+
+void TCPClientSocketHandler::connect(const std::string& address, const std::string& port, const bool tls)
+{
+ this->address = address;
+ this->port = port;
+ this->use_tls = tls;
+
+ struct addrinfo* addr_res;
+
+ if (!this->connecting)
+ {
+ // Get the addrinfo from getaddrinfo (or using udns), only if
+ // this is the first call of this function.
+ if (!this->resolver.is_resolved())
+ {
+ log_info("Trying to connect to ", address, ":", port);
+ // Start the asynchronous process of resolving the hostname. Once
+ // the addresses have been found and `resolved` has been set to true
+ // (but connecting will still be false), TCPClientSocketHandler::connect()
+ // needs to be called, again.
+ this->resolver.resolve(address, port,
+ [this](const struct addrinfo*)
+ {
+ log_debug("Resolution success, calling connect() again");
+ this->connect();
+ },
+ [this](const char*)
+ {
+ log_debug("Resolution failed, calling connect() again");
+ this->connect();
+ });
+ return;
+ }
+ else
+ {
+ // The DNS resolver resolved the hostname and the available addresses
+ // where saved in the addrinfo linked list. Now, just use
+ // this list to try to connect.
+ addr_res = this->resolver.get_result().get();
+ if (!addr_res)
+ {
+ this->hostname_resolution_failed = true;
+ const auto msg = this->resolver.get_error_message();
+ this->close();
+ this->on_connection_failed(msg);
+ return ;
+ }
+ }
+ }
+ else
+ { // This function is called again, use the saved addrinfo structure,
+ // instead of re-doing the whole getaddrinfo process.
+ addr_res = &this->addrinfo;
+ }
+
+ for (struct addrinfo* rp = addr_res; rp; rp = rp->ai_next)
+ {
+ if (!this->connecting)
+ {
+ try {
+ this->init_socket(rp);
+ }
+ catch (const std::runtime_error& error) {
+ log_error("Failed to init socket: ", error.what());
+ break;
+ }
+ }
+
+ this->display_resolved_ip(rp);
+
+ if (::connect(this->socket, rp->ai_addr, rp->ai_addrlen) == 0
+ || errno == EISCONN)
+ {
+ log_info("Connection success.");
+ TimedEventsManager::instance().cancel("connection_timeout"s +
+ std::to_string(this->socket));
+ this->poller->add_socket_handler(this);
+ this->connected = true;
+ this->connecting = false;
+#ifdef BOTAN_FOUND
+ if (this->use_tls)
+ this->start_tls(this->address, this->port);
+#endif
+ this->connection_date = std::chrono::system_clock::now();
+
+ // Get our local TCP port and store it
+ this->local_port = static_cast<uint16_t>(-1);
+ if (rp->ai_family == AF_INET6)
+ {
+ struct sockaddr_in6 a{};
+ socklen_t l = sizeof(a);
+ if (::getsockname(this->socket, (struct sockaddr*)&a, &l) != -1)
+ this->local_port = ntohs(a.sin6_port);
+ }
+ else if (rp->ai_family == AF_INET)
+ {
+ struct sockaddr_in a{};
+ socklen_t l = sizeof(a);
+ if (::getsockname(this->socket, (struct sockaddr*)&a, &l) != -1)
+ this->local_port = ntohs(a.sin_port);
+ }
+
+ log_debug("Local port: ", this->local_port, ", and remote port: ", this->port);
+
+ this->on_connected();
+ return ;
+ }
+ else if (errno == EINPROGRESS || errno == EALREADY)
+ { // retry this process later, when the socket
+ // is ready to be written on.
+ this->connecting = true;
+ this->poller->add_socket_handler(this);
+ this->poller->watch_send_events(this);
+ // Save the addrinfo structure, to use it on the next call
+ this->ai_addrlen = rp->ai_addrlen;
+ memcpy(&this->ai_addr, rp->ai_addr, this->ai_addrlen);
+ memcpy(&this->addrinfo, rp, sizeof(struct addrinfo));
+ this->addrinfo.ai_addr = reinterpret_cast<struct sockaddr*>(&this->ai_addr);
+ this->addrinfo.ai_next = nullptr;
+ // If the connection has not succeeded or failed in 5s, we consider
+ // it to have failed
+ TimedEventsManager::instance().add_event(
+ TimedEvent(std::chrono::steady_clock::now() + 5s,
+ std::bind(&TCPClientSocketHandler::on_connection_timeout, this),
+ "connection_timeout"s + std::to_string(this->socket)));
+ return ;
+ }
+ log_info("Connection failed:", std::strerror(errno));
+ }
+ log_error("All connection attempts failed.");
+ this->close();
+ this->on_connection_failed(std::strerror(errno));
+ return ;
+}
+
+void TCPClientSocketHandler::on_connection_timeout()
+{
+ this->close();
+ this->on_connection_failed("connection timed out");
+}
+
+void TCPClientSocketHandler::connect()
+{
+ this->connect(this->address, this->port, this->use_tls);
+}
+
+void TCPClientSocketHandler::close()
+{
+ TimedEventsManager::instance().cancel("connection_timeout"s +
+ std::to_string(this->socket));
+
+ TCPSocketHandler::close();
+
+ this->connected = false;
+ this->connecting = false;
+ this->port.clear();
+ this->resolver.clear();
+}
+
+void TCPClientSocketHandler::display_resolved_ip(struct addrinfo* rp) const
+{
+ if (rp->ai_family == AF_INET)
+ log_debug("Trying IPv4 address ", addr_to_string(rp));
+ else if (rp->ai_family == AF_INET6)
+ log_debug("Trying IPv6 address ", addr_to_string(rp));
+}
+
+bool TCPClientSocketHandler::is_connected() const
+{
+ return this->connected;
+}
+
+bool TCPClientSocketHandler::is_connecting() const
+{
+ return this->connecting || this->resolver.is_resolving();
+}
+
+std::string TCPClientSocketHandler::get_port() const
+{
+ return this->port;
+}
+
+bool TCPClientSocketHandler::match_port_pairt(const uint16_t local, const uint16_t remote) const
+{
+ const auto remote_port = static_cast<uint16_t>(std::stoi(this->port));
+ return this->is_connected() && local == this->local_port && remote == remote_port;
+}
diff --git a/src/network/tcp_client_socket_handler.hpp b/src/network/tcp_client_socket_handler.hpp
new file mode 100644
index 0000000..74caca9
--- /dev/null
+++ b/src/network/tcp_client_socket_handler.hpp
@@ -0,0 +1,82 @@
+#pragma once
+
+#include <network/tcp_socket_handler.hpp>
+
+class TCPClientSocketHandler: public TCPSocketHandler
+{
+ public:
+ TCPClientSocketHandler(std::shared_ptr<Poller>& poller);
+ ~TCPClientSocketHandler();
+ /**
+ * Connect to the remote server, and call on_connected() if this
+ * succeeds. If tls is true, we set use_tls to true and will also call
+ * start_tls() when the connection succeeds.
+ */
+ void connect(const std::string& address, const std::string& port, const bool tls);
+ void connect() override final;
+ /**
+ * Called by a TimedEvent, when the connection did not succeed or fail
+ * after a given time.
+ */
+ void on_connection_timeout();
+ /**
+ * Called when the connection is successful.
+ */
+ virtual void on_connected() = 0;
+ bool is_connected() const override;
+ bool is_connecting() const override;
+
+ std::string get_port() const;
+
+ void close() override final;
+ std::chrono::system_clock::time_point connection_date;
+
+ /**
+ * Whether or not this connection is using the two given TCP ports.
+ */
+ bool match_port_pairt(const uint16_t local, const uint16_t remote) const;
+
+ protected:
+ bool hostname_resolution_failed;
+ /**
+ * Address to bind the socket to, before calling connect().
+ * If empty, it’s equivalent to binding to INADDR_ANY.
+ */
+ std::string bind_addr;
+ /**
+ * Display the resolved IP, just for information purpose.
+ */
+ void display_resolved_ip(struct addrinfo* rp) const;
+ private:
+ /**
+ * Initialize the socket with the parameters contained in the given
+ * addrinfo structure.
+ */
+ void init_socket(const struct addrinfo* rp);
+ /**
+ * DNS resolver
+ */
+ Resolver resolver;
+ /**
+ * Keep the details of the addrinfo returned by the resolver that
+ * triggered a EINPROGRESS error when connect()ing to it, to reuse it
+ * directly when connect() is called again.
+ */
+ struct addrinfo addrinfo{};
+ struct sockaddr_in6 ai_addr{};
+ socklen_t ai_addrlen{};
+
+ /**
+ * Hostname we are connected/connecting to
+ */
+ std::string address;
+ /**
+ * Port we are connected/connecting to
+ */
+ std::string port;
+
+ uint16_t local_port{};
+
+ bool connected;
+ bool connecting;
+};
diff --git a/src/network/tcp_server_socket.hpp b/src/network/tcp_server_socket.hpp
new file mode 100644
index 0000000..652b773
--- /dev/null
+++ b/src/network/tcp_server_socket.hpp
@@ -0,0 +1,69 @@
+#pragma once
+
+#include <network/socket_handler.hpp>
+#include <network/poller.hpp>
+#include <logger/logger.hpp>
+
+#include <string>
+
+#include <arpa/inet.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/ip.h>
+
+#include <cstring>
+
+template <typename RemoteSocketType>
+class TcpSocketServer: public SocketHandler
+{
+ public:
+ TcpSocketServer(std::shared_ptr<Poller>& poller, const uint16_t port):
+ SocketHandler(poller, -1)
+ {
+ if ((this->socket = ::socket(AF_INET6, SOCK_STREAM, 0)) == -1)
+ throw std::runtime_error(std::string{"Could not create socket: "} + std::strerror(errno));
+
+ int opt = 1;
+ if (::setsockopt(this->socket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1)
+ throw std::runtime_error(std::string{"Failed to set socket option: "} + std::strerror(errno));
+
+ struct sockaddr_in6 addr{};
+ addr.sin6_family = AF_INET6;
+ addr.sin6_port = htons(port);
+ addr.sin6_addr = IN6ADDR_ANY_INIT;
+ if ((::bind(this->socket, (const struct sockaddr*)&addr, sizeof(addr))) == -1)
+ { // If we can’t listen on this port, we just give up, but this is not fatal.
+ log_warning("Failed to bind on port ", std::to_string(port), ": ", std::strerror(errno));
+ return;
+ }
+
+ if ((::listen(this->socket, 10)) == -1)
+ throw std::runtime_error("listen() failed");
+
+ this->accept();
+ }
+ ~TcpSocketServer() = default;
+
+ void on_recv() override
+ {
+ // Accept a RemoteSocketType
+ int socket = ::accept(this->socket, nullptr, nullptr);
+
+ auto client = std::make_unique<RemoteSocketType>(poller, socket, *this);
+ this->poller->add_socket_handler(client.get());
+ this->sockets.push_back(std::move(client));
+ }
+
+ protected:
+ std::vector<std::unique_ptr<RemoteSocketType>> sockets;
+
+ private:
+ void accept()
+ {
+ this->poller->add_socket_handler(this);
+ }
+ bool is_connected() const override
+ {
+ return true;
+ }
+};
diff --git a/src/network/tcp_socket_handler.cpp b/src/network/tcp_socket_handler.cpp
new file mode 100644
index 0000000..1049375
--- /dev/null
+++ b/src/network/tcp_socket_handler.cpp
@@ -0,0 +1,360 @@
+#include <network/tcp_socket_handler.hpp>
+#include <network/dns_handler.hpp>
+
+#include <network/poller.hpp>
+
+#include <logger/logger.hpp>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <stdexcept>
+#include <unistd.h>
+#include <cerrno>
+#include <cstring>
+
+#ifdef BOTAN_FOUND
+# include <botan/hex.h>
+# include <botan/tls_exceptn.h>
+# include <config/config.hpp>
+# include <utils/dirname.hpp>
+
+namespace
+{
+ Botan::AutoSeeded_RNG& get_rng()
+ {
+ static Botan::AutoSeeded_RNG rng{};
+ return rng;
+ }
+ Botan::TLS::Session_Manager_In_Memory& get_session_manager()
+ {
+ static Botan::TLS::Session_Manager_In_Memory session_manager{get_rng()};
+ return session_manager;
+ }
+}
+#endif
+
+#ifndef UIO_FASTIOV
+# define UIO_FASTIOV 8
+#endif
+
+using namespace std::string_literals;
+using namespace std::chrono_literals;
+
+
+TCPSocketHandler::TCPSocketHandler(std::shared_ptr<Poller>& poller):
+ SocketHandler(poller, -1),
+ use_tls(false)
+#ifdef BOTAN_FOUND
+ ,credential_manager(this)
+#endif
+{}
+
+TCPSocketHandler::~TCPSocketHandler()
+{
+ if (this->poller->is_managing_socket(this->get_socket()))
+ this->poller->remove_socket_handler(this->get_socket());
+ if (this->socket != -1)
+ {
+ ::close(this->socket);
+ this->socket = -1;
+ }
+}
+
+void TCPSocketHandler::on_recv()
+{
+#ifdef BOTAN_FOUND
+ if (this->use_tls)
+ this->tls_recv();
+ else
+#endif
+ this->plain_recv();
+}
+
+void TCPSocketHandler::plain_recv()
+{
+ static constexpr size_t buf_size = 4096;
+ char buf[buf_size];
+ void* recv_buf = this->get_receive_buffer(buf_size);
+
+ if (recv_buf == nullptr)
+ recv_buf = buf;
+
+ const ssize_t size = this->do_recv(recv_buf, buf_size);
+
+ if (size > 0)
+ {
+ if (buf == recv_buf)
+ {
+ // data needs to be placed in the in_buf string, because no buffer
+ // was provided to receive that data directly. The in_buf buffer
+ // will be handled in parse_in_buffer()
+ this->in_buf += std::string(buf, size);
+ }
+ this->parse_in_buffer(size);
+ }
+}
+
+ssize_t TCPSocketHandler::do_recv(void* recv_buf, const size_t buf_size)
+{
+ ssize_t size = ::recv(this->socket, recv_buf, buf_size, 0);
+ if (0 == size)
+ {
+ this->on_connection_close("");
+ this->close();
+ }
+ else if (-1 == size)
+ {
+ if (this->is_connecting())
+ log_warning("Error connecting: ", strerror(errno));
+ else
+ log_warning("Error while reading from socket: ", strerror(errno));
+ // Remember if we were connecting, or already connected when this
+ // happened, because close() sets this->connecting to false
+ const auto were_connecting = this->is_connecting();
+ this->close();
+ if (were_connecting)
+ this->on_connection_failed(strerror(errno));
+ else
+ this->on_connection_close(strerror(errno));
+ }
+ return size;
+}
+
+void TCPSocketHandler::on_send()
+{
+ struct iovec msg_iov[UIO_FASTIOV] = {};
+ struct msghdr msg{};
+ msg.msg_iov = msg_iov;
+ msg.msg_iovlen = 0;
+ for (const std::string& s: this->out_buf)
+ {
+ // unconsting the content of s is ok, sendmsg will never modify it
+ msg_iov[msg.msg_iovlen].iov_base = const_cast<char*>(s.data());
+ msg_iov[msg.msg_iovlen].iov_len = s.size();
+ msg.msg_iovlen++;
+ if (msg.msg_iovlen == UIO_FASTIOV)
+ break;
+ }
+ ssize_t res = ::sendmsg(this->socket, &msg, MSG_NOSIGNAL);
+ if (res < 0)
+ {
+ log_error("sendmsg failed: ", strerror(errno));
+ this->on_connection_close(strerror(errno));
+ this->close();
+ }
+ else
+ {
+ // remove all the strings that were successfully sent.
+ auto it = this->out_buf.begin();
+ while (it != this->out_buf.end())
+ {
+ if (static_cast<size_t>(res) >= it->size())
+ {
+ res -= it->size();
+ ++it;
+ }
+ else
+ {
+ // If one string has partially been sent, we use substr to
+ // crop it
+ if (res > 0)
+ *it = it->substr(res, std::string::npos);
+ break;
+ }
+ }
+ this->out_buf.erase(this->out_buf.begin(), it);
+ if (this->out_buf.empty())
+ this->poller->stop_watching_send_events(this);
+ }
+}
+
+void TCPSocketHandler::close()
+{
+ if (this->is_connected() || this->is_connecting())
+ this->poller->remove_socket_handler(this->get_socket());
+ if (this->socket != -1)
+ {
+ ::close(this->socket);
+ this->socket = -1;
+ }
+ this->in_buf.clear();
+ this->out_buf.clear();
+}
+
+void TCPSocketHandler::send_data(std::string&& data)
+{
+#ifdef BOTAN_FOUND
+ if (this->use_tls)
+ try {
+ this->tls_send(std::move(data));
+ } catch (const Botan::TLS::TLS_Exception& e) {
+ this->on_connection_close("TLS error: "s + e.what());
+ this->close();
+ return ;
+ }
+ else
+#endif
+ this->raw_send(std::move(data));
+}
+
+void TCPSocketHandler::raw_send(std::string&& data)
+{
+ if (data.empty())
+ return ;
+ this->out_buf.emplace_back(std::move(data));
+ if (this->is_connected())
+ this->poller->watch_send_events(this);
+}
+
+void TCPSocketHandler::send_pending_data()
+{
+ if (this->is_connected() && !this->out_buf.empty())
+ this->poller->watch_send_events(this);
+}
+
+bool TCPSocketHandler::is_using_tls() const
+{
+ return this->use_tls;
+}
+
+void* TCPSocketHandler::get_receive_buffer(const size_t) const
+{
+ return nullptr;
+}
+
+void TCPSocketHandler::consume_in_buffer(const std::size_t size)
+{
+ this->in_buf = this->in_buf.substr(size, std::string::npos);
+}
+
+#ifdef BOTAN_FOUND
+void TCPSocketHandler::start_tls(const std::string& address, const std::string& port_string)
+{
+ auto port = std::min(std::stoul(port_string), static_cast<unsigned long>(std::numeric_limits<uint16_t>::max()));
+ Botan::TLS::Server_Information server_info(address, "irc", static_cast<uint16_t>(port));
+ auto policy_directory = Config::get("policy_directory", utils::dirname(Config::get_filename()));
+ if (!policy_directory.empty() && policy_directory[policy_directory.size()-1] != '/')
+ policy_directory += '/';
+ this->policy.load(policy_directory + "policy.txt");
+ this->policy.load(policy_directory + address + ".policy.txt");
+ this->tls = std::make_unique<Botan::TLS::Client>(
+# if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,32)
+ *this,
+# else
+ [this](const Botan::byte* data, size_t size) { this->tls_emit_data(data, size); },
+ [this](const Botan::byte* data, size_t size) { this->tls_record_received(0, data, size); },
+ [this](Botan::TLS::Alert alert, const Botan::byte*, size_t) { this->tls_alert(alert); },
+ [this](const Botan::TLS::Session& session) { return this->tls_session_established(session); },
+# endif
+ get_session_manager(), this->credential_manager, this->policy,
+ get_rng(), server_info, Botan::TLS::Protocol_Version::latest_tls_version());
+}
+
+void TCPSocketHandler::tls_recv()
+{
+ static constexpr size_t buf_size = 4096;
+ Botan::byte recv_buf[buf_size];
+
+ const ssize_t size = this->do_recv(recv_buf, buf_size);
+ if (size > 0)
+ {
+ const bool was_active = this->tls->is_active();
+ try {
+ this->tls->received_data(recv_buf, static_cast<size_t>(size));
+ } catch (const Botan::TLS::TLS_Exception& e) {
+ // May happen if the server sends malformed TLS data (buggy server,
+ // or more probably we are just connected to a server that sends
+ // plain-text)
+ this->on_connection_close("TLS error: "s + e.what());
+ this->close();
+ return ;
+ }
+ if (!was_active && this->tls->is_active())
+ this->on_tls_activated();
+ }
+}
+
+void TCPSocketHandler::tls_send(std::string&& data)
+{
+ // We may not be connected yet, or the tls session has
+ // not yet been negociated
+ if (this->tls && this->tls->is_active())
+ {
+ const bool was_active = this->tls->is_active();
+ if (!this->pre_buf.empty())
+ {
+ this->tls->send(this->pre_buf.data(), this->pre_buf.size());
+ this->pre_buf.clear();
+ }
+ if (!data.empty())
+ this->tls->send(reinterpret_cast<const Botan::byte*>(data.data()),
+ data.size());
+ if (!was_active && this->tls->is_active())
+ this->on_tls_activated();
+ }
+ else
+ this->pre_buf.insert(this->pre_buf.end(),
+ std::make_move_iterator(data.begin()),
+ std::make_move_iterator(data.end()));
+}
+
+void TCPSocketHandler::tls_record_received(uint64_t, const Botan::byte *data, size_t size)
+{
+ this->in_buf += std::string(reinterpret_cast<const char*>(data),
+ size);
+ if (!this->in_buf.empty())
+ this->parse_in_buffer(size);
+}
+
+void TCPSocketHandler::tls_emit_data(const Botan::byte *data, size_t size)
+{
+ this->raw_send(std::string(reinterpret_cast<const char*>(data), size));
+}
+
+void TCPSocketHandler::tls_alert(Botan::TLS::Alert alert)
+{
+ log_debug("tls_alert: ", alert.type_string());
+}
+
+bool TCPSocketHandler::tls_session_established(const Botan::TLS::Session& session)
+{
+ log_debug("Handshake with ", session.server_info().hostname(), " complete.",
+ " Version: ", session.version().to_string(),
+ " using ", session.ciphersuite().to_string());
+ if (!session.session_id().empty())
+ log_debug("Session ID ", Botan::hex_encode(session.session_id()));
+ if (!session.session_ticket().empty())
+ log_debug("Session ticket ", Botan::hex_encode(session.session_ticket()));
+ return true;
+}
+
+#if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,34)
+void TCPSocketHandler::tls_verify_cert_chain(const std::vector<Botan::X509_Certificate>& cert_chain,
+ const std::vector<std::shared_ptr<const Botan::OCSP::Response>>& ocsp_responses,
+ const std::vector<Botan::Certificate_Store*>& trusted_roots,
+ Botan::Usage_Type usage, const std::string& hostname,
+ const Botan::TLS::Policy& policy)
+{
+ log_debug("Checking remote certificate for hostname ", hostname);
+ try
+ {
+ Botan::TLS::Callbacks::tls_verify_cert_chain(cert_chain, ocsp_responses, trusted_roots, usage, hostname, policy);
+ log_debug("Certificate is valid");
+ }
+ catch (const std::exception& tls_exception)
+ {
+ log_warning("TLS certificate check failed: ", tls_exception.what());
+ std::exception_ptr exception_ptr{};
+ if (this->abort_on_invalid_cert())
+ exception_ptr = std::current_exception();
+
+ check_tls_certificate(cert_chain, hostname, this->credential_manager.get_trusted_fingerprint(), exception_ptr);
+ }
+}
+#endif
+
+void TCPSocketHandler::on_tls_activated()
+{
+ this->send_data({});
+}
+
+#endif // BOTAN_FOUND
diff --git a/src/network/tcp_socket_handler.hpp b/src/network/tcp_socket_handler.hpp
new file mode 100644
index 0000000..f68698e
--- /dev/null
+++ b/src/network/tcp_socket_handler.hpp
@@ -0,0 +1,238 @@
+#pragma once
+
+#include "biboumi.h"
+
+#include <network/socket_handler.hpp>
+#include <network/resolver.hpp>
+
+#include <network/credentials_manager.hpp>
+
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <netdb.h>
+
+#include <chrono>
+#include <vector>
+#include <memory>
+#include <string>
+#include <list>
+
+#ifdef BOTAN_FOUND
+
+# include <botan/types.h>
+# include <botan/botan.h>
+# include <botan/tls_session_manager.h>
+# include <network/tls_policy.hpp>
+
+# if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,32)
+# define BOTAN_TLS_CALLBACKS_OVERRIDE override final
+# else
+# define BOTAN_TLS_CALLBACKS_OVERRIDE
+# endif
+#endif
+
+/**
+ * Does all the read/write, buffering etc. With optional tls.
+ * But doesn’t do any connect() or accept() or anything else.
+ */
+class TCPSocketHandler: public SocketHandler
+#ifdef BOTAN_FOUND
+# if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,32)
+ ,public Botan::TLS::Callbacks
+# endif
+#endif
+{
+protected:
+ ~TCPSocketHandler();
+public:
+ explicit TCPSocketHandler(std::shared_ptr<Poller>& poller);
+ TCPSocketHandler(const TCPSocketHandler&) = delete;
+ TCPSocketHandler(TCPSocketHandler&&) = delete;
+ TCPSocketHandler& operator=(const TCPSocketHandler&) = delete;
+ TCPSocketHandler& operator=(TCPSocketHandler&&) = delete;
+
+ /**
+ * Reads raw data from the socket. And pass it to parse_in_buffer()
+ * If we are using TLS on this connection, we call tls_recv()
+ */
+ void on_recv() override final;
+ /**
+ * Write as much data from out_buf as possible, in the socket.
+ */
+ void on_send() override final;
+ /**
+ * Add the given data to out_buf and tell our poller that we want to be
+ * notified when a send event is ready.
+ *
+ * This can be overriden if we want to modify the data before sending
+ * it. For example if we want to encrypt it.
+ */
+ void send_data(std::string&& data);
+ /**
+ * Watch the socket for send events, if our out buffer is not empty.
+ */
+ void send_pending_data();
+ /**
+ * Close the connection, remove us from the poller
+ */
+ virtual void close();
+ /**
+ * Handle/consume (some of) the data received so far. The data to handle
+ * may be in the in_buf buffer, or somewhere else, depending on what
+ * get_receive_buffer() returned. If some data is used from in_buf, it
+ * should be truncated, only the unused data should be left untouched.
+ *
+ * The size argument is the size of the last chunk of data that was added to the buffer.
+ *
+ * The function should call consume_in_buffer, with the size that was consumed by the
+ * “parsing”, and thus to be removed from the input buffer.
+ */
+ virtual void parse_in_buffer(const size_t size) = 0;
+#ifdef BOTAN_FOUND
+ /**
+ * Tell whether the credential manager should cancel the connection when the
+ * certificate is invalid.
+ */
+ virtual bool abort_on_invalid_cert() const
+ {
+ return true;
+ }
+#endif
+ bool is_using_tls() const;
+
+private:
+ /**
+ * Reads from the socket into the provided buffer. If an error occurs
+ * (read returns <= 0), the handling of the error is done here (close the
+ * connection, log a message, etc).
+ *
+ * Returns the value returned by ::recv(), so the buffer should not be
+ * used if it’s not positive.
+ */
+ ssize_t do_recv(void* recv_buf, const size_t buf_size);
+ /**
+ * Reads data from the socket and calls parse_in_buffer with it.
+ */
+ void plain_recv();
+ /**
+ * Mark the given data as ready to be sent, as-is, on the socket, as soon
+ * as we can.
+ */
+ void raw_send(std::string&& data);
+
+ protected:
+ virtual bool is_connecting() const = 0;
+#ifdef BOTAN_FOUND
+ /**
+ * Create the TLS::Client object, with all the callbacks etc. This must be
+ * called only when we know we are able to send TLS-encrypted data over
+ * the socket.
+ */
+ void start_tls(const std::string& address, const std::string& port);
+ private:
+ /**
+ * An additional step to pass the data into our tls object to decrypt it
+ * before passing it to parse_in_buffer.
+ */
+ void tls_recv();
+ /**
+ * Pass the data to the tls object in order to encrypt it. The tls object
+ * will then call raw_send as a callback whenever data as been encrypted
+ * and can be sent on the socket.
+ */
+ void tls_send(std::string&& data);
+ /**
+ * Called by the tls object that some data has been decrypt. We call
+ * parse_in_buffer() to handle that unencrypted data.
+ */
+ void tls_record_received(uint64_t rec_no, const Botan::byte* data, size_t size) BOTAN_TLS_CALLBACKS_OVERRIDE;
+ /**
+ * Called by the tls object to indicate that some data has been encrypted
+ * and is now ready to be sent on the socket as is.
+ */
+ void tls_emit_data(const Botan::byte* data, size_t size) BOTAN_TLS_CALLBACKS_OVERRIDE;
+ /**
+ * Called by the tls object to indicate that a TLS alert has been
+ * received. We don’t use it, we just log some message, at the moment.
+ */
+ void tls_alert(Botan::TLS::Alert alert) BOTAN_TLS_CALLBACKS_OVERRIDE;
+ /**
+ * Called by the tls object at the end of the TLS handshake. We don't do
+ * anything here appart from logging the TLS session information.
+ */
+ bool tls_session_established(const Botan::TLS::Session& session) BOTAN_TLS_CALLBACKS_OVERRIDE;
+
+#if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,11,34)
+ void tls_verify_cert_chain(const std::vector<Botan::X509_Certificate>& cert_chain,
+ const std::vector<std::shared_ptr<const Botan::OCSP::Response>>& ocsp_responses,
+ const std::vector<Botan::Certificate_Store*>& trusted_roots,
+ Botan::Usage_Type usage,
+ const std::string& hostname,
+ const Botan::TLS::Policy& policy) BOTAN_TLS_CALLBACKS_OVERRIDE;
+#endif
+ /**
+ * Called whenever the tls session goes from inactive to active. This
+ * means that the handshake has just been successfully done, and we can
+ * now proceed to send any available data into our tls object.
+ */
+ void on_tls_activated();
+#endif // BOTAN_FOUND
+ /**
+ * Where data is added, when we want to send something to the client.
+ */
+ std::vector<std::string> out_buf;
+protected:
+ /**
+ * Whether we are using TLS on this connection or not.
+ */
+ bool use_tls;
+ /**
+ * Where data read from the socket is added until we can extract a full
+ * and meaningful “message” from it.
+ *
+ * TODO: something more efficient than a string.
+ */
+ std::string in_buf;
+ /**
+ * Remove the given “size” first bytes from our in_buf.
+ */
+ void consume_in_buffer(const std::size_t size);
+ /**
+ * Provide a buffer in which data can be directly received. This can be
+ * used to avoid copying data into in_buf before using it. If no buffer
+ * needs to be provided, nullptr is returned (the default implementation
+ * does that), in that case our internal in_buf will be used to save the
+ * data until it can be used by parse_in_buffer().
+ */
+ virtual void* get_receive_buffer(const size_t size) const;
+ /**
+ * Called when we detect a disconnection from the remote host.
+ */
+ virtual void on_connection_close(const std::string&) {}
+ virtual void on_connection_failed(const std::string&) {}
+
+#ifdef BOTAN_FOUND
+protected:
+ BasicCredentialsManager credential_manager;
+private:
+ BiboumiTLSPolicy policy;
+ /**
+ * We use a unique_ptr because we may not want to create the object at
+ * all. The Botan::TLS::Client object generates a handshake message and
+ * calls the output_fn callback with it as soon as it is created.
+ * Therefore, we do not want to create it if we do not intend to send any
+ * TLS-encrypted message. We create the object only when needed (for
+ * example after we have negociated a TLS session using a STARTTLS
+ * message, or stuf like that).
+ *
+ * See start_tls for the method where this object is created.
+ */
+ std::unique_ptr<Botan::TLS::Client> tls;
+ /**
+ * An additional buffer to keep data that the user wants to send, but
+ * cannot because the handshake is not done.
+ */
+ std::vector<Botan::byte> pre_buf;
+#endif // BOTAN_FOUND
+};
diff --git a/src/network/tls_policy.cpp b/src/network/tls_policy.cpp
new file mode 100644
index 0000000..5439397
--- /dev/null
+++ b/src/network/tls_policy.cpp
@@ -0,0 +1,48 @@
+#include "biboumi.h"
+
+#ifdef BOTAN_FOUND
+
+#include <fstream>
+
+#include <utils/tolower.hpp>
+
+#include <network/tls_policy.hpp>
+#include <logger/logger.hpp>
+
+bool BiboumiTLSPolicy::load(const std::string& filename)
+{
+ std::ifstream is(filename.data());
+ if (is)
+ {
+ try {
+ this->load(is);
+ log_info("Successfully loaded policy file: ", filename);
+ return true;
+ } catch (const Botan::Exception& e) {
+ log_error("Failed to parse policy_file ", filename, ": ", e.what());
+ return false;
+ }
+ }
+ log_info("Could not open policy file: ", filename);
+ return false;
+}
+
+void BiboumiTLSPolicy::load(std::istream& is)
+{
+ const auto dict = Botan::read_cfg(is);
+ for (const auto& pair: dict)
+ {
+ // Workaround for options that are not overridden in Botan::TLS::Text_Policy
+ if (pair.first == "require_cert_revocation_info")
+ this->req_cert_revocation_info = !(pair.second == "0" || utils::tolower(pair.second) == "false");
+ else
+ this->set(pair.first, pair.second);
+ }
+}
+
+bool BiboumiTLSPolicy::require_cert_revocation_info() const
+{
+ return this->req_cert_revocation_info;
+}
+
+#endif
diff --git a/src/network/tls_policy.hpp b/src/network/tls_policy.hpp
new file mode 100644
index 0000000..29fd2b3
--- /dev/null
+++ b/src/network/tls_policy.hpp
@@ -0,0 +1,28 @@
+#pragma once
+
+#include "biboumi.h"
+
+#ifdef BOTAN_FOUND
+
+#include <botan/tls_policy.h>
+
+class BiboumiTLSPolicy: public Botan::TLS::Text_Policy
+{
+public:
+ BiboumiTLSPolicy():
+ Botan::TLS::Text_Policy({})
+ {}
+ bool load(const std::string& filename);
+ void load(std::istream& iss);
+
+ BiboumiTLSPolicy(const BiboumiTLSPolicy &) = delete;
+ BiboumiTLSPolicy(BiboumiTLSPolicy &&) = delete;
+ BiboumiTLSPolicy &operator=(const BiboumiTLSPolicy &) = delete;
+ BiboumiTLSPolicy &operator=(BiboumiTLSPolicy &&) = delete;
+
+ bool require_cert_revocation_info() const override;
+protected:
+ bool req_cert_revocation_info{true};
+};
+
+#endif
diff --git a/src/utils/dirname.cpp b/src/utils/dirname.cpp
new file mode 100644
index 0000000..71c9c38
--- /dev/null
+++ b/src/utils/dirname.cpp
@@ -0,0 +1,16 @@
+#include <utils/dirname.hpp>
+
+namespace utils
+{
+ std::string dirname(const std::string filename)
+ {
+ if (filename.empty())
+ return "./";
+ if (filename == ".." || filename == ".")
+ return filename;
+ auto pos = filename.rfind('/');
+ if (pos == std::string::npos)
+ return "./";
+ return filename.substr(0, pos + 1);
+ }
+}
diff --git a/src/utils/dirname.hpp b/src/utils/dirname.hpp
new file mode 100644
index 0000000..c1df81b
--- /dev/null
+++ b/src/utils/dirname.hpp
@@ -0,0 +1,6 @@
+#include <string>
+
+namespace utils
+{
+std::string dirname(const std::string filename);
+}
diff --git a/src/utils/encoding.cpp b/src/utils/encoding.cpp
new file mode 100644
index 0000000..cff0039
--- /dev/null
+++ b/src/utils/encoding.cpp
@@ -0,0 +1,254 @@
+#include <utils/encoding.hpp>
+
+#include <utils/scopeguard.hpp>
+
+#include <stdexcept>
+
+#include <cassert>
+#include <string.h>
+#include <iconv.h>
+#include <cerrno>
+
+#include <map>
+#include <bitset>
+
+/**
+ * The UTF-8-encoded character used as a place holder when a character conversion fails.
+ * This is U+FFFD � "replacement character"
+ */
+static const char* invalid_char = "\xef\xbf\xbd";
+static const size_t invalid_char_len = 3;
+
+namespace utils
+{
+ /**
+ * Based on http://en.wikipedia.org/wiki/UTF-8#Description
+ */
+ std::size_t get_next_codepoint_size(const unsigned char c)
+ {
+ if ((c & 0b11111000) == 0b11110000) // 4 bytes: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
+ return 4;
+ else if ((c & 0b11110000) == 0b11100000) // 3 bytes: 1110xxx 10xxxxxx 10xxxxxx
+ return 3;
+ else if ((c & 0b11100000) == 0b11000000) // 2 bytes: 110xxxxx 10xxxxxx
+ return 2;
+ return 1; // 1 byte: 0xxxxxxx
+ }
+
+ bool is_valid_utf8(const char* s)
+ {
+ if (!s)
+ return false;
+
+ const unsigned char* str = reinterpret_cast<const unsigned char*>(s);
+
+ while (*str)
+ {
+ const auto codepoint_size = get_next_codepoint_size(str[0]);
+ if (codepoint_size == 4)
+ {
+ if (!str[1] || !str[2] || !str[3]
+ || ((str[1] & 0b11000000) != 0b10000000)
+ || ((str[2] & 0b11000000) != 0b10000000)
+ || ((str[3] & 0b11000000) != 0b10000000))
+ return false;
+ }
+ else if (codepoint_size == 3)
+ {
+ if (!str[1] || !str[2]
+ || ((str[1] & 0b11000000) != 0b10000000)
+ || ((str[2] & 0b11000000) != 0b10000000))
+ return false;
+ }
+ else if (codepoint_size == 2)
+ {
+ if (!str[1] ||
+ ((str[1] & 0b11000000) != 0b10000000))
+ return false;
+ }
+ else if ((str[0] & 0b10000000) != 0)
+ return false;
+ str += codepoint_size;
+ }
+ return true;
+ }
+
+ std::string remove_invalid_xml_chars(const std::string& original)
+ {
+ // The given string MUST be a valid utf-8 string
+ std::vector<char> res(original.size(), '\0');
+
+ // pointer where we write valid chars
+ char* r = res.data();
+
+ const char* str = original.c_str();
+ std::bitset<20> codepoint;
+
+ while (*str)
+ {
+ // 4 bytes: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
+ if ((str[0] & 0b11111000) == 0b11110000)
+ {
+ codepoint = ((str[0] & 0b00000111) << 18);
+ codepoint |= ((str[1] & 0b00111111) << 12);
+ codepoint |= ((str[2] & 0b00111111) << 6 );
+ codepoint |= ((str[3] & 0b00111111) << 0 );
+ if (codepoint.to_ulong() <= 0x10FFFF)
+ {
+ ::memcpy(r, str, 4);
+ r += 4;
+ }
+ str += 4;
+ }
+ // 3 bytes: 1110xxx 10xxxxxx 10xxxxxx
+ else if ((str[0] & 0b11110000) == 0b11100000)
+ {
+ codepoint = ((str[0] & 0b00001111) << 12);
+ codepoint |= ((str[1] & 0b00111111) << 6);
+ codepoint |= ((str[2] & 0b00111111) << 0 );
+ if (codepoint.to_ulong() <= 0xD7FF ||
+ (codepoint.to_ulong() >= 0xE000 && codepoint.to_ulong() <= 0xFFFD))
+ {
+ ::memcpy(r, str, 3);
+ r += 3;
+ }
+ str += 3;
+ }
+ // 2 bytes: 110xxxxx 10xxxxxx
+ else if (((str[0]) & 0b11100000) == 0b11000000)
+ {
+ // All 2 bytes char are valid, don't even bother calculating
+ // the codepoint
+ ::memcpy(r, str, 2);
+ r += 2;
+ str += 2;
+ }
+ // 1 byte: 0xxxxxxx
+ else if ((str[0] & 0b10000000) == 0)
+ {
+ codepoint = ((str[0] & 0b01111111));
+ if (codepoint.to_ulong() == 0x09 ||
+ codepoint.to_ulong() == 0x0A ||
+ codepoint.to_ulong() == 0x0D ||
+ codepoint.to_ulong() >= 0x20)
+ {
+ ::memcpy(r, str, 1);
+ r += 1;
+ }
+ str += 1;
+ }
+ else
+ throw std::runtime_error("Invalid UTF-8 passed to remove_invalid_xml_chars");
+ }
+ return {res.data(), static_cast<size_t>(r - res.data())};
+ }
+
+ std::string convert_to_utf8(const std::string& str, const char* charset)
+ {
+ std::string res;
+
+ const iconv_t cd = iconv_open("UTF-8", charset);
+ if (cd == (iconv_t)-1)
+ throw std::runtime_error("Cannot convert into UTF-8");
+
+ // Make sure cd is always closed when we leave this function
+ const auto sg = utils::make_scope_guard([&cd](){ iconv_close(cd); });
+
+ size_t inbytesleft = str.size();
+
+ // iconv will not attempt to modify this buffer, but some plateform
+ // require a char** anyway
+#ifdef ICONV_SECOND_ARGUMENT_IS_CONST
+ const char* inbuf_ptr = str.c_str();
+#else
+ char* inbuf_ptr = const_cast<char*>(str.c_str());
+#endif
+
+ size_t outbytesleft = str.size() * 4;
+ char* outbuf = new char[outbytesleft];
+ char* outbuf_ptr = outbuf;
+
+ // Make sure outbuf is always deleted when we leave this function
+ const auto sg2 = utils::make_scope_guard([outbuf](){ delete[] outbuf; });
+
+ bool done = false;
+ while (done == false)
+ {
+ size_t error = iconv(cd, &inbuf_ptr, &inbytesleft, &outbuf_ptr, &outbytesleft);
+ if ((size_t)-1 == error)
+ {
+ switch (errno)
+ {
+ case EILSEQ:
+ // Invalid byte found. Insert a placeholder instead of the
+ // converted character, jump one byte and continue
+ memcpy(outbuf_ptr, invalid_char, invalid_char_len);
+ outbuf_ptr += invalid_char_len;
+ inbytesleft--;
+ inbuf_ptr++;
+ break;
+ case EINVAL:
+ // A multibyte sequence is not terminated, but we can't
+ // provide any more data, so we just add a placeholder to
+ // indicate that the character is not properly converted,
+ // and we stop the conversion
+ memcpy(outbuf_ptr, invalid_char, invalid_char_len);
+ outbuf_ptr += invalid_char_len;
+ outbuf_ptr++;
+ done = true;
+ break;
+ case E2BIG: // This should never happen
+ default: // This should happen even neverer
+ done = true;
+ break;
+ }
+ }
+ else
+ {
+ // The conversion finished without any error, stop converting
+ done = true;
+ }
+ }
+ // Terminate the converted buffer, and copy that buffer it into the
+ // string we return
+ *outbuf_ptr = '\0';
+ res = outbuf;
+ return res;
+ }
+
+}
+
+namespace xep0106
+{
+ static const std::map<const char, const std::string> encode_map = {
+ {' ', "\\20"},
+ {'"', "\\22"},
+ {'&', "\\26"},
+ {'\'',"\\27"},
+ {'/', "\\2f"},
+ {':', "\\3a"},
+ {'<', "\\3c"},
+ {'>', "\\3e"},
+ {'@', "\\40"},
+ };
+
+ void decode(std::string& s)
+ {
+ std::string::size_type pos;
+ for (const auto& pair: encode_map)
+ while ((pos = s.find(pair.second)) != std::string::npos)
+ s.replace(pos, pair.second.size(),
+ 1, pair.first);
+ }
+
+ void encode(std::string& s)
+ {
+ std::string::size_type pos;
+ while ((pos = s.find_first_of(" \"&'/:<>@")) != std::string::npos)
+ {
+ auto it = encode_map.find(s[pos]);
+ assert(it != encode_map.end());
+ s.replace(pos, 1, it->second);
+ }
+ }
+}
diff --git a/src/utils/encoding.hpp b/src/utils/encoding.hpp
new file mode 100644
index 0000000..b707a0c
--- /dev/null
+++ b/src/utils/encoding.hpp
@@ -0,0 +1,43 @@
+#pragma once
+
+
+#include <string>
+
+namespace utils
+{
+ /**
+ * Return the size, in bytes, of the next UTF-8 codepoint, based on
+ * the given char.
+ */
+ std::size_t get_next_codepoint_size(const unsigned char c);
+ /**
+ * Returns true if the given null-terminated string is valid utf-8.
+ *
+ * Based on http://en.wikipedia.org/wiki/UTF-8#Description
+ */
+ bool is_valid_utf8(const char* s);
+ /**
+ * Remove all invalid codepoints from the given utf-8-encoded string.
+ * The value returned is a copy of the string, without the removed chars.
+ *
+ * See http://www.w3.org/TR/xml/#charsets for the list of valid characters
+ * in XML.
+ */
+ std::string remove_invalid_xml_chars(const std::string& original);
+ /**
+ * Convert the given string (encoded is "encoding") into valid utf-8.
+ * If some decoding fails, insert an utf-8 placeholder character instead.
+ */
+ std::string convert_to_utf8(const std::string& str, const char* charset);
+}
+
+namespace xep0106
+{
+ /**
+ * Decode and encode inplace.
+ */
+ void decode(std::string&);
+ void encode(std::string&);
+}
+
+
diff --git a/src/utils/get_first_non_empty.cpp b/src/utils/get_first_non_empty.cpp
new file mode 100644
index 0000000..5b3bedb
--- /dev/null
+++ b/src/utils/get_first_non_empty.cpp
@@ -0,0 +1,11 @@
+#include <utils/get_first_non_empty.hpp>
+
+bool is_empty(const std::string& val)
+{
+ return val.empty();
+}
+
+bool is_empty(const int& val)
+{
+ return val == 0;
+}
diff --git a/src/utils/get_first_non_empty.hpp b/src/utils/get_first_non_empty.hpp
new file mode 100644
index 0000000..a38f5fb
--- /dev/null
+++ b/src/utils/get_first_non_empty.hpp
@@ -0,0 +1,20 @@
+#pragma once
+
+#include <string>
+
+bool is_empty(const std::string& val);
+bool is_empty(const int& val);
+
+template <typename T>
+T get_first_non_empty(T&& last)
+{
+ return last;
+}
+
+template <typename T, typename... Args>
+T get_first_non_empty(T&& first, Args&&... args)
+{
+ if (!is_empty(first))
+ return first;
+ return get_first_non_empty(std::forward<Args>(args)...);
+}
diff --git a/src/utils/optional_bool.hpp b/src/utils/optional_bool.hpp
new file mode 100644
index 0000000..59bbbab
--- /dev/null
+++ b/src/utils/optional_bool.hpp
@@ -0,0 +1,35 @@
+#pragma once
+
+#include <string>
+
+struct OptionalBool
+{
+ OptionalBool() = default;
+
+ OptionalBool(bool value):
+ is_set(true), value(value) {}
+
+ void set_value(bool value)
+ {
+ this->is_set = true;
+ this->value = value;
+ }
+
+ void unset()
+ {
+ this->is_set = false;
+ }
+
+ std::string to_string()
+ {
+ if (this->is_set == false)
+ return "unset";
+ else if (this->value)
+ return "true";
+ else
+ return "false";
+ }
+
+ bool is_set{false};
+ bool value{false};
+};
diff --git a/src/utils/reload.cpp b/src/utils/reload.cpp
index 348c5b5..fdca9bc 100644
--- a/src/utils/reload.cpp
+++ b/src/utils/reload.cpp
@@ -26,7 +26,7 @@ void reload_process()
#ifdef USE_DATABASE
try {
open_database();
- } catch (const litesql::DatabaseError&) {
+ } catch (...) {
log_warning("Re-using the previous database.");
}
#endif
diff --git a/src/utils/revstr.cpp b/src/utils/revstr.cpp
new file mode 100644
index 0000000..87fd801
--- /dev/null
+++ b/src/utils/revstr.cpp
@@ -0,0 +1,9 @@
+#include <utils/revstr.hpp>
+
+namespace utils
+{
+ std::string revstr(const std::string& original)
+ {
+ return {original.rbegin(), original.rend()};
+ }
+}
diff --git a/src/utils/revstr.hpp b/src/utils/revstr.hpp
new file mode 100644
index 0000000..8e521ea
--- /dev/null
+++ b/src/utils/revstr.hpp
@@ -0,0 +1,11 @@
+#pragma once
+
+
+#include <string>
+
+namespace utils
+{
+ std::string revstr(const std::string& original);
+}
+
+
diff --git a/src/utils/scopeguard.hpp b/src/utils/scopeguard.hpp
new file mode 100644
index 0000000..e697fc3
--- /dev/null
+++ b/src/utils/scopeguard.hpp
@@ -0,0 +1,98 @@
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+/**
+ * A class to be used to make sure some functions are called when the scope
+ * is left, because they will be called in the ScopeGuard's destructor. It
+ * can for example be used to delete some pointer whenever any exception is
+ * called. Example:
+
+ * {
+ * ScopeGuard scope;
+ * int* number = new int(2);
+ * scope.add_callback([number]() { delete number; });
+ * // Do some other stuff with the number. But these stuff might throw an exception:
+ * throw std::runtime_error("Some error not caught here, but in our caller");
+ * return true;
+ * }
+
+ * In this example, our pointer will always be deleted, even when the
+ * exception is thrown. If we want the functions to be called only when the
+ * scope is left because of an unexpected exception, we can use
+ * ScopeGuard::disable();
+ */
+
+namespace utils
+{
+
+class ScopeGuard
+{
+public:
+ /**
+ * The constructor can take a callback. But additional callbacks can be
+ * added later with add_callback()
+ */
+ explicit ScopeGuard(std::function<void()>&& func):
+ enabled(true)
+ {
+ this->add_callback(std::move(func));
+ }
+
+ ScopeGuard(const ScopeGuard&) = delete;
+ ScopeGuard& operator=(ScopeGuard&&) = delete;
+ ScopeGuard(ScopeGuard&&) = delete;
+ ScopeGuard& operator=(const ScopeGuard&) = delete;
+
+ /**
+ * default constructor, the scope guard is enabled but empty, use
+ * add_callback()
+ */
+ explicit ScopeGuard():
+ enabled(true)
+ {
+ }
+ /**
+ * Call all callbacks in the desctructor, unless it has been disabled.
+ */
+ ~ScopeGuard()
+ {
+ if (this->enabled)
+ for (auto& func: this->callbacks)
+ func();
+ }
+ /**
+ * Add a callback to be called in our destructor, one scope guard can be
+ * used for more than one task, if needed.
+ */
+ void add_callback(std::function<void()>&& func)
+ {
+ this->callbacks.emplace_back(std::move(func));
+ }
+ /**
+ * Disable that scope guard, nothing will be done when the scope is
+ * exited.
+ */
+ void disable()
+ {
+ this->enabled = false;
+ }
+
+private:
+ bool enabled;
+ std::vector<std::function<void()>> callbacks;
+
+};
+
+template<typename F>
+auto make_scope_guard(F f)
+{
+ static struct Empty {} empty;
+ auto deleter = [f = std::move(f)](Empty*) { f(); };
+ return std::unique_ptr<Empty, decltype(deleter)>{&empty, std::move(deleter)};
+}
+
+}
+
diff --git a/src/utils/sha1.cpp b/src/utils/sha1.cpp
new file mode 100644
index 0000000..2e6efc2
--- /dev/null
+++ b/src/utils/sha1.cpp
@@ -0,0 +1,40 @@
+#include <utils/sha1.hpp>
+
+#include <biboumi.h>
+
+#ifdef BOTAN_FOUND
+# include <botan/version.h>
+# include <botan/hash.h>
+# include <botan/hex.h>
+# include <botan/exceptn.h>
+#endif
+#ifdef GCRYPT_FOUND
+# include <gcrypt.h>
+# include <vector>
+# include <iomanip>
+# include <sstream>
+#endif
+
+std::string sha1(const std::string& input)
+{
+#ifdef BOTAN_FOUND
+# if BOTAN_VERSION_CODE < BOTAN_VERSION_CODE_FOR(1,11,34)
+ auto sha1 = Botan::HashFunction::create("SHA-1");
+ if (!sha1)
+ throw Botan::Algorithm_Not_Found("SHA-1");
+# else
+ auto sha1 = Botan::HashFunction::create_or_throw("SHA-1");
+# endif
+ sha1->update(input);
+ return Botan::hex_encode(sha1->final(), false);
+#endif
+#ifdef GCRYPT_FOUND
+ const auto hash_length = gcry_md_get_algo_dlen(GCRY_MD_SHA1);
+ std::vector<uint8_t> output(hash_length, {});
+ gcry_md_hash_buffer(GCRY_MD_SHA1, output.data(), input.data(), input.size());
+ std::ostringstream digest;
+ for (std::size_t i = 0; i < hash_length; i++)
+ digest << std::hex << std::setfill('0') << std::setw(2) << static_cast<int>(output[i]);
+ return digest.str();
+#endif
+}
diff --git a/src/utils/sha1.hpp b/src/utils/sha1.hpp
new file mode 100644
index 0000000..6c551ac
--- /dev/null
+++ b/src/utils/sha1.hpp
@@ -0,0 +1,5 @@
+#pragma once
+
+#include <string>
+
+std::string sha1(const std::string& input);
diff --git a/src/utils/split.cpp b/src/utils/split.cpp
new file mode 100644
index 0000000..80f8dae
--- /dev/null
+++ b/src/utils/split.cpp
@@ -0,0 +1,19 @@
+#include <utils/split.hpp>
+#include <sstream>
+
+namespace utils
+{
+ std::vector<std::string> split(const std::string& s, const char delim, const bool allow_empty)
+ {
+ std::vector<std::string> ret;
+ std::stringstream ss(s);
+ std::string item;
+ while (std::getline(ss, item, delim))
+ {
+ if (item.empty() && !allow_empty)
+ continue ;
+ ret.emplace_back(std::move(item));
+ }
+ return ret;
+ }
+}
diff --git a/src/utils/split.hpp b/src/utils/split.hpp
new file mode 100644
index 0000000..3755ef8
--- /dev/null
+++ b/src/utils/split.hpp
@@ -0,0 +1,12 @@
+#pragma once
+
+
+#include <string>
+#include <vector>
+
+namespace utils
+{
+ std::vector<std::string> split(const std::string &s, const char delim, const bool allow_empty=true);
+}
+
+
diff --git a/src/utils/string.cpp b/src/utils/string.cpp
new file mode 100644
index 0000000..635e71a
--- /dev/null
+++ b/src/utils/string.cpp
@@ -0,0 +1,28 @@
+#include <utils/string.hpp>
+#include <utils/encoding.hpp>
+
+bool to_bool(const std::string& val)
+{
+ return (val == "1" || val == "true");
+}
+
+std::vector<std::string> cut(const std::string& val, const std::size_t size)
+{
+ std::vector<std::string> res;
+ std::string::size_type pos = 0;
+ while (pos < val.size())
+ {
+ // Get the number of chars, <= size, that contain only whole
+ // UTF-8 codepoints.
+ std::size_t s = 0;
+ auto codepoint_size = utils::get_next_codepoint_size(val[pos + s]);
+ while (s + codepoint_size <= size && pos + s < val.size())
+ {
+ s += codepoint_size;
+ codepoint_size = utils::get_next_codepoint_size(val[pos + s]);
+ }
+ res.emplace_back(val.substr(pos, s));
+ pos += s;
+ }
+ return res;
+}
diff --git a/src/utils/string.hpp b/src/utils/string.hpp
new file mode 100644
index 0000000..071ce2c
--- /dev/null
+++ b/src/utils/string.hpp
@@ -0,0 +1,8 @@
+#pragma once
+
+
+#include <vector>
+#include <string>
+
+bool to_bool(const std::string& val);
+std::vector<std::string> cut(const std::string& val, const std::size_t size);
diff --git a/src/utils/system.cpp b/src/utils/system.cpp
new file mode 100644
index 0000000..d821dec
--- /dev/null
+++ b/src/utils/system.cpp
@@ -0,0 +1,21 @@
+#include <logger/logger.hpp>
+#include <utils/system.hpp>
+#include <sys/utsname.h>
+#include <cstring>
+
+using namespace std::string_literals;
+
+namespace utils
+{
+std::string get_system_name()
+{
+ struct utsname uts{};
+ const int res = ::uname(&uts);
+ if (res == -1)
+ {
+ log_error("uname failed: ", std::strerror(errno));
+ return "Unknown";
+ }
+ return uts.sysname + " "s + uts.release;
+}
+} \ No newline at end of file
diff --git a/src/utils/system.hpp b/src/utils/system.hpp
new file mode 100644
index 0000000..7ea1677
--- /dev/null
+++ b/src/utils/system.hpp
@@ -0,0 +1,8 @@
+#pragma once
+
+#include <string>
+
+namespace utils
+{
+std::string get_system_name();
+} \ No newline at end of file
diff --git a/src/utils/time.cpp b/src/utils/time.cpp
new file mode 100644
index 0000000..bc2c18d
--- /dev/null
+++ b/src/utils/time.cpp
@@ -0,0 +1,80 @@
+#include <utils/time.hpp>
+#include <ctime>
+
+#include <sstream>
+#include <iomanip>
+#include <locale>
+
+#include "biboumi.h"
+
+namespace utils
+{
+std::string to_string(const std::time_t& timestamp)
+{
+ constexpr std::size_t stamp_size = 21;
+ char date_buf[stamp_size];
+ if (std::strftime(date_buf, stamp_size, "%FT%TZ", std::gmtime(&timestamp)) != stamp_size - 1)
+ return "";
+ return {std::begin(date_buf), std::end(date_buf) - 1};
+}
+
+std::time_t parse_datetime(const std::string& stamp)
+{
+ static const char* format = "%Y-%m-%dT%H:%M:%S";
+ std::tm t = {};
+#ifdef HAS_GET_TIME
+ std::istringstream ss(stamp);
+ ss.imbue(std::locale("C"));
+
+ std::string remainings;
+ ss >> std::get_time(&t, format) >> remainings;
+ if (ss.fail())
+ return -1;
+#else
+ /* Y - m - d T H : M : S */
+ constexpr std::size_t stamp_size_without_tz = 4 + 1 + 2 + 1 + 2 + 1 + 2 + 1 + 2 + 1 + 2;
+ if (!strptime(stamp.data(), format, &t)) {
+ return -1;
+ }
+ const std::string remainings(stamp.data() + stamp_size_without_tz);
+#endif
+
+ if (remainings.empty())
+ return -1;
+
+ std::string timezone;
+ // Skip optional fractions of seconds
+ if (remainings[0] == '.')
+ {
+ const auto pos = remainings.find_first_not_of(".0123456789");
+ timezone = remainings.substr(pos);
+ }
+ else
+ timezone = std::move(remainings);
+
+ if (timezone.compare(0, 1, "Z") != 0)
+ {
+ std::stringstream tz_ss;
+ tz_ss << timezone;
+ int multiplier = -1;
+ char prefix;
+ int hours;
+ char sep;
+ int minutes;
+ tz_ss >> prefix >> hours >> sep >> minutes;
+ if (tz_ss.fail())
+ return -1;
+ if (prefix == '-')
+ multiplier = +1;
+ else if (prefix != '+')
+ return -1;
+
+ t.tm_hour += multiplier * hours;
+ t.tm_min += multiplier * minutes;
+ }
+ return ::timegm(&t);
+}
+
+}
+
+
diff --git a/src/utils/time.hpp b/src/utils/time.hpp
new file mode 100644
index 0000000..c71cd9c
--- /dev/null
+++ b/src/utils/time.hpp
@@ -0,0 +1,10 @@
+#pragma once
+
+#include <ctime>
+#include <string>
+
+namespace utils
+{
+std::string to_string(const std::time_t& timestamp);
+std::time_t parse_datetime(const std::string& stamp);
+} \ No newline at end of file
diff --git a/src/utils/timed_events.cpp b/src/utils/timed_events.cpp
new file mode 100644
index 0000000..26ded82
--- /dev/null
+++ b/src/utils/timed_events.cpp
@@ -0,0 +1,48 @@
+#include <utility>
+#include <utils/timed_events.hpp>
+
+TimedEvent::TimedEvent(std::chrono::steady_clock::time_point&& time_point,
+ std::function<void()> callback, std::string name):
+ time_point(time_point),
+ callback(std::move(callback)),
+ repeat(false),
+ repeat_delay(0),
+ name(std::move(name))
+{
+}
+
+TimedEvent::TimedEvent(std::chrono::milliseconds&& duration,
+ std::function<void()> callback, std::string name):
+ time_point(std::chrono::steady_clock::now() + duration),
+ callback(std::move(callback)),
+ repeat(true),
+ repeat_delay(duration),
+ name(std::move(name))
+{
+}
+
+bool TimedEvent::is_after(const TimedEvent& other) const
+{
+ return this->is_after(other.time_point);
+}
+
+bool TimedEvent::is_after(const std::chrono::steady_clock::time_point& time_point) const
+{
+ return this->time_point > time_point;
+}
+
+std::chrono::milliseconds TimedEvent::get_timeout() const
+{
+ auto diff = std::chrono::duration_cast<std::chrono::milliseconds>(this->time_point - std::chrono::steady_clock::now());
+ return std::max(diff, 0ms);
+}
+
+void TimedEvent::execute() const
+{
+ this->callback();
+}
+
+const std::string& TimedEvent::get_name() const
+{
+ return this->name;
+}
diff --git a/src/utils/timed_events.hpp b/src/utils/timed_events.hpp
new file mode 100644
index 0000000..fa0fc50
--- /dev/null
+++ b/src/utils/timed_events.hpp
@@ -0,0 +1,137 @@
+#pragma once
+
+#include <functional>
+#include <string>
+#include <chrono>
+#include <vector>
+
+using namespace std::literals::chrono_literals;
+
+namespace utils {
+static constexpr std::chrono::milliseconds no_timeout = std::chrono::milliseconds(-1);
+}
+
+class TimedEventsManager;
+
+/**
+ * A callback with an associated date.
+ */
+
+class TimedEvent
+{
+ friend class TimedEventsManager;
+public:
+ /**
+ * An event the occurs only once, at the given time_point
+ */
+ explicit TimedEvent(std::chrono::steady_clock::time_point&& time_point,
+ std::function<void()> callback, std::string name="");
+ explicit TimedEvent(std::chrono::milliseconds&& duration,
+ std::function<void()> callback, std::string name="");
+
+ explicit TimedEvent(TimedEvent&&) = default;
+ TimedEvent& operator=(TimedEvent&&) = default;
+ ~TimedEvent() = default;
+
+ TimedEvent(const TimedEvent&) = delete;
+ TimedEvent& operator=(const TimedEvent&) = delete;
+
+ /**
+ * Whether or not this event happens after the other one.
+ */
+ bool is_after(const TimedEvent& other) const;
+ bool is_after(const std::chrono::steady_clock::time_point& time_point) const;
+ /**
+ * Return the duration difference between now and the event time point.
+ * If the difference would be negative (i.e. the event is expired), the
+ * returned value is 0 instead. The value cannot then be negative.
+ */
+ std::chrono::milliseconds get_timeout() const;
+ void execute() const;
+ const std::string& get_name() const;
+
+private:
+ /**
+ * The next time point at which the event is executed.
+ */
+ std::chrono::steady_clock::time_point time_point;
+ /**
+ * The function to execute.
+ */
+ std::function<void()> callback;
+ /**
+ * Whether or not this events repeats itself until it is destroyed.
+ */
+ bool repeat;
+ /**
+ * This value is added to the time_point each time the event is executed,
+ * if repeat is true. Otherwise it is ignored.
+ */
+ std::chrono::milliseconds repeat_delay;
+ /**
+ * A name that is used to identify that event. If you want to find your
+ * event (for example if you want to cancel it), the name should be
+ * unique.
+ */
+ std::string name;
+};
+
+/**
+ * A class managing a list of TimedEvents.
+ * They are sorted, new events can be added, removed, fetch, etc.
+ */
+
+class TimedEventsManager
+{
+public:
+ ~TimedEventsManager() = default;
+
+ TimedEventsManager(const TimedEventsManager&) = delete;
+ TimedEventsManager(TimedEventsManager&&) = delete;
+ TimedEventsManager& operator=(const TimedEventsManager&) = delete;
+ TimedEventsManager& operator=(TimedEventsManager&&) = delete;
+
+ /**
+ * Return the unique instance of this class
+ */
+ static TimedEventsManager& instance();
+ /**
+ * Add an event to the list of managed events. The list is sorted after
+ * this call.
+ */
+ void add_event(TimedEvent&& event);
+ /**
+ * Returns the duration, in milliseconds, between now and the next
+ * available event. If the event is already expired (the duration is
+ * negative), 0 is returned instead (as in “it's not too late, execute it
+ * now”)
+ * Returns a negative value if no event is available.
+ */
+ std::chrono::milliseconds get_timeout() const;
+ /**
+ * Execute all the expired events (if their expiration time is exactly
+ * now, or before now). The event is then removed from the list. If the
+ * event does repeat, its expiration time is updated and it is reinserted
+ * in the list at the correct position.
+ * Returns the number of executed events.
+ */
+ std::size_t execute_expired_events();
+ /**
+ * Remove (and thus cancel) all the timed events with the given name.
+ * Returns the number of canceled events.
+ */
+ std::size_t cancel(const std::string& name);
+ /**
+ * Return the number of managed events.
+ */
+ std::size_t size() const;
+ /**
+ * Return a pointer to the first event with the given name. If none
+ * is found, returns nullptr.
+ */
+ const TimedEvent* find_event(const std::string& name) const;
+
+private:
+ std::vector<TimedEvent> events;
+ explicit TimedEventsManager() = default;
+};
diff --git a/src/utils/timed_events_manager.cpp b/src/utils/timed_events_manager.cpp
new file mode 100644
index 0000000..75e6338
--- /dev/null
+++ b/src/utils/timed_events_manager.cpp
@@ -0,0 +1,87 @@
+#include <utils/timed_events.hpp>
+
+#include <algorithm>
+
+TimedEventsManager& TimedEventsManager::instance()
+{
+ static TimedEventsManager inst;
+ return inst;
+}
+
+void TimedEventsManager::add_event(TimedEvent&& event)
+{
+ for (auto it = this->events.begin(); it != this->events.end(); ++it)
+ {
+ if (it->is_after(event))
+ {
+ this->events.emplace(it, std::move(event));
+ return;
+ }
+ }
+ this->events.emplace_back(std::move(event));
+}
+
+std::chrono::milliseconds TimedEventsManager::get_timeout() const
+{
+ if (this->events.empty())
+ return utils::no_timeout;
+ return this->events.front().get_timeout();
+}
+
+std::size_t TimedEventsManager::execute_expired_events()
+{
+ std::size_t count = 0;
+ const auto now = std::chrono::steady_clock::now();
+ for (auto it = this->events.begin(); it != this->events.end();)
+ {
+ if (!it->is_after(now))
+ {
+ TimedEvent copy(std::move(*it));
+ it = this->events.erase(it);
+ ++count;
+ copy.execute();
+ if (copy.repeat)
+ {
+ copy.time_point += copy.repeat_delay;
+ this->add_event(std::move(copy));
+ }
+ continue;
+ }
+ else
+ break;
+ }
+ return count;
+}
+
+std::size_t TimedEventsManager::cancel(const std::string& name)
+{
+ std::size_t res = 0;
+ for (auto it = this->events.begin(); it != this->events.end();)
+ {
+ if (it->get_name() == name)
+ {
+ it = this->events.erase(it);
+ res++;
+ }
+ else
+ ++it;
+ }
+ return res;
+}
+
+
+
+std::size_t TimedEventsManager::size() const
+{
+ return this->events.size();
+}
+
+const TimedEvent* TimedEventsManager::find_event(const std::string& name) const
+{
+ const auto it = std::find_if(this->events.begin(), this->events.end(), [&name](const TimedEvent& o) {
+ return o.get_name() == name;
+ });
+ if (it == this->events.end())
+ return nullptr;
+ return &*it;
+}
diff --git a/src/utils/tolower.cpp b/src/utils/tolower.cpp
new file mode 100644
index 0000000..3e518bd
--- /dev/null
+++ b/src/utils/tolower.cpp
@@ -0,0 +1,13 @@
+#include <utils/tolower.hpp>
+
+namespace utils
+{
+ std::string tolower(const std::string& original)
+ {
+ std::string res;
+ res.reserve(original.size());
+ for (const char c: original)
+ res += static_cast<char>(std::tolower(c));
+ return res;
+ }
+}
diff --git a/src/utils/tolower.hpp b/src/utils/tolower.hpp
new file mode 100644
index 0000000..650e05d
--- /dev/null
+++ b/src/utils/tolower.hpp
@@ -0,0 +1,11 @@
+#pragma once
+
+
+#include <string>
+
+namespace utils
+{
+ std::string tolower(const std::string& original);
+}
+
+
diff --git a/src/utils/xdg.cpp b/src/utils/xdg.cpp
new file mode 100644
index 0000000..b0fa7be
--- /dev/null
+++ b/src/utils/xdg.cpp
@@ -0,0 +1,29 @@
+#include <utils/xdg.hpp>
+#include <cstdlib>
+
+#include "biboumi.h"
+
+std::string xdg_path(const std::string& filename, const char* env_var)
+{
+ const char* xdg_home = ::getenv(env_var);
+ if (xdg_home && xdg_home[0] == '/')
+ return std::string{xdg_home} + "/" PROJECT_NAME "/" + filename;
+ else
+ {
+ const char* home = ::getenv("HOME");
+ if (home)
+ return std::string{home} + "/" ".config" "/" PROJECT_NAME "/" + filename;
+ else
+ return filename;
+ }
+}
+
+std::string xdg_config_path(const std::string& filename)
+{
+ return xdg_path(filename, "XDG_CONFIG_HOME");
+}
+
+std::string xdg_data_path(const std::string& filename)
+{
+ return xdg_path(filename, "XDG_DATA_HOME");
+}
diff --git a/src/utils/xdg.hpp b/src/utils/xdg.hpp
new file mode 100644
index 0000000..7be6922
--- /dev/null
+++ b/src/utils/xdg.hpp
@@ -0,0 +1,12 @@
+#pragma once
+
+
+#include <string>
+
+/**
+ * Returns a path for the given filename, according to the XDG base
+ * directory specification, see
+ * http://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
+ */
+std::string xdg_config_path(const std::string& filename);
+std::string xdg_data_path(const std::string& filename);
diff --git a/src/xmpp/adhoc_command.cpp b/src/xmpp/adhoc_command.cpp
new file mode 100644
index 0000000..e02bf35
--- /dev/null
+++ b/src/xmpp/adhoc_command.cpp
@@ -0,0 +1,81 @@
+#include <utility>
+#include <xmpp/adhoc_command.hpp>
+#include <xmpp/xmpp_component.hpp>
+#include <utils/reload.hpp>
+
+using namespace std::string_literals;
+
+AdhocCommand::AdhocCommand(std::vector<AdhocStep>&& callbacks, std::string name, const bool admin_only):
+ name(std::move(name)),
+ callbacks(std::move(callbacks)),
+ admin_only(admin_only)
+{
+}
+
+bool AdhocCommand::is_admin_only() const
+{
+ return this->admin_only;
+}
+
+void PingStep1(XmppComponent&, AdhocSession&, XmlNode& command_node)
+{
+ XmlSubNode note(command_node, "note");
+ note["type"] = "info";
+ note.set_inner("Pong");
+}
+
+void HelloStep1(XmppComponent&, AdhocSession&, XmlNode& command_node)
+{
+ XmlSubNode x(command_node, "jabber:x:data:x");
+ x["type"] = "form";
+ XmlSubNode title(x, "title");
+ title.set_inner("Configure your name.");
+ XmlSubNode instructions(x, "instructions");
+ instructions.set_inner("Please provide your name.");
+ XmlSubNode name_field(x, "field");
+ name_field["var"] = "name";
+ name_field["type"] = "text-single";
+ name_field["label"] = "Your name";
+ XmlSubNode required(name_field, "required");
+}
+
+void HelloStep2(XmppComponent&, AdhocSession& session, XmlNode& command_node)
+{
+ // Find out if the name was provided in the form.
+ if (const XmlNode* x = command_node.get_child("x", "jabber:x:data"))
+ {
+ const XmlNode* name_field = nullptr;
+ for (const XmlNode* field: x->get_children("field", "jabber:x:data"))
+ if (field->get_tag("var") == "name")
+ {
+ name_field = field;
+ break;
+ }
+ if (name_field)
+ {
+ if (const XmlNode* value = name_field->get_child("value", "jabber:x:data"))
+ {
+ const std::string value_str = value->get_inner();
+ command_node.delete_all_children();
+ XmlSubNode note(command_node, "note");
+ note["type"] = "info";
+ note.set_inner("Hello "s + value_str + "!"s);
+ return;
+ }
+ }
+ }
+ command_node.delete_all_children();
+ XmlSubNode error(command_node, ADHOC_NS":error");
+ error["type"] = "modify";
+ XmlSubNode condition(error, STANZA_NS":bad-request");
+ session.terminate();
+}
+
+void Reload(XmppComponent&, AdhocSession&, XmlNode& command_node)
+{
+ ::reload_process();
+ command_node.delete_all_children();
+ XmlSubNode note(command_node, "note");
+ note["type"] = "info";
+ note.set_inner("Configuration reloaded.");
+}
diff --git a/src/xmpp/adhoc_command.hpp b/src/xmpp/adhoc_command.hpp
new file mode 100644
index 0000000..c00d9e6
--- /dev/null
+++ b/src/xmpp/adhoc_command.hpp
@@ -0,0 +1,44 @@
+#pragma once
+
+/**
+ * Describe an ad-hoc command.
+ *
+ * Can only have zero or one step for now. When execution is requested, it
+ * can return a result immediately, or provide a form to be filled, and
+ * provide a result once the filled form is received.
+ */
+
+#include <xmpp/adhoc_session.hpp>
+
+#include <functional>
+#include <string>
+
+class AdhocCommand
+{
+ friend class AdhocSession;
+public:
+ AdhocCommand(std::vector<AdhocStep>&& callbacks, std::string name, const bool admin_only);
+ ~AdhocCommand() = default;
+ AdhocCommand(const AdhocCommand&) = default;
+ AdhocCommand(AdhocCommand&&) = default;
+ AdhocCommand& operator=(AdhocCommand&&) = delete;
+ AdhocCommand& operator=(const AdhocCommand&) = delete;
+
+ const std::string name;
+
+ bool is_admin_only() const;
+
+private:
+ /**
+ * A command may have one or more steps. Each step is a different
+ * callback, inserting things into a <command/> XmlNode and calling
+ * methods of an AdhocSession.
+ */
+ std::vector<AdhocStep> callbacks;
+ const bool admin_only;
+};
+
+void PingStep1(XmppComponent&, AdhocSession& session, XmlNode& command_node);
+void HelloStep1(XmppComponent&, AdhocSession& session, XmlNode& command_node);
+void HelloStep2(XmppComponent&, AdhocSession& session, XmlNode& command_node);
+void Reload(XmppComponent&, AdhocSession& session, XmlNode& command_node);
diff --git a/src/xmpp/adhoc_commands_handler.cpp b/src/xmpp/adhoc_commands_handler.cpp
new file mode 100644
index 0000000..040d0ff
--- /dev/null
+++ b/src/xmpp/adhoc_commands_handler.cpp
@@ -0,0 +1,111 @@
+#include <xmpp/adhoc_commands_handler.hpp>
+#include <xmpp/xmpp_component.hpp>
+
+#include <utils/timed_events.hpp>
+#include <logger/logger.hpp>
+#include <config/config.hpp>
+#include <xmpp/jid.hpp>
+
+#include <iostream>
+
+using namespace std::string_literals;
+
+const std::map<const std::string, const AdhocCommand>& AdhocCommandsHandler::get_commands() const
+{
+ return this->commands;
+}
+
+void AdhocCommandsHandler::add_command(std::string name, AdhocCommand command)
+{
+ const auto found = this->commands.find(name);
+ if (found != this->commands.end())
+ throw std::runtime_error("Trying to add an ad-hoc command that already exist: "s + name);
+ this->commands.emplace(std::make_pair(std::move(name), std::move(command)));
+}
+
+XmlNode AdhocCommandsHandler::handle_request(const std::string& executor_jid, const std::string& to, XmlNode command_node)
+{
+ std::string action = command_node.get_tag("action");
+ if (action.empty())
+ action = "execute";
+ command_node.del_tag("action");
+
+ Jid jid(executor_jid);
+
+ const std::string node = command_node.get_tag("node");
+ auto command_it = this->commands.find(node);
+ if (command_it == this->commands.end())
+ {
+ XmlSubNode error(command_node, ADHOC_NS":error");
+ error["type"] = "cancel";
+ XmlSubNode condition(error, STANZA_NS":item-not-found");
+ }
+ else if (command_it->second.is_admin_only() &&
+ Config::get("admin", "") != jid.local + "@" + jid.domain)
+ {
+ XmlSubNode error(command_node, ADHOC_NS":error");
+ error["type"] = "cancel";
+ XmlSubNode condition(error, STANZA_NS":forbidden");
+ }
+ else
+ {
+ std::string sessionid = command_node.get_tag("sessionid");
+ if (sessionid.empty())
+ { // create a new session, with a new id
+ sessionid = XmppComponent::next_id();
+ command_node["sessionid"] = sessionid;
+ this->sessions.emplace(std::piecewise_construct,
+ std::forward_as_tuple(sessionid, executor_jid),
+ std::forward_as_tuple(command_it->second, executor_jid, to));
+ TimedEventsManager::instance().add_event(TimedEvent(std::chrono::steady_clock::now() + 3600s,
+ std::bind(&AdhocCommandsHandler::remove_session, this, sessionid, executor_jid),
+ "adhocsession"s + sessionid + executor_jid));
+ }
+ auto session_it = this->sessions.find(std::make_pair(sessionid, executor_jid));
+ if ((session_it != this->sessions.end()) &&
+ (action == "execute" || action == "next" || action == "complete"))
+ {
+ // execute the step
+ AdhocSession& session = session_it->second;
+ const AdhocStep& step = session.get_next_step();
+ step(this->xmpp_component, session, command_node);
+ if (session.remaining_steps() == 0 ||
+ session.is_terminated())
+ {
+ this->sessions.erase(session_it);
+ command_node["status"] = "completed";
+ TimedEventsManager::instance().cancel("adhocsession"s + sessionid + executor_jid);
+ }
+ else
+ {
+ command_node["status"] = "executing";
+ XmlSubNode actions(command_node, "actions");
+ XmlSubNode next(actions, "next");
+ }
+ }
+ else if (action == "cancel")
+ {
+ this->sessions.erase(session_it);
+ command_node["status"] = "canceled";
+ TimedEventsManager::instance().cancel("adhocsession"s + sessionid + executor_jid);
+ }
+ else // unsupported action
+ {
+ XmlSubNode error(command_node, ADHOC_NS":error");
+ error["type"] = "modify";
+ XmlSubNode condition(error, STANZA_NS":bad-request");
+ }
+ }
+ return command_node;
+}
+
+void AdhocCommandsHandler::remove_session(const std::string& session_id, const std::string& initiator_jid)
+{
+ auto session_it = this->sessions.find(std::make_pair(session_id, initiator_jid));
+ if (session_it != this->sessions.end())
+ {
+ this->sessions.erase(session_it);
+ return ;
+ }
+ log_error("Tried to remove ad-hoc session for [", session_id, ", ", initiator_jid, "] but none found");
+}
diff --git a/src/xmpp/adhoc_commands_handler.hpp b/src/xmpp/adhoc_commands_handler.hpp
new file mode 100644
index 0000000..e37d913
--- /dev/null
+++ b/src/xmpp/adhoc_commands_handler.hpp
@@ -0,0 +1,71 @@
+#pragma once
+
+/**
+ * Manage a list of available AdhocCommands and the list of ongoing
+ * AdhocCommandSessions.
+ */
+
+#include <xmpp/adhoc_command.hpp>
+#include <xmpp/xmpp_stanza.hpp>
+
+#include <utility>
+#include <string>
+#include <map>
+
+class AdhocCommandsHandler
+{
+public:
+ explicit AdhocCommandsHandler(XmppComponent& xmpp_component):
+ xmpp_component(xmpp_component),
+ commands{}
+ { }
+ ~AdhocCommandsHandler() = default;
+
+ AdhocCommandsHandler(const AdhocCommandsHandler&) = delete;
+ AdhocCommandsHandler(AdhocCommandsHandler&&) = delete;
+ AdhocCommandsHandler& operator=(const AdhocCommandsHandler&) = delete;
+ AdhocCommandsHandler& operator=(AdhocCommandsHandler&&) = delete;
+
+ /**
+ * Returns the list of available commands.
+ */
+ const std::map<const std::string, const AdhocCommand>& get_commands() const;
+ /**
+ * Add a command into the list, associated with the given name
+ */
+ void add_command(std::string name, AdhocCommand command);
+ /**
+ * Find the requested command, create a new session or use an existing
+ * one, and process the request (provide a new form, an error, or a
+ * result).
+ *
+ * Returns a (moved) XmlNode that will be inserted in the iq response. It
+ * should be a <command/> node containing one or more useful children. If
+ * it contains an <error/> node, the iq response will have an error type.
+ *
+ * Takes a copy of the <command/> node so we can actually edit it and use
+ * it as our return value.
+ */
+ XmlNode handle_request(const std::string& executor_jid, const std::string& to, XmlNode command_node);
+ /**
+ * Remove the session from the list. This is done to avoid filling the
+ * memory with waiting session (for example due to a client that starts
+ * multi-steps command but never finishes them).
+ */
+ void remove_session(const std::string& session_id, const std::string& initiator_jid);
+private:
+ /**
+ * To access basically anything in the gateway.
+ */
+ XmppComponent& xmpp_component;
+ /**
+ * The list of all available commands.
+ */
+ std::map<const std::string, const AdhocCommand> commands;
+ /**
+ * The list of all currently on-going commands.
+ *
+ * Of the form: {{session_id, owner_jid}, session}.
+ */
+ std::map<std::pair<const std::string, const std::string>, AdhocSession> sessions;
+};
diff --git a/src/xmpp/adhoc_session.cpp b/src/xmpp/adhoc_session.cpp
new file mode 100644
index 0000000..e2d6c0e
--- /dev/null
+++ b/src/xmpp/adhoc_session.cpp
@@ -0,0 +1,35 @@
+#include <xmpp/adhoc_session.hpp>
+#include <xmpp/adhoc_command.hpp>
+
+#include <cassert>
+
+AdhocSession::AdhocSession(const AdhocCommand& command, const std::string& owner_jid,
+ const std::string& to_jid):
+ command(command),
+ owner_jid(owner_jid),
+ to_jid(to_jid),
+ current_step(0),
+ terminated(false)
+{
+}
+
+const AdhocStep& AdhocSession::get_next_step()
+{
+ assert(this->current_step < this->command.callbacks.size());
+ return this->command.callbacks[this->current_step++];
+}
+
+size_t AdhocSession::remaining_steps() const
+{
+ return this->command.callbacks.size() - this->current_step;
+}
+
+bool AdhocSession::is_terminated() const
+{
+ return this->terminated;
+}
+
+void AdhocSession::terminate()
+{
+ this->terminated = true;
+}
diff --git a/src/xmpp/adhoc_session.hpp b/src/xmpp/adhoc_session.hpp
new file mode 100644
index 0000000..0de8d13
--- /dev/null
+++ b/src/xmpp/adhoc_session.hpp
@@ -0,0 +1,88 @@
+#pragma once
+
+#include <xmpp/xmpp_stanza.hpp>
+
+#include <functional>
+#include <string>
+#include <map>
+
+class XmppComponent;
+
+class AdhocCommand;
+class AdhocSession;
+
+/**
+ * A function executed as an ad-hoc command step. It takes a <command/>
+ * XmlNode and modifies it accordingly (inserting for example an <error/>
+ * node, or a data form…).
+ */
+using AdhocStep = std::function<void(XmppComponent&, AdhocSession&, XmlNode&)>;
+
+class AdhocSession
+{
+public:
+ explicit AdhocSession(const AdhocCommand& command, const std::string& owner_jid,
+ const std::string& to_jid);
+ ~AdhocSession() = default;
+
+ AdhocSession(const AdhocSession&) = delete;
+ AdhocSession(AdhocSession&&) = delete;
+ AdhocSession& operator=(const AdhocSession&) = delete;
+ AdhocSession& operator=(AdhocSession&&) = delete;
+
+ /**
+ * Return the function to be executed, found in our AdhocCommand, for the
+ * current_step. And increment the current_step.
+ */
+ const AdhocStep& get_next_step();
+ /**
+ * Return the number of remaining steps.
+ */
+ size_t remaining_steps() const;
+ /**
+ * This may be modified by an AdhocStep, to indicate that this session
+ * should no longer exist, because we encountered an error, and we can't
+ * execute any more step of it.
+ */
+ void terminate();
+ bool is_terminated() const;
+ std::string get_target_jid() const
+ {
+ return this->to_jid;
+ }
+ std::string get_owner_jid() const
+ {
+ return this->owner_jid;
+ }
+
+private:
+ /**
+ * A reference of the command concerned by this session. Used for example
+ * to get the next step of that command, things like that.
+ */
+ const AdhocCommand& command;
+ /**
+ * The full JID of the XMPP user that created this session by executing
+ * the first step of a command. Only that JID must be allowed to access
+ * this session.
+ */
+ const std::string& owner_jid;
+ /**
+ * The 'to' attribute in the request stanza. This is the target of the current session.
+ */
+ const std::string& to_jid;
+ /**
+ * The current step we are at. It starts at zero. It is used to index the
+ * associated AdhocCommand::callbacks vector.
+ */
+ size_t current_step;
+ bool terminated;
+
+public:
+ /**
+ * A map to store various things that we may want to remember between two
+ * steps of the same session. A step can insert any value associated to
+ * any key in there.
+ */
+ std::map<std::string, std::string> vars;
+};
diff --git a/src/xmpp/auth.cpp b/src/xmpp/auth.cpp
new file mode 100644
index 0000000..8a34a4e
--- /dev/null
+++ b/src/xmpp/auth.cpp
@@ -0,0 +1,8 @@
+#include <xmpp/auth.hpp>
+
+#include <utils/sha1.hpp>
+
+std::string get_handshake_digest(const std::string& stream_id, const std::string& secret)
+{
+ return sha1(stream_id + secret);
+}
diff --git a/src/xmpp/auth.hpp b/src/xmpp/auth.hpp
new file mode 100644
index 0000000..34a2116
--- /dev/null
+++ b/src/xmpp/auth.hpp
@@ -0,0 +1,6 @@
+#pragma once
+
+#include <string>
+
+std::string get_handshake_digest(const std::string& stream_id, const std::string& secret);
+
diff --git a/src/xmpp/biboumi_adhoc_commands.cpp b/src/xmpp/biboumi_adhoc_commands.cpp
index 003b901..4129517 100644
--- a/src/xmpp/biboumi_adhoc_commands.cpp
+++ b/src/xmpp/biboumi_adhoc_commands.cpp
@@ -7,6 +7,7 @@
#include <utils/split.hpp>
#include <xmpp/jid.hpp>
#include <algorithm>
+#include <sstream>
#include <iomanip>
#include <biboumi.h>
@@ -23,47 +24,38 @@ using namespace std::string_literals;
void DisconnectUserStep1(XmppComponent& xmpp_component, AdhocSession&, XmlNode& command_node)
{
- auto& biboumi_component = static_cast<BiboumiComponent&>(xmpp_component);
+ auto& biboumi_component = dynamic_cast<BiboumiComponent&>(xmpp_component);
- XmlNode x("jabber:x:data:x");
+ XmlSubNode x(command_node, "jabber:x:data:x");
x["type"] = "form";
- XmlNode title("title");
+ XmlSubNode title(x, "title");
title.set_inner("Disconnect a user from the gateway");
- x.add_child(std::move(title));
- XmlNode instructions("instructions");
+ XmlSubNode instructions(x, "instructions");
instructions.set_inner("Choose a user JID and a quit message");
- x.add_child(std::move(instructions));
- XmlNode jids_field("field");
+ XmlSubNode jids_field(x, "field");
jids_field["var"] = "jids";
jids_field["type"] = "list-multi";
jids_field["label"] = "The JIDs to disconnect";
- XmlNode required("required");
- jids_field.add_child(std::move(required));
+ XmlSubNode required(jids_field, "required");
for (Bridge* bridge: biboumi_component.get_bridges())
{
- XmlNode option("option");
+ XmlSubNode option(jids_field, "option");
option["label"] = bridge->get_jid();
- XmlNode value("value");
+ XmlSubNode value(option, "value");
value.set_inner(bridge->get_jid());
- option.add_child(std::move(value));
- jids_field.add_child(std::move(option));
}
- x.add_child(std::move(jids_field));
- XmlNode message_field("field");
+ XmlSubNode message_field(x, "field");
message_field["var"] = "quit-message";
message_field["type"] = "text-single";
message_field["label"] = "Quit message";
- XmlNode message_value("value");
+ XmlSubNode message_value(message_field, "value");
message_value.set_inner("Disconnected by admin");
- message_field.add_child(std::move(message_value));
- x.add_child(std::move(message_field));
- command_node.add_child(std::move(x));
}
void DisconnectUserStep2(XmppComponent& xmpp_component, AdhocSession& session, XmlNode& command_node)
{
- auto& biboumi_component = static_cast<BiboumiComponent&>(xmpp_component);
+ auto& biboumi_component = dynamic_cast<BiboumiComponent&>(xmpp_component);
// Find out if the jids, and the quit message are provided in the form.
std::string quit_message;
@@ -97,7 +89,7 @@ void DisconnectUserStep2(XmppComponent& xmpp_component, AdhocSession& session, X
}
command_node.delete_all_children();
- XmlNode note("note");
+ XmlSubNode note(command_node, "note");
note["type"] = "info";
if (num == 0)
note.set_inner("No user were disconnected.");
@@ -105,15 +97,12 @@ void DisconnectUserStep2(XmppComponent& xmpp_component, AdhocSession& session, X
note.set_inner("1 user has been disconnected.");
else
note.set_inner(std::to_string(num) + " users have been disconnected.");
- command_node.add_child(std::move(note));
return;
}
}
- XmlNode error(ADHOC_NS":error");
+ XmlSubNode error(command_node, ADHOC_NS":error");
error["type"] = "modify";
- XmlNode condition(STANZA_NS":bad-request");
- error.add_child(std::move(condition));
- command_node.add_child(std::move(error));
+ XmlSubNode condition(error, STANZA_NS":bad-request");
session.terminate();
}
@@ -126,48 +115,61 @@ void ConfigureGlobalStep1(XmppComponent&, AdhocSession& session, XmlNode& comman
auto options = Database::get_global_options(owner.bare());
- XmlNode x("jabber:x:data:x");
+ XmlSubNode x(command_node, "jabber:x:data:x");
x["type"] = "form";
- XmlNode title("title");
+ XmlSubNode title(x, "title");
title.set_inner("Configure some global default settings.");
- x.add_child(std::move(title));
- XmlNode instructions("instructions");
+ XmlSubNode instructions(x, "instructions");
instructions.set_inner("Edit the form, to configure your global settings for the component.");
- x.add_child(std::move(instructions));
-
- XmlNode required("required");
-
- XmlNode max_histo_length("field");
- max_histo_length["var"] = "max_history_length";
- max_histo_length["type"] = "text-single";
- max_histo_length["label"] = "Max history length";
- max_histo_length["desc"] = "The maximum number of lines in the history that the server sends when joining a channel";
-
- XmlNode value("value");
- value.set_inner(std::to_string(options.maxHistoryLength.value()));
- max_histo_length.add_child(std::move(value));
- x.add_child(std::move(max_histo_length));
-
- XmlNode record_history("field");
- record_history["var"] = "record_history";
- record_history["type"] = "boolean";
- record_history["label"] = "Record history";
- record_history["desc"] = "Whether to save the messages into the database, or not";
-
- value.set_name("value");
- if (options.recordHistory.value())
- value.set_inner("true");
- else
- value.set_inner("false");
- record_history.add_child(std::move(value));
- x.add_child(std::move(record_history));
- command_node.add_child(std::move(x));
+ {
+ XmlSubNode max_histo_length(x, "field");
+ max_histo_length["var"] = "max_history_length";
+ max_histo_length["type"] = "text-single";
+ max_histo_length["label"] = "Max history length";
+ max_histo_length["desc"] = "The maximum number of lines in the history that the server sends when joining a channel";
+ {
+ XmlSubNode value(max_histo_length, "value");
+ value.set_inner(std::to_string(options.col<Database::MaxHistoryLength>()));
+ }
+ }
+
+ {
+ XmlSubNode record_history(x, "field");
+ record_history["var"] = "record_history";
+ record_history["type"] = "boolean";
+ record_history["label"] = "Record history";
+ record_history["desc"] = "Whether to save the messages into the database, or not";
+ {
+ XmlSubNode value(record_history, "value");
+ value.set_name("value");
+ if (options.col<Database::RecordHistory>())
+ value.set_inner("true");
+ else
+ value.set_inner("false");
+ }
+ }
+
+ {
+ XmlSubNode persistent(x, "field");
+ persistent["var"] = "persistent";
+ persistent["type"] = "boolean";
+ persistent["label"] = "Make all channels persistent";
+ persistent["desc"] = "If true, all channels will be persistent";
+ {
+ XmlSubNode value(persistent, "value");
+ value.set_name("value");
+ if (options.col<Database::Persistent>())
+ value.set_inner("true");
+ else
+ value.set_inner("false");
+ }
+ }
}
void ConfigureGlobalStep2(XmppComponent& xmpp_component, AdhocSession& session, XmlNode& command_node)
{
- BiboumiComponent& biboumi_component = static_cast<BiboumiComponent&>(xmpp_component);
+ auto& biboumi_component = dynamic_cast<BiboumiComponent&>(xmpp_component);
const XmlNode* x = command_node.get_child("x", "jabber:x:data");
if (x)
@@ -180,31 +182,31 @@ void ConfigureGlobalStep2(XmppComponent& xmpp_component, AdhocSession& session,
if (field->get_tag("var") == "max_history_length" &&
value && !value->get_inner().empty())
- options.maxHistoryLength = value->get_inner();
+ options.col<Database::MaxHistoryLength>() = atoi(value->get_inner().data());
else if (field->get_tag("var") == "record_history" &&
value && !value->get_inner().empty())
{
- options.recordHistory = to_bool(value->get_inner());
+ options.col<Database::RecordHistory>() = to_bool(value->get_inner());
Bridge* bridge = biboumi_component.find_user_bridge(owner.bare());
if (bridge)
- bridge->set_record_history(options.recordHistory.value());
+ bridge->set_record_history(options.col<Database::RecordHistory>());
}
+ else if (field->get_tag("var") == "persistent" &&
+ value)
+ options.col<Database::Persistent>() = to_bool(value->get_inner());
}
- options.update();
+ options.save(Database::db);
command_node.delete_all_children();
- XmlNode note("note");
+ XmlSubNode note(command_node, "note");
note["type"] = "info";
note.set_inner("Configuration successfully applied.");
- command_node.add_child(std::move(note));
return;
}
- XmlNode error(ADHOC_NS":error");
+ XmlSubNode error(command_node, ADHOC_NS":error");
error["type"] = "modify";
- XmlNode condition(STANZA_NS":bad-request");
- error.add_child(std::move(condition));
- command_node.add_child(std::move(error));
+ XmlSubNode condition(error, STANZA_NS":bad-request");
session.terminate();
}
@@ -218,161 +220,143 @@ void ConfigureIrcServerStep1(XmppComponent&, AdhocSession& session, XmlNode& com
auto options = Database::get_irc_server_options(owner.local + "@" + owner.domain,
server_domain);
- XmlNode x("jabber:x:data:x");
+ XmlSubNode x(command_node, "jabber:x:data:x");
x["type"] = "form";
- XmlNode title("title");
+ XmlSubNode title(x, "title");
title.set_inner("Configure the IRC server "s + server_domain);
- x.add_child(std::move(title));
- XmlNode instructions("instructions");
+ XmlSubNode instructions(x, "instructions");
instructions.set_inner("Edit the form, to configure the settings of the IRC server "s + server_domain);
- x.add_child(std::move(instructions));
-
- XmlNode required("required");
- XmlNode ports("field");
- ports["var"] = "ports";
- ports["type"] = "text-multi";
- ports["label"] = "Ports";
- ports["desc"] = "List of ports to try, without TLS. Defaults: 6667.";
- auto vals = utils::split(options.ports.value(), ';', false);
- for (const auto& val: vals)
- {
- XmlNode ports_value("value");
- ports_value.set_inner(val);
- ports.add_child(std::move(ports_value));
- }
- ports.add_child(required);
- x.add_child(std::move(ports));
+ {
+ XmlSubNode ports(x, "field");
+ ports["var"] = "ports";
+ ports["type"] = "text-multi";
+ ports["label"] = "Ports";
+ ports["desc"] = "List of ports to try, without TLS. Defaults: 6667.";
+ for (const auto& val: utils::split(options.col<Database::Ports>(), ';', false))
+ {
+ XmlSubNode ports_value(ports, "value");
+ ports_value.set_inner(val);
+ }
+ }
#ifdef BOTAN_FOUND
- XmlNode tls_ports("field");
- tls_ports["var"] = "tls_ports";
- tls_ports["type"] = "text-multi";
- tls_ports["label"] = "TLS ports";
- tls_ports["desc"] = "List of ports to try, with TLS. Defaults: 6697, 6670.";
- vals = utils::split(options.tlsPorts.value(), ';', false);
- for (const auto& val: vals)
- {
- XmlNode tls_ports_value("value");
- tls_ports_value.set_inner(val);
- tls_ports.add_child(std::move(tls_ports_value));
- }
- tls_ports.add_child(required);
- x.add_child(std::move(tls_ports));
-
- XmlNode verify_cert("field");
- verify_cert["var"] = "verify_cert";
- verify_cert["type"] = "boolean";
- verify_cert["label"] = "Verify certificate";
- verify_cert["desc"] = "Whether or not to abort the connection if the server’s TLS certificate is invalid";
- XmlNode verify_cert_value("value");
- if (options.verifyCert.value())
- verify_cert_value.set_inner("true");
- else
- verify_cert_value.set_inner("false");
- verify_cert.add_child(std::move(verify_cert_value));
- x.add_child(std::move(verify_cert));
-
- XmlNode fingerprint("field");
- fingerprint["var"] = "fingerprint";
- fingerprint["type"] = "text-single";
- fingerprint["label"] = "SHA-1 fingerprint of the TLS certificate to trust.";
- if (!options.trustedFingerprint.value().empty())
- {
- XmlNode fingerprint_value("value");
- fingerprint_value.set_inner(options.trustedFingerprint.value());
- fingerprint.add_child(std::move(fingerprint_value));
- }
- fingerprint.add_child(required);
- x.add_child(std::move(fingerprint));
+ {
+ XmlSubNode tls_ports(x, "field");
+ tls_ports["var"] = "tls_ports";
+ tls_ports["type"] = "text-multi";
+ tls_ports["label"] = "TLS ports";
+ tls_ports["desc"] = "List of ports to try, with TLS. Defaults: 6697, 6670.";
+ for (const auto& val: utils::split(options.col<Database::TlsPorts>(), ';', false))
+ {
+ XmlSubNode tls_ports_value(tls_ports, "value");
+ tls_ports_value.set_inner(val);
+ }
+ }
+
+ {
+ XmlSubNode verify_cert(x, "field");
+ verify_cert["var"] = "verify_cert";
+ verify_cert["type"] = "boolean";
+ verify_cert["label"] = "Verify certificate";
+ verify_cert["desc"] = "Whether or not to abort the connection if the server’s TLS certificate is invalid";
+ XmlSubNode verify_cert_value(verify_cert, "value");
+ if (options.col<Database::VerifyCert>())
+ verify_cert_value.set_inner("true");
+ else
+ verify_cert_value.set_inner("false");
+ }
+
+ {
+ XmlSubNode fingerprint(x, "field");
+ fingerprint["var"] = "fingerprint";
+ fingerprint["type"] = "text-single";
+ fingerprint["label"] = "SHA-1 fingerprint of the TLS certificate to trust.";
+ if (!options.col<Database::TrustedFingerprint>().empty())
+ {
+ XmlSubNode fingerprint_value(fingerprint, "value");
+ fingerprint_value.set_inner(options.col<Database::TrustedFingerprint>());
+ }
+ }
#endif
-
- XmlNode pass("field");
- pass["var"] = "pass";
- pass["type"] = "text-private";
- pass["label"] = "Server password (to be used in a PASS command when connecting)";
- if (!options.pass.value().empty())
- {
- XmlNode pass_value("value");
- pass_value.set_inner(options.pass.value());
- pass.add_child(std::move(pass_value));
- }
- pass.add_child(required);
- x.add_child(std::move(pass));
-
- XmlNode after_cnt_cmd("field");
- after_cnt_cmd["var"] = "after_connect_command";
- after_cnt_cmd["type"] = "text-single";
- after_cnt_cmd["desc"] = "Custom IRC command sent after the connection is established with the server.";
- after_cnt_cmd["label"] = "After-connection IRC command";
- if (!options.afterConnectionCommand.value().empty())
- {
- XmlNode after_cnt_cmd_value("value");
- after_cnt_cmd_value.set_inner(options.afterConnectionCommand.value());
- after_cnt_cmd.add_child(std::move(after_cnt_cmd_value));
- }
- after_cnt_cmd.add_child(required);
- x.add_child(std::move(after_cnt_cmd));
+ {
+ XmlSubNode pass(x, "field");
+ pass["var"] = "pass";
+ pass["type"] = "text-private";
+ pass["label"] = "Server password";
+ pass["desc"] = "Will be used in a PASS command when connecting";
+ if (!options.col<Database::Pass>().empty())
+ {
+ XmlSubNode pass_value(pass, "value");
+ pass_value.set_inner(options.col<Database::Pass>());
+ }
+ }
+
+ {
+ XmlSubNode after_cnt_cmd(x, "field");
+ after_cnt_cmd["var"] = "after_connect_command";
+ after_cnt_cmd["type"] = "text-single";
+ after_cnt_cmd["desc"] = "Custom IRC command sent after the connection is established with the server.";
+ after_cnt_cmd["label"] = "After-connection IRC command";
+ if (!options.col<Database::AfterConnectionCommand>().empty())
+ {
+ XmlSubNode after_cnt_cmd_value(after_cnt_cmd, "value");
+ after_cnt_cmd_value.set_inner(options.col<Database::AfterConnectionCommand>());
+ }
+ }
if (Config::get("realname_customization", "true") == "true")
{
- XmlNode username("field");
- username["var"] = "username";
- username["type"] = "text-single";
- username["label"] = "Username";
- if (!options.username.value().empty())
- {
- XmlNode username_value("value");
- username_value.set_inner(options.username.value());
- username.add_child(std::move(username_value));
- }
- username.add_child(required);
- x.add_child(std::move(username));
-
- XmlNode realname("field");
- realname["var"] = "realname";
- realname["type"] = "text-single";
- realname["label"] = "Realname";
- if (!options.realname.value().empty())
- {
- XmlNode realname_value("value");
- realname_value.set_inner(options.realname.value());
- realname.add_child(std::move(realname_value));
- }
- realname.add_child(required);
- x.add_child(std::move(realname));
+ {
+ XmlSubNode username(x, "field");
+ username["var"] = "username";
+ username["type"] = "text-single";
+ username["label"] = "Username";
+ if (!options.col<Database::Username>().empty())
+ {
+ XmlSubNode username_value(username, "value");
+ username_value.set_inner(options.col<Database::Username>());
+ }
+ }
+
+ {
+ XmlSubNode realname(x, "field");
+ realname["var"] = "realname";
+ realname["type"] = "text-single";
+ realname["label"] = "Realname";
+ if (!options.col<Database::Realname>().empty())
+ {
+ XmlSubNode realname_value(realname, "value");
+ realname_value.set_inner(options.col<Database::Realname>());
+ }
+ }
}
- XmlNode encoding_out("field");
+ {
+ XmlSubNode encoding_out(x, "field");
encoding_out["var"] = "encoding_out";
encoding_out["type"] = "text-single";
encoding_out["desc"] = "The encoding used when sending messages to the IRC server.";
encoding_out["label"] = "Out encoding";
- if (!options.encodingOut.value().empty())
- {
- XmlNode encoding_out_value("value");
- encoding_out_value.set_inner(options.encodingOut.value());
- encoding_out.add_child(std::move(encoding_out_value));
- }
- encoding_out.add_child(required);
- x.add_child(std::move(encoding_out));
-
- XmlNode encoding_in("field");
- encoding_in["var"] = "encoding_in";
- encoding_in["type"] = "text-single";
- encoding_in["desc"] = "The encoding used to decode message received from the IRC server.";
- encoding_in["label"] = "In encoding";
- if (!options.encodingIn.value().empty())
+ if (!options.col<Database::EncodingOut>().empty())
{
- XmlNode encoding_in_value("value");
- encoding_in_value.set_inner(options.encodingIn.value());
- encoding_in.add_child(std::move(encoding_in_value));
+ XmlSubNode encoding_out_value(encoding_out, "value");
+ encoding_out_value.set_inner(options.col<Database::EncodingOut>());
}
- encoding_in.add_child(required);
- x.add_child(std::move(encoding_in));
-
-
- command_node.add_child(std::move(x));
+ }
+
+ {
+ XmlSubNode encoding_in(x, "field");
+ encoding_in["var"] = "encoding_in";
+ encoding_in["type"] = "text-single";
+ encoding_in["desc"] = "The encoding used to decode message received from the IRC server.";
+ encoding_in["label"] = "In encoding";
+ if (!options.col<Database::EncodingIn>().empty())
+ {
+ XmlSubNode encoding_in_value(encoding_in, "value");
+ encoding_in_value.set_inner(options.col<Database::EncodingIn>());
+ }
+ }
}
void ConfigureIrcServerStep2(XmppComponent&, AdhocSession& session, XmlNode& command_node)
@@ -396,7 +380,7 @@ void ConfigureIrcServerStep2(XmppComponent&, AdhocSession& session, XmlNode& com
std::string ports;
for (const auto& val: values)
ports += val->get_inner() + ";";
- options.ports = ports;
+ options.col<Database::Ports>() = ports;
}
#ifdef BOTAN_FOUND
@@ -405,31 +389,31 @@ void ConfigureIrcServerStep2(XmppComponent&, AdhocSession& session, XmlNode& com
std::string ports;
for (const auto& val: values)
ports += val->get_inner() + ";";
- options.tlsPorts = ports;
+ options.col<Database::TlsPorts>() = ports;
}
else if (field->get_tag("var") == "verify_cert" && value
&& !value->get_inner().empty())
{
auto val = to_bool(value->get_inner());
- options.verifyCert = val;
+ options.col<Database::VerifyCert>() = val;
}
else if (field->get_tag("var") == "fingerprint" && value &&
!value->get_inner().empty())
{
- options.trustedFingerprint = value->get_inner();
+ options.col<Database::TrustedFingerprint>() = value->get_inner();
}
#endif // BOTAN_FOUND
else if (field->get_tag("var") == "pass" &&
value && !value->get_inner().empty())
- options.pass = value->get_inner();
+ options.col<Database::Pass>() = value->get_inner();
else if (field->get_tag("var") == "after_connect_command" &&
value && !value->get_inner().empty())
- options.afterConnectionCommand = value->get_inner();
+ options.col<Database::AfterConnectionCommand>() = value->get_inner();
else if (field->get_tag("var") == "username" &&
value && !value->get_inner().empty())
@@ -437,37 +421,34 @@ void ConfigureIrcServerStep2(XmppComponent&, AdhocSession& session, XmlNode& com
auto username = value->get_inner();
// The username must not contain spaces
std::replace(username.begin(), username.end(), ' ', '_');
- options.username = username;
+ options.col<Database::Username>() = username;
}
else if (field->get_tag("var") == "realname" &&
value && !value->get_inner().empty())
- options.realname = value->get_inner();
+ options.col<Database::Realname>() = value->get_inner();
else if (field->get_tag("var") == "encoding_out" &&
value && !value->get_inner().empty())
- options.encodingOut = value->get_inner();
+ options.col<Database::EncodingOut>() = value->get_inner();
else if (field->get_tag("var") == "encoding_in" &&
value && !value->get_inner().empty())
- options.encodingIn = value->get_inner();
+ options.col<Database::EncodingIn>() = value->get_inner();
}
- options.update();
+ options.save(Database::db);
command_node.delete_all_children();
- XmlNode note("note");
+ XmlSubNode note(command_node, "note");
note["type"] = "info";
note.set_inner("Configuration successfully applied.");
- command_node.add_child(std::move(note));
return;
}
- XmlNode error(ADHOC_NS":error");
+ XmlSubNode error(command_node, ADHOC_NS":error");
error["type"] = "modify";
- XmlNode condition(STANZA_NS":bad-request");
- error.add_child(std::move(condition));
- command_node.add_child(std::move(error));
+ XmlSubNode condition(error, STANZA_NS":bad-request");
session.terminate();
}
@@ -475,90 +456,165 @@ void ConfigureIrcChannelStep1(XmppComponent&, AdhocSession& session, XmlNode& co
{
const Jid owner(session.get_owner_jid());
const Jid target(session.get_target_jid());
+
+ insert_irc_channel_configuration_form(command_node, owner, target);
+}
+
+void insert_irc_channel_configuration_form(XmlNode& node, const Jid& requester, const Jid& target)
+{
const Iid iid(target.local, {});
- auto options = Database::get_irc_channel_options_with_server_default(owner.local + "@" + owner.domain,
- iid.get_server(), iid.get_local());
- XmlNode x("jabber:x:data:x");
+ auto options = Database::get_irc_channel_options_with_server_default(requester.local + "@" + requester.domain,
+ iid.get_server(), iid.get_local());
+ XmlSubNode x(node, "jabber:x:data:x");
x["type"] = "form";
- XmlNode title("title");
+ XmlSubNode title(x, "title");
title.set_inner("Configure the IRC channel "s + iid.get_local() + " on server "s + iid.get_server());
- x.add_child(std::move(title));
- XmlNode instructions("instructions");
+ XmlSubNode instructions(x, "instructions");
instructions.set_inner("Edit the form, to configure the settings of the IRC channel "s + iid.get_local());
- x.add_child(std::move(instructions));
-
- XmlNode required("required");
- XmlNode encoding_out("field");
- encoding_out["var"] = "encoding_out";
- encoding_out["type"] = "text-single";
- encoding_out["desc"] = "The encoding used when sending messages to the IRC server. Defaults to the server's “out encoding” if unset for the channel";
- encoding_out["label"] = "Out encoding";
- if (!options.encodingOut.value().empty())
+ {
+ XmlSubNode record_history(x, "field");
+ record_history["var"] = "record_history";
+ record_history["type"] = "list-single";
+ record_history["label"] = "Record history for this channel";
+ record_history["desc"] = "If unset, the value is the one configured globally";
{
- XmlNode encoding_out_value("value");
- encoding_out_value.set_inner(options.encodingOut.value());
- encoding_out.add_child(std::move(encoding_out_value));
+ // Value selected by default
+ XmlSubNode value(record_history, "value");
+ value.set_inner(options.col<Database::RecordHistoryOptional>().to_string());
}
- encoding_out.add_child(required);
- x.add_child(std::move(encoding_out));
-
- XmlNode encoding_in("field");
- encoding_in["var"] = "encoding_in";
- encoding_in["type"] = "text-single";
- encoding_in["desc"] = "The encoding used to decode message received from the IRC server. Defaults to the server's “in encoding” if unset for the channel";
- encoding_in["label"] = "In encoding";
- if (!options.encodingIn.value().empty())
+ // All three possible values
+ for (const auto& val: {"unset", "true", "false"})
+ {
+ XmlSubNode option(record_history, "option");
+ option["label"] = val;
+ XmlSubNode value(option, "value");
+ value.set_inner(val);
+ }
+ }
+
+ {
+ XmlSubNode encoding_out(x, "field");
+ encoding_out["var"] = "encoding_out";
+ encoding_out["type"] = "text-single";
+ encoding_out["desc"] = "The encoding used when sending messages to the IRC server. Defaults to the server's “out encoding” if unset for the channel";
+ encoding_out["label"] = "Out encoding";
+ if (!options.col<Database::EncodingOut>().empty())
+ {
+ XmlSubNode encoding_out_value(encoding_out, "value");
+ encoding_out_value.set_inner(options.col<Database::EncodingOut>());
+ }
+ }
+
+ {
+ XmlSubNode encoding_in(x, "field");
+ encoding_in["var"] = "encoding_in";
+ encoding_in["type"] = "text-single";
+ encoding_in["desc"] = "The encoding used to decode message received from the IRC server. Defaults to the server's “in encoding” if unset for the channel";
+ encoding_in["label"] = "In encoding";
+ if (!options.col<Database::EncodingIn>().empty())
+ {
+ XmlSubNode encoding_in_value(encoding_in, "value");
+ encoding_in_value.set_inner(options.col<Database::EncodingIn>());
+ }
+ }
+
+ {
+ XmlSubNode persistent(x, "field");
+ persistent["var"] = "persistent";
+ persistent["type"] = "boolean";
+ persistent["desc"] = "If set to true, when all XMPP clients have left this channel, biboumi will stay idle in it, without sending a PART command.";
+ persistent["label"] = "Persistent";
{
- XmlNode encoding_in_value("value");
- encoding_in_value.set_inner(options.encodingIn.value());
- encoding_in.add_child(std::move(encoding_in_value));
+ XmlSubNode value(persistent, "value");
+ value.set_name("value");
+ if (options.col<Database::Persistent>())
+ value.set_inner("true");
+ else
+ value.set_inner("false");
}
- encoding_in.add_child(required);
- x.add_child(std::move(encoding_in));
+ }
+}
+
+void ConfigureIrcChannelStep2(XmppComponent& xmpp_component, AdhocSession& session, XmlNode& command_node)
+{
+ const Jid owner(session.get_owner_jid());
+ const Jid target(session.get_target_jid());
- command_node.add_child(std::move(x));
+ if (handle_irc_channel_configuration_form(xmpp_component, command_node, owner, target))
+ {
+ command_node.delete_all_children();
+ XmlSubNode note(command_node, "note");
+ note["type"] = "info";
+ note.set_inner("Configuration successfully applied.");
+ }
+ else
+ {
+ XmlSubNode error(command_node, ADHOC_NS":error");
+ error["type"] = "modify";
+ XmlSubNode condition(error, STANZA_NS":bad-request");
+ session.terminate();
+ }
}
-void ConfigureIrcChannelStep2(XmppComponent&, AdhocSession& session, XmlNode& command_node)
+bool handle_irc_channel_configuration_form(XmppComponent& xmpp_component, const XmlNode& node, const Jid& requester, const Jid& target)
{
- const XmlNode* x = command_node.get_child("x", "jabber:x:data");
+ const XmlNode* x = node.get_child("x", "jabber:x:data");
if (x)
{
- const Jid owner(session.get_owner_jid());
- const Jid target(session.get_target_jid());
- const Iid iid(target.local, {});
- auto options = Database::get_irc_channel_options(owner.local + "@" + owner.domain,
- iid.get_server(), iid.get_local());
- for (const XmlNode* field: x->get_children("field", "jabber:x:data"))
+ if (x->get_tag("type") == "submit")
{
- const XmlNode* value = field->get_child("value", "jabber:x:data");
+ const Iid iid(target.local, {});
+ auto options = Database::get_irc_channel_options(requester.bare(),
+ iid.get_server(), iid.get_local());
+ for (const XmlNode *field: x->get_children("field", "jabber:x:data"))
+ {
+ const XmlNode *value = field->get_child("value", "jabber:x:data");
- if (field->get_tag("var") == "encoding_out" &&
- value && !value->get_inner().empty())
- options.encodingOut = value->get_inner();
+ if (field->get_tag("var") == "encoding_out" &&
+ value && !value->get_inner().empty())
+ options.col<Database::EncodingOut>() = value->get_inner();
- else if (field->get_tag("var") == "encoding_in" &&
- value && !value->get_inner().empty())
- options.encodingIn = value->get_inner();
- }
+ else if (field->get_tag("var") == "encoding_in" &&
+ value && !value->get_inner().empty())
+ options.col<Database::EncodingIn>() = value->get_inner();
- options.update();
+ else if (field->get_tag("var") == "persistent" &&
+ value)
+ options.col<Database::Persistent>() = to_bool(value->get_inner());
+ else if (field->get_tag("var") == "record_history" &&
+ value && !value->get_inner().empty())
+ {
+ OptionalBool& database_value = options.col<Database::RecordHistoryOptional>();
+ if (value->get_inner() == "true")
+ database_value.set_value(true);
+ else if (value->get_inner() == "false")
+ database_value.set_value(false);
+ else
+ database_value.unset();
+ auto& biboumi_component = dynamic_cast<BiboumiComponent&>(xmpp_component);
+ Bridge* bridge = biboumi_component.find_user_bridge(requester.bare());
+ if (bridge)
+ {
+ if (database_value.is_set)
+ bridge->set_record_history(database_value.value);
+ else
+ { // It is unset, we need to fetch the Global option, to
+ // know if it’s enabled or not
+ auto g_options = Database::get_global_options(requester.bare());
+ bridge->set_record_history(g_options.col<Database::RecordHistory>());
+ }
+ }
+ }
- command_node.delete_all_children();
- XmlNode note("note");
- note["type"] = "info";
- note.set_inner("Configuration successfully applied.");
- command_node.add_child(std::move(note));
- return;
+ }
+
+ options.save(Database::db);
+ }
+ return true;
}
- XmlNode error(ADHOC_NS":error");
- error["type"] = "modify";
- XmlNode condition(STANZA_NS":bad-request");
- error.add_child(std::move(condition));
- command_node.add_child(std::move(error));
- session.terminate();
+ return false;
}
#endif // USE_DATABASE
@@ -573,33 +629,26 @@ void DisconnectUserFromServerStep1(XmppComponent& xmpp_component, AdhocSession&
}
else
{ // Send a form to select the user to disconnect
- auto& biboumi_component = static_cast<BiboumiComponent&>(xmpp_component);
+ auto& biboumi_component = dynamic_cast<BiboumiComponent&>(xmpp_component);
- XmlNode x("jabber:x:data:x");
+ XmlSubNode x(command_node, "jabber:x:data:x");
x["type"] = "form";
- XmlNode title("title");
+ XmlSubNode title(x, "title");
title.set_inner("Disconnect a user from selected IRC servers");
- x.add_child(std::move(title));
- XmlNode instructions("instructions");
+ XmlSubNode instructions(x, "instructions");
instructions.set_inner("Choose a user JID");
- x.add_child(std::move(instructions));
- XmlNode jids_field("field");
+ XmlSubNode jids_field(x, "field");
jids_field["var"] = "jid";
jids_field["type"] = "list-single";
jids_field["label"] = "The JID to disconnect";
- XmlNode required("required");
- jids_field.add_child(std::move(required));
+ XmlSubNode required(jids_field, "required");
for (Bridge* bridge: biboumi_component.get_bridges())
{
- XmlNode option("option");
+ XmlSubNode option(jids_field, "option");
option["label"] = bridge->get_jid();
- XmlNode value("value");
+ XmlSubNode value(option, "value");
value.set_inner(bridge->get_jid());
- option.add_child(std::move(value));
- jids_field.add_child(std::move(option));
}
- x.add_child(std::move(jids_field));
- command_node.add_child(std::move(x));
}
}
@@ -625,55 +674,44 @@ void DisconnectUserFromServerStep2(XmppComponent& xmpp_component, AdhocSession&
// Send a data form to let the user choose which server to disconnect the
// user from
command_node.delete_all_children();
- auto& biboumi_component = static_cast<BiboumiComponent&>(xmpp_component);
+ auto& biboumi_component = dynamic_cast<BiboumiComponent&>(xmpp_component);
- XmlNode x("jabber:x:data:x");
+ XmlSubNode x(command_node, "jabber:x:data:x");
x["type"] = "form";
- XmlNode title("title");
+ XmlSubNode title(x, "title");
title.set_inner("Disconnect a user from selected IRC servers");
- x.add_child(std::move(title));
- XmlNode instructions("instructions");
+ XmlSubNode instructions(x, "instructions");
instructions.set_inner("Choose one or more servers to disconnect this JID from");
- x.add_child(std::move(instructions));
- XmlNode jids_field("field");
+ XmlSubNode jids_field(x, "field");
jids_field["var"] = "irc-servers";
jids_field["type"] = "list-multi";
jids_field["label"] = "The servers to disconnect from";
- XmlNode required("required");
- jids_field.add_child(std::move(required));
+ XmlSubNode required(jids_field, "required");
Bridge* bridge = biboumi_component.find_user_bridge(jid_to_disconnect);
if (!bridge || bridge->get_irc_clients().empty())
{
- XmlNode note("note");
+ XmlSubNode note(command_node, "note");
note["type"] = "info";
note.set_inner("User "s + jid_to_disconnect + " is not connected to any IRC server.");
- command_node.add_child(std::move(note));
session.terminate();
return ;
}
for (const auto& pair: bridge->get_irc_clients())
{
- XmlNode option("option");
+ XmlSubNode option(jids_field, "option");
option["label"] = pair.first;
- XmlNode value("value");
+ XmlSubNode value(option, "value");
value.set_inner(pair.first);
- option.add_child(std::move(value));
- jids_field.add_child(std::move(option));
}
- x.add_child(std::move(jids_field));
- XmlNode message_field("field");
+ XmlSubNode message_field(x, "field");
message_field["var"] = "quit-message";
message_field["type"] = "text-single";
message_field["label"] = "Quit message";
- XmlNode message_value("value");
+ XmlSubNode message_value(message_field, "value");
message_value.set_inner("Killed by admin");
- message_field.add_child(std::move(message_value));
- x.add_child(std::move(message_field));
-
- command_node.add_child(std::move(x));
}
void DisconnectUserFromServerStep3(XmppComponent& xmpp_component, AdhocSession& session, XmlNode& command_node)
@@ -701,7 +739,7 @@ void DisconnectUserFromServerStep3(XmppComponent& xmpp_component, AdhocSession&
}
}
- auto& biboumi_component = static_cast<BiboumiComponent&>(xmpp_component);
+ auto& biboumi_component = dynamic_cast<BiboumiComponent&>(xmpp_component);
Bridge* bridge = biboumi_component.find_user_bridge(jid_to_disconnect);
auto& clients = bridge->get_irc_clients();
@@ -718,19 +756,18 @@ void DisconnectUserFromServerStep3(XmppComponent& xmpp_component, AdhocSession&
}
}
command_node.delete_all_children();
- XmlNode note("note");
+ XmlSubNode note(command_node, "note");
note["type"] = "info";
std::string msg = jid_to_disconnect + " was disconnected from " + std::to_string(number) + " IRC server";
if (number > 1)
msg += "s";
msg += ".";
note.set_inner(msg);
- command_node.add_child(std::move(note));
}
void GetIrcConnectionInfoStep1(XmppComponent& component, AdhocSession& session, XmlNode& command_node)
{
- BiboumiComponent& biboumi_component = static_cast<BiboumiComponent&>(component);
+ auto& biboumi_component = dynamic_cast<BiboumiComponent&>(component);
const Jid owner(session.get_owner_jid());
const Jid target(session.get_target_jid());
@@ -741,10 +778,9 @@ void GetIrcConnectionInfoStep1(XmppComponent& component, AdhocSession& session,
utils::ScopeGuard sg([&message, &command_node]()
{
command_node.delete_all_children();
- XmlNode note("note");
+ XmlSubNode note(command_node, "note");
note["type"] = "info";
note.set_inner(message);
- command_node.add_child(std::move(note));
});
Bridge* bridge = biboumi_component.get_user_bridge(owner.bare());
diff --git a/src/xmpp/biboumi_adhoc_commands.hpp b/src/xmpp/biboumi_adhoc_commands.hpp
index b5fce61..cb6acb9 100644
--- a/src/xmpp/biboumi_adhoc_commands.hpp
+++ b/src/xmpp/biboumi_adhoc_commands.hpp
@@ -4,6 +4,7 @@
#include <xmpp/adhoc_command.hpp>
#include <xmpp/adhoc_session.hpp>
#include <xmpp/xmpp_stanza.hpp>
+#include <xmpp/jid.hpp>
class XmppComponent;
@@ -17,7 +18,9 @@ void ConfigureIrcServerStep1(XmppComponent&, AdhocSession& session, XmlNode& com
void ConfigureIrcServerStep2(XmppComponent&, AdhocSession& session, XmlNode& command_node);
void ConfigureIrcChannelStep1(XmppComponent&, AdhocSession& session, XmlNode& command_node);
+void insert_irc_channel_configuration_form(XmlNode& node, const Jid& requester, const Jid& target);
void ConfigureIrcChannelStep2(XmppComponent&, AdhocSession& session, XmlNode& command_node);
+bool handle_irc_channel_configuration_form(XmppComponent&, const XmlNode& node, const Jid& requester, const Jid& target);
void DisconnectUserFromServerStep1(XmppComponent&, AdhocSession& session, XmlNode& command_node);
void DisconnectUserFromServerStep2(XmppComponent&, AdhocSession& session, XmlNode& command_node);
diff --git a/src/xmpp/biboumi_component.cpp b/src/xmpp/biboumi_component.cpp
index d6782e2..32f3968 100644
--- a/src/xmpp/biboumi_component.cpp
+++ b/src/xmpp/biboumi_component.cpp
@@ -8,7 +8,6 @@
#include <xmpp/biboumi_adhoc_commands.hpp>
#include <bridge/list_element.hpp>
#include <config/config.hpp>
-#include <utils/sha1.hpp>
#include <utils/time.hpp>
#include <xmpp/jid.hpp>
@@ -17,11 +16,8 @@
#include <cstdlib>
-#include <louloulibs.h>
#include <biboumi.h>
-#include <uuid/uuid.h>
-
#ifdef SYSTEMD_FOUND
# include <systemd/sd-daemon.h>
#endif
@@ -45,7 +41,7 @@ static std::set<std::string> kickable_errors{
};
-BiboumiComponent::BiboumiComponent(std::shared_ptr<Poller> poller, const std::string& hostname, const std::string& secret):
+BiboumiComponent::BiboumiComponent(std::shared_ptr<Poller>& poller, const std::string& hostname, const std::string& secret):
XmppComponent(poller, hostname, secret),
irc_server_adhoc_commands_handler(*this),
irc_channel_adhoc_commands_handler(*this)
@@ -85,10 +81,8 @@ BiboumiComponent::BiboumiComponent(std::shared_ptr<Poller> poller, const std::st
void BiboumiComponent::shutdown()
{
- for (auto it = this->bridges.begin(); it != this->bridges.end(); ++it)
- {
- it->second->shutdown("Gateway shutdown");
- }
+ for (auto& pair: this->bridges)
+ pair.second->shutdown("Gateway shutdown");
}
void BiboumiComponent::clean()
@@ -137,7 +131,7 @@ void BiboumiComponent::handle_presence(const Stanza& stanza)
// stanza_error.disable() call at the end of the function.
std::string error_type("cancel");
std::string error_name("internal-server-error");
- utils::ScopeGuard stanza_error([&](){
+ utils::ScopeGuard stanza_error([this, &from_str, &to_str, &id, &error_type, &error_name](){
this->send_stanza_error("presence", from_str, to_str, id,
error_type, error_name, "");
});
@@ -150,7 +144,7 @@ void BiboumiComponent::handle_presence(const Stanza& stanza)
{
const std::string own_nick = bridge->get_own_nick(iid);
if (!own_nick.empty() && own_nick != to.resource)
- bridge->send_irc_nick_change(iid, to.resource);
+ bridge->send_irc_nick_change(iid, to.resource, from.resource);
const XmlNode* x = stanza.get_child("x", MUC_NS);
const XmlNode* password = x ? x->get_child("password", MUC_NS): nullptr;
bridge->join_irc_channel(iid, to.resource, password ? password->get_inner(): "",
@@ -162,6 +156,15 @@ void BiboumiComponent::handle_presence(const Stanza& stanza)
bridge->leave_irc_channel(std::move(iid), status ? status->get_inner() : "", from.resource);
}
}
+ else if (iid.type == Iid::Type::Server || iid.type == Iid::Type::None)
+ {
+ if (type == "subscribe")
+ { // Auto-accept any subscription request for an IRC server
+ this->accept_subscription(to_str, from.bare());
+ this->ask_subscription(to_str, from.bare());
+ }
+
+ }
else
{
// A user wants to join an invalid IRC channel, return a presence error to him/her
@@ -171,10 +174,11 @@ void BiboumiComponent::handle_presence(const Stanza& stanza)
}
catch (const IRCNotConnected& ex)
{
- this->send_stanza_error("presence", from_str, to_str, id,
- "cancel", "remote-server-not-found",
- "Not connected to IRC server "s + ex.hostname,
- true);
+ if (type != "unavailable")
+ this->send_stanza_error("presence", from_str, to_str, id,
+ "cancel", "remote-server-not-found",
+ "Not connected to IRC server "s + ex.hostname,
+ true);
}
stanza_error.disable();
}
@@ -197,7 +201,7 @@ void BiboumiComponent::handle_message(const Stanza& stanza)
std::string error_type("cancel");
std::string error_name("internal-server-error");
- utils::ScopeGuard stanza_error([&](){
+ utils::ScopeGuard stanza_error([this, &from_str, &to_str, &id, &error_type, &error_name](){
this->send_stanza_error("message", from_str, to_str, id,
error_type, error_name, "");
});
@@ -280,7 +284,7 @@ void BiboumiComponent::handle_message(const Stanza& stanza)
}
// We MUST return an iq, whatever happens, except if the type is
-// "result".
+// "result" or "error".
// To do this, we use a scopeguard. If an exception is raised somewhere, an
// iq of type error "internal-server-error" is sent. If we handle the
// request properly (by calling a function that registers an iq to be sent
@@ -316,7 +320,7 @@ void BiboumiComponent::handle_iq(const Stanza& stanza)
// the scopeguard.
std::string error_type("cancel");
std::string error_name("internal-server-error");
- utils::ScopeGuard stanza_error([&](){
+ utils::ScopeGuard stanza_error([this, &from, &to_str, &id, &error_type, &error_name](){
this->send_stanza_error("iq", from, to_str, id,
error_type, error_name, "");
});
@@ -344,7 +348,7 @@ void BiboumiComponent::handle_iq(const Stanza& stanza)
bridge->send_irc_kick(iid, nick, reason, id, from);
}
else
- bridge->forward_affiliation_role_change(iid, nick, affiliation, role);
+ bridge->forward_affiliation_role_change(iid, from, nick, affiliation, role, id);
stanza_error.disable();
}
}
@@ -386,6 +390,11 @@ void BiboumiComponent::handle_iq(const Stanza& stanza)
if (this->handle_mam_request(stanza))
stanza_error.disable();
}
+ else if ((query = stanza.get_child("query", MUC_OWNER_NS)))
+ {
+ if (this->handle_room_configuration_form(*query, from, to, id))
+ stanza_error.disable();
+ }
#endif
}
else if (type == "get")
@@ -414,7 +423,12 @@ void BiboumiComponent::handle_iq(const Stanza& stanza)
}
else if (iid.type == Iid::Type::Channel)
{
- if (node == MUC_TRAFFIC_NS)
+ if (node.empty())
+ {
+ this->send_irc_channel_disco_info(id, from, to_str);
+ stanza_error.disable();
+ }
+ else if (node == MUC_TRAFFIC_NS)
{
this->send_irc_channel_muc_traffic_info(id, from, to_str);
stanza_error.disable();
@@ -492,6 +506,8 @@ void BiboumiComponent::handle_iq(const Stanza& stanza)
rs_info.max = std::atoi(max->get_inner().data());
}
+ if (rs_info.max == -1)
+ rs_info.max = 100;
bridge->send_irc_channel_list_request(iid, id, from, std::move(rs_info));
stanza_error.disable();
}
@@ -516,6 +532,13 @@ void BiboumiComponent::handle_iq(const Stanza& stanza)
}
stanza_error.disable();
}
+#ifdef USE_DATABASE
+ else if ((query = stanza.get_child("query", MUC_OWNER_NS)))
+ {
+ if (this->handle_room_configuration_form_request(from, to, id))
+ stanza_error.disable();
+ }
+#endif
}
else if (type == "result")
{
@@ -548,6 +571,16 @@ void BiboumiComponent::handle_iq(const Stanza& stanza)
}
}
}
+ else if (type == "error")
+ {
+ stanza_error.disable();
+ const auto it = this->waiting_iq.find(id);
+ if (it != this->waiting_iq.end())
+ {
+ it->second(bridge, stanza);
+ this->waiting_iq.erase(it);
+ }
+ }
}
catch (const IRCNotConnected& ex)
{
@@ -570,13 +603,11 @@ bool BiboumiComponent::handle_mam_request(const Stanza& stanza)
Jid to(stanza.get_tag("to"));
const XmlNode* query = stanza.get_child("query", MAM_NS);
- std::string query_id;
- if (query)
- query_id = query->get_tag("queryid");
Iid iid(to.local, {'#', '&'});
- if (iid.type == Iid::Type::Channel && to.resource.empty())
+ if (query && iid.type == Iid::Type::Channel && to.resource.empty())
{
+ const std::string query_id = query->get_tag("queryid");
std::string start;
std::string end;
const XmlNode* x = query->get_child("x", DATAFORM_NS);
@@ -600,10 +631,24 @@ bool BiboumiComponent::handle_mam_request(const Stanza& stanza)
}
}
}
- const auto lines = Database::get_muc_logs(from.bare(), iid.get_local(), iid.get_server(), -1, start, end);
- for (const db::MucLogLine& line: lines)
+ const XmlNode* set = query->get_child("set", RSM_NS);
+ int limit = -1;
+ if (set)
+ {
+ const XmlNode* max = set->get_child("max", RSM_NS);
+ if (max)
+ limit = std::atoi(max->get_inner().data());
+ }
+ // If the archive is really big, and the client didn’t specify any
+ // limit, we avoid flooding it: we set an arbitrary max limit.
+ if (limit == -1 && start.empty() && end.empty())
+ {
+ limit = 100;
+ }
+ const auto lines = Database::get_muc_logs(from.bare(), iid.get_local(), iid.get_server(), limit, start, end);
+ for (const Database::MucLogLine& line: lines)
{
- if (!line.nick.value().empty())
+ if (!line.col<Database::Nick>().empty())
this->send_archived_message(line, to.full(), from.full(), query_id);
}
this->send_iq_result_full_jid(id, from.full(), to.full());
@@ -612,42 +657,80 @@ bool BiboumiComponent::handle_mam_request(const Stanza& stanza)
return false;
}
-void BiboumiComponent::send_archived_message(const db::MucLogLine& log_line, const std::string& from, const std::string& to,
+void BiboumiComponent::send_archived_message(const Database::MucLogLine& log_line, const std::string& from, const std::string& to,
const std::string& queryid)
{
- Stanza message("message");
+ Stanza message("message");
+ {
message["from"] = from;
message["to"] = to;
- XmlNode result("result");
+ XmlSubNode result(message, "result");
result["xmlns"] = MAM_NS;
if (!queryid.empty())
result["queryid"] = queryid;
- result["id"] = log_line.uuid.value();
+ result["id"] = log_line.col<Database::Uuid>();
- XmlNode forwarded("forwarded");
+ XmlSubNode forwarded(result, "forwarded");
forwarded["xmlns"] = FORWARD_NS;
- XmlNode delay("delay");
+ XmlSubNode delay(forwarded, "delay");
delay["xmlns"] = DELAY_NS;
- delay["stamp"] = utils::to_string(log_line.date.value().timeStamp());
-
- forwarded.add_child(std::move(delay));
+ delay["stamp"] = utils::to_string(log_line.col<Database::Date>());
- XmlNode submessage("message");
+ XmlSubNode submessage(forwarded, "message");
submessage["xmlns"] = CLIENT_NS;
- submessage["from"] = from + "/" + log_line.nick.value();
+ submessage["from"] = from + "/" + log_line.col<Database::Nick>();
submessage["type"] = "groupchat";
- XmlNode body("body");
- body.set_inner(log_line.body.value());
- submessage.add_child(std::move(body));
+ XmlSubNode body(submessage, "body");
+ body.set_inner(log_line.col<Database::Body>());
+ }
+ this->send_stanza(message);
+}
+
+bool BiboumiComponent::handle_room_configuration_form_request(const std::string& from, const Jid& to, const std::string& id)
+{
+ Iid iid(to.local, {'#', '&'});
- forwarded.add_child(std::move(submessage));
- result.add_child(std::move(forwarded));
- message.add_child(std::move(result));
+ if (iid.type != Iid::Type::Channel)
+ return false;
- this->send_stanza(message);
+ Stanza iq("iq");
+ {
+ iq["from"] = to.full();
+ iq["to"] = from;
+ iq["id"] = id;
+ iq["type"] = "result";
+ XmlSubNode query(iq, "query");
+ query["xmlns"] = MUC_OWNER_NS;
+ Jid requester(from);
+ insert_irc_channel_configuration_form(query, requester, to);
+ }
+ this->send_stanza(iq);
+ return true;
+}
+
+bool BiboumiComponent::handle_room_configuration_form(const XmlNode& query, const std::string &from, const Jid &to, const std::string &id)
+{
+ Iid iid(to.local, {'#', '&'});
+
+ if (iid.type != Iid::Type::Channel)
+ return false;
+
+ Jid requester(from);
+ if (!handle_irc_channel_configuration_form(*this, query, requester, to))
+ return false;
+
+ Stanza iq("iq");
+ iq["type"] = "result";
+ iq["from"] = to.full();
+ iq["to"] = from;
+ iq["id"] = id;
+
+ this->send_stanza(iq);
+
+ return true;
}
#endif
@@ -681,32 +764,31 @@ Bridge* BiboumiComponent::find_user_bridge(const std::string& full_jid)
std::vector<Bridge*> BiboumiComponent::get_bridges() const
{
std::vector<Bridge*> res;
- for (auto it = this->bridges.begin(); it != this->bridges.end(); ++it)
- res.push_back(it->second.get());
+ for (const auto& bridge: this->bridges)
+ res.push_back(bridge.second.get());
return res;
}
void BiboumiComponent::send_self_disco_info(const std::string& id, const std::string& jid_to)
{
Stanza iq("iq");
- iq["type"] = "result";
- iq["id"] = id;
- iq["to"] = jid_to;
- iq["from"] = this->served_hostname;
- XmlNode query("query");
- query["xmlns"] = DISCO_INFO_NS;
- XmlNode identity("identity");
- identity["category"] = "conference";
- identity["type"] = "irc";
- identity["name"] = "Biboumi XMPP-IRC gateway";
- query.add_child(std::move(identity));
- for (const char* ns: {DISCO_INFO_NS, MUC_NS, ADHOC_NS, PING_NS, MAM_NS, VERSION_NS})
- {
- XmlNode feature("feature");
- feature["var"] = ns;
- query.add_child(std::move(feature));
- }
- iq.add_child(std::move(query));
+ {
+ iq["type"] = "result";
+ iq["id"] = id;
+ iq["to"] = jid_to;
+ iq["from"] = this->served_hostname;
+ XmlSubNode query(iq, "query");
+ query["xmlns"] = DISCO_INFO_NS;
+ XmlSubNode identity(query, "identity");
+ identity["category"] = "conference";
+ identity["type"] = "irc";
+ identity["name"] = "Biboumi XMPP-IRC gateway";
+ for (const char *ns: {DISCO_INFO_NS, MUC_NS, ADHOC_NS, PING_NS, MAM_NS, VERSION_NS})
+ {
+ XmlSubNode feature(query, "feature");
+ feature["var"] = ns;
+ }
+ }
this->send_stanza(iq);
}
@@ -714,57 +796,66 @@ void BiboumiComponent::send_irc_server_disco_info(const std::string& id, const s
{
Jid from(jid_from);
Stanza iq("iq");
- iq["type"] = "result";
- iq["id"] = id;
- iq["to"] = jid_to;
- iq["from"] = jid_from;
- XmlNode query("query");
- query["xmlns"] = DISCO_INFO_NS;
- XmlNode identity("identity");
- identity["category"] = "conference";
- identity["type"] = "irc";
- identity["name"] = "IRC server "s + from.local + " over Biboumi";
- query.add_child(std::move(identity));
- for (const char* ns: {DISCO_INFO_NS, ADHOC_NS, PING_NS, VERSION_NS})
- {
- XmlNode feature("feature");
- feature["var"] = ns;
- query.add_child(std::move(feature));
- }
- iq.add_child(std::move(query));
+ {
+ iq["type"] = "result";
+ iq["id"] = id;
+ iq["to"] = jid_to;
+ iq["from"] = jid_from;
+ XmlSubNode query(iq, "query");
+ query["xmlns"] = DISCO_INFO_NS;
+ XmlSubNode identity(query, "identity");
+ identity["category"] = "conference";
+ identity["type"] = "irc";
+ identity["name"] = "IRC server "s + from.local + " over Biboumi";
+ for (const char *ns: {DISCO_INFO_NS, MUC_NS, ADHOC_NS, PING_NS, MAM_NS, VERSION_NS})
+ {
+ XmlSubNode feature(query, "feature");
+ feature["var"] = ns;
+ }
+ }
this->send_stanza(iq);
}
-void BiboumiComponent::send_irc_channel_muc_traffic_info(const std::string id, const std::string& jid_from, const std::string& jid_to)
+void BiboumiComponent::send_irc_channel_muc_traffic_info(const std::string& id, const std::string& jid_to, const std::string& jid_from)
{
Stanza iq("iq");
- iq["type"] = "result";
- iq["id"] = id;
- iq["from"] = jid_from;
- iq["to"] = jid_to;
-
- XmlNode query("query");
- query["xmlns"] = DISCO_INFO_NS;
- query["node"] = MUC_TRAFFIC_NS;
- // We drop all “special” traffic (like xhtml-im, chatstates, etc), so
- // don’t include any <feature/>
- iq.add_child(std::move(query));
-
+ {
+ iq["type"] = "result";
+ iq["id"] = id;
+ iq["from"] = jid_from;
+ iq["to"] = jid_to;
+
+ XmlSubNode query(iq, "query");
+ query["xmlns"] = DISCO_INFO_NS;
+ query["node"] = MUC_TRAFFIC_NS;
+ // We drop all “special” traffic (like xhtml-im, chatstates, etc), so
+ // don’t include any <feature/>
+ }
this->send_stanza(iq);
-
}
-void BiboumiComponent::send_iq_version_request(const std::string& from,
- const std::string& jid_to)
+void BiboumiComponent::send_irc_channel_disco_info(const std::string& id, const std::string& jid_to, const std::string& jid_from)
{
+ Jid from(jid_from);
+ Iid iid(from.local, {});
Stanza iq("iq");
- iq["type"] = "get";
- iq["id"] = "version_"s + this->next_id();
- iq["from"] = from + "@" + this->served_hostname;
- iq["to"] = jid_to;
- XmlNode query("query");
- query["xmlns"] = VERSION_NS;
- iq.add_child(std::move(query));
+ {
+ iq["type"] = "result";
+ iq["id"] = id;
+ iq["to"] = jid_to;
+ iq["from"] = jid_from;
+ XmlSubNode query(iq, "query");
+ query["xmlns"] = DISCO_INFO_NS;
+ XmlSubNode identity(query, "identity");
+ identity["category"] = "conference";
+ identity["type"] = "irc";
+ identity["name"] = "IRC channel "s + iid.get_local() + " from server " + iid.get_server() + " over biboumi";
+ for (const char *ns: {DISCO_INFO_NS, MUC_NS, ADHOC_NS, PING_NS, MAM_NS, VERSION_NS})
+ {
+ XmlSubNode feature(query, "feature");
+ feature["var"] = ns;
+ }
+ }
this->send_stanza(iq);
}
@@ -773,13 +864,14 @@ void BiboumiComponent::send_ping_request(const std::string& from,
const std::string& id)
{
Stanza iq("iq");
- iq["type"] = "get";
- iq["id"] = id;
- iq["from"] = from + "@" + this->served_hostname;
- iq["to"] = jid_to;
- XmlNode ping("ping");
- ping["xmlns"] = PING_NS;
- iq.add_child(std::move(ping));
+ {
+ iq["type"] = "get";
+ iq["id"] = id;
+ iq["from"] = from + "@" + this->served_hostname;
+ iq["to"] = jid_to;
+ XmlSubNode ping(iq, "ping");
+ ping["xmlns"] = PING_NS;
+ }
this->send_stanza(iq);
auto result_cb = [from, id](Bridge* bridge, const Stanza& stanza)
@@ -789,8 +881,14 @@ void BiboumiComponent::send_ping_request(const std::string& from,
{
log_error("Received a corresponding ping result, but the 'to' from "
"the response mismatches the 'from' of the request");
+ return;
}
- else
+ const std::string type = stanza.get_tag("type");
+ const XmlNode* error = stanza.get_child("error", COMPONENT_NS);
+ // Check if what we receive is considered a valid response. And yes, those errors are valid responses
+ if (type == "result" ||
+ (type == "error" && error && (error->get_child("feature-not-implemented", STANZA_NS) ||
+ error->get_child("service-unavailable", STANZA_NS))))
bridge->send_irc_ping_result({from, bridge}, id);
};
this->waiting_iq[id] = result_cb;
@@ -803,48 +901,43 @@ void BiboumiComponent::send_iq_room_list_result(const std::string& id, const std
const ResultSetInfo& rs_info)
{
Stanza iq("iq");
- iq["from"] = from + "@" + this->served_hostname;
- iq["to"] = to_jid;
- iq["id"] = id;
- iq["type"] = "result";
- XmlNode query("query");
- query["xmlns"] = DISCO_ITEMS_NS;
+ {
+ iq["from"] = from + "@" + this->served_hostname;
+ iq["to"] = to_jid;
+ iq["id"] = id;
+ iq["type"] = "result";
+ XmlSubNode query(iq, "query");
+ query["xmlns"] = DISCO_ITEMS_NS;
for (auto it = begin; it != end; ++it)
- {
- XmlNode item("item");
+ {
+ XmlSubNode item(query, "item");
item["jid"] = it->channel + "@" + this->served_hostname;
- query.add_child(std::move(item));
- }
-
- if ((rs_info.max >= 0 || !rs_info.after.empty() || !rs_info.before.empty()))
- {
- XmlNode set_node("set");
- set_node["xmlns"] = RSM_NS;
+ }
- if (begin != channel_list.channels.cend())
- {
- XmlNode first_node("first");
- first_node["index"] = std::to_string(std::distance(channel_list.channels.cbegin(), begin));
- first_node.set_inner(begin->channel + "@" + this->served_hostname);
- set_node.add_child(std::move(first_node));
- }
- if (end != channel_list.channels.cbegin())
- {
- XmlNode last_node("last");
- last_node.set_inner(std::prev(end)->channel + "@" + this->served_hostname);
- set_node.add_child(std::move(last_node));
- }
- if (channel_list.complete)
- {
- XmlNode count_node("count");
- count_node.set_inner(std::to_string(channel_list.channels.size()));
- set_node.add_child(std::move(count_node));
- }
- query.add_child(std::move(set_node));
- }
+ if ((rs_info.max >= 0 || !rs_info.after.empty() || !rs_info.before.empty()))
+ {
+ XmlSubNode set_node(query, "set");
+ set_node["xmlns"] = RSM_NS;
- iq.add_child(std::move(query));
+ if (begin != channel_list.channels.cend())
+ {
+ XmlSubNode first_node(set_node, "first");
+ first_node["index"] = std::to_string(std::distance(channel_list.channels.cbegin(), begin));
+ first_node.set_inner(begin->channel + "@" + this->served_hostname);
+ }
+ if (end != channel_list.channels.cbegin())
+ {
+ XmlSubNode last_node(set_node, "last");
+ last_node.set_inner(std::prev(end)->channel + "@" + this->served_hostname);
+ }
+ if (channel_list.complete)
+ {
+ XmlSubNode count_node(set_node, "count");
+ count_node.set_inner(std::to_string(channel_list.channels.size()));
+ }
+ }
+ }
this->send_stanza(iq);
}
@@ -853,16 +946,36 @@ void BiboumiComponent::send_invitation(const std::string& room_target,
const std::string& author_nick)
{
Stanza message("message");
- message["from"] = room_target + "@" + this->served_hostname;
- message["to"] = jid_to;
- XmlNode x("x");
- x["xmlns"] = MUC_USER_NS;
- XmlNode invite("invite");
- if (author_nick.empty())
- invite["from"] = room_target + "@" + this->served_hostname;
- else
- invite["from"] = room_target + "@" + this->served_hostname + "/" + author_nick;
- x.add_child(std::move(invite));
- message.add_child(std::move(x));
+ {
+ message["from"] = room_target + "@" + this->served_hostname;
+ message["to"] = jid_to;
+ XmlSubNode x(message, "x");
+ x["xmlns"] = MUC_USER_NS;
+ XmlSubNode invite(x, "invite");
+ if (author_nick.empty())
+ invite["from"] = room_target + "@" + this->served_hostname;
+ else
+ invite["from"] = room_target + "@" + this->served_hostname + "/" + author_nick;
+ }
this->send_stanza(message);
}
+
+void BiboumiComponent::accept_subscription(const std::string& from, const std::string& to)
+{
+ Stanza presence("presence");
+ presence["from"] = from;
+ presence["to"] = to;
+ presence["id"] = this->next_id();
+ presence["type"] = "subscribed";
+ this->send_stanza(presence);
+}
+
+void BiboumiComponent::ask_subscription(const std::string& from, const std::string& to)
+{
+ Stanza presence("presence");
+ presence["from"] = from;
+ presence["to"] = to;
+ presence["id"] = this->next_id();
+ presence["type"] = "subscribe";
+ this->send_stanza(presence);
+}
diff --git a/src/xmpp/biboumi_component.hpp b/src/xmpp/biboumi_component.hpp
index 999001f..87311f9 100644
--- a/src/xmpp/biboumi_component.hpp
+++ b/src/xmpp/biboumi_component.hpp
@@ -1,7 +1,8 @@
#pragma once
-
+#include <database/database.hpp>
#include <xmpp/xmpp_component.hpp>
+#include <xmpp/jid.hpp>
#include <bridge/bridge.hpp>
@@ -27,7 +28,7 @@ using iq_responder_callback_t = std::function<void(Bridge* bridge, const Stanza&
class BiboumiComponent: public XmppComponent
{
public:
- explicit BiboumiComponent(std::shared_ptr<Poller> poller, const std::string& hostname, const std::string& secret);
+ explicit BiboumiComponent(std::shared_ptr<Poller>& poller, const std::string& hostname, const std::string& secret);
~BiboumiComponent() = default;
BiboumiComponent(const BiboumiComponent&) = delete;
@@ -69,12 +70,8 @@ public:
* Sends the allowed namespaces in MUC message, according to
* http://xmpp.org/extensions/xep-0045.html#impl-service-traffic
*/
- void send_irc_channel_muc_traffic_info(const std::string id, const std::string& jid_from, const std::string& jid_to);
- /**
- * Send an iq version request
- */
- void send_iq_version_request(const std::string& from,
- const std::string& jid_to);
+ void send_irc_channel_muc_traffic_info(const std::string& id, const std::string& jid_to, const std::string& jid_from);
+ void send_irc_channel_disco_info(const std::string& id, const std::string& jid_to, const std::string& jid_from);
/**
* Send a ping request
*/
@@ -88,6 +85,8 @@ public:
const ChannelList& channel_list, std::vector<ListElement>::const_iterator begin,
std::vector<ListElement>::const_iterator end, const ResultSetInfo& rs_info);
void send_invitation(const std::string& room_target, const std::string& jid_to, const std::string& author_nick);
+ void accept_subscription(const std::string& from, const std::string& to);
+ void ask_subscription(const std::string& from, const std::string& to);
/**
* Handle the various stanza types
*/
@@ -97,8 +96,10 @@ public:
#ifdef USE_DATABASE
bool handle_mam_request(const Stanza& stanza);
- void send_archived_message(const db::MucLogLine& log_line, const std::string& from, const std::string& to,
+ void send_archived_message(const Database::MucLogLine& log_line, const std::string& from, const std::string& to,
const std::string& queryid);
+ bool handle_room_configuration_form_request(const std::string& from, const Jid& to, const std::string& id);
+ bool handle_room_configuration_form(const XmlNode& query, const std::string& from, const Jid& to, const std::string& id);
#endif
/**
diff --git a/src/xmpp/body.hpp b/src/xmpp/body.hpp
new file mode 100644
index 0000000..068d1a4
--- /dev/null
+++ b/src/xmpp/body.hpp
@@ -0,0 +1,12 @@
+#pragma once
+
+
+namespace Xmpp
+{
+// Contains:
+// - an XMPP-valid UTF-8 body
+// - an XML node representing the XHTML-IM body, or null
+ using body = std::tuple<const std::string, std::unique_ptr<XmlNode>>;
+}
+
+
diff --git a/src/xmpp/jid.cpp b/src/xmpp/jid.cpp
new file mode 100644
index 0000000..19d1b55
--- /dev/null
+++ b/src/xmpp/jid.cpp
@@ -0,0 +1,152 @@
+#include <xmpp/jid.hpp>
+#include <algorithm>
+#include <cstring>
+#include <map>
+
+#include <biboumi.h>
+#ifdef LIBIDN_FOUND
+ #include <stringprep.h>
+ #include <sys/types.h>
+ #include <sys/socket.h>
+ #include <netdb.h>
+ #include <utils/scopeguard.hpp>
+ #include <set>
+#endif
+
+#include <logger/logger.hpp>
+
+Jid::Jid(const std::string& jid)
+{
+ std::string::size_type slash = jid.find('/');
+ if (slash != std::string::npos)
+ {
+ this->resource = jid.substr(slash + 1);
+ }
+
+ std::string::size_type at = jid.find('@');
+ if (at != std::string::npos && at < slash)
+ {
+ this->local = jid.substr(0, at);
+ at++;
+ }
+ else
+ at = 0;
+
+ this->domain = jid.substr(at, slash - at);
+}
+
+static constexpr size_t max_jid_part_len = 1023;
+
+std::string jidprep(const std::string& original)
+{
+#ifdef LIBIDN_FOUND
+ using CacheType = std::map<std::string, std::string>;
+ static CacheType cache;
+ std::pair<CacheType::iterator, bool> cached = cache.insert({original, {}});
+ if (std::get<1>(cached) == false)
+ { // Insertion failed: the result is already in the cache, return it
+ return std::get<0>(cached)->second;
+ }
+
+ const std::string error_msg("Failed to convert " + original + " into a valid JID:");
+ Jid jid(original);
+
+ char local[max_jid_part_len] = {};
+ memcpy(local, jid.local.data(), std::min(max_jid_part_len, jid.local.size()));
+ auto rc = static_cast<Stringprep_rc>(::stringprep(local, max_jid_part_len,
+ static_cast<Stringprep_profile_flags>(0), stringprep_xmpp_nodeprep));
+ if (rc != STRINGPREP_OK)
+ {
+ log_error(error_msg + stringprep_strerror(rc));
+ return "";
+ }
+
+ char domain[max_jid_part_len] = {};
+ memcpy(domain, jid.domain.data(), std::min(max_jid_part_len, jid.domain.size()));
+
+ {
+ // Using getaddrinfo, check if the domain part is a valid IPv4 (then use
+ // it as is), or IPv6 (surround it with []), or a domain name (run
+ // nameprep)
+ struct addrinfo hints{};
+ hints.ai_flags = AI_NUMERICHOST;
+ hints.ai_family = AF_UNSPEC;
+
+ struct addrinfo* addr_res = nullptr;
+ const auto ret = ::getaddrinfo(domain, nullptr, &hints, &addr_res);
+ auto addrinfo_deleter = utils::make_scope_guard([addr_res] { if (addr_res) freeaddrinfo(addr_res); });
+ if (ret || !addr_res || (addr_res->ai_family != AF_INET && addr_res->ai_family != AF_INET6))
+ { // Not an IP, run nameprep on it
+ rc = static_cast<Stringprep_rc>(::stringprep(domain, max_jid_part_len,
+ static_cast<Stringprep_profile_flags>(0), stringprep_nameprep));
+ if (rc != STRINGPREP_OK)
+ {
+ log_error(error_msg + stringprep_strerror(rc));
+ return "";
+ }
+
+ // Make sure it contains only allowed characters
+ using std::begin;
+ using std::end;
+ char* domain_end = domain + ::strlen(domain);
+ std::replace_if(std::begin(domain), domain + ::strlen(domain),
+ [](const char c) -> bool
+ {
+ return !((c >= 'a' && c <= 'z') || c == '-' ||
+ (c >= '0' && c <= '9') || c == '.');
+ }, '-');
+ // Make sure there are no doubled - or .
+ std::set<char> special_chars{'-', '.'};
+ domain_end = std::unique(begin(domain), domain + ::strlen(domain), [&special_chars](const char& a, const char& b) -> bool
+ {
+ return special_chars.count(a) && special_chars.count(b);
+ });
+ // remove leading and trailing -. if any
+ if (domain_end != domain && special_chars.count(*(domain_end - 1)))
+ --domain_end;
+ if (domain_end != domain && special_chars.count(domain[0]))
+ {
+ std::memmove(domain, domain + 1, domain_end - domain + 1);
+ --domain_end;
+ }
+ // And if the final result is an empty string, return a dummy hostname
+ if (domain_end == domain)
+ ::strcpy(domain, "empty");
+ else
+ *domain_end = '\0';
+ }
+ else if (addr_res->ai_family == AF_INET6)
+ { // IPv6, surround it with []. The length is always enough:
+ // the longest possible IPv6 is way shorter than max_jid_part_len
+ ::memmove(domain + 1, domain, jid.domain.size());
+ domain[0] = '[';
+ domain[jid.domain.size() + 1] = ']';
+ }
+ }
+
+
+ // If there is no resource, stop here
+ if (jid.resource.empty())
+ {
+ std::get<0>(cached)->second = std::string(local) + "@" + domain;
+ return std::get<0>(cached)->second;
+ }
+
+ // Otherwise, also process the resource part
+ char resource[max_jid_part_len] = {};
+ memcpy(resource, jid.resource.data(), std::min(max_jid_part_len, jid.resource.size()));
+ rc = static_cast<Stringprep_rc>(::stringprep(resource, max_jid_part_len,
+ static_cast<Stringprep_profile_flags>(0), stringprep_xmpp_resourceprep));
+ if (rc != STRINGPREP_OK)
+ {
+ log_error(error_msg + stringprep_strerror(rc));
+ return "";
+ }
+ std::get<0>(cached)->second = std::string(local) + "@" + domain + "/" + resource;
+ return std::get<0>(cached)->second;
+
+#else
+ (void)original;
+ return "";
+#endif
+}
diff --git a/src/xmpp/jid.hpp b/src/xmpp/jid.hpp
new file mode 100644
index 0000000..85e835c
--- /dev/null
+++ b/src/xmpp/jid.hpp
@@ -0,0 +1,49 @@
+#pragma once
+
+
+#include <string>
+
+/**
+ * Parse a JID into its different subart
+ */
+class Jid
+{
+public:
+ explicit Jid(const std::string& jid);
+
+ Jid(const Jid&) = delete;
+ Jid(Jid&&) = delete;
+ Jid& operator=(const Jid&) = delete;
+ Jid& operator=(Jid&&) = delete;
+
+ std::string domain;
+ std::string local;
+ std::string resource;
+
+ std::string bare() const
+ {
+ return this->local + "@" + this->domain;
+ }
+ std::string full() const
+ {
+ std::string res = this->domain;
+ if (!this->local.empty())
+ res = this->local + "@" + this->domain;
+ if (!this->resource.empty())
+ res += "/" + this->resource;
+ return res;
+ }
+};
+
+/**
+ * Prepare the given UTF-8 string according to the XMPP node stringprep
+ * identifier profile. This is used to send properly-formed JID to the XMPP
+ * server.
+ *
+ * If the stringprep library is not found, we return an empty string. When
+ * this function is used, the result must always be checked for an empty
+ * value, and if this is the case it must not be used as a JID.
+ */
+std::string jidprep(const std::string& original);
+
+
diff --git a/src/xmpp/xmpp_component.cpp b/src/xmpp/xmpp_component.cpp
new file mode 100644
index 0000000..b138ed9
--- /dev/null
+++ b/src/xmpp/xmpp_component.cpp
@@ -0,0 +1,684 @@
+#include <utils/timed_events.hpp>
+#include <utils/scopeguard.hpp>
+#include <utils/tolower.hpp>
+#include <logger/logger.hpp>
+
+#include <xmpp/xmpp_component.hpp>
+#include <config/config.hpp>
+#include <utils/system.hpp>
+#include <utils/time.hpp>
+#include <xmpp/auth.hpp>
+#include <xmpp/jid.hpp>
+
+#include <stdexcept>
+#include <iostream>
+#include <set>
+
+#include <uuid/uuid.h>
+
+#include <cstdlib>
+#include <set>
+
+#include <biboumi.h>
+#ifdef SYSTEMD_FOUND
+# include <systemd/sd-daemon.h>
+#endif
+
+using namespace std::string_literals;
+
+static std::set<std::string> kickable_errors{
+ "gone",
+ "internal-server-error",
+ "item-not-found",
+ "jid-malformed",
+ "recipient-unavailable",
+ "redirect",
+ "remote-server-not-found",
+ "remote-server-timeout",
+ "service-unavailable",
+ "malformed-error"
+ };
+
+XmppComponent::XmppComponent(std::shared_ptr<Poller>& poller, std::string hostname, std::string secret):
+ TCPClientSocketHandler(poller),
+ ever_auth(false),
+ first_connection_try(true),
+ secret(std::move(secret)),
+ authenticated(false),
+ doc_open(false),
+ served_hostname(std::move(hostname)),
+ stanza_handlers{},
+ adhoc_commands_handler(*this)
+{
+ this->parser.add_stream_open_callback(std::bind(&XmppComponent::on_remote_stream_open, this,
+ std::placeholders::_1));
+ this->parser.add_stanza_callback(std::bind(&XmppComponent::on_stanza, this,
+ std::placeholders::_1));
+ this->parser.add_stream_close_callback(std::bind(&XmppComponent::on_remote_stream_close, this,
+ std::placeholders::_1));
+ this->stanza_handlers.emplace("handshake",
+ std::bind(&XmppComponent::handle_handshake, this,std::placeholders::_1));
+ this->stanza_handlers.emplace("error",
+ std::bind(&XmppComponent::handle_error, this,std::placeholders::_1));
+}
+
+void XmppComponent::start()
+{
+ this->connect(Config::get("xmpp_server_ip", "127.0.0.1"), Config::get("port", "5347"), false);
+}
+
+bool XmppComponent::is_document_open() const
+{
+ return this->doc_open;
+}
+
+void XmppComponent::send_stanza(const Stanza& stanza)
+{
+ std::string str = stanza.to_string();
+ log_debug("XMPP SENDING: ", str);
+ this->send_data(std::move(str));
+}
+
+void XmppComponent::on_connection_failed(const std::string& reason)
+{
+ this->first_connection_try = false;
+ log_error("Failed to connect to the XMPP server: ", reason);
+#ifdef SYSTEMD_FOUND
+ sd_notifyf(0, "STATUS=Failed to connect to the XMPP server: %s", reason.data());
+#endif
+}
+
+void XmppComponent::on_connected()
+{
+ log_info("connected to XMPP server");
+ this->first_connection_try = true;
+ auto data = "<stream:stream to='"s + this->served_hostname + \
+ "' xmlns:stream='http://etherx.jabber.org/streams' xmlns='" COMPONENT_NS "'>";
+ log_debug("XMPP SENDING: ", data);
+ this->send_data(std::move(data));
+ this->doc_open = true;
+ // We may have some pending data to send: this happens when we try to send
+ // some data before we are actually connected. We send that data right now, if any
+ this->send_pending_data();
+}
+
+void XmppComponent::on_connection_close(const std::string& error)
+{
+ if (error.empty())
+ log_info("XMPP server closed connection");
+ else
+ log_info("XMPP server closed connection: ", error);
+}
+
+void XmppComponent::parse_in_buffer(const size_t size)
+{
+ // in_buf.size, or size, cannot be bigger than our read-size (4096) so it’s safe
+ // to cast.
+
+ if (!this->in_buf.empty())
+ { // This may happen if the parser could not allocate enough space for
+ // us. We try to feed it the data that was read into our in_buf
+ // instead. If this fails again we are in trouble.
+ this->parser.feed(this->in_buf.data(), static_cast<int>(this->in_buf.size()), false);
+ this->in_buf.clear();
+ }
+ else
+ { // Just tell the parser to parse the data that was placed into the
+ // buffer it provided to us with GetBuffer
+ this->parser.parse(static_cast<int>(size), false);
+ }
+}
+
+void XmppComponent::on_remote_stream_open(const XmlNode& node)
+{
+ log_debug("XMPP RECEIVING: ", node.to_string());
+ this->stream_id = node.get_tag("id");
+ if (this->stream_id.empty())
+ {
+ log_error("Error: no attribute 'id' found");
+ this->send_stream_error("bad-format", "missing 'id' attribute");
+ this->close_document();
+ return ;
+ }
+
+ // Try to authenticate
+ auto data = "<handshake xmlns='" COMPONENT_NS "'>"s + get_handshake_digest(this->stream_id, this->secret) + "</handshake>";
+ log_debug("XMPP SENDING: ", data);
+ this->send_data(std::move(data));
+}
+
+void XmppComponent::on_remote_stream_close(const XmlNode& node)
+{
+ log_debug("XMPP RECEIVING: ", node.to_string());
+ this->doc_open = false;
+}
+
+void XmppComponent::reset()
+{
+ this->parser.reset();
+}
+
+void XmppComponent::on_stanza(const Stanza& stanza)
+{
+ log_debug("XMPP RECEIVING: ", stanza.to_string());
+ std::function<void(const Stanza&)> handler;
+ try
+ {
+ handler = this->stanza_handlers.at(stanza.get_name());
+ }
+ catch (const std::out_of_range& exception)
+ {
+ log_warning("No handler for stanza of type ", stanza.get_name());
+ return;
+ }
+ handler(stanza);
+}
+
+void XmppComponent::send_stream_error(const std::string& name, const std::string& explanation)
+{
+ Stanza node("stream:error");
+ {
+ XmlSubNode error(node, name);
+ error["xmlns"] = STREAM_NS;
+ if (!explanation.empty())
+ error.set_inner(explanation);
+ }
+ this->send_stanza(node);
+}
+
+void XmppComponent::send_stanza_error(const std::string& kind, const std::string& to, const std::string& from,
+ const std::string& id, const std::string& error_type,
+ const std::string& defined_condition, const std::string& text,
+ const bool fulljid)
+{
+ Stanza node(kind);
+ {
+ if (!to.empty())
+ node["to"] = to;
+ if (!from.empty())
+ {
+ if (fulljid)
+ node["from"] = from;
+ else
+ node["from"] = from + "@" + this->served_hostname;
+ }
+ if (!id.empty())
+ node["id"] = id;
+ node["type"] = "error";
+ {
+ XmlSubNode error(node, "error");
+ error["type"] = error_type;
+ {
+ XmlSubNode inner_error(error, defined_condition);
+ inner_error["xmlns"] = STANZA_NS;
+ }
+ if (!text.empty())
+ {
+ XmlSubNode text_node(error, "text");
+ text_node["xmlns"] = STANZA_NS;
+ text_node.set_inner(text);
+ }
+ }
+ }
+ this->send_stanza(node);
+}
+
+void XmppComponent::close_document()
+{
+ log_debug("XMPP SENDING: </stream:stream>");
+ this->send_data("</stream:stream>");
+ this->doc_open = false;
+}
+
+void XmppComponent::handle_handshake(const Stanza&)
+{
+ this->authenticated = true;
+ this->ever_auth = true;
+ log_info("Authenticated with the XMPP server");
+#ifdef SYSTEMD_FOUND
+ sd_notify(0, "READY=1");
+ // Install an event that sends a keepalive to systemd. If biboumi crashes
+ // or hangs for too long, systemd will restart it.
+ uint64_t usec;
+ if (sd_watchdog_enabled(0, &usec) > 0)
+ {
+ TimedEventsManager::instance().add_event(TimedEvent(
+ std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::microseconds(usec / 2)),
+ []() { sd_notify(0, "WATCHDOG=1"); }));
+ }
+#endif
+ this->after_handshake();
+}
+
+void XmppComponent::handle_error(const Stanza& stanza)
+{
+ const XmlNode* text = stanza.get_child("text", STREAMS_NS);
+ std::string error_message("Unspecified error");
+ if (text)
+ error_message = text->get_inner();
+ log_error("Stream error received from the XMPP server: ", error_message);
+#ifdef SYSTEMD_FOUND
+ if (!this->ever_auth)
+ sd_notifyf(0, "STATUS=Failed to authenticate to the XMPP server: %s", error_message.data());
+#endif
+
+}
+
+void* XmppComponent::get_receive_buffer(const size_t size) const
+{
+ return this->parser.get_buffer(size);
+}
+
+void XmppComponent::send_message(const std::string& from, Xmpp::body&& body, const std::string& to,
+ const std::string& type, const bool fulljid, const bool nocopy)
+{
+ Stanza message("message");
+ {
+ message["to"] = to;
+ if (fulljid)
+ message["from"] = from;
+ else
+ message["from"] = from + "@" + this->served_hostname;
+ if (!type.empty())
+ message["type"] = type;
+ XmlSubNode body_node(message, "body");
+ body_node.set_inner(std::get<0>(body));
+ if (std::get<1>(body))
+ {
+ XmlSubNode html(message, "html");
+ html["xmlns"] = XHTMLIM_NS;
+ // Pass the ownership of the pointer to this xmlnode
+ html.add_child(std::move(std::get<1>(body)));
+ }
+ if (nocopy)
+ {
+ XmlSubNode private_node(message, "private");
+ private_node["xmlns"] = "urn:xmpp:carbons:2";
+ XmlSubNode nocopy(message, "no-copy");
+ nocopy["xmlns"] = "urn:xmpp:hints";
+ }
+ }
+ this->send_stanza(message);
+}
+
+void XmppComponent::send_user_join(const std::string& from,
+ const std::string& nick,
+ const std::string& realjid,
+ const std::string& affiliation,
+ const std::string& role,
+ const std::string& to,
+ const bool self)
+{
+ Stanza presence("presence");
+ {
+ presence["to"] = to;
+ presence["from"] = from + "@" + this->served_hostname + "/" + nick;
+
+ XmlSubNode x(presence, "x");
+ x["xmlns"] = MUC_USER_NS;
+
+ XmlSubNode item(x, "item");
+ if (!affiliation.empty())
+ item["affiliation"] = affiliation;
+ if (!role.empty())
+ item["role"] = role;
+ if (!realjid.empty())
+ {
+ const std::string preped_jid = jidprep(realjid);
+ if (!preped_jid.empty())
+ item["jid"] = preped_jid;
+ }
+
+ if (self)
+ {
+ XmlSubNode status(x, "status");
+ status["code"] = "110";
+ }
+ }
+ this->send_stanza(presence);
+}
+
+void XmppComponent::send_invalid_room_error(const std::string& muc_name,
+ const std::string& nick,
+ const std::string& to)
+{
+ Stanza presence("presence");
+ {
+ if (!muc_name.empty ())
+ presence["from"] = muc_name + "@" + this->served_hostname + "/" + nick;
+ else
+ presence["from"] = this->served_hostname;
+ presence["to"] = to;
+ presence["type"] = "error";
+ XmlSubNode x(presence, "x");
+ x["xmlns"] = MUC_NS;
+ XmlSubNode error(presence, "error");
+ error["by"] = muc_name + "@" + this->served_hostname;
+ error["type"] = "cancel";
+ XmlSubNode item_not_found(error, "item-not-found");
+ item_not_found["xmlns"] = STANZA_NS;
+ XmlSubNode text(error, "text");
+ text["xmlns"] = STANZA_NS;
+ text["xml:lang"] = "en";
+ text.set_inner(muc_name +
+ " is not a valid IRC channel name. A correct room jid is of the form: #<chan>%<server>@" +
+ this->served_hostname);
+ }
+ this->send_stanza(presence);
+}
+
+void XmppComponent::send_topic(const std::string& from, Xmpp::body&& topic, const std::string& to, const std::string& who)
+{
+ Stanza message("message");
+ {
+ message["to"] = to;
+ if (who.empty())
+ message["from"] = from + "@" + this->served_hostname;
+ else
+ message["from"] = from + "@" + this->served_hostname + "/" + who;
+ message["type"] = "groupchat";
+ XmlSubNode subject(message, "subject");
+ subject.set_inner(std::get<0>(topic));
+ }
+ this->send_stanza(message);
+}
+
+void XmppComponent::send_muc_message(const std::string& muc_name, const std::string& nick, Xmpp::body&& xmpp_body, const std::string& jid_to, std::string uuid)
+{
+ Stanza message("message");
+ message["to"] = jid_to;
+ if (!nick.empty())
+ message["from"] = muc_name + "@" + this->served_hostname + "/" + nick;
+ else // Message from the room itself
+ message["from"] = muc_name + "@" + this->served_hostname;
+ message["type"] = "groupchat";
+
+ {
+ XmlSubNode body(message, "body");
+ body.set_inner(std::get<0>(xmpp_body));
+ }
+
+ if (std::get<1>(xmpp_body))
+ {
+ XmlSubNode html(message, "html");
+ html["xmlns"] = XHTMLIM_NS;
+ // Pass the ownership of the pointer to this xmlnode
+ html.add_child(std::move(std::get<1>(xmpp_body)));
+ }
+
+ if (!uuid.empty())
+ {
+ XmlSubNode stanza_id(message, "stanza-id");
+ stanza_id["xmlns"] = STABLE_ID_NS;
+ stanza_id["by"] = muc_name + "@" + this->served_hostname;
+ stanza_id["id"] = std::move(uuid);
+ }
+
+ this->send_stanza(message);
+}
+
+void XmppComponent::send_history_message(const std::string& muc_name, const std::string& nick, const std::string& body_txt, const std::string& jid_to, std::time_t timestamp)
+{
+ Stanza message("message");
+ message["to"] = jid_to;
+ if (!nick.empty())
+ message["from"] = muc_name + "@" + this->served_hostname + "/" + nick;
+ else
+ message["from"] = muc_name + "@" + this->served_hostname;
+ message["type"] = "groupchat";
+
+ {
+ XmlSubNode body(message, "body");
+ body.set_inner(body_txt);
+ }
+ {
+ XmlSubNode delay(message, "delay");
+ delay["xmlns"] = DELAY_NS;
+ delay["from"] = muc_name + "@" + this->served_hostname;
+ delay["stamp"] = utils::to_string(timestamp);
+ }
+
+ this->send_stanza(message);
+}
+
+void XmppComponent::send_muc_leave(const std::string& muc_name, const std::string& nick, Xmpp::body&& message, const std::string& jid_to, const bool self)
+{
+ Stanza presence("presence");
+ {
+ presence["to"] = jid_to;
+ presence["from"] = muc_name + "@" + this->served_hostname + "/" + nick;
+ presence["type"] = "unavailable";
+ const std::string& message_str = std::get<0>(message);
+ XmlSubNode x(presence, "x");
+ x["xmlns"] = MUC_USER_NS;
+ if (self)
+ {
+ XmlSubNode status(x, "status");
+ status["code"] = "110";
+ }
+ if (!message_str.empty())
+ {
+ XmlSubNode status(presence, "status");
+ status.set_inner(message_str);
+ }
+ }
+ this->send_stanza(presence);
+}
+
+void XmppComponent::send_nick_change(const std::string& muc_name,
+ const std::string& old_nick,
+ const std::string& new_nick,
+ const std::string& affiliation,
+ const std::string& role,
+ const std::string& jid_to,
+ const bool self)
+{
+ Stanza presence("presence");
+ {
+ presence["to"] = jid_to;
+ presence["from"] = muc_name + "@" + this->served_hostname + "/" + old_nick;
+ presence["type"] = "unavailable";
+ XmlSubNode x(presence, "x");
+ x["xmlns"] = MUC_USER_NS;
+ XmlSubNode item(x, "item");
+ item["nick"] = new_nick;
+ XmlSubNode status(x, "status");
+ status["code"] = "303";
+ if (self)
+ {
+ XmlSubNode status(x, "status");
+ status["code"] = "110";
+ }
+ }
+ this->send_stanza(presence);
+
+ this->send_user_join(muc_name, new_nick, "", affiliation, role, jid_to, self);
+}
+
+void XmppComponent::kick_user(const std::string& muc_name, const std::string& target, const std::string& txt,
+ const std::string& author, const std::string& jid_to, const bool self)
+{
+ Stanza presence("presence");
+ {
+ presence["from"] = muc_name + "@" + this->served_hostname + "/" + target;
+ presence["to"] = jid_to;
+ presence["type"] = "unavailable";
+ XmlSubNode x(presence, "x");
+ x["xmlns"] = MUC_USER_NS;
+ XmlSubNode item(x, "item");
+ item["affiliation"] = "none";
+ item["role"] = "none";
+ XmlSubNode actor(item, "actor");
+ actor["nick"] = author;
+ actor["jid"] = author; // backward compatibility with old clients
+ XmlSubNode reason(item, "reason");
+ reason.set_inner(txt);
+ XmlSubNode status(x, "status");
+ status["code"] = "307";
+ if (self)
+ {
+ XmlSubNode status(x, "status");
+ status["code"] = "110";
+ }
+ }
+ this->send_stanza(presence);
+}
+
+void XmppComponent::send_presence_error(const std::string& muc_name,
+ const std::string& nickname,
+ const std::string& jid_to,
+ const std::string& type,
+ const std::string& condition,
+ const std::string& error_code,
+ const std::string& text)
+{
+ Stanza presence("presence");
+ {
+ presence["from"] = muc_name + "@" + this->served_hostname + "/" + nickname;
+ presence["to"] = jid_to;
+ presence["type"] = "error";
+ XmlSubNode x(presence, "x");
+ x["xmlns"] = MUC_NS;
+ XmlSubNode error(presence, "error");
+ error["by"] = muc_name + "@" + this->served_hostname;
+ error["type"] = type;
+ if (!text.empty())
+ {
+ XmlSubNode text_node(error, "text");
+ text_node["xmlns"] = STANZA_NS;
+ text_node.set_inner(text);
+ }
+ if (!error_code.empty())
+ error["code"] = error_code;
+ XmlSubNode subnode(error, condition);
+ subnode["xmlns"] = STANZA_NS;
+ }
+ this->send_stanza(presence);
+}
+
+void XmppComponent::send_affiliation_role_change(const std::string& muc_name,
+ const std::string& target,
+ const std::string& affiliation,
+ const std::string& role,
+ const std::string& jid_to)
+{
+ Stanza presence("presence");
+ {
+ presence["from"] = muc_name + "@" + this->served_hostname + "/" + target;
+ presence["to"] = jid_to;
+ XmlSubNode x(presence, "x");
+ x["xmlns"] = MUC_USER_NS;
+ XmlSubNode item(x, "item");
+ item["affiliation"] = affiliation;
+ item["role"] = role;
+ }
+ this->send_stanza(presence);
+}
+
+void XmppComponent::send_version(const std::string& id, const std::string& jid_to, const std::string& jid_from,
+ const std::string& version)
+{
+ Stanza iq("iq");
+ iq["type"] = "result";
+ iq["id"] = id;
+ iq["to"] = jid_to;
+ iq["from"] = jid_from;
+ {
+ XmlSubNode query(iq, "query");
+ query["xmlns"] = VERSION_NS;
+ if (version.empty())
+ {
+ {
+ XmlSubNode name(query, "name");
+ name.set_inner("biboumi");
+ }
+ {
+ XmlSubNode version(query, "version");
+ version.set_inner(SOFTWARE_VERSION);
+ }
+ {
+ XmlSubNode os(query, "os");
+ os.set_inner(utils::get_system_name());
+ }
+ }
+ else
+ {
+ XmlSubNode name(query, "name");
+ name.set_inner(version);
+ }
+ }
+ this->send_stanza(iq);
+}
+
+void XmppComponent::send_adhoc_commands_list(const std::string& id, const std::string& requester_jid,
+ const std::string& from_jid,
+ const bool with_admin_only, const AdhocCommandsHandler& adhoc_handler)
+{
+ Stanza iq("iq");
+ {
+ iq["type"] = "result";
+ iq["id"] = id;
+ iq["to"] = requester_jid;
+ iq["from"] = from_jid;
+ XmlSubNode query(iq, "query");
+ query["xmlns"] = DISCO_ITEMS_NS;
+ query["node"] = ADHOC_NS;
+ for (const auto &kv: adhoc_handler.get_commands())
+ {
+ if (kv.second.is_admin_only() && !with_admin_only)
+ continue;
+ XmlSubNode item(query, "item");
+ item["jid"] = from_jid;
+ item["node"] = kv.first;
+ item["name"] = kv.second.name;
+ }
+ }
+ this->send_stanza(iq);
+}
+
+void XmppComponent::send_iq_version_request(const std::string& from,
+ const std::string& jid_to)
+{
+ Stanza iq("iq");
+ {
+ iq["type"] = "get";
+ iq["id"] = "version_"s + XmppComponent::next_id();
+ iq["from"] = from + "@" + this->served_hostname;
+ iq["to"] = jid_to;
+ XmlSubNode query(iq, "query");
+ query["xmlns"] = VERSION_NS;
+ }
+ this->send_stanza(iq);
+}
+
+void XmppComponent::send_iq_result_full_jid(const std::string& id, const std::string& to_jid, const std::string& from_full_jid)
+{
+ Stanza iq("iq");
+ iq["from"] = from_full_jid;
+ iq["to"] = to_jid;
+ iq["id"] = id;
+ iq["type"] = "result";
+ this->send_stanza(iq);
+}
+
+void XmppComponent::send_iq_result(const std::string& id, const std::string& to_jid, const std::string& from_local_part)
+{
+ Stanza iq("iq");
+ if (!from_local_part.empty())
+ iq["from"] = from_local_part + "@" + this->served_hostname;
+ else
+ iq["from"] = this->served_hostname;
+ iq["to"] = to_jid;
+ iq["id"] = id;
+ iq["type"] = "result";
+ this->send_stanza(iq);
+}
+
+std::string XmppComponent::next_id()
+{
+ char uuid_str[37];
+ uuid_t uuid;
+ uuid_generate(uuid);
+ uuid_unparse(uuid, uuid_str);
+ return uuid_str;
+}
diff --git a/src/xmpp/xmpp_component.hpp b/src/xmpp/xmpp_component.hpp
new file mode 100644
index 0000000..ebe3ec8
--- /dev/null
+++ b/src/xmpp/xmpp_component.hpp
@@ -0,0 +1,248 @@
+#pragma once
+
+
+#include <xmpp/adhoc_commands_handler.hpp>
+#include <network/tcp_client_socket_handler.hpp>
+#include <xmpp/xmpp_parser.hpp>
+#include <xmpp/body.hpp>
+
+#include <unordered_map>
+#include <memory>
+#include <string>
+#include <ctime>
+#include <map>
+
+#define STREAM_NS "http://etherx.jabber.org/streams"
+#define COMPONENT_NS "jabber:component:accept"
+#define MUC_NS "http://jabber.org/protocol/muc"
+#define MUC_USER_NS MUC_NS"#user"
+#define MUC_ADMIN_NS MUC_NS"#admin"
+#define MUC_OWNER_NS MUC_NS"#owner"
+#define DISCO_NS "http://jabber.org/protocol/disco"
+#define DISCO_ITEMS_NS DISCO_NS"#items"
+#define DISCO_INFO_NS DISCO_NS"#info"
+#define XHTMLIM_NS "http://jabber.org/protocol/xhtml-im"
+#define STANZA_NS "urn:ietf:params:xml:ns:xmpp-stanzas"
+#define STREAMS_NS "urn:ietf:params:xml:ns:xmpp-streams"
+#define VERSION_NS "jabber:iq:version"
+#define ADHOC_NS "http://jabber.org/protocol/commands"
+#define PING_NS "urn:xmpp:ping"
+#define DELAY_NS "urn:xmpp:delay"
+#define MAM_NS "urn:xmpp:mam:2"
+#define FORWARD_NS "urn:xmpp:forward:0"
+#define CLIENT_NS "jabber:client"
+#define DATAFORM_NS "jabber:x:data"
+#define RSM_NS "http://jabber.org/protocol/rsm"
+#define MUC_TRAFFIC_NS "http://jabber.org/protocol/muc#traffic"
+#define STABLE_ID_NS "urn:xmpp:sid:0"
+
+/**
+ * An XMPP component, communicating with an XMPP server using the protocole
+ * described in XEP-0114: Jabber Component Protocol
+ *
+ * TODO: implement XEP-0225: Component Connections
+ */
+class XmppComponent: public TCPClientSocketHandler
+{
+public:
+ explicit XmppComponent(std::shared_ptr<Poller>& poller, std::string hostname, std::string secret);
+ virtual ~XmppComponent() = default;
+
+ XmppComponent(const XmppComponent&) = delete;
+ XmppComponent(XmppComponent&&) = delete;
+ XmppComponent& operator=(const XmppComponent&) = delete;
+ XmppComponent& operator=(XmppComponent&&) = delete;
+
+ void on_connection_failed(const std::string& reason) override final;
+ void on_connected() override final;
+ void on_connection_close(const std::string& error) override final;
+ void parse_in_buffer(const size_t size) override final;
+
+ /**
+ * Returns a unique id, to be used in the 'id' element of our iq stanzas.
+ */
+ static std::string next_id();
+ bool is_document_open() const;
+ /**
+ * Connect to the XMPP server.
+ */
+ void start();
+ /**
+ * Reset the component so we can use the component on a new XMPP stream
+ */
+ void reset();
+ /**
+ * Serialize the stanza and add it to the out_buf to be sent to the
+ * server.
+ */
+ void send_stanza(const Stanza& stanza);
+ /**
+ * Handle the opening of the remote stream
+ */
+ void on_remote_stream_open(const XmlNode& node);
+ /**
+ * Handle the closing of the remote stream
+ */
+ void on_remote_stream_close(const XmlNode& node);
+ /**
+ * Handle received stanzas
+ */
+ void on_stanza(const Stanza& stanza);
+ /**
+ * Send an error stanza. Message being the name of the element inside the
+ * stanza, and explanation being a short human-readable sentence
+ * describing the error.
+ */
+ void send_stream_error(const std::string& name, const std::string& explanation);
+ /**
+ * Send error stanza, described in http://xmpp.org/rfcs/rfc6120.html#stanzas-error
+ */
+ void send_stanza_error(const std::string& kind, const std::string& to, const std::string& from,
+ const std::string& id, const std::string& error_type,
+ const std::string& defined_condition, const std::string& text,
+ const bool fulljid=true);
+ /**
+ * Send the closing signal for our document (not closing the connection though).
+ */
+ void close_document();
+ /**
+ * Send a message from from@served_hostname, with the given body
+ *
+ * If fulljid is false, the provided 'from' doesn't contain the
+ * server-part of the JID and must be added.
+ */
+ void send_message(const std::string& from, Xmpp::body&& body, const std::string& to,
+ const std::string& type, const bool fulljid, const bool nocopy=false);
+ /**
+ * Send a join from a new participant
+ */
+ void send_user_join(const std::string& from,
+ const std::string& nick,
+ const std::string& realjid,
+ const std::string& affiliation,
+ const std::string& role,
+ const std::string& to,
+ const bool self);
+ /**
+ * Send an error to indicate that the user tried to join an invalid room
+ */
+ void send_invalid_room_error(const std::string& muc_jid,
+ const std::string& nick,
+ const std::string& to);
+ /**
+ * Send the MUC topic to the user
+ */
+ void send_topic(const std::string& from, Xmpp::body&& xmpp_topic, const std::string& to, const std::string& who);
+ /**
+ * Send a (non-private) message to the MUC
+ */
+ void send_muc_message(const std::string& muc_name, const std::string& nick, Xmpp::body&& body, const std::string& jid_to,
+ std::string uuid);
+ /**
+ * Send a message, with a <delay/> element, part of a MUC history
+ */
+ void send_history_message(const std::string& muc_name, const std::string& nick, const std::string& body,
+ const std::string& jid_to, const std::time_t timestamp);
+ /**
+ * Send an unavailable presence for this nick
+ */
+ void send_muc_leave(const std::string& muc_name, const std::string& nick, Xmpp::body&& message, const std::string& jid_to, const bool self);
+ /**
+ * Indicate that a participant changed his nick
+ */
+ void send_nick_change(const std::string& muc_name,
+ const std::string& old_nick,
+ const std::string& new_nick,
+ const std::string& affiliation,
+ const std::string& role,
+ const std::string& jid_to,
+ const bool self);
+ /**
+ * An user is kicked from a room
+ */
+ void kick_user(const std::string& muc_name, const std::string& target, const std::string& reason,
+ const std::string& author, const std::string& jid_to, const bool self);
+ /**
+ * Send a generic presence error
+ */
+ void send_presence_error(const std::string& muc_name,
+ const std::string& nickname,
+ const std::string& jid_to,
+ const std::string& type,
+ const std::string& condition,
+ const std::string& error_code,
+ const std::string& text);
+ /**
+ * Send a presence from the MUC indicating a change in the role and/or
+ * affiliation of a participant
+ */
+ void send_affiliation_role_change(const std::string& muc_name,
+ const std::string& target,
+ const std::string& affiliation,
+ const std::string& role,
+ const std::string& jid_to);
+ /**
+ * Send a result IQ with the given version, or the gateway version if the
+ * passed string is empty.
+ */
+ void send_version(const std::string& id, const std::string& jid_to, const std::string& jid_from,
+ const std::string& version="");
+ /**
+ * Send the list of all available ad-hoc commands to that JID. The list is
+ * different depending on what JID made the request.
+ */
+ void send_adhoc_commands_list(const std::string& id, const std::string& requester_jid, const std::string& from_jid,
+ const bool with_admin_only, const AdhocCommandsHandler& adhoc_handler);
+ /**
+ * Send an iq version request
+ */
+ void send_iq_version_request(const std::string& from,
+ const std::string& jid_to);
+ /**
+ * Send an empty iq of type result
+ */
+ void send_iq_result(const std::string& id, const std::string& to_jid, const std::string& from);
+ void send_iq_result_full_jid(const std::string& id, const std::string& to_jid,
+ const std::string& from_full_jid);
+
+ void handle_handshake(const Stanza& stanza);
+ void handle_error(const Stanza& stanza);
+
+ virtual void after_handshake() {}
+
+ const std::string& get_served_hostname() const
+ { return this->served_hostname; }
+
+ /**
+ * Whether or not we ever succeeded our authentication to the XMPP server
+ */
+ bool ever_auth;
+ /**
+ * Whether or not this is the first consecutive try on connecting to the
+ * XMPP server. We use this to delay the connection attempt for a few
+ * seconds, if it is not the first try.
+ */
+ bool first_connection_try;
+
+private:
+ /**
+ * Return a buffer provided by the XML parser, to read data directly into
+ * it, and avoiding some unnecessary copy.
+ */
+ void* get_receive_buffer(const size_t size) const override final;
+ XmppParser parser;
+ std::string stream_id;
+ std::string secret;
+ bool authenticated;
+ /**
+ * Whether or not OUR XMPP document is open
+ */
+ bool doc_open;
+protected:
+ std::string served_hostname;
+
+ std::unordered_map<std::string, std::function<void(const Stanza&)>> stanza_handlers;
+ AdhocCommandsHandler adhoc_commands_handler;
+};
+
+
diff --git a/src/xmpp/xmpp_parser.cpp b/src/xmpp/xmpp_parser.cpp
new file mode 100644
index 0000000..0488be9
--- /dev/null
+++ b/src/xmpp/xmpp_parser.cpp
@@ -0,0 +1,172 @@
+#include <xmpp/xmpp_parser.hpp>
+#include <xmpp/xmpp_stanza.hpp>
+
+#include <logger/logger.hpp>
+
+/**
+ * Expat handlers. Called by the Expat library, never by ourself.
+ * They just forward the call to the XmppParser corresponding methods.
+ */
+
+static void start_element_handler(void* user_data, const XML_Char* name, const XML_Char** atts)
+{
+ static_cast<XmppParser*>(user_data)->start_element(name, atts);
+}
+
+static void end_element_handler(void* user_data, const XML_Char* name)
+{
+ static_cast<XmppParser*>(user_data)->end_element(name);
+}
+
+static void character_data_handler(void *user_data, const XML_Char *s, int len)
+{
+ static_cast<XmppParser*>(user_data)->char_data(s, len);
+}
+
+/**
+ * XmppParser class
+ */
+
+XmppParser::XmppParser():
+ level(0),
+ current_node(nullptr),
+ root(nullptr)
+{
+ this->init_xml_parser();
+}
+
+void XmppParser::init_xml_parser()
+{
+ // Create the expat parser
+ this->parser = XML_ParserCreateNS("UTF-8", ':');
+ XML_SetUserData(this->parser, static_cast<void*>(this));
+
+ // Install Expat handlers
+ XML_SetElementHandler(this->parser, &start_element_handler, &end_element_handler);
+ XML_SetCharacterDataHandler(this->parser, &character_data_handler);
+}
+
+XmppParser::~XmppParser()
+{
+ XML_ParserFree(this->parser);
+}
+
+int XmppParser::feed(const char* data, const int len, const bool is_final)
+{
+ int res = XML_Parse(this->parser, data, len, is_final);
+ if (res == XML_STATUS_ERROR &&
+ (XML_GetErrorCode(this->parser) != XML_ERROR_FINISHED))
+ log_error("Xml_Parse encountered an error: ",
+ XML_ErrorString(XML_GetErrorCode(this->parser)));
+ return res;
+}
+
+int XmppParser::parse(const int len, const bool is_final)
+{
+ int res = XML_ParseBuffer(this->parser, len, is_final);
+ if (res == XML_STATUS_ERROR)
+ log_error("Xml_Parsebuffer encountered an error: ",
+ XML_ErrorString(XML_GetErrorCode(this->parser)));
+ return res;
+}
+
+void XmppParser::reset()
+{
+ XML_ParserFree(this->parser);
+ this->init_xml_parser();
+ this->current_node = nullptr;
+ this->root.reset(nullptr);
+ this->level = 0;
+}
+
+void* XmppParser::get_buffer(const size_t size) const
+{
+ return XML_GetBuffer(this->parser, static_cast<int>(size));
+}
+
+void XmppParser::start_element(const XML_Char* name, const XML_Char** attribute)
+{
+ this->level++;
+
+ auto new_node = std::make_unique<XmlNode>(name, this->current_node);
+ auto new_node_ptr = new_node.get();
+ if (this->current_node)
+ this->current_node->add_child(std::move(new_node));
+ else
+ this->root = std::move(new_node);
+ this->current_node = new_node_ptr;
+ for (size_t i = 0; attribute[i]; i += 2)
+ this->current_node->set_attribute(attribute[i], attribute[i+1]);
+ if (this->level == 1)
+ this->stream_open_event(*this->current_node);
+}
+
+void XmppParser::end_element(const XML_Char*)
+{
+ this->level--;
+ if (this->level == 0)
+ { // End of the whole stream
+ this->stream_close_event(*this->current_node);
+ this->current_node = nullptr;
+ this->root.reset();
+ }
+ else
+ {
+ auto parent = this->current_node->get_parent();
+ if (this->level == 1)
+ { // End of a stanza
+ this->stanza_event(*this->current_node);
+ // Note: deleting all the children of our parent deletes ourself,
+ // so current_node is an invalid pointer after this line
+ parent->delete_all_children();
+ }
+ this->current_node = parent;
+ }
+}
+
+void XmppParser::char_data(const XML_Char* data, const size_t len)
+{
+ if (this->current_node->has_children())
+ this->current_node->get_last_child()->add_to_tail({data, len});
+ else
+ this->current_node->add_to_inner({data, len});
+}
+
+void XmppParser::stanza_event(const Stanza& stanza) const
+{
+ for (const auto& callback: this->stanza_callbacks)
+ {
+ try {
+ callback(stanza);
+ } catch (const std::exception& e) {
+ log_error("Unhandled exception: ", e.what());
+ }
+ }
+}
+
+void XmppParser::stream_open_event(const XmlNode& node) const
+{
+ for (const auto& callback: this->stream_open_callbacks)
+ callback(node);
+}
+
+void XmppParser::stream_close_event(const XmlNode& node) const
+{
+ for (const auto& callback: this->stream_close_callbacks)
+ callback(node);
+}
+
+void XmppParser::add_stanza_callback(std::function<void(const Stanza&)>&& callback)
+{
+ this->stanza_callbacks.emplace_back(std::move(callback));
+}
+
+void XmppParser::add_stream_open_callback(std::function<void(const XmlNode&)>&& callback)
+{
+ this->stream_open_callbacks.emplace_back(std::move(callback));
+}
+
+void XmppParser::add_stream_close_callback(std::function<void(const XmlNode&)>&& callback)
+{
+ this->stream_close_callbacks.emplace_back(std::move(callback));
+}
diff --git a/src/xmpp/xmpp_parser.hpp b/src/xmpp/xmpp_parser.hpp
new file mode 100644
index 0000000..ec42f9a
--- /dev/null
+++ b/src/xmpp/xmpp_parser.hpp
@@ -0,0 +1,133 @@
+#pragma once
+
+
+#include <xmpp/xmpp_stanza.hpp>
+
+#include <functional>
+
+#include <expat.h>
+
+/**
+ * A SAX XML parser that builds XML nodes and spawns events when a complete
+ * stanza is received (an element of level 2), or when the document is
+ * opened/closed (an element of level 1)
+ *
+ * After a stanza_event has been spawned, we delete the whole stanza. This
+ * means that even with a very long document (in XMPP the document is
+ * potentially infinite), the memory is never exhausted as long as each
+ * stanza is reasonnably short.
+ *
+ * The element names generated by expat contain the namespace of the
+ * element, a colon (':') and then the actual name of the element. To get
+ * an element "x" with a namespace of "http://jabber.org/protocol/muc", you
+ * just look for an XmlNode named "http://jabber.org/protocol/muc:x"
+ *
+ * TODO: enforce the size-limit for the stanza (limit the number of childs
+ * it can contain). For example forbid the parser going further than level
+ * 20 (arbitrary number here), and each XML node to have more than 15 childs
+ * (arbitrary number again).
+ */
+class XmppParser
+{
+public:
+ explicit XmppParser();
+ ~XmppParser();
+ XmppParser(const XmppParser&) = delete;
+ XmppParser& operator=(const XmppParser&) = delete;
+ XmppParser(XmppParser&&) = delete;
+ XmppParser& operator=(XmppParser&&) = delete;
+
+public:
+ /**
+ * Feed the parser with some XML data
+ */
+ int feed(const char* data, const int len, const bool is_final);
+ /**
+ * Parse the data placed in the parser buffer
+ */
+ int parse(const int size, const bool is_final);
+ /**
+ * Reset the parser, so it can be used from scratch afterward
+ */
+ void reset();
+ /**
+ * Get a buffer provided by the xml parser.
+ */
+ void* get_buffer(const size_t size) const;
+ /**
+ * Add one callback for the various events that this parser can spawn.
+ */
+ void add_stanza_callback(std::function<void(const Stanza&)>&& callback);
+ void add_stream_open_callback(std::function<void(const XmlNode&)>&& callback);
+ void add_stream_close_callback(std::function<void(const XmlNode&)>&& callback);
+
+ /**
+ * Called when a new XML element has been opened. We instanciate a new
+ * XmlNode and set it as our current node. The parent of this new node is
+ * the previous "current" node. We have all the element's attributes in
+ * this event.
+ *
+ * We spawn a stream_event with this node if this is a level-1 element.
+ */
+ void start_element(const XML_Char* name, const XML_Char** attribute);
+ /**
+ * Called when an XML element has been closed. We close the current_node,
+ * set our current_node as the parent of the current_node, and if that was
+ * a level-2 element we spawn a stanza_event with this node.
+ *
+ * And we then delete the stanza (and everything under it, its children,
+ * attribute, etc).
+ */
+ void end_element(const XML_Char* name);
+ /**
+ * Some inner or tail data has been parsed
+ */
+ void char_data(const XML_Char* data, const size_t len);
+ /**
+ * Calls all the stanza_callbacks one by one.
+ */
+ void stanza_event(const Stanza& stanza) const;
+ /**
+ * Calls all the stream_open_callbacks one by one. Note: the passed node is not
+ * closed yet.
+ */
+ void stream_open_event(const XmlNode& node) const;
+ /**
+ * Calls all the stream_close_callbacks one by one.
+ */
+ void stream_close_event(const XmlNode& node) const;
+
+private:
+ /**
+ * Init the XML parser and install the callbacks
+ */
+ void init_xml_parser();
+
+ /**
+ * Expat structure.
+ */
+ XML_Parser parser{};
+ /**
+ * The current depth in the XML document
+ */
+ size_t level;
+ /**
+ * The deepest XML node opened but not yet closed (to which we are adding
+ * new children, inner or tail)
+ */
+ XmlNode* current_node;
+ /**
+ * The root node has no parent, so we keep it here: the XmppParser object
+ * is its owner.
+ */
+ std::unique_ptr<XmlNode> root;
+ /**
+ * A list of callbacks to be called on an *_event, receiving the
+ * concerned Stanza/XmlNode.
+ */
+ std::vector<std::function<void(const Stanza&)>> stanza_callbacks;
+ std::vector<std::function<void(const XmlNode&)>> stream_open_callbacks;
+ std::vector<std::function<void(const XmlNode&)>> stream_close_callbacks;
+};
+
+
diff --git a/src/xmpp/xmpp_stanza.cpp b/src/xmpp/xmpp_stanza.cpp
new file mode 100644
index 0000000..435f333
--- /dev/null
+++ b/src/xmpp/xmpp_stanza.cpp
@@ -0,0 +1,229 @@
+#include <xmpp/xmpp_stanza.hpp>
+
+#include <utils/encoding.hpp>
+#include <utils/split.hpp>
+
+#include <stdexcept>
+#include <iostream>
+#include <sstream>
+
+#include <cstring>
+
+std::string xml_escape(const std::string& data)
+{
+ std::string res;
+ res.reserve(data.size());
+ for (size_t pos = 0; pos != data.size(); ++pos)
+ {
+ switch(data[pos])
+ {
+ case '&':
+ res += "&amp;";
+ break;
+ case '<':
+ res += "&lt;";
+ break;
+ case '>':
+ res += "&gt;";
+ break;
+ case '\"':
+ res += "&quot;";
+ break;
+ case '\'':
+ res += "&apos;";
+ break;
+ default:
+ res += data[pos];
+ break;
+ }
+ }
+ return res;
+}
+
+std::string sanitize(const std::string& data, const std::string& encoding)
+{
+ if (utils::is_valid_utf8(data.data()))
+ return xml_escape(utils::remove_invalid_xml_chars(data));
+ else
+ return xml_escape(utils::remove_invalid_xml_chars(utils::convert_to_utf8(data, encoding.data())));
+}
+
+XmlNode::XmlNode(const std::string& name, XmlNode* parent):
+ parent(parent)
+{
+ // split the namespace and the name
+ auto n = name.rfind(':');
+ if (n == std::string::npos)
+ this->name = name;
+ else
+ {
+ this->name = name.substr(n+1);
+ this->attributes["xmlns"] = name.substr(0, n);
+ }
+}
+
+XmlNode::XmlNode(const std::string& name):
+ XmlNode(name, nullptr)
+{
+}
+
+void XmlNode::delete_all_children()
+{
+ this->children.clear();
+}
+
+void XmlNode::set_attribute(const std::string& name, const std::string& value)
+{
+ this->attributes[name] = value;
+}
+
+void XmlNode::set_tail(const std::string& data)
+{
+ this->tail = data;
+}
+
+void XmlNode::add_to_tail(const std::string& data)
+{
+ this->tail += data;
+}
+
+void XmlNode::set_inner(const std::string& data)
+{
+ this->inner = data;
+}
+
+void XmlNode::add_to_inner(const std::string& data)
+{
+ this->inner += data;
+}
+
+std::string XmlNode::get_inner() const
+{
+ return this->inner;
+}
+
+std::string XmlNode::get_tail() const
+{
+ return this->tail;
+}
+
+const XmlNode* XmlNode::get_child(const std::string& name, const std::string& xmlns) const
+{
+ for (const auto& child: this->children)
+ {
+ if (child->name == name && child->get_tag("xmlns") == xmlns)
+ return child.get();
+ }
+ return nullptr;
+}
+
+std::vector<const XmlNode*> XmlNode::get_children(const std::string& name, const std::string& xmlns) const
+{
+ std::vector<const XmlNode*> res;
+ for (const auto& child: this->children)
+ {
+ if (child->name == name && child->get_tag("xmlns") == xmlns)
+ res.push_back(child.get());
+ }
+ return res;
+}
+
+XmlNode* XmlNode::add_child(std::unique_ptr<XmlNode> child)
+{
+ child->parent = this;
+ auto ret = child.get();
+ this->children.push_back(std::move(child));
+ return ret;
+}
+
+XmlNode* XmlNode::add_child(XmlNode&& child)
+{
+ auto new_node = std::make_unique<XmlNode>(std::move(child));
+ return this->add_child(std::move(new_node));
+}
+
+XmlNode* XmlNode::add_child(const XmlNode& child)
+{
+ auto new_node = std::make_unique<XmlNode>(child);
+ return this->add_child(std::move(new_node));
+}
+
+XmlNode* XmlNode::get_last_child() const
+{
+ return this->children.back().get();
+}
+
+XmlNode* XmlNode::get_parent() const
+{
+ return this->parent;
+}
+
+void XmlNode::set_name(const std::string& name)
+{
+ this->name = name;
+}
+
+void XmlNode::set_name(std::string&& name)
+{
+ this->name = std::move(name);
+}
+
+const std::string XmlNode::get_name() const
+{
+ return this->name;
+}
+
+std::string XmlNode::to_string() const
+{
+ std::ostringstream res;
+ res << "<" << this->name;
+ for (const auto& it: this->attributes)
+ res << " " << it.first << "='" << sanitize(it.second) + "'";
+ if (!this->has_children() && this->inner.empty())
+ res << "/>";
+ else
+ {
+ res << ">" + sanitize(this->inner);
+ for (const auto& child: this->children)
+ res << child->to_string();
+ res << "</" << this->get_name() << ">";
+ }
+ res << sanitize(this->tail);
+ return res.str();
+}
+
+bool XmlNode::has_children() const
+{
+ return !this->children.empty();
+}
+
+const std::string& XmlNode::get_tag(const std::string& name) const
+{
+ try
+ {
+ const auto& value = this->attributes.at(name);
+ return value;
+ }
+ catch (const std::out_of_range& e)
+ {
+ static const std::string def{};
+ return def;
+ }
+}
+
+bool XmlNode::del_tag(const std::string& name)
+{
+ if (this->attributes.erase(name) != 0)
+ return true;
+ return false;
+}
+
+std::string& XmlNode::operator[](const std::string& name)
+{
+ return this->attributes[name];
+}
+
+std::ostream& operator<<(std::ostream& os, const XmlNode& node)
+{
+ return os << node.to_string();
+}
diff --git a/src/xmpp/xmpp_stanza.hpp b/src/xmpp/xmpp_stanza.hpp
new file mode 100644
index 0000000..f4b3948
--- /dev/null
+++ b/src/xmpp/xmpp_stanza.hpp
@@ -0,0 +1,160 @@
+#pragma once
+
+
+#include <map>
+#include <string>
+#include <vector>
+#include <memory>
+
+std::string xml_escape(const std::string& data);
+std::string xml_unescape(const std::string& data);
+std::string sanitize(const std::string& data, const std::string& encoding = "ISO-8859-1");
+
+/**
+ * Represent an XML node. It has
+ * - A parent XML node (in the case of the first-level nodes, the parent is
+ nullptr)
+ * - zero, one or more children XML nodes
+ * - A name
+ * - A map of attributes
+ * - inner data (text inside the node)
+ * - tail data (text just after the node)
+ */
+class XmlNode
+{
+public:
+ explicit XmlNode(const std::string& name, XmlNode* parent);
+ explicit XmlNode(const std::string& name);
+ /**
+ * The copy constructor does not copy the parent attribute. The children
+ * nodes are all copied recursively.
+ */
+ XmlNode(const XmlNode& node):
+ name(node.name),
+ parent(nullptr),
+ attributes(node.attributes),
+ children{},
+ inner(node.inner),
+ tail(node.tail)
+ {
+ for (const auto& child: node.children)
+ this->add_child(std::make_unique<XmlNode>(*child));
+ }
+
+ XmlNode(XmlNode&& node) = default;
+ XmlNode& operator=(const XmlNode&) = delete;
+ XmlNode& operator=(XmlNode&&) = delete;
+
+ ~XmlNode() = default;
+
+ void delete_all_children();
+ void set_attribute(const std::string& name, const std::string& value);
+ /**
+ * Set the content of the tail, that is the text just after this node
+ */
+ void set_tail(const std::string& data);
+ /**
+ * Append the given data to the content of the tail. This exists because
+ * the expat library may provide the complete text of an element in more
+ * than one call
+ */
+ void add_to_tail(const std::string& data);
+ /**
+ * Set the content of the inner, that is the text inside this node.
+ */
+ void set_inner(const std::string& data);
+ /**
+ * Append the given data to the content of the inner. For the reason
+ * described in add_to_tail comment.
+ */
+ void add_to_inner(const std::string& data);
+ /**
+ * Get the content of inner
+ */
+ std::string get_inner() const;
+ /**
+ * Get the content of the tail
+ */
+ std::string get_tail() const;
+ /**
+ * Get a pointer to the first child element with that name and that xml namespace
+ */
+ const XmlNode* get_child(const std::string& name, const std::string& xmlns) const;
+ /**
+ * Get a vector of all the children that have that name and that xml namespace.
+ */
+ std::vector<const XmlNode*> get_children(const std::string& name, const std::string& xmlns) const;
+ /**
+ * Add a node child to this node. Assign this node to the child’s parent.
+ * Returns a pointer to the newly added child.
+ */
+ XmlNode* add_child(std::unique_ptr<XmlNode> child);
+ XmlNode* add_child(XmlNode&& child);
+ XmlNode* add_child(const XmlNode& child);
+ /**
+ * Returns the last of the children. If the node doesn't have any child,
+ * the behaviour is undefined. The user should make sure this is the case
+ * by calling has_children() for example.
+ */
+ XmlNode* get_last_child() const;
+ XmlNode* get_parent() const;
+ void set_name(const std::string& name);
+ void set_name(std::string&& name);
+ const std::string get_name() const;
+ /**
+ * Serialize the stanza into a string
+ */
+ std::string to_string() const;
+ /**
+ * Whether or not this node has at least one child (if not, this is a leaf
+ * node)
+ */
+ bool has_children() const;
+ /**
+ * Gets the value for the given attribute, returns an empty string if the
+ * node as no such attribute.
+ */
+ const std::string& get_tag(const std::string& name) const;
+ /**
+ * Remove the attribute of the node. Does nothing if that attribute is not
+ * present. Returns true if the tag was removed, false if it was absent.
+ */
+ bool del_tag(const std::string& name);
+ /**
+ * Use this to set an attribute's value, like node["id"] = "12";
+ */
+ std::string& operator[](const std::string& name);
+
+private:
+ std::string name;
+ XmlNode* parent;
+ std::map<std::string, std::string> attributes;
+ std::vector<std::unique_ptr<XmlNode>> children;
+ std::string inner;
+ std::string tail;
+};
+
+std::ostream& operator<<(std::ostream& os, const XmlNode& node);
+
+/**
+ * An XMPP stanza is just an XML node of level 2 in the XMPP document (the
+ * level 1 ones are the <stream::stream/>, and the ones above 2 are just the
+ * content of the stanzas)
+ */
+using Stanza = XmlNode;
+
+class XmlSubNode: public XmlNode
+{
+public:
+ XmlSubNode(XmlNode& parent_ref, const std::string& name):
+ XmlNode(name),
+ parent_to_add(parent_ref)
+ {}
+
+ ~XmlSubNode()
+ {
+ this->parent_to_add.add_child(std::move(*this));
+ }
+private:
+ XmlNode& parent_to_add;
+}; \ No newline at end of file