diff options
-rw-r--r-- | src/database/count_query.hpp | 6 | ||||
-rw-r--r-- | src/database/insert_query.hpp | 17 | ||||
-rw-r--r-- | src/database/query.hpp | 6 | ||||
-rw-r--r-- | src/database/select_query.hpp | 19 | ||||
-rw-r--r-- | src/database/statement.hpp | 35 |
5 files changed, 61 insertions, 22 deletions
diff --git a/src/database/count_query.hpp b/src/database/count_query.hpp index 863fad1..322ad1b 100644 --- a/src/database/count_query.hpp +++ b/src/database/count_query.hpp @@ -19,14 +19,14 @@ struct CountQuery: public Query { auto statement = this->prepare(db); std::size_t res = 0; - if (sqlite3_step(statement) == SQLITE_ROW) - res = sqlite3_column_int64(statement, 0); + if (sqlite3_step(statement.get()) == SQLITE_ROW) + res = sqlite3_column_int64(statement.get(), 0); else { log_error("Count request didn’t return a result"); return 0; } - if (sqlite3_step(statement) != SQLITE_DONE) + if (sqlite3_step(statement.get()) != SQLITE_DONE) log_warning("Count request returned more than one result."); log_debug("Returning count: ", res); diff --git a/src/database/insert_query.hpp b/src/database/insert_query.hpp index 00b77c5..1712916 100644 --- a/src/database/insert_query.hpp +++ b/src/database/insert_query.hpp @@ -1,5 +1,6 @@ #pragma once +#include <database/statement.hpp> #include <database/column.hpp> #include <database/query.hpp> #include <logger/logger.hpp> @@ -13,11 +14,11 @@ template <int N, typename ColumnType, typename... T> typename std::enable_if<!std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type -actual_bind(sqlite3_stmt* statement, std::vector<std::string>& params, const std::tuple<T...>&) +actual_bind(Statement& statement, std::vector<std::string>& params, const std::tuple<T...>&) { const auto value = params.front(); params.erase(params.begin()); - if (sqlite3_bind_text(statement, N + 1, value.data(), static_cast<int>(value.size()), SQLITE_TRANSIENT) != SQLITE_OK) + if (sqlite3_bind_text(statement.get(), N + 1, value.data(), static_cast<int>(value.size()), SQLITE_TRANSIENT) != SQLITE_OK) log_error("Failed to bind ", value, " to param ", N); else log_debug("Bound (not id) [", value, "] to ", N); @@ -25,15 +26,15 @@ actual_bind(sqlite3_stmt* statement, std::vector<std::string>& params, const std template <int N, typename ColumnType, typename... T> typename std::enable_if<std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type -actual_bind(sqlite3_stmt* statement, std::vector<std::string>&, const std::tuple<T...>& columns) +actual_bind(Statement& statement, std::vector<std::string>&, const std::tuple<T...>& columns) { auto&& column = std::get<Id>(columns); if (column.value != 0) { - if (sqlite3_bind_int64(statement, N + 1, column.value) != SQLITE_OK) + if (sqlite3_bind_int64(statement.get(), N + 1, column.value) != SQLITE_OK) log_error("Failed to bind ", column.value, " to id."); } - else if (sqlite3_bind_null(statement, N + 1) != SQLITE_OK) + else if (sqlite3_bind_null(statement.get(), N + 1) != SQLITE_OK) log_error("Failed to bind NULL to param ", N); else log_debug("Bound NULL to ", N); @@ -53,14 +54,14 @@ struct InsertQuery: public Query auto statement = this->prepare(db); { this->bind_param<0>(columns, statement); - if (sqlite3_step(statement) != SQLITE_DONE) + if (sqlite3_step(statement.get()) != SQLITE_DONE) log_error("Failed to execute query: ", sqlite3_errmsg(db)); } } template <int N, typename... T> typename std::enable_if<N < sizeof...(T), void>::type - bind_param(const std::tuple<T...>& columns, sqlite3_stmt* statement) + bind_param(const std::tuple<T...>& columns, Statement& statement) { using ColumnType = typename std::remove_reference<decltype(std::get<N>(columns))>::type; @@ -70,7 +71,7 @@ struct InsertQuery: public Query template <int N, typename... T> typename std::enable_if<N == sizeof...(T), void>::type - bind_param(const std::tuple<T...>&, sqlite3_stmt*) + bind_param(const std::tuple<T...>&, Statement&) {} template <typename... T> diff --git a/src/database/query.hpp b/src/database/query.hpp index 92845d0..b77a421 100644 --- a/src/database/query.hpp +++ b/src/database/query.hpp @@ -1,5 +1,7 @@ #pragma once +#include <database/statement.hpp> + #include <logger/logger.hpp> #include <vector> @@ -16,7 +18,7 @@ struct Query body(std::move(str)) {} - sqlite3_stmt* prepare(sqlite3* db) + Statement prepare(sqlite3* db) { sqlite3_stmt* statement; log_debug(this->body); @@ -27,7 +29,7 @@ struct Query log_error("Error preparing statement: ", sqlite3_errmsg(db)); return nullptr; } - return statement; + return {statement}; } }; diff --git a/src/database/select_query.hpp b/src/database/select_query.hpp index b41632e..80d1424 100644 --- a/src/database/select_query.hpp +++ b/src/database/select_query.hpp @@ -1,5 +1,6 @@ #pragma once +#include <database/statement.hpp> #include <database/query.hpp> #include <logger/logger.hpp> #include <database/row.hpp> @@ -13,24 +14,24 @@ using namespace std::string_literals; template <typename T> typename std::enable_if<std::is_integral<T>::value, sqlite3_int64>::type -extract_row_value(sqlite3_stmt* statement, const int i) +extract_row_value(Statement& statement, const int i) { - return sqlite3_column_int64(statement, i); + return sqlite3_column_int64(statement.get(), i); } template <typename T> typename std::enable_if<std::is_same<std::string, T>::value, std::string>::type -extract_row_value(sqlite3_stmt* statement, const int i) +extract_row_value(Statement& statement, const int i) { - const auto size = sqlite3_column_bytes(statement, i); - const unsigned char* str = sqlite3_column_text(statement, i); + const auto size = sqlite3_column_bytes(statement.get(), i); + const unsigned char* str = sqlite3_column_text(statement.get(), i); std::string result(reinterpret_cast<const char*>(str), size); return result; } template <std::size_t N=0, typename... T> typename std::enable_if<N < sizeof...(T), void>::type -extract_row_values(Row<T...>& row, sqlite3_stmt* statement) +extract_row_values(Row<T...>& row, Statement& statement) { using ColumnType = typename std::remove_reference<decltype(std::get<N>(row.columns))>::type; @@ -42,7 +43,7 @@ extract_row_values(Row<T...>& row, sqlite3_stmt* statement) template <std::size_t N=0, typename... T> typename std::enable_if<N == sizeof...(T), void>::type -extract_row_values(Row<T...>&, sqlite3_stmt*) +extract_row_values(Row<T...>&, Statement&) {} template <typename... T> @@ -100,7 +101,7 @@ struct SelectQuery: public Query int i = 1; for (const std::string& param: this->params) { - if (sqlite3_bind_text(statement, i, param.data(), static_cast<int>(param.size()), SQLITE_TRANSIENT) != SQLITE_OK) + if (sqlite3_bind_text(statement.get(), i, param.data(), static_cast<int>(param.size()), SQLITE_TRANSIENT) != SQLITE_OK) log_debug("Failed to bind ", param, " to param ", i); else log_debug("Bound ", param, " to ", i); @@ -108,7 +109,7 @@ struct SelectQuery: public Query i++; } std::vector<Row<T...>> rows; - while (sqlite3_step(statement) == SQLITE_ROW) + while (sqlite3_step(statement.get()) == SQLITE_ROW) { Row<T...> row(this->table_name); extract_row_values(row, statement); diff --git a/src/database/statement.hpp b/src/database/statement.hpp new file mode 100644 index 0000000..87cd70f --- /dev/null +++ b/src/database/statement.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include <sqlite3.h> + +class Statement +{ + public: + Statement(sqlite3_stmt* stmt): + stmt(stmt) {} + ~Statement() + { + sqlite3_finalize(this->stmt); + } + + Statement(const Statement&) = delete; + Statement& operator=(const Statement&) = delete; + Statement(Statement&& other): + stmt(other.stmt) + { + other.stmt = nullptr; + } + Statement& operator=(Statement&& other) + { + this->stmt = other.stmt; + other.stmt = nullptr; + return *this; + } + sqlite3_stmt* get() + { + return this->stmt; + } + + private: + sqlite3_stmt* stmt; +}; |