fucking around w/ db stuff
This commit is contained in:
parent
d694674585
commit
4bddb9f633
|
@ -24,7 +24,7 @@ pub fn build(b: *std.build.Builder) void {
|
|||
|
||||
// There are some weird problems relating to sentinel values and function pointers
|
||||
// when using the stage1 compiler. Just disable it entirely for now.
|
||||
b.use_stage1 = false;
|
||||
//b.use_stage1 = false;
|
||||
|
||||
const exe = b.addExecutable("apub", "src/main/main.zig");
|
||||
exe.setTarget(target);
|
||||
|
|
122
src/main/api.zig
122
src/main/api.zig
|
@ -127,23 +127,14 @@ pub const ApiSource = struct {
|
|||
var my_db = try db.Database.init();
|
||||
|
||||
{
|
||||
const C = db.builder.Condition;
|
||||
const qt = db.builder.queryTables(&.{ models.User, models.User, models.LocalUser, models.Invite });
|
||||
const UInviter = qt[0];
|
||||
const UInvitee = qt[1];
|
||||
const LUInvitee = qt[2];
|
||||
const Invite = qt[3];
|
||||
const q = comptime db.builder.Query
|
||||
.from(qt)
|
||||
.select(&.{ UInviter.select(.username), UInvitee.select(.username), Invite.select(.id) })
|
||||
.where(C.all(&.{
|
||||
C.eql(UInviter.field(.id), Invite.field(.created_by)),
|
||||
C.eql(LUInvitee.field(.invite_id), Invite.field(.id)),
|
||||
C.eql(LUInvitee.field(.user_id), UInvitee.field(.id)),
|
||||
}));
|
||||
const row = try my_db.execRow2(
|
||||
&.{Uuid},
|
||||
"SELECT id FROM user WHERE username = ?",
|
||||
.{"heartles"},
|
||||
null,
|
||||
);
|
||||
|
||||
const result = (try my_db.execRowQuery(q, alloc)) orelse unreachable;
|
||||
std.log.debug("{s} invited {s}", .{ result[0], result[1] });
|
||||
std.log.debug("{s}", .{row.?[0]});
|
||||
}
|
||||
|
||||
return ApiSource{
|
||||
|
@ -157,8 +148,8 @@ pub const ApiSource = struct {
|
|||
pub fn connectUnauthorized(self: *ApiSource, host: ?[]const u8, alloc: std.mem.Allocator) !Conn {
|
||||
const community_id = blk: {
|
||||
if (host) |h| {
|
||||
const community = try self.db.getBy(models.Community, .host, h, alloc);
|
||||
if (community) |c| break :blk c.id;
|
||||
const result = try self.db.execRow2(&.{Uuid}, "SELECT id FROM community WHERE host = ?", .{h}, null);
|
||||
if (result) |r| break :blk r[0];
|
||||
}
|
||||
|
||||
break :blk null;
|
||||
|
@ -187,7 +178,14 @@ pub const ApiSource = struct {
|
|||
models.Token.HashFn.hash(&decoded, &hash.data, .{});
|
||||
|
||||
const db_token = (try self.db.getBy(models.Token, .hash, hash, conn.arena.allocator())) orelse return error.InvalidToken;
|
||||
//const token_result = (try self.db.execRow2(
|
||||
//&.{Uuid},
|
||||
//"SELECT id FROM token WHERE hash = ?",
|
||||
//.{hash},
|
||||
//null,
|
||||
//)) orelse return error.InvalidToken;
|
||||
|
||||
//conn.as_user = token_result[0];
|
||||
conn.as_user = db_token.user_id;
|
||||
|
||||
return conn;
|
||||
|
@ -332,20 +330,42 @@ fn ApiConn(comptime DbConn: type) type {
|
|||
const user_id = Uuid.randV4(prng.random());
|
||||
// TODO: lock for transaction
|
||||
|
||||
if (try self.db.existsWhereEq(models.User, .username, info.username)) {
|
||||
// TODO: not community aware :(
|
||||
if (try self.db.execRow2(&.{}, "SELECT 1 FROM user WHERE username = ?", .{info.username}, null) != null) {
|
||||
//if (try self.db.existsWhereEq(models.User, .username, info.username)) {
|
||||
return error.UsernameUnavailable;
|
||||
}
|
||||
|
||||
const now = DateTime.now();
|
||||
const invite_id = if (info.invite_code) |invite_code| blk: {
|
||||
const invite = (try self.db.getBy(models.Invite, .invite_code, invite_code, self.arena.allocator())) orelse return error.InvalidInvite;
|
||||
const uses = try self.db.countWhereEq(models.LocalUser, .invite_id, invite.id);
|
||||
const uses_left = if (invite.max_uses) |max_uses| uses < max_uses else true;
|
||||
const expired = if (invite.expires_at) |expires_at| now.seconds_since_epoch > expires_at.seconds_since_epoch else false;
|
||||
// TODO have this query also check for time-based expiration
|
||||
const result = (try self.db.execRow2(
|
||||
&.{ Uuid, ?DateTime },
|
||||
\\SELECT invite.id, invite.expires_at
|
||||
\\FROM invite
|
||||
\\ LEFT OUTER JOIN local_user ON invite.id = local_user.invite_id
|
||||
\\WHERE invite.invite_code = ?
|
||||
\\GROUP BY invite.id
|
||||
\\HAVING
|
||||
\\ (invite.max_uses IS NULL OR invite.max_uses > COUNT(local_user.user_id))
|
||||
\\
|
||||
,
|
||||
.{invite_code},
|
||||
null,
|
||||
)) orelse return error.InvalidInvite;
|
||||
|
||||
if (!uses_left or expired) return error.InvalidInvite;
|
||||
const expired = if (result[1]) |expires_at| now.seconds_since_epoch > expires_at.seconds_since_epoch else false;
|
||||
if (expired) return error.InvalidInvite;
|
||||
|
||||
//const invite = (try self.db.getBy(models.Invite, .invite_code, invite_code, self.arena.allocator())) orelse return error.InvalidInvite;
|
||||
//const invite = (try self.db.getBy(models.Invite, .invite_code, invite_code, self.arena.allocator())) orelse return error.InvalidInvite;
|
||||
//const uses = try self.db.countWhereEq(models.LocalUser, .invite_id, invite.id);
|
||||
//const uses_left = if (invite.max_uses) |max_uses| uses < max_uses else true;
|
||||
//const expired = if (invite.expires_at) |expires_at| now.seconds_since_epoch > expires_at.seconds_since_epoch else false;
|
||||
|
||||
//if (!uses_left or expired) return error.InvalidInvite;
|
||||
// TODO: increment uses
|
||||
break :blk invite.id;
|
||||
break :blk result[0];
|
||||
} else null;
|
||||
|
||||
// use internal alloc because necessary buffer is *big*
|
||||
|
@ -354,8 +374,15 @@ fn ApiConn(comptime DbConn: type) type {
|
|||
|
||||
const community_id = if (info.community_host) |host| blk: {
|
||||
//const id_tuple = (try self.db.execRow("select id from community where host = '?'", host, &.{Uuid}, self.arena.allocator())) orelse return error.CommunityNotFound;
|
||||
const community = (try self.db.getBy(models.Community, .host, host, self.arena.allocator())) orelse return error.CommunityNotFound;
|
||||
break :blk community.id;
|
||||
const community_result = (try self.db.execRow2(
|
||||
&.{Uuid},
|
||||
"SELECT id FROM community WHERE host = ?",
|
||||
.{host},
|
||||
null,
|
||||
)) orelse return error.CommunityNotFound;
|
||||
|
||||
//const community = (try self.db.getBy(models.Community, .host, host, self.arena.allocator())) orelse return error.CommunityNotFound;
|
||||
break :blk community_result[0];
|
||||
//break :blk id_tuple[0];
|
||||
} else null;
|
||||
|
||||
|
@ -385,23 +412,37 @@ fn ApiConn(comptime DbConn: type) type {
|
|||
|
||||
pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResult {
|
||||
// TODO: This gives away the existence of a user through a timing side channel. is that acceptable?
|
||||
const user_info = (try self.db.getBy(models.User, .username, username, self.arena.allocator())) orelse return error.InvalidLogin;
|
||||
const local_user_info = (try self.db.getBy(models.LocalUser, .user_id, user_info.id, self.arena.allocator())) orelse return error.InvalidLogin;
|
||||
//const user_info = (try self.db.getBy(models.User, .username, username, self.arena.allocator())) orelse return error.InvalidLogin;
|
||||
//const local_user_info = (try self.db.getBy(models.LocalUser, .user_id, user_info.id, self.arena.allocator())) orelse return error.InvalidLogin;
|
||||
|
||||
const user_info = (try self.db.execRow2(
|
||||
&.{ Uuid, []const u8 },
|
||||
\\SELECT user.id, local_user.hashed_password
|
||||
\\FROM user JOIN local_user ON local_user.user_id = user.id
|
||||
\\WHERE user.username = ?
|
||||
,
|
||||
.{username},
|
||||
self.arena.allocator(),
|
||||
)) orelse return error.InvalidLogin;
|
||||
|
||||
const user_id = user_info[0];
|
||||
const hashed_password = user_info[1];
|
||||
|
||||
//defer free(self.arena.allocator(), user_info);
|
||||
|
||||
const Hash = std.crypto.pwhash.scrypt;
|
||||
Hash.strVerify(local_user_info.hashed_password, password, .{ .allocator = self.internal_alloc }) catch |err| switch (err) {
|
||||
Hash.strVerify(hashed_password, password, .{ .allocator = self.internal_alloc }) catch |err| switch (err) {
|
||||
error.PasswordVerificationFailed => return error.InvalidLogin,
|
||||
else => return err,
|
||||
};
|
||||
|
||||
const token = try self.createToken(user_info.id);
|
||||
const token = try self.createToken(user_id);
|
||||
|
||||
var token_enc: [token_str_len]u8 = undefined;
|
||||
_ = std.base64.standard.Encoder.encode(&token_enc, &token.value);
|
||||
|
||||
return LoginResult{
|
||||
.user_id = user_info.id,
|
||||
.user_id = user_id,
|
||||
.token = token_enc,
|
||||
.issued_at = token.info.issued_at,
|
||||
};
|
||||
|
@ -425,7 +466,7 @@ fn ApiConn(comptime DbConn: type) type {
|
|||
.issued_at = DateTime.now(),
|
||||
};
|
||||
|
||||
try self.db.insert(models.Token, db_token);
|
||||
try self.db.insert2("token", db_token);
|
||||
return TokenResult{
|
||||
.info = db_token,
|
||||
.value = token,
|
||||
|
@ -440,14 +481,21 @@ fn ApiConn(comptime DbConn: type) type {
|
|||
// Users can only make invites to their own community, unless they
|
||||
// are system users
|
||||
const community_id = if (options.to_community) |host| blk: {
|
||||
const desired_community = (try self.db.getBy(models.Community, .host, host, self.arena.allocator())) orelse return error.CommunityNotFound;
|
||||
if (user.community_id != null and !Uuid.eql(desired_community.id, user.community_id.?)) {
|
||||
const desired_community = (try self.db.execRow2(
|
||||
&.{Uuid},
|
||||
"SELECT id FROM community WHERE host = ?",
|
||||
.{host},
|
||||
null,
|
||||
)) orelse return error.CommunityNotFound;
|
||||
|
||||
if (user.community_id != null and !Uuid.eql(desired_community[0], user.community_id.?)) {
|
||||
return error.WrongCommunity;
|
||||
}
|
||||
|
||||
break :blk desired_community.id;
|
||||
break :blk desired_community[0];
|
||||
} else null;
|
||||
if (user.community_id != null and options.to_community == null) {
|
||||
|
||||
if (user.community_id != null and community_id == null) {
|
||||
return error.WrongCommunity;
|
||||
}
|
||||
|
||||
|
|
168
src/main/db.zig
168
src/main/db.zig
|
@ -9,8 +9,6 @@ const DateTime = util.DateTime;
|
|||
const String = []const u8;
|
||||
const comptimePrint = std.fmt.comptimePrint;
|
||||
|
||||
pub const builder = @import("./db/query_builder.zig");
|
||||
|
||||
fn tableName(comptime T: type) String {
|
||||
return switch (T) {
|
||||
models.Note => "note",
|
||||
|
@ -25,6 +23,44 @@ fn tableName(comptime T: type) String {
|
|||
};
|
||||
}
|
||||
|
||||
fn readRow(comptime RowTuple: type, row: sql.Row, allocator: ?std.mem.Allocator) !RowTuple {
|
||||
var result: RowTuple = undefined;
|
||||
// TODO: undo allocations on failure
|
||||
inline for (std.meta.fields(RowTuple)) |f, i| {
|
||||
@field(result, f.name) = try getAlloc(row, f.field_type, i, allocator);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
pub fn ResultSet(comptime result_types: []const type) type {
|
||||
return struct {
|
||||
pub const QueryError = anyerror;
|
||||
pub const Row = std.meta.Tuple(result_types);
|
||||
|
||||
_stmt: sql.PreparedStmt,
|
||||
err: ?QueryError = null,
|
||||
|
||||
pub fn finish(self: *@This()) void {
|
||||
self._stmt.finalize();
|
||||
}
|
||||
|
||||
pub fn row(self: *@This(), allocator: ?std.mem.Allocator) ?Row {
|
||||
const sql_result = self._stmt.step() catch |err| {
|
||||
self.err = err;
|
||||
return null;
|
||||
};
|
||||
|
||||
if (sql_result) |sql_row| {
|
||||
return readRow(Row, sql_row, allocator) catch |err| {
|
||||
self.err = err;
|
||||
return null;
|
||||
};
|
||||
} else return null;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Combines an array/tuple of strings into a single string, with a copy of
|
||||
// joiner in between each one
|
||||
fn join(comptime vals: anytype, comptime joiner: String) String {
|
||||
|
@ -121,8 +157,9 @@ fn fieldsExcept(comptime T: type, comptime to_ignore: []const String) []const St
|
|||
// pub fn bindToSql(val: T, stmt: sql.PreparedStmt, idx: u15) !void
|
||||
// TODO define what error set this ^ should return
|
||||
fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void {
|
||||
if (comptime std.meta.trait.isZigString(@TypeOf(val))) return stmt.bindText(idx, val);
|
||||
|
||||
return switch (@TypeOf(val)) {
|
||||
[]u8, []const u8 => stmt.bindText(idx, val),
|
||||
i64 => stmt.bindI64(idx, val),
|
||||
Uuid => stmt.bindUuid(idx, val),
|
||||
DateTime => stmt.bindDateTime(idx, val),
|
||||
|
@ -134,7 +171,8 @@ fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void {
|
|||
val.bindToSql(stmt, idx)
|
||||
else
|
||||
@compileError("unsupported type " ++ @typeName(T)),
|
||||
else => @compileError("unsupported Type " ++ @typeName(T)),
|
||||
else => unreachable,
|
||||
//@compileError("unsupported type " ++ @typeName(T)),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -144,9 +182,9 @@ fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void {
|
|||
// declaring a method with the given signature:
|
||||
// pub fn getFromSql(row: sql.Row, idx: u15, alloc: std.mem.Allocator) !T
|
||||
// TODO define what error set this ^ should return
|
||||
fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: std.mem.Allocator) !T {
|
||||
fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: ?std.mem.Allocator) !T {
|
||||
return switch (T) {
|
||||
[]u8, []const u8 => row.getTextAlloc(idx, alloc),
|
||||
[]u8, []const u8 => row.getTextAlloc(idx, alloc orelse return error.AllocatorRequired),
|
||||
i64 => row.getI64(idx),
|
||||
Uuid => row.getUuid(idx),
|
||||
DateTime => row.getDateTime(idx),
|
||||
|
@ -158,11 +196,11 @@ fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: std.mem.Allocator)
|
|||
try getAlloc(row, std.meta.Child(T), idx, alloc),
|
||||
|
||||
.Struct, .Union, .Opaque => if (@hasDecl(T, "getFromSql"))
|
||||
T.getFromSql(row, idx, alloc)
|
||||
T.getFromSql(row, idx, alloc orelse return error.AllocatorRequired)
|
||||
else
|
||||
@compileError("unknown type " ++ @typeName(T)),
|
||||
|
||||
.Enum => try getEnum(row, T, idx, alloc),
|
||||
.Enum => try getEnum(row, T, idx, alloc orelse return error.AllocatorRequired),
|
||||
|
||||
else => @compileError("unknown type " ++ @typeName(T)),
|
||||
},
|
||||
|
@ -195,20 +233,65 @@ pub const Database = struct {
|
|||
self.db.close();
|
||||
}
|
||||
|
||||
pub fn execRowQuery(self: *Database, comptime q: builder.Query, alloc: std.mem.Allocator) !?q.rowType() {
|
||||
std.log.debug("executing sql:\n===\n{s}\n===", .{q.str()});
|
||||
var stmt = try self.db.prepare(q.str());
|
||||
pub fn exec2(
|
||||
self: *Database,
|
||||
comptime result_types: []const type,
|
||||
comptime q: []const u8,
|
||||
args: anytype,
|
||||
) !ResultSet(result_types) {
|
||||
std.log.debug("executing sql:\n===\n{s}\n===", .{q});
|
||||
|
||||
const stmt = try self.db.prepare(q);
|
||||
errdefer stmt.finalize();
|
||||
|
||||
const row = (try stmt.step()) orelse return null;
|
||||
|
||||
std.log.debug("successful query", .{});
|
||||
var result: q.rowType() = undefined;
|
||||
inline for (std.meta.fields(q.rowType())) |f, i| {
|
||||
result[i] = try getAlloc(row, f.field_type, i, alloc);
|
||||
inline for (std.meta.fields(@TypeOf(args))) |field, i| {
|
||||
try bind(stmt, @intCast(u15, i + 1), @field(args, field.name));
|
||||
}
|
||||
|
||||
return result;
|
||||
return ResultSet(result_types){
|
||||
._stmt = stmt,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn execRow2(
|
||||
self: *Database,
|
||||
comptime result_types: []const type,
|
||||
comptime q: []const u8,
|
||||
args: anytype,
|
||||
allocator: ?std.mem.Allocator,
|
||||
) !?ResultSet(result_types).Row {
|
||||
var results = try self.exec2(result_types, q, args);
|
||||
defer results.finish();
|
||||
|
||||
const row = results.row(allocator);
|
||||
return row orelse (results.err orelse null);
|
||||
}
|
||||
|
||||
fn build_field_list(comptime T: type, comptime placeholder: ?[]const u8) []const u8 {
|
||||
comptime {
|
||||
const joiner = ",";
|
||||
var result: []const u8 = "";
|
||||
inline for (std.meta.fields(T)) |f| {
|
||||
result = result ++ joiner ++ (placeholder orelse f.name);
|
||||
}
|
||||
|
||||
return "(" ++ result[joiner.len..] ++ ")";
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert2(
|
||||
self: *Database,
|
||||
comptime table: []const u8,
|
||||
value: anytype,
|
||||
) !void {
|
||||
const ValueType = comptime @TypeOf(value);
|
||||
const table_spec = comptime table ++ build_field_list(ValueType, null);
|
||||
const value_spec = comptime build_field_list(ValueType, "?");
|
||||
const q = comptime std.fmt.comptimePrint(
|
||||
"INSERT INTO {s} VALUES {s}",
|
||||
.{ table_spec, value_spec },
|
||||
);
|
||||
_ = try self.execRow2(&.{}, q, value, null);
|
||||
}
|
||||
|
||||
// Lower level function
|
||||
|
@ -306,55 +389,6 @@ pub const Database = struct {
|
|||
return results.toOwnedSlice();
|
||||
}
|
||||
|
||||
// Returns the number of rows that satisfy an equality check on
|
||||
// one of their fields
|
||||
pub fn countWhereEq(
|
||||
self: *Database,
|
||||
comptime T: type,
|
||||
comptime field: std.meta.FieldEnum(T),
|
||||
val: std.meta.fieldInfo(T, field).field_type,
|
||||
) !usize {
|
||||
const field_name = std.meta.fieldInfo(T, field).name;
|
||||
const q = comptime (Query{
|
||||
.select = &.{"COUNT()"},
|
||||
.from = tableName(T),
|
||||
.where = field_name ++ " = ?",
|
||||
}).str();
|
||||
|
||||
var stmt = try self.db.prepare(q);
|
||||
defer stmt.finalize();
|
||||
|
||||
try bind(stmt, 1, val);
|
||||
|
||||
const row = (try stmt.step()) orelse unreachable;
|
||||
return @intCast(usize, try row.getI64(0));
|
||||
}
|
||||
|
||||
// Returns whether a row with the given value exists.
|
||||
pub fn existsWhereEq(
|
||||
self: *Database,
|
||||
comptime T: type,
|
||||
comptime field: std.meta.FieldEnum(T),
|
||||
val: std.meta.fieldInfo(T, field).field_type,
|
||||
) !bool {
|
||||
const field_name = std.meta.fieldInfo(T, field).name;
|
||||
// TODO: don't like this query
|
||||
const q = comptime (Query{
|
||||
.select = &.{"COUNT(1)"},
|
||||
.from = tableName(T),
|
||||
.where = field_name ++ " = ?",
|
||||
.limit = 1,
|
||||
}).str();
|
||||
|
||||
var stmt = try self.db.prepare(q);
|
||||
defer stmt.finalize();
|
||||
|
||||
try bind(stmt, 1, val);
|
||||
|
||||
const row = (try stmt.step()) orelse unreachable;
|
||||
return (try row.getI64(0)) > 0;
|
||||
}
|
||||
|
||||
// Inserts a row into the database
|
||||
// TODO: consider making this generic?
|
||||
pub fn insert(self: *Database, comptime T: type, val: T) !void {
|
||||
|
|
|
@ -1,284 +0,0 @@
|
|||
const std = @import("std");
|
||||
const util = @import("util");
|
||||
const builtin = @import("builtin");
|
||||
|
||||
const String = []const u8;
|
||||
const comptimePrint = std.fmt.comptimePrint;
|
||||
|
||||
fn baseTypeName(comptime T: type) []const u8 {
|
||||
comptime {
|
||||
const name = @typeName(T);
|
||||
const start = for (name) |_, i| {
|
||||
if (name[name.len - i] == '.') {
|
||||
// This function has an off-by-one error in the self hosted compiler (-fno-stage1)
|
||||
// The following code fixes it as of 2022-08-07
|
||||
// TODO: Figure out what's going on here
|
||||
if (builtin.zig_backend == .stage1) {
|
||||
break name.len - i;
|
||||
} else {
|
||||
break name.len - i + 1;
|
||||
}
|
||||
}
|
||||
} else 0;
|
||||
|
||||
return name[start..];
|
||||
}
|
||||
}
|
||||
|
||||
fn tableName(comptime T: type) String {
|
||||
return comptime util.case.pascalToSnake(baseTypeName(T));
|
||||
}
|
||||
|
||||
// Represents a table bound to an identifier in a sql query
|
||||
pub const QueryTable = struct {
|
||||
Model: type,
|
||||
index: comptime_int,
|
||||
|
||||
// Gets a fully qualified field from a literal
|
||||
pub fn field(comptime self: QueryTable, comptime lit: @Type(.EnumLiteral)) String {
|
||||
comptime {
|
||||
const f = @as(std.meta.FieldEnum(self.Model), lit);
|
||||
return comptimePrint("{s}.{s}", .{ self.as(), @tagName(f) });
|
||||
}
|
||||
}
|
||||
|
||||
pub fn select(comptime self: QueryTable, comptime lit: @Type(.EnumLiteral)) ResultColumn {
|
||||
return .{
|
||||
.@"type" = std.meta.fieldInfo(self.Model, lit).field_type,
|
||||
.field = self.field(lit),
|
||||
};
|
||||
}
|
||||
|
||||
// returns the declaration to put in the FROM clause
|
||||
fn declarationStr(comptime self: QueryTable) String {
|
||||
comptime {
|
||||
return comptimePrint("{s} AS {s}", .{ tableName(self.Model), self.as() });
|
||||
}
|
||||
}
|
||||
|
||||
fn as(comptime self: QueryTable) String {
|
||||
comptime {
|
||||
return comptimePrint("{s}_{}", .{ tableName(self.Model), self.index });
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
fn makeQueryTable(comptime Model: type, comptime table_index: usize) QueryTable {
|
||||
return .{ .Model = Model, .index = table_index };
|
||||
}
|
||||
|
||||
pub fn queryTables(comptime models: []const type) *const [models.len]QueryTable {
|
||||
return map(type, QueryTable, models, makeQueryTable);
|
||||
}
|
||||
|
||||
test "QueryTable.declarationStr" {
|
||||
const MyTable = struct { id: i64 };
|
||||
const tbl = QueryTable{
|
||||
.Model = MyTable,
|
||||
.index = 0,
|
||||
};
|
||||
|
||||
try std.testing.expectEqualStrings("my_table AS my_table_0", tbl.declarationStr());
|
||||
try std.testing.expectEqualStrings("my_table_0.id", tbl.field(.id));
|
||||
}
|
||||
|
||||
test "queryTables constructor" {
|
||||
const MyTable = struct { id: i64 };
|
||||
const MyOtherTable = struct { val: i64 };
|
||||
|
||||
const qt = queryTables(&.{ MyTable, MyOtherTable });
|
||||
|
||||
try std.testing.expectEqual(MyTable, qt[0].Model);
|
||||
try std.testing.expectEqual(MyOtherTable, qt[1].Model);
|
||||
try std.testing.expectEqualStrings("my_table_0", qt[0].as());
|
||||
try std.testing.expectEqualStrings("my_other_table_1", qt[1].as());
|
||||
}
|
||||
|
||||
fn map(comptime T: type, comptime R: type, comptime vals: []const T, comptime func: anytype) *const [vals.len]R {
|
||||
var result: [vals.len]R = undefined;
|
||||
if (@typeInfo(@TypeOf(func)).Fn.args.len == 2) {
|
||||
inline for (vals) |v, i| result[i] = @as(R, func(v, i));
|
||||
} else {
|
||||
inline for (vals) |v, i| result[i] = @as(R, func(v));
|
||||
}
|
||||
|
||||
return &result;
|
||||
}
|
||||
|
||||
// Combines an array/tuple of strings into a single string, with a copy of
|
||||
// joiner in between each one
|
||||
fn join(comptime vals: []const String, comptime joiner: String) String {
|
||||
if (vals.len == 0) return "";
|
||||
|
||||
var result: String = "";
|
||||
for (vals) |v| {
|
||||
result = comptimePrint("{s}{s}{s}", .{ result, joiner, v });
|
||||
}
|
||||
|
||||
return result[joiner.len..];
|
||||
}
|
||||
|
||||
// Stringifies and joins an array of conditions into a single string
|
||||
fn joinConditions(comptime cs: []const Condition, comptime joiner: String) String {
|
||||
var strs: [cs.len]String = undefined;
|
||||
for (cs) |v, i| strs[i] = v.str();
|
||||
return join(&strs, joiner);
|
||||
}
|
||||
|
||||
// Represents a condition in a SQL statement
|
||||
pub const Condition = union(enum) {
|
||||
const BinaryOp = struct {
|
||||
lhs: String,
|
||||
rhs: String,
|
||||
};
|
||||
|
||||
eql: BinaryOp,
|
||||
is_null: String,
|
||||
val: String,
|
||||
not: *const Condition,
|
||||
all: []const Condition,
|
||||
any: []const Condition,
|
||||
|
||||
fn str(comptime self: Condition) String {
|
||||
comptime {
|
||||
return comptimePrint("({s})", .{switch (self) {
|
||||
.eql => |op| comptimePrint("{s} = {s}", .{ op.lhs, op.rhs }),
|
||||
.is_null => |val| comptimePrint("{s} IS NULL", .{val}),
|
||||
.val => |val| val,
|
||||
.not => |c| comptimePrint("NOT {s}", .{c.str()}),
|
||||
.all => |cs| joinConditions(cs, " AND "),
|
||||
.any => |cs| joinConditions(cs, " OR "),
|
||||
}});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn eql(comptime lhs: String, comptime rhs: String) Condition {
|
||||
return .{
|
||||
.eql = .{ .lhs = lhs, .rhs = rhs },
|
||||
};
|
||||
}
|
||||
|
||||
pub fn all(comptime cs: []const Condition) Condition {
|
||||
return .{
|
||||
.all = cs,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
test "Condition.str()" {
|
||||
try std.testing.expectEqualStrings(
|
||||
"((abc = def) AND (def = abc))",
|
||||
(comptime Condition{ .all = &.{
|
||||
.{ .eql = .{ .lhs = "abc", .rhs = "def" } },
|
||||
.{ .eql = .{ .lhs = "def", .rhs = "abc" } },
|
||||
} }).str(),
|
||||
);
|
||||
|
||||
try std.testing.expectEqualStrings(
|
||||
"((abc IS NULL) OR (NOT (def)))",
|
||||
(comptime Condition{ .any = &.{
|
||||
.{ .is_null = "abc" },
|
||||
.{ .not = &.{ .val = "def" } },
|
||||
} }).str(),
|
||||
);
|
||||
}
|
||||
|
||||
const ResultColumn = struct {
|
||||
@"type": type,
|
||||
field: []const u8,
|
||||
|
||||
pub fn toSelectClause(comptime self: ResultColumn) String {
|
||||
return self.field;
|
||||
}
|
||||
|
||||
pub fn toStructField(comptime self: ResultColumn, comptime index: usize) std.builtin.Type.StructField {
|
||||
return .{
|
||||
.name = comptimePrint("{}", .{index}),
|
||||
.field_type = self.@"type",
|
||||
.default_value = null,
|
||||
.is_comptime = false,
|
||||
.alignment = 0,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// Represents a full SQL query
|
||||
pub const Query = struct {
|
||||
tables: []const QueryTable,
|
||||
fields: []const ResultColumn,
|
||||
filter: Condition,
|
||||
|
||||
pub fn from(comptime tables: []const QueryTable) Query {
|
||||
return .{
|
||||
.tables = tables,
|
||||
.fields = &.{},
|
||||
.filter = .{ .val = "TRUE" }, // TODO
|
||||
};
|
||||
}
|
||||
|
||||
pub fn str(comptime self: Query) String {
|
||||
comptime {
|
||||
const table_aliases = map(QueryTable, String, self.tables, QueryTable.declarationStr);
|
||||
const select_clauses = map(ResultColumn, String, self.fields, ResultColumn.toSelectClause);
|
||||
return comptimePrint("SELECT {s} FROM {s} WHERE {s}", .{ join(select_clauses, ", "), join(table_aliases, ", "), self.filter.str() });
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rowType(comptime self: *const Query) type {
|
||||
const struct_fields = map(ResultColumn, std.builtin.Type.StructField, self.fields, ResultColumn.toStructField);
|
||||
|
||||
return @Type(.{ .Struct = .{
|
||||
.layout = .Auto,
|
||||
.fields = struct_fields,
|
||||
.decls = &.{},
|
||||
.is_tuple = true,
|
||||
} });
|
||||
}
|
||||
|
||||
pub fn select(comptime self: Query, comptime fields: []const ResultColumn) Query {
|
||||
return .{
|
||||
.tables = self.tables,
|
||||
.fields = fields,
|
||||
.filter = self.filter,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn where(comptime self: Query, comptime condition: Condition) Query {
|
||||
return .{
|
||||
.tables = self.tables,
|
||||
.fields = self.fields,
|
||||
.filter = condition,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
test "Query" {
|
||||
const C = Condition;
|
||||
const MyTable = struct { id: i64 };
|
||||
const MyOtherTable = struct {
|
||||
val: []const u8,
|
||||
};
|
||||
const qt = queryTables(&.{ MyTable, MyOtherTable, MyTable });
|
||||
const t1 = qt[0];
|
||||
const t2 = qt[2];
|
||||
const t_other = qt[1];
|
||||
|
||||
const q = comptime Query
|
||||
.from(qt)
|
||||
.select(&.{ t1.select(.id), t_other.select(.val) })
|
||||
.where(C.all(&.{
|
||||
C.eql(t1.field(.id), t2.field(.id)),
|
||||
C.eql(t1.field(.id), t2.field(.id)),
|
||||
}));
|
||||
|
||||
try std.testing.expectEqualStrings(
|
||||
"SELECT my_table_0.id, my_other_table_1.val " ++
|
||||
"FROM my_table AS my_table_0, my_other_table AS my_other_table_1, my_table AS my_table_2 " ++
|
||||
"WHERE ((my_table_0.id = my_table_2.id) AND (my_table_0.id = my_table_2.id))",
|
||||
comptime q.str(),
|
||||
);
|
||||
|
||||
const fields = std.meta.fields(q.rowType());
|
||||
try std.testing.expectEqual(i64, fields[0].field_type);
|
||||
try std.testing.expectEqual([]const u8, fields[1].field_type);
|
||||
}
|
Loading…
Reference in New Issue