const std = @import("std"); const util = @import("util"); const postgres = @import("./postgres.zig"); const sqlite = @import("./sqlite.zig"); const Allocator = std.mem.Allocator; pub const Type = enum { postgres, sqlite, }; pub const Config = union(Type) { postgres: struct { conn_str: [:0]const u8, }, sqlite: struct { file_path: [:0]const u8, }, }; //pub const OpenError = sqlite.OpenError | postgres.OpenError; const RawResults = union(Type) { postgres: postgres.Results, sqlite: sqlite.Results, fn finish(self: RawResults) void { switch (self) { .postgres => |pg| pg.finish(), .sqlite => |lite| lite.finish(), } } }; // Represents a set of results. // row() must be called until it returns null, or the query may not complete // Must be deallocated by a call to finish() pub fn Results(comptime result_types: []const type) type { return struct { const Self = @This(); const RowTuple = std.meta.Tuple(result_types); underlying: RawResults, pub fn finish(self: Self) void { self.underlying.finish(); } // can be used as an optimization to reduce memory reallocation // only works on postgres pub fn rowCount(self: Self) ?usize { return switch (self.underlying) { .postgres => |pg| pg.rowCount(), .sqlite => null, // not possible without repeating the query }; } pub fn row(self: *Self, alloc: ?Allocator) !?RowTuple { const row_val = switch (self.underlying) { .postgres => |*pg| if (try pg.row()) |r| Row{ .postgres = r } else return null, .sqlite => |*lite| if (try lite.row()) |r| Row{ .sqlite = r } else return null, }; var result: RowTuple = undefined; var fields_allocated = [_]bool{false} ** result.len; errdefer { inline for (result_types) |_, i| { if (fields_allocated[i]) util.deepFree(alloc, result[i]); } } inline for (result_types) |T, i| { result[i] = try row_val.get(T, i, alloc); fields_allocated[i] = true; } return result; } }; } // Row is invalidated by the next call to result.row() const Row = union(Type) { postgres: postgres.Row, sqlite: sqlite.Row, // Returns a value of type T from the zero-indexed column given by idx. // Not all types require an allocator to be present. If an allocator is needed but // not required, it will return error.AllocatorRequired. // The caller is responsible for deallocating T, if relevant. fn get(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) anyerror!T { return switch (self) { .postgres => |pg| pg.get(T, idx, alloc), .sqlite => |lite| lite.get(T, idx, alloc), }; } }; const DbUnion = union(Type) { postgres: postgres.Db, sqlite: sqlite.Db, }; pub const Db = struct { underlying: DbUnion, pub fn open(cfg: Config) !Db { return switch (cfg) { .postgres => |postgres_cfg| Db{ .underlying = .{ .postgres = try postgres.Db.open(postgres_cfg.conn_str), }, }, .sqlite => |lite_cfg| Db{ .underlying = .{ .sqlite = try sqlite.Db.open(lite_cfg.file_path), }, }, }; } pub fn close(self: Db) void { switch (self.underlying) { .postgres => |pg| pg.close(), .sqlite => |lite| lite.close(), } } pub fn query( self: Db, comptime result_types: []const type, sql: [:0]const u8, args: anytype, alloc: ?Allocator, ) !Results(result_types) { // Create fake transaction to use its functions return (Tx{ .underlying = self.underlying }).query(result_types, sql, args, alloc); } pub fn exec( self: Db, sql: [:0]const u8, args: anytype, alloc: ?Allocator, ) !void { // Create fake transaction to use its functions return (Tx{ .underlying = self.underlying }).exec(sql, args, alloc); } pub fn queryRow( self: Db, comptime result_types: []const type, sql: [:0]const u8, args: anytype, alloc: ?Allocator, ) !?Results(result_types).RowTuple { // Create fake transaction to use its functions return (Tx{ .underlying = self.underlying }).exec(sql, args, alloc); } pub fn insert( self: Db, comptime table: []const u8, value: anytype, ) !void { // Create fake transaction to use its functions return (Tx{ .underlying = self.underlying }).insert(table, value); } // Begins a transaction pub fn begin(self: Db) !Tx { const tx = Tx{ .underlying = self.underlying }; try tx.exec("BEGIN", .{}, null); return tx; } }; pub const Tx = struct { underlying: DbUnion, // internal helper fn fn queryInternal( self: Tx, sql: [:0]const u8, args: anytype, alloc: ?Allocator, ) !RawResults { return switch (self.underlying) { .postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, alloc) }, .sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args) }, }; } // Executes a query and returns the result set pub fn query( self: Tx, comptime result_types: []const type, sql: [:0]const u8, args: anytype, alloc: ?Allocator, ) !Results(result_types) { return Results(result_types){ .unerlying = try self.queryInternal(sql, args, alloc) }; } // Executes a query without returning results pub fn exec( self: Tx, sql: [:0]const u8, args: anytype, alloc: ?Allocator, ) !void { (try self.queryInternal(sql, args, alloc)).finish(); } // Runs a query and returns a single row pub fn queryRow( self: Tx, comptime result_types: []const type, q: [:0]const u8, args: anytype, alloc: ?Allocator, ) !?Results(result_types).RowTuple { var results = try self.query(result_types, q, args, alloc); defer results.finish(); const row = (try results.row(alloc)) orelse return null; errdefer util.deepFree(alloc, row); var more_rows = false; while (try results.row(alloc)) |r| { util.deepFree(alloc, r); more_rows = true; } if (more_rows) return error.TooManyRows; return row; } // Inserts a single value into a table pub fn insert( self: Tx, comptime table: []const u8, value: anytype, ) !void { const ValueType = comptime @TypeOf(value); const table_spec = comptime table ++ build_field_list(ValueType, null); const value_spec = comptime build_field_list(ValueType, "?"); const q = comptime std.fmt.comptimePrint( "INSERT INTO {s} VALUES {s}", .{ table_spec, value_spec }, ); try self.exec(q, value, null); } pub fn rollback(self: Tx) void { self.exec("ROLLBACK", .{}, null) catch |err| { std.log.err("Error occured during rollback operation: {}", .{err}); }; } pub fn commit(self: Tx) !void { try self.exec("COMMIT", .{}, null); } }; fn build_field_list(comptime T: type, comptime placeholder: ?[]const u8) []const u8 { comptime { const joiner = ","; var result: []const u8 = ""; inline for (std.meta.fields(T)) |f| { result = result ++ joiner ++ (placeholder orelse f.name); } return "(" ++ result[joiner.len..] ++ ")"; } }