diff --git a/src/main/api.zig b/src/main/api.zig index fae758a..ca9b119 100644 --- a/src/main/api.zig +++ b/src/main/api.zig @@ -95,15 +95,15 @@ pub fn getRandom() std.rand.Random { } pub const ApiSource = struct { - db: sql.Db, + db: *sql.Db, internal_alloc: std.mem.Allocator, config: Config, - pub const Conn = ApiConn(sql.Db); + pub const Conn = ApiConn(*sql.Db); const root_username = "root"; - pub fn init(alloc: std.mem.Allocator, cfg: Config, root_password: ?[]const u8, db_conn: sql.Db) !ApiSource { + pub fn init(alloc: std.mem.Allocator, cfg: Config, root_password: ?[]const u8, db_conn: *sql.Db) !ApiSource { var self = ApiSource{ .db = db_conn, .internal_alloc = alloc, @@ -112,14 +112,14 @@ pub const ApiSource = struct { try migrations.up(db_conn); - if ((try services.users.lookupByUsername(&self.db, root_username, null)) == null) { + if ((try services.users.lookupByUsername(self.db, root_username, null)) == null) { std.log.info("No cluster root user detected. Creating...", .{}); // TODO: Fix this const password = root_password orelse return error.NeedRootPassword; var arena = std.heap.ArenaAllocator.init(alloc); defer arena.deinit(); - const user_id = try services.users.create(&self.db, root_username, password, null, .{}, arena.allocator()); + const user_id = try services.users.create(self.db, root_username, password, null, .{}, arena.allocator()); std.log.debug("Created {s} ID {}", .{ root_username, user_id }); } @@ -157,7 +157,7 @@ pub const ApiSource = struct { pub fn connectToken(self: *ApiSource, host: []const u8, token: []const u8, alloc: std.mem.Allocator) !Conn { const community_id = try self.getCommunityFromHost(host); - const token_info = try services.auth.tokens.verify(&self.db, token, community_id); + const token_info = try services.auth.tokens.verify(self.db, token, community_id); return Conn{ .db = self.db, @@ -189,10 +189,10 @@ fn ApiConn(comptime DbConn: type) type { } pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResponse { - const user_id = (try services.users.lookupByUsername(&self.db, username, self.community_id)) orelse return error.InvalidLogin; - try services.auth.passwords.verify(&self.db, user_id, password, self.internal_alloc); + const user_id = (try services.users.lookupByUsername(self.db, username, self.community_id)) orelse return error.InvalidLogin; + try services.auth.passwords.verify(self.db, user_id, password, self.internal_alloc); - const token = try services.auth.tokens.create(&self.db, user_id); + const token = try services.auth.tokens.create(self.db, user_id); return LoginResponse{ .user_id = user_id, @@ -225,7 +225,7 @@ fn ApiConn(comptime DbConn: type) type { return error.PermissionDenied; } - return services.communities.create(&self.db, origin, null); + return services.communities.create(self.db, origin, null); } pub fn createInvite(self: *Self, options: InviteRequest) !services.invites.Invite { @@ -239,13 +239,13 @@ fn ApiConn(comptime DbConn: type) type { // Only admins can invite on the admin domain if (!self.isAdmin()) return error.PermissionDenied; - break :blk (try services.communities.getByHost(&self.db, host, self.arena.allocator())).id; + break :blk (try services.communities.getByHost(self.db, host, self.arena.allocator())).id; } else self.community_id; // Users can only make user invites if (options.invite_type != .user and !self.isAdmin()) return error.PermissionDenied; - return try services.invites.create(&self.db, user_id, community_id, .{ + return try services.invites.create(self.db, user_id, community_id, .{ .name = options.name, .expires_at = options.expires_at, .max_uses = options.max_uses, @@ -255,7 +255,7 @@ fn ApiConn(comptime DbConn: type) type { pub fn register(self: *Self, request: RegistrationRequest) !UserResponse { std.log.debug("registering user {s} with code {s}", .{ request.username, request.invite_code }); - const invite = try services.invites.getByCode(&self.db, request.invite_code, self.arena.allocator()); + const invite = try services.invites.getByCode(self.db, request.invite_code, self.arena.allocator()); if (!Uuid.eql(invite.to_community, self.community_id)) return error.NotFound; if (invite.max_uses != null and invite.times_used >= invite.max_uses.?) return error.InviteExpired; @@ -263,13 +263,13 @@ fn ApiConn(comptime DbConn: type) type { if (self.community_id == null) @panic("Unimplmented"); - const user_id = try services.users.create(&self.db, request.username, request.password, self.community_id, .{ .invite_id = invite.id, .email = request.email }, self.internal_alloc); + const user_id = try services.users.create(self.db, request.username, request.password, self.community_id, .{ .invite_id = invite.id, .email = request.email }, self.internal_alloc); switch (invite.invite_type) { .user => {}, .system => @panic("System user invites unimplemented"), .community_owner => { - try services.communities.transferOwnership(&self.db, self.community_id.?, user_id); + try services.communities.transferOwnership(self.db, self.community_id.?, user_id); }, } @@ -280,7 +280,7 @@ fn ApiConn(comptime DbConn: type) type { } pub fn getUser(self: *Self, user_id: Uuid) !UserResponse { - const user = try services.users.get(&self.db, user_id, self.arena.allocator()); + const user = try services.users.get(self.db, user_id, self.arena.allocator()); if (self.user_id == null) { if (!Uuid.eql(self.community_id, user.community_id)) return error.NotFound; @@ -298,7 +298,7 @@ fn ApiConn(comptime DbConn: type) type { if (self.community_id == null) return error.WrongCommunity; const user_id = self.user_id orelse return error.TokenRequired; - const note_id = try services.notes.create(&self.db, user_id, content); + const note_id = try services.notes.create(self.db, user_id, content); return self.getNote(note_id) catch |err| switch (err) { error.NotFound => error.Unexpected, @@ -307,8 +307,8 @@ fn ApiConn(comptime DbConn: type) type { } pub fn getNote(self: *Self, note_id: Uuid) !NoteResponse { - const note = try services.notes.get(&self.db, note_id, self.arena.allocator()); - const user = try services.users.get(&self.db, note.author_id, self.arena.allocator()); + const note = try services.notes.get(self.db, note_id, self.arena.allocator()); + const user = try services.users.get(self.db, note.author_id, self.arena.allocator()); // Only serve community-specific notes on unauthenticated requests if (self.user_id == null) { @@ -329,7 +329,7 @@ fn ApiConn(comptime DbConn: type) type { pub fn queryCommunities(self: *Self, args: services.communities.QueryArgs) ![]services.communities.Community { if (!self.isAdmin()) return error.PermissionDenied; - return services.communities.query(&self.db, args, self.arena.allocator()); + return services.communities.query(self.db, args, self.arena.allocator()); } }; } diff --git a/src/main/main.zig b/src/main/main.zig index e1bad55..e03d54c 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -103,7 +103,7 @@ pub fn main() anyerror!void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; var cfg = try loadConfig(gpa.allocator()); var db_conn = try sql.Db.open(cfg.db); - var api_src = api.ApiSource.init(gpa.allocator(), cfg, std.os.getenv(root_password_envvar), db_conn) catch |err| switch (err) { + var api_src = api.ApiSource.init(gpa.allocator(), cfg, std.os.getenv(root_password_envvar), &db_conn) catch |err| switch (err) { error.NeedRootPassword => { std.log.err( "No root user created and no password specified. Please provide the password for the root user by the ${s} environment variable for initial startup. This only needs to be done once", diff --git a/src/main/migrations.zig b/src/main/migrations.zig index cc4c1c9..83fba1d 100644 --- a/src/main/migrations.zig +++ b/src/main/migrations.zig @@ -22,7 +22,7 @@ fn execStmt(tx: sql.Tx, stmt: []const u8, alloc: std.mem.Allocator) !void { try tx.exec(stmt_null, .{}, null); } -fn execScript(db: sql.Db, script: []const u8, alloc: std.mem.Allocator) !void { +fn execScript(db: *sql.Db, script: []const u8, alloc: std.mem.Allocator) !void { const tx = try db.begin(); errdefer tx.rollback(); @@ -37,12 +37,12 @@ fn execScript(db: sql.Db, script: []const u8, alloc: std.mem.Allocator) !void { try tx.commit(); } -fn wasMigrationRan(db: sql.Db, name: []const u8, alloc: std.mem.Allocator) !bool { +fn wasMigrationRan(db: anytype, name: []const u8, alloc: std.mem.Allocator) !bool { const row = (try db.queryRow(&.{i32}, "SELECT COUNT(*) FROM migration WHERE name = $1", .{name}, alloc)) orelse return false; return row[0] != 0; } -pub fn up(db: sql.Db) !void { +pub fn up(db: *sql.Db) !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; defer _ = gpa.deinit(); std.log.info("Running migrations...", .{}); diff --git a/src/sql/lib.zig b/src/sql/lib.zig index 26cbf22..e122f8c 100644 --- a/src/sql/lib.zig +++ b/src/sql/lib.zig @@ -8,12 +8,12 @@ const Allocator = std.mem.Allocator; pub const QueryOptions = common.QueryOptions; -pub const Type = enum { +pub const Engine = enum { postgres, sqlite, }; -pub const Config = union(Type) { +pub const Config = union(Engine) { postgres: struct { pg_conn_str: [:0]const u8, }, @@ -23,7 +23,7 @@ pub const Config = union(Type) { }; //pub const OpenError = sqlite.OpenError | postgres.OpenError; -const RawResults = union(Type) { +const RawResults = union(Engine) { postgres: postgres.Results, sqlite: sqlite.Results, @@ -82,7 +82,7 @@ pub fn Results(comptime result_types: []const type) type { } // Row is invalidated by the next call to result.row() -const Row = union(Type) { +const Row = union(Engine) { postgres: postgres.Row, sqlite: sqlite.Row, @@ -98,13 +98,20 @@ const Row = union(Type) { } }; -const DbUnion = union(Type) { +const DbUnion = union(Engine) { postgres: postgres.Db, sqlite: sqlite.Db, }; +pub const ConstraintMode = enum { + deferred, + immediate, +}; + pub const Db = struct { + tx_open: bool = false, underlying: DbUnion, + pub fn open(cfg: Config) !Db { return switch (cfg) { .postgres => |postgres_cfg| Db{ @@ -120,7 +127,7 @@ pub const Db = struct { }; } - pub fn close(self: Db) void { + pub fn close(self: *Db) void { switch (self.underlying) { .postgres => |pg| pg.close(), .sqlite => |lite| lite.close(), @@ -128,60 +135,69 @@ pub const Db = struct { } pub fn queryWithOptions( - self: Db, + self: *Db, comptime result_types: []const type, sql: [:0]const u8, args: anytype, opt: QueryOptions, ) !Results(result_types) { + if (self.tx_open) return error.TransactionOpen; // Create fake transaction to use its functions - return (Tx{ .underlying = self.underlying }).queryWithOptions(result_types, sql, args, opt); + return (Tx{ .db = self }).queryWithOptions(result_types, sql, args, opt); } pub fn query( - self: Db, + self: *Db, comptime result_types: []const type, sql: [:0]const u8, args: anytype, alloc: ?Allocator, ) !Results(result_types) { + if (self.tx_open) return error.TransactionOpen; // Create fake transaction to use its functions - return (Tx{ .underlying = self.underlying }).query(result_types, sql, args, alloc); + return (Tx{ .db = self }).query(result_types, sql, args, alloc); } pub fn exec( - self: Db, + self: *Db, sql: [:0]const u8, args: anytype, alloc: ?Allocator, ) !void { + if (self.tx_open) return error.TransactionOpen; // Create fake transaction to use its functions - return (Tx{ .underlying = self.underlying }).exec(sql, args, alloc); + return (Tx{ .db = self }).exec(sql, args, alloc); } pub fn queryRow( - self: Db, + self: *Db, comptime result_types: []const type, sql: [:0]const u8, args: anytype, alloc: ?Allocator, ) !?Results(result_types).RowTuple { + if (self.tx_open) return error.TransactionOpen; // Create fake transaction to use its functions - return (Tx{ .underlying = self.underlying }).queryRow(result_types, sql, args, alloc); + return (Tx{ .db = self }).queryRow(result_types, sql, args, alloc); } pub fn insert( - self: Db, + self: *Db, comptime table: []const u8, value: anytype, ) !void { + if (self.tx_open) return error.TransactionOpen; // Create fake transaction to use its functions - return (Tx{ .underlying = self.underlying }).insert(table, value); + return (Tx{ .db = self }).insert(table, value); + } + + pub fn sqlEngine(self: *Db) Engine { + return self.underlying; } // Begins a transaction - pub fn begin(self: Db) !Tx { - const tx = Tx{ .underlying = self.underlying }; + pub fn begin(self: *Db) !Tx { + const tx = Tx{ .db = self }; try tx.exec("BEGIN", .{}, null); return tx; @@ -189,7 +205,7 @@ pub const Db = struct { }; pub const Tx = struct { - underlying: DbUnion, + db: *Db, // internal helper fn fn queryInternal( @@ -198,7 +214,7 @@ pub const Tx = struct { args: anytype, opt: QueryOptions, ) !RawResults { - return switch (self.underlying) { + return switch (self.db.underlying) { .postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, opt.prep_allocator) }, .sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args, opt) }, }; @@ -290,6 +306,31 @@ pub const Tx = struct { try self.exec(q, args_tuple, null); } + pub fn sqlEngine(self: Tx) Engine { + return self.db.underlying; + } + + pub fn setConstraintMode(self: Tx, mode: ConstraintMode) !void { + switch (self.db.underlying) { + .sqlite => try self.exec( + switch (mode) { + .immediate => "PRAGMA defer_foreign_keys = FALSE", + .deferred => "PRAGMA defer_foreign_keys = TRUE", + }, + .{}, + null, + ), + .postgres => try self.exec( + switch (mode) { + .immediate => "SET CONSTRAINTS ALL IMMEDIATE", + .deferred => "SET CONSTRAINTS ALL DEFERRED", + }, + .{}, + null, + ), + } + } + pub fn rollback(self: Tx) void { self.exec("ROLLBACK", .{}, null) catch |err| { std.log.err("Error occured during rollback operation: {}", .{err});