diff --git a/src/main/api.zig b/src/main/api.zig index abfe8aa..fae758a 100644 --- a/src/main/api.zig +++ b/src/main/api.zig @@ -1,9 +1,10 @@ const std = @import("std"); const util = @import("util"); const builtin = @import("builtin"); +const sql = @import("sql"); -const db = @import("./db.zig"); const models = @import("./db/models.zig"); +const migrations = @import("./migrations.zig"); pub const DateTime = util.DateTime; pub const Uuid = util.Uuid; const Config = @import("./main.zig").Config; @@ -28,7 +29,7 @@ pub const InviteRequest = struct { name: ?[]const u8 = null, expires_at: ?DateTime = null, // TODO: Change this to lifespan - max_uses: ?usize = null, + max_uses: ?u16 = null, invite_type: Type = .user, // must be user unless the creator is an admin to_community: ?[]const u8 = null, // only valid on admin community @@ -94,40 +95,41 @@ pub fn getRandom() std.rand.Random { } pub const ApiSource = struct { - db: db.Database, + db: sql.Db, internal_alloc: std.mem.Allocator, config: Config, - pub const Conn = ApiConn(db.Database); + pub const Conn = ApiConn(sql.Db); const root_username = "root"; - pub fn init(alloc: std.mem.Allocator, cfg: Config, root_password: ?[]const u8) !ApiSource { + pub fn init(alloc: std.mem.Allocator, cfg: Config, root_password: ?[]const u8, db_conn: sql.Db) !ApiSource { var self = ApiSource{ - .db = try db.Database.init(cfg.db.sqlite.db_file), + .db = db_conn, .internal_alloc = alloc, .config = cfg, }; + try migrations.up(db_conn); + if ((try services.users.lookupByUsername(&self.db, root_username, null)) == null) { std.log.info("No cluster root user detected. Creating...", .{}); // TODO: Fix this const password = root_password orelse return error.NeedRootPassword; - std.debug.print("\npassword: {s}\n", .{password}); var arena = std.heap.ArenaAllocator.init(alloc); defer arena.deinit(); const user_id = try services.users.create(&self.db, root_username, password, null, .{}, arena.allocator()); - std.debug.print("Created {s} ID {}", .{ root_username, user_id }); + std.log.debug("Created {s} ID {}", .{ root_username, user_id }); } return self; } fn getCommunityFromHost(self: *ApiSource, host: []const u8) !?Uuid { - if (try self.db.execRow( + if (try self.db.queryRow( &.{Uuid}, - "SELECT id FROM community WHERE host = ?", + "SELECT id FROM community WHERE host = $1", .{host}, null, )) |result| return result[0]; @@ -204,9 +206,9 @@ fn ApiConn(comptime DbConn: type) type { }; pub fn getTokenInfo(self: *Self) !TokenInfo { if (self.user_id) |user_id| { - const result = (try self.db.execRow( + const result = (try self.db.queryRow( &.{[]const u8}, - "SELECT username FROM user WHERE id = ?", + "SELECT username FROM user WHERE id = $1", .{user_id}, self.arena.allocator(), )) orelse { diff --git a/src/main/api/auth.zig b/src/main/api/auth.zig index 04b36b3..b6722af 100644 --- a/src/main/api/auth.zig +++ b/src/main/api/auth.zig @@ -23,9 +23,9 @@ pub const passwords = struct { }; pub fn verify(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) VerifyError!void { // TODO: This could be done w/o the dynamically allocated hash buf - const hash = (db.execRow( + const hash = (db.queryRow( &.{[]const u8}, - "SELECT hashed_password FROM account_password WHERE user_id = ? LIMIT 1", + "SELECT hashed_password FROM account_password WHERE user_id = $1 LIMIT 1", .{user_id}, alloc, ) catch return error.DbError) orelse return error.InvalidLogin; @@ -95,11 +95,11 @@ pub const tokens = struct { } fn lookupUserTokenFromHash(db: anytype, hash: []const u8, community_id: Uuid) !?Token.Info { - return if (try db.execRow( + return if (try db.queryRow( &.{ Uuid, DateTime }, \\SELECT user.id, token.issued_at \\FROM token JOIN user ON token.user_id = user.id - \\WHERE user.community_id = ? AND token.hash = ? + \\WHERE user.community_id = $1 AND token.hash = $2 \\LIMIT 1 , .{ community_id, hash }, @@ -114,11 +114,11 @@ pub const tokens = struct { } fn lookupSystemTokenFromHash(db: anytype, hash: []const u8) !?Token.Info { - return if (try db.execRow( + return if (try db.queryRow( &.{ Uuid, DateTime }, \\SELECT user.id, token.issued_at \\FROM token JOIN user ON token.user_id = user.id - \\WHERE user.community_id IS NULL AND token.hash = ? + \\WHERE user.community_id IS NULL AND token.hash = $1 \\LIMIT 1 , .{hash}, diff --git a/src/main/api/communities.zig b/src/main/api/communities.zig index 209f185..e7cbbdb 100644 --- a/src/main/api/communities.zig +++ b/src/main/api/communities.zig @@ -3,8 +3,6 @@ const builtin = @import("builtin"); const util = @import("util"); const models = @import("../db/models.zig"); -const DbError = @import("../db.zig").ExecError; - const getRandom = @import("../api.zig").getRandom; const Uuid = util.Uuid; @@ -14,7 +12,7 @@ const CreateError = error{ InvalidOrigin, UnsupportedScheme, CommunityExists, -} || DbError; +} || anyerror; // TODO pub const Scheme = enum { https, @@ -76,7 +74,7 @@ pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Co .created_at = DateTime.now(), }; - if ((try db.execRow(&.{Uuid}, "SELECT id FROM community WHERE host = ?", .{host}, null)) != null) { + if ((try db.queryRow(&.{Uuid}, "SELECT id FROM community WHERE host = $1", .{host}, null)) != null) { return error.CommunityExists; } @@ -94,7 +92,7 @@ fn firstIndexOf(str: []const u8, ch: u8) ?usize { } pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) !Community { - const result = (try db.execRow(&.{ Uuid, ?Uuid, []const u8, []const u8, Scheme, DateTime }, "SELECT id, owner_id, host, name, scheme, created_at FROM community WHERE host = ?", .{host}, alloc)) orelse return error.NotFound; + const result = (try db.queryRow(&.{ Uuid, ?Uuid, []const u8, []const u8, Scheme, DateTime }, "SELECT id, owner_id, host, name, scheme, created_at FROM community WHERE host = $1", .{host}, alloc)) orelse return error.NotFound; return Community{ .id = result[0], @@ -107,7 +105,7 @@ pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) !Commu } pub fn transferOwnership(db: anytype, community_id: Uuid, new_owner: Uuid) !void { - _ = try db.execRow(&.{i64}, "UPDATE community SET owner_id = ? WHERE id = ?", .{ new_owner, community_id }, null); + try db.exec("UPDATE community SET owner_id = $1 WHERE id = $2", .{ new_owner, community_id }, null); } pub const QueryArgs = struct { @@ -238,7 +236,7 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) ![]Communit max_items, }; - var results = try db.exec( + var results = try db.query( &.{ Uuid, ?Uuid, []const u8, []const u8, Scheme, DateTime }, builder.array.items, query_args, diff --git a/src/main/api/invites.zig b/src/main/api/invites.zig index 6e1f9c4..f07d286 100644 --- a/src/main/api/invites.zig +++ b/src/main/api/invites.zig @@ -2,7 +2,6 @@ const std = @import("std"); const builtin = @import("builtin"); const util = @import("util"); const models = @import("../db/models.zig"); -const DbError = @import("../db.zig").ExecError; const getRandom = @import("../api.zig").getRandom; const Uuid = util.Uuid; @@ -31,6 +30,7 @@ fn defaultJsonStringify(comptime T: type) fn (T, std.json.StringifyOptions, anyt }.jsonStringify; } +const InviteCount = u16; pub const Invite = struct { id: Uuid, @@ -40,10 +40,10 @@ pub const Invite = struct { code: []const u8, created_at: DateTime, - times_used: usize, + times_used: InviteCount, expires_at: ?DateTime, - max_uses: ?usize, + max_uses: ?InviteCount, invite_type: InviteType, }; @@ -59,7 +59,7 @@ const DbModel = struct { created_at: DateTime, expires_at: ?DateTime, - max_uses: ?usize, + max_uses: ?InviteCount, @"type": InviteType, }; @@ -72,7 +72,7 @@ fn cloneStr(str: []const u8, alloc: std.mem.Allocator) ![]const u8 { pub const InviteOptions = struct { name: ?[]const u8 = null, - max_uses: ?usize = null, + max_uses: ?InviteCount = null, expires_at: ?DateTime = null, invite_type: InviteType = .user, }; @@ -130,14 +130,14 @@ pub fn create(db: anytype, created_by: Uuid, to_community: ?Uuid, options: Invit pub fn getByCode(db: anytype, code: []const u8, alloc: std.mem.Allocator) !Invite { const code_clone = try cloneStr(code, alloc); - const info = (try db.execRow(&.{ Uuid, Uuid, Uuid, []const u8, DateTime, ?DateTime, usize, ?usize, InviteType }, + const info = (try db.queryRow(&.{ Uuid, Uuid, Uuid, []const u8, DateTime, ?DateTime, InviteCount, ?InviteCount, InviteType }, \\SELECT \\ invite.id, invite.created_by, invite.to_community, invite.name, \\ invite.created_at, invite.expires_at, \\ COUNT(local_user.user_id) as uses, invite.max_uses, \\ invite.type \\FROM invite LEFT OUTER JOIN local_user ON invite.id = local_user.invite_id - \\WHERE invite.code = ? + \\WHERE invite.code = $1 \\GROUP BY invite.id , .{code}, alloc)) orelse return error.NotFound; diff --git a/src/main/api/notes.zig b/src/main/api/notes.zig index 7db4c9a..8379e1e 100644 --- a/src/main/api/notes.zig +++ b/src/main/api/notes.zig @@ -40,11 +40,11 @@ pub fn create( } pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !Note { - const result = (try db.execRow( + const result = (try db.queryRow( &.{ Uuid, []const u8, DateTime }, \\SELECT author_id, content, created_at \\FROM note - \\WHERE id = ? + \\WHERE id = $1 \\LIMIT 1 , .{id}, diff --git a/src/main/api/users.zig b/src/main/api/users.zig index 0a03efa..3f8c440 100644 --- a/src/main/api/users.zig +++ b/src/main/api/users.zig @@ -37,9 +37,9 @@ pub const CreateOptions = struct { }; fn lookupSystemUserByUsername(db: anytype, username: []const u8) !?Uuid { - return if (try db.execRow( + return if (try db.queryRow( &.{Uuid}, - "SELECT user.id FROM user WHERE community_id IS NULL AND username = ?", + "SELECT user.id FROM user WHERE community_id IS NULL AND username = $1", .{username}, null, )) |result| @@ -49,9 +49,9 @@ fn lookupSystemUserByUsername(db: anytype, username: []const u8) !?Uuid { } fn lookupUserByUsername(db: anytype, username: []const u8, community_id: Uuid) !?Uuid { - return if (try db.execRow( + return if (try db.queryRow( &.{Uuid}, - "SELECT user.id FROM user WHERE community_id = ? AND username = ?", + "SELECT user.id FROM user WHERE community_id = $1 AND username = $2", .{ community_id, username }, null, )) |result| @@ -107,11 +107,11 @@ pub const User = struct { }; pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !User { - const result = (try db.execRow( + const result = (try db.queryRow( &.{ []const u8, []const u8, Uuid, DateTime }, \\SELECT user.username, community.host, community.id, user.created_at \\FROM user JOIN community ON user.community_id = community.id - \\WHERE user.id = ? + \\WHERE user.id = $1 \\LIMIT 1 , .{id}, diff --git a/src/main/controllers/auth.zig b/src/main/controllers/auth.zig index 961adca..c2b356a 100644 --- a/src/main/controllers/auth.zig +++ b/src/main/controllers/auth.zig @@ -14,6 +14,7 @@ pub const login = struct { pub const path = "/auth/login"; pub const method = .POST; pub fn handler(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { + std.debug.print("{s}", .{ctx.request.body.?}); const credentials = try utils.parseRequestBody(struct { username: []const u8, password: []const u8 }, ctx); defer utils.freeRequestBody(credentials, ctx.alloc); diff --git a/src/main/db.zig b/src/main/db.zig deleted file mode 100644 index 1e24d30..0000000 --- a/src/main/db.zig +++ /dev/null @@ -1,214 +0,0 @@ -const std = @import("std"); -const sql = @import("sql"); -const models = @import("./db/models.zig"); -const migrations = @import("./db/migrations.zig"); -const util = @import("util"); - -const Uuid = util.Uuid; -const DateTime = util.DateTime; -const String = []const u8; -const comptimePrint = std.fmt.comptimePrint; - -fn readRow(comptime RowTuple: type, row: sql.Row, allocator: ?std.mem.Allocator) !RowTuple { - var result: RowTuple = undefined; - // TODO: undo allocations on failure - inline for (std.meta.fields(RowTuple)) |f, i| { - @field(result, f.name) = try getAlloc(row, f.field_type, i, allocator); - } - - return result; -} - -pub fn ResultSet(comptime result_types: []const type) type { - return struct { - pub const Row = std.meta.Tuple(result_types); - - _stmt: sql.PreparedStmt, - err: ?ExecError = null, - - pub fn finish(self: *@This()) void { - self._stmt.finalize(); - } - - pub fn row(self: *@This(), allocator: ?std.mem.Allocator) ?Row { - const sql_result = self._stmt.step() catch |err| { - self.err = err; - return null; - }; - - if (sql_result) |sql_row| { - return readRow(Row, sql_row, allocator) catch |err| { - self.err = err; - return null; - }; - } else return null; - } - }; -} - -// Binds a value to a parameter in the query. Use this instead of string -// concatenation to avoid injection attacks; -// If a given type is not supported by this function, you can add support by -// declaring a method with the given signature: -// pub fn bindToSql(val: T, stmt: sql.PreparedStmt, idx: u15) !void -// TODO define what error set this ^ should return -fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void { - if (comptime std.meta.trait.isZigString(@TypeOf(val))) return stmt.bindText(idx, val); - - return switch (@TypeOf(val)) { - i64 => stmt.bindI64(idx, val), - Uuid => stmt.bindUuid(idx, val), - DateTime => stmt.bindDateTime(idx, val), - @TypeOf(null) => stmt.bindNull(idx), - else => |T| switch (@typeInfo(T)) { - .Optional => if (val) |v| bind(stmt, idx, v) else stmt.bindNull(idx), - .Enum => stmt.bindText(idx, @tagName(val)), - .Struct, .Union, .Opaque => if (@hasDecl(T, "bindToSql")) - val.bindToSql(stmt, idx) - else - @compileError("unsupported type " ++ @typeName(T)), - .Int => stmt.bindI64(idx, @intCast(i64, val)), - else => @compileError("unsupported type " ++ @typeName(T)), - }, - }; -} - -// Gets a value from the row, allocating memory if necessary. -// If a given type is not supported by this function, you can add support by -// declaring a method with the given signature: -// pub fn getFromSql(row: sql.Row, idx: u15, alloc: std.mem.Allocator) !T -// TODO define what error set this ^ should return -fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: ?std.mem.Allocator) !T { - return switch (T) { - []u8, []const u8 => row.getTextAlloc(idx, alloc orelse return error.AllocatorRequired), - i64 => row.getI64(idx), - Uuid => row.getUuid(idx), - DateTime => row.getDateTime(idx), - - else => switch (@typeInfo(T)) { - .Optional => if (try row.isNull(idx)) - null - else - try getAlloc(row, std.meta.Child(T), idx, alloc), - - .Struct, .Union, .Opaque => if (@hasDecl(T, "getFromSql")) - T.getFromSql(row, idx, alloc) - else - @compileError("unknown type " ++ @typeName(T)), - - .Enum => try getEnum(row, T, idx), - - .Int => @intCast(T, try row.getI64(idx)), - - //else => unreachable, - else => @compileError("unknown type " ++ @typeName(T)), - }, - }; -} - -fn maxTagLen(comptime T: type) usize { - var max: usize = 0; - for (std.meta.fields(T)) |f| { - if (f.name.len > max) { - max = f.name.len; - } - } - return max; -} - -fn getEnum(row: sql.Row, comptime T: type, idx: u15) !T { - var tag_buf: [maxTagLen(T)]u8 = undefined; - const tag_name = try row.getText(idx, &tag_buf); - inline for (std.meta.fields(T)) |tag| { - if (std.mem.eql(u8, tag_name, tag.name)) return @intToEnum(T, tag.value); - } - - return error.UnknownTag; -} - -pub const ExecError = sql.PrepareError || sql.RowGetError || sql.BindError || std.mem.Allocator.Error || error{ AllocatorRequired, UnknownTag }; - -pub const Database = struct { - db: sql.Sqlite, - - pub fn init(file_path: [:0]const u8) !Database { - var db = try sql.Sqlite.open(file_path); - errdefer db.close(); - - try migrations.up(&db); - - return Database{ .db = db }; - } - - pub fn deinit(self: *Database) void { - self.db.close(); - } - - pub fn exec( - self: *Database, - comptime result_types: []const type, - comptime q: []const u8, - args: anytype, - ) ExecError!ResultSet(result_types) { - std.log.debug("executing sql:\n===\n{s}\n===", .{q}); - - const stmt = try self.db.prepare(q); - errdefer stmt.finalize(); - - inline for (std.meta.fields(@TypeOf(args))) |field, i| { - try bind(stmt, @intCast(u15, i + 1), @field(args, field.name)); - } - - return ResultSet(result_types){ - ._stmt = stmt, - }; - } - - pub fn execRow( - self: *Database, - comptime result_types: []const type, - comptime q: []const u8, - args: anytype, - allocator: ?std.mem.Allocator, - ) ExecError!?ResultSet(result_types).Row { - var results = try self.exec(result_types, q, args); - defer results.finish(); - - const row = results.row(allocator); - std.log.debug("done exec", .{}); - if (row) |r| return r; - if (results.err) |err| { - std.log.debug("{}", .{err}); - std.log.debug("{?}", .{@errorReturnTrace()}); - return err; - } - return null; - } - - fn build_field_list(comptime T: type, comptime placeholder: ?[]const u8) []const u8 { - comptime { - const joiner = ","; - var result: []const u8 = ""; - inline for (std.meta.fields(T)) |f| { - result = result ++ joiner ++ (placeholder orelse f.name); - } - - return "(" ++ result[joiner.len..] ++ ")"; - } - } - - pub fn insert( - self: *Database, - comptime table: []const u8, - value: anytype, - ) ExecError!void { - const ValueType = comptime @TypeOf(value); - const table_spec = comptime table ++ build_field_list(ValueType, null); - const value_spec = comptime build_field_list(ValueType, "?"); - const q = comptime std.fmt.comptimePrint( - "INSERT INTO {s} VALUES {s}", - .{ table_spec, value_spec }, - ); - _ = try self.execRow(&.{}, q, value, null); - } -}; diff --git a/src/main/main.zig b/src/main/main.zig index 70f628b..e1bad55 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -1,5 +1,6 @@ const std = @import("std"); const builtin = @import("builtin"); +const sql = @import("sql"); const http = @import("http"); const util = @import("util"); @@ -82,11 +83,7 @@ pub const RequestServer = struct { pub const Config = struct { cluster_host: []const u8, - db: struct { - sqlite: struct { - db_file: [:0]const u8, - }, - }, + db: sql.Config, root_password: ?[]const u8 = null, }; @@ -105,7 +102,8 @@ const root_password_envvar = "CLUSTER_ROOT_PASSWORD"; pub fn main() anyerror!void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; var cfg = try loadConfig(gpa.allocator()); - var api_src = api.ApiSource.init(gpa.allocator(), cfg, std.os.getenv(root_password_envvar)) catch |err| switch (err) { + var db_conn = try sql.Db.open(cfg.db); + var api_src = api.ApiSource.init(gpa.allocator(), cfg, std.os.getenv(root_password_envvar), db_conn) catch |err| switch (err) { error.NeedRootPassword => { std.log.err( "No root user created and no password specified. Please provide the password for the root user by the ${s} environment variable for initial startup. This only needs to be done once", diff --git a/src/main/db/migrations.zig b/src/main/migrations.zig similarity index 66% rename from src/main/db/migrations.zig rename to src/main/migrations.zig index 8b993c5..cc4c1c9 100644 --- a/src/main/db/migrations.zig +++ b/src/main/migrations.zig @@ -1,8 +1,9 @@ +const std = @import("std"); const sql = @import("sql"); const DateTime = @import("util").DateTime; pub const Migration = struct { - name: []const u8, + name: [:0]const u8, up: []const u8, down: []const u8, }; @@ -15,54 +16,44 @@ fn firstIndexOf(str: []const u8, char: u8) ?usize { return null; } -fn execStmt(db: *sql.Sqlite, stmt_sql: []const u8) !void { - const stmt = try db.prepare(stmt_sql); - defer stmt.finalize(); - while (try stmt.step()) |_| {} +fn execStmt(tx: sql.Tx, stmt: []const u8, alloc: std.mem.Allocator) !void { + const stmt_null = try std.cstr.addNullByte(alloc, stmt); + defer alloc.free(stmt_null); + try tx.exec(stmt_null, .{}, null); } -fn execScript(db: *sql.Sqlite, script: []const u8) !void { - try execStmt(db, "BEGIN;"); - errdefer { - _ = execStmt(db, "ROLLBACK;") catch unreachable; - } +fn execScript(db: sql.Db, script: []const u8, alloc: std.mem.Allocator) !void { + const tx = try db.begin(); + errdefer tx.rollback(); var remaining = script; while (firstIndexOf(remaining, ';')) |last| { - try execStmt(db, remaining[0 .. last + 1]); + try execStmt(tx, remaining[0 .. last + 1], alloc); remaining = remaining[last + 1 ..]; } + if (remaining.len > 1) try execStmt(tx, remaining, alloc); - try execStmt(db, "COMMIT;"); + try tx.commit(); } -fn wasMigrationRan(db: *sql.Sqlite, name: []const u8) !bool { - const stmt = try db.prepare("SELECT COUNT(*) FROM migration WHERE name = ?;"); - defer stmt.finalize(); - - try stmt.bindText(1, name); - const result = (try stmt.step()).?; - - const count = try result.getI64(0); - return count != 0; +fn wasMigrationRan(db: sql.Db, name: []const u8, alloc: std.mem.Allocator) !bool { + const row = (try db.queryRow(&.{i32}, "SELECT COUNT(*) FROM migration WHERE name = $1", .{name}, alloc)) orelse return false; + return row[0] != 0; } -fn markMigrationAsRan(db: *sql.Sqlite, name: []const u8) !void { - const stmt = try db.prepare("INSERT INTO migration(name) VALUES(?);"); - defer stmt.finalize(); - - try stmt.bindText(1, name); - _ = try stmt.step(); -} - -pub fn up(db: *sql.Sqlite) !void { - try execScript(db, create_migration_table); +pub fn up(db: sql.Db) !void { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer _ = gpa.deinit(); + std.log.info("Running migrations...", .{}); + try execScript(db, create_migration_table, gpa.allocator()); for (migrations) |migration| { - if (!try wasMigrationRan(db, migration.name)) { - try execScript(db, migration.up); - try markMigrationAsRan(db, migration.name); + const was_ran = try wasMigrationRan(db, migration.name, gpa.allocator()); + if (!was_ran) { + std.log.info("Running migration {s}", .{migration.name}); + try execScript(db, migration.up, gpa.allocator()); + try db.insert("migration", .{ .name = migration.name }); } } } @@ -71,7 +62,7 @@ const create_migration_table = \\CREATE TABLE IF NOT EXISTS \\migration( \\ name TEXT NOT NULL PRIMARY KEY, - \\ applied_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + \\ applied_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\); ; @@ -85,7 +76,7 @@ const migrations: []const Migration = &.{ \\ id TEXT NOT NULL PRIMARY KEY, \\ username TEXT NOT NULL, \\ - \\ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\); \\ \\CREATE TABLE local_user( @@ -115,7 +106,7 @@ const migrations: []const Migration = &.{ \\ content TEXT NOT NULL, \\ author_id TEXT NOT NULL REFERENCES user(id), \\ - \\ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\); , .down = "DROP TABLE note;", @@ -129,7 +120,7 @@ const migrations: []const Migration = &.{ \\ user_id TEXT NOT NULL REFERENCES user(id), \\ note_id TEXT NOT NULL REFERENCES note(id), \\ - \\ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\); , .down = "DROP TABLE reaction;", @@ -141,7 +132,7 @@ const migrations: []const Migration = &.{ \\ hash TEXT NOT NULL PRIMARY KEY, \\ user_id TEXT NOT NULL REFERENCES local_user(id), \\ - \\ issued_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + \\ issued_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\); , .down = "DROP TABLE token;", @@ -158,8 +149,8 @@ const migrations: []const Migration = &.{ \\ \\ max_uses INTEGER, \\ - \\ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - \\ expires_at DATETIME, + \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + \\ expires_at TIMESTAMPTZ, \\ \\ type TEXT NOT NULL CHECK (type in ('system', 'community_owner', 'user')) \\); @@ -181,7 +172,7 @@ const migrations: []const Migration = &.{ \\ host TEXT NOT NULL UNIQUE, \\ scheme TEXT NOT NULL CHECK (scheme IN ('http', 'https')), \\ - \\ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\); \\ALTER TABLE user ADD COLUMN community_id TEXT REFERENCES community(id); \\ALTER TABLE invite ADD COLUMN to_community TEXT REFERENCES community(id); diff --git a/src/sql/common.zig b/src/sql/common.zig index 42a776f..3b2c5e6 100644 --- a/src/sql/common.zig +++ b/src/sql/common.zig @@ -5,10 +5,30 @@ const Uuid = util.Uuid; const DateTime = util.DateTime; const Allocator = std.mem.Allocator; +pub const QueryOptions = struct { + // If true, then it will not return an error on the SQLite backend + // if an argument passed does not map to a parameter in the query. + // Has no effect on the postgres backend. + ignore_unknown_parameters: bool = false, + + // The allocator to use for query preparation and submission. + // All memory allocated with this allocator will be freed before results + // are retrieved. + // Some types (enums with constant representation, null terminated strings) + // do not require allocators for prep. If an allocator is needed but not + // provided, `error.AllocatorRequired` will be returned. + // Only used with the postgres backend. + prep_allocator: ?Allocator = null, +}; + // Turns a value into its appropriate textual value (or null) // as appropriate using the given arena allocator -pub fn prepareParamText(arena: std.heap.ArenaAllocator, val: anytype) !?[:0]const u8 { - if (comptime std.meta.trait.isZigString(@TypeOf(val))) return val; +pub fn prepareParamText(arena: *std.heap.ArenaAllocator, val: anytype) !?[:0]const u8 { + if (comptime std.meta.trait.isZigString(@TypeOf(val))) { + if (comptime std.meta.sentinel(@TypeOf(val))) |s| if (comptime s == 0) return val; + + return try std.cstr.addNullByte(arena.allocator(), val); + } return switch (@TypeOf(val)) { [:0]u8, [:0]const u8 => val, @@ -36,7 +56,9 @@ pub fn parseValueNotNull(alloc: ?Allocator, comptime T: type, str: []const u8) ! []u8, []const u8 => if (alloc) |a| util.deepClone(a, str) else return error.AllocatorRequired, else => switch (@typeInfo(T)) { - .Enum => parseEnum(T, str), + .Int => std.fmt.parseInt(T, str, 0), + .Enum => std.meta.stringToEnum(T, str) orelse return error.InvalidValue, + .Optional => try parseValueNotNull(alloc, std.meta.Child(T), str), else => @compileError("Type " ++ @typeName(T) ++ " not supported"), }, diff --git a/src/sql/lib.zig b/src/sql/lib.zig index 3c64082..26cbf22 100644 --- a/src/sql/lib.zig +++ b/src/sql/lib.zig @@ -3,8 +3,11 @@ const util = @import("util"); const postgres = @import("./postgres.zig"); const sqlite = @import("./sqlite.zig"); +const common = @import("./common.zig"); const Allocator = std.mem.Allocator; +pub const QueryOptions = common.QueryOptions; + pub const Type = enum { postgres, sqlite, @@ -12,10 +15,10 @@ pub const Type = enum { pub const Config = union(Type) { postgres: struct { - conn_str: [:0]const u8, + pg_conn_str: [:0]const u8, }, sqlite: struct { - file_path: [:0]const u8, + sqlite_file_path: [:0]const u8, }, }; @@ -106,12 +109,12 @@ pub const Db = struct { return switch (cfg) { .postgres => |postgres_cfg| Db{ .underlying = .{ - .postgres = try postgres.Db.open(postgres_cfg.conn_str), + .postgres = try postgres.Db.open(postgres_cfg.pg_conn_str), }, }, .sqlite => |lite_cfg| Db{ .underlying = .{ - .sqlite = try sqlite.Db.open(lite_cfg.file_path), + .sqlite = try sqlite.Db.open(lite_cfg.sqlite_file_path), }, }, }; @@ -124,6 +127,17 @@ pub const Db = struct { } } + pub fn queryWithOptions( + self: Db, + comptime result_types: []const type, + sql: [:0]const u8, + args: anytype, + opt: QueryOptions, + ) !Results(result_types) { + // Create fake transaction to use its functions + return (Tx{ .underlying = self.underlying }).queryWithOptions(result_types, sql, args, opt); + } + pub fn query( self: Db, comptime result_types: []const type, @@ -153,7 +167,7 @@ pub const Db = struct { alloc: ?Allocator, ) !?Results(result_types).RowTuple { // Create fake transaction to use its functions - return (Tx{ .underlying = self.underlying }).exec(sql, args, alloc); + return (Tx{ .underlying = self.underlying }).queryRow(result_types, sql, args, alloc); } pub fn insert( @@ -182,14 +196,24 @@ pub const Tx = struct { self: Tx, sql: [:0]const u8, args: anytype, - alloc: ?Allocator, + opt: QueryOptions, ) !RawResults { return switch (self.underlying) { - .postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, alloc) }, - .sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args) }, + .postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, opt.prep_allocator) }, + .sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args, opt) }, }; } + pub fn queryWithOptions( + self: Tx, + comptime result_types: []const type, + sql: [:0]const u8, + args: anytype, + options: QueryOptions, + ) !Results(result_types) { + return Results(result_types){ .underlying = try self.queryInternal(sql, args, options) }; + } + // Executes a query and returns the result set pub fn query( self: Tx, @@ -198,7 +222,7 @@ pub const Tx = struct { args: anytype, alloc: ?Allocator, ) !Results(result_types) { - return Results(result_types){ .unerlying = try self.queryInternal(sql, args, alloc) }; + return self.queryWithOptions(result_types, sql, args, .{ .prep_allocator = alloc }); } // Executes a query without returning results @@ -208,7 +232,7 @@ pub const Tx = struct { args: anytype, alloc: ?Allocator, ) !void { - (try self.queryInternal(sql, args, alloc)).finish(); + _ = try self.queryRow(&.{}, sql, args, alloc); } // Runs a query and returns a single row @@ -242,13 +266,28 @@ pub const Tx = struct { value: anytype, ) !void { const ValueType = comptime @TypeOf(value); - const table_spec = comptime table ++ build_field_list(ValueType, null); - const value_spec = comptime build_field_list(ValueType, "?"); + + const fields = std.meta.fields(ValueType); + comptime var types: [fields.len]type = undefined; + comptime var table_spec: []const u8 = table ++ "("; + comptime var value_spec: []const u8 = "("; + inline for (fields) |field, i| { + types[i] = field.field_type; + table_spec = comptime (table_spec ++ field.name ++ ","); + value_spec = comptime value_spec ++ std.fmt.comptimePrint("${},", .{i + 1}); + } + table_spec = comptime table_spec[0 .. table_spec.len - 1] ++ ")"; + value_spec = comptime value_spec[0 .. value_spec.len - 1] ++ ")"; const q = comptime std.fmt.comptimePrint( "INSERT INTO {s} VALUES {s}", .{ table_spec, value_spec }, ); - try self.exec(q, value, null); + + var args_tuple: std.meta.Tuple(&types) = undefined; + inline for (fields) |field, i| { + args_tuple[i] = @field(value, field.name); + } + try self.exec(q, args_tuple, null); } pub fn rollback(self: Tx) void { @@ -261,15 +300,3 @@ pub const Tx = struct { try self.exec("COMMIT", .{}, null); } }; - -fn build_field_list(comptime T: type, comptime placeholder: ?[]const u8) []const u8 { - comptime { - const joiner = ","; - var result: []const u8 = ""; - inline for (std.meta.fields(T)) |f| { - result = result ++ joiner ++ (placeholder orelse f.name); - } - - return "(" ++ result[joiner.len..] ++ ")"; - } -} diff --git a/src/sql/postgres.zig b/src/sql/postgres.zig index 2032ef5..a7a2fd1 100644 --- a/src/sql/postgres.zig +++ b/src/sql/postgres.zig @@ -90,7 +90,7 @@ pub const Db = struct { var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired); defer arena.deinit(); const params = try arena.allocator().alloc(?[*]const u8, args.len); - inline for (args) |a, i| params[i] = if (try common.prepareParamText(arena, a)) |slice| slice.ptr else null; + inline for (args) |a, i| params[i] = if (try common.prepareParamText(&arena, a)) |slice| slice.ptr else null; break :blk c.PQexecParams(self.conn, sql.ptr, @intCast(c_int, params.len), null, params.ptr, null, null, format_text); } else { diff --git a/src/sql/sqlite.zig b/src/sql/sqlite.zig index 1ace6c2..9a8de0e 100644 --- a/src/sql/sqlite.zig +++ b/src/sql/sqlite.zig @@ -103,7 +103,7 @@ pub const Db = struct { } } - pub fn exec(self: Db, sql: []const u8, args: anytype) !Results { + pub fn exec(self: Db, sql: []const u8, args: anytype, opts: common.QueryOptions) !Results { var stmt: ?*c.sqlite3_stmt = undefined; switch (c.sqlite3_prepare_v2(self.db, sql.ptr, @intCast(c_int, sql.len), &stmt, null)) { c.SQLITE_OK => {}, @@ -134,7 +134,7 @@ pub const Db = struct { return handleUnexpectedError(self.db, err, sql); }, } - } else unreachable; + } else if (!opts.ignore_unknown_parameters) return error.UnknownParameter; } return Results{ .stmt = stmt.?, .db = self.db }; @@ -216,7 +216,7 @@ pub const Row = struct { @intCast(T, c.sqlite3_column_int64(self.stmt, idx)) else self.getFromString(T, idx, alloc), - .Optional => self.getNotNull(std.meta.Child(T), idx, alloc), + .Optional => try self.getNotNull(std.meta.Child(T), idx, alloc), else => self.getFromString(T, idx, alloc), }, }; diff --git a/src/sql/test.zig b/src/sql/test.zig new file mode 100644 index 0000000..379d411 --- /dev/null +++ b/src/sql/test.zig @@ -0,0 +1,41 @@ +const sql = @import("./lib.zig"); +const std = @import("std"); +const Uuid = @import("util").Uuid; + +const alloc = std.testing.allocator; + +pub fn main() !void { + const db = try sql.Db.open(.{ + //.postgres = .{ + //.conn_str = "postgresql://localhost", + //}, + .sqlite = .{ + .file_path = "./test.db", + }, + }); + defer db.close(); + + const tx = try db.begin(); + try tx.commit(); + tx.rollback(); + try tx.commit(); +} + +test { + const db = try sql.Db.open(.{ + .sqlite = .{ + .file_path = "./test.db", + }, + }); + defer db.close(); + + var results = try db.query(&.{[]const u8}, "SELECT $1 as id", .{"abcdefg"}, alloc); + defer results.finish(); + + const row = (try results.row(alloc)) orelse unreachable; + defer alloc.free(row[0]); + + try std.testing.expectEqualStrings("abcdefg", row[0]); + + std.log.info("value: {s}", .{row[0]}); +} diff --git a/src/util/DateTime.zig b/src/util/DateTime.zig index 072191a..ab78541 100644 --- a/src/util/DateTime.zig +++ b/src/util/DateTime.zig @@ -1,9 +1,48 @@ const DateTime = @This(); const std = @import("std"); +const epoch = std.time.epoch; seconds_since_epoch: i64, +pub fn parse(str: []const u8) !DateTime { + // TODO: Try other formats + + return try parseRfc3339(str); +} + +// TODO: Validate non-numeric aspects of datetime +// TODO: Don't panic on bad string +// TODO: Make seconds optional (see ActivityStreams 2.0 spec ยง2.3) +// TODO: Handle times before 1970 +pub fn parseRfc3339(str: []const u8) !DateTime { + const year_num = try std.fmt.parseInt(u16, str[0..4], 10); + const month_num = try std.fmt.parseInt(std.meta.Tag(epoch.Month), str[5..7], 10); + const day_num = @as(i64, try std.fmt.parseInt(u9, str[8..10], 10)); + const hour_num = @as(i64, try std.fmt.parseInt(u5, str[11..13], 10)); + const minute_num = @as(i64, try std.fmt.parseInt(u6, str[14..15], 10)); + const second_num = @as(i64, try std.fmt.parseInt(u6, str[16..17], 10)); + + const is_leap_year = epoch.isLeapYear(year_num); + const leap_days_preceding_epoch = comptime epoch.epoch_year / 4 - epoch.epoch_year / 100 + epoch.epoch_year / 400; + const leap_days_preceding_year = year_num / 4 - year_num / 100 + year_num / 400 - leap_days_preceding_epoch - if (is_leap_year) @as(i64, 1) else 0; + + const epoch_day = (year_num - epoch.epoch_year) * 365 + leap_days_preceding_year + year_day: { + var days_preceding_month: i64 = 0; + var month_i: i64 = 1; + while (month_i < month_num) : (month_i += 1) { + days_preceding_month += epoch.getDaysInMonth(if (is_leap_year) .leap else .not_leap, @intToEnum(epoch.Month, month_i)); + } + break :year_day days_preceding_month + day_num; + }; + + const day_second = (hour_num * 60 + minute_num) * 60 + second_num; + + return DateTime{ + .seconds_since_epoch = epoch_day * epoch.secs_per_day + day_second, + }; +} + pub fn now() DateTime { return .{ .seconds_since_epoch = std.time.timestamp() }; } @@ -40,10 +79,24 @@ pub fn second(value: DateTime) u6 { return value.epochSeconds().getDaySeconds().getSecondsIntoMinute(); } +const array_len = 20; + +pub fn toCharArray(value: DateTime) [array_len + 1]u8 { + var buf: [array_len]u8 = undefined; + _ = std.fmt.bufPrintZ(&buf, "{}", value) catch unreachable; + return buf; +} + +pub fn toCharArrayZ(value: DateTime) [array_len + 1:0]u8 { + var buf: [array_len + 1:0]u8 = undefined; + _ = std.fmt.bufPrintZ(&buf, "{}", value) catch unreachable; + return buf; +} + pub fn format(value: DateTime, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { return std.fmt.format( writer, - "{}-{}-{} {}:{}:{}", + "{:0>4}-{:0>2}-{:0>2}T{:0>2}:{:0>2}:{:0>2}Z", .{ value.year(), value.month().numeric(), value.day(), value.hour(), value.minute(), value.second() }, ); } diff --git a/src/util/Uuid.zig b/src/util/Uuid.zig index 1aa95a0..7d1c8b9 100644 --- a/src/util/Uuid.zig +++ b/src/util/Uuid.zig @@ -58,7 +58,7 @@ pub const ParseError = error{ }; pub fn parse(str: []const u8) ParseError!Uuid { - if (str.len != string_len) return error.InvalidLength; + if (str.len != string_len and (str.len != string_len + 1 or str[str.len - 1] != 0)) return error.InvalidLength; var data: [16]u8 = undefined; var str_i: usize = 0; diff --git a/src/util/lib.zig b/src/util/lib.zig index a5639db..cd2eddc 100644 --- a/src/util/lib.zig +++ b/src/util/lib.zig @@ -24,11 +24,11 @@ pub fn deepFree(alloc: ?std.mem.Allocator, val: anytype) void { else => @compileError("Many and C-style pointers not supported by deepfree"), }, .Optional => if (val) |v| deepFree(alloc, v) else {}, - .Struct => |struct_info| for (struct_info.fields) |field| deepFree(alloc, @field(val, field.name)), + .Struct => |struct_info| inline for (struct_info.fields) |field| deepFree(alloc, @field(val, field.name)), .Union, .ErrorUnion => @compileError("TODO: Unions not yet supported by deepFree"), .Array => for (val) |v| deepFree(alloc, v), - .Int, .Float, .Bool, .Void, .Type => {}, + .Enum, .Int, .Float, .Bool, .Void, .Type => {}, else => @compileError("Type " ++ @typeName(T) ++ " not supported by deepFree"), } @@ -87,7 +87,7 @@ pub fn deepClone(alloc: std.mem.Allocator, val: anytype) !@TypeOf(val) { count += 1; } }, - .Int, .Float, .Bool, .Void, .Type => { + .Enum, .Int, .Float, .Bool, .Void, .Type => { result = val; }, else => @compileError("Type " ++ @typeName(T) ++ " not supported"),