diff options
Diffstat (limited to 'src/database')
-rw-r--r-- | src/database/column.hpp | 9 | ||||
-rw-r--r-- | src/database/count_query.hpp | 17 | ||||
-rw-r--r-- | src/database/database.cpp | 81 | ||||
-rw-r--r-- | src/database/database.hpp | 73 | ||||
-rw-r--r-- | src/database/engine.hpp | 41 | ||||
-rw-r--r-- | src/database/index.hpp | 38 | ||||
-rw-r--r-- | src/database/insert_query.hpp | 108 | ||||
-rw-r--r-- | src/database/postgresql_engine.cpp | 91 | ||||
-rw-r--r-- | src/database/postgresql_engine.hpp | 48 | ||||
-rw-r--r-- | src/database/postgresql_statement.hpp | 123 | ||||
-rw-r--r-- | src/database/query.cpp | 26 | ||||
-rw-r--r-- | src/database/query.hpp | 62 | ||||
-rw-r--r-- | src/database/row.hpp | 108 | ||||
-rw-r--r-- | src/database/select_query.hpp | 31 | ||||
-rw-r--r-- | src/database/sqlite3_engine.cpp | 101 | ||||
-rw-r--r-- | src/database/sqlite3_engine.hpp | 47 | ||||
-rw-r--r-- | src/database/sqlite3_statement.hpp | 92 | ||||
-rw-r--r-- | src/database/statement.hpp | 46 | ||||
-rw-r--r-- | src/database/table.cpp | 23 | ||||
-rw-r--r-- | src/database/table.hpp | 80 | ||||
-rw-r--r-- | src/database/type_to_sql.cpp | 9 | ||||
-rw-r--r-- | src/database/type_to_sql.hpp | 16 | ||||
-rw-r--r-- | src/database/update_query.hpp | 104 |
23 files changed, 1046 insertions, 328 deletions
diff --git a/src/database/column.hpp b/src/database/column.hpp index 111f9ca..1f16bcf 100644 --- a/src/database/column.hpp +++ b/src/database/column.hpp @@ -13,5 +13,10 @@ struct Column T value{}; }; -struct Id: Column<std::size_t> { static constexpr auto name = "id_"; - static constexpr auto options = "PRIMARY KEY AUTOINCREMENT"; }; +struct Id: Column<std::size_t> { + static constexpr std::size_t unset_value = static_cast<std::size_t>(-1); + static constexpr auto name = "id_"; + static constexpr auto options = "PRIMARY KEY"; + + Id(): Column<std::size_t>(-1) {} +}; diff --git a/src/database/count_query.hpp b/src/database/count_query.hpp index 0dde63c..118ce44 100644 --- a/src/database/count_query.hpp +++ b/src/database/count_query.hpp @@ -2,11 +2,10 @@ #include <database/query.hpp> #include <database/table.hpp> +#include <database/statement.hpp> #include <string> -#include <sqlite3.h> - struct CountQuery: public Query { CountQuery(std::string name): @@ -15,20 +14,20 @@ struct CountQuery: public Query this->body += std::move(name); } - int64_t execute(sqlite3* db) + int64_t execute(DatabaseEngine& db) { - auto statement = this->prepare(db); +#ifdef DEBUG_SQL_QUERIES + const auto timer = this->log_and_time(); +#endif + auto statement = db.prepare(this->body); int64_t res = 0; - if (sqlite3_step(statement.get()) == SQLITE_ROW) - res = sqlite3_column_int64(statement.get(), 0); + if (statement->step() != StepResult::Error) + res = statement->get_column_int64(0); else { log_error("Count request didn’t return a result"); return 0; } - if (sqlite3_step(statement.get()) != SQLITE_DONE) - log_warning("Count request returned more than one result."); - return res; } }; diff --git a/src/database/database.cpp b/src/database/database.cpp index 85c675e..3622963 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -6,40 +6,54 @@ #include <utils/get_first_non_empty.hpp> #include <utils/time.hpp> -#include <sqlite3.h> +#include <config/config.hpp> +#include <database/sqlite3_engine.hpp> +#include <database/postgresql_engine.hpp> -sqlite3* Database::db; -Database::MucLogLineTable Database::muc_log_lines("MucLogLine_"); -Database::GlobalOptionsTable Database::global_options("GlobalOptions_"); -Database::IrcServerOptionsTable Database::irc_server_options("IrcServerOptions_"); -Database::IrcChannelOptionsTable Database::irc_channel_options("IrcChannelOptions_"); +#include <database/engine.hpp> +#include <database/index.hpp> + +#include <memory> + +std::unique_ptr<DatabaseEngine> Database::db; +Database::MucLogLineTable Database::muc_log_lines("muclogline_"); +Database::GlobalOptionsTable Database::global_options("globaloptions_"); +Database::IrcServerOptionsTable Database::irc_server_options("ircserveroptions_"); +Database::IrcChannelOptionsTable Database::irc_channel_options("ircchanneloptions_"); Database::RosterTable Database::roster("roster"); +std::map<Database::CacheKey, Database::EncodingIn::real_type> Database::encoding_in_cache{}; +Database::GlobalPersistent::GlobalPersistent(): + Column<bool>{Config::get_bool("persistent_by_default", false)} +{} void Database::open(const std::string& filename) { // Try to open the specified database. // Close and replace the previous database pointer if it succeeded. If it did // not, just leave things untouched - sqlite3* new_db; - auto res = sqlite3_open_v2(filename.data(), &new_db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr); - if (res != SQLITE_OK) - { - log_error("Failed to open database file ", filename, ": ", sqlite3_errmsg(Database::db)); - throw std::runtime_error(""); - } - Database::close(); - Database::db = new_db; - Database::muc_log_lines.create(Database::db); - Database::muc_log_lines.upgrade(Database::db); - Database::global_options.create(Database::db); - Database::global_options.upgrade(Database::db); - Database::irc_server_options.create(Database::db); - Database::irc_server_options.upgrade(Database::db); - Database::irc_channel_options.create(Database::db); - Database::irc_channel_options.upgrade(Database::db); - Database::roster.create(Database::db); - Database::roster.upgrade(Database::db); + std::unique_ptr<DatabaseEngine> new_db; + static const auto psql_prefix = "postgresql://"s; + static const auto psql_prefix2 = "postgres://"s; + if ((filename.substr(0, psql_prefix.size()) == psql_prefix) || + (filename.substr(0, psql_prefix2.size()) == psql_prefix2)) + new_db = PostgresqlEngine::open(filename); + else + new_db = Sqlite3Engine::open(filename); + if (!new_db) + return; + Database::db = std::move(new_db); + Database::muc_log_lines.create(*Database::db); + Database::muc_log_lines.upgrade(*Database::db); + Database::global_options.create(*Database::db); + Database::global_options.upgrade(*Database::db); + Database::irc_server_options.create(*Database::db); + Database::irc_server_options.upgrade(*Database::db); + Database::irc_channel_options.create(*Database::db); + Database::irc_channel_options.upgrade(*Database::db); + Database::roster.create(*Database::db); + Database::roster.upgrade(*Database::db); + create_index<Database::Owner, Database::IrcChanName, Database::IrcServerName>(*Database::db, "archive_index", Database::muc_log_lines.get_name()); } @@ -49,7 +63,7 @@ Database::GlobalOptions Database::get_global_options(const std::string& owner) request.where() << Owner{} << "=" << owner; Database::GlobalOptions options{Database::global_options.get_name()}; - auto result = request.execute(Database::db); + auto result = request.execute(*Database::db); if (result.size() == 1) options = result.front(); else @@ -63,7 +77,7 @@ Database::IrcServerOptions Database::get_irc_server_options(const std::string& o request.where() << Owner{} << "=" << owner << " and " << Server{} << "=" << server; Database::IrcServerOptions options{Database::irc_server_options.get_name()}; - auto result = request.execute(Database::db); + auto result = request.execute(*Database::db); if (result.size() == 1) options = result.front(); else @@ -81,7 +95,7 @@ Database::IrcChannelOptions Database::get_irc_channel_options(const std::string& " and " << Server{} << "=" << server <<\ " and " << Channel{} << "=" << channel; Database::IrcChannelOptions options{Database::irc_channel_options.get_name()}; - auto result = request.execute(Database::db); + auto result = request.execute(*Database::db); if (result.size() == 1) options = result.front(); else @@ -176,7 +190,7 @@ std::vector<Database::MucLogLine> Database::get_muc_logs(const std::string& owne if (limit >= 0) request.limit() << limit; - auto result = request.execute(Database::db); + auto result = request.execute(*Database::db); return {result.crbegin(), result.crend()}; } @@ -197,7 +211,7 @@ void Database::delete_roster_item(const std::string& local, const std::string& r query << " WHERE " << Database::RemoteJid{} << "=" << remote << \ " AND " << Database::LocalJid{} << "=" << local; - query.execute(Database::db); +// query.execute(*Database::db); } bool Database::has_roster_item(const std::string& local, const std::string& remote) @@ -206,7 +220,7 @@ bool Database::has_roster_item(const std::string& local, const std::string& remo query.where() << Database::LocalJid{} << "=" << local << \ " and " << Database::RemoteJid{} << "=" << remote; - auto res = query.execute(Database::db); + auto res = query.execute(*Database::db); return !res.empty(); } @@ -216,19 +230,18 @@ std::vector<Database::RosterItem> Database::get_contact_list(const std::string& auto query = Database::roster.select(); query.where() << Database::LocalJid{} << "=" << local; - return query.execute(Database::db); + return query.execute(*Database::db); } std::vector<Database::RosterItem> Database::get_full_roster() { auto query = Database::roster.select(); - return query.execute(Database::db); + return query.execute(*Database::db); } void Database::close() { - sqlite3_close_v2(Database::db); Database::db = nullptr; } diff --git a/src/database/database.hpp b/src/database/database.hpp index c00c938..ec44543 100644 --- a/src/database/database.hpp +++ b/src/database/database.hpp @@ -7,12 +7,15 @@ #include <database/column.hpp> #include <database/count_query.hpp> +#include <database/engine.hpp> + #include <utils/optional_bool.hpp> #include <chrono> #include <string> #include <memory> +#include <map> class Database @@ -24,11 +27,11 @@ class Database struct Owner: Column<std::string> { static constexpr auto name = "owner_"; }; - struct IrcChanName: Column<std::string> { static constexpr auto name = "ircChanName_"; }; + struct IrcChanName: Column<std::string> { static constexpr auto name = "ircchanname_"; }; struct Channel: Column<std::string> { static constexpr auto name = "channel_"; }; - struct IrcServerName: Column<std::string> { static constexpr auto name = "ircServerName_"; }; + struct IrcServerName: Column<std::string> { static constexpr auto name = "ircservername_"; }; struct Server: Column<std::string> { static constexpr auto name = "server_"; }; @@ -43,35 +46,38 @@ class Database struct Ports: Column<std::string> { static constexpr auto name = "ports_"; Ports(): Column<std::string>("6667") {} }; - struct TlsPorts: Column<std::string> { static constexpr auto name = "tlsPorts_"; + struct TlsPorts: Column<std::string> { static constexpr auto name = "tlsports_"; TlsPorts(): Column<std::string>("6697;6670") {} }; struct Username: Column<std::string> { static constexpr auto name = "username_"; }; struct Realname: Column<std::string> { static constexpr auto name = "realname_"; }; - struct AfterConnectionCommand: Column<std::string> { static constexpr auto name = "afterConnectionCommand_"; }; + struct AfterConnectionCommand: Column<std::string> { static constexpr auto name = "afterconnectioncommand_"; }; - struct TrustedFingerprint: Column<std::string> { static constexpr auto name = "trustedFingerprint_"; }; + struct TrustedFingerprint: Column<std::string> { static constexpr auto name = "trustedfingerprint_"; }; - struct EncodingOut: Column<std::string> { static constexpr auto name = "encodingOut_"; }; + struct EncodingOut: Column<std::string> { static constexpr auto name = "encodingout_"; }; - struct EncodingIn: Column<std::string> { static constexpr auto name = "encodingIn_"; }; + struct EncodingIn: Column<std::string> { static constexpr auto name = "encodingin_"; }; - struct MaxHistoryLength: Column<int> { static constexpr auto name = "maxHistoryLength_"; + struct MaxHistoryLength: Column<int> { static constexpr auto name = "maxhistorylength_"; MaxHistoryLength(): Column<int>(20) {} }; - struct RecordHistory: Column<bool> { static constexpr auto name = "recordHistory_"; + struct RecordHistory: Column<bool> { static constexpr auto name = "recordhistory_"; RecordHistory(): Column<bool>(true) {}}; - struct RecordHistoryOptional: Column<OptionalBool> { static constexpr auto name = "recordHistory_"; }; + struct RecordHistoryOptional: Column<OptionalBool> { static constexpr auto name = "recordhistory_"; }; - struct VerifyCert: Column<bool> { static constexpr auto name = "verifyCert_"; + struct VerifyCert: Column<bool> { static constexpr auto name = "verifycert_"; VerifyCert(): Column<bool>(true) {} }; struct Persistent: Column<bool> { static constexpr auto name = "persistent_"; Persistent(): Column<bool>(false) {} }; + struct GlobalPersistent: Column<bool> { static constexpr auto name = "persistent_"; + GlobalPersistent(); }; + struct LocalJid: Column<std::string> { static constexpr auto name = "local"; }; struct RemoteJid: Column<std::string> { static constexpr auto name = "remote"; }; @@ -80,7 +86,7 @@ class Database using MucLogLineTable = Table<Id, Uuid, Owner, IrcChanName, IrcServerName, Date, Body, Nick>; using MucLogLine = MucLogLineTable::RowType; - using GlobalOptionsTable = Table<Id, Owner, MaxHistoryLength, RecordHistory, Persistent>; + using GlobalOptionsTable = Table<Id, Owner, MaxHistoryLength, RecordHistory, GlobalPersistent>; using GlobalOptions = GlobalOptionsTable::RowType; using IrcServerOptionsTable = Table<Id, Owner, Server, Pass, AfterConnectionCommand, TlsPorts, Ports, Username, Realname, VerifyCert, TrustedFingerprint, EncodingOut, EncodingIn, MaxHistoryLength>; @@ -130,7 +136,7 @@ class Database static int64_t count(const TableType& table) { CountQuery query{table.get_name()}; - return query.execute(Database::db); + return query.execute(*Database::db); } static MucLogLineTable muc_log_lines; @@ -138,9 +144,48 @@ class Database static IrcServerOptionsTable irc_server_options; static IrcChannelOptionsTable irc_channel_options; static RosterTable roster; - static sqlite3* db; + static std::unique_ptr<DatabaseEngine> db; + + /** + * Some caches, to avoid doing very frequent query requests for a few options. + */ + using CacheKey = std::tuple<std::string, std::string, std::string>; + + static EncodingIn::real_type get_encoding_in(const std::string& owner, + const std::string& server, + const std::string& channel) + { + CacheKey channel_key{owner, server, channel}; + auto it = Database::encoding_in_cache.find(channel_key); + if (it == Database::encoding_in_cache.end()) + { + auto options = Database::get_irc_channel_options_with_server_default(owner, server, channel); + EncodingIn::real_type result = options.col<Database::EncodingIn>(); + if (result.empty()) + result = "ISO-8859-1"; + it = Database::encoding_in_cache.insert(std::make_pair(channel_key, result)).first; + } + return it->second; + } + static void invalidate_encoding_in_cache(const std::string& owner, + const std::string& server, + const std::string& channel) + { + CacheKey channel_key{owner, server, channel}; + Database::encoding_in_cache.erase(channel_key); + } + static void invalidate_encoding_in_cache() + { + Database::encoding_in_cache.clear(); + } + + static auto raw_exec(const std::string& query) + { + Database::db->raw_exec(query); + } private: static std::string gen_uuid(); + static std::map<CacheKey, EncodingIn::real_type> encoding_in_cache; }; #endif /* USE_DATABASE */ diff --git a/src/database/engine.hpp b/src/database/engine.hpp new file mode 100644 index 0000000..41dccf5 --- /dev/null +++ b/src/database/engine.hpp @@ -0,0 +1,41 @@ +#pragma once + +/** + * Interface to provide non-portable behaviour, specific to each + * database engine we want to support. + * + * Everything else (all portable stuf) should go outside of this class. + */ + +#include <database/statement.hpp> + +#include <memory> +#include <string> +#include <vector> +#include <tuple> +#include <set> + +class DatabaseEngine +{ + public: + + DatabaseEngine() = default; + virtual ~DatabaseEngine() = default; + + DatabaseEngine(const DatabaseEngine&) = delete; + DatabaseEngine& operator=(const DatabaseEngine&) = delete; + DatabaseEngine(DatabaseEngine&&) = delete; + DatabaseEngine& operator=(DatabaseEngine&&) = delete; + + virtual std::set<std::string> get_all_columns_from_table(const std::string& table_name) = 0; + virtual std::tuple<bool, std::string> raw_exec(const std::string& query) = 0; + virtual std::unique_ptr<Statement> prepare(const std::string& query) = 0; + virtual void extract_last_insert_rowid(Statement& statement) = 0; + virtual std::string get_returning_id_sql_string(const std::string&) + { + return {}; + } + virtual std::string id_column_type() = 0; + + int64_t last_inserted_rowid{-1}; +}; diff --git a/src/database/index.hpp b/src/database/index.hpp new file mode 100644 index 0000000..30766ab --- /dev/null +++ b/src/database/index.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include <database/engine.hpp> + +#include <string> +#include <tuple> + +namespace +{ +template <std::size_t N=0, typename... T> +typename std::enable_if<N == sizeof...(T), void>::type +add_column_name(std::string&) +{ } + +template <std::size_t N=0, typename... T> +typename std::enable_if<N < sizeof...(T), void>::type +add_column_name(std::string& out) +{ + using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<std::tuple<T...>>()))>::type; + out += ColumnType::name; + if (N != sizeof...(T) - 1) + out += ","; + add_column_name<N+1, T...>(out); +} +} + +template <typename... Columns> +void create_index(DatabaseEngine& db, const std::string& name, const std::string& table) +{ + std::string query{"CREATE INDEX IF NOT EXISTS "}; + query += name + " ON " + table + "("; + add_column_name<0, Columns...>(query); + query += ")"; + + auto result = db.raw_exec(query); + if (std::get<0>(result) == false) + log_error("Error executing query: ", std::get<1>(result)); +} diff --git a/src/database/insert_query.hpp b/src/database/insert_query.hpp index 2ece69d..9726424 100644 --- a/src/database/insert_query.hpp +++ b/src/database/insert_query.hpp @@ -10,64 +10,64 @@ #include <string> #include <tuple> -#include <sqlite3.h> - -template <int N, typename ColumnType, typename... T> -typename std::enable_if<!std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type -actual_bind(Statement& statement, std::vector<std::string>& params, const std::tuple<T...>&) +template <std::size_t N=0, typename... T> +typename std::enable_if<N < sizeof...(T), void>::type +update_autoincrement_id(std::tuple<T...>& columns, Statement& statement) { - const auto value = params.front(); - params.erase(params.begin()); - if (sqlite3_bind_text(statement.get(), N + 1, value.data(), static_cast<int>(value.size()), SQLITE_TRANSIENT) != SQLITE_OK) - log_error("Failed to bind ", value, " to param ", N); + using ColumnType = typename std::decay<decltype(std::get<N>(columns))>::type; + if (std::is_same<ColumnType, Id>::value) + auto&& column = std::get<Id>(columns); + update_autoincrement_id<N+1>(columns, statement); } -template <int N, typename ColumnType, typename... T> -typename std::enable_if<std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type -actual_bind(Statement& statement, std::vector<std::string>&, const std::tuple<T...>& columns) -{ - auto&& column = std::get<Id>(columns); - if (column.value != 0) - { - if (sqlite3_bind_int64(statement.get(), N + 1, static_cast<sqlite3_int64>(column.value)) != SQLITE_OK) - log_error("Failed to bind ", column.value, " to id."); - } - else if (sqlite3_bind_null(statement.get(), N + 1) != SQLITE_OK) - log_error("Failed to bind NULL to param ", N); -} +template <std::size_t N=0, typename... T> +typename std::enable_if<N == sizeof...(T), void>::type +update_autoincrement_id(std::tuple<T...>&, Statement& statement) +{} struct InsertQuery: public Query { - InsertQuery(const std::string& name): - Query("INSERT OR REPLACE INTO ") + template <typename... T> + InsertQuery(const std::string& name, const std::tuple<T...>& columns): + Query("INSERT INTO ") { this->body += name; + this->insert_col_names(columns); + this->insert_values(columns); } template <typename... T> - void execute(const std::tuple<T...>& columns, sqlite3* db) + void execute(DatabaseEngine& db, std::tuple<T...>& columns) { - auto statement = this->prepare(db); - { - this->bind_param(columns, statement); - if (sqlite3_step(statement.get()) != SQLITE_DONE) - log_error("Failed to execute query: ", sqlite3_errmsg(db)); - } +#ifdef DEBUG_SQL_QUERIES + const auto timer = this->log_and_time(); +#endif + + auto statement = db.prepare(this->body); + this->bind_param(columns, *statement); + + if (statement->step() != StepResult::Error) + db.extract_last_insert_rowid(*statement); + else + log_error("Failed to extract the rowid from the last INSERT"); } template <int N=0, typename... T> typename std::enable_if<N < sizeof...(T), void>::type - bind_param(const std::tuple<T...>& columns, Statement& statement) + bind_param(const std::tuple<T...>& columns, Statement& statement, int index=1) { - using ColumnType = typename std::remove_reference<decltype(std::get<N>(columns))>::type; + auto&& column = std::get<N>(columns); + using ColumnType = std::decay_t<decltype(column)>; - actual_bind<N, ColumnType>(statement, this->params, columns); - this->bind_param<N+1>(columns, statement); + if (!std::is_same<ColumnType, Id>::value) + actual_bind(statement, column.value, index++); + + this->bind_param<N+1>(columns, statement, index); } template <int N=0, typename... T> typename std::enable_if<N == sizeof...(T), void>::type - bind_param(const std::tuple<T...>&, Statement&) + bind_param(const std::tuple<T...>&, Statement&, int) {} template <typename... T> @@ -80,18 +80,21 @@ struct InsertQuery: public Query template <int N=0, typename... T> typename std::enable_if<N < sizeof...(T), void>::type - insert_value(const std::tuple<T...>& columns) + insert_value(const std::tuple<T...>& columns, int index=1) { - this->body += "?"; - if (N != sizeof...(T) - 1) - this->body += ","; - this->body += " "; - add_param(*this, std::get<N>(columns)); - this->insert_value<N+1>(columns); + using ColumnType = std::decay_t<decltype(std::get<N>(columns))>; + + if (!std::is_same<ColumnType, Id>::value) + { + this->body += "$" + std::to_string(index++); + if (N != sizeof...(T) - 1) + this->body += ", "; + } + this->insert_value<N+1>(columns, index); } template <int N=0, typename... T> typename std::enable_if<N == sizeof...(T), void>::type - insert_value(const std::tuple<T...>&) + insert_value(const std::tuple<T...>&, const int) { } template <typename... T> @@ -99,27 +102,28 @@ struct InsertQuery: public Query { this->body += " ("; this->insert_col_name(columns); - this->body += ")\n"; + this->body += ")"; } template <int N=0, typename... T> typename std::enable_if<N < sizeof...(T), void>::type insert_col_name(const std::tuple<T...>& columns) { - using ColumnType = typename std::remove_reference<decltype(std::get<N>(columns))>::type; + using ColumnType = std::decay_t<decltype(std::get<N>(columns))>; - this->body += ColumnType::name; + if (!std::is_same<ColumnType, Id>::value) + { + this->body += ColumnType::name; - if (N < (sizeof...(T) - 1)) - this->body += ", "; + if (N < (sizeof...(T) - 1)) + this->body += ", "; + } this->insert_col_name<N+1>(columns); } + template <int N=0, typename... T> typename std::enable_if<N == sizeof...(T), void>::type insert_col_name(const std::tuple<T...>&) {} - - - private: }; diff --git a/src/database/postgresql_engine.cpp b/src/database/postgresql_engine.cpp new file mode 100644 index 0000000..984a959 --- /dev/null +++ b/src/database/postgresql_engine.cpp @@ -0,0 +1,91 @@ +#include <biboumi.h> +#ifdef PQ_FOUND + +#include <utils/scopeguard.hpp> + +#include <database/query.hpp> + +#include <database/postgresql_engine.hpp> + +#include <database/postgresql_statement.hpp> + +#include <logger/logger.hpp> + +PostgresqlEngine::PostgresqlEngine(PGconn*const conn): + conn(conn) +{} + +PostgresqlEngine::~PostgresqlEngine() +{ + PQfinish(this->conn); +} + +std::unique_ptr<DatabaseEngine> PostgresqlEngine::open(const std::string& conninfo) +{ + PGconn* con = PQconnectdb(conninfo.data()); + + if (!con) + { + log_error("Failed to allocate a Postgresql connection"); + throw std::runtime_error(""); + } + const auto status = PQstatus(con); + if (status != CONNECTION_OK) + { + const char* errmsg = PQerrorMessage(con); + log_error("Postgresql connection failed: ", errmsg); + throw std::runtime_error("failed to open connection."); + } + return std::make_unique<PostgresqlEngine>(con); +} + +std::set<std::string> PostgresqlEngine::get_all_columns_from_table(const std::string& table_name) +{ + const auto query = "SELECT column_name from information_schema.columns where table_name='" + table_name + "'"; + auto statement = this->prepare(query); + std::set<std::string> columns; + + while (statement->step() == StepResult::Row) + columns.insert(statement->get_column_text(0)); + + return columns; +} + +std::tuple<bool, std::string> PostgresqlEngine::raw_exec(const std::string& query) +{ +#ifdef DEBUG_SQL_QUERIES + log_debug("SQL QUERY: ", query); + const auto timer = make_sql_timer(); +#endif + PGresult* res = PQexec(this->conn, query.data()); + auto sg = utils::make_scope_guard([res](){ + PQclear(res); + }); + + auto res_status = PQresultStatus(res); + if (res_status != PGRES_COMMAND_OK) + return std::make_tuple(false, PQresultErrorMessage(res)); + return std::make_tuple(true, std::string{}); +} + +std::unique_ptr<Statement> PostgresqlEngine::prepare(const std::string& query) +{ + return std::make_unique<PostgresqlStatement>(query, this->conn); +} + +void PostgresqlEngine::extract_last_insert_rowid(Statement& statement) +{ + this->last_inserted_rowid = statement.get_column_int64(0); +} + +std::string PostgresqlEngine::get_returning_id_sql_string(const std::string& col_name) +{ + return " RETURNING " + col_name; +} + +std::string PostgresqlEngine::id_column_type() +{ + return "SERIAL"; +} + +#endif diff --git a/src/database/postgresql_engine.hpp b/src/database/postgresql_engine.hpp new file mode 100644 index 0000000..fe4fb53 --- /dev/null +++ b/src/database/postgresql_engine.hpp @@ -0,0 +1,48 @@ +#pragma once + +#include <biboumi.h> +#include <string> +#include <stdexcept> +#include <memory> + +#include <database/statement.hpp> +#include <database/engine.hpp> + +#include <tuple> +#include <set> + +#ifdef PQ_FOUND + +#include <libpq-fe.h> + +class PostgresqlEngine: public DatabaseEngine +{ + public: + PostgresqlEngine(PGconn*const conn); + + ~PostgresqlEngine(); + + static std::unique_ptr<DatabaseEngine> open(const std::string& string); + + std::set<std::string> get_all_columns_from_table(const std::string& table_name) override final; + std::tuple<bool, std::string> raw_exec(const std::string& query) override final; + std::unique_ptr<Statement> prepare(const std::string& query) override; + void extract_last_insert_rowid(Statement& statement) override; + std::string get_returning_id_sql_string(const std::string& col_name) override; + std::string id_column_type() override; +private: + PGconn* const conn; +}; + +#else + +class PostgresqlEngine +{ +public: + static std::unique_ptr<DatabaseEngine> open(const std::string& string) + { + throw std::runtime_error("Cannot open postgresql database "s + string + ": biboumi is not compiled with libpq."); + } +}; + +#endif diff --git a/src/database/postgresql_statement.hpp b/src/database/postgresql_statement.hpp new file mode 100644 index 0000000..571c8f1 --- /dev/null +++ b/src/database/postgresql_statement.hpp @@ -0,0 +1,123 @@ +#pragma once + +#include <database/statement.hpp> + +#include <logger/logger.hpp> + +#include <libpq-fe.h> + +class PostgresqlStatement: public Statement +{ + public: + PostgresqlStatement(std::string body, PGconn*const conn): + body(std::move(body)), + conn(conn) + {} + ~PostgresqlStatement() + { + PQclear(this->result); + this->result = nullptr; + } + PostgresqlStatement(const PostgresqlStatement&) = delete; + PostgresqlStatement& operator=(const PostgresqlStatement&) = delete; + PostgresqlStatement(PostgresqlStatement&& other) = delete; + PostgresqlStatement& operator=(PostgresqlStatement&& other) = delete; + + StepResult step() override final + { + if (!this->executed) + { + this->current_tuple = 0; + this->executed = true; + if (!this->execute()) + return StepResult::Error; + } + else + { + this->current_tuple++; + } + if (this->current_tuple < PQntuples(this->result)) + return StepResult::Row; + return StepResult::Done; + } + + int64_t get_column_int64(const int col) override + { + const char* result = PQgetvalue(this->result, this->current_tuple, col); + std::istringstream iss; + iss.str(result); + int64_t res; + iss >> res; + return res; + } + std::string get_column_text(const int col) override + { + const char* result = PQgetvalue(this->result, this->current_tuple, col); + return result; + } + int get_column_int(const int col) override + { + const char* result = PQgetvalue(this->result, this->current_tuple, col); + std::istringstream iss; + iss.str(result); + int res; + iss >> res; + return res; + } + + void bind(std::vector<std::string> params) override + { + + this->params = std::move(params); + } + + bool bind_text(const int, const std::string& data) override + { + this->params.push_back(data); + return true; + } + bool bind_int64(const int, const std::int64_t value) override + { + this->params.push_back(std::to_string(value)); + return true; + } + bool bind_null(const int) override + { + this->params.push_back("NULL"); + return true; + } + + private: + +private: + bool execute() + { + std::vector<const char*> params; + params.reserve(this->params.size()); + + for (const auto& param: this->params) + params.push_back(param.data()); + const int param_size = static_cast<int>(this->params.size()); + this->result = PQexecParams(this->conn, this->body.data(), + param_size, + nullptr, + params.data(), + nullptr, + nullptr, + 0); + const auto status = PQresultStatus(this->result); + if (status != PGRES_TUPLES_OK && status != PGRES_COMMAND_OK) + { + log_error("Failed to execute command: ", PQresultErrorMessage(this->result)); + return false; + } + return true; + } + + bool executed{false}; + std::string body; + PGconn*const conn; + std::vector<std::string> params; + PGresult* result{nullptr}; + int current_tuple{0}; +}; diff --git a/src/database/query.cpp b/src/database/query.cpp index ba63a92..d27dc59 100644 --- a/src/database/query.cpp +++ b/src/database/query.cpp @@ -1,9 +1,26 @@ #include <database/query.hpp> #include <database/column.hpp> -template <> -void add_param<Id>(Query&, const Id&) -{} +void actual_bind(Statement& statement, const std::string& value, int index) +{ + statement.bind_text(index, value); +} + +void actual_bind(Statement& statement, const std::int64_t value, int index) +{ + statement.bind_int64(index, value); +} + +void actual_bind(Statement& statement, const OptionalBool& value, int index) +{ + if (!value.is_set) + statement.bind_int64(index, 0); + else if (value.value) + statement.bind_int64(index, 1); + else + statement.bind_int64(index, -1); +} + void actual_add_param(Query& query, const std::string& val) { @@ -28,7 +45,8 @@ Query& operator<<(Query& query, const char* str) Query& operator<<(Query& query, const std::string& str) { - query.body += "?"; + query.body += "$" + std::to_string(query.current_param); + query.current_param++; actual_add_param(query, str); return query; } diff --git a/src/database/query.hpp b/src/database/query.hpp index 6e1db12..8434944 100644 --- a/src/database/query.hpp +++ b/src/database/query.hpp @@ -1,5 +1,7 @@ #pragma once +#include <biboumi.h> + #include <utils/optional_bool.hpp> #include <database/statement.hpp> #include <database/column.hpp> @@ -9,54 +11,53 @@ #include <vector> #include <string> -#include <sqlite3.h> +void actual_bind(Statement& statement, const std::string& value, int index); +void actual_bind(Statement& statement, const std::int64_t value, int index); +void actual_bind(Statement& statement, const OptionalBool& value, int index); + +#ifdef DEBUG_SQL_QUERIES +#include <utils/scopetimer.hpp> + +inline auto make_sql_timer() +{ + return make_scope_timer([](const std::chrono::steady_clock::duration& elapsed) + { + const auto seconds = std::chrono::duration_cast<std::chrono::seconds>(elapsed); + const auto rest = elapsed - seconds; + log_debug("Query executed in ", seconds.count(), ".", rest.count(), "s."); + }); +} +#endif struct Query { std::string body; std::vector<std::string> params; + int current_param{1}; Query(std::string str): body(std::move(str)) {} - Statement prepare(sqlite3* db) +#ifdef DEBUG_SQL_QUERIES + auto log_and_time() { - sqlite3_stmt* stmt; - auto res = sqlite3_prepare(db, this->body.data(), static_cast<int>(this->body.size()) + 1, - &stmt, nullptr); - if (res != SQLITE_OK) - { - log_error("Error preparing statement: ", sqlite3_errmsg(db)); - return nullptr; - } - Statement statement(stmt); - int i = 1; - for (const std::string& param: this->params) - { - if (sqlite3_bind_text(statement.get(), i, param.data(), static_cast<int>(param.size()), SQLITE_TRANSIENT) != SQLITE_OK) - log_error("Failed to bind ", param, " to param ", i); - i++; - } - - return statement; - } - - void execute(sqlite3* db) - { - auto statement = this->prepare(db); - while (sqlite3_step(statement.get()) != SQLITE_DONE) - ; + std::ostringstream os; + os << this->body << "; "; + for (const auto& param: this->params) + os << "'" << param << "' "; + log_debug("SQL QUERY: ", os.str()); + return make_sql_timer(); } +#endif }; template <typename ColumnType> void add_param(Query& query, const ColumnType& column) { + std::cout << "add_param<ColumnType>" << std::endl; actual_add_param(query, column.value); } -template <> -void add_param<Id>(Query& query, const Id& column); template <typename T> void actual_add_param(Query& query, const T& val) @@ -81,7 +82,8 @@ template <typename Integer> typename std::enable_if<std::is_integral<Integer>::value, Query&>::type operator<<(Query& query, const Integer& i) { - query.body += "?"; + query.body += "$" + std::to_string(query.current_param++); actual_add_param(query, i); return query; } + diff --git a/src/database/row.hpp b/src/database/row.hpp index 2b50874..4dc98be 100644 --- a/src/database/row.hpp +++ b/src/database/row.hpp @@ -1,72 +1,72 @@ #pragma once #include <database/insert_query.hpp> +#include <database/update_query.hpp> #include <logger/logger.hpp> -#include <type_traits> - -#include <sqlite3.h> +#include <utils/is_one_of.hpp> -template <typename ColumnType, typename... T> -typename std::enable_if<!std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type -update_id(std::tuple<T...>&, sqlite3*) -{} +#include <type_traits> -template <typename ColumnType, typename... T> -typename std::enable_if<std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type -update_id(std::tuple<T...>& columns, sqlite3* db) +template <typename... T> +struct Row { - auto&& column = std::get<ColumnType>(columns); - auto res = sqlite3_last_insert_rowid(db); - column.value = static_cast<Id::real_type>(res); -} + Row(std::string name): + table_name(std::move(name)) + {} -template <std::size_t N=0, typename... T> -typename std::enable_if<N < sizeof...(T), void>::type -update_autoincrement_id(std::tuple<T...>& columns, sqlite3* db) -{ - using ColumnType = typename std::remove_reference<decltype(std::get<N>(columns))>::type; - update_id<ColumnType>(columns, db); - update_autoincrement_id<N+1>(columns, db); -} + template <typename Type> + typename Type::real_type& col() + { + auto&& col = std::get<Type>(this->columns); + return col.value; + } -template <std::size_t N=0, typename... T> -typename std::enable_if<N == sizeof...(T), void>::type -update_autoincrement_id(std::tuple<T...>&, sqlite3*) -{} + template <typename Type> + const auto& col() const + { + auto&& col = std::get<Type>(this->columns); + return col.value; + } -template <typename... T> -struct Row -{ - Row(std::string name): - table_name(std::move(name)) - {} + template <bool Coucou=true> + void save(std::unique_ptr<DatabaseEngine>& db, typename std::enable_if<!is_one_of<Id, T...> && Coucou>::type* = nullptr) + { + this->insert(*db); + } - template <typename Type> - typename Type::real_type& col() - { - auto&& col = std::get<Type>(this->columns); - return col.value; - } + template <bool Coucou=true> + void save(std::unique_ptr<DatabaseEngine>& db, typename std::enable_if<is_one_of<Id, T...> && Coucou>::type* = nullptr) + { + const Id& id = std::get<Id>(this->columns); + if (id.value == Id::unset_value) + { + this->insert(*db); + if (db->last_inserted_rowid >= 0) + std::get<Id>(this->columns).value = static_cast<Id::real_type>(db->last_inserted_rowid); + } + else + this->update(*db); + } - template <typename Type> - const auto& col() const - { - auto&& col = std::get<Type>(this->columns); - return col.value; - } + private: + void insert(DatabaseEngine& db) + { + InsertQuery query(this->table_name, this->columns); + // Ugly workaround for non portable stuff + query.body += db.get_returning_id_sql_string(Id::name); + query.execute(db, this->columns); + } - void save(sqlite3* db) - { - InsertQuery query(this->table_name); - query.insert_col_names(this->columns); - query.insert_values(this->columns); + void update(DatabaseEngine& db) + { + UpdateQuery query(this->table_name, this->columns); - query.execute(this->columns, db); + query.execute(db, this->columns); + } - update_autoincrement_id(this->columns, db); - } +public: + std::tuple<T...> columns; + std::string table_name; - std::tuple<T...> columns; - std::string table_name; }; diff --git a/src/database/select_query.hpp b/src/database/select_query.hpp index 872001c..5a17f38 100644 --- a/src/database/select_query.hpp +++ b/src/database/select_query.hpp @@ -1,5 +1,7 @@ #pragma once +#include <database/engine.hpp> + #include <database/statement.hpp> #include <database/query.hpp> #include <logger/logger.hpp> @@ -10,32 +12,27 @@ #include <vector> #include <string> -#include <sqlite3.h> - using namespace std::string_literals; template <typename T> -typename std::enable_if<std::is_integral<T>::value, sqlite3_int64>::type +typename std::enable_if<std::is_integral<T>::value, std::int64_t>::type extract_row_value(Statement& statement, const int i) { - return sqlite3_column_int64(statement.get(), i); + return statement.get_column_int64(i); } template <typename T> typename std::enable_if<std::is_same<std::string, T>::value, T>::type extract_row_value(Statement& statement, const int i) { - const auto size = sqlite3_column_bytes(statement.get(), i); - const unsigned char* str = sqlite3_column_text(statement.get(), i); - std::string result(reinterpret_cast<const char*>(str), static_cast<std::size_t>(size)); - return result; + return statement.get_column_text(i); } template <typename T> typename std::enable_if<std::is_same<OptionalBool, T>::value, T>::type extract_row_value(Statement& statement, const int i) { - const auto integer = sqlite3_column_int(statement.get(), i); + const auto integer = statement.get_column_int(i); OptionalBool result; if (integer > 0) result.set_value(true); @@ -109,16 +106,24 @@ struct SelectQuery: public Query return *this; } - auto execute(sqlite3* db) + auto execute(DatabaseEngine& db) { - auto statement = this->prepare(db); std::vector<Row<T...>> rows; - while (sqlite3_step(statement.get()) == SQLITE_ROW) + +#ifdef DEBUG_SQL_QUERIES + const auto timer = this->log_and_time(); +#endif + + auto statement = db.prepare(this->body); + statement->bind(std::move(this->params)); + + while (statement->step() == StepResult::Row) { Row<T...> row(this->table_name); - extract_row_values(row, statement); + extract_row_values(row, *statement); rows.push_back(row); } + return rows; } diff --git a/src/database/sqlite3_engine.cpp b/src/database/sqlite3_engine.cpp new file mode 100644 index 0000000..ae4a146 --- /dev/null +++ b/src/database/sqlite3_engine.cpp @@ -0,0 +1,101 @@ +#include <biboumi.h> + +#ifdef SQLITE3_FOUND + +#include <database/sqlite3_engine.hpp> + +#include <database/sqlite3_statement.hpp> + +#include <database/query.hpp> + +#include <utils/tolower.hpp> +#include <logger/logger.hpp> +#include <vector> + +Sqlite3Engine::Sqlite3Engine(sqlite3* db): + db(db) +{ +} + +Sqlite3Engine::~Sqlite3Engine() +{ + sqlite3_close(this->db); +} + +std::set<std::string> Sqlite3Engine::get_all_columns_from_table(const std::string& table_name) +{ + std::set<std::string> result; + char* errmsg; + std::string query{"PRAGMA table_info(" + table_name + ")"}; + int res = sqlite3_exec(this->db, query.data(), [](void* param, int columns_nb, char** columns, char**) -> int { + constexpr int name_column = 1; + std::set<std::string>* result = static_cast<std::set<std::string>*>(param); + if (name_column < columns_nb) + result->insert(utils::tolower(columns[name_column])); + return 0; + }, &result, &errmsg); + + if (res != SQLITE_OK) + { + log_error("Error executing ", query, ": ", errmsg); + sqlite3_free(errmsg); + } + + return result; +} + +std::unique_ptr<DatabaseEngine> Sqlite3Engine::open(const std::string& filename) +{ + sqlite3* new_db; + auto res = sqlite3_open_v2(filename.data(), &new_db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr); + if (res != SQLITE_OK) + { + log_error("Failed to open database file ", filename, ": ", sqlite3_errmsg(new_db)); + sqlite3_close(new_db); + throw std::runtime_error(""); + } + return std::make_unique<Sqlite3Engine>(new_db); +} + +std::tuple<bool, std::string> Sqlite3Engine::raw_exec(const std::string& query) +{ +#ifdef DEBUG_SQL_QUERIES + log_debug("SQL QUERY: ", query); + const auto timer = make_sql_timer(); +#endif + + char* error; + const auto result = sqlite3_exec(db, query.data(), nullptr, nullptr, &error); + if (result != SQLITE_OK) + { + std::string err_msg(error); + sqlite3_free(error); + return std::make_tuple(false, err_msg); + } + return std::make_tuple(true, std::string{}); +} + +std::unique_ptr<Statement> Sqlite3Engine::prepare(const std::string& query) +{ + sqlite3_stmt* stmt; + auto res = sqlite3_prepare(db, query.data(), static_cast<int>(query.size()) + 1, + &stmt, nullptr); + if (res != SQLITE_OK) + { + log_error("Error preparing statement: ", sqlite3_errmsg(db)); + return nullptr; + } + return std::make_unique<Sqlite3Statement>(stmt); +} + +void Sqlite3Engine::extract_last_insert_rowid(Statement&) +{ + this->last_inserted_rowid = sqlite3_last_insert_rowid(this->db); +} + +std::string Sqlite3Engine::id_column_type() +{ + return "INTEGER PRIMARY KEY AUTOINCREMENT"; +} + +#endif diff --git a/src/database/sqlite3_engine.hpp b/src/database/sqlite3_engine.hpp new file mode 100644 index 0000000..5b8176c --- /dev/null +++ b/src/database/sqlite3_engine.hpp @@ -0,0 +1,47 @@ +#pragma once + +#include <database/engine.hpp> + +#include <database/statement.hpp> + +#include <memory> +#include <string> +#include <tuple> +#include <set> + +#include <biboumi.h> + +#ifdef SQLITE3_FOUND + +#include <sqlite3.h> + +class Sqlite3Engine: public DatabaseEngine +{ + public: + Sqlite3Engine(sqlite3* db); + + ~Sqlite3Engine(); + + static std::unique_ptr<DatabaseEngine> open(const std::string& string); + + std::set<std::string> get_all_columns_from_table(const std::string& table_name) override final; + std::tuple<bool, std::string> raw_exec(const std::string& query) override final; + std::unique_ptr<Statement> prepare(const std::string& query) override; + void extract_last_insert_rowid(Statement& statement) override; + std::string id_column_type() override; +private: + sqlite3* const db; +}; + +#else + +class Sqlite3Engine +{ +public: + static std::unique_ptr<DatabaseEngine> open(const std::string& string) + { + throw std::runtime_error("Cannot open sqlite3 database "s + string + ": biboumi is not compiled with sqlite3 lib."); + } +}; + +#endif diff --git a/src/database/sqlite3_statement.hpp b/src/database/sqlite3_statement.hpp new file mode 100644 index 0000000..7738fa6 --- /dev/null +++ b/src/database/sqlite3_statement.hpp @@ -0,0 +1,92 @@ +#pragma once + +#include <database/statement.hpp> + +#include <logger/logger.hpp> + +#include <sqlite3.h> + +class Sqlite3Statement: public Statement +{ + public: + Sqlite3Statement(sqlite3_stmt* stmt): + stmt(stmt) {} + ~Sqlite3Statement() + { + sqlite3_finalize(this->stmt); + } + + StepResult step() override final + { + auto res = sqlite3_step(this->get()); + if (res == SQLITE_ROW) + return StepResult::Row; + else if (res == SQLITE_DONE) + return StepResult::Done; + else + return StepResult::Error; + } + + void bind(std::vector<std::string> params) override + { + int i = 1; + for (const std::string& param: params) + { + if (sqlite3_bind_text(this->get(), i, param.data(), static_cast<int>(param.size()), SQLITE_TRANSIENT) != SQLITE_OK) + log_error("Failed to bind ", param, " to param ", i); + i++; + } + } + + int64_t get_column_int64(const int col) override + { + return sqlite3_column_int64(this->get(), col); + } + + std::string get_column_text(const int col) override + { + const auto size = sqlite3_column_bytes(this->get(), col); + const unsigned char* str = sqlite3_column_text(this->get(), col); + std::string result(reinterpret_cast<const char*>(str), static_cast<std::size_t>(size)); + return result; + } + + bool bind_text(const int pos, const std::string& data) override + { + return sqlite3_bind_text(this->get(), pos, data.data(), static_cast<int>(data.size()), SQLITE_TRANSIENT) == SQLITE_OK; + } + bool bind_int64(const int pos, const std::int64_t value) override + { + return sqlite3_bind_int64(this->get(), pos, static_cast<sqlite3_int64>(value)) == SQLITE_OK; + } + bool bind_null(const int pos) override + { + return sqlite3_bind_null(this->get(), pos) == SQLITE_OK; + } + int get_column_int(const int col) override + { + return sqlite3_column_int(this->get(), col); + } + + Sqlite3Statement(const Sqlite3Statement&) = delete; + Sqlite3Statement& operator=(const Sqlite3Statement&) = delete; + Sqlite3Statement(Sqlite3Statement&& other): + stmt(other.stmt) + { + other.stmt = nullptr; + } + Sqlite3Statement& operator=(Sqlite3Statement&& other) + { + this->stmt = other.stmt; + other.stmt = nullptr; + return *this; + } + sqlite3_stmt* get() + { + return this->stmt; + } + + private: + sqlite3_stmt* stmt; + int last_step_result{SQLITE_OK}; +}; diff --git a/src/database/statement.hpp b/src/database/statement.hpp index 87cd70f..4a61928 100644 --- a/src/database/statement.hpp +++ b/src/database/statement.hpp @@ -1,35 +1,29 @@ #pragma once -#include <sqlite3.h> +#include <cstdint> +#include <string> +#include <vector> + +enum class StepResult +{ + Row, + Done, + Error, +}; class Statement { public: - Statement(sqlite3_stmt* stmt): - stmt(stmt) {} - ~Statement() - { - sqlite3_finalize(this->stmt); - } + virtual ~Statement() = default; + virtual StepResult step() = 0; + + virtual void bind(std::vector<std::string> params) = 0; - Statement(const Statement&) = delete; - Statement& operator=(const Statement&) = delete; - Statement(Statement&& other): - stmt(other.stmt) - { - other.stmt = nullptr; - } - Statement& operator=(Statement&& other) - { - this->stmt = other.stmt; - other.stmt = nullptr; - return *this; - } - sqlite3_stmt* get() - { - return this->stmt; - } + virtual std::int64_t get_column_int64(const int col) = 0; + virtual std::string get_column_text(const int col) = 0; + virtual int get_column_int(const int col) = 0; - private: - sqlite3_stmt* stmt; + virtual bool bind_text(const int pos, const std::string& data) = 0; + virtual bool bind_int64(const int pos, const std::int64_t value) = 0; + virtual bool bind_null(const int pos) = 0; }; diff --git a/src/database/table.cpp b/src/database/table.cpp deleted file mode 100644 index 9224d79..0000000 --- a/src/database/table.cpp +++ /dev/null @@ -1,23 +0,0 @@ -#include <database/table.hpp> - -std::set<std::string> get_all_columns_from_table(sqlite3* db, const std::string& table_name) -{ - std::set<std::string> result; - char* errmsg; - std::string query{"PRAGMA table_info(" + table_name + ")"}; - int res = sqlite3_exec(db, query.data(), [](void* param, int columns_nb, char** columns, char**) -> int { - constexpr int name_column = 1; - std::set<std::string>* result = static_cast<std::set<std::string>*>(param); - if (name_column < columns_nb) - result->insert(columns[name_column]); - return 0; - }, &result, &errmsg); - - if (res != SQLITE_OK) - { - log_error("Error executing ", query, ": ", errmsg); - sqlite3_free(errmsg); - } - - return result; -} diff --git a/src/database/table.hpp b/src/database/table.hpp index 0060211..680e7cc 100644 --- a/src/database/table.hpp +++ b/src/database/table.hpp @@ -1,7 +1,8 @@ #pragma once +#include <database/engine.hpp> + #include <database/select_query.hpp> -#include <database/type_to_sql.hpp> #include <database/row.hpp> #include <algorithm> @@ -10,23 +11,27 @@ using namespace std::string_literals; -std::set<std::string> get_all_columns_from_table(sqlite3* db, const std::string& table_name); +template <typename T> +std::string ToSQLType(DatabaseEngine& db) +{ + if (std::is_same<T, Id>::value) + return db.id_column_type(); + else if (std::is_same<typename T::real_type, std::string>::value) + return "TEXT"; + else + return "INTEGER"; +} template <typename ColumnType> -void add_column_to_table(sqlite3* db, const std::string& table_name) +void add_column_to_table(DatabaseEngine& db, const std::string& table_name) { const std::string name = ColumnType::name; - std::string query{"ALTER TABLE " + table_name + " ADD " + ColumnType::name + " " + TypeToSQLType<typename ColumnType::real_type>::type}; - char* error; - const auto result = sqlite3_exec(db, query.data(), nullptr, nullptr, &error); - if (result != SQLITE_OK) - { - log_error("Error adding column ", name, " to table ", table_name, ": ", error); - sqlite3_free(error); - } + std::string query{"ALTER TABLE " + table_name + " ADD " + ColumnType::name + " " + ToSQLType<ColumnType>(db)}; + auto res = db.raw_exec(query); + if (std::get<0>(res) == false) + log_error("Error adding column ", name, " to table ", table_name, ": ", std::get<1>(res)); } - template <typename ColumnType, decltype(ColumnType::options) = nullptr> void append_option(std::string& s) { @@ -50,27 +55,23 @@ class Table name(std::move(name)) {} - void upgrade(sqlite3* db) + void upgrade(DatabaseEngine& db) { - const auto existing_columns = get_all_columns_from_table(db, this->name); + const auto existing_columns = db.get_all_columns_from_table(this->name); add_column_if_not_exists(db, existing_columns); } - void create(sqlite3* db) + void create(DatabaseEngine& db) { - std::string res{"CREATE TABLE IF NOT EXISTS "}; - res += this->name; - res += " (\n"; - this->add_column_create(res); - res += ")"; - - char* error; - const auto result = sqlite3_exec(db, res.data(), nullptr, nullptr, &error); - if (result != SQLITE_OK) - { - log_error("Error executing query: ", error); - sqlite3_free(error); - } + std::string query{"CREATE TABLE IF NOT EXISTS "}; + query += this->name; + query += " ("; + this->add_column_create(db, query); + query += ")"; + + auto result = db.raw_exec(query); + if (std::get<0>(result) == false) + log_error("Error executing query: ", std::get<1>(result)); } RowType row() @@ -78,7 +79,7 @@ class Table return {this->name}; } - SelectQuery<T...> select() + auto select() { SelectQuery<T...> select(this->name); return select; @@ -93,39 +94,34 @@ class Table template <std::size_t N=0> typename std::enable_if<N < sizeof...(T), void>::type - add_column_if_not_exists(sqlite3* db, const std::set<std::string>& existing_columns) + add_column_if_not_exists(DatabaseEngine& db, const std::set<std::string>& existing_columns) { using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<ColumnTypes>()))>::type; - if (existing_columns.count(ColumnType::name) != 1) - { - add_column_to_table<ColumnType>(db, this->name); - } + if (existing_columns.count(ColumnType::name) == 0) + add_column_to_table<ColumnType>(db, this->name); add_column_if_not_exists<N+1>(db, existing_columns); } template <std::size_t N=0> typename std::enable_if<N == sizeof...(T), void>::type - add_column_if_not_exists(sqlite3*, const std::set<std::string>&) + add_column_if_not_exists(DatabaseEngine&, const std::set<std::string>&) {} template <std::size_t N=0> typename std::enable_if<N < sizeof...(T), void>::type - add_column_create(std::string& str) + add_column_create(DatabaseEngine& db, std::string& str) { using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<ColumnTypes>()))>::type; - using RealType = typename ColumnType::real_type; str += ColumnType::name; str += " "; - str += TypeToSQLType<RealType>::type; - append_option<ColumnType>(str); + str += ToSQLType<ColumnType>(db); if (N != sizeof...(T) - 1) str += ","; - str += "\n"; - add_column_create<N+1>(str); + add_column_create<N+1>(db, str); } template <std::size_t N=0> typename std::enable_if<N == sizeof...(T), void>::type - add_column_create(std::string&) + add_column_create(DatabaseEngine&, std::string&) { } const std::string name; diff --git a/src/database/type_to_sql.cpp b/src/database/type_to_sql.cpp deleted file mode 100644 index bcd9daa..0000000 --- a/src/database/type_to_sql.cpp +++ /dev/null @@ -1,9 +0,0 @@ -#include <database/type_to_sql.hpp> - -template <> const std::string TypeToSQLType<int>::type = "INTEGER"; -template <> const std::string TypeToSQLType<std::size_t>::type = "INTEGER"; -template <> const std::string TypeToSQLType<long>::type = "INTEGER"; -template <> const std::string TypeToSQLType<long long>::type = "INTEGER"; -template <> const std::string TypeToSQLType<bool>::type = "INTEGER"; -template <> const std::string TypeToSQLType<std::string>::type = "TEXT"; -template <> const std::string TypeToSQLType<OptionalBool>::type = "INTEGER";
\ No newline at end of file diff --git a/src/database/type_to_sql.hpp b/src/database/type_to_sql.hpp deleted file mode 100644 index ba806ab..0000000 --- a/src/database/type_to_sql.hpp +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once - -#include <utils/optional_bool.hpp> - -#include <string> - -template <typename T> -struct TypeToSQLType { static const std::string type; }; - -template <> const std::string TypeToSQLType<int>::type; -template <> const std::string TypeToSQLType<std::size_t>::type; -template <> const std::string TypeToSQLType<long>::type; -template <> const std::string TypeToSQLType<long long>::type; -template <> const std::string TypeToSQLType<bool>::type; -template <> const std::string TypeToSQLType<std::string>::type; -template <> const std::string TypeToSQLType<OptionalBool>::type;
\ No newline at end of file diff --git a/src/database/update_query.hpp b/src/database/update_query.hpp new file mode 100644 index 0000000..a29ac3f --- /dev/null +++ b/src/database/update_query.hpp @@ -0,0 +1,104 @@ +#pragma once + +#include <database/query.hpp> +#include <database/engine.hpp> + +using namespace std::string_literals; + +template <class T, class... Tuple> +struct Index; + +template <class T, class... Types> +struct Index<T, std::tuple<T, Types...>> +{ + static const std::size_t value = 0; +}; + +template <class T, class U, class... Types> +struct Index<T, std::tuple<U, Types...>> +{ + static const std::size_t value = Index<T, std::tuple<Types...>>::value + 1; +}; + +struct UpdateQuery: public Query +{ + template <typename... T> + UpdateQuery(const std::string& name, const std::tuple<T...>& columns): + Query("UPDATE ") + { + this->body += name; + this->insert_col_names_and_values(columns); + } + + template <typename... T> + void insert_col_names_and_values(const std::tuple<T...>& columns) + { + this->body += " SET "; + this->insert_col_name_and_value(columns); + this->body += " WHERE "s + Id::name + "=$" + std::to_string(this->current_param); + } + + template <int N=0, typename... T> + typename std::enable_if<N < sizeof...(T), void>::type + insert_col_name_and_value(const std::tuple<T...>& columns) + { + using ColumnType = std::decay_t<decltype(std::get<N>(columns))>; + + if (!std::is_same<ColumnType, Id>::value) + { + this->body += ColumnType::name + "=$"s + std::to_string(this->current_param); + this->current_param++; + + if (N < (sizeof...(T) - 1)) + this->body += ", "; + } + + this->insert_col_name_and_value<N+1>(columns); + } + template <int N=0, typename... T> + typename std::enable_if<N == sizeof...(T), void>::type + insert_col_name_and_value(const std::tuple<T...>&) + {} + + + template <typename... T> + void execute(DatabaseEngine& db, const std::tuple<T...>& columns) + { +#ifdef DEBUG_SQL_QUERIES + const auto timer = this->log_and_time(); +#endif + + auto statement = db.prepare(this->body); + this->bind_param(columns, *statement); + this->bind_id(columns, *statement); + + statement->step(); + } + + template <int N=0, typename... T> + typename std::enable_if<N < sizeof...(T), void>::type + bind_param(const std::tuple<T...>& columns, Statement& statement, int index=1) + { + auto&& column = std::get<N>(columns); + using ColumnType = std::decay_t<decltype(column)>; + + if (!std::is_same<ColumnType, Id>::value) + actual_bind(statement, column.value, index++); + + this->bind_param<N+1>(columns, statement, index); + } + + template <int N=0, typename... T> + typename std::enable_if<N == sizeof...(T), void>::type + bind_param(const std::tuple<T...>&, Statement&, int) + {} + + template <typename... T> + void bind_id(const std::tuple<T...>& columns, Statement& statement) + { + static constexpr auto index = Index<Id, std::tuple<T...>>::value; + auto&& value = std::get<index>(columns); + + actual_bind(statement, value.value, sizeof...(T)); + } +}; |