diff --git a/src/main/db/query_builder.zig b/src/main/db/query_builder.zig index 6735e9d..f6147ed 100644 --- a/src/main/db/query_builder.zig +++ b/src/main/db/query_builder.zig @@ -16,7 +16,7 @@ fn baseTypeName(comptime T: type) []const u8 { } fn tableName(comptime T: type) String { - return util.case.pascalToSnake(baseTypeName(T)); + return comptime util.case.pascalToSnake(baseTypeName(T)); } // Represents a table bound to an identifier in a sql query @@ -29,6 +29,20 @@ pub const QueryTable = struct { const f = @as(std.meta.FieldEnum(self.Model), lit); return comptimePrint("{s}.{s}", .{ self.as, @tagName(f) }); } + + pub fn select(comptime self: QueryTable, comptime lit: @Type(.EnumLiteral)) ResultColumn { + return .{ + .@"type" = std.meta.fieldInfo(self.Model, lit).field_type, + .field = self.field(lit), + }; + } + + // returns the declaration to put in the FROM clause + fn declarationStr(comptime self: QueryTable) String { + comptime { + return comptimePrint("{s} AS {s}", .{ tableName(self.Model), self.as }); + } + } }; fn makeQueryTable(comptime Model: type, comptime table_index: usize) QueryTable { @@ -51,6 +65,16 @@ test "QueryTable.field" { try std.testing.expectEqualStrings("my_type.id", tbl.field(.id)); } +test "QueryTable.declarationStr" { + const MyTable = struct { id: i64 }; + const tbl = QueryTable{ + .Model = MyTable, + .as = "some_table", + }; + + try std.testing.expectEqualStrings("my_table AS some_table", tbl.declarationStr()); +} + test "queryTables constructor" { const MyTable = struct { id: i64 }; const MyOtherTable = struct { val: i64 }; @@ -63,9 +87,14 @@ test "queryTables constructor" { try std.testing.expectEqualStrings("my_other_table_1", qt[1].as); } -fn map(comptime T: type, comptime R: type, comptime vals: []const T, comptime func: fn (T, usize) R) *const [vals.len]R { +fn map(comptime T: type, comptime R: type, comptime vals: []const T, comptime func: anytype) *const [vals.len]R { var result: [vals.len]R = undefined; - for (vals) |v, i| result[i] = func(v, i); + if (@typeInfo(@TypeOf(func)).Fn.args.len == 2) { + for (vals) |v, i| result[i] = @as(R, func(v, i)); + } else { + for (vals) |v, i| result[i] = @as(R, func(v)); + } + return &result; } @@ -134,3 +163,70 @@ test "Condition.str()" { } }).str(), ); } + +const ResultColumn = struct { + @"type": type, + field: []const u8, + + pub fn toSelectClause(comptime self: ResultColumn) String { + return self.field; + } + + pub fn toStructField(comptime self: ResultColumn) std.builtin.Type.StructField { + return .{ + .name = self.field, + .field_type = self.@"type", + .default_value = null, + .is_comptime = false, + .alignment = 0, + }; + } +}; + +// Represents a full SQL query +pub const Query = struct { + tables: []const QueryTable, + fields: []const ResultColumn, + filter: Condition, + + pub fn str(comptime self: Query) String { + const table_aliases = map(QueryTable, String, self.tables, QueryTable.declarationStr); + const select_clauses = map(ResultColumn, String, self.fields, ResultColumn.toSelectClause); + return comptimePrint("SELECT {s} FROM {s} WHERE {s}", .{ join(select_clauses, ", "), join(table_aliases, ", "), self.filter.str() }); + } + + pub fn rowType(comptime self: Query) type { + const struct_fields = map(ResultColumn, std.builtin.Type.StructField, self.fields, ResultColumn.toStructField); + + return @Type(.{ .Struct = .{ + .layout = .Auto, + .fields = struct_fields, + .decls = &.{}, + .is_tuple = true, + } }); + } +}; + +test "Query" { + const MyTable = struct { id: i64 }; + const MyOtherTable = struct { + val: []const u8, + }; + const qt = queryTables(&.{ MyTable, MyOtherTable, MyTable }); + const q = comptime Query{ + .tables = qt, + .fields = &.{ qt[0].select(.id), qt[1].select(.val) }, + .filter = Condition{ .eql = .{ .lhs = qt[0].field(.id), .rhs = qt[2].field(.id) } }, + }; + + try std.testing.expectEqualStrings( + "SELECT my_table_0.id, my_other_table_1.val " ++ + "FROM my_table AS my_table_0, my_other_table AS my_other_table_1, my_table AS my_table_2 " ++ + "WHERE (my_table_0.id = my_table_2.id)", + comptime q.str(), + ); + + const fields = std.meta.fields(q.rowType()); + try std.testing.expectEqual(i64, fields[0].field_type); + try std.testing.expectEqual([]const u8, fields[1].field_type); +}