const std = @import("std"); const util = @import("util"); const build_options = @import("build_options"); const postgres = if (build_options.enable_postgres) @import("./engines/postgres.zig") else @import("./engines/null.zig"); const sqlite = if (build_options.enable_sqlite) @import("./engines/sqlite.zig") else @import("./engines/null.zig"); const common = @import("./engines/common.zig"); const Allocator = std.mem.Allocator; const errors = @import("./errors.zig").library_errors; pub const AcquireError = OpenError || error{NoConnectionsLeft}; pub const OpenError = errors.OpenError; pub const QueryError = errors.QueryError; pub const RowError = errors.RowError; pub const QueryRowError = errors.QueryRowError; pub const BeginError = errors.BeginError; pub const CommitError = errors.CommitError; pub const DatabaseError = QueryError || RowError || QueryRowError || BeginError || CommitError; pub const QueryOptions = common.QueryOptions; pub const Engine = enum { postgres, sqlite, }; /// Helper for building queries at runtime. All constituent parts of the /// query should be defined at comptime, however the choice of whether /// or not to include them can occur at runtime. pub const QueryBuilder = struct { array: std.ArrayList(u8), where_clauses_appended: usize = 0, set_statements_appended: usize = 0, pub fn init(alloc: std.mem.Allocator) QueryBuilder { return QueryBuilder{ .array = std.ArrayList(u8).init(alloc) }; } pub fn deinit(self: *const QueryBuilder) void { self.array.deinit(); } /// Add a chunk of sql to the query without processing pub fn appendSlice(self: *QueryBuilder, comptime sql: []const u8) !void { try self.array.appendSlice(sql); } /// Add a where clause to the query. Clauses are assumed to be components /// in an overall expression in Conjunctive Normal Form (AND of OR's). /// https://en.wikipedia.org/wiki/Conjunctive_normal_form /// All calls to andWhere must be contiguous, that is, they cannot be /// interspersed with calls to appendSlice pub fn andWhere(self: *QueryBuilder, comptime clause: []const u8) !void { if (self.where_clauses_appended == 0) { try self.array.appendSlice("\nWHERE "); } else { try self.array.appendSlice(" AND "); } try self.array.appendSlice(clause); self.where_clauses_appended += 1; } pub fn set(self: *QueryBuilder, comptime col: []const u8, comptime val: []const u8) !void { if (self.set_statements_appended == 0) { try self.array.appendSlice("\nSET "); } else { try self.array.appendSlice(", "); } try self.array.appendSlice(col ++ " = " ++ val); self.set_statements_appended += 1; } pub fn str(self: *const QueryBuilder) []const u8 { return self.array.items; } pub fn terminate(self: *QueryBuilder) ![:0]const u8 { std.debug.assert(self.array.items.len != 0); if (self.array.items[self.array.items.len - 1] != 0) try self.array.append(0); return std.meta.assumeSentinel(self.array.items, 0); } }; // TODO: make this suck less pub const Config = union(Engine) { postgres: struct { pg_conn_str: [:0]const u8, }, sqlite: struct { sqlite_file_path: [:0]const u8, sqlite_is_uri: bool = false, }, }; const RawResults = union(Engine) { postgres: postgres.Results, sqlite: sqlite.Results, fn finish(self: RawResults) void { switch (self) { .postgres => |pg| pg.finish(), .sqlite => |lite| lite.finish(), } } fn columnCount(self: RawResults) !u15 { return try switch (self) { .postgres => |pg| pg.columnCount(), .sqlite => |lite| lite.columnCount(), }; } fn columnIndex(self: RawResults, name: []const u8) error{ NotFound, Unexpected }!u15 { return switch (self) { .postgres => |pg| pg.columnIndex(name), .sqlite => |lite| lite.columnIndex(name), } catch |err| switch (err) { error.OutOfRange => error.Unexpected, error.NotFound => error.NotFound, }; } fn row(self: *RawResults) RowError!?Row { return switch (self.*) { .postgres => |*pg| if (try pg.row()) |r| Row{ .postgres = r } else null, .sqlite => |*lite| if (try lite.row()) |r| Row{ .sqlite = r } else null, }; } }; const FieldRef = []const []const u8; fn FieldPtr(comptime Ptr: type, comptime names: FieldRef) type { if (names.len == 0) return Ptr; const T = std.meta.Child(Ptr); const field = for (@typeInfo(T).Struct.fields) |f| { if (std.mem.eql(u8, f.name, names[0])) break f; } else @compileError("Unknown field " ++ names[0] ++ " in type " ++ @typeName(T)); return FieldPtr(*field.field_type, names[1..]); } fn fieldPtr(ptr: anytype, comptime names: FieldRef) FieldPtr(@TypeOf(ptr), names) { if (names.len == 0) return ptr; return fieldPtr(&@field(ptr.*, names[0]), names[1..]); } fn getRecursiveFieldList(comptime T: type, comptime prefix: FieldRef, comptime options: anytype) []const FieldRef { comptime { if (std.meta.trait.is(.Union)(T) and prefix.len == 0 and options.embed_unions) { @compileError("Cannot embed a union into nothing"); } if (options.isScalar(T)) return &.{prefix}; if (std.meta.trait.is(.Optional)(T)) return getRecursiveFieldList(std.meta.Child(T), prefix, options); const eff_prefix: FieldRef = if (std.meta.trait.is(.Union)(T) and options.embed_unions) prefix[0 .. prefix.len - 1] else prefix; var fields: []const FieldRef = &.{}; for (std.meta.fields(T)) |f| { const new_prefix = eff_prefix ++ &[_][]const u8{f.name}; if (@hasDecl(T, "sql_serialize") and @hasDecl(T.sql_serialize, f.name) and @field(T.sql_serialize, f.name) == .json) { fields = fields ++ &[_]FieldRef{new_prefix}; } else { const F = f.field_type; fields = fields ++ getRecursiveFieldList(F, new_prefix, options); } } return fields; } } // Represents a set of results. // row() must be called until it returns null, or the query may not complete // Must be deallocated by a call to finish() pub fn Results(comptime T: type) type { // would normally make this a declaration of the struct, but it causes the compiler to crash const fields = if (T == void) .{} else getRecursiveFieldList( T, &.{}, util.serialize.default_options, ); return struct { const Self = @This(); underlying: RawResults, column_indices: [fields.len]u15, fn from(underlying: RawResults) QueryError!Self { if (std.debug.runtime_safety and fields.len != underlying.columnCount() catch unreachable) { std.log.err("Expected {} columns in result, got {}", .{ fields.len, underlying.columnCount() catch unreachable }); return error.ColumnMismatch; } return Self{ .underlying = underlying, .column_indices = blk: { var indices: [fields.len]u15 = undefined; inline for (fields) |f, i| { if (comptime std.meta.trait.isTuple(T)) { indices[i] = i; } else { const name = util.comptimeJoin(".", f); indices[i] = underlying.columnIndex(name) catch { std.log.err("Could not find column index for field {s}", .{name}); return error.ColumnMismatch; }; } } break :blk indices; } }; } pub fn finish(self: Self) void { self.underlying.finish(); } // Returns the next row of results, or null if there are no more rows. // Caller owns all memory allocated. The entire object can be deallocated with a // call to util.deepFree pub fn row(self: *Self, alloc: ?Allocator) RowError!?T { if (try self.underlying.row()) |row_val| { var result: T = undefined; var fields_allocated: usize = 0; errdefer inline for (fields) |f, i| { // Iteration bounds must be defined at comptime (inline for) but the number of fields we could // successfully allocate is defined at runtime. So we iterate over the entire field array and // conditionally deallocate fields in the loop. const ptr = fieldPtr(&result, f); if (i < fields_allocated) util.deepFree(alloc, ptr.*); }; inline for (fields) |f, i| { // TODO: Causes compiler segfault. why? //const F = f.field_type; //const F = @TypeOf(@field(result, f.name)); const F = std.meta.Child(FieldPtr(*@TypeOf(result), f)); const ptr = fieldPtr(&result, f); const name = comptime util.comptimeJoin(".", f); const mode = comptime if (@hasDecl(T, "sql_serialize")) blk: { if (@hasDecl(T.sql_serialize, name)) { break :blk @field(T.sql_serialize, name); } break :blk .default; } else .default; switch (mode) { .default => ptr.* = row_val.get(F, self.column_indices[i], alloc) catch |err| { std.log.err("SQL: Error getting column {s} of type {}", .{ name, F }); return err; }, .json => { const str = row_val.get([]const u8, self.column_indices[i], alloc) catch |err| { std.log.err("SQL: Error getting column {s} of type {}", .{ name, F }); return err; }; const a = alloc orelse return error.AllocatorRequired; defer a.free(str); var ts = std.json.TokenStream.init(str); ptr.* = std.json.parse(F, &ts, .{ .allocator = a }) catch |err| { std.log.err("SQL: Error parsing columns {s} of type {}: {}", .{ name, F, err }); return error.ResultTypeMismatch; }; }, else => @compileError("unknown mode"), } fields_allocated += 1; } return result; } else return null; } }; } // Row is invalidated by the next call to result.row() const Row = union(Engine) { postgres: postgres.Row, sqlite: sqlite.Row, // Returns a value of type T from the zero-indexed column given by idx. // Not all types require an allocator to be present. If an allocator is needed but // not required, it will return error.AllocatorRequired. // The caller is responsible for deallocating T, if relevant. fn get(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T { if (T == void) return; return switch (self) { .postgres => |pg| try pg.get(T, idx, alloc), .sqlite => |lite| try lite.get(T, idx, alloc), }; } }; pub const ConstraintMode = enum { deferred, immediate, }; pub const ConnPool = struct { const max_conns = 4; const Conn = struct { engine: union(Engine) { postgres: postgres.Db, sqlite: sqlite.Db, }, in_use: std.atomic.Atomic(bool) = std.atomic.Atomic(bool).init(false), current_tx_level: u8 = 0, }; config: Config, connections: [max_conns]Conn, pub fn init(cfg: Config) OpenError!ConnPool { var self = ConnPool{ .config = cfg, .connections = undefined, }; var count: usize = 0; errdefer for (self.connections[0..count]) |*c| closeConn(c); for (self.connections) |*c| { c.* = try self.createConn(); count += 1; } return self; } pub fn deinit(self: *ConnPool) void { for (self.connections) |*c| closeConn(c); } pub fn acquire(self: *ConnPool) AcquireError!Db { for (self.connections) |*c| { if (tryAcquire(c)) return Db{ .conn = c }; } return error.NoConnectionsLeft; } fn tryAcquire(conn: *Conn) bool { const acquired = !conn.in_use.swap(true, .AcqRel); if (acquired) { if (conn.current_tx_level != 0) @panic("Transaction still open on unused db connection"); return true; } return false; } fn createConn(self: *ConnPool) OpenError!Conn { return switch (self.config) { .postgres => |postgres_cfg| Conn{ .engine = .{ .postgres = try postgres.Db.open(postgres_cfg.pg_conn_str), }, }, .sqlite => |lite_cfg| Conn{ .engine = .{ .sqlite = if (lite_cfg.sqlite_is_uri) try sqlite.Db.openUri(lite_cfg.sqlite_file_path) else try sqlite.Db.open(lite_cfg.sqlite_file_path), }, }, }; } fn closeConn(conn: *Conn) void { if (conn.in_use.loadUnchecked()) @panic("DB Conn still open"); switch (conn.engine) { .postgres => |pg| pg.close(), .sqlite => |lite| lite.close(), } } }; pub const Db = Tx(0); /// When tx_level == 0, the DB is operating in "implied transaction" mode where /// every command is its own transaction /// When tx_level >= 1, the DB has an explicit transaction open /// When tx_level >= 2, the DB has (tx_level - 1) levels of transaction savepoints open /// (like nested transactions) fn Tx(comptime tx_level: u8) type { return struct { const Self = @This(); const savepoint_name = if (tx_level == 0) @compileError("Transaction not started") else std.fmt.comptimePrint("save_{}", .{tx_level}); const next_savepoint_name = Tx(tx_level + 1).savepoint_name; conn: *ConnPool.Conn, /// The type of SQL engine being used. Use of this function should be discouraged pub fn sqlEngine(self: Self) Engine { return self.conn.engine; } /// Return the connection to the pool pub fn releaseConnection(self: Self) void { if (tx_level != 0) @compileError("close must be called on root db"); if (self.conn.current_tx_level != 0) { std.log.warn("Database released while transaction in progress!", .{}); self.rollbackUnchecked() catch {}; // TODO: Burn database connection } if (!self.conn.in_use.swap(false, .AcqRel)) @panic("Double close on db conection"); } // ********* Transaction management functions ********** /// Start an explicit transaction pub fn begin(self: Self) !Tx(1) { if (tx_level != 0) @compileError("Transaction already started"); if (self.conn.current_tx_level != 0) return error.BadTransactionState; try self.exec("BEGIN", {}, null); self.conn.current_tx_level = 1; return Tx(1){ .conn = self.conn }; } /// Create a savepoint (nested transaction) pub fn savepoint(self: Self) !Tx(tx_level + 1) { if (tx_level == 0) @compileError("Cannot place a savepoint on an implicit transaction"); if (self.conn.current_tx_level != tx_level) return error.BadTransactionState; try self.exec("SAVEPOINT " ++ next_savepoint_name, {}, null); self.conn.current_tx_level = tx_level + 1; return Tx(tx_level + 1){ .conn = self.conn }; } /// Commit the entire transaction pub fn commit(self: Self) !void { if (tx_level == 0) @compileError("Transaction not started"); if (tx_level >= 2) @compileError("Cannot commit a savepoint"); if (self.conn.current_tx_level == 0) return error.BadTransactionState; try self.exec("COMMIT", {}, null); self.conn.current_tx_level = 0; } /// Release the current savepoint and all savepoints created from it. pub fn release(self: Self) !void { if (tx_level == 0) @compileError("Transaction not started"); if (tx_level == 1) @compileError("Cannot release a transaction"); if (self.conn.current_tx_level < tx_level) return error.BadTransactionState; try self.exec("RELEASE SAVEPOINT " ++ savepoint_name, {}, null); self.conn.current_tx_level = tx_level - 1; } /// Rolls back the entire transaction pub fn rollbackTx(self: Self) !void { if (tx_level == 0) @compileError("Transaction not started"); if (tx_level >= 2) @compileError("Cannot rollback a transaction using a savepoint"); if (self.conn.current_tx_level == 0) return error.BadTransactionState; try self.rollbackUnchecked(); self.conn.current_tx_level = 0; } /// Attempts to roll back to a savepoint pub fn rollbackSavepoint(self: Self) !void { if (tx_level == 0) @compileError("Transaction not started"); if (tx_level == 1) @compileError("Cannot rollback a savepoint on the entire transaction"); if (self.conn.current_tx_level < tx_level) return error.BadTransactionState; try self.exec("ROLLBACK TO " ++ savepoint_name, {}, null); self.conn.current_tx_level = tx_level - 1; } /// Perform whichever rollback is appropriate for the situation pub fn rollback(self: Self) void { (if (tx_level < 2) self.rollbackTx() else self.rollbackSavepoint()) catch |err| { std.log.err("Failed to rollback transaction: {}", .{err}); std.log.err("{any}", .{@errorReturnTrace()}); @panic("TODO: more gracefully handle rollback failures"); }; } pub const BeginOrSavepoint = Tx(tx_level + 1); pub const beginOrSavepoint = if (tx_level == 0) begin else savepoint; pub const commitOrRelease = if (tx_level < 2) commit else release; // Allows relaxing *some* constraints for the lifetime of the transaction. // You should generally not do this, but it's useful when bootstrapping // the initial admin community and cluster operator user. pub fn setConstraintMode(self: Self, mode: ConstraintMode) !void { if (tx_level == 0) @compileError("Transaction not started"); if (tx_level >= 2) @compileError("Cannot set constraint mode on a savepoint"); switch (self.sqlEngine()) { .sqlite => try self.exec( switch (mode) { .immediate => "PRAGMA defer_foreign_keys = FALSE", .deferred => "PRAGMA defer_foreign_keys = TRUE", }, {}, null, ), .postgres => try self.exec( switch (mode) { .immediate => "SET CONSTRAINTS ALL IMMEDIATE", .deferred => "SET CONSTRAINTS ALL DEFERRED", }, {}, null, ), } } // ********** Query Helpers ************ /// Runs a command without returning results pub fn exec( self: Self, sql: [:0]const u8, args: anytype, alloc: ?std.mem.Allocator, ) !void { try self.execInternal(sql, args, .{ .allocator = alloc }, true); } pub fn execWithOptions( self: Self, sql: [:0]const u8, args: anytype, options: QueryOptions, ) !void { try self.execInternal(sql, args, options, true); } pub fn queryWithOptions( self: Self, comptime RowType: type, sql: [:0]const u8, args: anytype, options: QueryOptions, ) QueryError!Results(RowType) { return Results(RowType).from(try self.runSql(sql, args, options, true)); } pub fn query( self: Self, comptime RowType: type, sql: [:0]const u8, args: anytype, alloc: ?Allocator, ) QueryError!Results(RowType) { return self.queryWithOptions(RowType, sql, args, .{ .allocator = alloc }); } /// Runs a query to completion and returns a row of results, unless the query /// returned a different number of rows. pub fn queryRow( self: Self, comptime RowType: type, q: [:0]const u8, args: anytype, alloc: ?Allocator, ) QueryRowError!RowType { var results = try self.query(RowType, q, args, alloc); defer results.finish(); const row = (try results.row(alloc)) orelse return error.NoRows; errdefer util.deepFree(alloc, row); // execute query to completion var more_rows = false; while (try results.row(alloc)) |r| { util.deepFree(alloc, r); more_rows = true; } if (more_rows) return error.TooManyRows; return row; } pub fn queryRows( self: Self, comptime RowType: type, q: [:0]const u8, args: anytype, max_items: ?usize, alloc: std.mem.Allocator, ) QueryRowError![]RowType { return try self.queryRowsWithOptions(RowType, q, args, max_items, .{ .allocator = alloc }); } // Runs a query to completion and returns the results as a slice pub fn queryRowsWithOptions( self: Self, comptime RowType: type, q: [:0]const u8, args: anytype, max_items: ?usize, options: QueryOptions, ) QueryRowError![]RowType { var results = try self.queryWithOptions(RowType, q, args, options); defer results.finish(); const alloc = options.allocator orelse return error.AllocatorRequired; var result_array = std.ArrayList(RowType).init(alloc); errdefer result_array.deinit(); if (max_items) |max| try result_array.ensureTotalCapacity(max); errdefer for (result_array.items) |r| util.deepFree(alloc, r); var too_many: bool = false; while (try results.row(alloc)) |row| { errdefer util.deepFree(alloc, row); if (max_items) |max| { if (result_array.items.len >= max) { util.deepFree(alloc, row); too_many = true; continue; } } try result_array.append(row); } if (too_many) return error.TooManyRows; return result_array.toOwnedSlice(); } // Inserts a single value into a table pub fn insert( self: Self, comptime table: []const u8, value: anytype, alloc: ?std.mem.Allocator, ) !void { const ValueType = comptime @TypeOf(value); const fields = std.meta.fields(ValueType); comptime var types: [fields.len]type = undefined; comptime var table_spec: []const u8 = table ++ "("; comptime var value_spec: []const u8 = "("; inline for (fields) |field, i| { // This causes a compiler crash. Why? //const F = field.field_type; const F = @TypeOf(@field(value, field.name)); // causes issues if F is @TypeOf(null), use dummy type types[i] = if (F == @TypeOf(null)) ?i64 else F; table_spec = comptime (table_spec ++ field.name ++ ","); value_spec = comptime value_spec ++ std.fmt.comptimePrint("${},", .{i + 1}); } table_spec = comptime table_spec[0 .. table_spec.len - 1] ++ ")"; value_spec = comptime value_spec[0 .. value_spec.len - 1] ++ ")"; const q = comptime std.fmt.comptimePrint( "INSERT INTO {s} VALUES {s}", .{ table_spec, value_spec }, ); var args_tuple: std.meta.Tuple(&types) = undefined; inline for (fields) |field, i| { args_tuple[i] = @field(value, field.name); } try self.exec(q, args_tuple, alloc); } // internal helpers fn runSql( self: Self, sql: [:0]const u8, args: anytype, opt: QueryOptions, comptime check_tx: bool, ) !RawResults { if (check_tx and self.conn.current_tx_level != tx_level) return error.BadTransactionState; return switch (self.conn.engine) { .postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, opt) }, .sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args, opt) }, }; } fn execInternal( self: Self, sql: [:0]const u8, args: anytype, options: QueryOptions, comptime check_tx: bool, ) !void { var results = try self.runSql(sql, args, options, check_tx); defer results.finish(); while (try results.row()) |_| {} } fn rollbackUnchecked(self: Self) !void { try self.execInternal("ROLLBACK", {}, .{}, false); self.conn.current_tx_level = 0; } }; }