diff --git a/src/main/db.zig b/src/main/db.zig index 9d49186..07417f7 100644 --- a/src/main/db.zig +++ b/src/main/db.zig @@ -3,6 +3,70 @@ const sql = @import("sql"); const models = @import("./models.zig"); const Uuid = @import("util").Uuid; +const String = []const u8; + +fn tableName(comptime T: type) String { + return switch (T) { + models.Note => "note", + models.User => "user", + else => unreachable, + }; +} + +fn join(comptime vals: anytype, comptime joiner: String) String { + comptime { + if (vals.len == 0) return ""; + + var result: String = ""; + for (vals) |v| { + result = std.fmt.comptimePrint("{s}{s}{s}", .{ result, joiner, v }); + } + + return result[2..]; + } +} + +const Query = struct { + select: []const String, + from: String, + where: String = "id = ?", + limit: usize = 1, + + pub fn str(comptime self: Query) String { + comptime { + return std.fmt.comptimePrint( + "SELECT {s} FROM {s} WHERE {s} LIMIT {};", + .{ join(self.select, ", "), self.from, self.where, self.limit }, + ); + } + } +}; + +fn filterOut(comptime vals: []const String, comptime to_ignore: []const String) []const String { + comptime { + var result: [vals.len]String = undefined; + var count = 0; + + for (vals) |v| { + const keep = for (to_ignore) |x| { + if (std.mem.eql(u8, x, v)) break false; + } else true; + + if (keep) { + result[count] = v; + count += 1; + } + } + + return result[0..count]; + } +} + +fn fieldsExcept(comptime T: type, comptime to_ignore: []const String) []const String { + comptime { + return filterOut(std.meta.fieldNames(T), to_ignore); + } +} pub const Database = struct { db: sql.Sqlite, @@ -30,18 +94,38 @@ pub const Database = struct { self.db.close(); } - pub fn getNoteById(self: *Database, id: Uuid, alloc: std.mem.Allocator) !?models.Note { - var stmt = try self.db.prepare("SELECT content FROM note WHERE id = ? LIMIT 1;"); + pub fn getById(self: *Database, comptime T: type, id: Uuid, alloc: std.mem.Allocator) !?T { + const fields = comptime fieldsExcept(T, &.{"id"}); + const q = comptime (Query{ + .select = fields, + .from = tableName(T), + .where = "id = ?", + .limit = 1, + }).str(); + + var stmt = try self.db.prepare(q); defer stmt.finalize(); const id_str = id.toCharArray(); try stmt.bindText(1, &id_str); const row = (try stmt.step()) orelse return null; - return models.Note{ - .id = id, - .content = try row.getTextAlloc(0, alloc), - }; + var val: T = undefined; + val.id = id; + + inline for (fields) |f, i| { + @field(val, f) = switch (@TypeOf(@field(val, f))) { + // TODO: Handle allocation failures gracefully + []const u8 => row.getTextAlloc(i, alloc) catch unreachable, + else => @compileError("unknown type"), + }; + } + + return val; + } + + pub fn getNoteById(self: *Database, id: Uuid, alloc: std.mem.Allocator) !?models.Note { + return self.getById(models.Note, id, alloc); } pub fn insertNote(self: *Database, note: models.Note) !void {