diff --git a/build.zig b/build.zig index ada7f9a..df7c55d 100644 --- a/build.zig +++ b/build.zig @@ -118,9 +118,7 @@ pub fn build(b: *std.build.Builder) !void { const unittest_sql = b.addTest("src/sql/lib.zig"); unittest_sql_cmd.dependOn(&unittest_sql.step); unittest_sql.addPackage(unittest_pkgs.util); - unittest_sql.addSystemIncludePath("/nix/store/c2vkxzbqb3z5220bsgdw1s0kasg61lry-sqlite-3.41.2-dev/include/"); // TODO: why - unittest_sql.linkSystemLibrary("sqlite3"); - unittest_sql.linkLibC(); + //unittest_sql.linkLibC(); const unittest_template_cmd = b.step("unit:template", "Run tests for template package"); const unittest_template = b.addTest("src/template/lib.zig"); diff --git a/default.nix b/default.nix deleted file mode 100644 index a5cfb51..0000000 --- a/default.nix +++ /dev/null @@ -1,10 +0,0 @@ -{ }: - -let pkgs = import { }; - -in pkgs.stdenv.mkDerivation { - name = "fediglam"; - src = ./.; - - buildInputs = with pkgs; [ zig postgresql sqlite ]; -} diff --git a/src/api/methods/auth.zig b/src/api/methods/auth.zig index f56d88a..0d085d6 100644 --- a/src/api/methods/auth.zig +++ b/src/api/methods/auth.zig @@ -17,8 +17,6 @@ pub fn register( alloc: std.mem.Allocator, ctx: ApiContext, svcs: anytype, - username: []const u8, - password: []const u8, opt: RegistrationOptions, ) !Uuid { const tx = try svcs.beginTx(); @@ -46,8 +44,8 @@ pub fn register( alloc, tx, .{ - .username = username, - .password = password, + .username = opt.username, + .password = opt.password, .community_id = ctx.community.id, .invite_id = if (maybe_invite) |inv| @as(?Uuid, inv.id) else null, .email = opt.email, @@ -358,9 +356,9 @@ test "register" { allocator, .{ .community = community }, &svc, - test_args.username, - test_args.password, .{ + .username = test_args.username, + .password = test_args.password, .invite_code = if (test_args.use_invite) test_invite_code else null, }, // shortcut out of memory errors to test allocation diff --git a/src/api/types.zig b/src/api/types.zig index e072f3d..390306a 100644 --- a/src/api/types.zig +++ b/src/api/types.zig @@ -15,6 +15,8 @@ fn QueryResult(comptime R: type, comptime A: type) type { pub const auth = struct { pub const RegistrationOptions = struct { + username: []const u8, + password: []const u8, invite_code: ?[]const u8 = null, email: ?[]const u8 = null, }; diff --git a/src/main/controllers/api/users.zig b/src/main/controllers/api/users.zig index 3bf0190..321b9ba 100644 --- a/src/main/controllers/api/users.zig +++ b/src/main/controllers/api/users.zig @@ -14,10 +14,12 @@ pub const create = struct { pub fn handler(req: anytype, res: anytype, srv: anytype) !void { const options = .{ + .username = req.body.username, + .password = req.body.password, .invite_code = req.body.invite_code, .email = req.body.email, }; - const user = srv.register(req.body.username, req.body.password, options) catch |err| switch (err) { + const user = srv.register(options) catch |err| switch (err) { error.UsernameTaken => return res.err(.unprocessable_entity, "Username Unavailable", {}), else => return err, }; diff --git a/src/sql/engines/sqlite.zig b/src/sql/engines/sqlite.zig index 1c34ae6..99d1f50 100644 --- a/src/sql/engines/sqlite.zig +++ b/src/sql/engines/sqlite.zig @@ -1,270 +1,14 @@ -const sql = @import("../lib.zig"); const std = @import("std"); +const util = @import("util"); +const common = @import("./common.zig"); const c = @cImport({ @cInclude("sqlite3.h"); }); -const QueryOptions = sql.QueryOptions; -const ExecError = sql.ExecError; -const SqlValue = sql.SqlValue; -const Results = sql.Results; - +const Uuid = util.Uuid; +const DateTime = util.DateTime; const Allocator = std.mem.Allocator; -pub const Engine = struct { - conn: *c.sqlite3, - - pub fn open(path: [:0]const u8) !Engine { - return openInternal(path, false); - } - - pub fn openUri(path: [:0]const u8) !Engine { - return openInternal(path, true); - } - - fn openInternal(path: [:0]const u8, is_uri: bool) !Engine { - const flags = c.SQLITE_OPEN_READWRITE | c.SQLITE_OPEN_CREATE | c.SQLITE_OPEN_EXRESCODE | if (is_uri) c.SQLITE_OPEN_URI else 0; - - var conn: ?*c.sqlite3 = null; - switch (c.sqlite3_open_v2(@ptrCast([*c]const u8, path), &conn, flags, null)) { - c.SQLITE_OK => {}, - else => |code| { - if (conn == null) { - // this path should only be hit if out of memory, but log it anyways - std.log.err( - "Unable to open SQLite DB \"{s}\". Error: {?s} ({})", - .{ path, c.sqlite3_errstr(code), code }, - ); - return error.BadConnection; - } - - const ext_code = c.sqlite3_extended_errcode(conn); - std.log.err( - \\Unable to open SQLite DB "{s}". Error: {?s} ({}) - \\Details: {?s} - , - .{ path, c.sqlite3_errstr(ext_code), ext_code, c.sqlite3_errmsg(conn) }, - ); - - return error.Unexpected; - }, - } - - return Engine{ - .conn = conn.?, - }; - } - - pub fn db(self: *Engine) sql.Db { - return .{ - .ptr = self, - .vtable = &.{ - .exec = Engine.exec, - }, - }; - } - - pub fn close(self: *Engine) void { - switch (c.sqlite3_close(self.conn)) { - c.SQLITE_OK => {}, - - c.SQLITE_BUSY => { - std.log.err("SQLite DB could not be closed as it is busy.\n{s}", .{c.sqlite3_errmsg(self.conn)}); - }, - - else => |err| { - std.log.err("Could not close SQLite DB", .{}); - handleUnexpectedError(self.conn, err, null) catch {}; - }, - } - } - - fn exec( - ctx: *anyopaque, - cmd: []const u8, - args: []const SqlValue, - opts: QueryOptions, - _: Allocator, - ) ExecError!Results { - const self = @ptrCast(*Engine, @alignCast(@alignOf(Engine), ctx)); - - var stmt: ?*c.sqlite3_stmt = undefined; - switch (c.sqlite3_prepare_v2(self.conn, cmd.ptr, @intCast(c_int, cmd.len), &stmt, null)) { - c.SQLITE_OK => {}, - else => |err| return handleUnexpectedError(self.conn, err, sql), - } - errdefer switch (c.sqlite3_finalize(stmt)) { - c.SQLITE_OK => {}, - else => |err| { - handleUnexpectedError(self.conn, err, sql) catch {}; - }, - }; - - if (args.len != 0) { - for (args) |arg, i| { - const buf_size = 21; // ceil(log10(2^64)) + 1 - var name_buf: [buf_size]u8 = undefined; - const name = std.fmt.bufPrintZ(&name_buf, "{}", .{i}) catch unreachable; - - const db_idx = c.sqlite3_bind_parameter_index(stmt.?, name.ptr); - if (db_idx != 0) - try self.bindArgument(stmt.?, @intCast(u15, db_idx), arg) - else if (!opts.ignore_unused_arguments) - return error.UnusedArgument; - } - } - - @panic("unimplemented"); - } - - fn bindArgument(self: *Engine, stmt: *c.sqlite3_stmt, idx: u15, val: SqlValue) !void { - return switch (val) { - .int => |v| self.bindInt(stmt, idx, i64, v), - .uint => |v| self.bindInt(stmt, idx, u64, v), - .str => |str| { - const len = std.math.cast(c_int, str.len) orelse { - std.log.err("SQLite: string len {} too large", .{str.len}); - return error.BindException; - }; - - switch (c.sqlite3_bind_text(stmt, idx, str.ptr, len, c.SQLITE_TRANSIENT)) { - c.SQLITE_OK => {}, - else => |result| { - std.log.err("SQLite: Unable to bind string to index {}", .{idx}); - std.log.debug("SQLite: {s}", .{str}); - return handleUnexpectedError(self.conn, result, null); - }, - } - }, - .float => |v| { - switch (c.sqlite3_bind_double(stmt, idx, v)) { - c.SQLITE_OK => {}, - else => |result| { - std.log.err("SQLite: Unable to bind float to index {}", .{idx}); - std.log.debug("SQLite: {}", .{v}); - return handleUnexpectedError(self.conn, result, null); - }, - } - }, - .@"null" => { - switch (c.sqlite3_bind_null(stmt, idx)) { - c.SQLITE_OK => {}, - else => |result| { - std.log.err("SQLite: Unable to bind NULL to index {}", .{idx}); - return handleUnexpectedError(self.conn, result, null); - }, - } - }, - }; - } - - fn bindInt(self: *Engine, stmt: *c.sqlite3_stmt, idx: u15, comptime T: type, val: T) !void { - const v = std.math.cast(i64, val) orelse { - std.log.err("SQLite: integer {} does not fit within i64", .{val}); - return error.BindException; - }; - switch (c.sqlite3_bind_int64(stmt, idx, v)) { - c.SQLITE_OK => {}, - else => |result| { - std.log.err("SQLite: Unable to bind int to index {}", .{idx}); - std.log.debug("SQLite: {}", .{v}); - return handleUnexpectedError(self.conn, result, null); - }, - } - } -}; - -const SqliteResults = struct { - stmt: *c.sqlite3_stmt, - conn: *c.sqlite3, - allocator: std.mem.Allocator, - - fn results(self: *SqliteResults) Results { - return .{ - .ptr = self, - .vtable = &.{}, - }; - } - - fn finish(ctx: *anyopaque) void { - const self = @ptrCast(*SqliteResults, @alignCast(@alignOf(SqliteResults), ctx)); - - _ = c.sqlite3_finalize(self.stmt); - self.allocator.destroy(self); - } - - fn row(ctx: *anyopaque) sql.RowError!?sql.Row { - const self = @ptrCast(*SqliteResults, @alignCast(@alignOf(SqliteResults), ctx)); - - switch (c.sqlite3_step(self.stmt)) { - c.SQLITE_ROW => {}, - - c.SQLITE_DONE => return null, - - c.SQLITE_CONSTRAINT_UNIQUE => return error.UniqueViolation, - c.SQLITE_CONSTRAINT_CHECK => return error.CheckViolation, - c.SQLITE_CONSTRAINT_NOTNULL => return error.NotNullViolation, - c.SQLITE_CONSTRAINT_FOREIGNKEY => return error.ForeignKeyViolation, - c.SQLITE_CONSTRAINT => return error.ConstraintViolation, - - else => |err| return handleUnexpectedError(self.db, err, self.getGeneratingSql()), - } - - return sql.Row{ - .ptr = self, - .vtable = &.{}, - }; - } - - fn columnCount(ctx: *anyopaque) sql.ColumnCountError!sql.ColumnIndex { - const self = @ptrCast(*SqliteResults, @alignCast(@alignOf(SqliteResults), ctx)); - - return @intCast(sql.ColumnIndex, c.sqlite3_column_count(self.stmt)); - } - - fn columnIndex(ctx: *anyopaque, name: []const u8) sql.ColumnIndexError!sql.ColumnIndex { - const self = @ptrCast(*SqliteResults, @alignCast(@alignOf(SqliteResults), ctx)); - - var i: u15 = 0; - const count = try self.columnCount(); - while (i < count) : (i += 1) { - const column = try self.columnName(i); - if (std.mem.eql(u8, name, column)) return i; - } - - return error.NotFound; - } - - fn isNull(ctx: *anyopaque, idx: sql.ColumnIndex) sql.GetError!bool { - const self = @ptrCast(*SqliteResults, @alignCast(@alignOf(SqliteResults), ctx)); - - return c.sqlite3_column_type(self.stmt, idx) == c.SQLITE_NULL; - } - - fn getStr(ctx: *anyopaque, idx: sql.ColumnIndex) sql.GetError![]const u8 { - const self = @ptrCast(*SqliteResults, @alignCast(@alignOf(SqliteResults), ctx)); - - if (c.sqlite3_column_text(self.stmt, idx)) |ptr| { - const size = @intCast(usize, c.sqlite3_column_bytes(self.stmt, idx)); - const str = std.mem.sliceTo(ptr[0..size], 0); - - return str; - } else { - std.log.err("SQLite column {}: TEXT value requested but null pointer returned (out of memory?)", .{idx}); - return error.Unexpected; - } - } - - fn getFloat(ctx: *anyopaque, idx: sql.ColumnIndex) sql.getError!f64 {} - - fn columnName(self: Results, idx: u15) ![]const u8 { - return if (c.sqlite3_column_name(self.stmt, idx)) |ptr| - ptr[0..std.mem.len(ptr)] - else - unreachable; - } -}; - fn getCharPos(text: []const u8, offset: c_int) struct { row: usize, col: usize } { var row: usize = 0; var col: usize = 0; @@ -289,14 +33,352 @@ fn handleUnexpectedError(db: *c.sqlite3, code: c_int, sql_text: ?[]const u8) err std.log.debug("Additional details:", .{}); std.log.debug("{?s}", .{c.sqlite3_errmsg(db)}); - if (sql_text) |text| { + if (sql_text) |sql| { const byte_offset = c.sqlite3_error_offset(db); if (byte_offset >= 0) { - const pos = getCharPos(text, byte_offset); - std.log.debug("Failed at char ({}:{}) of SQL:\n{s}", .{ pos.row, pos.col, text }); + const pos = getCharPos(sql, byte_offset); + std.log.debug("Failed at char ({}:{}) of SQL:\n{s}", .{ pos.row, pos.col, sql }); } } std.log.debug("{?}", .{@errorReturnTrace()}); return error.Unexpected; } + +pub const Db = struct { + db: *c.sqlite3, + + pub fn open(path: [:0]const u8) common.OpenError!Db { + return openInternal(path, false); + } + + pub fn openUri(path: [:0]const u8) common.OpenError!Db { + return openInternal(path, true); + } + + fn openInternal(path: [:0]const u8, is_uri: bool) common.OpenError!Db { + const flags = c.SQLITE_OPEN_READWRITE | c.SQLITE_OPEN_CREATE | c.SQLITE_OPEN_EXRESCODE | if (is_uri) c.SQLITE_OPEN_URI else 0; + + var db: ?*c.sqlite3 = null; + switch (c.sqlite3_open_v2(@ptrCast([*c]const u8, path), &db, flags, null)) { + c.SQLITE_OK => {}, + else => |code| { + if (db == null) { + // this path should only be hit if out of memory, but log it anyways + std.log.err( + "Unable to open SQLite DB \"{s}\". Error: {?s} ({})", + .{ path, c.sqlite3_errstr(code), code }, + ); + return error.BadConnection; + } + + const ext_code = c.sqlite3_extended_errcode(db); + std.log.err( + \\Unable to open SQLite DB "{s}". Error: {?s} ({}) + \\Details: {?s} + , + .{ path, c.sqlite3_errstr(ext_code), ext_code, c.sqlite3_errmsg(db) }, + ); + + return error.Unexpected; + }, + } + + return Db{ + .db = db.?, + }; + } + + pub fn close(self: Db) void { + switch (c.sqlite3_close(self.db)) { + c.SQLITE_OK => {}, + + c.SQLITE_BUSY => { + std.log.err("SQLite DB could not be closed as it is busy.\n{s}", .{c.sqlite3_errmsg(self.db)}); + }, + + else => |err| { + std.log.err("Could not close SQLite DB", .{}); + handleUnexpectedError(self.db, err, null) catch {}; + }, + } + } + + pub fn exec(self: Db, sql: []const u8, args: anytype, opts: common.QueryOptions) common.ExecError!Results { + var stmt: ?*c.sqlite3_stmt = undefined; + switch (c.sqlite3_prepare_v2(self.db, sql.ptr, @intCast(c_int, sql.len), &stmt, null)) { + c.SQLITE_OK => {}, + else => |err| return handleUnexpectedError(self.db, err, sql), + } + errdefer switch (c.sqlite3_finalize(stmt)) { + c.SQLITE_OK => {}, + else => |err| { + handleUnexpectedError(self.db, err, sql) catch {}; + }, + }; + + if (@TypeOf(args) != void) { + // TODO: Fix for stage1 compiler + //inline for (args) |arg, i| { + inline for (std.meta.fields(@TypeOf(args))) |field, i| { + const arg = @field(args, field.name); + // SQLite treats $NNN args as having the name NNN, not index NNN. + // As such, if you reference $2 and not $1 in your query (such as + // when dynamically constructing queries), it could assign $2 the + // index 1. So we can't assume the index according to the 1-indexed + // arg array is equivalent to the param to bind it to. + // We can, however, look up the exact index to bind to. + // If the argument is not used in the query, then it will have an "index" + // of 0, and we must not bind the argument. + const name = std.fmt.comptimePrint("${}", .{i + 1}); + const db_idx = c.sqlite3_bind_parameter_index(stmt.?, name); + if (db_idx != 0) + try self.bindArgument(stmt.?, @intCast(u15, db_idx), arg) + else if (!opts.ignore_unused_arguments) + return error.UnusedArgument; + } + } + + return Results{ .stmt = stmt.?, .db = self.db }; + } + + fn bindArgument(self: Db, stmt: *c.sqlite3_stmt, idx: u15, val: anytype) !void { + if (comptime std.meta.trait.isZigString(@TypeOf(val))) { + return self.bindString(stmt, idx, val); + } + + const T = @TypeOf(val); + + switch (@typeInfo(T)) { + .Union => { + const arr = if (@hasDecl(T, "toCharArray")) + val.toCharArray() + else if (@hasDecl(T, "toCharArrayZ")) + val.toCharArrayZ() + else { + inline for (std.meta.fields(T)) |field| { + const Tag = std.meta.Tag(T); + const tag = @field(Tag, field.name); + + if (val == tag) return try self.bindArgument(stmt, idx, @field(val, field.name)); + } + unreachable; + }; + + const len = std.mem.len(&arr); + return self.bindString(stmt, idx, arr[0..len]); + }, + + .Struct => { + const arr = if (@hasDecl(T, "toCharArray")) + val.toCharArray() + else if (@hasDecl(T, "toCharArrayZ")) + val.toCharArrayZ() + else + @compileError("SQLite: Could not serialize " ++ @typeName(T) ++ " into staticly sized string"); + + const len = std.mem.len(&arr); + return self.bindString(stmt, idx, arr[0..len]); + }, + .Enum => |info| { + const name = if (info.is_exhaustive) + @tagName(val) + else + @compileError("SQLite: Could not serialize non-exhaustive enum " ++ @typeName(T) ++ " into string"); + return self.bindString(stmt, idx, name); + }, + .Optional => { + return if (val) |v| self.bindArgument(stmt, idx, v) else self.bindNull(stmt, idx); + }, + .Null => return self.bindNull(stmt, idx), + .Int => return self.bindInt(stmt, idx, std.math.cast(i64, val) orelse unreachable), + .Float => return self.bindFloat(stmt, idx, val), + .Bool => return self.bindInt(stmt, idx, if (val) 1 else 0), + else => @compileError("Unable to serialize type " ++ @typeName(T)), + } + } + + fn bindString(self: Db, stmt: *c.sqlite3_stmt, idx: u15, str: []const u8) !void { + const len = std.math.cast(c_int, str.len) orelse { + std.log.err("SQLite: string len {} too large", .{str.len}); + return error.BindException; + }; + + switch (c.sqlite3_bind_text(stmt, idx, str.ptr, len, c.SQLITE_TRANSIENT)) { + c.SQLITE_OK => {}, + else => |result| { + std.log.err("SQLite: Unable to bind string to index {}", .{idx}); + std.log.debug("SQLite: {s}", .{str}); + return handleUnexpectedError(self.db, result, null); + }, + } + } + + fn bindNull(self: Db, stmt: *c.sqlite3_stmt, idx: u15) !void { + switch (c.sqlite3_bind_null(stmt, idx)) { + c.SQLITE_OK => {}, + else => |result| { + std.log.err("SQLite: Unable to bind NULL to index {}", .{idx}); + return handleUnexpectedError(self.db, result, null); + }, + } + } + + fn bindInt(self: Db, stmt: *c.sqlite3_stmt, idx: u15, val: i64) !void { + switch (c.sqlite3_bind_int64(stmt, idx, val)) { + c.SQLITE_OK => {}, + else => |result| { + std.log.err("SQLite: Unable to bind int to index {}", .{idx}); + std.log.debug("SQLite: {}", .{val}); + return handleUnexpectedError(self.db, result, null); + }, + } + } + + fn bindFloat(self: Db, stmt: *c.sqlite3_stmt, idx: u15, val: f64) !void { + switch (c.sqlite3_bind_double(stmt, idx, val)) { + c.SQLITE_OK => {}, + else => |result| { + std.log.err("SQLite: Unable to bind float to index {}", .{idx}); + std.log.debug("SQLite: {}", .{val}); + return handleUnexpectedError(self.db, result, null); + }, + } + } +}; + +pub const Results = struct { + stmt: *c.sqlite3_stmt, + db: *c.sqlite3, + + pub fn finish(self: Results) void { + _ = c.sqlite3_finalize(self.stmt); + } + + pub fn row(self: Results) common.RowError!?Row { + return switch (c.sqlite3_step(self.stmt)) { + c.SQLITE_ROW => Row{ .stmt = self.stmt, .db = self.db }, + c.SQLITE_DONE => null, + + c.SQLITE_CONSTRAINT_UNIQUE => return error.UniqueViolation, + c.SQLITE_CONSTRAINT_CHECK => return error.CheckViolation, + c.SQLITE_CONSTRAINT_NOTNULL => return error.NotNullViolation, + c.SQLITE_CONSTRAINT_FOREIGNKEY => return error.ForeignKeyViolation, + c.SQLITE_CONSTRAINT => return error.ConstraintViolation, + + else => |err| handleUnexpectedError(self.db, err, self.getGeneratingSql()), + }; + } + + fn getGeneratingSql(self: Results) ?[]const u8 { + const ptr = c.sqlite3_sql(self.stmt) orelse return null; + return ptr[0..std.mem.len(ptr)]; + } + + pub fn columnCount(self: Results) common.ColumnCountError!u15 { + return @intCast(u15, c.sqlite3_column_count(self.stmt)); + } + + fn columnName(self: Results, idx: u15) ![]const u8 { + return if (c.sqlite3_column_name(self.stmt, idx)) |ptr| + ptr[0..std.mem.len(ptr)] + else + unreachable; + } + + pub fn columnIndex(self: Results, name: []const u8) common.ColumnIndexError!u15 { + var i: u15 = 0; + const count = try self.columnCount(); + while (i < count) : (i += 1) { + const column = try self.columnName(i); + if (std.mem.eql(u8, name, column)) return i; + } + + return error.NotFound; + } +}; + +pub const Row = struct { + stmt: *c.sqlite3_stmt, + db: *c.sqlite3, + + pub fn get(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T { + return getColumn(self.stmt, T, idx, alloc); + } +}; + +fn getColumn(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T { + const Eff = if (comptime std.meta.trait.is(.Optional)(T)) std.meta.Child(T) else T; + return switch (c.sqlite3_column_type(stmt, idx)) { + c.SQLITE_INTEGER => try getColumnInt(stmt, Eff, idx), + c.SQLITE_FLOAT => try getColumnFloat(stmt, Eff, idx), + c.SQLITE_TEXT => try getColumnText(stmt, Eff, idx, alloc), + c.SQLITE_NULL => { + if (T == DateTime) { + std.log.warn("SQLite: Treating NULL as DateTime epoch", .{}); + return std.mem.zeroes(DateTime); + } + + if (@typeInfo(T) != .Optional) { + std.log.err("SQLite column {}: Expected value of type {}, got (null)", .{ idx, T }); + return error.ResultTypeMismatch; + } + + return null; + }, + c.SQLITE_BLOB => { + std.log.err("SQLite column {}: SQLite value had unsupported storage class BLOB", .{idx}); + return error.ResultTypeMismatch; + }, + else => |class| { + std.log.err("SQLite column {}: SQLite value had unknown storage class {}", .{ idx, class }); + return error.ResultTypeMismatch; + }, + }; +} + +fn getColumnInt(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15) common.GetError!T { + const val: i64 = c.sqlite3_column_int64(stmt, idx); + + switch (T) { + DateTime => return DateTime{ .seconds_since_epoch = val }, + else => switch (@typeInfo(T)) { + .Int => if (std.math.cast(T, val)) |v| return v else { + std.log.err("SQLite column {}: Expected value of type {}, got {} (outside of range)", .{ idx, T, val }); + return error.ResultTypeMismatch; + }, + .Bool => if (val == 0) return false else return true, + else => { + std.log.err("SQLite column {}: Storage class INT cannot be parsed into type {}", .{ idx, T }); + return error.ResultTypeMismatch; + }, + }, + } +} + +fn getColumnFloat(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15) common.GetError!T { + const val: f64 = c.sqlite3_column_double(stmt, idx); + switch (T) { + // Only support floats that fit in range for now + f16, f32, f64 => return @floatCast(T, val), + DateTime => return DateTime{ + .seconds_since_epoch = std.math.lossyCast(i64, val * @intToFloat(f64, std.time.epoch.secs_per_day)), + }, + else => { + std.log.err("SQLite column {}: Storage class FLOAT cannot be parsed into type {}", .{ idx, T }); + return error.ResultTypeMismatch; + }, + } +} + +fn getColumnText(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T { + if (c.sqlite3_column_text(stmt, idx)) |ptr| { + const size = @intCast(usize, c.sqlite3_column_bytes(stmt, idx)); + const str = std.mem.sliceTo(ptr[0..size], 0); + + return common.parseValueNotNull(alloc, T, str); + } else { + std.log.err("SQLite column {}: TEXT value stored but engine returned null pointer (out of memory?)", .{idx}); + return error.ResultTypeMismatch; + } +} diff --git a/src/sql/lib.zig b/src/sql/lib.zig index 2acc7ee..bb55dd3 100644 --- a/src/sql/lib.zig +++ b/src/sql/lib.zig @@ -1,110 +1,726 @@ const std = @import("std"); +const util = @import("util"); +const build_options = @import("build_options"); + +const postgres = if (build_options.enable_postgres) + @import("./engines/postgres.zig") +else + @import("./engines/null.zig"); + +const sqlite = if (build_options.enable_sqlite) + @import("./engines/sqlite.zig") +else + @import("./engines/null.zig"); +const common = @import("./engines/common.zig"); const Allocator = std.mem.Allocator; -pub const SqlValue = union(enum) { - int: i64, - uint: u64, - str: []const u8, - @"null": void, - float: f64, +const errors = @import("./errors.zig").library_errors; + +pub const AcquireError = OpenError || error{NoConnectionsLeft}; +pub const OpenError = errors.OpenError; +pub const QueryError = errors.QueryError; +pub const RowError = errors.RowError; +pub const QueryRowError = errors.QueryRowError; +pub const BeginError = errors.BeginError; +pub const CommitError = errors.CommitError; + +pub const DatabaseError = QueryError || RowError || QueryRowError || BeginError || CommitError; + +pub const QueryOptions = common.QueryOptions; + +pub const Engine = enum { + postgres, + sqlite, }; -pub const QueryOptions = struct { - // If true, then it will not return an error on the SQLite backend - // if an argument passed does not map to a parameter in the query. - // Has no effect on the postgres backend. - ignore_unused_arguments: bool = false, +/// Helper for building queries at runtime. All constituent parts of the +/// query should be defined at comptime, however the choice of whether +/// or not to include them can occur at runtime. +pub const QueryBuilder = struct { + array: std.ArrayList(u8), + where_clauses_appended: usize = 0, + set_statements_appended: usize = 0, + + pub fn init(alloc: std.mem.Allocator) QueryBuilder { + return QueryBuilder{ .array = std.ArrayList(u8).init(alloc) }; + } + + pub fn deinit(self: *const QueryBuilder) void { + self.array.deinit(); + } + + /// Add a chunk of sql to the query without processing + pub fn appendSlice(self: *QueryBuilder, comptime sql: []const u8) !void { + try self.array.appendSlice(sql); + } + + /// Add a where clause to the query. Clauses are assumed to be components + /// in an overall expression in Conjunctive Normal Form (AND of OR's). + /// https://en.wikipedia.org/wiki/Conjunctive_normal_form + /// All calls to andWhere must be contiguous, that is, they cannot be + /// interspersed with calls to appendSlice + pub fn andWhere(self: *QueryBuilder, comptime clause: []const u8) !void { + if (self.where_clauses_appended == 0) { + try self.array.appendSlice("\nWHERE "); + } else { + try self.array.appendSlice(" AND "); + } + + try self.array.appendSlice(clause); + self.where_clauses_appended += 1; + } + + pub fn set(self: *QueryBuilder, comptime col: []const u8, comptime val: []const u8) !void { + if (self.set_statements_appended == 0) { + try self.array.appendSlice("\nSET "); + } else { + try self.array.appendSlice(", "); + } + + try self.array.appendSlice(col ++ " = " ++ val); + self.set_statements_appended += 1; + } + + pub fn str(self: *const QueryBuilder) []const u8 { + return self.array.items; + } + + pub fn terminate(self: *QueryBuilder) ![:0]const u8 { + std.debug.assert(self.array.items.len != 0); + if (self.array.items[self.array.items.len - 1] != 0) try self.array.append(0); + + return std.meta.assumeSentinel(self.array.items, 0); + } }; -pub const UnexpectedError = error{Unexpected}; -pub const ConstraintError = error{ - NotNullViolation, - ForeignKeyViolation, - UniqueViolation, - CheckViolation, - - /// Catchall for miscellaneous types of constraints - ConstraintViolation, +// TODO: make this suck less +pub const Config = union(Engine) { + postgres: struct { + pg_conn_str: [:0]const u8, + }, + sqlite: struct { + sqlite_file_path: [:0]const u8, + sqlite_is_uri: bool = false, + }, }; -pub const ExecError = error{ - Cancelled, - BadConnection, - InternalException, - DatabaseBusy, - PermissionDenied, - SqlException, +const RawResults = union(Engine) { + postgres: postgres.Results, + sqlite: sqlite.Results, - /// Argument could not be marshalled for query - BindException, + fn finish(self: RawResults) void { + switch (self) { + .postgres => |pg| pg.finish(), + .sqlite => |lite| lite.finish(), + } + } - /// An argument was not used by the query (not checked in all DB engines) - UnusedArgument, + fn columnCount(self: RawResults) !u15 { + return try switch (self) { + .postgres => |pg| pg.columnCount(), + .sqlite => |lite| lite.columnCount(), + }; + } - /// Memory error when marshalling argument for query - OutOfMemory, - AllocatorRequired, -} || ConstraintError || UnexpectedError; + fn columnIndex(self: RawResults, name: []const u8) error{ NotFound, Unexpected }!u15 { + return switch (self) { + .postgres => |pg| pg.columnIndex(name), + .sqlite => |lite| lite.columnIndex(name), + } catch |err| switch (err) { + error.OutOfRange => error.Unexpected, + error.NotFound => error.NotFound, + }; + } -pub const Db = struct { - pub const VTable = struct { - /// Executes a SQL query. - exec: *const fn (ctx: *anyopaque, sql: []const u8, args: []const SqlValue, opt: QueryOptions, allocator: Allocator) ExecError!Results, - }; - - vtable: *const VTable, - ptr: *anyopaque, + fn row(self: *RawResults) RowError!?Row { + return switch (self.*) { + .postgres => |*pg| if (try pg.row()) |r| Row{ .postgres = r } else null, + .sqlite => |*lite| if (try lite.row()) |r| Row{ .sqlite = r } else null, + }; + } }; -pub const ColumnCountError = error{OutOfRange}; -pub const ColumnIndexError = error{ NotFound, OutOfRange }; -pub const ColumnIndex = u32; +const FieldRef = []const []const u8; +fn FieldPtr(comptime Ptr: type, comptime names: FieldRef) type { + if (names.len == 0) return Ptr; -pub const RowError = error{ - Cancelled, - BadConnection, - InternalException, - DatabaseBusy, - PermissionDenied, - SqlException, -} || ConstraintError || UnexpectedError; + const T = std.meta.Child(Ptr); -pub const Results = struct { - pub const VTable = struct { - columnCount: *const fn (ctx: *anyopaque) ColumnCountError!ColumnIndex, - columnIndex: *const fn (ctx: *anyopaque) ColumnIndexError!ColumnIndex, - row: *const fn (ctx: *anyopaque) RowError!?Row, - finish: *const fn (ctx: *anyopaque) void, - }; + const field = for (@typeInfo(T).Struct.fields) |f| { + if (std.mem.eql(u8, f.name, names[0])) break f; + } else @compileError("Unknown field " ++ names[0] ++ " in type " ++ @typeName(T)); - vtable: *const VTable, - ptr: *anyopaque, -}; - -pub const GetError = error{ - TypeMismatch, - InvalidIndex, -} || UnexpectedError; - -pub const Row = struct { - pub const VTable = struct { - isNull: *const fn (ctx: *anyopaque, idx: ColumnIndex) GetError!bool, - getStr: *const fn (ctx: *anyopaque, idx: ColumnIndex) GetError![]const u8, - getInt: *const fn (ctx: *anyopaque, idx: ColumnIndex) GetError!i64, - getUint: *const fn (ctx: *anyopaque, idx: ColumnIndex) GetError!u64, - getFloat: *const fn (ctx: *anyopaque, idx: ColumnIndex) GetError!f64, - }; - - vtable: *const VTable, - ptr: *anyopaque, -}; - -test "test" { - const backend = @import("./engines/sqlite.zig"); - var engine = try backend.Engine.open(":memory:"); - defer engine.close(); - - const db = engine.db(); - - _ = try db.vtable.exec(db.ptr, "CREATE TABLE foo(bar INT PRIMARY KEY);", &.{}, .{}, std.testing.allocator); + return FieldPtr(*field.field_type, names[1..]); +} + +fn fieldPtr(ptr: anytype, comptime names: FieldRef) FieldPtr(@TypeOf(ptr), names) { + if (names.len == 0) return ptr; + + return fieldPtr(&@field(ptr.*, names[0]), names[1..]); +} + +fn getRecursiveFieldList(comptime T: type, comptime prefix: FieldRef, comptime options: anytype) []const FieldRef { + comptime { + if (std.meta.trait.is(.Union)(T) and prefix.len == 0 and options.embed_unions) { + @compileError("Cannot embed a union into nothing"); + } + + if (options.isScalar(T)) return &.{prefix}; + if (std.meta.trait.is(.Optional)(T)) return getRecursiveFieldList(std.meta.Child(T), prefix, options); + + const eff_prefix: FieldRef = if (std.meta.trait.is(.Union)(T) and options.embed_unions) + prefix[0 .. prefix.len - 1] + else + prefix; + + var fields: []const FieldRef = &.{}; + + for (std.meta.fields(T)) |f| { + const new_prefix = eff_prefix ++ &[_][]const u8{f.name}; + if (@hasDecl(T, "sql_serialize") and @hasDecl(T.sql_serialize, f.name) and @field(T.sql_serialize, f.name) == .json) { + fields = fields ++ &[_]FieldRef{new_prefix}; + } else { + const F = f.field_type; + fields = fields ++ getRecursiveFieldList(F, new_prefix, options); + } + } + + return fields; + } +} + +// Represents a set of results. +// row() must be called until it returns null, or the query may not complete +// Must be deallocated by a call to finish() +pub fn Results(comptime T: type) type { + // would normally make this a declaration of the struct, but it causes the compiler to crash + const fields = if (T == void) .{} else getRecursiveFieldList( + T, + &.{}, + util.serialize.default_options, + ); + return struct { + const Self = @This(); + + underlying: RawResults, + column_indices: [fields.len]u15, + + fn from(underlying: RawResults) QueryError!Self { + if (std.debug.runtime_safety and fields.len != underlying.columnCount() catch unreachable) { + std.log.err("Expected {} columns in result, got {}", .{ fields.len, underlying.columnCount() catch unreachable }); + return error.ColumnMismatch; + } + + return Self{ .underlying = underlying, .column_indices = blk: { + var indices: [fields.len]u15 = undefined; + inline for (fields) |f, i| { + if (comptime std.meta.trait.isTuple(T)) { + indices[i] = i; + } else { + const name = util.comptimeJoin(".", f); + indices[i] = + underlying.columnIndex(name) catch { + std.log.err("Could not find column index for field {s}", .{name}); + return error.ColumnMismatch; + }; + } + } + break :blk indices; + } }; + } + + pub fn finish(self: Self) void { + self.underlying.finish(); + } + + // Returns the next row of results, or null if there are no more rows. + // Caller owns all memory allocated. The entire object can be deallocated with a + // call to util.deepFree + pub fn row(self: *Self, alloc: ?Allocator) RowError!?T { + if (try self.underlying.row()) |row_val| { + var result: T = undefined; + var fields_allocated: usize = 0; + errdefer inline for (fields) |f, i| { + // Iteration bounds must be defined at comptime (inline for) but the number of fields we could + // successfully allocate is defined at runtime. So we iterate over the entire field array and + // conditionally deallocate fields in the loop. + const ptr = fieldPtr(&result, f); + if (i < fields_allocated) util.deepFree(alloc, ptr.*); + }; + + inline for (fields) |f, i| { + // TODO: Causes compiler segfault. why? + //const F = f.field_type; + //const F = @TypeOf(@field(result, f.name)); + const F = std.meta.Child(FieldPtr(*@TypeOf(result), f)); + const ptr = fieldPtr(&result, f); + const name = comptime util.comptimeJoin(".", f); + + const mode = comptime if (@hasDecl(T, "sql_serialize")) blk: { + if (@hasDecl(T.sql_serialize, name)) { + break :blk @field(T.sql_serialize, name); + } + break :blk .default; + } else .default; + switch (mode) { + .default => ptr.* = row_val.get(F, self.column_indices[i], alloc) catch |err| { + std.log.err("SQL: Error getting column {s} of type {}", .{ name, F }); + return err; + }, + .json => { + const str = row_val.get([]const u8, self.column_indices[i], alloc) catch |err| { + std.log.err("SQL: Error getting column {s} of type {}", .{ name, F }); + return err; + }; + const a = alloc orelse return error.AllocatorRequired; + defer a.free(str); + + var ts = std.json.TokenStream.init(str); + ptr.* = std.json.parse(F, &ts, .{ .allocator = a }) catch |err| { + std.log.err("SQL: Error parsing columns {s} of type {}: {}", .{ name, F, err }); + return error.ResultTypeMismatch; + }; + }, + else => @compileError("unknown mode"), + } + fields_allocated += 1; + } + + return result; + } else return null; + } + }; +} + +// Row is invalidated by the next call to result.row() +const Row = union(Engine) { + postgres: postgres.Row, + sqlite: sqlite.Row, + + // Returns a value of type T from the zero-indexed column given by idx. + // Not all types require an allocator to be present. If an allocator is needed but + // not required, it will return error.AllocatorRequired. + // The caller is responsible for deallocating T, if relevant. + fn get(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T { + if (T == void) return; + return switch (self) { + .postgres => |pg| try pg.get(T, idx, alloc), + .sqlite => |lite| try lite.get(T, idx, alloc), + }; + } +}; + +pub const ConstraintMode = enum { + deferred, + immediate, +}; + +pub const ConnPool = struct { + const max_conns = 4; + const Conn = struct { + engine: union(Engine) { + postgres: postgres.Db, + sqlite: sqlite.Db, + }, + in_use: std.atomic.Atomic(bool) = std.atomic.Atomic(bool).init(false), + current_tx_level: u8 = 0, + }; + + config: Config, + connections: [max_conns]Conn, + + pub fn init(cfg: Config) OpenError!ConnPool { + var self = ConnPool{ + .config = cfg, + .connections = undefined, + }; + var count: usize = 0; + errdefer for (self.connections[0..count]) |*c| closeConn(c); + for (self.connections) |*c| { + c.* = try self.createConn(); + count += 1; + } + + return self; + } + + pub fn deinit(self: *ConnPool) void { + for (self.connections) |*c| closeConn(c); + } + + pub fn acquire(self: *ConnPool) AcquireError!Db { + for (self.connections) |*c| { + if (tryAcquire(c)) return Db{ .conn = c }; + } + return error.NoConnectionsLeft; + } + + fn tryAcquire(conn: *Conn) bool { + const acquired = !conn.in_use.swap(true, .AcqRel); + if (acquired) { + if (conn.current_tx_level != 0) @panic("Transaction still open on unused db connection"); + return true; + } + + return false; + } + + fn createConn(self: *ConnPool) OpenError!Conn { + return switch (self.config) { + .postgres => |postgres_cfg| Conn{ + .engine = .{ + .postgres = try postgres.Db.open(postgres_cfg.pg_conn_str), + }, + }, + .sqlite => |lite_cfg| Conn{ + .engine = .{ + .sqlite = if (lite_cfg.sqlite_is_uri) + try sqlite.Db.openUri(lite_cfg.sqlite_file_path) + else + try sqlite.Db.open(lite_cfg.sqlite_file_path), + }, + }, + }; + } + + fn closeConn(conn: *Conn) void { + if (conn.in_use.loadUnchecked()) @panic("DB Conn still open"); + switch (conn.engine) { + .postgres => |pg| pg.close(), + .sqlite => |lite| lite.close(), + } + } +}; + +pub const Db = Tx(0); + +/// When tx_level == 0, the DB is operating in "implied transaction" mode where +/// every command is its own transaction +/// When tx_level >= 1, the DB has an explicit transaction open +/// When tx_level >= 2, the DB has (tx_level - 1) levels of transaction savepoints open +/// (like nested transactions) +fn Tx(comptime tx_level: u8) type { + return struct { + const Self = @This(); + const savepoint_name = if (tx_level == 0) + @compileError("Transaction not started") + else + std.fmt.comptimePrint("save_{}", .{tx_level}); + const next_savepoint_name = Tx(tx_level + 1).savepoint_name; + + conn: *ConnPool.Conn, + + /// The type of SQL engine being used. Use of this function should be discouraged + pub fn sqlEngine(self: Self) Engine { + return self.conn.engine; + } + + /// Return the connection to the pool + pub fn releaseConnection(self: Self) void { + if (tx_level != 0) @compileError("close must be called on root db"); + if (self.conn.current_tx_level != 0) { + std.log.warn("Database released while transaction in progress!", .{}); + self.rollbackUnchecked() catch {}; // TODO: Burn database connection + } + + if (!self.conn.in_use.swap(false, .AcqRel)) @panic("Double close on db conection"); + } + + // ********* Transaction management functions ********** + + /// Start an explicit transaction + pub fn begin(self: Self) !Tx(1) { + if (tx_level != 0) @compileError("Transaction already started"); + if (self.conn.current_tx_level != 0) return error.BadTransactionState; + + try self.exec("BEGIN", {}, null); + + self.conn.current_tx_level = 1; + + return Tx(1){ .conn = self.conn }; + } + + /// Create a savepoint (nested transaction) + pub fn savepoint(self: Self) !Tx(tx_level + 1) { + if (tx_level == 0) @compileError("Cannot place a savepoint on an implicit transaction"); + if (self.conn.current_tx_level != tx_level) return error.BadTransactionState; + + try self.exec("SAVEPOINT " ++ next_savepoint_name, {}, null); + + self.conn.current_tx_level = tx_level + 1; + + return Tx(tx_level + 1){ .conn = self.conn }; + } + + /// Commit the entire transaction + pub fn commit(self: Self) !void { + if (tx_level == 0) @compileError("Transaction not started"); + if (tx_level >= 2) @compileError("Cannot commit a savepoint"); + if (self.conn.current_tx_level == 0) return error.BadTransactionState; + + try self.exec("COMMIT", {}, null); + + self.conn.current_tx_level = 0; + } + + /// Release the current savepoint and all savepoints created from it. + pub fn release(self: Self) !void { + if (tx_level == 0) @compileError("Transaction not started"); + if (tx_level == 1) @compileError("Cannot release a transaction"); + if (self.conn.current_tx_level < tx_level) return error.BadTransactionState; + + try self.exec("RELEASE SAVEPOINT " ++ savepoint_name, {}, null); + + self.conn.current_tx_level = tx_level - 1; + } + + /// Rolls back the entire transaction + pub fn rollbackTx(self: Self) !void { + if (tx_level == 0) @compileError("Transaction not started"); + if (tx_level >= 2) @compileError("Cannot rollback a transaction using a savepoint"); + if (self.conn.current_tx_level == 0) return error.BadTransactionState; + + try self.rollbackUnchecked(); + + self.conn.current_tx_level = 0; + } + + /// Attempts to roll back to a savepoint + pub fn rollbackSavepoint(self: Self) !void { + if (tx_level == 0) @compileError("Transaction not started"); + if (tx_level == 1) @compileError("Cannot rollback a savepoint on the entire transaction"); + if (self.conn.current_tx_level < tx_level) return error.BadTransactionState; + + try self.exec("ROLLBACK TO " ++ savepoint_name, {}, null); + + self.conn.current_tx_level = tx_level - 1; + } + + /// Perform whichever rollback is appropriate for the situation + pub fn rollback(self: Self) void { + (if (tx_level < 2) self.rollbackTx() else self.rollbackSavepoint()) catch |err| { + std.log.err("Failed to rollback transaction: {}", .{err}); + std.log.err("{any}", .{@errorReturnTrace()}); + @panic("TODO: more gracefully handle rollback failures"); + }; + } + + pub const BeginOrSavepoint = Tx(tx_level + 1); + pub const beginOrSavepoint = if (tx_level == 0) begin else savepoint; + pub const commitOrRelease = if (tx_level < 2) commit else release; + + // Allows relaxing *some* constraints for the lifetime of the transaction. + // You should generally not do this, but it's useful when bootstrapping + // the initial admin community and cluster operator user. + pub fn setConstraintMode(self: Self, mode: ConstraintMode) !void { + if (tx_level == 0) @compileError("Transaction not started"); + if (tx_level >= 2) @compileError("Cannot set constraint mode on a savepoint"); + switch (self.sqlEngine()) { + .sqlite => try self.exec( + switch (mode) { + .immediate => "PRAGMA defer_foreign_keys = FALSE", + .deferred => "PRAGMA defer_foreign_keys = TRUE", + }, + {}, + null, + ), + .postgres => try self.exec( + switch (mode) { + .immediate => "SET CONSTRAINTS ALL IMMEDIATE", + .deferred => "SET CONSTRAINTS ALL DEFERRED", + }, + {}, + null, + ), + } + } + + // ********** Query Helpers ************ + + /// Runs a command without returning results + pub fn exec( + self: Self, + sql: [:0]const u8, + args: anytype, + alloc: ?std.mem.Allocator, + ) !void { + try self.execInternal(sql, args, .{ .allocator = alloc }, true); + } + + pub fn execWithOptions( + self: Self, + sql: [:0]const u8, + args: anytype, + options: QueryOptions, + ) !void { + try self.execInternal(sql, args, options, true); + } + + pub fn queryWithOptions( + self: Self, + comptime RowType: type, + sql: [:0]const u8, + args: anytype, + options: QueryOptions, + ) QueryError!Results(RowType) { + return Results(RowType).from(try self.runSql(sql, args, options, true)); + } + + pub fn query( + self: Self, + comptime RowType: type, + sql: [:0]const u8, + args: anytype, + alloc: ?Allocator, + ) QueryError!Results(RowType) { + return self.queryWithOptions(RowType, sql, args, .{ .allocator = alloc }); + } + + /// Runs a query to completion and returns a row of results, unless the query + /// returned a different number of rows. + pub fn queryRow( + self: Self, + comptime RowType: type, + q: [:0]const u8, + args: anytype, + alloc: ?Allocator, + ) QueryRowError!RowType { + var results = try self.query(RowType, q, args, alloc); + defer results.finish(); + + const row = (try results.row(alloc)) orelse return error.NoRows; + errdefer util.deepFree(alloc, row); + + // execute query to completion + var more_rows = false; + while (try results.row(alloc)) |r| { + util.deepFree(alloc, r); + more_rows = true; + } + if (more_rows) return error.TooManyRows; + + return row; + } + + pub fn queryRows( + self: Self, + comptime RowType: type, + q: [:0]const u8, + args: anytype, + max_items: ?usize, + alloc: std.mem.Allocator, + ) QueryRowError![]RowType { + return try self.queryRowsWithOptions(RowType, q, args, max_items, .{ .allocator = alloc }); + } + + // Runs a query to completion and returns the results as a slice + pub fn queryRowsWithOptions( + self: Self, + comptime RowType: type, + q: [:0]const u8, + args: anytype, + max_items: ?usize, + options: QueryOptions, + ) QueryRowError![]RowType { + var results = try self.queryWithOptions(RowType, q, args, options); + defer results.finish(); + + const alloc = options.allocator orelse return error.AllocatorRequired; + + var result_array = std.ArrayList(RowType).init(alloc); + errdefer result_array.deinit(); + if (max_items) |max| try result_array.ensureTotalCapacity(max); + + errdefer for (result_array.items) |r| util.deepFree(alloc, r); + + var too_many: bool = false; + while (try results.row(alloc)) |row| { + errdefer util.deepFree(alloc, row); + if (max_items) |max| { + if (result_array.items.len >= max) { + util.deepFree(alloc, row); + too_many = true; + continue; + } + } + + try result_array.append(row); + } + + if (too_many) return error.TooManyRows; + + return result_array.toOwnedSlice(); + } + + // Inserts a single value into a table + pub fn insert( + self: Self, + comptime table: []const u8, + value: anytype, + alloc: ?std.mem.Allocator, + ) !void { + const ValueType = comptime @TypeOf(value); + + const fields = std.meta.fields(ValueType); + comptime var types: [fields.len]type = undefined; + comptime var table_spec: []const u8 = table ++ "("; + comptime var value_spec: []const u8 = "("; + inline for (fields) |field, i| { + // This causes a compiler crash. Why? + //const F = field.field_type; + const F = @TypeOf(@field(value, field.name)); + // causes issues if F is @TypeOf(null), use dummy type + types[i] = if (F == @TypeOf(null)) ?i64 else F; + table_spec = comptime (table_spec ++ field.name ++ ","); + value_spec = comptime value_spec ++ std.fmt.comptimePrint("${},", .{i + 1}); + } + table_spec = comptime table_spec[0 .. table_spec.len - 1] ++ ")"; + value_spec = comptime value_spec[0 .. value_spec.len - 1] ++ ")"; + const q = comptime std.fmt.comptimePrint( + "INSERT INTO {s} VALUES {s}", + .{ table_spec, value_spec }, + ); + + var args_tuple: std.meta.Tuple(&types) = undefined; + inline for (fields) |field, i| { + args_tuple[i] = @field(value, field.name); + } + try self.exec(q, args_tuple, alloc); + } + + // internal helpers + + fn runSql( + self: Self, + sql: [:0]const u8, + args: anytype, + opt: QueryOptions, + comptime check_tx: bool, + ) !RawResults { + if (check_tx and self.conn.current_tx_level != tx_level) return error.BadTransactionState; + + return switch (self.conn.engine) { + .postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, opt) }, + .sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args, opt) }, + }; + } + + fn execInternal( + self: Self, + sql: [:0]const u8, + args: anytype, + options: QueryOptions, + comptime check_tx: bool, + ) !void { + var results = try self.runSql(sql, args, options, check_tx); + defer results.finish(); + + while (try results.row()) |_| {} + } + + fn rollbackUnchecked(self: Self) !void { + try self.execInternal("ROLLBACK", {}, .{}, false); + self.conn.current_tx_level = 0; + } + }; } diff --git a/src/sql/sqlite.bak.zig b/src/sql/sqlite.bak.zig deleted file mode 100644 index 99d1f50..0000000 --- a/src/sql/sqlite.bak.zig +++ /dev/null @@ -1,384 +0,0 @@ -const std = @import("std"); -const util = @import("util"); -const common = @import("./common.zig"); -const c = @cImport({ - @cInclude("sqlite3.h"); -}); - -const Uuid = util.Uuid; -const DateTime = util.DateTime; -const Allocator = std.mem.Allocator; - -fn getCharPos(text: []const u8, offset: c_int) struct { row: usize, col: usize } { - var row: usize = 0; - var col: usize = 0; - var i: usize = 0; - - if (offset > text.len) return .{ .row = 0, .col = 0 }; - - while (i != offset) : (i += 1) { - if (text[i] == '\n') { - row += 1; - col = 0; - } else { - col += 1; - } - } - - return .{ .row = row, .col = col }; -} - -fn handleUnexpectedError(db: *c.sqlite3, code: c_int, sql_text: ?[]const u8) error{Unexpected} { - std.log.err("Unexpected error in SQLite engine: {s} ({})", .{ c.sqlite3_errstr(code), code }); - - std.log.debug("Additional details:", .{}); - std.log.debug("{?s}", .{c.sqlite3_errmsg(db)}); - if (sql_text) |sql| { - const byte_offset = c.sqlite3_error_offset(db); - if (byte_offset >= 0) { - const pos = getCharPos(sql, byte_offset); - std.log.debug("Failed at char ({}:{}) of SQL:\n{s}", .{ pos.row, pos.col, sql }); - } - } - std.log.debug("{?}", .{@errorReturnTrace()}); - - return error.Unexpected; -} - -pub const Db = struct { - db: *c.sqlite3, - - pub fn open(path: [:0]const u8) common.OpenError!Db { - return openInternal(path, false); - } - - pub fn openUri(path: [:0]const u8) common.OpenError!Db { - return openInternal(path, true); - } - - fn openInternal(path: [:0]const u8, is_uri: bool) common.OpenError!Db { - const flags = c.SQLITE_OPEN_READWRITE | c.SQLITE_OPEN_CREATE | c.SQLITE_OPEN_EXRESCODE | if (is_uri) c.SQLITE_OPEN_URI else 0; - - var db: ?*c.sqlite3 = null; - switch (c.sqlite3_open_v2(@ptrCast([*c]const u8, path), &db, flags, null)) { - c.SQLITE_OK => {}, - else => |code| { - if (db == null) { - // this path should only be hit if out of memory, but log it anyways - std.log.err( - "Unable to open SQLite DB \"{s}\". Error: {?s} ({})", - .{ path, c.sqlite3_errstr(code), code }, - ); - return error.BadConnection; - } - - const ext_code = c.sqlite3_extended_errcode(db); - std.log.err( - \\Unable to open SQLite DB "{s}". Error: {?s} ({}) - \\Details: {?s} - , - .{ path, c.sqlite3_errstr(ext_code), ext_code, c.sqlite3_errmsg(db) }, - ); - - return error.Unexpected; - }, - } - - return Db{ - .db = db.?, - }; - } - - pub fn close(self: Db) void { - switch (c.sqlite3_close(self.db)) { - c.SQLITE_OK => {}, - - c.SQLITE_BUSY => { - std.log.err("SQLite DB could not be closed as it is busy.\n{s}", .{c.sqlite3_errmsg(self.db)}); - }, - - else => |err| { - std.log.err("Could not close SQLite DB", .{}); - handleUnexpectedError(self.db, err, null) catch {}; - }, - } - } - - pub fn exec(self: Db, sql: []const u8, args: anytype, opts: common.QueryOptions) common.ExecError!Results { - var stmt: ?*c.sqlite3_stmt = undefined; - switch (c.sqlite3_prepare_v2(self.db, sql.ptr, @intCast(c_int, sql.len), &stmt, null)) { - c.SQLITE_OK => {}, - else => |err| return handleUnexpectedError(self.db, err, sql), - } - errdefer switch (c.sqlite3_finalize(stmt)) { - c.SQLITE_OK => {}, - else => |err| { - handleUnexpectedError(self.db, err, sql) catch {}; - }, - }; - - if (@TypeOf(args) != void) { - // TODO: Fix for stage1 compiler - //inline for (args) |arg, i| { - inline for (std.meta.fields(@TypeOf(args))) |field, i| { - const arg = @field(args, field.name); - // SQLite treats $NNN args as having the name NNN, not index NNN. - // As such, if you reference $2 and not $1 in your query (such as - // when dynamically constructing queries), it could assign $2 the - // index 1. So we can't assume the index according to the 1-indexed - // arg array is equivalent to the param to bind it to. - // We can, however, look up the exact index to bind to. - // If the argument is not used in the query, then it will have an "index" - // of 0, and we must not bind the argument. - const name = std.fmt.comptimePrint("${}", .{i + 1}); - const db_idx = c.sqlite3_bind_parameter_index(stmt.?, name); - if (db_idx != 0) - try self.bindArgument(stmt.?, @intCast(u15, db_idx), arg) - else if (!opts.ignore_unused_arguments) - return error.UnusedArgument; - } - } - - return Results{ .stmt = stmt.?, .db = self.db }; - } - - fn bindArgument(self: Db, stmt: *c.sqlite3_stmt, idx: u15, val: anytype) !void { - if (comptime std.meta.trait.isZigString(@TypeOf(val))) { - return self.bindString(stmt, idx, val); - } - - const T = @TypeOf(val); - - switch (@typeInfo(T)) { - .Union => { - const arr = if (@hasDecl(T, "toCharArray")) - val.toCharArray() - else if (@hasDecl(T, "toCharArrayZ")) - val.toCharArrayZ() - else { - inline for (std.meta.fields(T)) |field| { - const Tag = std.meta.Tag(T); - const tag = @field(Tag, field.name); - - if (val == tag) return try self.bindArgument(stmt, idx, @field(val, field.name)); - } - unreachable; - }; - - const len = std.mem.len(&arr); - return self.bindString(stmt, idx, arr[0..len]); - }, - - .Struct => { - const arr = if (@hasDecl(T, "toCharArray")) - val.toCharArray() - else if (@hasDecl(T, "toCharArrayZ")) - val.toCharArrayZ() - else - @compileError("SQLite: Could not serialize " ++ @typeName(T) ++ " into staticly sized string"); - - const len = std.mem.len(&arr); - return self.bindString(stmt, idx, arr[0..len]); - }, - .Enum => |info| { - const name = if (info.is_exhaustive) - @tagName(val) - else - @compileError("SQLite: Could not serialize non-exhaustive enum " ++ @typeName(T) ++ " into string"); - return self.bindString(stmt, idx, name); - }, - .Optional => { - return if (val) |v| self.bindArgument(stmt, idx, v) else self.bindNull(stmt, idx); - }, - .Null => return self.bindNull(stmt, idx), - .Int => return self.bindInt(stmt, idx, std.math.cast(i64, val) orelse unreachable), - .Float => return self.bindFloat(stmt, idx, val), - .Bool => return self.bindInt(stmt, idx, if (val) 1 else 0), - else => @compileError("Unable to serialize type " ++ @typeName(T)), - } - } - - fn bindString(self: Db, stmt: *c.sqlite3_stmt, idx: u15, str: []const u8) !void { - const len = std.math.cast(c_int, str.len) orelse { - std.log.err("SQLite: string len {} too large", .{str.len}); - return error.BindException; - }; - - switch (c.sqlite3_bind_text(stmt, idx, str.ptr, len, c.SQLITE_TRANSIENT)) { - c.SQLITE_OK => {}, - else => |result| { - std.log.err("SQLite: Unable to bind string to index {}", .{idx}); - std.log.debug("SQLite: {s}", .{str}); - return handleUnexpectedError(self.db, result, null); - }, - } - } - - fn bindNull(self: Db, stmt: *c.sqlite3_stmt, idx: u15) !void { - switch (c.sqlite3_bind_null(stmt, idx)) { - c.SQLITE_OK => {}, - else => |result| { - std.log.err("SQLite: Unable to bind NULL to index {}", .{idx}); - return handleUnexpectedError(self.db, result, null); - }, - } - } - - fn bindInt(self: Db, stmt: *c.sqlite3_stmt, idx: u15, val: i64) !void { - switch (c.sqlite3_bind_int64(stmt, idx, val)) { - c.SQLITE_OK => {}, - else => |result| { - std.log.err("SQLite: Unable to bind int to index {}", .{idx}); - std.log.debug("SQLite: {}", .{val}); - return handleUnexpectedError(self.db, result, null); - }, - } - } - - fn bindFloat(self: Db, stmt: *c.sqlite3_stmt, idx: u15, val: f64) !void { - switch (c.sqlite3_bind_double(stmt, idx, val)) { - c.SQLITE_OK => {}, - else => |result| { - std.log.err("SQLite: Unable to bind float to index {}", .{idx}); - std.log.debug("SQLite: {}", .{val}); - return handleUnexpectedError(self.db, result, null); - }, - } - } -}; - -pub const Results = struct { - stmt: *c.sqlite3_stmt, - db: *c.sqlite3, - - pub fn finish(self: Results) void { - _ = c.sqlite3_finalize(self.stmt); - } - - pub fn row(self: Results) common.RowError!?Row { - return switch (c.sqlite3_step(self.stmt)) { - c.SQLITE_ROW => Row{ .stmt = self.stmt, .db = self.db }, - c.SQLITE_DONE => null, - - c.SQLITE_CONSTRAINT_UNIQUE => return error.UniqueViolation, - c.SQLITE_CONSTRAINT_CHECK => return error.CheckViolation, - c.SQLITE_CONSTRAINT_NOTNULL => return error.NotNullViolation, - c.SQLITE_CONSTRAINT_FOREIGNKEY => return error.ForeignKeyViolation, - c.SQLITE_CONSTRAINT => return error.ConstraintViolation, - - else => |err| handleUnexpectedError(self.db, err, self.getGeneratingSql()), - }; - } - - fn getGeneratingSql(self: Results) ?[]const u8 { - const ptr = c.sqlite3_sql(self.stmt) orelse return null; - return ptr[0..std.mem.len(ptr)]; - } - - pub fn columnCount(self: Results) common.ColumnCountError!u15 { - return @intCast(u15, c.sqlite3_column_count(self.stmt)); - } - - fn columnName(self: Results, idx: u15) ![]const u8 { - return if (c.sqlite3_column_name(self.stmt, idx)) |ptr| - ptr[0..std.mem.len(ptr)] - else - unreachable; - } - - pub fn columnIndex(self: Results, name: []const u8) common.ColumnIndexError!u15 { - var i: u15 = 0; - const count = try self.columnCount(); - while (i < count) : (i += 1) { - const column = try self.columnName(i); - if (std.mem.eql(u8, name, column)) return i; - } - - return error.NotFound; - } -}; - -pub const Row = struct { - stmt: *c.sqlite3_stmt, - db: *c.sqlite3, - - pub fn get(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T { - return getColumn(self.stmt, T, idx, alloc); - } -}; - -fn getColumn(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T { - const Eff = if (comptime std.meta.trait.is(.Optional)(T)) std.meta.Child(T) else T; - return switch (c.sqlite3_column_type(stmt, idx)) { - c.SQLITE_INTEGER => try getColumnInt(stmt, Eff, idx), - c.SQLITE_FLOAT => try getColumnFloat(stmt, Eff, idx), - c.SQLITE_TEXT => try getColumnText(stmt, Eff, idx, alloc), - c.SQLITE_NULL => { - if (T == DateTime) { - std.log.warn("SQLite: Treating NULL as DateTime epoch", .{}); - return std.mem.zeroes(DateTime); - } - - if (@typeInfo(T) != .Optional) { - std.log.err("SQLite column {}: Expected value of type {}, got (null)", .{ idx, T }); - return error.ResultTypeMismatch; - } - - return null; - }, - c.SQLITE_BLOB => { - std.log.err("SQLite column {}: SQLite value had unsupported storage class BLOB", .{idx}); - return error.ResultTypeMismatch; - }, - else => |class| { - std.log.err("SQLite column {}: SQLite value had unknown storage class {}", .{ idx, class }); - return error.ResultTypeMismatch; - }, - }; -} - -fn getColumnInt(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15) common.GetError!T { - const val: i64 = c.sqlite3_column_int64(stmt, idx); - - switch (T) { - DateTime => return DateTime{ .seconds_since_epoch = val }, - else => switch (@typeInfo(T)) { - .Int => if (std.math.cast(T, val)) |v| return v else { - std.log.err("SQLite column {}: Expected value of type {}, got {} (outside of range)", .{ idx, T, val }); - return error.ResultTypeMismatch; - }, - .Bool => if (val == 0) return false else return true, - else => { - std.log.err("SQLite column {}: Storage class INT cannot be parsed into type {}", .{ idx, T }); - return error.ResultTypeMismatch; - }, - }, - } -} - -fn getColumnFloat(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15) common.GetError!T { - const val: f64 = c.sqlite3_column_double(stmt, idx); - switch (T) { - // Only support floats that fit in range for now - f16, f32, f64 => return @floatCast(T, val), - DateTime => return DateTime{ - .seconds_since_epoch = std.math.lossyCast(i64, val * @intToFloat(f64, std.time.epoch.secs_per_day)), - }, - else => { - std.log.err("SQLite column {}: Storage class FLOAT cannot be parsed into type {}", .{ idx, T }); - return error.ResultTypeMismatch; - }, - } -} - -fn getColumnText(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T { - if (c.sqlite3_column_text(stmt, idx)) |ptr| { - const size = @intCast(usize, c.sqlite3_column_bytes(stmt, idx)); - const str = std.mem.sliceTo(ptr[0..size], 0); - - return common.parseValueNotNull(alloc, T, str); - } else { - std.log.err("SQLite column {}: TEXT value stored but engine returned null pointer (out of memory?)", .{idx}); - return error.ResultTypeMismatch; - } -}