From 50cadf3dac0d56ef8181d1800cc30f8dcb749141 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?louiz=E2=80=99?= Date: Tue, 13 Jun 2017 10:38:39 +0200 Subject: Implement our own database ORM, and update the whole code to use it Entirely replace LiteSQL fix #3271 --- src/database/column.hpp | 13 +++ src/database/count_query.hpp | 35 ++++++++ src/database/database.cpp | 204 ++++++++++++++++++++---------------------- src/database/database.hpp | 153 +++++++++++++++++++++++-------- src/database/insert_query.hpp | 128 ++++++++++++++++++++++++++ src/database/query.cpp | 11 +++ src/database/query.hpp | 46 ++++++++++ src/database/row.hpp | 75 ++++++++++++++++ src/database/select_query.hpp | 153 +++++++++++++++++++++++++++++++ src/database/table.hpp | 84 +++++++++++++++++ src/database/type_to_sql.cpp | 7 ++ src/database/type_to_sql.hpp | 6 ++ 12 files changed, 770 insertions(+), 145 deletions(-) create mode 100644 src/database/column.hpp create mode 100644 src/database/count_query.hpp create mode 100644 src/database/insert_query.hpp create mode 100644 src/database/query.cpp create mode 100644 src/database/query.hpp create mode 100644 src/database/row.hpp create mode 100644 src/database/select_query.hpp create mode 100644 src/database/table.hpp create mode 100644 src/database/type_to_sql.cpp create mode 100644 src/database/type_to_sql.hpp (limited to 'src/database') diff --git a/src/database/column.hpp b/src/database/column.hpp new file mode 100644 index 0000000..e74d426 --- /dev/null +++ b/src/database/column.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include + +template +struct Column +{ + using real_type = T; + T value; +}; + +struct Id: Column { static constexpr auto name = "id_"; + static constexpr auto options = "PRIMARY KEY AUTOINCREMENT"; }; diff --git a/src/database/count_query.hpp b/src/database/count_query.hpp new file mode 100644 index 0000000..863fad1 --- /dev/null +++ b/src/database/count_query.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include +#include + +#include + +#include + +struct CountQuery: public Query +{ + CountQuery(std::string name): + Query("SELECT count(*) FROM ") + { + this->body += std::move(name); + } + + std::size_t execute(sqlite3* db) + { + auto statement = this->prepare(db); + std::size_t res = 0; + if (sqlite3_step(statement) == SQLITE_ROW) + res = sqlite3_column_int64(statement, 0); + else + { + log_error("Count request didn’t return a result"); + return 0; + } + if (sqlite3_step(statement) != SQLITE_DONE) + log_warning("Count request returned more than one result."); + + log_debug("Returning count: ", res); + return res; + } +}; diff --git a/src/database/database.cpp b/src/database/database.cpp index 9f310da..1738a69 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -2,175 +2,166 @@ #ifdef USE_DATABASE #include -#include -#include #include #include #include -using namespace std::string_literals; +#include -std::unique_ptr Database::db; +sqlite3* Database::db; +Database::MucLogLineTable Database::muc_log_lines("MucLogLine_"); +Database::GlobalOptionsTable Database::global_options("GlobalOptions_"); +Database::IrcServerOptionsTable Database::irc_server_options("IrcServerOptions_"); +Database::IrcChannelOptionsTable Database::irc_channel_options("IrcChannelOptions_"); -void Database::open(const std::string& filename, const std::string& db_type) +void Database::open(const std::string& filename) { - try - { - auto new_db = std::make_unique(db_type, - "database="s + filename); - if (new_db->needsUpgrade()) - new_db->upgrade(); - Database::db = std::move(new_db); - } catch (const litesql::DatabaseError& e) { - log_error("Failed to open database ", filename, ". ", e.what()); - throw; - } + auto res = sqlite3_open_v2(filename.data(), &Database::db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr); + log_debug("open: ", res); + Database::muc_log_lines.create(Database::db); + Database::global_options.create(Database::db); + Database::irc_server_options.create(Database::db); + Database::irc_channel_options.create(Database::db); } -void Database::set_verbose(const bool val) -{ - Database::db->verbose = val; -} -db::GlobalOptions Database::get_global_options(const std::string& owner) +Database::GlobalOptions Database::get_global_options(const std::string& owner) { - try { - auto options = litesql::select(*Database::db, - db::GlobalOptions::Owner == owner).one(); - return options; - } catch (const litesql::NotFound& e) { - db::GlobalOptions options(*Database::db); - options.owner = owner; - return options; - } + auto request = Database::global_options.select().where() << Owner{} << "=" << owner; + + Database::GlobalOptions options{Database::global_options.get_name()}; + auto result = request.execute(Database::db); + if (result.size() == 1) + options = result.front(); + else + options.col() = owner; + return options; } -db::IrcServerOptions Database::get_irc_server_options(const std::string& owner, - const std::string& server) +Database::IrcServerOptions Database::get_irc_server_options(const std::string& owner, const std::string& server) { - try { - auto options = litesql::select(*Database::db, - db::IrcServerOptions::Owner == owner && - db::IrcServerOptions::Server == server).one(); - return options; - } catch (const litesql::NotFound& e) { - db::IrcServerOptions options(*Database::db); - options.owner = owner; - options.server = server; - // options.update(); - return options; - } + auto request = Database::irc_server_options.select().where() << Owner{} << "=" << owner << " and " << Server{} << "=" << server; + + Database::IrcServerOptions options{Database::irc_server_options.get_name()}; + auto result = request.execute(Database::db); + if (result.size() == 1) + options = result.front(); + else + { + options.col() = owner; + options.col() = server; + } + return options; } -db::IrcChannelOptions Database::get_irc_channel_options(const std::string& owner, - const std::string& server, - const std::string& channel) +Database::IrcChannelOptions Database::get_irc_channel_options(const std::string& owner, const std::string& server, const std::string& channel) { - try { - auto options = litesql::select(*Database::db, - db::IrcChannelOptions::Owner == owner && - db::IrcChannelOptions::Server == server && - db::IrcChannelOptions::Channel == channel).one(); - return options; - } catch (const litesql::NotFound& e) { - db::IrcChannelOptions options(*Database::db); - options.owner = owner; - options.server = server; - options.channel = channel; - return options; - } + auto request = Database::irc_channel_options.select().where() << Owner{} << "=" << owner <<\ + " and " << Server{} << "=" << server <<\ + " and " << Channel{} << "=" << channel; + Database::IrcChannelOptions options{Database::irc_channel_options.get_name()}; + auto result = request.execute(Database::db); + if (result.size() == 1) + options = result.front(); + else + { + options.col() = owner; + options.col() = server; + options.col() = channel; + } + return options; } -db::IrcChannelOptions Database::get_irc_channel_options_with_server_default(const std::string& owner, - const std::string& server, - const std::string& channel) +Database::IrcChannelOptions Database::get_irc_channel_options_with_server_default(const std::string& owner, const std::string& server, + const std::string& channel) { auto coptions = Database::get_irc_channel_options(owner, server, channel); auto soptions = Database::get_irc_server_options(owner, server); - coptions.encodingIn = get_first_non_empty(coptions.encodingIn.value(), - soptions.encodingIn.value()); - coptions.encodingOut = get_first_non_empty(coptions.encodingOut.value(), - soptions.encodingOut.value()); + coptions.col() = get_first_non_empty(coptions.col(), + soptions.col()); + coptions.col() = get_first_non_empty(coptions.col(), + soptions.col()); - coptions.maxHistoryLength = get_first_non_empty(coptions.maxHistoryLength.value(), - soptions.maxHistoryLength.value()); + coptions.col() = get_first_non_empty(coptions.col(), + soptions.col()); return coptions; } -db::IrcChannelOptions Database::get_irc_channel_options_with_server_and_global_default(const std::string& owner, - const std::string& server, - const std::string& channel) +Database::IrcChannelOptions Database::get_irc_channel_options_with_server_and_global_default(const std::string& owner, const std::string& server, const std::string& channel) { auto coptions = Database::get_irc_channel_options(owner, server, channel); auto soptions = Database::get_irc_server_options(owner, server); auto goptions = Database::get_global_options(owner); - coptions.encodingIn = get_first_non_empty(coptions.encodingIn.value(), - soptions.encodingIn.value()); - coptions.encodingOut = get_first_non_empty(coptions.encodingOut.value(), - soptions.encodingOut.value()); + coptions.col() = get_first_non_empty(coptions.col(), + soptions.col()); - coptions.maxHistoryLength = get_first_non_empty(coptions.maxHistoryLength.value(), - soptions.maxHistoryLength.value(), - goptions.maxHistoryLength.value()); + coptions.col() = get_first_non_empty(coptions.col(), + soptions.col()); + + coptions.col() = get_first_non_empty(coptions.col(), + soptions.col(), + goptions.col()); return coptions; } -std::string Database::store_muc_message(const std::string& owner, const Iid& iid, - Database::time_point date, - const std::string& body, - const std::string& nick) +std::string Database::store_muc_message(const std::string& owner, const std::string& chan_name, + const std::string& server_name, Database::time_point date, + const std::string& body, const std::string& nick) { - db::MucLogLine line(*Database::db); + auto line = Database::muc_log_lines.row(); auto uuid = Database::gen_uuid(); - line.uuid = uuid; - line.owner = owner; - line.ircChanName = iid.get_local(); - line.ircServerName = iid.get_server(); - line.date = std::chrono::duration_cast(date.time_since_epoch()).count(); - line.body = body; - line.nick = nick; + line.col() = uuid; + line.col() = owner; + line.col() = chan_name; + line.col() = server_name; + line.col() = std::chrono::duration_cast(date.time_since_epoch()).count(); + line.col() = body; + line.col() = nick; - line.update(); + line.save(Database::db); return uuid; } -std::vector Database::get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server, +std::vector Database::get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server, int limit, const std::string& start, const std::string& end) { - auto request = litesql::select(*Database::db, - db::MucLogLine::Owner == owner && - db::MucLogLine::IrcChanName == chan_name && - db::MucLogLine::IrcServerName == server); - request.orderBy(db::MucLogLine::Id, false); + auto request = Database::muc_log_lines.select().where() << Database::Owner{} << "=" << owner << \ + " and " << Database::IrcChanName{} << "=" << chan_name << \ + " and " << Database::IrcServerName{} << "=" << server; - if (limit >= 0) - request.limit(limit); if (!start.empty()) { const auto start_time = utils::parse_datetime(start); if (start_time != -1) - request.where(db::MucLogLine::Date >= start_time); + request << " and " << Database::Date{} << ">=" << start_time; } if (!end.empty()) { const auto end_time = utils::parse_datetime(end); if (end_time != -1) - request.where(db::MucLogLine::Date <= end_time); + request << " and " << Database::Date{} << "<=" << end_time; } - const auto& res = request.all(); - return {res.crbegin(), res.crend()}; + + request.order_by() << Id{} << " DESC "; + + if (limit >= 0) + request.limit() << limit; + + auto result = request.execute(Database::db); + + return {result.crbegin(), result.crend()}; } void Database::close() { - Database::db.reset(nullptr); + sqlite3_close_v2(Database::db); } std::string Database::gen_uuid() @@ -182,5 +173,4 @@ std::string Database::gen_uuid() return uuid_str; } - -#endif +#endif \ No newline at end of file diff --git a/src/database/database.hpp b/src/database/database.hpp index b08a175..ebc4878 100644 --- a/src/database/database.hpp +++ b/src/database/database.hpp @@ -1,22 +1,101 @@ #pragma once - #include #ifdef USE_DATABASE -#include "biboudb.hpp" - -#include +#include +#include +#include -#include #include +#include + +#include -class Iid; class Database { -public: + public: using time_point = std::chrono::system_clock::time_point; + + struct Uuid: Column { static constexpr auto name = "uuid_"; + static constexpr auto options = ""; }; + + struct Owner: Column { static constexpr auto name = "owner_"; + static constexpr auto options = ""; }; + + struct IrcChanName: Column { static constexpr auto name = "ircChanName_"; + static constexpr auto options = ""; }; + + struct Channel: Column { static constexpr auto name = "channel_"; + static constexpr auto options = ""; }; + + struct IrcServerName: Column { static constexpr auto name = "ircServerName_"; + static constexpr auto options = ""; }; + + struct Server: Column { static constexpr auto name = "server_"; + static constexpr auto options = ""; }; + + struct Date: Column { static constexpr auto name = "date_"; + static constexpr auto options = ""; }; + + struct Body: Column { static constexpr auto name = "body_"; + static constexpr auto options = ""; }; + + struct Nick: Column { static constexpr auto name = "nick_"; + static constexpr auto options = ""; }; + + struct Pass: Column { static constexpr auto name = "pass_"; + static constexpr auto options = ""; }; + + struct Ports: Column { static constexpr auto name = "tlsPorts_"; + static constexpr auto options = ""; }; + + struct TlsPorts: Column { static constexpr auto name = "ports_"; + static constexpr auto options = ""; }; + + struct Username: Column { static constexpr auto name = "username_"; + static constexpr auto options = ""; }; + + struct Realname: Column { static constexpr auto name = "realname_"; + static constexpr auto options = ""; }; + + struct AfterConnectionCommand: Column { static constexpr auto name = "afterConnectionCommand_"; + static constexpr auto options = ""; }; + + struct TrustedFingerprint: Column { static constexpr auto name = "trustedFingerprint_"; + static constexpr auto options = ""; }; + + struct EncodingOut: Column { static constexpr auto name = "encodingOut_"; + static constexpr auto options = ""; }; + + struct EncodingIn: Column { static constexpr auto name = "encodingIn_"; + static constexpr auto options = ""; }; + + struct MaxHistoryLength: Column { static constexpr auto name = "maxHistoryLength_"; + static constexpr auto options = ""; }; + + struct RecordHistory: Column { static constexpr auto name = "recordHistory_"; + static constexpr auto options = ""; }; + + struct VerifyCert: Column { static constexpr auto name = "verifyCert_"; + static constexpr auto options = ""; }; + + struct Persistent: Column { static constexpr auto name = "persistent_"; + static constexpr auto options = ""; }; + + using MucLogLineTable = Table; + using MucLogLine = MucLogLineTable::RowType; + + using GlobalOptionsTable = Table; + using GlobalOptions = GlobalOptionsTable::RowType; + + using IrcServerOptionsTable = Table; + using IrcServerOptions = IrcServerOptionsTable::RowType; + + using IrcChannelOptionsTable = Table; + using IrcChannelOptions = IrcChannelOptionsTable::RowType; + Database() = default; ~Database() = default; @@ -25,42 +104,40 @@ public: Database& operator=(const Database&) = delete; Database& operator=(Database&&) = delete; - static void set_verbose(const bool val); - - template - static size_t count() - { - return litesql::select(*Database::db).count(); - } - /** - * Return the object from the db. Create it beforehand (with all default - * values) if it is not already present. - */ - static db::GlobalOptions get_global_options(const std::string& owner); - static db::IrcServerOptions get_irc_server_options(const std::string& owner, + static GlobalOptions get_global_options(const std::string& owner); + static IrcServerOptions get_irc_server_options(const std::string& owner, const std::string& server); - static db::IrcChannelOptions get_irc_channel_options(const std::string& owner, - const std::string& server, - const std::string& channel); - static db::IrcChannelOptions get_irc_channel_options_with_server_default(const std::string& owner, - const std::string& server, - const std::string& channel); - static db::IrcChannelOptions get_irc_channel_options_with_server_and_global_default(const std::string& owner, - const std::string& server, - const std::string& channel); - static std::vector get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server, - int limit=-1, const std::string& start="", const std::string& end=""); - static std::string store_muc_message(const std::string& owner, const Iid& iid, - time_point date, const std::string& body, const std::string& nick); + static IrcChannelOptions get_irc_channel_options(const std::string& owner, + const std::string& server, + const std::string& channel); + static IrcChannelOptions get_irc_channel_options_with_server_default(const std::string& owner, + const std::string& server, + const std::string& channel); + static IrcChannelOptions get_irc_channel_options_with_server_and_global_default(const std::string& owner, + const std::string& server, + const std::string& channel); + static std::vector get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server, + int limit=-1, const std::string& start="", const std::string& end=""); + static std::string store_muc_message(const std::string& owner, const std::string& chan_name, const std::string& server_name, + time_point date, const std::string& body, const std::string& nick); static void close(); - static void open(const std::string& filename, const std::string& db_type="sqlite3"); + static void open(const std::string& filename); + template + static std::size_t count(const TableType& table) + { + CountQuery query{table.get_name()}; + return query.execute(Database::db); + } + + static MucLogLineTable muc_log_lines; + static GlobalOptionsTable global_options; + static IrcServerOptionsTable irc_server_options; + static IrcChannelOptionsTable irc_channel_options; + static sqlite3* db; -private: + private: static std::string gen_uuid(); - static std::unique_ptr db; }; #endif /* USE_DATABASE */ - - diff --git a/src/database/insert_query.hpp b/src/database/insert_query.hpp new file mode 100644 index 0000000..00b77c5 --- /dev/null +++ b/src/database/insert_query.hpp @@ -0,0 +1,128 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +#include + +template +typename std::enable_if, Id>::value, void>::type +actual_bind(sqlite3_stmt* statement, std::vector& params, const std::tuple&) +{ + const auto value = params.front(); + params.erase(params.begin()); + if (sqlite3_bind_text(statement, N + 1, value.data(), static_cast(value.size()), SQLITE_TRANSIENT) != SQLITE_OK) + log_error("Failed to bind ", value, " to param ", N); + else + log_debug("Bound (not id) [", value, "] to ", N); +} + +template +typename std::enable_if, Id>::value, void>::type +actual_bind(sqlite3_stmt* statement, std::vector&, const std::tuple& columns) +{ + auto&& column = std::get(columns); + if (column.value != 0) + { + if (sqlite3_bind_int64(statement, N + 1, column.value) != SQLITE_OK) + log_error("Failed to bind ", column.value, " to id."); + } + else if (sqlite3_bind_null(statement, N + 1) != SQLITE_OK) + log_error("Failed to bind NULL to param ", N); + else + log_debug("Bound NULL to ", N); +} + +struct InsertQuery: public Query +{ + InsertQuery(const std::string& name): + Query("INSERT OR REPLACE INTO ") + { + this->body += name; + } + + template + void execute(const std::tuple& columns, sqlite3* db) + { + auto statement = this->prepare(db); + { + this->bind_param<0>(columns, statement); + if (sqlite3_step(statement) != SQLITE_DONE) + log_error("Failed to execute query: ", sqlite3_errmsg(db)); + } + } + + template + typename std::enable_if::type + bind_param(const std::tuple& columns, sqlite3_stmt* statement) + { + using ColumnType = typename std::remove_reference(columns))>::type; + + actual_bind(statement, this->params, columns); + this->bind_param(columns, statement); + } + + template + typename std::enable_if::type + bind_param(const std::tuple&, sqlite3_stmt*) + {} + + template + void insert_values(const std::tuple& columns) + { + this->body += "VALUES ("; + this->insert_value<0>(columns); + this->body += ")"; + } + + template + typename std::enable_if::type + insert_value(const std::tuple& columns) + { + this->body += "?"; + if (N != sizeof...(T) - 1) + this->body += ","; + this->body += " "; + add_param(*this, std::get(columns)); + this->insert_value(columns); + } + template + typename std::enable_if::type + insert_value(const std::tuple&) + { } + + template + void insert_col_names(const std::tuple& columns) + { + this->body += " ("; + this->insert_col_name<0>(columns); + this->body += ")\n"; + } + + template + typename std::enable_if::type + insert_col_name(const std::tuple& columns) + { + auto value = std::get(columns); + + this->body += value.name; + + if (N < (sizeof...(T) - 1)) + this->body += ", "; + + this->insert_col_name(columns); + } + template + typename std::enable_if::type + insert_col_name(const std::tuple&) + {} + + + private: +}; diff --git a/src/database/query.cpp b/src/database/query.cpp new file mode 100644 index 0000000..fb8c055 --- /dev/null +++ b/src/database/query.cpp @@ -0,0 +1,11 @@ +#include +#include + +template <> +void add_param(Query&, const Id&) +{} + +void actual_add_param(Query& query, const std::string& val) +{ + query.params.push_back(val); +} diff --git a/src/database/query.hpp b/src/database/query.hpp new file mode 100644 index 0000000..92845d0 --- /dev/null +++ b/src/database/query.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include + +#include +#include + +#include + +struct Query +{ + std::string body; + std::vector params; + + Query(std::string str): + body(std::move(str)) + {} + + sqlite3_stmt* prepare(sqlite3* db) + { + sqlite3_stmt* statement; + log_debug(this->body); + auto res = sqlite3_prepare(db, this->body.data(), static_cast(this->body.size()) + 1, + &statement, nullptr); + if (res != SQLITE_OK) + { + log_error("Error preparing statement: ", sqlite3_errmsg(db)); + return nullptr; + } + return statement; + } +}; + +template +void add_param(Query& query, const ColumnType& column) +{ + actual_add_param(query, column.value); +} + +template +void actual_add_param(Query& query, const T& val) +{ + query.params.push_back(std::to_string(val)); +} + +void actual_add_param(Query& query, const std::string& val); diff --git a/src/database/row.hpp b/src/database/row.hpp new file mode 100644 index 0000000..ca686c1 --- /dev/null +++ b/src/database/row.hpp @@ -0,0 +1,75 @@ +#pragma once + +#include +#include + +#include + +#include + +template +typename std::enable_if, Id>::value, void>::type +update_id(std::tuple&, sqlite3*) +{} + +template +typename std::enable_if, Id>::value, void>::type +update_id(std::tuple& columns, sqlite3* db) +{ + auto&& column = std::get(columns); + log_debug("Found an autoincrement col."); + auto res = sqlite3_last_insert_rowid(db); + log_debug("Value is now: ", res); + column.value = res; +} + +template +typename std::enable_if::type +update_autoincrement_id(std::tuple& columns, sqlite3* db) +{ + using ColumnType = typename std::remove_reference(columns))>::type; + update_id(columns, db); + update_autoincrement_id(columns, db); +} + +template +typename std::enable_if::type +update_autoincrement_id(std::tuple&, sqlite3*) +{} + +template +struct Row +{ + Row(std::string name): + table_name(std::move(name)) + {} + + template + auto& col() + { + auto&& col = std::get(this->columns); + return col.value; + } + + template + const auto& col() const + { + auto&& col = std::get(this->columns); + return col.value; + } + + void save(sqlite3* db) + { + InsertQuery query(this->table_name); + query.insert_col_names(this->columns); + query.insert_values(this->columns); + log_debug(query.body); + + query.execute(this->columns, db); + + update_autoincrement_id<0>(this->columns, db); + } + + std::tuple columns; + std::string table_name; +}; diff --git a/src/database/select_query.hpp b/src/database/select_query.hpp new file mode 100644 index 0000000..b41632e --- /dev/null +++ b/src/database/select_query.hpp @@ -0,0 +1,153 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include + +using namespace std::string_literals; + +template +typename std::enable_if::value, sqlite3_int64>::type +extract_row_value(sqlite3_stmt* statement, const int i) +{ + return sqlite3_column_int64(statement, i); +} + +template +typename std::enable_if::value, std::string>::type +extract_row_value(sqlite3_stmt* statement, const int i) +{ + const auto size = sqlite3_column_bytes(statement, i); + const unsigned char* str = sqlite3_column_text(statement, i); + std::string result(reinterpret_cast(str), size); + return result; +} + +template +typename std::enable_if::type +extract_row_values(Row& row, sqlite3_stmt* statement) +{ + using ColumnType = typename std::remove_reference(row.columns))>::type; + + auto&& column = std::get(row.columns); + column.value = static_cast(extract_row_value(statement, N)); + + extract_row_values(row, statement); +} + +template +typename std::enable_if::type +extract_row_values(Row&, sqlite3_stmt*) +{} + +template +struct SelectQuery: public Query +{ + SelectQuery(std::string table_name): + Query("SELECT"), + table_name(table_name) + { + this->insert_col_name<0>(); + this->body += " from " + this->table_name; + } + + template + typename std::enable_if::type + insert_col_name() + { + using ColumnsType = std::tuple; + ColumnsType tuple{}; + auto value = std::get(tuple); + + this->body += " "s + value.name; + + if (N < (sizeof...(T) - 1)) + this->body += ", "; + + this->insert_col_name(); + } + template + typename std::enable_if::type + insert_col_name() + {} + + SelectQuery& where() + { + this->body += " WHERE "; + return *this; + }; + + SelectQuery& order_by() + { + this->body += " ORDER BY "; + return *this; + } + + SelectQuery& limit() + { + this->body += " LIMIT "; + return *this; + } + + auto execute(sqlite3* db) + { + auto statement = this->prepare(db); + int i = 1; + for (const std::string& param: this->params) + { + if (sqlite3_bind_text(statement, i, param.data(), static_cast(param.size()), SQLITE_TRANSIENT) != SQLITE_OK) + log_debug("Failed to bind ", param, " to param ", i); + else + log_debug("Bound ", param, " to ", i); + + i++; + } + std::vector> rows; + while (sqlite3_step(statement) == SQLITE_ROW) + { + Row row(this->table_name); + extract_row_values(row, statement); + rows.push_back(row); + } + return rows; + } + + const std::string table_name; +}; + +template +typename std::enable_if::value, SelectQuery&>::type +operator<<(SelectQuery& query, const T&) +{ + query.body += T::name; + return query; +} + +template +SelectQuery& operator<<(SelectQuery& query, const char* str) +{ + query.body += str; + return query; +} + +template +SelectQuery& operator<<(SelectQuery& query, const std::string& str) +{ + query.body += "?"; + actual_add_param(query, str); + return query; +} + +template +typename std::enable_if::value, SelectQuery&>::type +operator<<(SelectQuery& query, const Integer& i) +{ + query.body += "?"; + actual_add_param(query, i); + return query; +} diff --git a/src/database/table.hpp b/src/database/table.hpp new file mode 100644 index 0000000..163b97f --- /dev/null +++ b/src/database/table.hpp @@ -0,0 +1,84 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +template +class Table +{ + static_assert(sizeof...(T) > 0, "Table cannot be empty"); + using ColumnTypes = std::tuple; + + public: + using RowType = Row; + + Table(std::string name): + name(std::move(name)) + {} + + void create(sqlite3* db) + { + std::string res{"CREATE TABLE IF NOT EXISTS "}; + res += this->name; + res += " (\n"; + this->add_column_create<0>(res); + res += ")"; + + log_debug(res); + + char* error; + const auto result = sqlite3_exec(db, res.data(), nullptr, nullptr, &error); + log_debug("result: ", +result); + if (result != SQLITE_OK) + { + log_error("Error executing query: ", error); + sqlite3_free(error); + } + } + + RowType row() + { + return {this->name}; + } + + SelectQuery select() + { + SelectQuery select(this->name); + return select; + } + + const std::string& get_name() const + { + return this->name; + } + + private: + template + typename std::enable_if::type + add_column_create(std::string& str) + { + using ColumnType = typename std::remove_reference(std::declval()))>::type; + using RealType = typename ColumnType::real_type; + str += ColumnType::name; + str += " "; + str += TypeToSQLType::type; + str += " "s + ColumnType::options; + if (N != sizeof...(T) - 1) + str += ","; + str += "\n"; + + add_column_create(str); + } + + template + typename std::enable_if::type + add_column_create(std::string&) + { } + + const std::string name; +}; diff --git a/src/database/type_to_sql.cpp b/src/database/type_to_sql.cpp new file mode 100644 index 0000000..5de012e --- /dev/null +++ b/src/database/type_to_sql.cpp @@ -0,0 +1,7 @@ +#include + +template <> const std::string TypeToSQLType::type = "INTEGER"; +template <> const std::string TypeToSQLType::type = "INTEGER"; +template <> const std::string TypeToSQLType::type = "INTEGER"; +template <> const std::string TypeToSQLType::type = "INTEGER"; +template <> const std::string TypeToSQLType::type = "TEXT"; diff --git a/src/database/type_to_sql.hpp b/src/database/type_to_sql.hpp new file mode 100644 index 0000000..5ae03ad --- /dev/null +++ b/src/database/type_to_sql.hpp @@ -0,0 +1,6 @@ +#pragma once + +#include + +template +struct TypeToSQLType { static const std::string type; }; -- cgit v1.2.3