summaryrefslogtreecommitdiff
path: root/src/database
diff options
context:
space:
mode:
Diffstat (limited to 'src/database')
-rw-r--r--src/database/column.hpp9
-rw-r--r--src/database/count_query.hpp17
-rw-r--r--src/database/database.cpp81
-rw-r--r--src/database/database.hpp73
-rw-r--r--src/database/engine.hpp41
-rw-r--r--src/database/index.hpp38
-rw-r--r--src/database/insert_query.hpp108
-rw-r--r--src/database/postgresql_engine.cpp91
-rw-r--r--src/database/postgresql_engine.hpp48
-rw-r--r--src/database/postgresql_statement.hpp123
-rw-r--r--src/database/query.cpp26
-rw-r--r--src/database/query.hpp62
-rw-r--r--src/database/row.hpp108
-rw-r--r--src/database/select_query.hpp31
-rw-r--r--src/database/sqlite3_engine.cpp101
-rw-r--r--src/database/sqlite3_engine.hpp47
-rw-r--r--src/database/sqlite3_statement.hpp92
-rw-r--r--src/database/statement.hpp46
-rw-r--r--src/database/table.cpp23
-rw-r--r--src/database/table.hpp80
-rw-r--r--src/database/type_to_sql.cpp9
-rw-r--r--src/database/type_to_sql.hpp16
-rw-r--r--src/database/update_query.hpp104
23 files changed, 1046 insertions, 328 deletions
diff --git a/src/database/column.hpp b/src/database/column.hpp
index 111f9ca..1f16bcf 100644
--- a/src/database/column.hpp
+++ b/src/database/column.hpp
@@ -13,5 +13,10 @@ struct Column
T value{};
};
-struct Id: Column<std::size_t> { static constexpr auto name = "id_";
- static constexpr auto options = "PRIMARY KEY AUTOINCREMENT"; };
+struct Id: Column<std::size_t> {
+ static constexpr std::size_t unset_value = static_cast<std::size_t>(-1);
+ static constexpr auto name = "id_";
+ static constexpr auto options = "PRIMARY KEY";
+
+ Id(): Column<std::size_t>(-1) {}
+};
diff --git a/src/database/count_query.hpp b/src/database/count_query.hpp
index 0dde63c..118ce44 100644
--- a/src/database/count_query.hpp
+++ b/src/database/count_query.hpp
@@ -2,11 +2,10 @@
#include <database/query.hpp>
#include <database/table.hpp>
+#include <database/statement.hpp>
#include <string>
-#include <sqlite3.h>
-
struct CountQuery: public Query
{
CountQuery(std::string name):
@@ -15,20 +14,20 @@ struct CountQuery: public Query
this->body += std::move(name);
}
- int64_t execute(sqlite3* db)
+ int64_t execute(DatabaseEngine& db)
{
- auto statement = this->prepare(db);
+#ifdef DEBUG_SQL_QUERIES
+ const auto timer = this->log_and_time();
+#endif
+ auto statement = db.prepare(this->body);
int64_t res = 0;
- if (sqlite3_step(statement.get()) == SQLITE_ROW)
- res = sqlite3_column_int64(statement.get(), 0);
+ if (statement->step() != StepResult::Error)
+ res = statement->get_column_int64(0);
else
{
log_error("Count request didn’t return a result");
return 0;
}
- if (sqlite3_step(statement.get()) != SQLITE_DONE)
- log_warning("Count request returned more than one result.");
-
return res;
}
};
diff --git a/src/database/database.cpp b/src/database/database.cpp
index 85c675e..3622963 100644
--- a/src/database/database.cpp
+++ b/src/database/database.cpp
@@ -6,40 +6,54 @@
#include <utils/get_first_non_empty.hpp>
#include <utils/time.hpp>
-#include <sqlite3.h>
+#include <config/config.hpp>
+#include <database/sqlite3_engine.hpp>
+#include <database/postgresql_engine.hpp>
-sqlite3* Database::db;
-Database::MucLogLineTable Database::muc_log_lines("MucLogLine_");
-Database::GlobalOptionsTable Database::global_options("GlobalOptions_");
-Database::IrcServerOptionsTable Database::irc_server_options("IrcServerOptions_");
-Database::IrcChannelOptionsTable Database::irc_channel_options("IrcChannelOptions_");
+#include <database/engine.hpp>
+#include <database/index.hpp>
+
+#include <memory>
+
+std::unique_ptr<DatabaseEngine> Database::db;
+Database::MucLogLineTable Database::muc_log_lines("muclogline_");
+Database::GlobalOptionsTable Database::global_options("globaloptions_");
+Database::IrcServerOptionsTable Database::irc_server_options("ircserveroptions_");
+Database::IrcChannelOptionsTable Database::irc_channel_options("ircchanneloptions_");
Database::RosterTable Database::roster("roster");
+std::map<Database::CacheKey, Database::EncodingIn::real_type> Database::encoding_in_cache{};
+Database::GlobalPersistent::GlobalPersistent():
+ Column<bool>{Config::get_bool("persistent_by_default", false)}
+{}
void Database::open(const std::string& filename)
{
// Try to open the specified database.
// Close and replace the previous database pointer if it succeeded. If it did
// not, just leave things untouched
- sqlite3* new_db;
- auto res = sqlite3_open_v2(filename.data(), &new_db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr);
- if (res != SQLITE_OK)
- {
- log_error("Failed to open database file ", filename, ": ", sqlite3_errmsg(Database::db));
- throw std::runtime_error("");
- }
- Database::close();
- Database::db = new_db;
- Database::muc_log_lines.create(Database::db);
- Database::muc_log_lines.upgrade(Database::db);
- Database::global_options.create(Database::db);
- Database::global_options.upgrade(Database::db);
- Database::irc_server_options.create(Database::db);
- Database::irc_server_options.upgrade(Database::db);
- Database::irc_channel_options.create(Database::db);
- Database::irc_channel_options.upgrade(Database::db);
- Database::roster.create(Database::db);
- Database::roster.upgrade(Database::db);
+ std::unique_ptr<DatabaseEngine> new_db;
+ static const auto psql_prefix = "postgresql://"s;
+ static const auto psql_prefix2 = "postgres://"s;
+ if ((filename.substr(0, psql_prefix.size()) == psql_prefix) ||
+ (filename.substr(0, psql_prefix2.size()) == psql_prefix2))
+ new_db = PostgresqlEngine::open(filename);
+ else
+ new_db = Sqlite3Engine::open(filename);
+ if (!new_db)
+ return;
+ Database::db = std::move(new_db);
+ Database::muc_log_lines.create(*Database::db);
+ Database::muc_log_lines.upgrade(*Database::db);
+ Database::global_options.create(*Database::db);
+ Database::global_options.upgrade(*Database::db);
+ Database::irc_server_options.create(*Database::db);
+ Database::irc_server_options.upgrade(*Database::db);
+ Database::irc_channel_options.create(*Database::db);
+ Database::irc_channel_options.upgrade(*Database::db);
+ Database::roster.create(*Database::db);
+ Database::roster.upgrade(*Database::db);
+ create_index<Database::Owner, Database::IrcChanName, Database::IrcServerName>(*Database::db, "archive_index", Database::muc_log_lines.get_name());
}
@@ -49,7 +63,7 @@ Database::GlobalOptions Database::get_global_options(const std::string& owner)
request.where() << Owner{} << "=" << owner;
Database::GlobalOptions options{Database::global_options.get_name()};
- auto result = request.execute(Database::db);
+ auto result = request.execute(*Database::db);
if (result.size() == 1)
options = result.front();
else
@@ -63,7 +77,7 @@ Database::IrcServerOptions Database::get_irc_server_options(const std::string& o
request.where() << Owner{} << "=" << owner << " and " << Server{} << "=" << server;
Database::IrcServerOptions options{Database::irc_server_options.get_name()};
- auto result = request.execute(Database::db);
+ auto result = request.execute(*Database::db);
if (result.size() == 1)
options = result.front();
else
@@ -81,7 +95,7 @@ Database::IrcChannelOptions Database::get_irc_channel_options(const std::string&
" and " << Server{} << "=" << server <<\
" and " << Channel{} << "=" << channel;
Database::IrcChannelOptions options{Database::irc_channel_options.get_name()};
- auto result = request.execute(Database::db);
+ auto result = request.execute(*Database::db);
if (result.size() == 1)
options = result.front();
else
@@ -176,7 +190,7 @@ std::vector<Database::MucLogLine> Database::get_muc_logs(const std::string& owne
if (limit >= 0)
request.limit() << limit;
- auto result = request.execute(Database::db);
+ auto result = request.execute(*Database::db);
return {result.crbegin(), result.crend()};
}
@@ -197,7 +211,7 @@ void Database::delete_roster_item(const std::string& local, const std::string& r
query << " WHERE " << Database::RemoteJid{} << "=" << remote << \
" AND " << Database::LocalJid{} << "=" << local;
- query.execute(Database::db);
+// query.execute(*Database::db);
}
bool Database::has_roster_item(const std::string& local, const std::string& remote)
@@ -206,7 +220,7 @@ bool Database::has_roster_item(const std::string& local, const std::string& remo
query.where() << Database::LocalJid{} << "=" << local << \
" and " << Database::RemoteJid{} << "=" << remote;
- auto res = query.execute(Database::db);
+ auto res = query.execute(*Database::db);
return !res.empty();
}
@@ -216,19 +230,18 @@ std::vector<Database::RosterItem> Database::get_contact_list(const std::string&
auto query = Database::roster.select();
query.where() << Database::LocalJid{} << "=" << local;
- return query.execute(Database::db);
+ return query.execute(*Database::db);
}
std::vector<Database::RosterItem> Database::get_full_roster()
{
auto query = Database::roster.select();
- return query.execute(Database::db);
+ return query.execute(*Database::db);
}
void Database::close()
{
- sqlite3_close_v2(Database::db);
Database::db = nullptr;
}
diff --git a/src/database/database.hpp b/src/database/database.hpp
index c00c938..ec44543 100644
--- a/src/database/database.hpp
+++ b/src/database/database.hpp
@@ -7,12 +7,15 @@
#include <database/column.hpp>
#include <database/count_query.hpp>
+#include <database/engine.hpp>
+
#include <utils/optional_bool.hpp>
#include <chrono>
#include <string>
#include <memory>
+#include <map>
class Database
@@ -24,11 +27,11 @@ class Database
struct Owner: Column<std::string> { static constexpr auto name = "owner_"; };
- struct IrcChanName: Column<std::string> { static constexpr auto name = "ircChanName_"; };
+ struct IrcChanName: Column<std::string> { static constexpr auto name = "ircchanname_"; };
struct Channel: Column<std::string> { static constexpr auto name = "channel_"; };
- struct IrcServerName: Column<std::string> { static constexpr auto name = "ircServerName_"; };
+ struct IrcServerName: Column<std::string> { static constexpr auto name = "ircservername_"; };
struct Server: Column<std::string> { static constexpr auto name = "server_"; };
@@ -43,35 +46,38 @@ class Database
struct Ports: Column<std::string> { static constexpr auto name = "ports_";
Ports(): Column<std::string>("6667") {} };
- struct TlsPorts: Column<std::string> { static constexpr auto name = "tlsPorts_";
+ struct TlsPorts: Column<std::string> { static constexpr auto name = "tlsports_";
TlsPorts(): Column<std::string>("6697;6670") {} };
struct Username: Column<std::string> { static constexpr auto name = "username_"; };
struct Realname: Column<std::string> { static constexpr auto name = "realname_"; };
- struct AfterConnectionCommand: Column<std::string> { static constexpr auto name = "afterConnectionCommand_"; };
+ struct AfterConnectionCommand: Column<std::string> { static constexpr auto name = "afterconnectioncommand_"; };
- struct TrustedFingerprint: Column<std::string> { static constexpr auto name = "trustedFingerprint_"; };
+ struct TrustedFingerprint: Column<std::string> { static constexpr auto name = "trustedfingerprint_"; };
- struct EncodingOut: Column<std::string> { static constexpr auto name = "encodingOut_"; };
+ struct EncodingOut: Column<std::string> { static constexpr auto name = "encodingout_"; };
- struct EncodingIn: Column<std::string> { static constexpr auto name = "encodingIn_"; };
+ struct EncodingIn: Column<std::string> { static constexpr auto name = "encodingin_"; };
- struct MaxHistoryLength: Column<int> { static constexpr auto name = "maxHistoryLength_";
+ struct MaxHistoryLength: Column<int> { static constexpr auto name = "maxhistorylength_";
MaxHistoryLength(): Column<int>(20) {} };
- struct RecordHistory: Column<bool> { static constexpr auto name = "recordHistory_";
+ struct RecordHistory: Column<bool> { static constexpr auto name = "recordhistory_";
RecordHistory(): Column<bool>(true) {}};
- struct RecordHistoryOptional: Column<OptionalBool> { static constexpr auto name = "recordHistory_"; };
+ struct RecordHistoryOptional: Column<OptionalBool> { static constexpr auto name = "recordhistory_"; };
- struct VerifyCert: Column<bool> { static constexpr auto name = "verifyCert_";
+ struct VerifyCert: Column<bool> { static constexpr auto name = "verifycert_";
VerifyCert(): Column<bool>(true) {} };
struct Persistent: Column<bool> { static constexpr auto name = "persistent_";
Persistent(): Column<bool>(false) {} };
+ struct GlobalPersistent: Column<bool> { static constexpr auto name = "persistent_";
+ GlobalPersistent(); };
+
struct LocalJid: Column<std::string> { static constexpr auto name = "local"; };
struct RemoteJid: Column<std::string> { static constexpr auto name = "remote"; };
@@ -80,7 +86,7 @@ class Database
using MucLogLineTable = Table<Id, Uuid, Owner, IrcChanName, IrcServerName, Date, Body, Nick>;
using MucLogLine = MucLogLineTable::RowType;
- using GlobalOptionsTable = Table<Id, Owner, MaxHistoryLength, RecordHistory, Persistent>;
+ using GlobalOptionsTable = Table<Id, Owner, MaxHistoryLength, RecordHistory, GlobalPersistent>;
using GlobalOptions = GlobalOptionsTable::RowType;
using IrcServerOptionsTable = Table<Id, Owner, Server, Pass, AfterConnectionCommand, TlsPorts, Ports, Username, Realname, VerifyCert, TrustedFingerprint, EncodingOut, EncodingIn, MaxHistoryLength>;
@@ -130,7 +136,7 @@ class Database
static int64_t count(const TableType& table)
{
CountQuery query{table.get_name()};
- return query.execute(Database::db);
+ return query.execute(*Database::db);
}
static MucLogLineTable muc_log_lines;
@@ -138,9 +144,48 @@ class Database
static IrcServerOptionsTable irc_server_options;
static IrcChannelOptionsTable irc_channel_options;
static RosterTable roster;
- static sqlite3* db;
+ static std::unique_ptr<DatabaseEngine> db;
+
+ /**
+ * Some caches, to avoid doing very frequent query requests for a few options.
+ */
+ using CacheKey = std::tuple<std::string, std::string, std::string>;
+
+ static EncodingIn::real_type get_encoding_in(const std::string& owner,
+ const std::string& server,
+ const std::string& channel)
+ {
+ CacheKey channel_key{owner, server, channel};
+ auto it = Database::encoding_in_cache.find(channel_key);
+ if (it == Database::encoding_in_cache.end())
+ {
+ auto options = Database::get_irc_channel_options_with_server_default(owner, server, channel);
+ EncodingIn::real_type result = options.col<Database::EncodingIn>();
+ if (result.empty())
+ result = "ISO-8859-1";
+ it = Database::encoding_in_cache.insert(std::make_pair(channel_key, result)).first;
+ }
+ return it->second;
+ }
+ static void invalidate_encoding_in_cache(const std::string& owner,
+ const std::string& server,
+ const std::string& channel)
+ {
+ CacheKey channel_key{owner, server, channel};
+ Database::encoding_in_cache.erase(channel_key);
+ }
+ static void invalidate_encoding_in_cache()
+ {
+ Database::encoding_in_cache.clear();
+ }
+
+ static auto raw_exec(const std::string& query)
+ {
+ Database::db->raw_exec(query);
+ }
private:
static std::string gen_uuid();
+ static std::map<CacheKey, EncodingIn::real_type> encoding_in_cache;
};
#endif /* USE_DATABASE */
diff --git a/src/database/engine.hpp b/src/database/engine.hpp
new file mode 100644
index 0000000..41dccf5
--- /dev/null
+++ b/src/database/engine.hpp
@@ -0,0 +1,41 @@
+#pragma once
+
+/**
+ * Interface to provide non-portable behaviour, specific to each
+ * database engine we want to support.
+ *
+ * Everything else (all portable stuf) should go outside of this class.
+ */
+
+#include <database/statement.hpp>
+
+#include <memory>
+#include <string>
+#include <vector>
+#include <tuple>
+#include <set>
+
+class DatabaseEngine
+{
+ public:
+
+ DatabaseEngine() = default;
+ virtual ~DatabaseEngine() = default;
+
+ DatabaseEngine(const DatabaseEngine&) = delete;
+ DatabaseEngine& operator=(const DatabaseEngine&) = delete;
+ DatabaseEngine(DatabaseEngine&&) = delete;
+ DatabaseEngine& operator=(DatabaseEngine&&) = delete;
+
+ virtual std::set<std::string> get_all_columns_from_table(const std::string& table_name) = 0;
+ virtual std::tuple<bool, std::string> raw_exec(const std::string& query) = 0;
+ virtual std::unique_ptr<Statement> prepare(const std::string& query) = 0;
+ virtual void extract_last_insert_rowid(Statement& statement) = 0;
+ virtual std::string get_returning_id_sql_string(const std::string&)
+ {
+ return {};
+ }
+ virtual std::string id_column_type() = 0;
+
+ int64_t last_inserted_rowid{-1};
+};
diff --git a/src/database/index.hpp b/src/database/index.hpp
new file mode 100644
index 0000000..30766ab
--- /dev/null
+++ b/src/database/index.hpp
@@ -0,0 +1,38 @@
+#pragma once
+
+#include <database/engine.hpp>
+
+#include <string>
+#include <tuple>
+
+namespace
+{
+template <std::size_t N=0, typename... T>
+typename std::enable_if<N == sizeof...(T), void>::type
+add_column_name(std::string&)
+{ }
+
+template <std::size_t N=0, typename... T>
+typename std::enable_if<N < sizeof...(T), void>::type
+add_column_name(std::string& out)
+{
+ using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<std::tuple<T...>>()))>::type;
+ out += ColumnType::name;
+ if (N != sizeof...(T) - 1)
+ out += ",";
+ add_column_name<N+1, T...>(out);
+}
+}
+
+template <typename... Columns>
+void create_index(DatabaseEngine& db, const std::string& name, const std::string& table)
+{
+ std::string query{"CREATE INDEX IF NOT EXISTS "};
+ query += name + " ON " + table + "(";
+ add_column_name<0, Columns...>(query);
+ query += ")";
+
+ auto result = db.raw_exec(query);
+ if (std::get<0>(result) == false)
+ log_error("Error executing query: ", std::get<1>(result));
+}
diff --git a/src/database/insert_query.hpp b/src/database/insert_query.hpp
index 2ece69d..9726424 100644
--- a/src/database/insert_query.hpp
+++ b/src/database/insert_query.hpp
@@ -10,64 +10,64 @@
#include <string>
#include <tuple>
-#include <sqlite3.h>
-
-template <int N, typename ColumnType, typename... T>
-typename std::enable_if<!std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type
-actual_bind(Statement& statement, std::vector<std::string>& params, const std::tuple<T...>&)
+template <std::size_t N=0, typename... T>
+typename std::enable_if<N < sizeof...(T), void>::type
+update_autoincrement_id(std::tuple<T...>& columns, Statement& statement)
{
- const auto value = params.front();
- params.erase(params.begin());
- if (sqlite3_bind_text(statement.get(), N + 1, value.data(), static_cast<int>(value.size()), SQLITE_TRANSIENT) != SQLITE_OK)
- log_error("Failed to bind ", value, " to param ", N);
+ using ColumnType = typename std::decay<decltype(std::get<N>(columns))>::type;
+ if (std::is_same<ColumnType, Id>::value)
+ auto&& column = std::get<Id>(columns);
+ update_autoincrement_id<N+1>(columns, statement);
}
-template <int N, typename ColumnType, typename... T>
-typename std::enable_if<std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type
-actual_bind(Statement& statement, std::vector<std::string>&, const std::tuple<T...>& columns)
-{
- auto&& column = std::get<Id>(columns);
- if (column.value != 0)
- {
- if (sqlite3_bind_int64(statement.get(), N + 1, static_cast<sqlite3_int64>(column.value)) != SQLITE_OK)
- log_error("Failed to bind ", column.value, " to id.");
- }
- else if (sqlite3_bind_null(statement.get(), N + 1) != SQLITE_OK)
- log_error("Failed to bind NULL to param ", N);
-}
+template <std::size_t N=0, typename... T>
+typename std::enable_if<N == sizeof...(T), void>::type
+update_autoincrement_id(std::tuple<T...>&, Statement& statement)
+{}
struct InsertQuery: public Query
{
- InsertQuery(const std::string& name):
- Query("INSERT OR REPLACE INTO ")
+ template <typename... T>
+ InsertQuery(const std::string& name, const std::tuple<T...>& columns):
+ Query("INSERT INTO ")
{
this->body += name;
+ this->insert_col_names(columns);
+ this->insert_values(columns);
}
template <typename... T>
- void execute(const std::tuple<T...>& columns, sqlite3* db)
+ void execute(DatabaseEngine& db, std::tuple<T...>& columns)
{
- auto statement = this->prepare(db);
- {
- this->bind_param(columns, statement);
- if (sqlite3_step(statement.get()) != SQLITE_DONE)
- log_error("Failed to execute query: ", sqlite3_errmsg(db));
- }
+#ifdef DEBUG_SQL_QUERIES
+ const auto timer = this->log_and_time();
+#endif
+
+ auto statement = db.prepare(this->body);
+ this->bind_param(columns, *statement);
+
+ if (statement->step() != StepResult::Error)
+ db.extract_last_insert_rowid(*statement);
+ else
+ log_error("Failed to extract the rowid from the last INSERT");
}
template <int N=0, typename... T>
typename std::enable_if<N < sizeof...(T), void>::type
- bind_param(const std::tuple<T...>& columns, Statement& statement)
+ bind_param(const std::tuple<T...>& columns, Statement& statement, int index=1)
{
- using ColumnType = typename std::remove_reference<decltype(std::get<N>(columns))>::type;
+ auto&& column = std::get<N>(columns);
+ using ColumnType = std::decay_t<decltype(column)>;
- actual_bind<N, ColumnType>(statement, this->params, columns);
- this->bind_param<N+1>(columns, statement);
+ if (!std::is_same<ColumnType, Id>::value)
+ actual_bind(statement, column.value, index++);
+
+ this->bind_param<N+1>(columns, statement, index);
}
template <int N=0, typename... T>
typename std::enable_if<N == sizeof...(T), void>::type
- bind_param(const std::tuple<T...>&, Statement&)
+ bind_param(const std::tuple<T...>&, Statement&, int)
{}
template <typename... T>
@@ -80,18 +80,21 @@ struct InsertQuery: public Query
template <int N=0, typename... T>
typename std::enable_if<N < sizeof...(T), void>::type
- insert_value(const std::tuple<T...>& columns)
+ insert_value(const std::tuple<T...>& columns, int index=1)
{
- this->body += "?";
- if (N != sizeof...(T) - 1)
- this->body += ",";
- this->body += " ";
- add_param(*this, std::get<N>(columns));
- this->insert_value<N+1>(columns);
+ using ColumnType = std::decay_t<decltype(std::get<N>(columns))>;
+
+ if (!std::is_same<ColumnType, Id>::value)
+ {
+ this->body += "$" + std::to_string(index++);
+ if (N != sizeof...(T) - 1)
+ this->body += ", ";
+ }
+ this->insert_value<N+1>(columns, index);
}
template <int N=0, typename... T>
typename std::enable_if<N == sizeof...(T), void>::type
- insert_value(const std::tuple<T...>&)
+ insert_value(const std::tuple<T...>&, const int)
{ }
template <typename... T>
@@ -99,27 +102,28 @@ struct InsertQuery: public Query
{
this->body += " (";
this->insert_col_name(columns);
- this->body += ")\n";
+ this->body += ")";
}
template <int N=0, typename... T>
typename std::enable_if<N < sizeof...(T), void>::type
insert_col_name(const std::tuple<T...>& columns)
{
- using ColumnType = typename std::remove_reference<decltype(std::get<N>(columns))>::type;
+ using ColumnType = std::decay_t<decltype(std::get<N>(columns))>;
- this->body += ColumnType::name;
+ if (!std::is_same<ColumnType, Id>::value)
+ {
+ this->body += ColumnType::name;
- if (N < (sizeof...(T) - 1))
- this->body += ", ";
+ if (N < (sizeof...(T) - 1))
+ this->body += ", ";
+ }
this->insert_col_name<N+1>(columns);
}
+
template <int N=0, typename... T>
typename std::enable_if<N == sizeof...(T), void>::type
insert_col_name(const std::tuple<T...>&)
{}
-
-
- private:
};
diff --git a/src/database/postgresql_engine.cpp b/src/database/postgresql_engine.cpp
new file mode 100644
index 0000000..984a959
--- /dev/null
+++ b/src/database/postgresql_engine.cpp
@@ -0,0 +1,91 @@
+#include <biboumi.h>
+#ifdef PQ_FOUND
+
+#include <utils/scopeguard.hpp>
+
+#include <database/query.hpp>
+
+#include <database/postgresql_engine.hpp>
+
+#include <database/postgresql_statement.hpp>
+
+#include <logger/logger.hpp>
+
+PostgresqlEngine::PostgresqlEngine(PGconn*const conn):
+ conn(conn)
+{}
+
+PostgresqlEngine::~PostgresqlEngine()
+{
+ PQfinish(this->conn);
+}
+
+std::unique_ptr<DatabaseEngine> PostgresqlEngine::open(const std::string& conninfo)
+{
+ PGconn* con = PQconnectdb(conninfo.data());
+
+ if (!con)
+ {
+ log_error("Failed to allocate a Postgresql connection");
+ throw std::runtime_error("");
+ }
+ const auto status = PQstatus(con);
+ if (status != CONNECTION_OK)
+ {
+ const char* errmsg = PQerrorMessage(con);
+ log_error("Postgresql connection failed: ", errmsg);
+ throw std::runtime_error("failed to open connection.");
+ }
+ return std::make_unique<PostgresqlEngine>(con);
+}
+
+std::set<std::string> PostgresqlEngine::get_all_columns_from_table(const std::string& table_name)
+{
+ const auto query = "SELECT column_name from information_schema.columns where table_name='" + table_name + "'";
+ auto statement = this->prepare(query);
+ std::set<std::string> columns;
+
+ while (statement->step() == StepResult::Row)
+ columns.insert(statement->get_column_text(0));
+
+ return columns;
+}
+
+std::tuple<bool, std::string> PostgresqlEngine::raw_exec(const std::string& query)
+{
+#ifdef DEBUG_SQL_QUERIES
+ log_debug("SQL QUERY: ", query);
+ const auto timer = make_sql_timer();
+#endif
+ PGresult* res = PQexec(this->conn, query.data());
+ auto sg = utils::make_scope_guard([res](){
+ PQclear(res);
+ });
+
+ auto res_status = PQresultStatus(res);
+ if (res_status != PGRES_COMMAND_OK)
+ return std::make_tuple(false, PQresultErrorMessage(res));
+ return std::make_tuple(true, std::string{});
+}
+
+std::unique_ptr<Statement> PostgresqlEngine::prepare(const std::string& query)
+{
+ return std::make_unique<PostgresqlStatement>(query, this->conn);
+}
+
+void PostgresqlEngine::extract_last_insert_rowid(Statement& statement)
+{
+ this->last_inserted_rowid = statement.get_column_int64(0);
+}
+
+std::string PostgresqlEngine::get_returning_id_sql_string(const std::string& col_name)
+{
+ return " RETURNING " + col_name;
+}
+
+std::string PostgresqlEngine::id_column_type()
+{
+ return "SERIAL";
+}
+
+#endif
diff --git a/src/database/postgresql_engine.hpp b/src/database/postgresql_engine.hpp
new file mode 100644
index 0000000..fe4fb53
--- /dev/null
+++ b/src/database/postgresql_engine.hpp
@@ -0,0 +1,48 @@
+#pragma once
+
+#include <biboumi.h>
+#include <string>
+#include <stdexcept>
+#include <memory>
+
+#include <database/statement.hpp>
+#include <database/engine.hpp>
+
+#include <tuple>
+#include <set>
+
+#ifdef PQ_FOUND
+
+#include <libpq-fe.h>
+
+class PostgresqlEngine: public DatabaseEngine
+{
+ public:
+ PostgresqlEngine(PGconn*const conn);
+
+ ~PostgresqlEngine();
+
+ static std::unique_ptr<DatabaseEngine> open(const std::string& string);
+
+ std::set<std::string> get_all_columns_from_table(const std::string& table_name) override final;
+ std::tuple<bool, std::string> raw_exec(const std::string& query) override final;
+ std::unique_ptr<Statement> prepare(const std::string& query) override;
+ void extract_last_insert_rowid(Statement& statement) override;
+ std::string get_returning_id_sql_string(const std::string& col_name) override;
+ std::string id_column_type() override;
+private:
+ PGconn* const conn;
+};
+
+#else
+
+class PostgresqlEngine
+{
+public:
+ static std::unique_ptr<DatabaseEngine> open(const std::string& string)
+ {
+ throw std::runtime_error("Cannot open postgresql database "s + string + ": biboumi is not compiled with libpq.");
+ }
+};
+
+#endif
diff --git a/src/database/postgresql_statement.hpp b/src/database/postgresql_statement.hpp
new file mode 100644
index 0000000..571c8f1
--- /dev/null
+++ b/src/database/postgresql_statement.hpp
@@ -0,0 +1,123 @@
+#pragma once
+
+#include <database/statement.hpp>
+
+#include <logger/logger.hpp>
+
+#include <libpq-fe.h>
+
+class PostgresqlStatement: public Statement
+{
+ public:
+ PostgresqlStatement(std::string body, PGconn*const conn):
+ body(std::move(body)),
+ conn(conn)
+ {}
+ ~PostgresqlStatement()
+ {
+ PQclear(this->result);
+ this->result = nullptr;
+ }
+ PostgresqlStatement(const PostgresqlStatement&) = delete;
+ PostgresqlStatement& operator=(const PostgresqlStatement&) = delete;
+ PostgresqlStatement(PostgresqlStatement&& other) = delete;
+ PostgresqlStatement& operator=(PostgresqlStatement&& other) = delete;
+
+ StepResult step() override final
+ {
+ if (!this->executed)
+ {
+ this->current_tuple = 0;
+ this->executed = true;
+ if (!this->execute())
+ return StepResult::Error;
+ }
+ else
+ {
+ this->current_tuple++;
+ }
+ if (this->current_tuple < PQntuples(this->result))
+ return StepResult::Row;
+ return StepResult::Done;
+ }
+
+ int64_t get_column_int64(const int col) override
+ {
+ const char* result = PQgetvalue(this->result, this->current_tuple, col);
+ std::istringstream iss;
+ iss.str(result);
+ int64_t res;
+ iss >> res;
+ return res;
+ }
+ std::string get_column_text(const int col) override
+ {
+ const char* result = PQgetvalue(this->result, this->current_tuple, col);
+ return result;
+ }
+ int get_column_int(const int col) override
+ {
+ const char* result = PQgetvalue(this->result, this->current_tuple, col);
+ std::istringstream iss;
+ iss.str(result);
+ int res;
+ iss >> res;
+ return res;
+ }
+
+ void bind(std::vector<std::string> params) override
+ {
+
+ this->params = std::move(params);
+ }
+
+ bool bind_text(const int, const std::string& data) override
+ {
+ this->params.push_back(data);
+ return true;
+ }
+ bool bind_int64(const int, const std::int64_t value) override
+ {
+ this->params.push_back(std::to_string(value));
+ return true;
+ }
+ bool bind_null(const int) override
+ {
+ this->params.push_back("NULL");
+ return true;
+ }
+
+ private:
+
+private:
+ bool execute()
+ {
+ std::vector<const char*> params;
+ params.reserve(this->params.size());
+
+ for (const auto& param: this->params)
+ params.push_back(param.data());
+ const int param_size = static_cast<int>(this->params.size());
+ this->result = PQexecParams(this->conn, this->body.data(),
+ param_size,
+ nullptr,
+ params.data(),
+ nullptr,
+ nullptr,
+ 0);
+ const auto status = PQresultStatus(this->result);
+ if (status != PGRES_TUPLES_OK && status != PGRES_COMMAND_OK)
+ {
+ log_error("Failed to execute command: ", PQresultErrorMessage(this->result));
+ return false;
+ }
+ return true;
+ }
+
+ bool executed{false};
+ std::string body;
+ PGconn*const conn;
+ std::vector<std::string> params;
+ PGresult* result{nullptr};
+ int current_tuple{0};
+};
diff --git a/src/database/query.cpp b/src/database/query.cpp
index ba63a92..d27dc59 100644
--- a/src/database/query.cpp
+++ b/src/database/query.cpp
@@ -1,9 +1,26 @@
#include <database/query.hpp>
#include <database/column.hpp>
-template <>
-void add_param<Id>(Query&, const Id&)
-{}
+void actual_bind(Statement& statement, const std::string& value, int index)
+{
+ statement.bind_text(index, value);
+}
+
+void actual_bind(Statement& statement, const std::int64_t value, int index)
+{
+ statement.bind_int64(index, value);
+}
+
+void actual_bind(Statement& statement, const OptionalBool& value, int index)
+{
+ if (!value.is_set)
+ statement.bind_int64(index, 0);
+ else if (value.value)
+ statement.bind_int64(index, 1);
+ else
+ statement.bind_int64(index, -1);
+}
+
void actual_add_param(Query& query, const std::string& val)
{
@@ -28,7 +45,8 @@ Query& operator<<(Query& query, const char* str)
Query& operator<<(Query& query, const std::string& str)
{
- query.body += "?";
+ query.body += "$" + std::to_string(query.current_param);
+ query.current_param++;
actual_add_param(query, str);
return query;
}
diff --git a/src/database/query.hpp b/src/database/query.hpp
index 6e1db12..8434944 100644
--- a/src/database/query.hpp
+++ b/src/database/query.hpp
@@ -1,5 +1,7 @@
#pragma once
+#include <biboumi.h>
+
#include <utils/optional_bool.hpp>
#include <database/statement.hpp>
#include <database/column.hpp>
@@ -9,54 +11,53 @@
#include <vector>
#include <string>
-#include <sqlite3.h>
+void actual_bind(Statement& statement, const std::string& value, int index);
+void actual_bind(Statement& statement, const std::int64_t value, int index);
+void actual_bind(Statement& statement, const OptionalBool& value, int index);
+
+#ifdef DEBUG_SQL_QUERIES
+#include <utils/scopetimer.hpp>
+
+inline auto make_sql_timer()
+{
+ return make_scope_timer([](const std::chrono::steady_clock::duration& elapsed)
+ {
+ const auto seconds = std::chrono::duration_cast<std::chrono::seconds>(elapsed);
+ const auto rest = elapsed - seconds;
+ log_debug("Query executed in ", seconds.count(), ".", rest.count(), "s.");
+ });
+}
+#endif
struct Query
{
std::string body;
std::vector<std::string> params;
+ int current_param{1};
Query(std::string str):
body(std::move(str))
{}
- Statement prepare(sqlite3* db)
+#ifdef DEBUG_SQL_QUERIES
+ auto log_and_time()
{
- sqlite3_stmt* stmt;
- auto res = sqlite3_prepare(db, this->body.data(), static_cast<int>(this->body.size()) + 1,
- &stmt, nullptr);
- if (res != SQLITE_OK)
- {
- log_error("Error preparing statement: ", sqlite3_errmsg(db));
- return nullptr;
- }
- Statement statement(stmt);
- int i = 1;
- for (const std::string& param: this->params)
- {
- if (sqlite3_bind_text(statement.get(), i, param.data(), static_cast<int>(param.size()), SQLITE_TRANSIENT) != SQLITE_OK)
- log_error("Failed to bind ", param, " to param ", i);
- i++;
- }
-
- return statement;
- }
-
- void execute(sqlite3* db)
- {
- auto statement = this->prepare(db);
- while (sqlite3_step(statement.get()) != SQLITE_DONE)
- ;
+ std::ostringstream os;
+ os << this->body << "; ";
+ for (const auto& param: this->params)
+ os << "'" << param << "' ";
+ log_debug("SQL QUERY: ", os.str());
+ return make_sql_timer();
}
+#endif
};
template <typename ColumnType>
void add_param(Query& query, const ColumnType& column)
{
+ std::cout << "add_param<ColumnType>" << std::endl;
actual_add_param(query, column.value);
}
-template <>
-void add_param<Id>(Query& query, const Id& column);
template <typename T>
void actual_add_param(Query& query, const T& val)
@@ -81,7 +82,8 @@ template <typename Integer>
typename std::enable_if<std::is_integral<Integer>::value, Query&>::type
operator<<(Query& query, const Integer& i)
{
- query.body += "?";
+ query.body += "$" + std::to_string(query.current_param++);
actual_add_param(query, i);
return query;
}
+
diff --git a/src/database/row.hpp b/src/database/row.hpp
index 2b50874..4dc98be 100644
--- a/src/database/row.hpp
+++ b/src/database/row.hpp
@@ -1,72 +1,72 @@
#pragma once
#include <database/insert_query.hpp>
+#include <database/update_query.hpp>
#include <logger/logger.hpp>
-#include <type_traits>
-
-#include <sqlite3.h>
+#include <utils/is_one_of.hpp>
-template <typename ColumnType, typename... T>
-typename std::enable_if<!std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type
-update_id(std::tuple<T...>&, sqlite3*)
-{}
+#include <type_traits>
-template <typename ColumnType, typename... T>
-typename std::enable_if<std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type
-update_id(std::tuple<T...>& columns, sqlite3* db)
+template <typename... T>
+struct Row
{
- auto&& column = std::get<ColumnType>(columns);
- auto res = sqlite3_last_insert_rowid(db);
- column.value = static_cast<Id::real_type>(res);
-}
+ Row(std::string name):
+ table_name(std::move(name))
+ {}
-template <std::size_t N=0, typename... T>
-typename std::enable_if<N < sizeof...(T), void>::type
-update_autoincrement_id(std::tuple<T...>& columns, sqlite3* db)
-{
- using ColumnType = typename std::remove_reference<decltype(std::get<N>(columns))>::type;
- update_id<ColumnType>(columns, db);
- update_autoincrement_id<N+1>(columns, db);
-}
+ template <typename Type>
+ typename Type::real_type& col()
+ {
+ auto&& col = std::get<Type>(this->columns);
+ return col.value;
+ }
-template <std::size_t N=0, typename... T>
-typename std::enable_if<N == sizeof...(T), void>::type
-update_autoincrement_id(std::tuple<T...>&, sqlite3*)
-{}
+ template <typename Type>
+ const auto& col() const
+ {
+ auto&& col = std::get<Type>(this->columns);
+ return col.value;
+ }
-template <typename... T>
-struct Row
-{
- Row(std::string name):
- table_name(std::move(name))
- {}
+ template <bool Coucou=true>
+ void save(std::unique_ptr<DatabaseEngine>& db, typename std::enable_if<!is_one_of<Id, T...> && Coucou>::type* = nullptr)
+ {
+ this->insert(*db);
+ }
- template <typename Type>
- typename Type::real_type& col()
- {
- auto&& col = std::get<Type>(this->columns);
- return col.value;
- }
+ template <bool Coucou=true>
+ void save(std::unique_ptr<DatabaseEngine>& db, typename std::enable_if<is_one_of<Id, T...> && Coucou>::type* = nullptr)
+ {
+ const Id& id = std::get<Id>(this->columns);
+ if (id.value == Id::unset_value)
+ {
+ this->insert(*db);
+ if (db->last_inserted_rowid >= 0)
+ std::get<Id>(this->columns).value = static_cast<Id::real_type>(db->last_inserted_rowid);
+ }
+ else
+ this->update(*db);
+ }
- template <typename Type>
- const auto& col() const
- {
- auto&& col = std::get<Type>(this->columns);
- return col.value;
- }
+ private:
+ void insert(DatabaseEngine& db)
+ {
+ InsertQuery query(this->table_name, this->columns);
+ // Ugly workaround for non portable stuff
+ query.body += db.get_returning_id_sql_string(Id::name);
+ query.execute(db, this->columns);
+ }
- void save(sqlite3* db)
- {
- InsertQuery query(this->table_name);
- query.insert_col_names(this->columns);
- query.insert_values(this->columns);
+ void update(DatabaseEngine& db)
+ {
+ UpdateQuery query(this->table_name, this->columns);
- query.execute(this->columns, db);
+ query.execute(db, this->columns);
+ }
- update_autoincrement_id(this->columns, db);
- }
+public:
+ std::tuple<T...> columns;
+ std::string table_name;
- std::tuple<T...> columns;
- std::string table_name;
};
diff --git a/src/database/select_query.hpp b/src/database/select_query.hpp
index 872001c..5a17f38 100644
--- a/src/database/select_query.hpp
+++ b/src/database/select_query.hpp
@@ -1,5 +1,7 @@
#pragma once
+#include <database/engine.hpp>
+
#include <database/statement.hpp>
#include <database/query.hpp>
#include <logger/logger.hpp>
@@ -10,32 +12,27 @@
#include <vector>
#include <string>
-#include <sqlite3.h>
-
using namespace std::string_literals;
template <typename T>
-typename std::enable_if<std::is_integral<T>::value, sqlite3_int64>::type
+typename std::enable_if<std::is_integral<T>::value, std::int64_t>::type
extract_row_value(Statement& statement, const int i)
{
- return sqlite3_column_int64(statement.get(), i);
+ return statement.get_column_int64(i);
}
template <typename T>
typename std::enable_if<std::is_same<std::string, T>::value, T>::type
extract_row_value(Statement& statement, const int i)
{
- const auto size = sqlite3_column_bytes(statement.get(), i);
- const unsigned char* str = sqlite3_column_text(statement.get(), i);
- std::string result(reinterpret_cast<const char*>(str), static_cast<std::size_t>(size));
- return result;
+ return statement.get_column_text(i);
}
template <typename T>
typename std::enable_if<std::is_same<OptionalBool, T>::value, T>::type
extract_row_value(Statement& statement, const int i)
{
- const auto integer = sqlite3_column_int(statement.get(), i);
+ const auto integer = statement.get_column_int(i);
OptionalBool result;
if (integer > 0)
result.set_value(true);
@@ -109,16 +106,24 @@ struct SelectQuery: public Query
return *this;
}
- auto execute(sqlite3* db)
+ auto execute(DatabaseEngine& db)
{
- auto statement = this->prepare(db);
std::vector<Row<T...>> rows;
- while (sqlite3_step(statement.get()) == SQLITE_ROW)
+
+#ifdef DEBUG_SQL_QUERIES
+ const auto timer = this->log_and_time();
+#endif
+
+ auto statement = db.prepare(this->body);
+ statement->bind(std::move(this->params));
+
+ while (statement->step() == StepResult::Row)
{
Row<T...> row(this->table_name);
- extract_row_values(row, statement);
+ extract_row_values(row, *statement);
rows.push_back(row);
}
+
return rows;
}
diff --git a/src/database/sqlite3_engine.cpp b/src/database/sqlite3_engine.cpp
new file mode 100644
index 0000000..ae4a146
--- /dev/null
+++ b/src/database/sqlite3_engine.cpp
@@ -0,0 +1,101 @@
+#include <biboumi.h>
+
+#ifdef SQLITE3_FOUND
+
+#include <database/sqlite3_engine.hpp>
+
+#include <database/sqlite3_statement.hpp>
+
+#include <database/query.hpp>
+
+#include <utils/tolower.hpp>
+#include <logger/logger.hpp>
+#include <vector>
+
+Sqlite3Engine::Sqlite3Engine(sqlite3* db):
+ db(db)
+{
+}
+
+Sqlite3Engine::~Sqlite3Engine()
+{
+ sqlite3_close(this->db);
+}
+
+std::set<std::string> Sqlite3Engine::get_all_columns_from_table(const std::string& table_name)
+{
+ std::set<std::string> result;
+ char* errmsg;
+ std::string query{"PRAGMA table_info(" + table_name + ")"};
+ int res = sqlite3_exec(this->db, query.data(), [](void* param, int columns_nb, char** columns, char**) -> int {
+ constexpr int name_column = 1;
+ std::set<std::string>* result = static_cast<std::set<std::string>*>(param);
+ if (name_column < columns_nb)
+ result->insert(utils::tolower(columns[name_column]));
+ return 0;
+ }, &result, &errmsg);
+
+ if (res != SQLITE_OK)
+ {
+ log_error("Error executing ", query, ": ", errmsg);
+ sqlite3_free(errmsg);
+ }
+
+ return result;
+}
+
+std::unique_ptr<DatabaseEngine> Sqlite3Engine::open(const std::string& filename)
+{
+ sqlite3* new_db;
+ auto res = sqlite3_open_v2(filename.data(), &new_db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr);
+ if (res != SQLITE_OK)
+ {
+ log_error("Failed to open database file ", filename, ": ", sqlite3_errmsg(new_db));
+ sqlite3_close(new_db);
+ throw std::runtime_error("");
+ }
+ return std::make_unique<Sqlite3Engine>(new_db);
+}
+
+std::tuple<bool, std::string> Sqlite3Engine::raw_exec(const std::string& query)
+{
+#ifdef DEBUG_SQL_QUERIES
+ log_debug("SQL QUERY: ", query);
+ const auto timer = make_sql_timer();
+#endif
+
+ char* error;
+ const auto result = sqlite3_exec(db, query.data(), nullptr, nullptr, &error);
+ if (result != SQLITE_OK)
+ {
+ std::string err_msg(error);
+ sqlite3_free(error);
+ return std::make_tuple(false, err_msg);
+ }
+ return std::make_tuple(true, std::string{});
+}
+
+std::unique_ptr<Statement> Sqlite3Engine::prepare(const std::string& query)
+{
+ sqlite3_stmt* stmt;
+ auto res = sqlite3_prepare(db, query.data(), static_cast<int>(query.size()) + 1,
+ &stmt, nullptr);
+ if (res != SQLITE_OK)
+ {
+ log_error("Error preparing statement: ", sqlite3_errmsg(db));
+ return nullptr;
+ }
+ return std::make_unique<Sqlite3Statement>(stmt);
+}
+
+void Sqlite3Engine::extract_last_insert_rowid(Statement&)
+{
+ this->last_inserted_rowid = sqlite3_last_insert_rowid(this->db);
+}
+
+std::string Sqlite3Engine::id_column_type()
+{
+ return "INTEGER PRIMARY KEY AUTOINCREMENT";
+}
+
+#endif
diff --git a/src/database/sqlite3_engine.hpp b/src/database/sqlite3_engine.hpp
new file mode 100644
index 0000000..5b8176c
--- /dev/null
+++ b/src/database/sqlite3_engine.hpp
@@ -0,0 +1,47 @@
+#pragma once
+
+#include <database/engine.hpp>
+
+#include <database/statement.hpp>
+
+#include <memory>
+#include <string>
+#include <tuple>
+#include <set>
+
+#include <biboumi.h>
+
+#ifdef SQLITE3_FOUND
+
+#include <sqlite3.h>
+
+class Sqlite3Engine: public DatabaseEngine
+{
+ public:
+ Sqlite3Engine(sqlite3* db);
+
+ ~Sqlite3Engine();
+
+ static std::unique_ptr<DatabaseEngine> open(const std::string& string);
+
+ std::set<std::string> get_all_columns_from_table(const std::string& table_name) override final;
+ std::tuple<bool, std::string> raw_exec(const std::string& query) override final;
+ std::unique_ptr<Statement> prepare(const std::string& query) override;
+ void extract_last_insert_rowid(Statement& statement) override;
+ std::string id_column_type() override;
+private:
+ sqlite3* const db;
+};
+
+#else
+
+class Sqlite3Engine
+{
+public:
+ static std::unique_ptr<DatabaseEngine> open(const std::string& string)
+ {
+ throw std::runtime_error("Cannot open sqlite3 database "s + string + ": biboumi is not compiled with sqlite3 lib.");
+ }
+};
+
+#endif
diff --git a/src/database/sqlite3_statement.hpp b/src/database/sqlite3_statement.hpp
new file mode 100644
index 0000000..7738fa6
--- /dev/null
+++ b/src/database/sqlite3_statement.hpp
@@ -0,0 +1,92 @@
+#pragma once
+
+#include <database/statement.hpp>
+
+#include <logger/logger.hpp>
+
+#include <sqlite3.h>
+
+class Sqlite3Statement: public Statement
+{
+ public:
+ Sqlite3Statement(sqlite3_stmt* stmt):
+ stmt(stmt) {}
+ ~Sqlite3Statement()
+ {
+ sqlite3_finalize(this->stmt);
+ }
+
+ StepResult step() override final
+ {
+ auto res = sqlite3_step(this->get());
+ if (res == SQLITE_ROW)
+ return StepResult::Row;
+ else if (res == SQLITE_DONE)
+ return StepResult::Done;
+ else
+ return StepResult::Error;
+ }
+
+ void bind(std::vector<std::string> params) override
+ {
+ int i = 1;
+ for (const std::string& param: params)
+ {
+ if (sqlite3_bind_text(this->get(), i, param.data(), static_cast<int>(param.size()), SQLITE_TRANSIENT) != SQLITE_OK)
+ log_error("Failed to bind ", param, " to param ", i);
+ i++;
+ }
+ }
+
+ int64_t get_column_int64(const int col) override
+ {
+ return sqlite3_column_int64(this->get(), col);
+ }
+
+ std::string get_column_text(const int col) override
+ {
+ const auto size = sqlite3_column_bytes(this->get(), col);
+ const unsigned char* str = sqlite3_column_text(this->get(), col);
+ std::string result(reinterpret_cast<const char*>(str), static_cast<std::size_t>(size));
+ return result;
+ }
+
+ bool bind_text(const int pos, const std::string& data) override
+ {
+ return sqlite3_bind_text(this->get(), pos, data.data(), static_cast<int>(data.size()), SQLITE_TRANSIENT) == SQLITE_OK;
+ }
+ bool bind_int64(const int pos, const std::int64_t value) override
+ {
+ return sqlite3_bind_int64(this->get(), pos, static_cast<sqlite3_int64>(value)) == SQLITE_OK;
+ }
+ bool bind_null(const int pos) override
+ {
+ return sqlite3_bind_null(this->get(), pos) == SQLITE_OK;
+ }
+ int get_column_int(const int col) override
+ {
+ return sqlite3_column_int(this->get(), col);
+ }
+
+ Sqlite3Statement(const Sqlite3Statement&) = delete;
+ Sqlite3Statement& operator=(const Sqlite3Statement&) = delete;
+ Sqlite3Statement(Sqlite3Statement&& other):
+ stmt(other.stmt)
+ {
+ other.stmt = nullptr;
+ }
+ Sqlite3Statement& operator=(Sqlite3Statement&& other)
+ {
+ this->stmt = other.stmt;
+ other.stmt = nullptr;
+ return *this;
+ }
+ sqlite3_stmt* get()
+ {
+ return this->stmt;
+ }
+
+ private:
+ sqlite3_stmt* stmt;
+ int last_step_result{SQLITE_OK};
+};
diff --git a/src/database/statement.hpp b/src/database/statement.hpp
index 87cd70f..4a61928 100644
--- a/src/database/statement.hpp
+++ b/src/database/statement.hpp
@@ -1,35 +1,29 @@
#pragma once
-#include <sqlite3.h>
+#include <cstdint>
+#include <string>
+#include <vector>
+
+enum class StepResult
+{
+ Row,
+ Done,
+ Error,
+};
class Statement
{
public:
- Statement(sqlite3_stmt* stmt):
- stmt(stmt) {}
- ~Statement()
- {
- sqlite3_finalize(this->stmt);
- }
+ virtual ~Statement() = default;
+ virtual StepResult step() = 0;
+
+ virtual void bind(std::vector<std::string> params) = 0;
- Statement(const Statement&) = delete;
- Statement& operator=(const Statement&) = delete;
- Statement(Statement&& other):
- stmt(other.stmt)
- {
- other.stmt = nullptr;
- }
- Statement& operator=(Statement&& other)
- {
- this->stmt = other.stmt;
- other.stmt = nullptr;
- return *this;
- }
- sqlite3_stmt* get()
- {
- return this->stmt;
- }
+ virtual std::int64_t get_column_int64(const int col) = 0;
+ virtual std::string get_column_text(const int col) = 0;
+ virtual int get_column_int(const int col) = 0;
- private:
- sqlite3_stmt* stmt;
+ virtual bool bind_text(const int pos, const std::string& data) = 0;
+ virtual bool bind_int64(const int pos, const std::int64_t value) = 0;
+ virtual bool bind_null(const int pos) = 0;
};
diff --git a/src/database/table.cpp b/src/database/table.cpp
deleted file mode 100644
index 9224d79..0000000
--- a/src/database/table.cpp
+++ /dev/null
@@ -1,23 +0,0 @@
-#include <database/table.hpp>
-
-std::set<std::string> get_all_columns_from_table(sqlite3* db, const std::string& table_name)
-{
- std::set<std::string> result;
- char* errmsg;
- std::string query{"PRAGMA table_info(" + table_name + ")"};
- int res = sqlite3_exec(db, query.data(), [](void* param, int columns_nb, char** columns, char**) -> int {
- constexpr int name_column = 1;
- std::set<std::string>* result = static_cast<std::set<std::string>*>(param);
- if (name_column < columns_nb)
- result->insert(columns[name_column]);
- return 0;
- }, &result, &errmsg);
-
- if (res != SQLITE_OK)
- {
- log_error("Error executing ", query, ": ", errmsg);
- sqlite3_free(errmsg);
- }
-
- return result;
-}
diff --git a/src/database/table.hpp b/src/database/table.hpp
index 0060211..680e7cc 100644
--- a/src/database/table.hpp
+++ b/src/database/table.hpp
@@ -1,7 +1,8 @@
#pragma once
+#include <database/engine.hpp>
+
#include <database/select_query.hpp>
-#include <database/type_to_sql.hpp>
#include <database/row.hpp>
#include <algorithm>
@@ -10,23 +11,27 @@
using namespace std::string_literals;
-std::set<std::string> get_all_columns_from_table(sqlite3* db, const std::string& table_name);
+template <typename T>
+std::string ToSQLType(DatabaseEngine& db)
+{
+ if (std::is_same<T, Id>::value)
+ return db.id_column_type();
+ else if (std::is_same<typename T::real_type, std::string>::value)
+ return "TEXT";
+ else
+ return "INTEGER";
+}
template <typename ColumnType>
-void add_column_to_table(sqlite3* db, const std::string& table_name)
+void add_column_to_table(DatabaseEngine& db, const std::string& table_name)
{
const std::string name = ColumnType::name;
- std::string query{"ALTER TABLE " + table_name + " ADD " + ColumnType::name + " " + TypeToSQLType<typename ColumnType::real_type>::type};
- char* error;
- const auto result = sqlite3_exec(db, query.data(), nullptr, nullptr, &error);
- if (result != SQLITE_OK)
- {
- log_error("Error adding column ", name, " to table ", table_name, ": ", error);
- sqlite3_free(error);
- }
+ std::string query{"ALTER TABLE " + table_name + " ADD " + ColumnType::name + " " + ToSQLType<ColumnType>(db)};
+ auto res = db.raw_exec(query);
+ if (std::get<0>(res) == false)
+ log_error("Error adding column ", name, " to table ", table_name, ": ", std::get<1>(res));
}
-
template <typename ColumnType, decltype(ColumnType::options) = nullptr>
void append_option(std::string& s)
{
@@ -50,27 +55,23 @@ class Table
name(std::move(name))
{}
- void upgrade(sqlite3* db)
+ void upgrade(DatabaseEngine& db)
{
- const auto existing_columns = get_all_columns_from_table(db, this->name);
+ const auto existing_columns = db.get_all_columns_from_table(this->name);
add_column_if_not_exists(db, existing_columns);
}
- void create(sqlite3* db)
+ void create(DatabaseEngine& db)
{
- std::string res{"CREATE TABLE IF NOT EXISTS "};
- res += this->name;
- res += " (\n";
- this->add_column_create(res);
- res += ")";
-
- char* error;
- const auto result = sqlite3_exec(db, res.data(), nullptr, nullptr, &error);
- if (result != SQLITE_OK)
- {
- log_error("Error executing query: ", error);
- sqlite3_free(error);
- }
+ std::string query{"CREATE TABLE IF NOT EXISTS "};
+ query += this->name;
+ query += " (";
+ this->add_column_create(db, query);
+ query += ")";
+
+ auto result = db.raw_exec(query);
+ if (std::get<0>(result) == false)
+ log_error("Error executing query: ", std::get<1>(result));
}
RowType row()
@@ -78,7 +79,7 @@ class Table
return {this->name};
}
- SelectQuery<T...> select()
+ auto select()
{
SelectQuery<T...> select(this->name);
return select;
@@ -93,39 +94,34 @@ class Table
template <std::size_t N=0>
typename std::enable_if<N < sizeof...(T), void>::type
- add_column_if_not_exists(sqlite3* db, const std::set<std::string>& existing_columns)
+ add_column_if_not_exists(DatabaseEngine& db, const std::set<std::string>& existing_columns)
{
using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<ColumnTypes>()))>::type;
- if (existing_columns.count(ColumnType::name) != 1)
- {
- add_column_to_table<ColumnType>(db, this->name);
- }
+ if (existing_columns.count(ColumnType::name) == 0)
+ add_column_to_table<ColumnType>(db, this->name);
add_column_if_not_exists<N+1>(db, existing_columns);
}
template <std::size_t N=0>
typename std::enable_if<N == sizeof...(T), void>::type
- add_column_if_not_exists(sqlite3*, const std::set<std::string>&)
+ add_column_if_not_exists(DatabaseEngine&, const std::set<std::string>&)
{}
template <std::size_t N=0>
typename std::enable_if<N < sizeof...(T), void>::type
- add_column_create(std::string& str)
+ add_column_create(DatabaseEngine& db, std::string& str)
{
using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<ColumnTypes>()))>::type;
- using RealType = typename ColumnType::real_type;
str += ColumnType::name;
str += " ";
- str += TypeToSQLType<RealType>::type;
- append_option<ColumnType>(str);
+ str += ToSQLType<ColumnType>(db);
if (N != sizeof...(T) - 1)
str += ",";
- str += "\n";
- add_column_create<N+1>(str);
+ add_column_create<N+1>(db, str);
}
template <std::size_t N=0>
typename std::enable_if<N == sizeof...(T), void>::type
- add_column_create(std::string&)
+ add_column_create(DatabaseEngine&, std::string&)
{ }
const std::string name;
diff --git a/src/database/type_to_sql.cpp b/src/database/type_to_sql.cpp
deleted file mode 100644
index bcd9daa..0000000
--- a/src/database/type_to_sql.cpp
+++ /dev/null
@@ -1,9 +0,0 @@
-#include <database/type_to_sql.hpp>
-
-template <> const std::string TypeToSQLType<int>::type = "INTEGER";
-template <> const std::string TypeToSQLType<std::size_t>::type = "INTEGER";
-template <> const std::string TypeToSQLType<long>::type = "INTEGER";
-template <> const std::string TypeToSQLType<long long>::type = "INTEGER";
-template <> const std::string TypeToSQLType<bool>::type = "INTEGER";
-template <> const std::string TypeToSQLType<std::string>::type = "TEXT";
-template <> const std::string TypeToSQLType<OptionalBool>::type = "INTEGER"; \ No newline at end of file
diff --git a/src/database/type_to_sql.hpp b/src/database/type_to_sql.hpp
deleted file mode 100644
index ba806ab..0000000
--- a/src/database/type_to_sql.hpp
+++ /dev/null
@@ -1,16 +0,0 @@
-#pragma once
-
-#include <utils/optional_bool.hpp>
-
-#include <string>
-
-template <typename T>
-struct TypeToSQLType { static const std::string type; };
-
-template <> const std::string TypeToSQLType<int>::type;
-template <> const std::string TypeToSQLType<std::size_t>::type;
-template <> const std::string TypeToSQLType<long>::type;
-template <> const std::string TypeToSQLType<long long>::type;
-template <> const std::string TypeToSQLType<bool>::type;
-template <> const std::string TypeToSQLType<std::string>::type;
-template <> const std::string TypeToSQLType<OptionalBool>::type; \ No newline at end of file
diff --git a/src/database/update_query.hpp b/src/database/update_query.hpp
new file mode 100644
index 0000000..a29ac3f
--- /dev/null
+++ b/src/database/update_query.hpp
@@ -0,0 +1,104 @@
+#pragma once
+
+#include <database/query.hpp>
+#include <database/engine.hpp>
+
+using namespace std::string_literals;
+
+template <class T, class... Tuple>
+struct Index;
+
+template <class T, class... Types>
+struct Index<T, std::tuple<T, Types...>>
+{
+ static const std::size_t value = 0;
+};
+
+template <class T, class U, class... Types>
+struct Index<T, std::tuple<U, Types...>>
+{
+ static const std::size_t value = Index<T, std::tuple<Types...>>::value + 1;
+};
+
+struct UpdateQuery: public Query
+{
+ template <typename... T>
+ UpdateQuery(const std::string& name, const std::tuple<T...>& columns):
+ Query("UPDATE ")
+ {
+ this->body += name;
+ this->insert_col_names_and_values(columns);
+ }
+
+ template <typename... T>
+ void insert_col_names_and_values(const std::tuple<T...>& columns)
+ {
+ this->body += " SET ";
+ this->insert_col_name_and_value(columns);
+ this->body += " WHERE "s + Id::name + "=$" + std::to_string(this->current_param);
+ }
+
+ template <int N=0, typename... T>
+ typename std::enable_if<N < sizeof...(T), void>::type
+ insert_col_name_and_value(const std::tuple<T...>& columns)
+ {
+ using ColumnType = std::decay_t<decltype(std::get<N>(columns))>;
+
+ if (!std::is_same<ColumnType, Id>::value)
+ {
+ this->body += ColumnType::name + "=$"s + std::to_string(this->current_param);
+ this->current_param++;
+
+ if (N < (sizeof...(T) - 1))
+ this->body += ", ";
+ }
+
+ this->insert_col_name_and_value<N+1>(columns);
+ }
+ template <int N=0, typename... T>
+ typename std::enable_if<N == sizeof...(T), void>::type
+ insert_col_name_and_value(const std::tuple<T...>&)
+ {}
+
+
+ template <typename... T>
+ void execute(DatabaseEngine& db, const std::tuple<T...>& columns)
+ {
+#ifdef DEBUG_SQL_QUERIES
+ const auto timer = this->log_and_time();
+#endif
+
+ auto statement = db.prepare(this->body);
+ this->bind_param(columns, *statement);
+ this->bind_id(columns, *statement);
+
+ statement->step();
+ }
+
+ template <int N=0, typename... T>
+ typename std::enable_if<N < sizeof...(T), void>::type
+ bind_param(const std::tuple<T...>& columns, Statement& statement, int index=1)
+ {
+ auto&& column = std::get<N>(columns);
+ using ColumnType = std::decay_t<decltype(column)>;
+
+ if (!std::is_same<ColumnType, Id>::value)
+ actual_bind(statement, column.value, index++);
+
+ this->bind_param<N+1>(columns, statement, index);
+ }
+
+ template <int N=0, typename... T>
+ typename std::enable_if<N == sizeof...(T), void>::type
+ bind_param(const std::tuple<T...>&, Statement&, int)
+ {}
+
+ template <typename... T>
+ void bind_id(const std::tuple<T...>& columns, Statement& statement)
+ {
+ static constexpr auto index = Index<Id, std::tuple<T...>>::value;
+ auto&& value = std::get<index>(columns);
+
+ actual_bind(statement, value.value, sizeof...(T));
+ }
+};