summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/biboumi.h.cmake4
-rw-r--r--src/bridge/bridge.cpp58
-rw-r--r--src/bridge/bridge.hpp7
-rw-r--r--src/bridge/colors.hpp12
-rw-r--r--src/bridge/history_limit.hpp8
-rw-r--r--src/config/config.cpp8
-rw-r--r--src/config/config.hpp1
-rw-r--r--src/database/column.hpp9
-rw-r--r--src/database/count_query.hpp17
-rw-r--r--src/database/database.cpp81
-rw-r--r--src/database/database.hpp73
-rw-r--r--src/database/engine.hpp41
-rw-r--r--src/database/index.hpp38
-rw-r--r--src/database/insert_query.hpp108
-rw-r--r--src/database/postgresql_engine.cpp91
-rw-r--r--src/database/postgresql_engine.hpp48
-rw-r--r--src/database/postgresql_statement.hpp123
-rw-r--r--src/database/query.cpp26
-rw-r--r--src/database/query.hpp62
-rw-r--r--src/database/row.hpp108
-rw-r--r--src/database/select_query.hpp31
-rw-r--r--src/database/sqlite3_engine.cpp101
-rw-r--r--src/database/sqlite3_engine.hpp47
-rw-r--r--src/database/sqlite3_statement.hpp92
-rw-r--r--src/database/statement.hpp46
-rw-r--r--src/database/table.cpp23
-rw-r--r--src/database/table.hpp80
-rw-r--r--src/database/type_to_sql.cpp9
-rw-r--r--src/database/type_to_sql.hpp16
-rw-r--r--src/database/update_query.hpp104
-rw-r--r--src/identd/identd_socket.cpp4
-rw-r--r--src/irc/irc_client.cpp45
-rw-r--r--src/irc/irc_client.hpp7
-rw-r--r--src/logger/logger.cpp23
-rw-r--r--src/logger/logger.hpp87
-rw-r--r--src/main.cpp3
-rw-r--r--src/network/credentials_manager.cpp1
-rw-r--r--src/network/credentials_manager.hpp3
-rw-r--r--src/network/tcp_socket_handler.cpp6
-rw-r--r--src/network/tcp_socket_handler.hpp1
-rw-r--r--src/network/tls_policy.cpp2
-rw-r--r--src/utils/is_one_of.hpp17
-rw-r--r--src/utils/optional_bool.cpp8
-rw-r--r--src/utils/optional_bool.hpp4
-rw-r--r--src/utils/scopetimer.hpp17
-rw-r--r--src/utils/time.cpp3
-rw-r--r--src/utils/time.hpp5
-rw-r--r--src/xmpp/adhoc_commands_handler.cpp2
-rw-r--r--src/xmpp/biboumi_adhoc_commands.cpp10
-rw-r--r--src/xmpp/biboumi_component.cpp101
-rw-r--r--src/xmpp/body.hpp4
-rw-r--r--src/xmpp/xmpp_component.cpp23
-rw-r--r--src/xmpp/xmpp_component.hpp11
53 files changed, 1407 insertions, 452 deletions
diff --git a/src/biboumi.h.cmake b/src/biboumi.h.cmake
index 1ad9a40..fa99cd4 100644
--- a/src/biboumi.h.cmake
+++ b/src/biboumi.h.cmake
@@ -6,7 +6,11 @@
#cmakedefine BOTAN_FOUND
#cmakedefine GCRYPT_FOUND
#cmakedefine UDNS_FOUND
+#cmakedefine PQ_FOUND
+#cmakedefine SQLITE3_FOUND
#cmakedefine SOFTWARE_VERSION "${SOFTWARE_VERSION}"
#cmakedefine PROJECT_NAME "${PROJECT_NAME}"
#cmakedefine HAS_GET_TIME
#cmakedefine HAS_PUT_TIME
+#cmakedefine DEBUG_SQL_QUERIES
+
diff --git a/src/bridge/bridge.cpp b/src/bridge/bridge.cpp
index e0cb36d..54bee84 100644
--- a/src/bridge/bridge.cpp
+++ b/src/bridge/bridge.cpp
@@ -22,15 +22,12 @@ static std::string in_encoding_for(const Bridge& bridge, const Iid& iid)
{
#ifdef USE_DATABASE
const auto jid = bridge.get_bare_jid();
- auto options = Database::get_irc_channel_options_with_server_default(jid, iid.get_server(), iid.get_local());
- auto result = options.col<Database::EncodingIn>();
- if (!result.empty())
- return result;
+ return Database::get_encoding_in(jid, iid.get_server(), iid.get_local());
#else
(void)bridge;
(void)iid;
-#endif
return {"ISO-8859-1"};
+#endif
}
Bridge::Bridge(std::string user_jid, BiboumiComponent& xmpp, std::shared_ptr<Poller>& poller):
@@ -170,10 +167,11 @@ IrcClient* Bridge::find_irc_client(const std::string& hostname) const
}
bool Bridge::join_irc_channel(const Iid& iid, const std::string& nickname, const std::string& password,
- const std::string& resource)
+ const std::string& resource, HistoryLimit history_limit)
{
const auto& hostname = iid.get_server();
IrcClient* irc = this->make_irc_client(hostname, nickname);
+ irc->history_limit = history_limit;
this->add_resource_to_server(hostname, resource);
auto res_in_chan = this->is_resource_in_chan(ChannelKey{iid.get_local(), hostname}, resource);
if (!res_in_chan)
@@ -437,7 +435,7 @@ void Bridge::leave_irc_channel(Iid&& iid, const std::string& status_message, con
bool persistent = false;
#ifdef USE_DATABASE
const auto goptions = Database::get_global_options(this->user_jid);
- if (goptions.col<Database::Persistent>())
+ if (goptions.col<Database::GlobalPersistent>())
persistent = true;
else
{
@@ -470,9 +468,9 @@ void Bridge::leave_irc_channel(Iid&& iid, const std::string& status_message, con
"Biboumi note: " + std::to_string(resources - 1) + " resources are still in this channel.",
true, true, resource);
this->remove_resource_from_chan(key, resource);
+ }
if (this->number_of_channels_the_resource_is_in(iid.get_server(), resource) == 0)
this->remove_resource_from_server(iid.get_server(), resource);
- }
}
@@ -864,7 +862,8 @@ void Bridge::send_message(const Iid& iid, const std::string& nick, const std::st
const auto chan_name = Iid(Jid(it->second).local, {}).get_local();
for (const auto& resource: this->resources_in_chan[ChannelKey{chan_name, iid.get_server()}])
this->xmpp.send_message(it->second, this->make_xmpp_body(body, encoding),
- this->user_jid + "/" + resource, "chat", true, true);
+ this->user_jid + "/"
+ + resource, "chat", true, true, true);
}
else
{
@@ -896,7 +895,19 @@ void Bridge::send_muc_leave(const Iid& iid, const std::string& nick,
this->xmpp.send_muc_leave(std::to_string(iid), nick, this->make_xmpp_body(message),
this->user_jid + "/" + res, self, user_requested);
if (self)
- this->remove_all_resources_from_chan(iid.to_tuple());
+ {
+ // Copy the resources currently in that channel
+ const auto resources_in_chan = this->resources_in_chan[iid.to_tuple()];
+
+ this->remove_all_resources_from_chan(iid.to_tuple());
+
+ // Now, for each resource that was in that channel, remove it from the server if it’s
+ // not in any other channel
+ for (const auto& r: resources_in_chan)
+ if (this->number_of_channels_the_resource_is_in(iid.get_server(), r) == 0)
+ this->remove_resource_from_server(iid.get_server(), r);
+
+ }
}
IrcClient* irc = this->find_irc_client(iid.get_server());
@@ -935,7 +946,10 @@ void Bridge::send_xmpp_message(const std::string& from, const std::string& autho
const auto encoding = in_encoding_for(*this, {from, this});
for (const auto& resource: this->resources_in_server[from])
{
- this->xmpp.send_message(from, this->make_xmpp_body(body, encoding), this->user_jid + "/" + resource, "chat", false, false);
+ if (Config::get("fixed_irc_server", "").empty())
+ this->xmpp.send_message(from, this->make_xmpp_body(body, encoding), this->user_jid + "/" + resource, "chat", false, true);
+ else
+ this->xmpp.send_message("", this->make_xmpp_body(body, encoding), this->user_jid + "/" + resource, "chat", false, true);
}
}
@@ -950,7 +964,7 @@ void Bridge::send_user_join(const std::string& hostname, const std::string& chan
const Iid iid(chan_name, hostname, Iid::Type::Channel);
this->send_xmpp_invitation(iid, "");
}
- else
+ else
{
for (const auto& resource: resources)
this->send_user_join(hostname, chan_name, user, user_mode, self, resource);
@@ -993,17 +1007,20 @@ void Bridge::send_topic(const std::string& hostname, const std::string& chan_nam
}
-void Bridge::send_room_history(const std::string& hostname, const std::string& chan_name)
+void Bridge::send_room_history(const std::string& hostname, const std::string& chan_name, const HistoryLimit& history_limit)
{
for (const auto& resource: this->resources_in_chan[ChannelKey{chan_name, hostname}])
- this->send_room_history(hostname, chan_name, resource);
+ this->send_room_history(hostname, chan_name, resource, history_limit);
}
-void Bridge::send_room_history(const std::string& hostname, std::string chan_name, const std::string& resource)
+void Bridge::send_room_history(const std::string& hostname, std::string chan_name, const std::string& resource, const HistoryLimit& history_limit)
{
#ifdef USE_DATABASE
const auto coptions = Database::get_irc_channel_options_with_server_and_global_default(this->user_jid, hostname, chan_name);
- const auto lines = Database::get_muc_logs(this->user_jid, chan_name, hostname, coptions.col<Database::MaxHistoryLength>());
+ auto limit = coptions.col<Database::MaxHistoryLength>();
+ if (history_limit.stanzas >= 0 && history_limit.stanzas < limit)
+ limit = history_limit.stanzas;
+ const auto lines = Database::get_muc_logs(this->user_jid, chan_name, hostname, limit, history_limit.since);
chan_name.append(utils::empty_if_fixed_server("%" + hostname));
for (const auto& line: lines)
{
@@ -1015,6 +1032,7 @@ void Bridge::send_room_history(const std::string& hostname, std::string chan_nam
(void)hostname;
(void)chan_name;
(void)resource;
+ (void)history_limit;
#endif
}
@@ -1232,9 +1250,13 @@ std::size_t Bridge::number_of_channels_the_resource_is_in(const std::string& irc
std::size_t res = 0;
for (auto pair: this->resources_in_chan)
{
- if (std::get<0>(pair.first) == irc_hostname && pair.second.count(resource) != 0)
+ if (std::get<1>(pair.first) == irc_hostname && pair.second.count(resource) != 0)
res++;
}
+
+ IrcClient* irc = this->find_irc_client(irc_hostname);
+ if (irc && (irc->get_dummy_channel().joined || irc->get_dummy_channel().joining))
+ res++;
return res;
}
@@ -1257,7 +1279,7 @@ void Bridge::generate_channel_join_for_resource(const Iid& iid, const std::strin
this->send_user_join(iid.get_server(), iid.get_encoded_local(),
self, self->get_most_significant_mode(irc->get_sorted_user_modes()),
true, resource);
- this->send_room_history(iid.get_server(), iid.get_local(), resource);
+ this->send_room_history(iid.get_server(), iid.get_local(), resource, irc->history_limit);
this->send_topic(iid.get_server(), iid.get_encoded_local(), channel->topic, channel->topic_author, resource);
}
diff --git a/src/bridge/bridge.hpp b/src/bridge/bridge.hpp
index c10631b..c2f0233 100644
--- a/src/bridge/bridge.hpp
+++ b/src/bridge/bridge.hpp
@@ -2,6 +2,7 @@
#include <bridge/result_set_management.hpp>
#include <bridge/list_element.hpp>
+#include <bridge/history_limit.hpp>
#include <irc/irc_message.hpp>
#include <irc/irc_client.hpp>
@@ -74,7 +75,7 @@ public:
* Try to join an irc_channel, does nothing and return true if the channel
* was already joined.
*/
- bool join_irc_channel(const Iid& iid, const std::string& nickname, const std::string& password, const std::string& resource);
+ bool join_irc_channel(const Iid& iid, const std::string& nickname, const std::string& password, const std::string& resource, HistoryLimit history_limit);
void send_channel_message(const Iid& iid, const std::string& body);
void send_private_message(const Iid& iid, const std::string& body, const std::string& type="PRIVMSG");
@@ -156,8 +157,8 @@ public:
/**
* Send the MUC history to the user
*/
- void send_room_history(const std::string& hostname, const std::string& chan_name);
- void send_room_history(const std::string& hostname, std::string chan_name, const std::string& resource);
+ void send_room_history(const std::string& hostname, const std::string& chan_name, const HistoryLimit& history_limit);
+ void send_room_history(const std::string& hostname, std::string chan_name, const std::string& resource, const HistoryLimit& history_limit);
/**
* Send a MUC message from some participant
*/
diff --git a/src/bridge/colors.hpp b/src/bridge/colors.hpp
index dceed74..25b085a 100644
--- a/src/bridge/colors.hpp
+++ b/src/bridge/colors.hpp
@@ -6,20 +6,12 @@
* vice versa.
*/
+#include <xmpp/body.hpp>
+
#include <string>
#include <memory>
#include <tuple>
-class XmlNode;
-
-namespace Xmpp
-{
-// Contains:
-// - an XMPP-valid UTF-8 body
-// - an XML node representing the XHTML-IM body, or null
- using body = std::tuple<const std::string, std::unique_ptr<XmlNode>>;
-}
-
#define IRC_FORMAT_BOLD_CHAR '\x02' // done
#define IRC_FORMAT_COLOR_CHAR '\x03' // done
#define IRC_FORMAT_RESET_CHAR '\x0F' // done
diff --git a/src/bridge/history_limit.hpp b/src/bridge/history_limit.hpp
new file mode 100644
index 0000000..9c75256
--- /dev/null
+++ b/src/bridge/history_limit.hpp
@@ -0,0 +1,8 @@
+#pragma once
+
+// Default values means no limit
+struct HistoryLimit
+{
+ int stanzas{-1};
+ std::string since{};
+};
diff --git a/src/config/config.cpp b/src/config/config.cpp
index 0f3d639..412b170 100644
--- a/src/config/config.cpp
+++ b/src/config/config.cpp
@@ -23,6 +23,14 @@ std::string Config::get(const std::string& option, const std::string& def)
return it->second;
}
+bool Config::get_bool(const std::string& option, const bool def)
+{
+ auto res = Config::get(option, "");
+ if (res.empty())
+ return def;
+ return res == "true";
+}
+
int Config::get_int(const std::string& option, const int& def)
{
std::string res = Config::get(option, "");
diff --git a/src/config/config.hpp b/src/config/config.hpp
index 2ba38cc..c5ef15d 100644
--- a/src/config/config.hpp
+++ b/src/config/config.hpp
@@ -44,6 +44,7 @@ public:
* the second argument as the default.
*/
static int get_int(const std::string&, const int&);
+ static bool get_bool(const std::string&, const bool);
/**
* Set a value for the given option. And write all the config
* in the file from which it was read if save is true.
diff --git a/src/database/column.hpp b/src/database/column.hpp
index 111f9ca..1f16bcf 100644
--- a/src/database/column.hpp
+++ b/src/database/column.hpp
@@ -13,5 +13,10 @@ struct Column
T value{};
};
-struct Id: Column<std::size_t> { static constexpr auto name = "id_";
- static constexpr auto options = "PRIMARY KEY AUTOINCREMENT"; };
+struct Id: Column<std::size_t> {
+ static constexpr std::size_t unset_value = static_cast<std::size_t>(-1);
+ static constexpr auto name = "id_";
+ static constexpr auto options = "PRIMARY KEY";
+
+ Id(): Column<std::size_t>(-1) {}
+};
diff --git a/src/database/count_query.hpp b/src/database/count_query.hpp
index 0dde63c..118ce44 100644
--- a/src/database/count_query.hpp
+++ b/src/database/count_query.hpp
@@ -2,11 +2,10 @@
#include <database/query.hpp>
#include <database/table.hpp>
+#include <database/statement.hpp>
#include <string>
-#include <sqlite3.h>
-
struct CountQuery: public Query
{
CountQuery(std::string name):
@@ -15,20 +14,20 @@ struct CountQuery: public Query
this->body += std::move(name);
}
- int64_t execute(sqlite3* db)
+ int64_t execute(DatabaseEngine& db)
{
- auto statement = this->prepare(db);
+#ifdef DEBUG_SQL_QUERIES
+ const auto timer = this->log_and_time();
+#endif
+ auto statement = db.prepare(this->body);
int64_t res = 0;
- if (sqlite3_step(statement.get()) == SQLITE_ROW)
- res = sqlite3_column_int64(statement.get(), 0);
+ if (statement->step() != StepResult::Error)
+ res = statement->get_column_int64(0);
else
{
log_error("Count request didn’t return a result");
return 0;
}
- if (sqlite3_step(statement.get()) != SQLITE_DONE)
- log_warning("Count request returned more than one result.");
-
return res;
}
};
diff --git a/src/database/database.cpp b/src/database/database.cpp
index 85c675e..3622963 100644
--- a/src/database/database.cpp
+++ b/src/database/database.cpp
@@ -6,40 +6,54 @@
#include <utils/get_first_non_empty.hpp>
#include <utils/time.hpp>
-#include <sqlite3.h>
+#include <config/config.hpp>
+#include <database/sqlite3_engine.hpp>
+#include <database/postgresql_engine.hpp>
-sqlite3* Database::db;
-Database::MucLogLineTable Database::muc_log_lines("MucLogLine_");
-Database::GlobalOptionsTable Database::global_options("GlobalOptions_");
-Database::IrcServerOptionsTable Database::irc_server_options("IrcServerOptions_");
-Database::IrcChannelOptionsTable Database::irc_channel_options("IrcChannelOptions_");
+#include <database/engine.hpp>
+#include <database/index.hpp>
+
+#include <memory>
+
+std::unique_ptr<DatabaseEngine> Database::db;
+Database::MucLogLineTable Database::muc_log_lines("muclogline_");
+Database::GlobalOptionsTable Database::global_options("globaloptions_");
+Database::IrcServerOptionsTable Database::irc_server_options("ircserveroptions_");
+Database::IrcChannelOptionsTable Database::irc_channel_options("ircchanneloptions_");
Database::RosterTable Database::roster("roster");
+std::map<Database::CacheKey, Database::EncodingIn::real_type> Database::encoding_in_cache{};
+Database::GlobalPersistent::GlobalPersistent():
+ Column<bool>{Config::get_bool("persistent_by_default", false)}
+{}
void Database::open(const std::string& filename)
{
// Try to open the specified database.
// Close and replace the previous database pointer if it succeeded. If it did
// not, just leave things untouched
- sqlite3* new_db;
- auto res = sqlite3_open_v2(filename.data(), &new_db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr);
- if (res != SQLITE_OK)
- {
- log_error("Failed to open database file ", filename, ": ", sqlite3_errmsg(Database::db));
- throw std::runtime_error("");
- }
- Database::close();
- Database::db = new_db;
- Database::muc_log_lines.create(Database::db);
- Database::muc_log_lines.upgrade(Database::db);
- Database::global_options.create(Database::db);
- Database::global_options.upgrade(Database::db);
- Database::irc_server_options.create(Database::db);
- Database::irc_server_options.upgrade(Database::db);
- Database::irc_channel_options.create(Database::db);
- Database::irc_channel_options.upgrade(Database::db);
- Database::roster.create(Database::db);
- Database::roster.upgrade(Database::db);
+ std::unique_ptr<DatabaseEngine> new_db;
+ static const auto psql_prefix = "postgresql://"s;
+ static const auto psql_prefix2 = "postgres://"s;
+ if ((filename.substr(0, psql_prefix.size()) == psql_prefix) ||
+ (filename.substr(0, psql_prefix2.size()) == psql_prefix2))
+ new_db = PostgresqlEngine::open(filename);
+ else
+ new_db = Sqlite3Engine::open(filename);
+ if (!new_db)
+ return;
+ Database::db = std::move(new_db);
+ Database::muc_log_lines.create(*Database::db);
+ Database::muc_log_lines.upgrade(*Database::db);
+ Database::global_options.create(*Database::db);
+ Database::global_options.upgrade(*Database::db);
+ Database::irc_server_options.create(*Database::db);
+ Database::irc_server_options.upgrade(*Database::db);
+ Database::irc_channel_options.create(*Database::db);
+ Database::irc_channel_options.upgrade(*Database::db);
+ Database::roster.create(*Database::db);
+ Database::roster.upgrade(*Database::db);
+ create_index<Database::Owner, Database::IrcChanName, Database::IrcServerName>(*Database::db, "archive_index", Database::muc_log_lines.get_name());
}
@@ -49,7 +63,7 @@ Database::GlobalOptions Database::get_global_options(const std::string& owner)
request.where() << Owner{} << "=" << owner;
Database::GlobalOptions options{Database::global_options.get_name()};
- auto result = request.execute(Database::db);
+ auto result = request.execute(*Database::db);
if (result.size() == 1)
options = result.front();
else
@@ -63,7 +77,7 @@ Database::IrcServerOptions Database::get_irc_server_options(const std::string& o
request.where() << Owner{} << "=" << owner << " and " << Server{} << "=" << server;
Database::IrcServerOptions options{Database::irc_server_options.get_name()};
- auto result = request.execute(Database::db);
+ auto result = request.execute(*Database::db);
if (result.size() == 1)
options = result.front();
else
@@ -81,7 +95,7 @@ Database::IrcChannelOptions Database::get_irc_channel_options(const std::string&
" and " << Server{} << "=" << server <<\
" and " << Channel{} << "=" << channel;
Database::IrcChannelOptions options{Database::irc_channel_options.get_name()};
- auto result = request.execute(Database::db);
+ auto result = request.execute(*Database::db);
if (result.size() == 1)
options = result.front();
else
@@ -176,7 +190,7 @@ std::vector<Database::MucLogLine> Database::get_muc_logs(const std::string& owne
if (limit >= 0)
request.limit() << limit;
- auto result = request.execute(Database::db);
+ auto result = request.execute(*Database::db);
return {result.crbegin(), result.crend()};
}
@@ -197,7 +211,7 @@ void Database::delete_roster_item(const std::string& local, const std::string& r
query << " WHERE " << Database::RemoteJid{} << "=" << remote << \
" AND " << Database::LocalJid{} << "=" << local;
- query.execute(Database::db);
+// query.execute(*Database::db);
}
bool Database::has_roster_item(const std::string& local, const std::string& remote)
@@ -206,7 +220,7 @@ bool Database::has_roster_item(const std::string& local, const std::string& remo
query.where() << Database::LocalJid{} << "=" << local << \
" and " << Database::RemoteJid{} << "=" << remote;
- auto res = query.execute(Database::db);
+ auto res = query.execute(*Database::db);
return !res.empty();
}
@@ -216,19 +230,18 @@ std::vector<Database::RosterItem> Database::get_contact_list(const std::string&
auto query = Database::roster.select();
query.where() << Database::LocalJid{} << "=" << local;
- return query.execute(Database::db);
+ return query.execute(*Database::db);
}
std::vector<Database::RosterItem> Database::get_full_roster()
{
auto query = Database::roster.select();
- return query.execute(Database::db);
+ return query.execute(*Database::db);
}
void Database::close()
{
- sqlite3_close_v2(Database::db);
Database::db = nullptr;
}
diff --git a/src/database/database.hpp b/src/database/database.hpp
index c00c938..ec44543 100644
--- a/src/database/database.hpp
+++ b/src/database/database.hpp
@@ -7,12 +7,15 @@
#include <database/column.hpp>
#include <database/count_query.hpp>
+#include <database/engine.hpp>
+
#include <utils/optional_bool.hpp>
#include <chrono>
#include <string>
#include <memory>
+#include <map>
class Database
@@ -24,11 +27,11 @@ class Database
struct Owner: Column<std::string> { static constexpr auto name = "owner_"; };
- struct IrcChanName: Column<std::string> { static constexpr auto name = "ircChanName_"; };
+ struct IrcChanName: Column<std::string> { static constexpr auto name = "ircchanname_"; };
struct Channel: Column<std::string> { static constexpr auto name = "channel_"; };
- struct IrcServerName: Column<std::string> { static constexpr auto name = "ircServerName_"; };
+ struct IrcServerName: Column<std::string> { static constexpr auto name = "ircservername_"; };
struct Server: Column<std::string> { static constexpr auto name = "server_"; };
@@ -43,35 +46,38 @@ class Database
struct Ports: Column<std::string> { static constexpr auto name = "ports_";
Ports(): Column<std::string>("6667") {} };
- struct TlsPorts: Column<std::string> { static constexpr auto name = "tlsPorts_";
+ struct TlsPorts: Column<std::string> { static constexpr auto name = "tlsports_";
TlsPorts(): Column<std::string>("6697;6670") {} };
struct Username: Column<std::string> { static constexpr auto name = "username_"; };
struct Realname: Column<std::string> { static constexpr auto name = "realname_"; };
- struct AfterConnectionCommand: Column<std::string> { static constexpr auto name = "afterConnectionCommand_"; };
+ struct AfterConnectionCommand: Column<std::string> { static constexpr auto name = "afterconnectioncommand_"; };
- struct TrustedFingerprint: Column<std::string> { static constexpr auto name = "trustedFingerprint_"; };
+ struct TrustedFingerprint: Column<std::string> { static constexpr auto name = "trustedfingerprint_"; };
- struct EncodingOut: Column<std::string> { static constexpr auto name = "encodingOut_"; };
+ struct EncodingOut: Column<std::string> { static constexpr auto name = "encodingout_"; };
- struct EncodingIn: Column<std::string> { static constexpr auto name = "encodingIn_"; };
+ struct EncodingIn: Column<std::string> { static constexpr auto name = "encodingin_"; };
- struct MaxHistoryLength: Column<int> { static constexpr auto name = "maxHistoryLength_";
+ struct MaxHistoryLength: Column<int> { static constexpr auto name = "maxhistorylength_";
MaxHistoryLength(): Column<int>(20) {} };
- struct RecordHistory: Column<bool> { static constexpr auto name = "recordHistory_";
+ struct RecordHistory: Column<bool> { static constexpr auto name = "recordhistory_";
RecordHistory(): Column<bool>(true) {}};
- struct RecordHistoryOptional: Column<OptionalBool> { static constexpr auto name = "recordHistory_"; };
+ struct RecordHistoryOptional: Column<OptionalBool> { static constexpr auto name = "recordhistory_"; };
- struct VerifyCert: Column<bool> { static constexpr auto name = "verifyCert_";
+ struct VerifyCert: Column<bool> { static constexpr auto name = "verifycert_";
VerifyCert(): Column<bool>(true) {} };
struct Persistent: Column<bool> { static constexpr auto name = "persistent_";
Persistent(): Column<bool>(false) {} };
+ struct GlobalPersistent: Column<bool> { static constexpr auto name = "persistent_";
+ GlobalPersistent(); };
+
struct LocalJid: Column<std::string> { static constexpr auto name = "local"; };
struct RemoteJid: Column<std::string> { static constexpr auto name = "remote"; };
@@ -80,7 +86,7 @@ class Database
using MucLogLineTable = Table<Id, Uuid, Owner, IrcChanName, IrcServerName, Date, Body, Nick>;
using MucLogLine = MucLogLineTable::RowType;
- using GlobalOptionsTable = Table<Id, Owner, MaxHistoryLength, RecordHistory, Persistent>;
+ using GlobalOptionsTable = Table<Id, Owner, MaxHistoryLength, RecordHistory, GlobalPersistent>;
using GlobalOptions = GlobalOptionsTable::RowType;
using IrcServerOptionsTable = Table<Id, Owner, Server, Pass, AfterConnectionCommand, TlsPorts, Ports, Username, Realname, VerifyCert, TrustedFingerprint, EncodingOut, EncodingIn, MaxHistoryLength>;
@@ -130,7 +136,7 @@ class Database
static int64_t count(const TableType& table)
{
CountQuery query{table.get_name()};
- return query.execute(Database::db);
+ return query.execute(*Database::db);
}
static MucLogLineTable muc_log_lines;
@@ -138,9 +144,48 @@ class Database
static IrcServerOptionsTable irc_server_options;
static IrcChannelOptionsTable irc_channel_options;
static RosterTable roster;
- static sqlite3* db;
+ static std::unique_ptr<DatabaseEngine> db;
+
+ /**
+ * Some caches, to avoid doing very frequent query requests for a few options.
+ */
+ using CacheKey = std::tuple<std::string, std::string, std::string>;
+
+ static EncodingIn::real_type get_encoding_in(const std::string& owner,
+ const std::string& server,
+ const std::string& channel)
+ {
+ CacheKey channel_key{owner, server, channel};
+ auto it = Database::encoding_in_cache.find(channel_key);
+ if (it == Database::encoding_in_cache.end())
+ {
+ auto options = Database::get_irc_channel_options_with_server_default(owner, server, channel);
+ EncodingIn::real_type result = options.col<Database::EncodingIn>();
+ if (result.empty())
+ result = "ISO-8859-1";
+ it = Database::encoding_in_cache.insert(std::make_pair(channel_key, result)).first;
+ }
+ return it->second;
+ }
+ static void invalidate_encoding_in_cache(const std::string& owner,
+ const std::string& server,
+ const std::string& channel)
+ {
+ CacheKey channel_key{owner, server, channel};
+ Database::encoding_in_cache.erase(channel_key);
+ }
+ static void invalidate_encoding_in_cache()
+ {
+ Database::encoding_in_cache.clear();
+ }
+
+ static auto raw_exec(const std::string& query)
+ {
+ Database::db->raw_exec(query);
+ }
private:
static std::string gen_uuid();
+ static std::map<CacheKey, EncodingIn::real_type> encoding_in_cache;
};
#endif /* USE_DATABASE */
diff --git a/src/database/engine.hpp b/src/database/engine.hpp
new file mode 100644
index 0000000..41dccf5
--- /dev/null
+++ b/src/database/engine.hpp
@@ -0,0 +1,41 @@
+#pragma once
+
+/**
+ * Interface to provide non-portable behaviour, specific to each
+ * database engine we want to support.
+ *
+ * Everything else (all portable stuf) should go outside of this class.
+ */
+
+#include <database/statement.hpp>
+
+#include <memory>
+#include <string>
+#include <vector>
+#include <tuple>
+#include <set>
+
+class DatabaseEngine
+{
+ public:
+
+ DatabaseEngine() = default;
+ virtual ~DatabaseEngine() = default;
+
+ DatabaseEngine(const DatabaseEngine&) = delete;
+ DatabaseEngine& operator=(const DatabaseEngine&) = delete;
+ DatabaseEngine(DatabaseEngine&&) = delete;
+ DatabaseEngine& operator=(DatabaseEngine&&) = delete;
+
+ virtual std::set<std::string> get_all_columns_from_table(const std::string& table_name) = 0;
+ virtual std::tuple<bool, std::string> raw_exec(const std::string& query) = 0;
+ virtual std::unique_ptr<Statement> prepare(const std::string& query) = 0;
+ virtual void extract_last_insert_rowid(Statement& statement) = 0;
+ virtual std::string get_returning_id_sql_string(const std::string&)
+ {
+ return {};
+ }
+ virtual std::string id_column_type() = 0;
+
+ int64_t last_inserted_rowid{-1};
+};
diff --git a/src/database/index.hpp b/src/database/index.hpp
new file mode 100644
index 0000000..30766ab
--- /dev/null
+++ b/src/database/index.hpp
@@ -0,0 +1,38 @@
+#pragma once
+
+#include <database/engine.hpp>
+
+#include <string>
+#include <tuple>
+
+namespace
+{
+template <std::size_t N=0, typename... T>
+typename std::enable_if<N == sizeof...(T), void>::type
+add_column_name(std::string&)
+{ }
+
+template <std::size_t N=0, typename... T>
+typename std::enable_if<N < sizeof...(T), void>::type
+add_column_name(std::string& out)
+{
+ using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<std::tuple<T...>>()))>::type;
+ out += ColumnType::name;
+ if (N != sizeof...(T) - 1)
+ out += ",";
+ add_column_name<N+1, T...>(out);
+}
+}
+
+template <typename... Columns>
+void create_index(DatabaseEngine& db, const std::string& name, const std::string& table)
+{
+ std::string query{"CREATE INDEX IF NOT EXISTS "};
+ query += name + " ON " + table + "(";
+ add_column_name<0, Columns...>(query);
+ query += ")";
+
+ auto result = db.raw_exec(query);
+ if (std::get<0>(result) == false)
+ log_error("Error executing query: ", std::get<1>(result));
+}
diff --git a/src/database/insert_query.hpp b/src/database/insert_query.hpp
index 2ece69d..9726424 100644
--- a/src/database/insert_query.hpp
+++ b/src/database/insert_query.hpp
@@ -10,64 +10,64 @@
#include <string>
#include <tuple>
-#include <sqlite3.h>
-
-template <int N, typename ColumnType, typename... T>
-typename std::enable_if<!std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type
-actual_bind(Statement& statement, std::vector<std::string>& params, const std::tuple<T...>&)
+template <std::size_t N=0, typename... T>
+typename std::enable_if<N < sizeof...(T), void>::type
+update_autoincrement_id(std::tuple<T...>& columns, Statement& statement)
{
- const auto value = params.front();
- params.erase(params.begin());
- if (sqlite3_bind_text(statement.get(), N + 1, value.data(), static_cast<int>(value.size()), SQLITE_TRANSIENT) != SQLITE_OK)
- log_error("Failed to bind ", value, " to param ", N);
+ using ColumnType = typename std::decay<decltype(std::get<N>(columns))>::type;
+ if (std::is_same<ColumnType, Id>::value)
+ auto&& column = std::get<Id>(columns);
+ update_autoincrement_id<N+1>(columns, statement);
}
-template <int N, typename ColumnType, typename... T>
-typename std::enable_if<std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type
-actual_bind(Statement& statement, std::vector<std::string>&, const std::tuple<T...>& columns)
-{
- auto&& column = std::get<Id>(columns);
- if (column.value != 0)
- {
- if (sqlite3_bind_int64(statement.get(), N + 1, static_cast<sqlite3_int64>(column.value)) != SQLITE_OK)
- log_error("Failed to bind ", column.value, " to id.");
- }
- else if (sqlite3_bind_null(statement.get(), N + 1) != SQLITE_OK)
- log_error("Failed to bind NULL to param ", N);
-}
+template <std::size_t N=0, typename... T>
+typename std::enable_if<N == sizeof...(T), void>::type
+update_autoincrement_id(std::tuple<T...>&, Statement& statement)
+{}
struct InsertQuery: public Query
{
- InsertQuery(const std::string& name):
- Query("INSERT OR REPLACE INTO ")
+ template <typename... T>
+ InsertQuery(const std::string& name, const std::tuple<T...>& columns):
+ Query("INSERT INTO ")
{
this->body += name;
+ this->insert_col_names(columns);
+ this->insert_values(columns);
}
template <typename... T>
- void execute(const std::tuple<T...>& columns, sqlite3* db)
+ void execute(DatabaseEngine& db, std::tuple<T...>& columns)
{
- auto statement = this->prepare(db);
- {
- this->bind_param(columns, statement);
- if (sqlite3_step(statement.get()) != SQLITE_DONE)
- log_error("Failed to execute query: ", sqlite3_errmsg(db));
- }
+#ifdef DEBUG_SQL_QUERIES
+ const auto timer = this->log_and_time();
+#endif
+
+ auto statement = db.prepare(this->body);
+ this->bind_param(columns, *statement);
+
+ if (statement->step() != StepResult::Error)
+ db.extract_last_insert_rowid(*statement);
+ else
+ log_error("Failed to extract the rowid from the last INSERT");
}
template <int N=0, typename... T>
typename std::enable_if<N < sizeof...(T), void>::type
- bind_param(const std::tuple<T...>& columns, Statement& statement)
+ bind_param(const std::tuple<T...>& columns, Statement& statement, int index=1)
{
- using ColumnType = typename std::remove_reference<decltype(std::get<N>(columns))>::type;
+ auto&& column = std::get<N>(columns);
+ using ColumnType = std::decay_t<decltype(column)>;
- actual_bind<N, ColumnType>(statement, this->params, columns);
- this->bind_param<N+1>(columns, statement);
+ if (!std::is_same<ColumnType, Id>::value)
+ actual_bind(statement, column.value, index++);
+
+ this->bind_param<N+1>(columns, statement, index);
}
template <int N=0, typename... T>
typename std::enable_if<N == sizeof...(T), void>::type
- bind_param(const std::tuple<T...>&, Statement&)
+ bind_param(const std::tuple<T...>&, Statement&, int)
{}
template <typename... T>
@@ -80,18 +80,21 @@ struct InsertQuery: public Query
template <int N=0, typename... T>
typename std::enable_if<N < sizeof...(T), void>::type
- insert_value(const std::tuple<T...>& columns)
+ insert_value(const std::tuple<T...>& columns, int index=1)
{
- this->body += "?";
- if (N != sizeof...(T) - 1)
- this->body += ",";
- this->body += " ";
- add_param(*this, std::get<N>(columns));
- this->insert_value<N+1>(columns);
+ using ColumnType = std::decay_t<decltype(std::get<N>(columns))>;
+
+ if (!std::is_same<ColumnType, Id>::value)
+ {
+ this->body += "$" + std::to_string(index++);
+ if (N != sizeof...(T) - 1)
+ this->body += ", ";
+ }
+ this->insert_value<N+1>(columns, index);
}
template <int N=0, typename... T>
typename std::enable_if<N == sizeof...(T), void>::type
- insert_value(const std::tuple<T...>&)
+ insert_value(const std::tuple<T...>&, const int)
{ }
template <typename... T>
@@ -99,27 +102,28 @@ struct InsertQuery: public Query
{
this->body += " (";
this->insert_col_name(columns);
- this->body += ")\n";
+ this->body += ")";
}
template <int N=0, typename... T>
typename std::enable_if<N < sizeof...(T), void>::type
insert_col_name(const std::tuple<T...>& columns)
{
- using ColumnType = typename std::remove_reference<decltype(std::get<N>(columns))>::type;
+ using ColumnType = std::decay_t<decltype(std::get<N>(columns))>;
- this->body += ColumnType::name;
+ if (!std::is_same<ColumnType, Id>::value)
+ {
+ this->body += ColumnType::name;
- if (N < (sizeof...(T) - 1))
- this->body += ", ";
+ if (N < (sizeof...(T) - 1))
+ this->body += ", ";
+ }
this->insert_col_name<N+1>(columns);
}
+
template <int N=0, typename... T>
typename std::enable_if<N == sizeof...(T), void>::type
insert_col_name(const std::tuple<T...>&)
{}
-
-
- private:
};
diff --git a/src/database/postgresql_engine.cpp b/src/database/postgresql_engine.cpp
new file mode 100644
index 0000000..984a959
--- /dev/null
+++ b/src/database/postgresql_engine.cpp
@@ -0,0 +1,91 @@
+#include <biboumi.h>
+#ifdef PQ_FOUND
+
+#include <utils/scopeguard.hpp>
+
+#include <database/query.hpp>
+
+#include <database/postgresql_engine.hpp>
+
+#include <database/postgresql_statement.hpp>
+
+#include <logger/logger.hpp>
+
+PostgresqlEngine::PostgresqlEngine(PGconn*const conn):
+ conn(conn)
+{}
+
+PostgresqlEngine::~PostgresqlEngine()
+{
+ PQfinish(this->conn);
+}
+
+std::unique_ptr<DatabaseEngine> PostgresqlEngine::open(const std::string& conninfo)
+{
+ PGconn* con = PQconnectdb(conninfo.data());
+
+ if (!con)
+ {
+ log_error("Failed to allocate a Postgresql connection");
+ throw std::runtime_error("");
+ }
+ const auto status = PQstatus(con);
+ if (status != CONNECTION_OK)
+ {
+ const char* errmsg = PQerrorMessage(con);
+ log_error("Postgresql connection failed: ", errmsg);
+ throw std::runtime_error("failed to open connection.");
+ }
+ return std::make_unique<PostgresqlEngine>(con);
+}
+
+std::set<std::string> PostgresqlEngine::get_all_columns_from_table(const std::string& table_name)
+{
+ const auto query = "SELECT column_name from information_schema.columns where table_name='" + table_name + "'";
+ auto statement = this->prepare(query);
+ std::set<std::string> columns;
+
+ while (statement->step() == StepResult::Row)
+ columns.insert(statement->get_column_text(0));
+
+ return columns;
+}
+
+std::tuple<bool, std::string> PostgresqlEngine::raw_exec(const std::string& query)
+{
+#ifdef DEBUG_SQL_QUERIES
+ log_debug("SQL QUERY: ", query);
+ const auto timer = make_sql_timer();
+#endif
+ PGresult* res = PQexec(this->conn, query.data());
+ auto sg = utils::make_scope_guard([res](){
+ PQclear(res);
+ });
+
+ auto res_status = PQresultStatus(res);
+ if (res_status != PGRES_COMMAND_OK)
+ return std::make_tuple(false, PQresultErrorMessage(res));
+ return std::make_tuple(true, std::string{});
+}
+
+std::unique_ptr<Statement> PostgresqlEngine::prepare(const std::string& query)
+{
+ return std::make_unique<PostgresqlStatement>(query, this->conn);
+}
+
+void PostgresqlEngine::extract_last_insert_rowid(Statement& statement)
+{
+ this->last_inserted_rowid = statement.get_column_int64(0);
+}
+
+std::string PostgresqlEngine::get_returning_id_sql_string(const std::string& col_name)
+{
+ return " RETURNING " + col_name;
+}
+
+std::string PostgresqlEngine::id_column_type()
+{
+ return "SERIAL";
+}
+
+#endif
diff --git a/src/database/postgresql_engine.hpp b/src/database/postgresql_engine.hpp
new file mode 100644
index 0000000..fe4fb53
--- /dev/null
+++ b/src/database/postgresql_engine.hpp
@@ -0,0 +1,48 @@
+#pragma once
+
+#include <biboumi.h>
+#include <string>
+#include <stdexcept>
+#include <memory>
+
+#include <database/statement.hpp>
+#include <database/engine.hpp>
+
+#include <tuple>
+#include <set>
+
+#ifdef PQ_FOUND
+
+#include <libpq-fe.h>
+
+class PostgresqlEngine: public DatabaseEngine
+{
+ public:
+ PostgresqlEngine(PGconn*const conn);
+
+ ~PostgresqlEngine();
+
+ static std::unique_ptr<DatabaseEngine> open(const std::string& string);
+
+ std::set<std::string> get_all_columns_from_table(const std::string& table_name) override final;
+ std::tuple<bool, std::string> raw_exec(const std::string& query) override final;
+ std::unique_ptr<Statement> prepare(const std::string& query) override;
+ void extract_last_insert_rowid(Statement& statement) override;
+ std::string get_returning_id_sql_string(const std::string& col_name) override;
+ std::string id_column_type() override;
+private:
+ PGconn* const conn;
+};
+
+#else
+
+class PostgresqlEngine
+{
+public:
+ static std::unique_ptr<DatabaseEngine> open(const std::string& string)
+ {
+ throw std::runtime_error("Cannot open postgresql database "s + string + ": biboumi is not compiled with libpq.");
+ }
+};
+
+#endif
diff --git a/src/database/postgresql_statement.hpp b/src/database/postgresql_statement.hpp
new file mode 100644
index 0000000..571c8f1
--- /dev/null
+++ b/src/database/postgresql_statement.hpp
@@ -0,0 +1,123 @@
+#pragma once
+
+#include <database/statement.hpp>
+
+#include <logger/logger.hpp>
+
+#include <libpq-fe.h>
+
+class PostgresqlStatement: public Statement
+{
+ public:
+ PostgresqlStatement(std::string body, PGconn*const conn):
+ body(std::move(body)),
+ conn(conn)
+ {}
+ ~PostgresqlStatement()
+ {
+ PQclear(this->result);
+ this->result = nullptr;
+ }
+ PostgresqlStatement(const PostgresqlStatement&) = delete;
+ PostgresqlStatement& operator=(const PostgresqlStatement&) = delete;
+ PostgresqlStatement(PostgresqlStatement&& other) = delete;
+ PostgresqlStatement& operator=(PostgresqlStatement&& other) = delete;
+
+ StepResult step() override final
+ {
+ if (!this->executed)
+ {
+ this->current_tuple = 0;
+ this->executed = true;
+ if (!this->execute())
+ return StepResult::Error;
+ }
+ else
+ {
+ this->current_tuple++;
+ }
+ if (this->current_tuple < PQntuples(this->result))
+ return StepResult::Row;
+ return StepResult::Done;
+ }
+
+ int64_t get_column_int64(const int col) override
+ {
+ const char* result = PQgetvalue(this->result, this->current_tuple, col);
+ std::istringstream iss;
+ iss.str(result);
+ int64_t res;
+ iss >> res;
+ return res;
+ }
+ std::string get_column_text(const int col) override
+ {
+ const char* result = PQgetvalue(this->result, this->current_tuple, col);
+ return result;
+ }
+ int get_column_int(const int col) override
+ {
+ const char* result = PQgetvalue(this->result, this->current_tuple, col);
+ std::istringstream iss;
+ iss.str(result);
+ int res;
+ iss >> res;
+ return res;
+ }
+
+ void bind(std::vector<std::string> params) override
+ {
+
+ this->params = std::move(params);
+ }
+
+ bool bind_text(const int, const std::string& data) override
+ {
+ this->params.push_back(data);
+ return true;
+ }
+ bool bind_int64(const int, const std::int64_t value) override
+ {
+ this->params.push_back(std::to_string(value));
+ return true;
+ }
+ bool bind_null(const int) override
+ {
+ this->params.push_back("NULL");
+ return true;
+ }
+
+ private:
+
+private:
+ bool execute()
+ {
+ std::vector<const char*> params;
+ params.reserve(this->params.size());
+
+ for (const auto& param: this->params)
+ params.push_back(param.data());
+ const int param_size = static_cast<int>(this->params.size());
+ this->result = PQexecParams(this->conn, this->body.data(),
+ param_size,
+ nullptr,
+ params.data(),
+ nullptr,
+ nullptr,
+ 0);
+ const auto status = PQresultStatus(this->result);
+ if (status != PGRES_TUPLES_OK && status != PGRES_COMMAND_OK)
+ {
+ log_error("Failed to execute command: ", PQresultErrorMessage(this->result));
+ return false;
+ }
+ return true;
+ }
+
+ bool executed{false};
+ std::string body;
+ PGconn*const conn;
+ std::vector<std::string> params;
+ PGresult* result{nullptr};
+ int current_tuple{0};
+};
diff --git a/src/database/query.cpp b/src/database/query.cpp
index ba63a92..d27dc59 100644
--- a/src/database/query.cpp
+++ b/src/database/query.cpp
@@ -1,9 +1,26 @@
#include <database/query.hpp>
#include <database/column.hpp>
-template <>
-void add_param<Id>(Query&, const Id&)
-{}
+void actual_bind(Statement& statement, const std::string& value, int index)
+{
+ statement.bind_text(index, value);
+}
+
+void actual_bind(Statement& statement, const std::int64_t value, int index)
+{
+ statement.bind_int64(index, value);
+}
+
+void actual_bind(Statement& statement, const OptionalBool& value, int index)
+{
+ if (!value.is_set)
+ statement.bind_int64(index, 0);
+ else if (value.value)
+ statement.bind_int64(index, 1);
+ else
+ statement.bind_int64(index, -1);
+}
+
void actual_add_param(Query& query, const std::string& val)
{
@@ -28,7 +45,8 @@ Query& operator<<(Query& query, const char* str)
Query& operator<<(Query& query, const std::string& str)
{
- query.body += "?";
+ query.body += "$" + std::to_string(query.current_param);
+ query.current_param++;
actual_add_param(query, str);
return query;
}
diff --git a/src/database/query.hpp b/src/database/query.hpp
index 6e1db12..8434944 100644
--- a/src/database/query.hpp
+++ b/src/database/query.hpp
@@ -1,5 +1,7 @@
#pragma once
+#include <biboumi.h>
+
#include <utils/optional_bool.hpp>
#include <database/statement.hpp>
#include <database/column.hpp>
@@ -9,54 +11,53 @@
#include <vector>
#include <string>
-#include <sqlite3.h>
+void actual_bind(Statement& statement, const std::string& value, int index);
+void actual_bind(Statement& statement, const std::int64_t value, int index);
+void actual_bind(Statement& statement, const OptionalBool& value, int index);
+
+#ifdef DEBUG_SQL_QUERIES
+#include <utils/scopetimer.hpp>
+
+inline auto make_sql_timer()
+{
+ return make_scope_timer([](const std::chrono::steady_clock::duration& elapsed)
+ {
+ const auto seconds = std::chrono::duration_cast<std::chrono::seconds>(elapsed);
+ const auto rest = elapsed - seconds;
+ log_debug("Query executed in ", seconds.count(), ".", rest.count(), "s.");
+ });
+}
+#endif
struct Query
{
std::string body;
std::vector<std::string> params;
+ int current_param{1};
Query(std::string str):
body(std::move(str))
{}
- Statement prepare(sqlite3* db)
+#ifdef DEBUG_SQL_QUERIES
+ auto log_and_time()
{
- sqlite3_stmt* stmt;
- auto res = sqlite3_prepare(db, this->body.data(), static_cast<int>(this->body.size()) + 1,
- &stmt, nullptr);
- if (res != SQLITE_OK)
- {
- log_error("Error preparing statement: ", sqlite3_errmsg(db));
- return nullptr;
- }
- Statement statement(stmt);
- int i = 1;
- for (const std::string& param: this->params)
- {
- if (sqlite3_bind_text(statement.get(), i, param.data(), static_cast<int>(param.size()), SQLITE_TRANSIENT) != SQLITE_OK)
- log_error("Failed to bind ", param, " to param ", i);
- i++;
- }
-
- return statement;
- }
-
- void execute(sqlite3* db)
- {
- auto statement = this->prepare(db);
- while (sqlite3_step(statement.get()) != SQLITE_DONE)
- ;
+ std::ostringstream os;
+ os << this->body << "; ";
+ for (const auto& param: this->params)
+ os << "'" << param << "' ";
+ log_debug("SQL QUERY: ", os.str());
+ return make_sql_timer();
}
+#endif
};
template <typename ColumnType>
void add_param(Query& query, const ColumnType& column)
{
+ std::cout << "add_param<ColumnType>" << std::endl;
actual_add_param(query, column.value);
}
-template <>
-void add_param<Id>(Query& query, const Id& column);
template <typename T>
void actual_add_param(Query& query, const T& val)
@@ -81,7 +82,8 @@ template <typename Integer>
typename std::enable_if<std::is_integral<Integer>::value, Query&>::type
operator<<(Query& query, const Integer& i)
{
- query.body += "?";
+ query.body += "$" + std::to_string(query.current_param++);
actual_add_param(query, i);
return query;
}
+
diff --git a/src/database/row.hpp b/src/database/row.hpp
index 2b50874..4dc98be 100644
--- a/src/database/row.hpp
+++ b/src/database/row.hpp
@@ -1,72 +1,72 @@
#pragma once
#include <database/insert_query.hpp>
+#include <database/update_query.hpp>
#include <logger/logger.hpp>
-#include <type_traits>
-
-#include <sqlite3.h>
+#include <utils/is_one_of.hpp>
-template <typename ColumnType, typename... T>
-typename std::enable_if<!std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type
-update_id(std::tuple<T...>&, sqlite3*)
-{}
+#include <type_traits>
-template <typename ColumnType, typename... T>
-typename std::enable_if<std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type
-update_id(std::tuple<T...>& columns, sqlite3* db)
+template <typename... T>
+struct Row
{
- auto&& column = std::get<ColumnType>(columns);
- auto res = sqlite3_last_insert_rowid(db);
- column.value = static_cast<Id::real_type>(res);
-}
+ Row(std::string name):
+ table_name(std::move(name))
+ {}
-template <std::size_t N=0, typename... T>
-typename std::enable_if<N < sizeof...(T), void>::type
-update_autoincrement_id(std::tuple<T...>& columns, sqlite3* db)
-{
- using ColumnType = typename std::remove_reference<decltype(std::get<N>(columns))>::type;
- update_id<ColumnType>(columns, db);
- update_autoincrement_id<N+1>(columns, db);
-}
+ template <typename Type>
+ typename Type::real_type& col()
+ {
+ auto&& col = std::get<Type>(this->columns);
+ return col.value;
+ }
-template <std::size_t N=0, typename... T>
-typename std::enable_if<N == sizeof...(T), void>::type
-update_autoincrement_id(std::tuple<T...>&, sqlite3*)
-{}
+ template <typename Type>
+ const auto& col() const
+ {
+ auto&& col = std::get<Type>(this->columns);
+ return col.value;
+ }
-template <typename... T>
-struct Row
-{
- Row(std::string name):
- table_name(std::move(name))
- {}
+ template <bool Coucou=true>
+ void save(std::unique_ptr<DatabaseEngine>& db, typename std::enable_if<!is_one_of<Id, T...> && Coucou>::type* = nullptr)
+ {
+ this->insert(*db);
+ }
- template <typename Type>
- typename Type::real_type& col()
- {
- auto&& col = std::get<Type>(this->columns);
- return col.value;
- }
+ template <bool Coucou=true>
+ void save(std::unique_ptr<DatabaseEngine>& db, typename std::enable_if<is_one_of<Id, T...> && Coucou>::type* = nullptr)
+ {
+ const Id& id = std::get<Id>(this->columns);
+ if (id.value == Id::unset_value)
+ {
+ this->insert(*db);
+ if (db->last_inserted_rowid >= 0)
+ std::get<Id>(this->columns).value = static_cast<Id::real_type>(db->last_inserted_rowid);
+ }
+ else
+ this->update(*db);
+ }
- template <typename Type>
- const auto& col() const
- {
- auto&& col = std::get<Type>(this->columns);
- return col.value;
- }
+ private:
+ void insert(DatabaseEngine& db)
+ {
+ InsertQuery query(this->table_name, this->columns);
+ // Ugly workaround for non portable stuff
+ query.body += db.get_returning_id_sql_string(Id::name);
+ query.execute(db, this->columns);
+ }
- void save(sqlite3* db)
- {
- InsertQuery query(this->table_name);
- query.insert_col_names(this->columns);
- query.insert_values(this->columns);
+ void update(DatabaseEngine& db)
+ {
+ UpdateQuery query(this->table_name, this->columns);
- query.execute(this->columns, db);
+ query.execute(db, this->columns);
+ }
- update_autoincrement_id(this->columns, db);
- }
+public:
+ std::tuple<T...> columns;
+ std::string table_name;
- std::tuple<T...> columns;
- std::string table_name;
};
diff --git a/src/database/select_query.hpp b/src/database/select_query.hpp
index 872001c..5a17f38 100644
--- a/src/database/select_query.hpp
+++ b/src/database/select_query.hpp
@@ -1,5 +1,7 @@
#pragma once
+#include <database/engine.hpp>
+
#include <database/statement.hpp>
#include <database/query.hpp>
#include <logger/logger.hpp>
@@ -10,32 +12,27 @@
#include <vector>
#include <string>
-#include <sqlite3.h>
-
using namespace std::string_literals;
template <typename T>
-typename std::enable_if<std::is_integral<T>::value, sqlite3_int64>::type
+typename std::enable_if<std::is_integral<T>::value, std::int64_t>::type
extract_row_value(Statement& statement, const int i)
{
- return sqlite3_column_int64(statement.get(), i);
+ return statement.get_column_int64(i);
}
template <typename T>
typename std::enable_if<std::is_same<std::string, T>::value, T>::type
extract_row_value(Statement& statement, const int i)
{
- const auto size = sqlite3_column_bytes(statement.get(), i);
- const unsigned char* str = sqlite3_column_text(statement.get(), i);
- std::string result(reinterpret_cast<const char*>(str), static_cast<std::size_t>(size));
- return result;
+ return statement.get_column_text(i);
}
template <typename T>
typename std::enable_if<std::is_same<OptionalBool, T>::value, T>::type
extract_row_value(Statement& statement, const int i)
{
- const auto integer = sqlite3_column_int(statement.get(), i);
+ const auto integer = statement.get_column_int(i);
OptionalBool result;
if (integer > 0)
result.set_value(true);
@@ -109,16 +106,24 @@ struct SelectQuery: public Query
return *this;
}
- auto execute(sqlite3* db)
+ auto execute(DatabaseEngine& db)
{
- auto statement = this->prepare(db);
std::vector<Row<T...>> rows;
- while (sqlite3_step(statement.get()) == SQLITE_ROW)
+
+#ifdef DEBUG_SQL_QUERIES
+ const auto timer = this->log_and_time();
+#endif
+
+ auto statement = db.prepare(this->body);
+ statement->bind(std::move(this->params));
+
+ while (statement->step() == StepResult::Row)
{
Row<T...> row(this->table_name);
- extract_row_values(row, statement);
+ extract_row_values(row, *statement);
rows.push_back(row);
}
+
return rows;
}
diff --git a/src/database/sqlite3_engine.cpp b/src/database/sqlite3_engine.cpp
new file mode 100644
index 0000000..ae4a146
--- /dev/null
+++ b/src/database/sqlite3_engine.cpp
@@ -0,0 +1,101 @@
+#include <biboumi.h>
+
+#ifdef SQLITE3_FOUND
+
+#include <database/sqlite3_engine.hpp>
+
+#include <database/sqlite3_statement.hpp>
+
+#include <database/query.hpp>
+
+#include <utils/tolower.hpp>
+#include <logger/logger.hpp>
+#include <vector>
+
+Sqlite3Engine::Sqlite3Engine(sqlite3* db):
+ db(db)
+{
+}
+
+Sqlite3Engine::~Sqlite3Engine()
+{
+ sqlite3_close(this->db);
+}
+
+std::set<std::string> Sqlite3Engine::get_all_columns_from_table(const std::string& table_name)
+{
+ std::set<std::string> result;
+ char* errmsg;
+ std::string query{"PRAGMA table_info(" + table_name + ")"};
+ int res = sqlite3_exec(this->db, query.data(), [](void* param, int columns_nb, char** columns, char**) -> int {
+ constexpr int name_column = 1;
+ std::set<std::string>* result = static_cast<std::set<std::string>*>(param);
+ if (name_column < columns_nb)
+ result->insert(utils::tolower(columns[name_column]));
+ return 0;
+ }, &result, &errmsg);
+
+ if (res != SQLITE_OK)
+ {
+ log_error("Error executing ", query, ": ", errmsg);
+ sqlite3_free(errmsg);
+ }
+
+ return result;
+}
+
+std::unique_ptr<DatabaseEngine> Sqlite3Engine::open(const std::string& filename)
+{
+ sqlite3* new_db;
+ auto res = sqlite3_open_v2(filename.data(), &new_db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr);
+ if (res != SQLITE_OK)
+ {
+ log_error("Failed to open database file ", filename, ": ", sqlite3_errmsg(new_db));
+ sqlite3_close(new_db);
+ throw std::runtime_error("");
+ }
+ return std::make_unique<Sqlite3Engine>(new_db);
+}
+
+std::tuple<bool, std::string> Sqlite3Engine::raw_exec(const std::string& query)
+{
+#ifdef DEBUG_SQL_QUERIES
+ log_debug("SQL QUERY: ", query);
+ const auto timer = make_sql_timer();
+#endif
+
+ char* error;
+ const auto result = sqlite3_exec(db, query.data(), nullptr, nullptr, &error);
+ if (result != SQLITE_OK)
+ {
+ std::string err_msg(error);
+ sqlite3_free(error);
+ return std::make_tuple(false, err_msg);
+ }
+ return std::make_tuple(true, std::string{});
+}
+
+std::unique_ptr<Statement> Sqlite3Engine::prepare(const std::string& query)
+{
+ sqlite3_stmt* stmt;
+ auto res = sqlite3_prepare(db, query.data(), static_cast<int>(query.size()) + 1,
+ &stmt, nullptr);
+ if (res != SQLITE_OK)
+ {
+ log_error("Error preparing statement: ", sqlite3_errmsg(db));
+ return nullptr;
+ }
+ return std::make_unique<Sqlite3Statement>(stmt);
+}
+
+void Sqlite3Engine::extract_last_insert_rowid(Statement&)
+{
+ this->last_inserted_rowid = sqlite3_last_insert_rowid(this->db);
+}
+
+std::string Sqlite3Engine::id_column_type()
+{
+ return "INTEGER PRIMARY KEY AUTOINCREMENT";
+}
+
+#endif
diff --git a/src/database/sqlite3_engine.hpp b/src/database/sqlite3_engine.hpp
new file mode 100644
index 0000000..5b8176c
--- /dev/null
+++ b/src/database/sqlite3_engine.hpp
@@ -0,0 +1,47 @@
+#pragma once
+
+#include <database/engine.hpp>
+
+#include <database/statement.hpp>
+
+#include <memory>
+#include <string>
+#include <tuple>
+#include <set>
+
+#include <biboumi.h>
+
+#ifdef SQLITE3_FOUND
+
+#include <sqlite3.h>
+
+class Sqlite3Engine: public DatabaseEngine
+{
+ public:
+ Sqlite3Engine(sqlite3* db);
+
+ ~Sqlite3Engine();
+
+ static std::unique_ptr<DatabaseEngine> open(const std::string& string);
+
+ std::set<std::string> get_all_columns_from_table(const std::string& table_name) override final;
+ std::tuple<bool, std::string> raw_exec(const std::string& query) override final;
+ std::unique_ptr<Statement> prepare(const std::string& query) override;
+ void extract_last_insert_rowid(Statement& statement) override;
+ std::string id_column_type() override;
+private:
+ sqlite3* const db;
+};
+
+#else
+
+class Sqlite3Engine
+{
+public:
+ static std::unique_ptr<DatabaseEngine> open(const std::string& string)
+ {
+ throw std::runtime_error("Cannot open sqlite3 database "s + string + ": biboumi is not compiled with sqlite3 lib.");
+ }
+};
+
+#endif
diff --git a/src/database/sqlite3_statement.hpp b/src/database/sqlite3_statement.hpp
new file mode 100644
index 0000000..7738fa6
--- /dev/null
+++ b/src/database/sqlite3_statement.hpp
@@ -0,0 +1,92 @@
+#pragma once
+
+#include <database/statement.hpp>
+
+#include <logger/logger.hpp>
+
+#include <sqlite3.h>
+
+class Sqlite3Statement: public Statement
+{
+ public:
+ Sqlite3Statement(sqlite3_stmt* stmt):
+ stmt(stmt) {}
+ ~Sqlite3Statement()
+ {
+ sqlite3_finalize(this->stmt);
+ }
+
+ StepResult step() override final
+ {
+ auto res = sqlite3_step(this->get());
+ if (res == SQLITE_ROW)
+ return StepResult::Row;
+ else if (res == SQLITE_DONE)
+ return StepResult::Done;
+ else
+ return StepResult::Error;
+ }
+
+ void bind(std::vector<std::string> params) override
+ {
+ int i = 1;
+ for (const std::string& param: params)
+ {
+ if (sqlite3_bind_text(this->get(), i, param.data(), static_cast<int>(param.size()), SQLITE_TRANSIENT) != SQLITE_OK)
+ log_error("Failed to bind ", param, " to param ", i);
+ i++;
+ }
+ }
+
+ int64_t get_column_int64(const int col) override
+ {
+ return sqlite3_column_int64(this->get(), col);
+ }
+
+ std::string get_column_text(const int col) override
+ {
+ const auto size = sqlite3_column_bytes(this->get(), col);
+ const unsigned char* str = sqlite3_column_text(this->get(), col);
+ std::string result(reinterpret_cast<const char*>(str), static_cast<std::size_t>(size));
+ return result;
+ }
+
+ bool bind_text(const int pos, const std::string& data) override
+ {
+ return sqlite3_bind_text(this->get(), pos, data.data(), static_cast<int>(data.size()), SQLITE_TRANSIENT) == SQLITE_OK;
+ }
+ bool bind_int64(const int pos, const std::int64_t value) override
+ {
+ return sqlite3_bind_int64(this->get(), pos, static_cast<sqlite3_int64>(value)) == SQLITE_OK;
+ }
+ bool bind_null(const int pos) override
+ {
+ return sqlite3_bind_null(this->get(), pos) == SQLITE_OK;
+ }
+ int get_column_int(const int col) override
+ {
+ return sqlite3_column_int(this->get(), col);
+ }
+
+ Sqlite3Statement(const Sqlite3Statement&) = delete;
+ Sqlite3Statement& operator=(const Sqlite3Statement&) = delete;
+ Sqlite3Statement(Sqlite3Statement&& other):
+ stmt(other.stmt)
+ {
+ other.stmt = nullptr;
+ }
+ Sqlite3Statement& operator=(Sqlite3Statement&& other)
+ {
+ this->stmt = other.stmt;
+ other.stmt = nullptr;
+ return *this;
+ }
+ sqlite3_stmt* get()
+ {
+ return this->stmt;
+ }
+
+ private:
+ sqlite3_stmt* stmt;
+ int last_step_result{SQLITE_OK};
+};
diff --git a/src/database/statement.hpp b/src/database/statement.hpp
index 87cd70f..4a61928 100644
--- a/src/database/statement.hpp
+++ b/src/database/statement.hpp
@@ -1,35 +1,29 @@
#pragma once
-#include <sqlite3.h>
+#include <cstdint>
+#include <string>
+#include <vector>
+
+enum class StepResult
+{
+ Row,
+ Done,
+ Error,
+};
class Statement
{
public:
- Statement(sqlite3_stmt* stmt):
- stmt(stmt) {}
- ~Statement()
- {
- sqlite3_finalize(this->stmt);
- }
+ virtual ~Statement() = default;
+ virtual StepResult step() = 0;
+
+ virtual void bind(std::vector<std::string> params) = 0;
- Statement(const Statement&) = delete;
- Statement& operator=(const Statement&) = delete;
- Statement(Statement&& other):
- stmt(other.stmt)
- {
- other.stmt = nullptr;
- }
- Statement& operator=(Statement&& other)
- {
- this->stmt = other.stmt;
- other.stmt = nullptr;
- return *this;
- }
- sqlite3_stmt* get()
- {
- return this->stmt;
- }
+ virtual std::int64_t get_column_int64(const int col) = 0;
+ virtual std::string get_column_text(const int col) = 0;
+ virtual int get_column_int(const int col) = 0;
- private:
- sqlite3_stmt* stmt;
+ virtual bool bind_text(const int pos, const std::string& data) = 0;
+ virtual bool bind_int64(const int pos, const std::int64_t value) = 0;
+ virtual bool bind_null(const int pos) = 0;
};
diff --git a/src/database/table.cpp b/src/database/table.cpp
deleted file mode 100644
index 9224d79..0000000
--- a/src/database/table.cpp
+++ /dev/null
@@ -1,23 +0,0 @@
-#include <database/table.hpp>
-
-std::set<std::string> get_all_columns_from_table(sqlite3* db, const std::string& table_name)
-{
- std::set<std::string> result;
- char* errmsg;
- std::string query{"PRAGMA table_info(" + table_name + ")"};
- int res = sqlite3_exec(db, query.data(), [](void* param, int columns_nb, char** columns, char**) -> int {
- constexpr int name_column = 1;
- std::set<std::string>* result = static_cast<std::set<std::string>*>(param);
- if (name_column < columns_nb)
- result->insert(columns[name_column]);
- return 0;
- }, &result, &errmsg);
-
- if (res != SQLITE_OK)
- {
- log_error("Error executing ", query, ": ", errmsg);
- sqlite3_free(errmsg);
- }
-
- return result;
-}
diff --git a/src/database/table.hpp b/src/database/table.hpp
index 0060211..680e7cc 100644
--- a/src/database/table.hpp
+++ b/src/database/table.hpp
@@ -1,7 +1,8 @@
#pragma once
+#include <database/engine.hpp>
+
#include <database/select_query.hpp>
-#include <database/type_to_sql.hpp>
#include <database/row.hpp>
#include <algorithm>
@@ -10,23 +11,27 @@
using namespace std::string_literals;
-std::set<std::string> get_all_columns_from_table(sqlite3* db, const std::string& table_name);
+template <typename T>
+std::string ToSQLType(DatabaseEngine& db)
+{
+ if (std::is_same<T, Id>::value)
+ return db.id_column_type();
+ else if (std::is_same<typename T::real_type, std::string>::value)
+ return "TEXT";
+ else
+ return "INTEGER";
+}
template <typename ColumnType>
-void add_column_to_table(sqlite3* db, const std::string& table_name)
+void add_column_to_table(DatabaseEngine& db, const std::string& table_name)
{
const std::string name = ColumnType::name;
- std::string query{"ALTER TABLE " + table_name + " ADD " + ColumnType::name + " " + TypeToSQLType<typename ColumnType::real_type>::type};
- char* error;
- const auto result = sqlite3_exec(db, query.data(), nullptr, nullptr, &error);
- if (result != SQLITE_OK)
- {
- log_error("Error adding column ", name, " to table ", table_name, ": ", error);
- sqlite3_free(error);
- }
+ std::string query{"ALTER TABLE " + table_name + " ADD " + ColumnType::name + " " + ToSQLType<ColumnType>(db)};
+ auto res = db.raw_exec(query);
+ if (std::get<0>(res) == false)
+ log_error("Error adding column ", name, " to table ", table_name, ": ", std::get<1>(res));
}
-
template <typename ColumnType, decltype(ColumnType::options) = nullptr>
void append_option(std::string& s)
{
@@ -50,27 +55,23 @@ class Table
name(std::move(name))
{}
- void upgrade(sqlite3* db)
+ void upgrade(DatabaseEngine& db)
{
- const auto existing_columns = get_all_columns_from_table(db, this->name);
+ const auto existing_columns = db.get_all_columns_from_table(this->name);
add_column_if_not_exists(db, existing_columns);
}
- void create(sqlite3* db)
+ void create(DatabaseEngine& db)
{
- std::string res{"CREATE TABLE IF NOT EXISTS "};
- res += this->name;
- res += " (\n";
- this->add_column_create(res);
- res += ")";
-
- char* error;
- const auto result = sqlite3_exec(db, res.data(), nullptr, nullptr, &error);
- if (result != SQLITE_OK)
- {
- log_error("Error executing query: ", error);
- sqlite3_free(error);
- }
+ std::string query{"CREATE TABLE IF NOT EXISTS "};
+ query += this->name;
+ query += " (";
+ this->add_column_create(db, query);
+ query += ")";
+
+ auto result = db.raw_exec(query);
+ if (std::get<0>(result) == false)
+ log_error("Error executing query: ", std::get<1>(result));
}
RowType row()
@@ -78,7 +79,7 @@ class Table
return {this->name};
}
- SelectQuery<T...> select()
+ auto select()
{
SelectQuery<T...> select(this->name);
return select;
@@ -93,39 +94,34 @@ class Table
template <std::size_t N=0>
typename std::enable_if<N < sizeof...(T), void>::type
- add_column_if_not_exists(sqlite3* db, const std::set<std::string>& existing_columns)
+ add_column_if_not_exists(DatabaseEngine& db, const std::set<std::string>& existing_columns)
{
using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<ColumnTypes>()))>::type;
- if (existing_columns.count(ColumnType::name) != 1)
- {
- add_column_to_table<ColumnType>(db, this->name);
- }
+ if (existing_columns.count(ColumnType::name) == 0)
+ add_column_to_table<ColumnType>(db, this->name);
add_column_if_not_exists<N+1>(db, existing_columns);
}
template <std::size_t N=0>
typename std::enable_if<N == sizeof...(T), void>::type
- add_column_if_not_exists(sqlite3*, const std::set<std::string>&)
+ add_column_if_not_exists(DatabaseEngine&, const std::set<std::string>&)
{}
template <std::size_t N=0>
typename std::enable_if<N < sizeof...(T), void>::type
- add_column_create(std::string& str)
+ add_column_create(DatabaseEngine& db, std::string& str)
{
using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<ColumnTypes>()))>::type;
- using RealType = typename ColumnType::real_type;
str += ColumnType::name;
str += " ";
- str += TypeToSQLType<RealType>::type;
- append_option<ColumnType>(str);
+ str += ToSQLType<ColumnType>(db);
if (N != sizeof...(T) - 1)
str += ",";
- str += "\n";
- add_column_create<N+1>(str);
+ add_column_create<N+1>(db, str);
}
template <std::size_t N=0>
typename std::enable_if<N == sizeof...(T), void>::type
- add_column_create(std::string&)
+ add_column_create(DatabaseEngine&, std::string&)
{ }
const std::string name;
diff --git a/src/database/type_to_sql.cpp b/src/database/type_to_sql.cpp
deleted file mode 100644
index bcd9daa..0000000
--- a/src/database/type_to_sql.cpp
+++ /dev/null
@@ -1,9 +0,0 @@
-#include <database/type_to_sql.hpp>
-
-template <> const std::string TypeToSQLType<int>::type = "INTEGER";
-template <> const std::string TypeToSQLType<std::size_t>::type = "INTEGER";
-template <> const std::string TypeToSQLType<long>::type = "INTEGER";
-template <> const std::string TypeToSQLType<long long>::type = "INTEGER";
-template <> const std::string TypeToSQLType<bool>::type = "INTEGER";
-template <> const std::string TypeToSQLType<std::string>::type = "TEXT";
-template <> const std::string TypeToSQLType<OptionalBool>::type = "INTEGER"; \ No newline at end of file
diff --git a/src/database/type_to_sql.hpp b/src/database/type_to_sql.hpp
deleted file mode 100644
index ba806ab..0000000
--- a/src/database/type_to_sql.hpp
+++ /dev/null
@@ -1,16 +0,0 @@
-#pragma once
-
-#include <utils/optional_bool.hpp>
-
-#include <string>
-
-template <typename T>
-struct TypeToSQLType { static const std::string type; };
-
-template <> const std::string TypeToSQLType<int>::type;
-template <> const std::string TypeToSQLType<std::size_t>::type;
-template <> const std::string TypeToSQLType<long>::type;
-template <> const std::string TypeToSQLType<long long>::type;
-template <> const std::string TypeToSQLType<bool>::type;
-template <> const std::string TypeToSQLType<std::string>::type;
-template <> const std::string TypeToSQLType<OptionalBool>::type; \ No newline at end of file
diff --git a/src/database/update_query.hpp b/src/database/update_query.hpp
new file mode 100644
index 0000000..a29ac3f
--- /dev/null
+++ b/src/database/update_query.hpp
@@ -0,0 +1,104 @@
+#pragma once
+
+#include <database/query.hpp>
+#include <database/engine.hpp>
+
+using namespace std::string_literals;
+
+template <class T, class... Tuple>
+struct Index;
+
+template <class T, class... Types>
+struct Index<T, std::tuple<T, Types...>>
+{
+ static const std::size_t value = 0;
+};
+
+template <class T, class U, class... Types>
+struct Index<T, std::tuple<U, Types...>>
+{
+ static const std::size_t value = Index<T, std::tuple<Types...>>::value + 1;
+};
+
+struct UpdateQuery: public Query
+{
+ template <typename... T>
+ UpdateQuery(const std::string& name, const std::tuple<T...>& columns):
+ Query("UPDATE ")
+ {
+ this->body += name;
+ this->insert_col_names_and_values(columns);
+ }
+
+ template <typename... T>
+ void insert_col_names_and_values(const std::tuple<T...>& columns)
+ {
+ this->body += " SET ";
+ this->insert_col_name_and_value(columns);
+ this->body += " WHERE "s + Id::name + "=$" + std::to_string(this->current_param);
+ }
+
+ template <int N=0, typename... T>
+ typename std::enable_if<N < sizeof...(T), void>::type
+ insert_col_name_and_value(const std::tuple<T...>& columns)
+ {
+ using ColumnType = std::decay_t<decltype(std::get<N>(columns))>;
+
+ if (!std::is_same<ColumnType, Id>::value)
+ {
+ this->body += ColumnType::name + "=$"s + std::to_string(this->current_param);
+ this->current_param++;
+
+ if (N < (sizeof...(T) - 1))
+ this->body += ", ";
+ }
+
+ this->insert_col_name_and_value<N+1>(columns);
+ }
+ template <int N=0, typename... T>
+ typename std::enable_if<N == sizeof...(T), void>::type
+ insert_col_name_and_value(const std::tuple<T...>&)
+ {}
+
+
+ template <typename... T>
+ void execute(DatabaseEngine& db, const std::tuple<T...>& columns)
+ {
+#ifdef DEBUG_SQL_QUERIES
+ const auto timer = this->log_and_time();
+#endif
+
+ auto statement = db.prepare(this->body);
+ this->bind_param(columns, *statement);
+ this->bind_id(columns, *statement);
+
+ statement->step();
+ }
+
+ template <int N=0, typename... T>
+ typename std::enable_if<N < sizeof...(T), void>::type
+ bind_param(const std::tuple<T...>& columns, Statement& statement, int index=1)
+ {
+ auto&& column = std::get<N>(columns);
+ using ColumnType = std::decay_t<decltype(column)>;
+
+ if (!std::is_same<ColumnType, Id>::value)
+ actual_bind(statement, column.value, index++);
+
+ this->bind_param<N+1>(columns, statement, index);
+ }
+
+ template <int N=0, typename... T>
+ typename std::enable_if<N == sizeof...(T), void>::type
+ bind_param(const std::tuple<T...>&, Statement&, int)
+ {}
+
+ template <typename... T>
+ void bind_id(const std::tuple<T...>& columns, Statement& statement)
+ {
+ static constexpr auto index = Index<Id, std::tuple<T...>>::value;
+ auto&& value = std::get<index>(columns);
+
+ actual_bind(statement, value.value, sizeof...(T));
+ }
+};
diff --git a/src/identd/identd_socket.cpp b/src/identd/identd_socket.cpp
index b85257c..92cd80b 100644
--- a/src/identd/identd_socket.cpp
+++ b/src/identd/identd_socket.cpp
@@ -50,14 +50,14 @@ std::string IdentdSocket::generate_answer(const BiboumiComponent& biboumi, uint1
if (pair.second->match_port_pairt(local, remote))
{
std::ostringstream os;
- os << local << " , " << remote << " : USERID : OTHER : " << hash_jid(bridge->get_bare_jid());
+ os << local << " , " << remote << " : USERID : OTHER : " << hash_jid(bridge->get_bare_jid()) << "\r\n";
log_debug("Identd, sending: ", os.str());
return os.str();
}
}
}
std::ostringstream os;
- os << local << " , " << remote << " ERROR : NO-USER";
+ os << local << " , " << remote << " ERROR : NO-USER" << "\r\n";
log_debug("Identd, sending: ", os.str());
return os.str();
}
diff --git a/src/irc/irc_client.cpp b/src/irc/irc_client.cpp
index 46dbdbe..40078d9 100644
--- a/src/irc/irc_client.cpp
+++ b/src/irc/irc_client.cpp
@@ -483,12 +483,16 @@ bool IrcClient::send_channel_message(const std::string& chan_name, const std::st
}
// The max size is 512, taking into account the whole message, not just
// the text we send.
- // This includes our own nick, username and host (because this will be
- // added by the server into our message), in addition to the basic
- // components of the message we send (command name, chan name, \r\n et)
+ // This includes our own nick, constants for username and host (because these
+ // are notoriously hard to know what the server will use), in addition to the basic
+ // components of the message we send (command name, chan name, \r\n etc.)
// : + NICK + ! + USER + @ + HOST + <space> + PRIVMSG + <space> + CHAN + <space> + : + \r\n
+ // 63 is the maximum hostname length defined by the protocol. 10 seems to be
+ // the username limit.
+ constexpr auto max_username_size = 10;
+ constexpr auto max_hostname_size = 63;
const auto line_size = 512 -
- this->current_nick.size() - this->username.size() - this->own_host.size() -
+ this->current_nick.size() - max_username_size - max_hostname_size -
::strlen(":!@ PRIVMSG ") - chan_name.length() - ::strlen(" :\r\n");
const auto lines = cut(body, line_size);
for (const auto& line: lines)
@@ -784,7 +788,7 @@ void IrcClient::on_channel_completely_joined(const IrcMessage& message)
channel->joined = true;
this->bridge.send_user_join(this->hostname, chan_name, channel->get_self(),
channel->get_self()->get_most_significant_mode(this->sorted_user_modes), true);
- this->bridge.send_room_history(this->hostname, chan_name);
+ this->bridge.send_room_history(this->hostname, chan_name, this->history_limit);
this->bridge.send_topic(this->hostname, chan_name, channel->topic, channel->topic_author);
}
@@ -1017,19 +1021,17 @@ void IrcClient::on_quit(const IrcMessage& message)
const std::string& chan_name = pair.first;
IrcChannel* channel = pair.second.get();
const IrcUser* user = channel->find_user(message.prefix);
+ if (!user)
+ continue;
bool self = false;
if (user == channel->get_self())
self = true;
- if (user)
- {
- std::string nick = user->nick;
- channel->remove_user(user);
- Iid iid;
- iid.set_local(chan_name);
- iid.set_server(this->hostname);
- iid.type = Iid::Type::Channel;
- this->bridge.send_muc_leave(iid, std::move(nick), txt, self, false);
- }
+ Iid iid;
+ iid.set_local(chan_name);
+ iid.set_server(this->hostname);
+ iid.type = Iid::Type::Channel;
+ this->bridge.send_muc_leave(iid, user->nick, txt, self, false);
+ channel->remove_user(user);
}
}
@@ -1073,12 +1075,18 @@ void IrcClient::on_nick(const IrcMessage& message)
void IrcClient::on_kick(const IrcMessage& message)
{
const std::string chan_name = utils::tolower(message.arguments[0]);
- const std::string target = message.arguments[1];
+ const std::string target_nick = message.arguments[1];
const std::string reason = message.arguments[2];
IrcChannel* channel = this->get_channel(chan_name);
if (!channel->joined)
return ;
- const bool self = channel->get_self()->nick == target;
+ const IrcUser* target = channel->find_user(target_nick);
+ if (!target)
+ {
+ log_warning("Received a KICK command from a nick absent from the channel.");
+ return;
+ }
+ const bool self = channel->get_self() == target;
if (self)
channel->joined = false;
IrcUser author(message.prefix);
@@ -1086,7 +1094,8 @@ void IrcClient::on_kick(const IrcMessage& message)
iid.set_local(chan_name);
iid.set_server(this->hostname);
iid.type = Iid::Type::Channel;
- this->bridge.kick_muc_user(std::move(iid), target, reason, author.nick, self);
+ this->bridge.kick_muc_user(std::move(iid), target_nick, reason, author.nick, self);
+ channel->remove_user(target);
}
void IrcClient::on_invite(const IrcMessage& message)
diff --git a/src/irc/irc_client.hpp b/src/irc/irc_client.hpp
index aec6cd9..de5c520 100644
--- a/src/irc/irc_client.hpp
+++ b/src/irc/irc_client.hpp
@@ -5,6 +5,8 @@
#include <irc/irc_channel.hpp>
#include <irc/iid.hpp>
+#include <bridge/history_limit.hpp>
+
#include <network/tcp_client_socket_handler.hpp>
#include <network/resolver.hpp>
@@ -296,6 +298,11 @@ public:
const std::vector<char>& get_sorted_user_modes() const { return this->sorted_user_modes; }
std::set<char> get_chantypes() const { return this->chantypes; }
+
+ /**
+ * Store the history limit that the client asked when joining this room.
+ */
+ HistoryLimit history_limit;
private:
/**
* The hostname of the server we are connected to.
diff --git a/src/logger/logger.cpp b/src/logger/logger.cpp
index 92a3d9b..482cb18 100644
--- a/src/logger/logger.cpp
+++ b/src/logger/logger.cpp
@@ -1,12 +1,35 @@
#include <logger/logger.hpp>
#include <config/config.hpp>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
Logger::Logger(const int log_level):
log_level(log_level),
stream(std::cout.rdbuf()),
null_buffer{},
null_stream{&null_buffer}
{
+#ifdef SYSTEMD_FOUND
+ if (!this->use_stdout())
+ return;
+
+ // See https://www.freedesktop.org/software/systemd/man/systemd.exec.html#%24JOURNAL_STREAM
+ const char* journal_stream = ::getenv("JOURNAL_STREAM");
+ if (journal_stream == nullptr)
+ return;
+
+ struct stat s{};
+ const int res = ::fstat(STDOUT_FILENO, &s);
+ if (res == -1)
+ return;
+
+ const auto stdout_stream = std::to_string(s.st_dev) + ":" + std::to_string(s.st_ino);
+
+ if (stdout_stream == journal_stream)
+ this->use_systemd = true;
+#endif
}
Logger::Logger(const int log_level, const std::string& log_file):
diff --git a/src/logger/logger.hpp b/src/logger/logger.hpp
index ff6a82b..315fc11 100644
--- a/src/logger/logger.hpp
+++ b/src/logger/logger.hpp
@@ -9,8 +9,10 @@
*/
#include <memory>
+#include <string>
#include <iostream>
#include <fstream>
+#include <sstream>
#define debug_lvl 0
#define info_lvl 1
@@ -19,12 +21,18 @@
#include "biboumi.h"
#ifdef SYSTEMD_FOUND
+#define SD_JOURNAL_SUPPRESS_LOCATION
# include <systemd/sd-daemon.h>
+# include <systemd/sd-journal.h>
#else
# define SD_DEBUG "[DEBUG]: "
# define SD_INFO "[INFO]: "
# define SD_WARNING "[WARNING]: "
# define SD_ERR "[ERROR]: "
+# define LOG_ERR 3
+# define LOG_WARNING 4
+# define LOG_INFO 6
+# define LOG_DEBUG 7
#endif
// Macro defined to get the filename instead of the full path. But if it is
@@ -57,8 +65,17 @@ public:
Logger(Logger&&) = delete;
Logger& operator=(Logger&&) = delete;
-private:
+#ifdef SYSTEMD_FOUND
+ bool use_stdout() const
+ {
+ return this->stream.rdbuf() == std::cout.rdbuf();
+ }
+
+ bool use_systemd{false};
+#endif
+
const int log_level;
+private:
std::ofstream ofstream{};
std::ostream stream;
@@ -66,8 +83,6 @@ private:
std::ostream null_stream;
};
-#define WHERE __FILENAME__, ":", __LINE__, ":\t"
-
namespace logging_details
{
template <typename T>
@@ -84,45 +99,41 @@ namespace logging_details
}
template <typename... U>
- void log_debug(U&&... args)
- {
- auto& os = Logger::instance()->get_stream(debug_lvl);
- os << SD_DEBUG;
- log(os, std::forward<U>(args)...);
- }
-
- template <typename... U>
- void log_info(U&&... args)
+ void do_logging(const int level, int syslog_level, const char* src_file, int line, U&&... args)
{
- auto& os = Logger::instance()->get_stream(info_lvl);
- os << SD_INFO;
- log(os, std::forward<U>(args)...);
- }
-
- template <typename... U>
- void log_warning(U&&... args)
- {
- auto& os = Logger::instance()->get_stream(warning_lvl);
- os << SD_WARNING;
- log(os, std::forward<U>(args)...);
- }
-
- template <typename... U>
- void log_error(U&&... args)
- {
- auto& os = Logger::instance()->get_stream(error_lvl);
- os << SD_ERR;
- log(os, std::forward<U>(args)...);
+ #ifdef SYSTEMD_FOUND
+ if (Logger::instance()->use_systemd)
+ {
+ (void)level;
+ if (level >= Logger::instance()->log_level)
+ {
+ std::ostringstream os;
+ log(os, std::forward<U>(args)...);
+ sd_journal_send("MESSAGE=%s", os.str().data(),
+ "PRIORITY=%i", syslog_level,
+ "CODE_FILE=%s", src_file,
+ "CODE_LINE=%i", line,
+ nullptr);
+ }
+ }
+ else
+ {
+ #endif
+ (void)syslog_level;
+ static const char* priority_names[] = {"DEBUG", "INFO", "WARNING", "ERROR"};
+ auto& os = Logger::instance()->get_stream(level);
+ os << '[' << priority_names[level] << "]: " << src_file << ':' << line << ":\t";
+ log(os, std::forward<U>(args)...);
+#ifdef SYSTEMD_FOUND
+ }
+#endif
}
}
-#define log_info(...) logging_details::log_info(WHERE, __VA_ARGS__)
-
-#define log_warning(...) logging_details::log_warning(WHERE, __VA_ARGS__)
-
-#define log_error(...) logging_details::log_error(WHERE, __VA_ARGS__)
-
-#define log_debug(...) logging_details::log_debug(WHERE, __VA_ARGS__)
+#define log_debug(...) logging_details::do_logging(debug_lvl, LOG_DEBUG, __FILENAME__, __LINE__, __VA_ARGS__)
+#define log_info(...) logging_details::do_logging(info_lvl, LOG_INFO, __FILENAME__, __LINE__, __VA_ARGS__)
+#define log_warning(...) logging_details::do_logging(warning_lvl, LOG_WARNING, __FILENAME__, __LINE__, __VA_ARGS__)
+#define log_error(...) logging_details::do_logging(error_lvl, LOG_ERR, __FILENAME__, __LINE__, __VA_ARGS__)
diff --git a/src/main.cpp b/src/main.cpp
index 5725584..c877e43 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -88,7 +88,8 @@ int main(int ac, char** av)
#ifdef USE_DATABASE
try {
open_database();
- } catch (...) {
+ } catch (const std::exception& e) {
+ log_error(e.what());
return 1;
}
#endif
diff --git a/src/network/credentials_manager.cpp b/src/network/credentials_manager.cpp
index 7f07cef..b25f442 100644
--- a/src/network/credentials_manager.cpp
+++ b/src/network/credentials_manager.cpp
@@ -5,6 +5,7 @@
#include <network/credentials_manager.hpp>
#include <logger/logger.hpp>
#include <botan/tls_exceptn.h>
+#include <botan/data_src.h>
#include <config/config.hpp>
/**
diff --git a/src/network/credentials_manager.hpp b/src/network/credentials_manager.hpp
index aa4732a..3a37bdc 100644
--- a/src/network/credentials_manager.hpp
+++ b/src/network/credentials_manager.hpp
@@ -4,7 +4,8 @@
#ifdef BOTAN_FOUND
-#include <botan/botan.h>
+#include <botan/credentials_manager.h>
+#include <botan/certstor.h>
#include <botan/tls_client.h>
class TCPSocketHandler;
diff --git a/src/network/tcp_socket_handler.cpp b/src/network/tcp_socket_handler.cpp
index 6239162..642cf03 100644
--- a/src/network/tcp_socket_handler.cpp
+++ b/src/network/tcp_socket_handler.cpp
@@ -12,7 +12,9 @@
#include <cstring>
#ifdef BOTAN_FOUND
+# include <botan/version.h>
# include <botan/hex.h>
+# include <botan/auto_rng.h>
# include <botan/tls_exceptn.h>
# include <config/config.hpp>
# include <utils/dirname.hpp>
@@ -27,6 +29,10 @@ namespace
Botan::TLS::Session_Manager_In_Memory& get_session_manager()
{
static Botan::TLS::Session_Manager_In_Memory session_manager{get_rng()};
+#if BOTAN_VERSION_CODE < BOTAN_VERSION_CODE_FOR(2,4,0)
+ // workaround for https://github.com/randombit/botan/issues/1276
+ session_manager.remove_all();
+#endif
return session_manager;
}
}
diff --git a/src/network/tcp_socket_handler.hpp b/src/network/tcp_socket_handler.hpp
index 5cef739..c598641 100644
--- a/src/network/tcp_socket_handler.hpp
+++ b/src/network/tcp_socket_handler.hpp
@@ -21,7 +21,6 @@
#ifdef BOTAN_FOUND
# include <botan/types.h>
-# include <botan/botan.h>
# include <botan/tls_session_manager.h>
# include <network/tls_policy.hpp>
diff --git a/src/network/tls_policy.cpp b/src/network/tls_policy.cpp
index 5439397..b88eb88 100644
--- a/src/network/tls_policy.cpp
+++ b/src/network/tls_policy.cpp
@@ -8,6 +8,8 @@
#include <network/tls_policy.hpp>
#include <logger/logger.hpp>
+#include <botan/parsing.h>
+#include <botan/exceptn.h>
bool BiboumiTLSPolicy::load(const std::string& filename)
{
diff --git a/src/utils/is_one_of.hpp b/src/utils/is_one_of.hpp
new file mode 100644
index 0000000..4d6770e
--- /dev/null
+++ b/src/utils/is_one_of.hpp
@@ -0,0 +1,17 @@
+#pragma once
+
+#include <type_traits>
+
+template <typename...>
+struct is_one_of_implem {
+ static constexpr bool value = false;
+};
+
+template <typename F, typename S, typename... T>
+struct is_one_of_implem<F, S, T...> {
+ static constexpr bool value =
+ std::is_same<F, S>::value || is_one_of_implem<F, T...>::value;
+};
+
+template<typename... T>
+constexpr bool is_one_of = is_one_of_implem<T...>::value;
diff --git a/src/utils/optional_bool.cpp b/src/utils/optional_bool.cpp
new file mode 100644
index 0000000..56fdca2
--- /dev/null
+++ b/src/utils/optional_bool.cpp
@@ -0,0 +1,8 @@
+#include <utils/optional_bool.hpp>
+
+
+std::ostream& operator<<(std::ostream& os, const OptionalBool& o)
+{
+ os << o.to_string();
+ return os;
+}
diff --git a/src/utils/optional_bool.hpp b/src/utils/optional_bool.hpp
index 59bbbab..867aca2 100644
--- a/src/utils/optional_bool.hpp
+++ b/src/utils/optional_bool.hpp
@@ -20,7 +20,7 @@ struct OptionalBool
this->is_set = false;
}
- std::string to_string()
+ std::string to_string() const
{
if (this->is_set == false)
return "unset";
@@ -33,3 +33,5 @@ struct OptionalBool
bool is_set{false};
bool value{false};
};
+
+std::ostream& operator<<(std::ostream& os, const OptionalBool& o);
diff --git a/src/utils/scopetimer.hpp b/src/utils/scopetimer.hpp
new file mode 100644
index 0000000..7d3db9b
--- /dev/null
+++ b/src/utils/scopetimer.hpp
@@ -0,0 +1,17 @@
+#include <utils/scopeguard.hpp>
+
+#include <chrono>
+
+#include <logger/logger.hpp>
+
+template <typename Callback>
+auto make_scope_timer(Callback cb)
+{
+ const auto start_time = std::chrono::steady_clock::now();
+ return utils::make_scope_guard([start_time, cb = std::move(cb)]()
+ {
+ const auto now = std::chrono::steady_clock::now();
+ const auto elapsed = now - start_time;
+ cb(elapsed);
+ });
+}
diff --git a/src/utils/time.cpp b/src/utils/time.cpp
index bc2c18d..71306fd 100644
--- a/src/utils/time.cpp
+++ b/src/utils/time.cpp
@@ -9,9 +9,10 @@
namespace utils
{
-std::string to_string(const std::time_t& timestamp)
+std::string to_string(const std::chrono::system_clock::time_point::rep& time)
{
constexpr std::size_t stamp_size = 21;
+ const std::time_t timestamp = static_cast<std::time_t>(time);
char date_buf[stamp_size];
if (std::strftime(date_buf, stamp_size, "%FT%TZ", std::gmtime(&timestamp)) != stamp_size - 1)
return "";
diff --git a/src/utils/time.hpp b/src/utils/time.hpp
index c71cd9c..4b19634 100644
--- a/src/utils/time.hpp
+++ b/src/utils/time.hpp
@@ -2,9 +2,10 @@
#include <ctime>
#include <string>
+#include <chrono>
namespace utils
{
-std::string to_string(const std::time_t& timestamp);
+std::string to_string(const std::chrono::system_clock::time_point::rep& timestamp);
std::time_t parse_datetime(const std::string& stamp);
-} \ No newline at end of file
+}
diff --git a/src/xmpp/adhoc_commands_handler.cpp b/src/xmpp/adhoc_commands_handler.cpp
index e4dcd5c..bb48781 100644
--- a/src/xmpp/adhoc_commands_handler.cpp
+++ b/src/xmpp/adhoc_commands_handler.cpp
@@ -83,7 +83,7 @@ XmlNode AdhocCommandsHandler::handle_request(const std::string& executor_jid, co
XmlSubNode next(actions, "next");
}
}
- else if (action == "cancel")
+ else if (session_it != this->sessions.end() && action == "cancel")
{
this->sessions.erase(session_it);
command_node["status"] = "canceled";
diff --git a/src/xmpp/biboumi_adhoc_commands.cpp b/src/xmpp/biboumi_adhoc_commands.cpp
index 60af506..bcdac39 100644
--- a/src/xmpp/biboumi_adhoc_commands.cpp
+++ b/src/xmpp/biboumi_adhoc_commands.cpp
@@ -159,7 +159,7 @@ void ConfigureGlobalStep1(XmppComponent&, AdhocSession& session, XmlNode& comman
{
XmlSubNode value(persistent, "value");
value.set_name("value");
- if (options.col<Database::Persistent>())
+ if (options.col<Database::GlobalPersistent>())
value.set_inner("true");
else
value.set_inner("false");
@@ -193,7 +193,7 @@ void ConfigureGlobalStep2(XmppComponent& xmpp_component, AdhocSession& session,
}
else if (field->get_tag("var") == "persistent" &&
value)
- options.col<Database::Persistent>() = to_bool(value->get_inner());
+ options.col<Database::GlobalPersistent>() = to_bool(value->get_inner());
}
options.save(Database::db);
@@ -409,7 +409,7 @@ void ConfigureIrcServerStep2(XmppComponent&, AdhocSession& session, XmlNode& com
else if (field->get_tag("var") == "pass" && value)
options.col<Database::Pass>() = value->get_inner();
- else if (field->get_tag("var") == "after_connect_command")
+ else if (field->get_tag("var") == "after_connect_command" && value)
options.col<Database::AfterConnectionCommand>() = value->get_inner();
else if (field->get_tag("var") == "username" && value)
@@ -430,7 +430,7 @@ void ConfigureIrcServerStep2(XmppComponent&, AdhocSession& session, XmlNode& com
options.col<Database::EncodingIn>() = value->get_inner();
}
-
+ Database::invalidate_encoding_in_cache();
options.save(Database::db);
command_node.delete_all_children();
@@ -599,7 +599,7 @@ bool handle_irc_channel_configuration_form(XmppComponent& xmpp_component, const
}
}
-
+ Database::invalidate_encoding_in_cache(requester.bare(), iid.get_server(), iid.get_local());
options.save(Database::db);
}
return true;
diff --git a/src/xmpp/biboumi_component.cpp b/src/xmpp/biboumi_component.cpp
index 0e1d270..481ebb9 100644
--- a/src/xmpp/biboumi_component.cpp
+++ b/src/xmpp/biboumi_component.cpp
@@ -7,6 +7,7 @@
#include <xmpp/adhoc_command.hpp>
#include <xmpp/biboumi_adhoc_commands.hpp>
#include <bridge/list_element.hpp>
+#include <utils/encoding.hpp>
#include <config/config.hpp>
#include <utils/time.hpp>
#include <xmpp/jid.hpp>
@@ -24,6 +25,7 @@
#include <database/database.hpp>
#include <bridge/result_set_management.hpp>
+#include <bridge/history_limit.hpp>
using namespace std::string_literals;
@@ -155,8 +157,32 @@ void BiboumiComponent::handle_presence(const Stanza& stanza)
bridge->send_irc_nick_change(iid, to.resource, from.resource);
const XmlNode* x = stanza.get_child("x", MUC_NS);
const XmlNode* password = x ? x->get_child("password", MUC_NS): nullptr;
+ const XmlNode* history = x ? x->get_child("history", MUC_NS): nullptr;
+ HistoryLimit history_limit;
+ if (history)
+ {
+ const auto seconds = history->get_tag("seconds");
+ if (!seconds.empty())
+ {
+ const auto now = std::chrono::system_clock::now();
+ std::time_t timestamp = std::chrono::system_clock::to_time_t(now);
+ int int_seconds = std::atoi(seconds.data());
+ timestamp -= int_seconds;
+ history_limit.since = utils::to_string(timestamp);
+ }
+ const auto since = history->get_tag("since");
+ if (!since.empty())
+ history_limit.since = since;
+ const auto maxstanzas = history->get_tag("maxstanzas");
+ if (!maxstanzas.empty())
+ history_limit.stanzas = std::atoi(maxstanzas.data());
+ // Ignore any other value, because this is too complex to implement,
+ // so I won’t do it.
+ if (history->get_tag("maxchars") == "0")
+ history_limit.stanzas = 0;
+ }
bridge->join_irc_channel(iid, to.resource, password ? password->get_inner(): "",
- from.resource);
+ from.resource, history_limit);
}
else if (type == "unavailable")
{
@@ -281,6 +307,7 @@ void BiboumiComponent::handle_message(const Stanza& stanza)
{
if (body && !body->get_inner().empty())
{
+ const auto fixed_irc_server = Config::get("fixed_irc_server", "");
// a message for nick!server
if (iid.type == Iid::Type::User && !iid.get_local().empty())
{
@@ -296,9 +323,11 @@ void BiboumiComponent::handle_message(const Stanza& stanza)
bridge->set_preferred_from_jid(user_iid.get_local(), to_str);
}
else if (iid.type == Iid::Type::Server)
+ bridge->send_raw_message(iid.get_server(), body->get_inner());
+ else if (iid.type == Iid::Type::None && !fixed_irc_server.empty())
{ // Message sent to the server JID
// Convert the message body into a raw IRC message
- bridge->send_raw_message(iid.get_server(), body->get_inner());
+ bridge->send_raw_message(fixed_irc_server, body->get_inner());
}
}
}
@@ -408,7 +437,7 @@ void BiboumiComponent::handle_iq(const Stanza& stanza)
// Depending on the 'to' jid in the request, we use one adhoc
// command handler or an other
- Iid iid(to.local, {});
+ Iid iid(to.local, {'#', '&'});
AdhocCommandsHandler* adhoc_handler;
if (to.local.empty())
adhoc_handler = &this->adhoc_commands_handler;
@@ -416,8 +445,13 @@ void BiboumiComponent::handle_iq(const Stanza& stanza)
{
if (iid.type == Iid::Type::Server)
adhoc_handler = &this->irc_server_adhoc_commands_handler;
- else
+ else if (iid.type == Iid::Type::Channel && to.resource.empty())
adhoc_handler = &this->irc_channel_adhoc_commands_handler;
+ else
+ {
+ error_name = "feature-not-implemented";
+ return;
+ }
}
// Execute the command, if any, and get a result XmlNode that we
// insert in our response
@@ -467,7 +501,7 @@ void BiboumiComponent::handle_iq(const Stanza& stanza)
stanza_error.disable();
}
}
- else if (iid.type == Iid::Type::Channel)
+ else if (iid.type == Iid::Type::Channel && to.resource.empty())
{
if (node.empty())
{
@@ -526,7 +560,7 @@ void BiboumiComponent::handle_iq(const Stanza& stanza)
this->irc_server_adhoc_commands_handler);
stanza_error.disable();
}
- else if (iid.type == Iid::Type::Channel)
+ else if (iid.type == Iid::Type::Channel && to.resource.empty())
{ // Get the channel's adhoc commands
this->send_adhoc_commands_list(id, from, to_str,
(Config::get("admin", "") ==
@@ -534,6 +568,8 @@ void BiboumiComponent::handle_iq(const Stanza& stanza)
this->irc_channel_adhoc_commands_handler);
stanza_error.disable();
}
+ else // “to” is a MUC user, not the room itself
+ error_name = "feature-not-implemented";
}
else if (node.empty() && iid.type == Iid::Type::Server)
{ // Disco on an IRC server: get the list of channels
@@ -685,19 +721,46 @@ bool BiboumiComponent::handle_mam_request(const Stanza& stanza)
if (max)
limit = std::atoi(max->get_inner().data());
}
- // If the archive is really big, and the client didn’t specify any
- // limit, we avoid flooding it: we set an arbitrary max limit.
- if (limit == -1 && start.empty() && end.empty())
+ // Do send more than 100 messages, even if the client asked for more,
+ // or if it didn’t specify any limit.
+ // 101 is just a trick to know if there are more available messages.
+ // If our query returns 101 message, we know it’s incomplete, but we
+ // still send only 100
+ if ((limit == -1 && start.empty() && end.empty())
+ || limit > 100)
+ limit = 101;
+ auto lines = Database::get_muc_logs(from.bare(), iid.get_local(), iid.get_server(), limit, start, end);
+ bool complete = true;
+ if (lines.size() > 100)
{
- limit = 100;
+ complete = false;
+ lines.erase(lines.begin(), std::prev(lines.end(), 100));
}
- const auto lines = Database::get_muc_logs(from.bare(), iid.get_local(), iid.get_server(), limit, start, end);
for (const Database::MucLogLine& line: lines)
{
if (!line.col<Database::Nick>().empty())
this->send_archived_message(line, to.full(), from.full(), query_id);
}
- this->send_iq_result_full_jid(id, from.full(), to.full());
+ {
+ auto fin_ptr = std::make_unique<XmlNode>("fin");
+ {
+ XmlNode& fin = *(fin_ptr.get());
+ fin["xmlns"] = MAM_NS;
+ if (complete)
+ fin["complete"] = "true";
+ XmlSubNode set(fin, "set");
+ set["xmlns"] = RSM_NS;
+ if (!lines.empty())
+ {
+ XmlSubNode first(set, "first");
+ first["index"] = "0";
+ first.set_inner(lines[0].col<Database::Uuid>());
+ XmlSubNode last(set, "last");
+ last.set_inner(lines[lines.size() - 1].col<Database::Uuid>());
+ }
+ }
+ this->send_iq_result_full_jid(id, from.full(), to.full(), std::move(fin_ptr));
+ }
return true;
}
return false;
@@ -739,7 +802,7 @@ bool BiboumiComponent::handle_room_configuration_form_request(const std::string&
{
Iid iid(to.local, {'#', '&'});
- if (iid.type != Iid::Type::Channel)
+ if (iid.type != Iid::Type::Channel || !to.resource.empty())
return false;
Stanza iq("iq");
@@ -761,7 +824,7 @@ bool BiboumiComponent::handle_room_configuration_form(const XmlNode& query, cons
{
Iid iid(to.local, {'#', '&'});
- if (iid.type != Iid::Type::Channel)
+ if (iid.type != Iid::Type::Channel || !to.resource.empty())
return false;
Jid requester(from);
@@ -958,7 +1021,9 @@ void BiboumiComponent::send_iq_room_list_result(const std::string& id, const std
for (auto it = begin; it != end; ++it)
{
XmlSubNode item(query, "item");
- item["jid"] = it->channel + "@" + this->served_hostname;
+ std::string channel_name = it->channel;
+ xep0106::encode(channel_name);
+ item["jid"] = channel_name + "@" + this->served_hostname;
}
if ((rs_info.max >= 0 || !rs_info.after.empty() || !rs_info.before.empty()))
@@ -1052,6 +1117,9 @@ void BiboumiComponent::on_irc_client_connected(const std::string& irc_hostname,
const auto local_jid = irc_hostname + "@" + this->served_hostname;
if (Database::has_roster_item(local_jid, jid))
this->send_presence_to_contact(local_jid, jid, "");
+#else
+ (void)irc_hostname;
+ (void)jid;
#endif
}
@@ -1061,6 +1129,9 @@ void BiboumiComponent::on_irc_client_disconnected(const std::string& irc_hostnam
const auto local_jid = irc_hostname + "@" + this->served_hostname;
if (Database::has_roster_item(local_jid, jid))
this->send_presence_to_contact(irc_hostname + "@" + this->served_hostname, jid, "unavailable");
+#else
+ (void)irc_hostname;
+ (void)jid;
#endif
}
diff --git a/src/xmpp/body.hpp b/src/xmpp/body.hpp
index 068d1a4..f693cdd 100644
--- a/src/xmpp/body.hpp
+++ b/src/xmpp/body.hpp
@@ -1,5 +1,9 @@
#pragma once
+#include <tuple>
+#include <memory>
+
+class XmlNode;
namespace Xmpp
{
diff --git a/src/xmpp/xmpp_component.cpp b/src/xmpp/xmpp_component.cpp
index 42a5392..9be9e34 100644
--- a/src/xmpp/xmpp_component.cpp
+++ b/src/xmpp/xmpp_component.cpp
@@ -269,7 +269,8 @@ void* XmppComponent::get_receive_buffer(const size_t size) const
}
void XmppComponent::send_message(const std::string& from, Xmpp::body&& body, const std::string& to,
- const std::string& type, const bool fulljid, const bool nocopy)
+ const std::string& type, const bool fulljid, const bool nocopy,
+ const bool muc_private)
{
Stanza message("message");
{
@@ -277,7 +278,12 @@ void XmppComponent::send_message(const std::string& from, Xmpp::body&& body, con
if (fulljid)
message["from"] = from;
else
- message["from"] = from + "@" + this->served_hostname;
+ {
+ if (!from.empty())
+ message["from"] = from + "@" + this->served_hostname;
+ else
+ message["from"] = this->served_hostname;
+ }
if (!type.empty())
message["type"] = type;
XmlSubNode body_node(message, "body");
@@ -296,6 +302,11 @@ void XmppComponent::send_message(const std::string& from, Xmpp::body&& body, con
XmlSubNode nocopy(message, "no-copy");
nocopy["xmlns"] = "urn:xmpp:hints";
}
+ if (muc_private)
+ {
+ XmlSubNode x(message, "x");
+ x["xmlns"] = MUC_USER_NS;
+ }
}
this->send_stanza(message);
}
@@ -387,7 +398,8 @@ void XmppComponent::send_muc_message(const std::string& muc_name, const std::str
this->send_stanza(message);
}
-void XmppComponent::send_history_message(const std::string& muc_name, const std::string& nick, const std::string& body_txt, const std::string& jid_to, std::time_t timestamp)
+#ifdef USE_DATABASE
+void XmppComponent::send_history_message(const std::string& muc_name, const std::string& nick, const std::string& body_txt, const std::string& jid_to, Database::time_point::rep timestamp)
{
Stanza message("message");
message["to"] = jid_to;
@@ -410,6 +422,7 @@ void XmppComponent::send_history_message(const std::string& muc_name, const std:
this->send_stanza(message);
}
+#endif
void XmppComponent::send_muc_leave(const std::string& muc_name, const std::string& nick, Xmpp::body&& message,
const std::string& jid_to, const bool self, const bool user_requested)
@@ -629,13 +642,15 @@ void XmppComponent::send_iq_version_request(const std::string& from,
this->send_stanza(iq);
}
-void XmppComponent::send_iq_result_full_jid(const std::string& id, const std::string& to_jid, const std::string& from_full_jid)
+void XmppComponent::send_iq_result_full_jid(const std::string& id, const std::string& to_jid, const std::string& from_full_jid, std::unique_ptr<XmlNode> inner)
{
Stanza iq("iq");
iq["from"] = from_full_jid;
iq["to"] = to_jid;
iq["id"] = id;
iq["type"] = "result";
+ if (inner)
+ iq.add_child(std::move(inner));
this->send_stanza(iq);
}
diff --git a/src/xmpp/xmpp_component.hpp b/src/xmpp/xmpp_component.hpp
index 22d5c48..1daa6fb 100644
--- a/src/xmpp/xmpp_component.hpp
+++ b/src/xmpp/xmpp_component.hpp
@@ -1,8 +1,10 @@
#pragma once
+#include "biboumi.h"
#include <xmpp/adhoc_commands_handler.hpp>
#include <network/tcp_client_socket_handler.hpp>
+#include <database/database.hpp>
#include <xmpp/xmpp_parser.hpp>
#include <xmpp/body.hpp>
@@ -112,7 +114,8 @@ public:
* server-part of the JID and must be added.
*/
void send_message(const std::string& from, Xmpp::body&& body, const std::string& to,
- const std::string& type, const bool fulljid, const bool nocopy=false);
+ const std::string& type, const bool fulljid, const bool nocopy=false,
+ const bool muc_private=false);
/**
* Send a join from a new participant
*/
@@ -132,11 +135,13 @@ public:
*/
void send_muc_message(const std::string& muc_name, const std::string& nick, Xmpp::body&& body, const std::string& jid_to,
std::string uuid);
+#ifdef USE_DATABASE
/**
* Send a message, with a <delay/> element, part of a MUC history
*/
void send_history_message(const std::string& muc_name, const std::string& nick, const std::string& body,
- const std::string& jid_to, const std::time_t timestamp);
+ const std::string& jid_to, Database::time_point::rep timestamp);
+#endif
/**
* Send an unavailable presence for this nick
*/
@@ -202,7 +207,7 @@ public:
*/
void send_iq_result(const std::string& id, const std::string& to_jid, const std::string& from);
void send_iq_result_full_jid(const std::string& id, const std::string& to_jid,
- const std::string& from_full_jid);
+ const std::string& from_full_jid, std::unique_ptr<XmlNode> inner=nullptr);
void handle_handshake(const Stanza& stanza);
void handle_error(const Stanza& stanza);