From 4bddb9f633ae4e4b7beed1695338945f1a64be9c Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Mon, 5 Sep 2022 00:03:31 -0700 Subject: [PATCH] fucking around w/ db stuff --- build.zig | 2 +- src/main/api.zig | 122 ++++++++++----- src/main/db.zig | 168 ++++++++++++-------- src/main/db/query_builder.zig | 284 ---------------------------------- 4 files changed, 187 insertions(+), 389 deletions(-) delete mode 100644 src/main/db/query_builder.zig diff --git a/build.zig b/build.zig index a2b54b1..e3f27cc 100644 --- a/build.zig +++ b/build.zig @@ -24,7 +24,7 @@ pub fn build(b: *std.build.Builder) void { // There are some weird problems relating to sentinel values and function pointers // when using the stage1 compiler. Just disable it entirely for now. - b.use_stage1 = false; + //b.use_stage1 = false; const exe = b.addExecutable("apub", "src/main/main.zig"); exe.setTarget(target); diff --git a/src/main/api.zig b/src/main/api.zig index 08f5a4c..09a4b60 100644 --- a/src/main/api.zig +++ b/src/main/api.zig @@ -127,23 +127,14 @@ pub const ApiSource = struct { var my_db = try db.Database.init(); { - const C = db.builder.Condition; - const qt = db.builder.queryTables(&.{ models.User, models.User, models.LocalUser, models.Invite }); - const UInviter = qt[0]; - const UInvitee = qt[1]; - const LUInvitee = qt[2]; - const Invite = qt[3]; - const q = comptime db.builder.Query - .from(qt) - .select(&.{ UInviter.select(.username), UInvitee.select(.username), Invite.select(.id) }) - .where(C.all(&.{ - C.eql(UInviter.field(.id), Invite.field(.created_by)), - C.eql(LUInvitee.field(.invite_id), Invite.field(.id)), - C.eql(LUInvitee.field(.user_id), UInvitee.field(.id)), - })); + const row = try my_db.execRow2( + &.{Uuid}, + "SELECT id FROM user WHERE username = ?", + .{"heartles"}, + null, + ); - const result = (try my_db.execRowQuery(q, alloc)) orelse unreachable; - std.log.debug("{s} invited {s}", .{ result[0], result[1] }); + std.log.debug("{s}", .{row.?[0]}); } return ApiSource{ @@ -157,8 +148,8 @@ pub const ApiSource = struct { pub fn connectUnauthorized(self: *ApiSource, host: ?[]const u8, alloc: std.mem.Allocator) !Conn { const community_id = blk: { if (host) |h| { - const community = try self.db.getBy(models.Community, .host, h, alloc); - if (community) |c| break :blk c.id; + const result = try self.db.execRow2(&.{Uuid}, "SELECT id FROM community WHERE host = ?", .{h}, null); + if (result) |r| break :blk r[0]; } break :blk null; @@ -187,7 +178,14 @@ pub const ApiSource = struct { models.Token.HashFn.hash(&decoded, &hash.data, .{}); const db_token = (try self.db.getBy(models.Token, .hash, hash, conn.arena.allocator())) orelse return error.InvalidToken; + //const token_result = (try self.db.execRow2( + //&.{Uuid}, + //"SELECT id FROM token WHERE hash = ?", + //.{hash}, + //null, + //)) orelse return error.InvalidToken; + //conn.as_user = token_result[0]; conn.as_user = db_token.user_id; return conn; @@ -332,20 +330,42 @@ fn ApiConn(comptime DbConn: type) type { const user_id = Uuid.randV4(prng.random()); // TODO: lock for transaction - if (try self.db.existsWhereEq(models.User, .username, info.username)) { + // TODO: not community aware :( + if (try self.db.execRow2(&.{}, "SELECT 1 FROM user WHERE username = ?", .{info.username}, null) != null) { + //if (try self.db.existsWhereEq(models.User, .username, info.username)) { return error.UsernameUnavailable; } const now = DateTime.now(); const invite_id = if (info.invite_code) |invite_code| blk: { - const invite = (try self.db.getBy(models.Invite, .invite_code, invite_code, self.arena.allocator())) orelse return error.InvalidInvite; - const uses = try self.db.countWhereEq(models.LocalUser, .invite_id, invite.id); - const uses_left = if (invite.max_uses) |max_uses| uses < max_uses else true; - const expired = if (invite.expires_at) |expires_at| now.seconds_since_epoch > expires_at.seconds_since_epoch else false; + // TODO have this query also check for time-based expiration + const result = (try self.db.execRow2( + &.{ Uuid, ?DateTime }, + \\SELECT invite.id, invite.expires_at + \\FROM invite + \\ LEFT OUTER JOIN local_user ON invite.id = local_user.invite_id + \\WHERE invite.invite_code = ? + \\GROUP BY invite.id + \\HAVING + \\ (invite.max_uses IS NULL OR invite.max_uses > COUNT(local_user.user_id)) + \\ + , + .{invite_code}, + null, + )) orelse return error.InvalidInvite; - if (!uses_left or expired) return error.InvalidInvite; + const expired = if (result[1]) |expires_at| now.seconds_since_epoch > expires_at.seconds_since_epoch else false; + if (expired) return error.InvalidInvite; + + //const invite = (try self.db.getBy(models.Invite, .invite_code, invite_code, self.arena.allocator())) orelse return error.InvalidInvite; + //const invite = (try self.db.getBy(models.Invite, .invite_code, invite_code, self.arena.allocator())) orelse return error.InvalidInvite; + //const uses = try self.db.countWhereEq(models.LocalUser, .invite_id, invite.id); + //const uses_left = if (invite.max_uses) |max_uses| uses < max_uses else true; + //const expired = if (invite.expires_at) |expires_at| now.seconds_since_epoch > expires_at.seconds_since_epoch else false; + + //if (!uses_left or expired) return error.InvalidInvite; // TODO: increment uses - break :blk invite.id; + break :blk result[0]; } else null; // use internal alloc because necessary buffer is *big* @@ -354,8 +374,15 @@ fn ApiConn(comptime DbConn: type) type { const community_id = if (info.community_host) |host| blk: { //const id_tuple = (try self.db.execRow("select id from community where host = '?'", host, &.{Uuid}, self.arena.allocator())) orelse return error.CommunityNotFound; - const community = (try self.db.getBy(models.Community, .host, host, self.arena.allocator())) orelse return error.CommunityNotFound; - break :blk community.id; + const community_result = (try self.db.execRow2( + &.{Uuid}, + "SELECT id FROM community WHERE host = ?", + .{host}, + null, + )) orelse return error.CommunityNotFound; + + //const community = (try self.db.getBy(models.Community, .host, host, self.arena.allocator())) orelse return error.CommunityNotFound; + break :blk community_result[0]; //break :blk id_tuple[0]; } else null; @@ -385,23 +412,37 @@ fn ApiConn(comptime DbConn: type) type { pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResult { // TODO: This gives away the existence of a user through a timing side channel. is that acceptable? - const user_info = (try self.db.getBy(models.User, .username, username, self.arena.allocator())) orelse return error.InvalidLogin; - const local_user_info = (try self.db.getBy(models.LocalUser, .user_id, user_info.id, self.arena.allocator())) orelse return error.InvalidLogin; + //const user_info = (try self.db.getBy(models.User, .username, username, self.arena.allocator())) orelse return error.InvalidLogin; + //const local_user_info = (try self.db.getBy(models.LocalUser, .user_id, user_info.id, self.arena.allocator())) orelse return error.InvalidLogin; + + const user_info = (try self.db.execRow2( + &.{ Uuid, []const u8 }, + \\SELECT user.id, local_user.hashed_password + \\FROM user JOIN local_user ON local_user.user_id = user.id + \\WHERE user.username = ? + , + .{username}, + self.arena.allocator(), + )) orelse return error.InvalidLogin; + + const user_id = user_info[0]; + const hashed_password = user_info[1]; + //defer free(self.arena.allocator(), user_info); const Hash = std.crypto.pwhash.scrypt; - Hash.strVerify(local_user_info.hashed_password, password, .{ .allocator = self.internal_alloc }) catch |err| switch (err) { + Hash.strVerify(hashed_password, password, .{ .allocator = self.internal_alloc }) catch |err| switch (err) { error.PasswordVerificationFailed => return error.InvalidLogin, else => return err, }; - const token = try self.createToken(user_info.id); + const token = try self.createToken(user_id); var token_enc: [token_str_len]u8 = undefined; _ = std.base64.standard.Encoder.encode(&token_enc, &token.value); return LoginResult{ - .user_id = user_info.id, + .user_id = user_id, .token = token_enc, .issued_at = token.info.issued_at, }; @@ -425,7 +466,7 @@ fn ApiConn(comptime DbConn: type) type { .issued_at = DateTime.now(), }; - try self.db.insert(models.Token, db_token); + try self.db.insert2("token", db_token); return TokenResult{ .info = db_token, .value = token, @@ -440,14 +481,21 @@ fn ApiConn(comptime DbConn: type) type { // Users can only make invites to their own community, unless they // are system users const community_id = if (options.to_community) |host| blk: { - const desired_community = (try self.db.getBy(models.Community, .host, host, self.arena.allocator())) orelse return error.CommunityNotFound; - if (user.community_id != null and !Uuid.eql(desired_community.id, user.community_id.?)) { + const desired_community = (try self.db.execRow2( + &.{Uuid}, + "SELECT id FROM community WHERE host = ?", + .{host}, + null, + )) orelse return error.CommunityNotFound; + + if (user.community_id != null and !Uuid.eql(desired_community[0], user.community_id.?)) { return error.WrongCommunity; } - break :blk desired_community.id; + break :blk desired_community[0]; } else null; - if (user.community_id != null and options.to_community == null) { + + if (user.community_id != null and community_id == null) { return error.WrongCommunity; } diff --git a/src/main/db.zig b/src/main/db.zig index c5a0e16..2288b3d 100644 --- a/src/main/db.zig +++ b/src/main/db.zig @@ -9,8 +9,6 @@ const DateTime = util.DateTime; const String = []const u8; const comptimePrint = std.fmt.comptimePrint; -pub const builder = @import("./db/query_builder.zig"); - fn tableName(comptime T: type) String { return switch (T) { models.Note => "note", @@ -25,6 +23,44 @@ fn tableName(comptime T: type) String { }; } +fn readRow(comptime RowTuple: type, row: sql.Row, allocator: ?std.mem.Allocator) !RowTuple { + var result: RowTuple = undefined; + // TODO: undo allocations on failure + inline for (std.meta.fields(RowTuple)) |f, i| { + @field(result, f.name) = try getAlloc(row, f.field_type, i, allocator); + } + + return result; +} + +pub fn ResultSet(comptime result_types: []const type) type { + return struct { + pub const QueryError = anyerror; + pub const Row = std.meta.Tuple(result_types); + + _stmt: sql.PreparedStmt, + err: ?QueryError = null, + + pub fn finish(self: *@This()) void { + self._stmt.finalize(); + } + + pub fn row(self: *@This(), allocator: ?std.mem.Allocator) ?Row { + const sql_result = self._stmt.step() catch |err| { + self.err = err; + return null; + }; + + if (sql_result) |sql_row| { + return readRow(Row, sql_row, allocator) catch |err| { + self.err = err; + return null; + }; + } else return null; + } + }; +} + // Combines an array/tuple of strings into a single string, with a copy of // joiner in between each one fn join(comptime vals: anytype, comptime joiner: String) String { @@ -121,8 +157,9 @@ fn fieldsExcept(comptime T: type, comptime to_ignore: []const String) []const St // pub fn bindToSql(val: T, stmt: sql.PreparedStmt, idx: u15) !void // TODO define what error set this ^ should return fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void { + if (comptime std.meta.trait.isZigString(@TypeOf(val))) return stmt.bindText(idx, val); + return switch (@TypeOf(val)) { - []u8, []const u8 => stmt.bindText(idx, val), i64 => stmt.bindI64(idx, val), Uuid => stmt.bindUuid(idx, val), DateTime => stmt.bindDateTime(idx, val), @@ -134,7 +171,8 @@ fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void { val.bindToSql(stmt, idx) else @compileError("unsupported type " ++ @typeName(T)), - else => @compileError("unsupported Type " ++ @typeName(T)), + else => unreachable, + //@compileError("unsupported type " ++ @typeName(T)), }, }; } @@ -144,9 +182,9 @@ fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void { // declaring a method with the given signature: // pub fn getFromSql(row: sql.Row, idx: u15, alloc: std.mem.Allocator) !T // TODO define what error set this ^ should return -fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: std.mem.Allocator) !T { +fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: ?std.mem.Allocator) !T { return switch (T) { - []u8, []const u8 => row.getTextAlloc(idx, alloc), + []u8, []const u8 => row.getTextAlloc(idx, alloc orelse return error.AllocatorRequired), i64 => row.getI64(idx), Uuid => row.getUuid(idx), DateTime => row.getDateTime(idx), @@ -158,11 +196,11 @@ fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: std.mem.Allocator) try getAlloc(row, std.meta.Child(T), idx, alloc), .Struct, .Union, .Opaque => if (@hasDecl(T, "getFromSql")) - T.getFromSql(row, idx, alloc) + T.getFromSql(row, idx, alloc orelse return error.AllocatorRequired) else @compileError("unknown type " ++ @typeName(T)), - .Enum => try getEnum(row, T, idx, alloc), + .Enum => try getEnum(row, T, idx, alloc orelse return error.AllocatorRequired), else => @compileError("unknown type " ++ @typeName(T)), }, @@ -195,20 +233,65 @@ pub const Database = struct { self.db.close(); } - pub fn execRowQuery(self: *Database, comptime q: builder.Query, alloc: std.mem.Allocator) !?q.rowType() { - std.log.debug("executing sql:\n===\n{s}\n===", .{q.str()}); - var stmt = try self.db.prepare(q.str()); + pub fn exec2( + self: *Database, + comptime result_types: []const type, + comptime q: []const u8, + args: anytype, + ) !ResultSet(result_types) { + std.log.debug("executing sql:\n===\n{s}\n===", .{q}); + + const stmt = try self.db.prepare(q); errdefer stmt.finalize(); - const row = (try stmt.step()) orelse return null; - - std.log.debug("successful query", .{}); - var result: q.rowType() = undefined; - inline for (std.meta.fields(q.rowType())) |f, i| { - result[i] = try getAlloc(row, f.field_type, i, alloc); + inline for (std.meta.fields(@TypeOf(args))) |field, i| { + try bind(stmt, @intCast(u15, i + 1), @field(args, field.name)); } - return result; + return ResultSet(result_types){ + ._stmt = stmt, + }; + } + + pub fn execRow2( + self: *Database, + comptime result_types: []const type, + comptime q: []const u8, + args: anytype, + allocator: ?std.mem.Allocator, + ) !?ResultSet(result_types).Row { + var results = try self.exec2(result_types, q, args); + defer results.finish(); + + const row = results.row(allocator); + return row orelse (results.err orelse null); + } + + fn build_field_list(comptime T: type, comptime placeholder: ?[]const u8) []const u8 { + comptime { + const joiner = ","; + var result: []const u8 = ""; + inline for (std.meta.fields(T)) |f| { + result = result ++ joiner ++ (placeholder orelse f.name); + } + + return "(" ++ result[joiner.len..] ++ ")"; + } + } + + pub fn insert2( + self: *Database, + comptime table: []const u8, + value: anytype, + ) !void { + const ValueType = comptime @TypeOf(value); + const table_spec = comptime table ++ build_field_list(ValueType, null); + const value_spec = comptime build_field_list(ValueType, "?"); + const q = comptime std.fmt.comptimePrint( + "INSERT INTO {s} VALUES {s}", + .{ table_spec, value_spec }, + ); + _ = try self.execRow2(&.{}, q, value, null); } // Lower level function @@ -306,55 +389,6 @@ pub const Database = struct { return results.toOwnedSlice(); } - // Returns the number of rows that satisfy an equality check on - // one of their fields - pub fn countWhereEq( - self: *Database, - comptime T: type, - comptime field: std.meta.FieldEnum(T), - val: std.meta.fieldInfo(T, field).field_type, - ) !usize { - const field_name = std.meta.fieldInfo(T, field).name; - const q = comptime (Query{ - .select = &.{"COUNT()"}, - .from = tableName(T), - .where = field_name ++ " = ?", - }).str(); - - var stmt = try self.db.prepare(q); - defer stmt.finalize(); - - try bind(stmt, 1, val); - - const row = (try stmt.step()) orelse unreachable; - return @intCast(usize, try row.getI64(0)); - } - - // Returns whether a row with the given value exists. - pub fn existsWhereEq( - self: *Database, - comptime T: type, - comptime field: std.meta.FieldEnum(T), - val: std.meta.fieldInfo(T, field).field_type, - ) !bool { - const field_name = std.meta.fieldInfo(T, field).name; - // TODO: don't like this query - const q = comptime (Query{ - .select = &.{"COUNT(1)"}, - .from = tableName(T), - .where = field_name ++ " = ?", - .limit = 1, - }).str(); - - var stmt = try self.db.prepare(q); - defer stmt.finalize(); - - try bind(stmt, 1, val); - - const row = (try stmt.step()) orelse unreachable; - return (try row.getI64(0)) > 0; - } - // Inserts a row into the database // TODO: consider making this generic? pub fn insert(self: *Database, comptime T: type, val: T) !void { diff --git a/src/main/db/query_builder.zig b/src/main/db/query_builder.zig deleted file mode 100644 index a651c6f..0000000 --- a/src/main/db/query_builder.zig +++ /dev/null @@ -1,284 +0,0 @@ -const std = @import("std"); -const util = @import("util"); -const builtin = @import("builtin"); - -const String = []const u8; -const comptimePrint = std.fmt.comptimePrint; - -fn baseTypeName(comptime T: type) []const u8 { - comptime { - const name = @typeName(T); - const start = for (name) |_, i| { - if (name[name.len - i] == '.') { - // This function has an off-by-one error in the self hosted compiler (-fno-stage1) - // The following code fixes it as of 2022-08-07 - // TODO: Figure out what's going on here - if (builtin.zig_backend == .stage1) { - break name.len - i; - } else { - break name.len - i + 1; - } - } - } else 0; - - return name[start..]; - } -} - -fn tableName(comptime T: type) String { - return comptime util.case.pascalToSnake(baseTypeName(T)); -} - -// Represents a table bound to an identifier in a sql query -pub const QueryTable = struct { - Model: type, - index: comptime_int, - - // Gets a fully qualified field from a literal - pub fn field(comptime self: QueryTable, comptime lit: @Type(.EnumLiteral)) String { - comptime { - const f = @as(std.meta.FieldEnum(self.Model), lit); - return comptimePrint("{s}.{s}", .{ self.as(), @tagName(f) }); - } - } - - pub fn select(comptime self: QueryTable, comptime lit: @Type(.EnumLiteral)) ResultColumn { - return .{ - .@"type" = std.meta.fieldInfo(self.Model, lit).field_type, - .field = self.field(lit), - }; - } - - // returns the declaration to put in the FROM clause - fn declarationStr(comptime self: QueryTable) String { - comptime { - return comptimePrint("{s} AS {s}", .{ tableName(self.Model), self.as() }); - } - } - - fn as(comptime self: QueryTable) String { - comptime { - return comptimePrint("{s}_{}", .{ tableName(self.Model), self.index }); - } - } -}; - -fn makeQueryTable(comptime Model: type, comptime table_index: usize) QueryTable { - return .{ .Model = Model, .index = table_index }; -} - -pub fn queryTables(comptime models: []const type) *const [models.len]QueryTable { - return map(type, QueryTable, models, makeQueryTable); -} - -test "QueryTable.declarationStr" { - const MyTable = struct { id: i64 }; - const tbl = QueryTable{ - .Model = MyTable, - .index = 0, - }; - - try std.testing.expectEqualStrings("my_table AS my_table_0", tbl.declarationStr()); - try std.testing.expectEqualStrings("my_table_0.id", tbl.field(.id)); -} - -test "queryTables constructor" { - const MyTable = struct { id: i64 }; - const MyOtherTable = struct { val: i64 }; - - const qt = queryTables(&.{ MyTable, MyOtherTable }); - - try std.testing.expectEqual(MyTable, qt[0].Model); - try std.testing.expectEqual(MyOtherTable, qt[1].Model); - try std.testing.expectEqualStrings("my_table_0", qt[0].as()); - try std.testing.expectEqualStrings("my_other_table_1", qt[1].as()); -} - -fn map(comptime T: type, comptime R: type, comptime vals: []const T, comptime func: anytype) *const [vals.len]R { - var result: [vals.len]R = undefined; - if (@typeInfo(@TypeOf(func)).Fn.args.len == 2) { - inline for (vals) |v, i| result[i] = @as(R, func(v, i)); - } else { - inline for (vals) |v, i| result[i] = @as(R, func(v)); - } - - return &result; -} - -// Combines an array/tuple of strings into a single string, with a copy of -// joiner in between each one -fn join(comptime vals: []const String, comptime joiner: String) String { - if (vals.len == 0) return ""; - - var result: String = ""; - for (vals) |v| { - result = comptimePrint("{s}{s}{s}", .{ result, joiner, v }); - } - - return result[joiner.len..]; -} - -// Stringifies and joins an array of conditions into a single string -fn joinConditions(comptime cs: []const Condition, comptime joiner: String) String { - var strs: [cs.len]String = undefined; - for (cs) |v, i| strs[i] = v.str(); - return join(&strs, joiner); -} - -// Represents a condition in a SQL statement -pub const Condition = union(enum) { - const BinaryOp = struct { - lhs: String, - rhs: String, - }; - - eql: BinaryOp, - is_null: String, - val: String, - not: *const Condition, - all: []const Condition, - any: []const Condition, - - fn str(comptime self: Condition) String { - comptime { - return comptimePrint("({s})", .{switch (self) { - .eql => |op| comptimePrint("{s} = {s}", .{ op.lhs, op.rhs }), - .is_null => |val| comptimePrint("{s} IS NULL", .{val}), - .val => |val| val, - .not => |c| comptimePrint("NOT {s}", .{c.str()}), - .all => |cs| joinConditions(cs, " AND "), - .any => |cs| joinConditions(cs, " OR "), - }}); - } - } - - pub fn eql(comptime lhs: String, comptime rhs: String) Condition { - return .{ - .eql = .{ .lhs = lhs, .rhs = rhs }, - }; - } - - pub fn all(comptime cs: []const Condition) Condition { - return .{ - .all = cs, - }; - } -}; - -test "Condition.str()" { - try std.testing.expectEqualStrings( - "((abc = def) AND (def = abc))", - (comptime Condition{ .all = &.{ - .{ .eql = .{ .lhs = "abc", .rhs = "def" } }, - .{ .eql = .{ .lhs = "def", .rhs = "abc" } }, - } }).str(), - ); - - try std.testing.expectEqualStrings( - "((abc IS NULL) OR (NOT (def)))", - (comptime Condition{ .any = &.{ - .{ .is_null = "abc" }, - .{ .not = &.{ .val = "def" } }, - } }).str(), - ); -} - -const ResultColumn = struct { - @"type": type, - field: []const u8, - - pub fn toSelectClause(comptime self: ResultColumn) String { - return self.field; - } - - pub fn toStructField(comptime self: ResultColumn, comptime index: usize) std.builtin.Type.StructField { - return .{ - .name = comptimePrint("{}", .{index}), - .field_type = self.@"type", - .default_value = null, - .is_comptime = false, - .alignment = 0, - }; - } -}; - -// Represents a full SQL query -pub const Query = struct { - tables: []const QueryTable, - fields: []const ResultColumn, - filter: Condition, - - pub fn from(comptime tables: []const QueryTable) Query { - return .{ - .tables = tables, - .fields = &.{}, - .filter = .{ .val = "TRUE" }, // TODO - }; - } - - pub fn str(comptime self: Query) String { - comptime { - const table_aliases = map(QueryTable, String, self.tables, QueryTable.declarationStr); - const select_clauses = map(ResultColumn, String, self.fields, ResultColumn.toSelectClause); - return comptimePrint("SELECT {s} FROM {s} WHERE {s}", .{ join(select_clauses, ", "), join(table_aliases, ", "), self.filter.str() }); - } - } - - pub fn rowType(comptime self: *const Query) type { - const struct_fields = map(ResultColumn, std.builtin.Type.StructField, self.fields, ResultColumn.toStructField); - - return @Type(.{ .Struct = .{ - .layout = .Auto, - .fields = struct_fields, - .decls = &.{}, - .is_tuple = true, - } }); - } - - pub fn select(comptime self: Query, comptime fields: []const ResultColumn) Query { - return .{ - .tables = self.tables, - .fields = fields, - .filter = self.filter, - }; - } - - pub fn where(comptime self: Query, comptime condition: Condition) Query { - return .{ - .tables = self.tables, - .fields = self.fields, - .filter = condition, - }; - } -}; - -test "Query" { - const C = Condition; - const MyTable = struct { id: i64 }; - const MyOtherTable = struct { - val: []const u8, - }; - const qt = queryTables(&.{ MyTable, MyOtherTable, MyTable }); - const t1 = qt[0]; - const t2 = qt[2]; - const t_other = qt[1]; - - const q = comptime Query - .from(qt) - .select(&.{ t1.select(.id), t_other.select(.val) }) - .where(C.all(&.{ - C.eql(t1.field(.id), t2.field(.id)), - C.eql(t1.field(.id), t2.field(.id)), - })); - - try std.testing.expectEqualStrings( - "SELECT my_table_0.id, my_other_table_1.val " ++ - "FROM my_table AS my_table_0, my_other_table AS my_other_table_1, my_table AS my_table_2 " ++ - "WHERE ((my_table_0.id = my_table_2.id) AND (my_table_0.id = my_table_2.id))", - comptime q.str(), - ); - - const fields = std.meta.fields(q.rowType()); - try std.testing.expectEqual(i64, fields[0].field_type); - try std.testing.expectEqual([]const u8, fields[1].field_type); -}