diff options
Diffstat (limited to 'src/database')
-rw-r--r-- | src/database/column_escape.cpp | 46 | ||||
-rw-r--r-- | src/database/database.cpp | 43 | ||||
-rw-r--r-- | src/database/database.hpp | 30 | ||||
-rw-r--r-- | src/database/datetime_writer.hpp | 32 | ||||
-rw-r--r-- | src/database/engine.hpp | 18 | ||||
-rw-r--r-- | src/database/insert_query.hpp | 16 | ||||
-rw-r--r-- | src/database/postgresql_engine.cpp | 44 | ||||
-rw-r--r-- | src/database/postgresql_engine.hpp | 13 | ||||
-rw-r--r-- | src/database/query.cpp | 16 | ||||
-rw-r--r-- | src/database/query.hpp | 10 | ||||
-rw-r--r-- | src/database/select_query.hpp | 25 | ||||
-rw-r--r-- | src/database/sqlite3_engine.cpp | 66 | ||||
-rw-r--r-- | src/database/sqlite3_engine.hpp | 11 | ||||
-rw-r--r-- | src/database/table.hpp | 8 |
14 files changed, 339 insertions, 39 deletions
diff --git a/src/database/column_escape.cpp b/src/database/column_escape.cpp new file mode 100644 index 0000000..0f1f611 --- /dev/null +++ b/src/database/column_escape.cpp @@ -0,0 +1,46 @@ +#include <string> + +#include <database/database.hpp> +#include <database/select_query.hpp> + +template <> +std::string before_column<Database::Date>() +{ + if (Database::engine_type() == DatabaseEngine::EngineType::Sqlite3) + return "strftime(\"%Y-%m-%dT%H:%M:%SZ\", "; + else if (Database::engine_type() == DatabaseEngine::EngineType::Postgresql) + return "to_char("; + return {}; +} + +template <> +std::string after_column<Database::Date>() +{ + if (Database::engine_type() == DatabaseEngine::EngineType::Sqlite3) + return ")"; + else if (Database::engine_type() == DatabaseEngine::EngineType::Postgresql) + return R"(, 'YYYY-MM-DD"T"HH24:MM:SS"Z"'))"; + return {}; +} + +#include <database/insert_query.hpp> + +template <> +std::string before_value<Database::Date>() +{ + if (Database::engine_type() == DatabaseEngine::EngineType::Sqlite3) + return "julianday("; + if (Database::engine_type() == DatabaseEngine::EngineType::Postgresql) + return "to_timestamp("; + return {}; +} + +template <> +std::string after_value<Database::Date>() +{ + if (Database::engine_type() == DatabaseEngine::EngineType::Sqlite3) + return ", \"unixepoch\")"; + if (Database::engine_type() == DatabaseEngine::EngineType::Postgresql) + return ")"; + return {}; +} diff --git a/src/database/database.cpp b/src/database/database.cpp index 02c5b4f..7cb0a45 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -46,6 +46,7 @@ void Database::open(const std::string& filename) Database::db = std::move(new_db); Database::muc_log_lines.create(*Database::db); Database::muc_log_lines.upgrade(*Database::db); + convert_date_format(*Database::db, Database::muc_log_lines); Database::global_options.create(*Database::db); Database::global_options.upgrade(*Database::db); Database::irc_server_options.create(*Database::db); @@ -57,9 +58,9 @@ void Database::open(const std::string& filename) Database::after_connection_commands.create(*Database::db); Database::after_connection_commands.upgrade(*Database::db); create_index<Database::Owner, Database::IrcChanName, Database::IrcServerName>(*Database::db, "archive_index", Database::muc_log_lines.get_name()); + Database::db->init_session(); } - Database::GlobalOptions Database::get_global_options(const std::string& owner) { auto request = Database::global_options.select(); @@ -175,7 +176,7 @@ Database::IrcChannelOptions Database::get_irc_channel_options_with_server_and_gl } std::string Database::store_muc_message(const std::string& owner, const std::string& chan_name, - const std::string& server_name, Database::time_point date, + const std::string& server_name, DateTime::time_point date, const std::string& body, const std::string& nick) { auto line = Database::muc_log_lines.row(); @@ -186,7 +187,7 @@ std::string Database::store_muc_message(const std::string& owner, const std::str line.col<Owner>() = owner; line.col<IrcChanName>() = chan_name; line.col<IrcServerName>() = server_name; - line.col<Date>() = std::chrono::duration_cast<std::chrono::seconds>(date.time_since_epoch()).count(); + line.col<Date>() = date; line.col<Body>() = body; line.col<Nick>() = nick; @@ -210,13 +211,21 @@ std::vector<Database::MucLogLine> Database::get_muc_logs(const std::string& owne { const auto start_time = utils::parse_datetime(start); if (start_time != -1) - request << " and " << Database::Date{} << ">=" << start_time; + { + DateTime datetime(start_time); + DatetimeWriter writer(datetime, *Database::db); + request << " and " << Database::Date{} << ">=" << writer; + } } if (!end.empty()) { const auto end_time = utils::parse_datetime(end); if (end_time != -1) - request << " and " << Database::Date{} << "<=" << end_time; + { + DateTime datetime(end_time); + DatetimeWriter writer(datetime, *Database::db); + request << " and " << Database::Date{} << "<=" << writer; + } } if (reference_record_id != Id::unset_value) { @@ -229,9 +238,9 @@ std::vector<Database::MucLogLine> Database::get_muc_logs(const std::string& owne } if (paging == Database::Paging::first) - request.order_by() << Database::Date{} << " ASC, " << Id{} << " ASC "; + request.order_by() << Database::Date{} << " ASC"; else - request.order_by() << Database::Date{} << " DESC, " << Id{} << " DESC "; + request.order_by() << Database::Date{} << " DESC"; if (limit >= 0) request.limit() << limit; @@ -257,13 +266,21 @@ Database::MucLogLine Database::get_muc_log(const std::string& owner, const std:: { const auto start_time = utils::parse_datetime(start); if (start_time != -1) - request << " and " << Database::Date{} << ">=" << start_time; + { + DateTime datetime(start_time); + DatetimeWriter writer(datetime, *Database::db); + request << " and " << Database::Date{} << ">=" << writer; + } } if (!end.empty()) { const auto end_time = utils::parse_datetime(end); if (end_time != -1) - request << " and " << Database::Date{} << "<=" << end_time; + { + DateTime datetime(end_time); + DatetimeWriter writer(datetime, *Database::db); + request << " and " << Database::Date{} << "<=" << writer; + } } auto result = request.execute(*Database::db); @@ -347,4 +364,12 @@ Transaction::~Transaction() log_error("Failed to end SQL transaction: ", std::get<std::string>(result)); } } + +void Transaction::rollback() +{ + this->success = false; + const auto result = Database::raw_exec("ROLLBACK"); + if (std::get<bool>(result) == false) + log_error("Failed to rollback SQL transaction: ", std::get<std::string>(result)); +} #endif diff --git a/src/database/database.hpp b/src/database/database.hpp index d986ecc..75ff8f3 100644 --- a/src/database/database.hpp +++ b/src/database/database.hpp @@ -10,6 +10,7 @@ #include <database/engine.hpp> #include <utils/optional_bool.hpp> +#include <utils/datetime.hpp> #include <chrono> #include <string> @@ -17,11 +18,9 @@ #include <memory> #include <map> - class Database { public: - using time_point = std::chrono::system_clock::time_point; struct RecordNotFound: public std::exception {}; enum class Paging { first, last }; @@ -37,7 +36,8 @@ class Database struct Server: Column<std::string> { static constexpr auto name = "server_"; }; - struct Date: Column<time_point::rep> { static constexpr auto name = "date_"; }; + struct OldDate: Column<std::chrono::system_clock::time_point::rep> { static constexpr auto name = "date_"; }; + struct Date: Column<DateTime> { static constexpr auto name = "date_"; }; struct Body: Column<std::string> { static constexpr auto name = "body_"; }; @@ -88,6 +88,8 @@ class Database using MucLogLineTable = Table<Id, Uuid, Owner, IrcChanName, IrcServerName, Date, Body, Nick>; using MucLogLine = MucLogLineTable::RowType; + using OldMucLogLineTable = Table<Id, Uuid, Owner, IrcChanName, IrcServerName, OldDate, Body, Nick>; + using OldMucLogLine = OldMucLogLineTable::RowType; using GlobalOptionsTable = Table<Id, Owner, MaxHistoryLength, RecordHistory, GlobalPersistent>; using GlobalOptions = GlobalOptionsTable::RowType; @@ -141,7 +143,7 @@ class Database */ static MucLogLine get_muc_log(const std::string& owner, const std::string& chan_name, const std::string& server, const std::string& uuid, const std::string& start="", const std::string& end=""); static std::string store_muc_message(const std::string& owner, const std::string& chan_name, const std::string& server_name, - time_point date, const std::string& body, const std::string& nick); + DateTime::time_point date, const std::string& body, const std::string& nick); static void add_roster_item(const std::string& local, const std::string& remote); static bool has_roster_item(const std::string& local, const std::string& remote); @@ -168,6 +170,13 @@ class Database static std::unique_ptr<DatabaseEngine> db; + static DatabaseEngine::EngineType engine_type() + { + if (Database::db) + return Database::db->engine_type(); + return DatabaseEngine::EngineType::None; + } + /** * Some caches, to avoid doing very frequent query requests for a few options. */ @@ -216,7 +225,20 @@ class Transaction public: Transaction(); ~Transaction(); + void rollback(); bool success{false}; }; +template <typename... T> +void convert_date_format(DatabaseEngine& db, Table<T...> table) +{ + const auto existing_columns = db.get_all_columns_from_table(table.get_name()); + const auto date_pair = existing_columns.find(Database::Date::name); + if (date_pair != existing_columns.end() && date_pair->second == "integer") + { + log_info("Converting Date_ format to the new one."); + db.convert_date_format(db); + } +} + #endif /* USE_DATABASE */ diff --git a/src/database/datetime_writer.hpp b/src/database/datetime_writer.hpp new file mode 100644 index 0000000..b104911 --- /dev/null +++ b/src/database/datetime_writer.hpp @@ -0,0 +1,32 @@ +#pragma once + +#include <utils/datetime.hpp> +#include <database/engine.hpp> + +#include <logger/logger.hpp> +#include <database/postgresql_engine.hpp> +#include <database/sqlite3_engine.hpp> + +class DatetimeWriter +{ +public: + DatetimeWriter(DateTime datetime, const DatabaseEngine& engine): + datetime(datetime), + engine(engine) + {} + + long double get_value() const + { + const long double epoch_duration = this->datetime.epoch().count(); + const long double epoch_seconds = epoch_duration / std::chrono::system_clock::period::den; + return this->engine.epoch_to_floating_value(epoch_seconds); + } + std::string escape_param_number(int value) const + { + return this->engine.escape_param_number(value); + } + +private: + const DateTime datetime; + const DatabaseEngine& engine; +}; diff --git a/src/database/engine.hpp b/src/database/engine.hpp index 41dccf5..ecf047f 100644 --- a/src/database/engine.hpp +++ b/src/database/engine.hpp @@ -13,6 +13,7 @@ #include <string> #include <vector> #include <tuple> +#include <map> #include <set> class DatabaseEngine @@ -27,7 +28,10 @@ class DatabaseEngine DatabaseEngine(DatabaseEngine&&) = delete; DatabaseEngine& operator=(DatabaseEngine&&) = delete; - virtual std::set<std::string> get_all_columns_from_table(const std::string& table_name) = 0; + enum class EngineType { None, Postgresql, Sqlite3, }; + virtual EngineType engine_type() const = 0; + + virtual std::map<std::string, std::string> get_all_columns_from_table(const std::string& table_name) = 0; virtual std::tuple<bool, std::string> raw_exec(const std::string& query) = 0; virtual std::unique_ptr<Statement> prepare(const std::string& query) = 0; virtual void extract_last_insert_rowid(Statement& statement) = 0; @@ -35,7 +39,17 @@ class DatabaseEngine { return {}; } - virtual std::string id_column_type() = 0; + virtual void convert_date_format(DatabaseEngine&) = 0; + virtual std::string id_column_type() const = 0; + virtual std::string datetime_column_type() const = 0; + virtual long double epoch_to_floating_value(long double seconds) const = 0; + virtual std::string escape_param_number(int nb) const + { + return "$" + std::to_string(nb); + } + virtual void init_session() + { + } int64_t last_inserted_rowid{-1}; }; diff --git a/src/database/insert_query.hpp b/src/database/insert_query.hpp index 04c098c..e3a7e83 100644 --- a/src/database/insert_query.hpp +++ b/src/database/insert_query.hpp @@ -25,6 +25,18 @@ typename std::enable_if<N == sizeof...(T), void>::type update_autoincrement_id(std::tuple<T...>&, Statement&) {} +template <typename T> +std::string before_value() +{ + return {}; +} + +template <typename T> +std::string after_value() +{ + return {}; +} + struct InsertQuery: public Query { template <typename... T> @@ -73,7 +85,7 @@ struct InsertQuery: public Query template <typename... T> void insert_values(const std::tuple<T...>& columns) { - this->body += "VALUES ("; + this->body += " VALUES ("; this->insert_value(columns); this->body += ")"; } @@ -86,7 +98,9 @@ struct InsertQuery: public Query if (!std::is_same<ColumnType, Id>::value) { + this->body += before_value<ColumnType>(); this->body += "$" + std::to_string(index++); + this->body += after_value<ColumnType>(); if (N != sizeof...(T) - 1) this->body += ", "; } diff --git a/src/database/postgresql_engine.cpp b/src/database/postgresql_engine.cpp index 59bc885..abeb779 100644 --- a/src/database/postgresql_engine.cpp +++ b/src/database/postgresql_engine.cpp @@ -2,6 +2,7 @@ #ifdef PQ_FOUND #include <utils/scopeguard.hpp> +#include <utils/tolower.hpp> #include <database/query.hpp> @@ -12,6 +13,7 @@ #include <logger/logger.hpp> #include <cstring> +#include <database/database.hpp> PostgresqlEngine::PostgresqlEngine(PGconn*const conn): conn(conn) @@ -52,14 +54,14 @@ std::unique_ptr<DatabaseEngine> PostgresqlEngine::open(const std::string& connin return std::make_unique<PostgresqlEngine>(con); } -std::set<std::string> PostgresqlEngine::get_all_columns_from_table(const std::string& table_name) +std::map<std::string, std::string> PostgresqlEngine::get_all_columns_from_table(const std::string& table_name) { - const auto query = "SELECT column_name from information_schema.columns where table_name='" + table_name + "'"; + const auto query = "SELECT column_name, data_type from information_schema.columns where table_name='" + table_name + "'"; auto statement = this->prepare(query); - std::set<std::string> columns; + std::map<std::string, std::string> columns; while (statement->step() == StepResult::Row) - columns.insert(statement->get_column_text(0)); + columns[utils::tolower(statement->get_column_text(0))] = utils::tolower(statement->get_column_text(1)); return columns; } @@ -96,9 +98,41 @@ std::string PostgresqlEngine::get_returning_id_sql_string(const std::string& col return " RETURNING " + col_name; } -std::string PostgresqlEngine::id_column_type() +std::string PostgresqlEngine::id_column_type() const { return "SERIAL"; } +std::string PostgresqlEngine::datetime_column_type() const +{ + return "TIMESTAMP"; +} + +void PostgresqlEngine::convert_date_format(DatabaseEngine& db) +{ + const auto table_name = Database::muc_log_lines.get_name(); + const std::string column_name = Database::Date::name; + const std::string query = "ALTER TABLE " + table_name + " ALTER COLMUN " + column_name + " SET DATA TYPE timestamp USING to_timestamp(" + column_name + ")"; + + auto result = db.raw_exec(query); + if (!std::get<bool>(result)) + log_error("Failed to execute query: ", std::get<std::string>(result)); +} + +std::string PostgresqlEngine::escape_param_number(int nb) const +{ + return "to_timestamp(" + DatabaseEngine::escape_param_number(nb) + ")"; +} + +void PostgresqlEngine::init_session() +{ + const auto res = this->raw_exec("SET SESSION TIME ZONE 'UTC'"); + if (!std::get<bool>(res)) + log_debug("Failed to set UTC timezone: ", std::get<std::string>(res)); +} +long double PostgresqlEngine::epoch_to_floating_value(long double seconds) const +{ + return seconds; +} + #endif diff --git a/src/database/postgresql_engine.hpp b/src/database/postgresql_engine.hpp index fe4fb53..f2dcec3 100644 --- a/src/database/postgresql_engine.hpp +++ b/src/database/postgresql_engine.hpp @@ -23,13 +23,22 @@ class PostgresqlEngine: public DatabaseEngine ~PostgresqlEngine(); static std::unique_ptr<DatabaseEngine> open(const std::string& string); + EngineType engine_type() const override + { + return EngineType::Postgresql; + } - std::set<std::string> get_all_columns_from_table(const std::string& table_name) override final; + std::map<std::string, std::string> get_all_columns_from_table(const std::string& table_name) override final; std::tuple<bool, std::string> raw_exec(const std::string& query) override final; std::unique_ptr<Statement> prepare(const std::string& query) override; void extract_last_insert_rowid(Statement& statement) override; std::string get_returning_id_sql_string(const std::string& col_name) override; - std::string id_column_type() override; + std::string id_column_type() const override; + std::string datetime_column_type() const override; + void convert_date_format(DatabaseEngine& engine) override; + long double epoch_to_floating_value(long double seconds) const override; + void init_session() override; + std::string escape_param_number(int nb) const override; private: PGconn* const conn; }; diff --git a/src/database/query.cpp b/src/database/query.cpp index d72066e..6d20302 100644 --- a/src/database/query.cpp +++ b/src/database/query.cpp @@ -21,6 +21,13 @@ void actual_bind(Statement& statement, const OptionalBool& value, int index) statement.bind_int64(index, -1); } +void actual_bind(Statement& statement, const DateTime& value, int index) +{ + const auto epoch = value.epoch().count(); + const auto result = std::to_string(static_cast<long double>(epoch) / std::chrono::system_clock::period::den); + statement.bind_text(index, result); +} + void actual_add_param(Query& query, const std::string& val) { query.params.push_back(val); @@ -49,3 +56,12 @@ Query& operator<<(Query& query, const std::string& str) actual_add_param(query, str); return query; } + +Query& operator<<(Query& query, const DatetimeWriter& datetime_writer) +{ + query.body += datetime_writer.escape_param_number(query.current_param++); + actual_add_param(query, datetime_writer.get_value()); + return query; +} + + diff --git a/src/database/query.hpp b/src/database/query.hpp index ba28b1a..910271a 100644 --- a/src/database/query.hpp +++ b/src/database/query.hpp @@ -3,6 +3,7 @@ #include <biboumi.h> #include <utils/optional_bool.hpp> +#include <utils/datetime.hpp> #include <database/statement.hpp> #include <database/column.hpp> @@ -10,6 +11,7 @@ #include <vector> #include <string> +#include <database/datetime_writer.hpp> void actual_bind(Statement& statement, const std::string& value, int index); void actual_bind(Statement& statement, const std::int64_t& value, int index); @@ -19,6 +21,7 @@ void actual_bind(Statement& statement, const T& value, int index) actual_bind(statement, static_cast<std::int64_t>(value), index); } void actual_bind(Statement& statement, const OptionalBool& value, int index); +void actual_bind(Statement& statement, const DateTime& value, int index); #ifdef DEBUG_SQL_QUERIES #include <utils/scopetimer.hpp> @@ -74,15 +77,13 @@ void actual_add_param(Query& query, const std::string& val); void actual_add_param(Query& query, const OptionalBool& val); template <typename T> -typename std::enable_if<!std::is_integral<T>::value, Query&>::type +typename std::enable_if<!std::is_arithmetic<T>::value, Query&>::type operator<<(Query& query, const T&) { query.body += T::name; return query; } -Query& operator<<(Query& query, const char* str); -Query& operator<<(Query& query, const std::string& str); template <typename Integer> typename std::enable_if<std::is_integral<Integer>::value, Query&>::type operator<<(Query& query, const Integer& i) @@ -92,3 +93,6 @@ operator<<(Query& query, const Integer& i) return query; } +Query& operator<<(Query& query, const char* str); +Query& operator<<(Query& query, const std::string& str); +Query& operator<<(Query& query, const DatetimeWriter& datetime); diff --git a/src/database/select_query.hpp b/src/database/select_query.hpp index 743a011..3013dd8 100644 --- a/src/database/select_query.hpp +++ b/src/database/select_query.hpp @@ -3,8 +3,8 @@ #include <database/engine.hpp> #include <database/statement.hpp> +#include <utils/datetime.hpp> #include <database/query.hpp> -#include <logger/logger.hpp> #include <database/row.hpp> #include <utils/optional_bool.hpp> @@ -41,6 +41,14 @@ extract_row_value(Statement& statement, const int i) return result; } +template <typename T> +typename std::enable_if<std::is_same<DateTime, T>::value, T>::type +extract_row_value(Statement& statement, const int i) +{ + const std::string timestamp = statement.get_column_text(i); + return {timestamp}; +} + template <std::size_t N=0, typename... T> typename std::enable_if<N < sizeof...(T), void>::type extract_row_values(Row<T...>& row, Statement& statement) @@ -58,6 +66,18 @@ typename std::enable_if<N == sizeof...(T), void>::type extract_row_values(Row<T...>&, Statement&) {} +template <typename ColumnType> +std::string before_column() +{ + return {}; +} + +template <typename ColumnType> +std::string after_column() +{ + return {}; +} + template <typename... T> struct SelectQuery: public Query { @@ -76,7 +96,8 @@ struct SelectQuery: public Query using ColumnsType = std::tuple<T...>; using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<ColumnsType>()))>::type; - this->body += " " + std::string{ColumnType::name}; + this->body += " "; + this->body += before_column<ColumnType>() + ColumnType::name + after_column<ColumnType>(); if (N < (sizeof...(T) - 1)) this->body += ", "; diff --git a/src/database/sqlite3_engine.cpp b/src/database/sqlite3_engine.cpp index ae4a146..1fa6316 100644 --- a/src/database/sqlite3_engine.cpp +++ b/src/database/sqlite3_engine.cpp @@ -2,6 +2,7 @@ #ifdef SQLITE3_FOUND +#include <database/database.hpp> #include <database/sqlite3_engine.hpp> #include <database/sqlite3_statement.hpp> @@ -22,16 +23,17 @@ Sqlite3Engine::~Sqlite3Engine() sqlite3_close(this->db); } -std::set<std::string> Sqlite3Engine::get_all_columns_from_table(const std::string& table_name) +std::map<std::string, std::string> Sqlite3Engine::get_all_columns_from_table(const std::string& table_name) { - std::set<std::string> result; + std::map<std::string, std::string> result; char* errmsg; std::string query{"PRAGMA table_info(" + table_name + ")"}; int res = sqlite3_exec(this->db, query.data(), [](void* param, int columns_nb, char** columns, char**) -> int { constexpr int name_column = 1; - std::set<std::string>* result = static_cast<std::set<std::string>*>(param); - if (name_column < columns_nb) - result->insert(utils::tolower(columns[name_column])); + constexpr int data_type_column = 2; + auto* result = static_cast<std::map<std::string, std::string>*>(param); + if (name_column < columns_nb && data_type_column < columns_nb) + (*result)[utils::tolower(columns[name_column])] = utils::tolower(columns[data_type_column]); return 0; }, &result, &errmsg); @@ -44,6 +46,48 @@ std::set<std::string> Sqlite3Engine::get_all_columns_from_table(const std::strin return result; } +template <typename... T> +static auto make_select_query(const Row<T...>&, const std::string& name) +{ + return SelectQuery<T...>{name}; +} + +void Sqlite3Engine::convert_date_format(DatabaseEngine& db) +{ + Transaction transaction{}; + auto rollback = [&transaction] (const std::string& error_msg) + { + log_error("Failed to execute query: ", error_msg); + transaction.rollback(); + }; + + const auto real_name = Database::muc_log_lines.get_name(); + const auto tmp_name = real_name + "tmp_"; + const std::string date_name = Database::Date::name; + + auto result = db.raw_exec("ALTER TABLE " + real_name + " RENAME TO " + tmp_name); + if (!std::get<bool>(result)) + return rollback(std::get<std::string>(result)); + + Database::muc_log_lines.create(db); + + Database::OldMucLogLineTable old_muc_log_line(tmp_name); + auto select_query = make_select_query(old_muc_log_line.row(), old_muc_log_line.get_name()); + + auto& select_body = select_query.body; + auto begin = select_body.find(date_name); + select_body.replace(begin, date_name.size(), "julianday("+date_name+", 'unixepoch')"); + select_body = "INSERT INTO " + real_name + " " + select_body; + + result = db.raw_exec(select_body); + if (!std::get<bool>(result)) + return rollback(std::get<std::string>(result)); + + result = db.raw_exec("DROP TABLE " + tmp_name); + if (!std::get<bool>(result)) + return rollback(std::get<std::string>(result)); +} + std::unique_ptr<DatabaseEngine> Sqlite3Engine::open(const std::string& filename) { sqlite3* new_db; @@ -93,9 +137,19 @@ void Sqlite3Engine::extract_last_insert_rowid(Statement&) this->last_inserted_rowid = sqlite3_last_insert_rowid(this->db); } -std::string Sqlite3Engine::id_column_type() +std::string Sqlite3Engine::id_column_type() const { return "INTEGER PRIMARY KEY AUTOINCREMENT"; } +std::string Sqlite3Engine::datetime_column_type() const +{ + return "REAL"; +} + +long double Sqlite3Engine::epoch_to_floating_value(long double d) const +{ + return (d / 86400.0) + 2440587.5; +} + #endif diff --git a/src/database/sqlite3_engine.hpp b/src/database/sqlite3_engine.hpp index 5b8176c..82d01c9 100644 --- a/src/database/sqlite3_engine.hpp +++ b/src/database/sqlite3_engine.hpp @@ -23,12 +23,19 @@ class Sqlite3Engine: public DatabaseEngine ~Sqlite3Engine(); static std::unique_ptr<DatabaseEngine> open(const std::string& string); + EngineType engine_type() const override + { + return EngineType::Sqlite3; + } - std::set<std::string> get_all_columns_from_table(const std::string& table_name) override final; + std::map<std::string, std::string> get_all_columns_from_table(const std::string& table_name) override final; std::tuple<bool, std::string> raw_exec(const std::string& query) override final; std::unique_ptr<Statement> prepare(const std::string& query) override; void extract_last_insert_rowid(Statement& statement) override; - std::string id_column_type() override; + std::string id_column_type() const override; + std::string datetime_column_type() const override; + void convert_date_format(DatabaseEngine&) override; + long double epoch_to_floating_value(long double d) const override; private: sqlite3* const db; }; diff --git a/src/database/table.hpp b/src/database/table.hpp index c8c1bdd..31b92a7 100644 --- a/src/database/table.hpp +++ b/src/database/table.hpp @@ -19,6 +19,8 @@ std::string ToSQLType(DatabaseEngine& db) return db.id_column_type(); else if (std::is_same<typename T::real_type, std::string>::value) return "TEXT"; + else if (std::is_same<typename T::real_type, DateTime>::value) + return db.datetime_column_type(); else return "INTEGER"; } @@ -101,16 +103,16 @@ class Table template <std::size_t N=0> typename std::enable_if<N < sizeof...(T), void>::type - add_column_if_not_exists(DatabaseEngine& db, const std::set<std::string>& existing_columns) + add_column_if_not_exists(DatabaseEngine& db, const std::map<std::string, std::string>& existing_columns) { using ColumnType = typename std::remove_reference<decltype(std::get<N>(std::declval<ColumnTypes>()))>::type; - if (existing_columns.count(ColumnType::name) == 0) + if (existing_columns.find(ColumnType::name) == existing_columns.end()) add_column_to_table<ColumnType>(db, this->name); add_column_if_not_exists<N+1>(db, existing_columns); } template <std::size_t N=0> typename std::enable_if<N == sizeof...(T), void>::type - add_column_if_not_exists(DatabaseEngine&, const std::set<std::string>&) + add_column_if_not_exists(DatabaseEngine&, const std::map<std::string, std::string>&) {} template <std::size_t N=0> |