diff --git a/src/api/lib.zig b/src/api/lib.zig index 487a4a9..e6cf82a 100644 --- a/src/api/lib.zig +++ b/src/api/lib.zig @@ -94,15 +94,15 @@ pub fn setupAdmin(db: sql.Db, origin: []const u8, username: []const u8, password } pub const ApiSource = struct { - db_conn: *sql.Conn, + db_conn_pool: *sql.ConnPool, pub const Conn = ApiConn(sql.Db); const root_username = "root"; - pub fn init(db_conn: *sql.Conn) !ApiSource { + pub fn init(pool: *sql.ConnPool) !ApiSource { return ApiSource{ - .db_conn = db_conn, + .db_conn_pool = pool, }; } @@ -110,7 +110,7 @@ pub const ApiSource = struct { var arena = std.heap.ArenaAllocator.init(alloc); errdefer arena.deinit(); - const db = try self.db_conn.acquire(); + const db = try self.db_conn_pool.acquire(); const community = try services.communities.getByHost(db, host, arena.allocator()); return Conn{ @@ -125,7 +125,7 @@ pub const ApiSource = struct { var arena = std.heap.ArenaAllocator.init(alloc); errdefer arena.deinit(); - const db = try self.db_conn.acquire(); + const db = try self.db_conn_pool.acquire(); const community = try services.communities.getByHost(db, host, arena.allocator()); const token_info = try services.auth.verifyToken( @@ -157,6 +157,7 @@ fn ApiConn(comptime DbConn: type) type { pub fn close(self: *Self) void { self.arena.deinit(); + self.db.releaseConnection(); } fn isAdmin(self: *Self) bool { diff --git a/src/main/main.zig b/src/main/main.zig index f722294..1d427f0 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -64,7 +64,10 @@ fn runAdminSetup(db: sql.Db, alloc: std.mem.Allocator) !void { try api.setupAdmin(db, origin, username, password, alloc); } -fn prepareDb(db: sql.Db, alloc: std.mem.Allocator) !void { +fn prepareDb(pool: *sql.ConnPool, alloc: std.mem.Allocator) !void { + const db = try pool.acquire(); + defer db.releaseConnection(); + try migrations.up(db); if (!try api.isAdminSetup(db)) { @@ -93,10 +96,10 @@ pub fn main() !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; var cfg = try loadConfig(gpa.allocator()); - var db_conn = try sql.Conn.open(cfg.db); - try prepareDb(try db_conn.acquire(), gpa.allocator()); + var pool = try sql.ConnPool.init(cfg.db); + try prepareDb(&pool, gpa.allocator()); - var api_src = try api.ApiSource.init(&db_conn); + var api_src = try api.ApiSource.init(&pool); var srv = try RequestServer.init(gpa.allocator(), &api_src, cfg); return srv.listenAndRun(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); } diff --git a/src/main/migrations.zig b/src/main/migrations.zig index 595646e..459e75a 100644 --- a/src/main/migrations.zig +++ b/src/main/migrations.zig @@ -16,10 +16,15 @@ fn execStmt(tx: anytype, stmt: []const u8, alloc: std.mem.Allocator) !void { } fn execScript(db: anytype, script: []const u8, alloc: std.mem.Allocator) !void { + const tx = try db.beginOrSavepoint(); + errdefer tx.rollback(); + var iter = util.SqlStmtIter.from(script); while (iter.next()) |stmt| { - try execStmt(db, stmt, alloc); + try execStmt(tx, stmt, alloc); } + + try tx.commitOrRelease(); } fn wasMigrationRan(db: anytype, name: []const u8, alloc: std.mem.Allocator) !bool { diff --git a/src/sql/engines/null.zig b/src/sql/engines/null.zig index f6b2d9d..1c5a434 100644 --- a/src/sql/engines/null.zig +++ b/src/sql/engines/null.zig @@ -28,6 +28,10 @@ pub const Db = struct { unreachable; } + pub fn openUri(_: anytype) common.OpenError!Db { + unreachable; + } + pub fn close(_: Db) void { unreachable; } diff --git a/src/sql/engines/sqlite.zig b/src/sql/engines/sqlite.zig index c191d36..d1a73c2 100644 --- a/src/sql/engines/sqlite.zig +++ b/src/sql/engines/sqlite.zig @@ -49,7 +49,15 @@ pub const Db = struct { db: *c.sqlite3, pub fn open(path: [:0]const u8) common.OpenError!Db { - const flags = c.SQLITE_OPEN_READWRITE | c.SQLITE_OPEN_CREATE | c.SQLITE_OPEN_EXRESCODE; + return openInternal(path, false); + } + + pub fn openUri(path: [:0]const u8) common.OpenError!Db { + return openInternal(path, true); + } + + fn openInternal(path: [:0]const u8, is_uri: bool) common.OpenError!Db { + const flags = c.SQLITE_OPEN_READWRITE | c.SQLITE_OPEN_CREATE | c.SQLITE_OPEN_EXRESCODE | if (is_uri) c.SQLITE_OPEN_URI else 0; var db: ?*c.sqlite3 = null; switch (c.sqlite3_open_v2(@ptrCast([*c]const u8, path), &db, flags, null)) { @@ -121,7 +129,6 @@ pub const Db = struct { // of 0, and we must not bind the argument. const name = std.fmt.comptimePrint("${}", .{i + 1}); const db_idx = c.sqlite3_bind_parameter_index(stmt.?, name); - std.log.debug("param {s} got index {}", .{ name, db_idx }); if (db_idx != 0) try self.bindArgument(stmt.?, @intCast(u15, db_idx), arg) else if (!opts.ignore_unused_arguments) @@ -167,7 +174,6 @@ pub const Db = struct { else @compileError("SQLite: Could not serialize " ++ @typeName(T) ++ " into staticly sized string"); - std.log.debug("binding type {any}: {s}", .{ T, arr }); const len = std.mem.len(&arr); return self.bindString(stmt, idx, arr[0..len]); }, @@ -194,8 +200,6 @@ pub const Db = struct { return error.BindException; }; - std.log.debug("binding string {s} to idx {}", .{ str, idx }); - switch (c.sqlite3_bind_text(stmt, idx, str.ptr, len, c.SQLITE_TRANSIENT)) { c.SQLITE_OK => {}, else => |result| { diff --git a/src/sql/lib.zig b/src/sql/lib.zig index 458cf5e..1208c9c 100644 --- a/src/sql/lib.zig +++ b/src/sql/lib.zig @@ -10,6 +10,7 @@ const Allocator = std.mem.Allocator; const errors = @import("./errors.zig").library_errors; +pub const AcquireError = OpenError || error{NoConnectionsLeft}; pub const OpenError = errors.OpenError; pub const QueryError = errors.QueryError; pub const RowError = errors.RowError; @@ -24,12 +25,14 @@ pub const Engine = enum { sqlite, }; +// TODO: make this suck less pub const Config = union(Engine) { postgres: struct { pg_conn_str: [:0]const u8, }, sqlite: struct { sqlite_file_path: [:0]const u8, + sqlite_is_uri: bool = false, }, }; @@ -160,16 +163,58 @@ pub const ConstraintMode = enum { immediate, }; -pub const Conn = struct { - engine: union(Engine) { - postgres: postgres.Db, - sqlite: sqlite.Db, - }, - current_tx_level: u8 = 0, - is_tx_failed: bool = false, +pub const ConnPool = struct { + const max_conns = 4; + const Conn = struct { + engine: union(Engine) { + postgres: postgres.Db, + sqlite: sqlite.Db, + }, + in_use: std.atomic.Atomic(bool) = std.atomic.Atomic(bool).init(false), + current_tx_level: u8 = 0, + }; - pub fn open(cfg: Config) OpenError!Conn { - return switch (cfg) { + config: Config, + connections: [max_conns]Conn, + + pub fn init(cfg: Config) OpenError!ConnPool { + var self = ConnPool{ + .config = cfg, + .connections = undefined, + }; + var count: usize = 0; + errdefer for (self.connections[0..count]) |*c| closeConn(c); + for (self.connections) |*c| { + c.* = try self.createConn(); + count += 1; + } + + return self; + } + + pub fn deinit(self: *ConnPool) void { + for (self.connections) |*c| closeConn(c); + } + + pub fn acquire(self: *ConnPool) AcquireError!Db { + for (self.connections) |*c| { + if (tryAcquire(c)) return Db{ .conn = c }; + } + return error.NoConnectionsLeft; + } + + fn tryAcquire(conn: *Conn) bool { + const acquired = !conn.in_use.swap(true, .AcqRel); + if (acquired) { + if (conn.current_tx_level != 0) @panic("Transaction still open on unused db connection"); + return true; + } + + return false; + } + + fn createConn(self: *ConnPool) OpenError!Conn { + return switch (self.config) { .postgres => |postgres_cfg| Conn{ .engine = .{ .postgres = try postgres.Db.open(postgres_cfg.pg_conn_str), @@ -177,27 +222,22 @@ pub const Conn = struct { }, .sqlite => |lite_cfg| Conn{ .engine = .{ - .sqlite = try sqlite.Db.open(lite_cfg.sqlite_file_path), + .sqlite = if (lite_cfg.sqlite_is_uri) + try sqlite.Db.openUri(lite_cfg.sqlite_file_path) + else + try sqlite.Db.open(lite_cfg.sqlite_file_path), }, }, }; } - pub fn close(self: *Conn) void { - switch (self.engine) { + fn closeConn(conn: *Conn) void { + if (conn.in_use.loadUnchecked()) @panic("DB Conn still open"); + switch (conn.engine) { .postgres => |pg| pg.close(), .sqlite => |lite| lite.close(), } } - - pub fn acquire(self: *Conn) !Db { - if (self.current_tx_level != 0) return error.BadTransactionState; - return Db{ .conn = self }; - } - - pub fn sqlEngine(self: *Conn) Engine { - return self.engine; - } }; pub const Db = Tx(0); @@ -216,11 +256,22 @@ fn Tx(comptime tx_level: u8) type { std.fmt.comptimePrint("save_{}", .{tx_level}); const next_savepoint_name = Tx(tx_level + 1).savepoint_name; - conn: *Conn, + conn: *ConnPool.Conn, /// The type of SQL engine being used. Use of this function should be discouraged pub fn sqlEngine(self: Self) Engine { - return self.conn.sqlEngine(); + return self.conn.engine; + } + + /// Return the connection to the pool + pub fn releaseConnection(self: Self) void { + if (tx_level != 0) @compileError("close must be called on root db"); + if (self.conn.current_tx_level != 0) { + std.log.warn("Database released while transaction in progress!", .{}); + self.rollbackUnchecked() catch {}; + } + + if (!self.conn.in_use.swap(false, .AcqRel)) @panic("Double close on db conection"); } // ********* Transaction management functions ********** @@ -277,7 +328,7 @@ fn Tx(comptime tx_level: u8) type { if (tx_level >= 2) @compileError("Cannot rollback a transaction using a savepoint"); if (self.conn.current_tx_level == 0) return error.BadTransactionState; - try self.exec("ROLLBACK", {}, null); + try self.rollbackUnchecked(); self.conn.current_tx_level = 0; } @@ -453,5 +504,9 @@ fn Tx(comptime tx_level: u8) type { while (try results.row()) |_| {} } + + fn rollbackUnchecked(self: Self) !void { + try self.exec("ROLLBACK", {}, null); + } }; } diff --git a/tests/api_integration/lib.zig b/tests/api_integration/lib.zig index 1027043..ea0fd5d 100644 --- a/tests/api_integration/lib.zig +++ b/tests/api_integration/lib.zig @@ -5,11 +5,10 @@ const sql = @import("sql"); const util = @import("util"); const test_config = .{ - .db = .{ - .sqlite = .{ - .sqlite_file_path = ":memory:", - }, - }, + .db = .{ .sqlite = .{ + .sqlite_file_path = "file::memory:?cache=shared", + .sqlite_is_uri = true, + } }, }; const ApiSource = api.ApiSource; @@ -18,12 +17,16 @@ const root_password = "password1234"; const admin_host = "example.com"; const admin_origin = "https://" ++ admin_host; -fn makeDb(alloc: std.mem.Allocator) !sql.Conn { +fn makeDb(alloc: std.mem.Allocator) !sql.ConnPool { try util.seedThreadPrng(); - var db = try sql.Conn.open(test_config.db); - try migrations.up(try db.acquire()); - try api.setupAdmin(try db.acquire(), admin_origin, root_user, root_password, alloc); - return db; + var pool = try sql.ConnPool.init(test_config.db); + { + var db = try pool.acquire(); + defer db.releaseConnection(); + try migrations.up(db); + try api.setupAdmin(db, admin_origin, root_user, root_password, alloc); + } + return pool; } fn connectAndLogin( @@ -42,6 +45,7 @@ fn connectAndLogin( test "login as root" { const alloc = std.testing.allocator; var db = try makeDb(alloc); + defer db.deinit(); var src = try ApiSource.init(&db); const login = try connectAndLogin(&src, admin_host, root_user, root_password, alloc); @@ -59,6 +63,7 @@ test "login as root" { test "create community" { const alloc = std.testing.allocator; var db = try makeDb(alloc); + defer db.deinit(); var src = try ApiSource.init(&db); const login = try connectAndLogin(&src, admin_host, root_user, root_password, alloc); @@ -80,6 +85,7 @@ test "create community" { test "create community and transfer to new owner" { const alloc = std.testing.allocator; var db = try makeDb(alloc); + defer db.deinit(); var src = try ApiSource.init(&db); const root_login = try connectAndLogin(&src, admin_host, root_user, root_password, alloc);