diff options
Diffstat (limited to 'src/database')
-rw-r--r-- | src/database/column.hpp | 17 | ||||
-rw-r--r-- | src/database/count_query.hpp | 35 | ||||
-rw-r--r-- | src/database/database.cpp | 223 | ||||
-rw-r--r-- | src/database/database.hpp | 162 | ||||
-rw-r--r-- | src/database/insert_query.hpp | 129 | ||||
-rw-r--r-- | src/database/query.cpp | 34 | ||||
-rw-r--r-- | src/database/query.hpp | 90 | ||||
-rw-r--r-- | src/database/row.hpp | 75 | ||||
-rw-r--r-- | src/database/select_query.hpp | 127 | ||||
-rw-r--r-- | src/database/statement.hpp | 35 | ||||
-rw-r--r-- | src/database/table.cpp | 25 | ||||
-rw-r--r-- | src/database/table.hpp | 127 | ||||
-rw-r--r-- | src/database/type_to_sql.cpp | 9 | ||||
-rw-r--r-- | src/database/type_to_sql.hpp | 16 |
14 files changed, 962 insertions, 142 deletions
diff --git a/src/database/column.hpp b/src/database/column.hpp new file mode 100644 index 0000000..111f9ca --- /dev/null +++ b/src/database/column.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include <cstddef> + +template <typename T> +struct Column +{ + Column(T default_value): + value{default_value} {} + Column(): + value{} {} + using real_type = T; + T value{}; +}; + +struct Id: Column<std::size_t> { static constexpr auto name = "id_"; + static constexpr auto options = "PRIMARY KEY AUTOINCREMENT"; }; diff --git a/src/database/count_query.hpp b/src/database/count_query.hpp new file mode 100644 index 0000000..b7bbf51 --- /dev/null +++ b/src/database/count_query.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include <database/query.hpp> +#include <database/table.hpp> + +#include <string> + +#include <sqlite3.h> + +struct CountQuery: public Query +{ + CountQuery(std::string name): + Query("SELECT count(*) FROM ") + { + this->body += std::move(name); + } + + int64_t execute(sqlite3* db) + { + auto statement = this->prepare(db); + int64_t res = 0; + if (sqlite3_step(statement.get()) == SQLITE_ROW) + res = sqlite3_column_int64(statement.get(), 0); + else + { + log_error("Count request didn’t return a result"); + return 0; + } + if (sqlite3_step(statement.get()) != SQLITE_DONE) + log_warning("Count request returned more than one result."); + + log_debug("Returning count: ", res); + return res; + } +}; diff --git a/src/database/database.cpp b/src/database/database.cpp index f7d309b..92f7682 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -2,171 +2,185 @@ #ifdef USE_DATABASE #include <database/database.hpp> -#include <logger/logger.hpp> -#include <irc/iid.hpp> #include <uuid/uuid.h> #include <utils/get_first_non_empty.hpp> #include <utils/time.hpp> -using namespace std::string_literals; +#include <sqlite3.h> -std::unique_ptr<db::BibouDB> Database::db; +sqlite3* Database::db; +Database::MucLogLineTable Database::muc_log_lines("MucLogLine_"); +Database::GlobalOptionsTable Database::global_options("GlobalOptions_"); +Database::IrcServerOptionsTable Database::irc_server_options("IrcServerOptions_"); +Database::IrcChannelOptionsTable Database::irc_channel_options("IrcChannelOptions_"); -void Database::open(const std::string& filename, const std::string& db_type) +void Database::open(const std::string& filename) { - try + // Try to open the specified database. + // Close and replace the previous database pointer if it succeeded. If it did + // not, just leave things untouched + sqlite3* new_db; + auto res = sqlite3_open_v2(filename.data(), &new_db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr); + if (res != SQLITE_OK) { - auto new_db = std::make_unique<db::BibouDB>(db_type, - "database="s + filename); - if (new_db->needsUpgrade()) - new_db->upgrade(); - Database::db.reset(new_db.release()); - } catch (const litesql::DatabaseError& e) { - log_error("Failed to open database ", filename, ". ", e.what()); - throw; + log_error("Failed to open database file ", filename, ": ", sqlite3_errmsg(Database::db)); + throw std::runtime_error(""); } + Database::close(); + Database::db = new_db; + Database::muc_log_lines.create(Database::db); + Database::muc_log_lines.upgrade(Database::db); + Database::global_options.create(Database::db); + Database::global_options.upgrade(Database::db); + Database::irc_server_options.create(Database::db); + Database::irc_server_options.upgrade(Database::db); + Database::irc_channel_options.create(Database::db); + Database::irc_channel_options.upgrade(Database::db); } -void Database::set_verbose(const bool val) -{ - Database::db->verbose = val; -} -db::GlobalOptions Database::get_global_options(const std::string& owner) +Database::GlobalOptions Database::get_global_options(const std::string& owner) { - try { - auto options = litesql::select<db::GlobalOptions>(*Database::db, - db::GlobalOptions::Owner == owner).one(); - return options; - } catch (const litesql::NotFound& e) { - db::GlobalOptions options(*Database::db); - options.owner = owner; - return options; - } + auto request = Database::global_options.select(); + request.where() << Owner{} << "=" << owner; + + Database::GlobalOptions options{Database::global_options.get_name()}; + auto result = request.execute(Database::db); + if (result.size() == 1) + options = result.front(); + else + options.col<Owner>() = owner; + return options; } -db::IrcServerOptions Database::get_irc_server_options(const std::string& owner, - const std::string& server) +Database::IrcServerOptions Database::get_irc_server_options(const std::string& owner, const std::string& server) { - try { - auto options = litesql::select<db::IrcServerOptions>(*Database::db, - db::IrcServerOptions::Owner == owner && - db::IrcServerOptions::Server == server).one(); - return options; - } catch (const litesql::NotFound& e) { - db::IrcServerOptions options(*Database::db); - options.owner = owner; - options.server = server; - // options.update(); - return options; - } + auto request = Database::irc_server_options.select(); + request.where() << Owner{} << "=" << owner << " and " << Server{} << "=" << server; + + Database::IrcServerOptions options{Database::irc_server_options.get_name()}; + auto result = request.execute(Database::db); + if (result.size() == 1) + options = result.front(); + else + { + options.col<Owner>() = owner; + options.col<Server>() = server; + } + return options; } -db::IrcChannelOptions Database::get_irc_channel_options(const std::string& owner, - const std::string& server, - const std::string& channel) +Database::IrcChannelOptions Database::get_irc_channel_options(const std::string& owner, const std::string& server, const std::string& channel) { - try { - auto options = litesql::select<db::IrcChannelOptions>(*Database::db, - db::IrcChannelOptions::Owner == owner && - db::IrcChannelOptions::Server == server && - db::IrcChannelOptions::Channel == channel).one(); - return options; - } catch (const litesql::NotFound& e) { - db::IrcChannelOptions options(*Database::db); - options.owner = owner; - options.server = server; - options.channel = channel; - return options; - } + auto request = Database::irc_channel_options.select(); + request.where() << Owner{} << "=" << owner <<\ + " and " << Server{} << "=" << server <<\ + " and " << Channel{} << "=" << channel; + Database::IrcChannelOptions options{Database::irc_channel_options.get_name()}; + auto result = request.execute(Database::db); + if (result.size() == 1) + options = result.front(); + else + { + options.col<Owner>() = owner; + options.col<Server>() = server; + options.col<Channel>() = channel; + } + return options; } -db::IrcChannelOptions Database::get_irc_channel_options_with_server_default(const std::string& owner, - const std::string& server, - const std::string& channel) +Database::IrcChannelOptions Database::get_irc_channel_options_with_server_default(const std::string& owner, const std::string& server, + const std::string& channel) { auto coptions = Database::get_irc_channel_options(owner, server, channel); auto soptions = Database::get_irc_server_options(owner, server); - coptions.encodingIn = get_first_non_empty(coptions.encodingIn.value(), - soptions.encodingIn.value()); - coptions.encodingOut = get_first_non_empty(coptions.encodingOut.value(), - soptions.encodingOut.value()); + coptions.col<EncodingIn>() = get_first_non_empty(coptions.col<EncodingIn>(), + soptions.col<EncodingIn>()); + coptions.col<EncodingOut>() = get_first_non_empty(coptions.col<EncodingOut>(), + soptions.col<EncodingOut>()); - coptions.maxHistoryLength = get_first_non_empty(coptions.maxHistoryLength.value(), - soptions.maxHistoryLength.value()); + coptions.col<MaxHistoryLength>() = get_first_non_empty(coptions.col<MaxHistoryLength>(), + soptions.col<MaxHistoryLength>()); return coptions; } -db::IrcChannelOptions Database::get_irc_channel_options_with_server_and_global_default(const std::string& owner, - const std::string& server, - const std::string& channel) +Database::IrcChannelOptions Database::get_irc_channel_options_with_server_and_global_default(const std::string& owner, const std::string& server, const std::string& channel) { auto coptions = Database::get_irc_channel_options(owner, server, channel); auto soptions = Database::get_irc_server_options(owner, server); auto goptions = Database::get_global_options(owner); - coptions.encodingIn = get_first_non_empty(coptions.encodingIn.value(), - soptions.encodingIn.value()); - coptions.encodingOut = get_first_non_empty(coptions.encodingOut.value(), - soptions.encodingOut.value()); + coptions.col<EncodingIn>() = get_first_non_empty(coptions.col<EncodingIn>(), + soptions.col<EncodingIn>()); - coptions.maxHistoryLength = get_first_non_empty(coptions.maxHistoryLength.value(), - soptions.maxHistoryLength.value(), - goptions.maxHistoryLength.value()); + coptions.col<EncodingOut>() = get_first_non_empty(coptions.col<EncodingOut>(), + soptions.col<EncodingOut>()); + + coptions.col<MaxHistoryLength>() = get_first_non_empty(coptions.col<MaxHistoryLength>(), + soptions.col<MaxHistoryLength>(), + goptions.col<MaxHistoryLength>()); return coptions; } -void Database::store_muc_message(const std::string& owner, const Iid& iid, - Database::time_point date, - const std::string& body, - const std::string& nick) +std::string Database::store_muc_message(const std::string& owner, const std::string& chan_name, + const std::string& server_name, Database::time_point date, + const std::string& body, const std::string& nick) { - db::MucLogLine line(*Database::db); + auto line = Database::muc_log_lines.row(); + + auto uuid = Database::gen_uuid(); - line.uuid = Database::gen_uuid(); - line.owner = owner; - line.ircChanName = iid.get_local(); - line.ircServerName = iid.get_server(); - line.date = date.time_since_epoch().count() / 1'000'000'000; - line.body = body; - line.nick = nick; + line.col<Uuid>() = uuid; + line.col<Owner>() = owner; + line.col<IrcChanName>() = chan_name; + line.col<IrcServerName>() = server_name; + line.col<Date>() = std::chrono::duration_cast<std::chrono::seconds>(date.time_since_epoch()).count(); + line.col<Body>() = body; + line.col<Nick>() = nick; - line.update(); + line.save(Database::db); + + return uuid; } -std::vector<db::MucLogLine> Database::get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server, +std::vector<Database::MucLogLine> Database::get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server, int limit, const std::string& start, const std::string& end) { - auto request = litesql::select<db::MucLogLine>(*Database::db, - db::MucLogLine::Owner == owner && - db::MucLogLine::IrcChanName == chan_name && - db::MucLogLine::IrcServerName == server); - request.orderBy(db::MucLogLine::Id, false); + auto request = Database::muc_log_lines.select(); + request.where() << Database::Owner{} << "=" << owner << \ + " and " << Database::IrcChanName{} << "=" << chan_name << \ + " and " << Database::IrcServerName{} << "=" << server; - if (limit >= 0) - request.limit(limit); if (!start.empty()) { const auto start_time = utils::parse_datetime(start); if (start_time != -1) - request.where(db::MucLogLine::Date >= start_time); + request << " and " << Database::Date{} << ">=" << start_time; } if (!end.empty()) { const auto end_time = utils::parse_datetime(end); if (end_time != -1) - request.where(db::MucLogLine::Date <= end_time); + request << " and " << Database::Date{} << "<=" << end_time; } - const auto& res = request.all(); - return {res.crbegin(), res.crend()}; + + request.order_by() << Id{} << " DESC "; + + if (limit >= 0) + request.limit() << limit; + + auto result = request.execute(Database::db); + + return {result.crbegin(), result.crend()}; } void Database::close() { - Database::db.reset(nullptr); + sqlite3_close_v2(Database::db); + Database::db = nullptr; } std::string Database::gen_uuid() @@ -178,5 +192,4 @@ std::string Database::gen_uuid() return uuid_str; } - -#endif +#endif
\ No newline at end of file diff --git a/src/database/database.hpp b/src/database/database.hpp index 6823574..8364abc 100644 --- a/src/database/database.hpp +++ b/src/database/database.hpp @@ -1,22 +1,112 @@ #pragma once - #include <biboumi.h> #ifdef USE_DATABASE -#include "biboudb.hpp" +#include <database/table.hpp> +#include <database/column.hpp> +#include <database/count_query.hpp> -#include <memory> +#include <utils/optional_bool.hpp> -#include <litesql.hpp> #include <chrono> +#include <string> + +#include <memory> -class Iid; class Database { -public: + public: using time_point = std::chrono::system_clock::time_point; + + struct Uuid: Column<std::string> { static constexpr auto name = "uuid_"; + static constexpr auto options = ""; }; + + struct Owner: Column<std::string> { static constexpr auto name = "owner_"; + static constexpr auto options = ""; }; + + struct IrcChanName: Column<std::string> { static constexpr auto name = "ircChanName_"; + static constexpr auto options = ""; }; + + struct Channel: Column<std::string> { static constexpr auto name = "channel_"; + static constexpr auto options = ""; }; + + struct IrcServerName: Column<std::string> { static constexpr auto name = "ircServerName_"; + static constexpr auto options = ""; }; + + struct Server: Column<std::string> { static constexpr auto name = "server_"; + static constexpr auto options = ""; }; + + struct Date: Column<time_point::rep> { static constexpr auto name = "date_"; + static constexpr auto options = ""; }; + + struct Body: Column<std::string> { static constexpr auto name = "body_"; + static constexpr auto options = ""; }; + + struct Nick: Column<std::string> { static constexpr auto name = "nick_"; + static constexpr auto options = ""; }; + + struct Pass: Column<std::string> { static constexpr auto name = "pass_"; + static constexpr auto options = ""; }; + + struct Ports: Column<std::string> { static constexpr auto name = "ports_"; + static constexpr auto options = ""; + Ports(): Column<std::string>("6667") {} }; + + struct TlsPorts: Column<std::string> { static constexpr auto name = "tlsPorts_"; + static constexpr auto options = ""; + TlsPorts(): Column<std::string>("6697;6670") {} }; + + struct Username: Column<std::string> { static constexpr auto name = "username_"; + static constexpr auto options = ""; }; + + struct Realname: Column<std::string> { static constexpr auto name = "realname_"; + static constexpr auto options = ""; }; + + struct AfterConnectionCommand: Column<std::string> { static constexpr auto name = "afterConnectionCommand_"; + static constexpr auto options = ""; }; + + struct TrustedFingerprint: Column<std::string> { static constexpr auto name = "trustedFingerprint_"; + static constexpr auto options = ""; }; + + struct EncodingOut: Column<std::string> { static constexpr auto name = "encodingOut_"; + static constexpr auto options = ""; }; + + struct EncodingIn: Column<std::string> { static constexpr auto name = "encodingIn_"; + static constexpr auto options = ""; }; + + struct MaxHistoryLength: Column<int> { static constexpr auto name = "maxHistoryLength_"; + static constexpr auto options = ""; + MaxHistoryLength(): Column<int>(20) {} }; + + struct RecordHistory: Column<bool> { static constexpr auto name = "recordHistory_"; + static constexpr auto options = ""; + RecordHistory(): Column<bool>(true) {}}; + + struct RecordHistoryOptional: Column<OptionalBool> { static constexpr auto name = "recordHistory_"; + static constexpr auto options = ""; }; + + struct VerifyCert: Column<bool> { static constexpr auto name = "verifyCert_"; + static constexpr auto options = ""; + VerifyCert(): Column<bool>(true) {} }; + + struct Persistent: Column<bool> { static constexpr auto name = "persistent_"; + static constexpr auto options = ""; + Persistent(): Column<bool>(false) {} }; + + using MucLogLineTable = Table<Id, Uuid, Owner, IrcChanName, IrcServerName, Date, Body, Nick>; + using MucLogLine = MucLogLineTable::RowType; + + using GlobalOptionsTable = Table<Id, Owner, MaxHistoryLength, RecordHistory, Persistent>; + using GlobalOptions = GlobalOptionsTable::RowType; + + using IrcServerOptionsTable = Table<Id, Owner, Server, Pass, AfterConnectionCommand, TlsPorts, Ports, Username, Realname, VerifyCert, TrustedFingerprint, EncodingOut, EncodingIn, MaxHistoryLength>; + using IrcServerOptions = IrcServerOptionsTable::RowType; + + using IrcChannelOptionsTable = Table<Id, Owner, Server, Channel, EncodingOut, EncodingIn, MaxHistoryLength, Persistent, RecordHistoryOptional>; + using IrcChannelOptions = IrcChannelOptionsTable::RowType; + Database() = default; ~Database() = default; @@ -25,42 +115,40 @@ public: Database& operator=(const Database&) = delete; Database& operator=(Database&&) = delete; - static void set_verbose(const bool val); - - template<typename PersistentType> - static size_t count() - { - return litesql::select<PersistentType>(*Database::db).count(); - } - /** - * Return the object from the db. Create it beforehand (with all default - * values) if it is not already present. - */ - static db::GlobalOptions get_global_options(const std::string& owner); - static db::IrcServerOptions get_irc_server_options(const std::string& owner, + static GlobalOptions get_global_options(const std::string& owner); + static IrcServerOptions get_irc_server_options(const std::string& owner, const std::string& server); - static db::IrcChannelOptions get_irc_channel_options(const std::string& owner, - const std::string& server, - const std::string& channel); - static db::IrcChannelOptions get_irc_channel_options_with_server_default(const std::string& owner, - const std::string& server, - const std::string& channel); - static db::IrcChannelOptions get_irc_channel_options_with_server_and_global_default(const std::string& owner, - const std::string& server, - const std::string& channel); - static std::vector<db::MucLogLine> get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server, - int limit=-1, const std::string& before="", const std::string& after=""); - static void store_muc_message(const std::string& owner, const Iid& iid, - time_point date, const std::string& body, const std::string& nick); + static IrcChannelOptions get_irc_channel_options(const std::string& owner, + const std::string& server, + const std::string& channel); + static IrcChannelOptions get_irc_channel_options_with_server_default(const std::string& owner, + const std::string& server, + const std::string& channel); + static IrcChannelOptions get_irc_channel_options_with_server_and_global_default(const std::string& owner, + const std::string& server, + const std::string& channel); + static std::vector<MucLogLine> get_muc_logs(const std::string& owner, const std::string& chan_name, const std::string& server, + int limit=-1, const std::string& start="", const std::string& end=""); + static std::string store_muc_message(const std::string& owner, const std::string& chan_name, const std::string& server_name, + time_point date, const std::string& body, const std::string& nick); static void close(); - static void open(const std::string& filename, const std::string& db_type="sqlite3"); + static void open(const std::string& filename); + template <typename TableType> + static int64_t count(const TableType& table) + { + CountQuery query{table.get_name()}; + return query.execute(Database::db); + } + + static MucLogLineTable muc_log_lines; + static GlobalOptionsTable global_options; + static IrcServerOptionsTable irc_server_options; + static IrcChannelOptionsTable irc_channel_options; + static sqlite3* db; -private: + private: static std::string gen_uuid(); - static std::unique_ptr<db::BibouDB> db; }; #endif /* USE_DATABASE */ - - diff --git a/src/database/insert_query.hpp b/src/database/insert_query.hpp new file mode 100644 index 0000000..9e410ce --- /dev/null +++ b/src/database/insert_query.hpp @@ -0,0 +1,129 @@ +#pragma once + +#include <database/statement.hpp> +#include <database/column.hpp> +#include <database/query.hpp> +#include <logger/logger.hpp> + +#include <type_traits> +#include <vector> +#include <string> +#include <tuple> + +#include <sqlite3.h> + +template <int N, typename ColumnType, typename... T> +typename std::enable_if<!std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type +actual_bind(Statement& statement, std::vector<std::string>& params, const std::tuple<T...>&) +{ + const auto value = params.front(); + params.erase(params.begin()); + if (sqlite3_bind_text(statement.get(), N + 1, value.data(), static_cast<int>(value.size()), SQLITE_TRANSIENT) != SQLITE_OK) + log_error("Failed to bind ", value, " to param ", N); + else + log_debug("Bound (not id) [", value, "] to ", N); +} + +template <int N, typename ColumnType, typename... T> +typename std::enable_if<std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type +actual_bind(Statement& statement, std::vector<std::string>&, const std::tuple<T...>& columns) +{ + auto&& column = std::get<Id>(columns); + if (column.value != 0) + { + if (sqlite3_bind_int64(statement.get(), N + 1, static_cast<sqlite3_int64>(column.value)) != SQLITE_OK) + log_error("Failed to bind ", column.value, " to id."); + } + else if (sqlite3_bind_null(statement.get(), N + 1) != SQLITE_OK) + log_error("Failed to bind NULL to param ", N); + else + log_debug("Bound NULL to ", N); +} + +struct InsertQuery: public Query +{ + InsertQuery(const std::string& name): + Query("INSERT OR REPLACE INTO ") + { + this->body += name; + } + + template <typename... T> + void execute(const std::tuple<T...>& columns, sqlite3* db) + { + auto statement = this->prepare(db); + { + this->bind_param(columns, statement); + if (sqlite3_step(statement.get()) != SQLITE_DONE) + log_error("Failed to execute query: ", sqlite3_errmsg(db)); + } + } + + template <int N=0, typename... T> + typename std::enable_if<N < sizeof...(T), void>::type + bind_param(const std::tuple<T...>& columns, Statement& statement) + { + using ColumnType = typename std::remove_reference<decltype(std::get<N>(columns))>::type; + + actual_bind<N, ColumnType>(statement, this->params, columns); + this->bind_param<N+1>(columns, statement); + } + + template <int N=0, typename... T> + typename std::enable_if<N == sizeof...(T), void>::type + bind_param(const std::tuple<T...>&, Statement&) + {} + + template <typename... T> + void insert_values(const std::tuple<T...>& columns) + { + this->body += "VALUES ("; + this->insert_value(columns); + this->body += ")"; + } + + template <int N=0, typename... T> + typename std::enable_if<N < sizeof...(T), void>::type + insert_value(const std::tuple<T...>& columns) + { + this->body += "?"; + if (N != sizeof...(T) - 1) + this->body += ","; + this->body += " "; + add_param(*this, std::get<N>(columns)); + this->insert_value<N+1>(columns); + } + template <int N=0, typename... T> + typename std::enable_if<N == sizeof...(T), void>::type + insert_value(const std::tuple<T...>&) + { } + + template <typename... T> + void insert_col_names(const std::tuple<T...>& columns) + { + this->body += " ("; + this->insert_col_name(columns); + this->body += ")\n"; + } + + template <int N=0, typename... T> + typename std::enable_if<N < sizeof...(T), void>::type + insert_col_name(const std::tuple<T...>& columns) + { + using ColumnType = typename std::remove_reference<decltype(std::get<N>(columns))>::type; + + this->body += ColumnType::name; + + if (N < (sizeof...(T) - 1)) + this->body += ", "; + + this->insert_col_name<N+1>(columns); + } + template <int N=0, typename... T> + typename std::enable_if<N == sizeof...(T), void>::type + insert_col_name(const std::tuple<T...>&) + {} + + + private: +}; diff --git a/src/database/query.cpp b/src/database/query.cpp new file mode 100644 index 0000000..ba63a92 --- /dev/null +++ b/src/database/query.cpp @@ -0,0 +1,34 @@ +#include <database/query.hpp> +#include <database/column.hpp> + +template <> +void add_param<Id>(Query&, const Id&) +{} + +void actual_add_param(Query& query, const std::string& val) +{ + query.params.push_back(val); +} + +void actual_add_param(Query& query, const OptionalBool& val) +{ + if (!val.is_set) + query.params.push_back("0"); + else if (val.value) + query.params.push_back("1"); + else + query.params.push_back("-1"); +} + +Query& operator<<(Query& query, const char* str) +{ + query.body += str; + return query; +} + +Query& operator<<(Query& query, const std::string& str) +{ + query.body += "?"; + actual_add_param(query, str); + return query; +} diff --git a/src/database/query.hpp b/src/database/query.hpp new file mode 100644 index 0000000..f103fe9 --- /dev/null +++ b/src/database/query.hpp @@ -0,0 +1,90 @@ +#pragma once + +#include <utils/optional_bool.hpp> +#include <database/statement.hpp> +#include <database/column.hpp> + +#include <logger/logger.hpp> + +#include <vector> +#include <string> + +#include <sqlite3.h> + +struct Query +{ + std::string body; + std::vector<std::string> params; + + Query(std::string str): + body(std::move(str)) + {} + + Statement prepare(sqlite3* db) + { + sqlite3_stmt* stmt; + log_debug(this->body); + auto res = sqlite3_prepare(db, this->body.data(), static_cast<int>(this->body.size()) + 1, + &stmt, nullptr); + if (res != SQLITE_OK) + { + log_error("Error preparing statement: ", sqlite3_errmsg(db)); + return nullptr; + } + Statement statement(stmt); + int i = 1; + for (const std::string& param: this->params) + { + if (sqlite3_bind_text(statement.get(), i, param.data(), static_cast<int>(param.size()), SQLITE_TRANSIENT) != SQLITE_OK) + log_debug("Failed to bind ", param, " to param ", i); + else + log_debug("Bound ", param, " to ", i); + i++; + } + + return statement; + } + + void execute(sqlite3* db) + { + auto statement = this->prepare(db); + while (sqlite3_step(statement.get()) != SQLITE_DONE) + ; + } +}; + +template <typename ColumnType> +void add_param(Query& query, const ColumnType& column) +{ + actual_add_param(query, column.value); +} +template <> +void add_param<Id>(Query& query, const Id& column); + +template <typename T> +void actual_add_param(Query& query, const T& val) +{ + query.params.push_back(std::to_string(val)); +} + +void actual_add_param(Query& query, const std::string& val); +void actual_add_param(Query& query, const OptionalBool& val); + +template <typename T> +typename std::enable_if<!std::is_integral<T>::value, Query&>::type +operator<<(Query& query, const T&) +{ + query.body += T::name; + return query; +} + +Query& operator<<(Query& query, const char* str); +Query& operator<<(Query& query, const std::string& str); +template <typename Integer> +typename std::enable_if<std::is_integral<Integer>::value, Query&>::type +operator<<(Query& query, const Integer& i) +{ + query.body += "?"; + actual_add_param(query, i); + return query; +} diff --git a/src/database/row.hpp b/src/database/row.hpp new file mode 100644 index 0000000..e7a58c4 --- /dev/null +++ b/src/database/row.hpp @@ -0,0 +1,75 @@ +#pragma once + +#include <database/insert_query.hpp> +#include <logger/logger.hpp> + +#include <type_traits> + +#include <sqlite3.h> + +template <typename ColumnType, typename... T> +typename std::enable_if<!std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type +update_id(std::tuple<T...>&, sqlite3*) +{} + +template <typename ColumnType, typename... T> +typename std::enable_if<std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type +update_id(std::tuple<T...>& columns, sqlite3* db) +{ + auto&& column = std::get<ColumnType>(columns); + log_debug("Found an autoincrement col."); + auto res = sqlite3_last_insert_rowid(db); + log_debug("Value is now: ", res); + column.value = static_cast<Id::real_type>(res); +} + +template <std::size_t N=0, typename... T> +typename std::enable_if<N < sizeof...(T), void>::type +update_autoincrement_id(std::tuple<T...>& columns, sqlite3* db) +{ + using ColumnType = typename std::remove_reference<decltype(std::get<N>(columns))>::type; + update_id<ColumnType>(columns, db); + update_autoincrement_id<N+1>(columns, db); +} + +template <std::size_t N=0, typename... T> +typename std::enable_if<N == sizeof...(T), void>::type +update_autoincrement_id(std::tuple<T...>&, sqlite3*) +{} + +template <typename... T> +struct Row +{ + Row(std::string name): + table_name(std::move(name)) + {} + + template <typename Type> + auto& col() + { + auto&& col = std::get<Type>(this->columns); + return col.value; + } + + template <typename Type> + const auto& col() const + { + auto&& col = std::get<Type>(this->columns); + return col.value; + } + + void save(sqlite3* db) + { + InsertQuery query(this->table_name); + query.insert_col_names(this->columns); + query.insert_values(this->columns); + log_debug(query.body); + + query.execute(this->columns, db); + + update_autoincrement_id(this->columns, db); + } + + std::tuple<T...> columns; + std::string table_name; +}; diff --git a/src/database/select_query.hpp b/src/database/select_query.hpp new file mode 100644 index 0000000..f4d71af --- /dev/null +++ b/src/database/select_query.hpp @@ -0,0 +1,127 @@ +#pragma once + +#include <database/statement.hpp> +#include <database/query.hpp> +#include <logger/logger.hpp> +#include <database/row.hpp> + +#include <utils/optional_bool.hpp> + +#include <vector> +#include <string> + +#include <sqlite3.h> + +using namespace std::string_literals; + +template <typename T> +typename std::enable_if<std::is_integral<T>::value, sqlite3_int64>::type +extract_row_value(Statement& statement, const int i) +{ + return sqlite3_column_int64(statement.get(), i); +} + +template <typename T> +typename std::enable_if<std::is_same<std::string, T>::value, T>::type +extract_row_value(Statement& statement, const int i) +{ + const auto size = sqlite3_column_bytes(statement.get(), i); + const unsigned char* str = sqlite3_column_text(statement.get(), i); + std::string result(reinterpret_cast<const char*>(str), static_cast<std::size_t>(size)); + return result; +} + +template <typename T> +typename std::enable_if<std::is_same<OptionalBool, T>::value, T>::type +extract_row_value(Statement& statement, const int i) +{ + const auto integer = sqlite3_column_int(statement.get(), i); + OptionalBool result; + if (integer > 0) + result.set_value(true); + else if (integer < 0) + result.set_value(false); + return result; +} + +template <std::size_t N=0, typename... T> +typename std::enable_if<N < sizeof...(T), void>::type +extract_row_values(Row<T...>& row, Statement& statement) +{ + using ColumnType = typename std::remove_reference<decltype(std::get<N>(row.columns))>::type; + + auto&& column = std::get<N>(row.columns); + column.value = static_cast<decltype(column.value)>(extract_row_value<typename ColumnType::real_type>(statement, N)); + + extract_row_values<N+1>(row, statement); +} + +template <std::size_t N=0, typename... T> +typename std::enable_if<N == sizeof...(T), void>::type +extract_row_values(Row<T...>&, Statement&) +{} + +template <typename... T> +struct SelectQuery: public Query +{ + SelectQuery(std::string table_name): + Query("SELECT"), + table_name(table_name) + { + this->insert_col_name(); + this->body += " from " + this->table_name; + } + + template <std::size_t N=0> + typename std::enable_if<N < sizeof...(T), void>::type + insert_col_name() + { + using ColumnsType = std::tuple<T...>; + using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<ColumnsType>()))>::type; + + this->body += " "s + ColumnType::name; + + if (N < (sizeof...(T) - 1)) + this->body += ", "; + + this->insert_col_name<N+1>(); + } + template <std::size_t N=0> + typename std::enable_if<N == sizeof...(T), void>::type + insert_col_name() + {} + + SelectQuery& where() + { + this->body += " WHERE "; + return *this; + }; + + SelectQuery& order_by() + { + this->body += " ORDER BY "; + return *this; + } + + SelectQuery& limit() + { + this->body += " LIMIT "; + return *this; + } + + auto execute(sqlite3* db) + { + auto statement = this->prepare(db); + std::vector<Row<T...>> rows; + while (sqlite3_step(statement.get()) == SQLITE_ROW) + { + Row<T...> row(this->table_name); + extract_row_values(row, statement); + rows.push_back(row); + } + return rows; + } + + const std::string table_name; +}; + diff --git a/src/database/statement.hpp b/src/database/statement.hpp new file mode 100644 index 0000000..87cd70f --- /dev/null +++ b/src/database/statement.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include <sqlite3.h> + +class Statement +{ + public: + Statement(sqlite3_stmt* stmt): + stmt(stmt) {} + ~Statement() + { + sqlite3_finalize(this->stmt); + } + + Statement(const Statement&) = delete; + Statement& operator=(const Statement&) = delete; + Statement(Statement&& other): + stmt(other.stmt) + { + other.stmt = nullptr; + } + Statement& operator=(Statement&& other) + { + this->stmt = other.stmt; + other.stmt = nullptr; + return *this; + } + sqlite3_stmt* get() + { + return this->stmt; + } + + private: + sqlite3_stmt* stmt; +}; diff --git a/src/database/table.cpp b/src/database/table.cpp new file mode 100644 index 0000000..5929f33 --- /dev/null +++ b/src/database/table.cpp @@ -0,0 +1,25 @@ +#include <database/table.hpp> + +std::set<std::string> get_all_columns_from_table(sqlite3* db, const std::string& table_name) +{ + std::set<std::string> result; + char* errmsg; + std::string query{"PRAGMA table_info("s + table_name + ")"}; + log_debug(query); + int res = sqlite3_exec(db, query.data(), [](void* param, int columns_nb, char** columns, char**) -> int { + constexpr int name_column = 1; + std::set<std::string>* result = static_cast<std::set<std::string>*>(param); + log_debug("Table has column ", columns[name_column]); + if (name_column < columns_nb) + result->insert(columns[name_column]); + return 0; + }, &result, &errmsg); + + if (res != SQLITE_OK) + { + log_error("Error executing ", query, ": ", errmsg); + sqlite3_free(errmsg); + } + + return result; +} diff --git a/src/database/table.hpp b/src/database/table.hpp new file mode 100644 index 0000000..411ac6a --- /dev/null +++ b/src/database/table.hpp @@ -0,0 +1,127 @@ +#pragma once + +#include <database/select_query.hpp> +#include <database/type_to_sql.hpp> +#include <logger/logger.hpp> +#include <database/row.hpp> + +#include <algorithm> +#include <string> +#include <set> + +using namespace std::string_literals; + +std::set<std::string> get_all_columns_from_table(sqlite3* db, const std::string& table_name); + +template <typename ColumnType> +void add_column_to_table(sqlite3* db, const std::string& table_name) +{ + const std::string name = ColumnType::name; + std::string query{"ALTER TABLE "s + table_name + " ADD " + ColumnType::name + " " + TypeToSQLType<typename ColumnType::real_type>::type}; + log_debug(query); + char* error; + const auto result = sqlite3_exec(db, query.data(), nullptr, nullptr, &error); + if (result != SQLITE_OK) + { + log_error("Error adding column ", name, " to table ", table_name, ": ", error); + sqlite3_free(error); + } +} + +template <typename... T> +class Table +{ + static_assert(sizeof...(T) > 0, "Table cannot be empty"); + using ColumnTypes = std::tuple<T...>; + + public: + using RowType = Row<T...>; + + Table(std::string name): + name(std::move(name)) + {} + + void upgrade(sqlite3* db) + { + const auto existing_columns = get_all_columns_from_table(db, this->name); + add_column_if_not_exists(db, existing_columns); + } + + void create(sqlite3* db) + { + std::string res{"CREATE TABLE IF NOT EXISTS "}; + res += this->name; + res += " (\n"; + this->add_column_create(res); + res += ")"; + + log_debug(res); + + char* error; + const auto result = sqlite3_exec(db, res.data(), nullptr, nullptr, &error); + log_debug("result: ", +result); + if (result != SQLITE_OK) + { + log_error("Error executing query: ", error); + sqlite3_free(error); + } + } + + RowType row() + { + return {this->name}; + } + + SelectQuery<T...> select() + { + SelectQuery<T...> select(this->name); + return select; + } + + const std::string& get_name() const + { + return this->name; + } + + private: + + template <std::size_t N=0> + typename std::enable_if<N < sizeof...(T), void>::type + add_column_if_not_exists(sqlite3* db, const std::set<std::string>& existing_columns) + { + using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<ColumnTypes>()))>::type; + if (existing_columns.count(ColumnType::name) != 1) + { + add_column_to_table<ColumnType>(db, this->name); + } + add_column_if_not_exists<N+1>(db, existing_columns); + } + template <std::size_t N=0> + typename std::enable_if<N == sizeof...(T), void>::type + add_column_if_not_exists(sqlite3*, const std::set<std::string>&) + {} + + template <std::size_t N=0> + typename std::enable_if<N < sizeof...(T), void>::type + add_column_create(std::string& str) + { + using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<ColumnTypes>()))>::type; + using RealType = typename ColumnType::real_type; + str += ColumnType::name; + str += " "; + str += TypeToSQLType<RealType>::type; + str += " "s + ColumnType::options; + if (N != sizeof...(T) - 1) + str += ","; + str += "\n"; + + add_column_create<N+1>(str); + } + + template <std::size_t N=0> + typename std::enable_if<N == sizeof...(T), void>::type + add_column_create(std::string&) + { } + + const std::string name; +}; diff --git a/src/database/type_to_sql.cpp b/src/database/type_to_sql.cpp new file mode 100644 index 0000000..bcd9daa --- /dev/null +++ b/src/database/type_to_sql.cpp @@ -0,0 +1,9 @@ +#include <database/type_to_sql.hpp> + +template <> const std::string TypeToSQLType<int>::type = "INTEGER"; +template <> const std::string TypeToSQLType<std::size_t>::type = "INTEGER"; +template <> const std::string TypeToSQLType<long>::type = "INTEGER"; +template <> const std::string TypeToSQLType<long long>::type = "INTEGER"; +template <> const std::string TypeToSQLType<bool>::type = "INTEGER"; +template <> const std::string TypeToSQLType<std::string>::type = "TEXT"; +template <> const std::string TypeToSQLType<OptionalBool>::type = "INTEGER";
\ No newline at end of file diff --git a/src/database/type_to_sql.hpp b/src/database/type_to_sql.hpp new file mode 100644 index 0000000..ba806ab --- /dev/null +++ b/src/database/type_to_sql.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include <utils/optional_bool.hpp> + +#include <string> + +template <typename T> +struct TypeToSQLType { static const std::string type; }; + +template <> const std::string TypeToSQLType<int>::type; +template <> const std::string TypeToSQLType<std::size_t>::type; +template <> const std::string TypeToSQLType<long>::type; +template <> const std::string TypeToSQLType<long long>::type; +template <> const std::string TypeToSQLType<bool>::type; +template <> const std::string TypeToSQLType<std::string>::type; +template <> const std::string TypeToSQLType<OptionalBool>::type;
\ No newline at end of file |