fucking around w/ db stuff
This commit is contained in:
parent
d694674585
commit
4bddb9f633
|
@ -24,7 +24,7 @@ pub fn build(b: *std.build.Builder) void {
|
||||||
|
|
||||||
// There are some weird problems relating to sentinel values and function pointers
|
// There are some weird problems relating to sentinel values and function pointers
|
||||||
// when using the stage1 compiler. Just disable it entirely for now.
|
// when using the stage1 compiler. Just disable it entirely for now.
|
||||||
b.use_stage1 = false;
|
//b.use_stage1 = false;
|
||||||
|
|
||||||
const exe = b.addExecutable("apub", "src/main/main.zig");
|
const exe = b.addExecutable("apub", "src/main/main.zig");
|
||||||
exe.setTarget(target);
|
exe.setTarget(target);
|
||||||
|
|
122
src/main/api.zig
122
src/main/api.zig
|
@ -127,23 +127,14 @@ pub const ApiSource = struct {
|
||||||
var my_db = try db.Database.init();
|
var my_db = try db.Database.init();
|
||||||
|
|
||||||
{
|
{
|
||||||
const C = db.builder.Condition;
|
const row = try my_db.execRow2(
|
||||||
const qt = db.builder.queryTables(&.{ models.User, models.User, models.LocalUser, models.Invite });
|
&.{Uuid},
|
||||||
const UInviter = qt[0];
|
"SELECT id FROM user WHERE username = ?",
|
||||||
const UInvitee = qt[1];
|
.{"heartles"},
|
||||||
const LUInvitee = qt[2];
|
null,
|
||||||
const Invite = qt[3];
|
);
|
||||||
const q = comptime db.builder.Query
|
|
||||||
.from(qt)
|
|
||||||
.select(&.{ UInviter.select(.username), UInvitee.select(.username), Invite.select(.id) })
|
|
||||||
.where(C.all(&.{
|
|
||||||
C.eql(UInviter.field(.id), Invite.field(.created_by)),
|
|
||||||
C.eql(LUInvitee.field(.invite_id), Invite.field(.id)),
|
|
||||||
C.eql(LUInvitee.field(.user_id), UInvitee.field(.id)),
|
|
||||||
}));
|
|
||||||
|
|
||||||
const result = (try my_db.execRowQuery(q, alloc)) orelse unreachable;
|
std.log.debug("{s}", .{row.?[0]});
|
||||||
std.log.debug("{s} invited {s}", .{ result[0], result[1] });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ApiSource{
|
return ApiSource{
|
||||||
|
@ -157,8 +148,8 @@ pub const ApiSource = struct {
|
||||||
pub fn connectUnauthorized(self: *ApiSource, host: ?[]const u8, alloc: std.mem.Allocator) !Conn {
|
pub fn connectUnauthorized(self: *ApiSource, host: ?[]const u8, alloc: std.mem.Allocator) !Conn {
|
||||||
const community_id = blk: {
|
const community_id = blk: {
|
||||||
if (host) |h| {
|
if (host) |h| {
|
||||||
const community = try self.db.getBy(models.Community, .host, h, alloc);
|
const result = try self.db.execRow2(&.{Uuid}, "SELECT id FROM community WHERE host = ?", .{h}, null);
|
||||||
if (community) |c| break :blk c.id;
|
if (result) |r| break :blk r[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
break :blk null;
|
break :blk null;
|
||||||
|
@ -187,7 +178,14 @@ pub const ApiSource = struct {
|
||||||
models.Token.HashFn.hash(&decoded, &hash.data, .{});
|
models.Token.HashFn.hash(&decoded, &hash.data, .{});
|
||||||
|
|
||||||
const db_token = (try self.db.getBy(models.Token, .hash, hash, conn.arena.allocator())) orelse return error.InvalidToken;
|
const db_token = (try self.db.getBy(models.Token, .hash, hash, conn.arena.allocator())) orelse return error.InvalidToken;
|
||||||
|
//const token_result = (try self.db.execRow2(
|
||||||
|
//&.{Uuid},
|
||||||
|
//"SELECT id FROM token WHERE hash = ?",
|
||||||
|
//.{hash},
|
||||||
|
//null,
|
||||||
|
//)) orelse return error.InvalidToken;
|
||||||
|
|
||||||
|
//conn.as_user = token_result[0];
|
||||||
conn.as_user = db_token.user_id;
|
conn.as_user = db_token.user_id;
|
||||||
|
|
||||||
return conn;
|
return conn;
|
||||||
|
@ -332,20 +330,42 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
const user_id = Uuid.randV4(prng.random());
|
const user_id = Uuid.randV4(prng.random());
|
||||||
// TODO: lock for transaction
|
// TODO: lock for transaction
|
||||||
|
|
||||||
if (try self.db.existsWhereEq(models.User, .username, info.username)) {
|
// TODO: not community aware :(
|
||||||
|
if (try self.db.execRow2(&.{}, "SELECT 1 FROM user WHERE username = ?", .{info.username}, null) != null) {
|
||||||
|
//if (try self.db.existsWhereEq(models.User, .username, info.username)) {
|
||||||
return error.UsernameUnavailable;
|
return error.UsernameUnavailable;
|
||||||
}
|
}
|
||||||
|
|
||||||
const now = DateTime.now();
|
const now = DateTime.now();
|
||||||
const invite_id = if (info.invite_code) |invite_code| blk: {
|
const invite_id = if (info.invite_code) |invite_code| blk: {
|
||||||
const invite = (try self.db.getBy(models.Invite, .invite_code, invite_code, self.arena.allocator())) orelse return error.InvalidInvite;
|
// TODO have this query also check for time-based expiration
|
||||||
const uses = try self.db.countWhereEq(models.LocalUser, .invite_id, invite.id);
|
const result = (try self.db.execRow2(
|
||||||
const uses_left = if (invite.max_uses) |max_uses| uses < max_uses else true;
|
&.{ Uuid, ?DateTime },
|
||||||
const expired = if (invite.expires_at) |expires_at| now.seconds_since_epoch > expires_at.seconds_since_epoch else false;
|
\\SELECT invite.id, invite.expires_at
|
||||||
|
\\FROM invite
|
||||||
|
\\ LEFT OUTER JOIN local_user ON invite.id = local_user.invite_id
|
||||||
|
\\WHERE invite.invite_code = ?
|
||||||
|
\\GROUP BY invite.id
|
||||||
|
\\HAVING
|
||||||
|
\\ (invite.max_uses IS NULL OR invite.max_uses > COUNT(local_user.user_id))
|
||||||
|
\\
|
||||||
|
,
|
||||||
|
.{invite_code},
|
||||||
|
null,
|
||||||
|
)) orelse return error.InvalidInvite;
|
||||||
|
|
||||||
if (!uses_left or expired) return error.InvalidInvite;
|
const expired = if (result[1]) |expires_at| now.seconds_since_epoch > expires_at.seconds_since_epoch else false;
|
||||||
|
if (expired) return error.InvalidInvite;
|
||||||
|
|
||||||
|
//const invite = (try self.db.getBy(models.Invite, .invite_code, invite_code, self.arena.allocator())) orelse return error.InvalidInvite;
|
||||||
|
//const invite = (try self.db.getBy(models.Invite, .invite_code, invite_code, self.arena.allocator())) orelse return error.InvalidInvite;
|
||||||
|
//const uses = try self.db.countWhereEq(models.LocalUser, .invite_id, invite.id);
|
||||||
|
//const uses_left = if (invite.max_uses) |max_uses| uses < max_uses else true;
|
||||||
|
//const expired = if (invite.expires_at) |expires_at| now.seconds_since_epoch > expires_at.seconds_since_epoch else false;
|
||||||
|
|
||||||
|
//if (!uses_left or expired) return error.InvalidInvite;
|
||||||
// TODO: increment uses
|
// TODO: increment uses
|
||||||
break :blk invite.id;
|
break :blk result[0];
|
||||||
} else null;
|
} else null;
|
||||||
|
|
||||||
// use internal alloc because necessary buffer is *big*
|
// use internal alloc because necessary buffer is *big*
|
||||||
|
@ -354,8 +374,15 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
|
|
||||||
const community_id = if (info.community_host) |host| blk: {
|
const community_id = if (info.community_host) |host| blk: {
|
||||||
//const id_tuple = (try self.db.execRow("select id from community where host = '?'", host, &.{Uuid}, self.arena.allocator())) orelse return error.CommunityNotFound;
|
//const id_tuple = (try self.db.execRow("select id from community where host = '?'", host, &.{Uuid}, self.arena.allocator())) orelse return error.CommunityNotFound;
|
||||||
const community = (try self.db.getBy(models.Community, .host, host, self.arena.allocator())) orelse return error.CommunityNotFound;
|
const community_result = (try self.db.execRow2(
|
||||||
break :blk community.id;
|
&.{Uuid},
|
||||||
|
"SELECT id FROM community WHERE host = ?",
|
||||||
|
.{host},
|
||||||
|
null,
|
||||||
|
)) orelse return error.CommunityNotFound;
|
||||||
|
|
||||||
|
//const community = (try self.db.getBy(models.Community, .host, host, self.arena.allocator())) orelse return error.CommunityNotFound;
|
||||||
|
break :blk community_result[0];
|
||||||
//break :blk id_tuple[0];
|
//break :blk id_tuple[0];
|
||||||
} else null;
|
} else null;
|
||||||
|
|
||||||
|
@ -385,23 +412,37 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
|
|
||||||
pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResult {
|
pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResult {
|
||||||
// TODO: This gives away the existence of a user through a timing side channel. is that acceptable?
|
// TODO: This gives away the existence of a user through a timing side channel. is that acceptable?
|
||||||
const user_info = (try self.db.getBy(models.User, .username, username, self.arena.allocator())) orelse return error.InvalidLogin;
|
//const user_info = (try self.db.getBy(models.User, .username, username, self.arena.allocator())) orelse return error.InvalidLogin;
|
||||||
const local_user_info = (try self.db.getBy(models.LocalUser, .user_id, user_info.id, self.arena.allocator())) orelse return error.InvalidLogin;
|
//const local_user_info = (try self.db.getBy(models.LocalUser, .user_id, user_info.id, self.arena.allocator())) orelse return error.InvalidLogin;
|
||||||
|
|
||||||
|
const user_info = (try self.db.execRow2(
|
||||||
|
&.{ Uuid, []const u8 },
|
||||||
|
\\SELECT user.id, local_user.hashed_password
|
||||||
|
\\FROM user JOIN local_user ON local_user.user_id = user.id
|
||||||
|
\\WHERE user.username = ?
|
||||||
|
,
|
||||||
|
.{username},
|
||||||
|
self.arena.allocator(),
|
||||||
|
)) orelse return error.InvalidLogin;
|
||||||
|
|
||||||
|
const user_id = user_info[0];
|
||||||
|
const hashed_password = user_info[1];
|
||||||
|
|
||||||
//defer free(self.arena.allocator(), user_info);
|
//defer free(self.arena.allocator(), user_info);
|
||||||
|
|
||||||
const Hash = std.crypto.pwhash.scrypt;
|
const Hash = std.crypto.pwhash.scrypt;
|
||||||
Hash.strVerify(local_user_info.hashed_password, password, .{ .allocator = self.internal_alloc }) catch |err| switch (err) {
|
Hash.strVerify(hashed_password, password, .{ .allocator = self.internal_alloc }) catch |err| switch (err) {
|
||||||
error.PasswordVerificationFailed => return error.InvalidLogin,
|
error.PasswordVerificationFailed => return error.InvalidLogin,
|
||||||
else => return err,
|
else => return err,
|
||||||
};
|
};
|
||||||
|
|
||||||
const token = try self.createToken(user_info.id);
|
const token = try self.createToken(user_id);
|
||||||
|
|
||||||
var token_enc: [token_str_len]u8 = undefined;
|
var token_enc: [token_str_len]u8 = undefined;
|
||||||
_ = std.base64.standard.Encoder.encode(&token_enc, &token.value);
|
_ = std.base64.standard.Encoder.encode(&token_enc, &token.value);
|
||||||
|
|
||||||
return LoginResult{
|
return LoginResult{
|
||||||
.user_id = user_info.id,
|
.user_id = user_id,
|
||||||
.token = token_enc,
|
.token = token_enc,
|
||||||
.issued_at = token.info.issued_at,
|
.issued_at = token.info.issued_at,
|
||||||
};
|
};
|
||||||
|
@ -425,7 +466,7 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
.issued_at = DateTime.now(),
|
.issued_at = DateTime.now(),
|
||||||
};
|
};
|
||||||
|
|
||||||
try self.db.insert(models.Token, db_token);
|
try self.db.insert2("token", db_token);
|
||||||
return TokenResult{
|
return TokenResult{
|
||||||
.info = db_token,
|
.info = db_token,
|
||||||
.value = token,
|
.value = token,
|
||||||
|
@ -440,14 +481,21 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
// Users can only make invites to their own community, unless they
|
// Users can only make invites to their own community, unless they
|
||||||
// are system users
|
// are system users
|
||||||
const community_id = if (options.to_community) |host| blk: {
|
const community_id = if (options.to_community) |host| blk: {
|
||||||
const desired_community = (try self.db.getBy(models.Community, .host, host, self.arena.allocator())) orelse return error.CommunityNotFound;
|
const desired_community = (try self.db.execRow2(
|
||||||
if (user.community_id != null and !Uuid.eql(desired_community.id, user.community_id.?)) {
|
&.{Uuid},
|
||||||
|
"SELECT id FROM community WHERE host = ?",
|
||||||
|
.{host},
|
||||||
|
null,
|
||||||
|
)) orelse return error.CommunityNotFound;
|
||||||
|
|
||||||
|
if (user.community_id != null and !Uuid.eql(desired_community[0], user.community_id.?)) {
|
||||||
return error.WrongCommunity;
|
return error.WrongCommunity;
|
||||||
}
|
}
|
||||||
|
|
||||||
break :blk desired_community.id;
|
break :blk desired_community[0];
|
||||||
} else null;
|
} else null;
|
||||||
if (user.community_id != null and options.to_community == null) {
|
|
||||||
|
if (user.community_id != null and community_id == null) {
|
||||||
return error.WrongCommunity;
|
return error.WrongCommunity;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
168
src/main/db.zig
168
src/main/db.zig
|
@ -9,8 +9,6 @@ const DateTime = util.DateTime;
|
||||||
const String = []const u8;
|
const String = []const u8;
|
||||||
const comptimePrint = std.fmt.comptimePrint;
|
const comptimePrint = std.fmt.comptimePrint;
|
||||||
|
|
||||||
pub const builder = @import("./db/query_builder.zig");
|
|
||||||
|
|
||||||
fn tableName(comptime T: type) String {
|
fn tableName(comptime T: type) String {
|
||||||
return switch (T) {
|
return switch (T) {
|
||||||
models.Note => "note",
|
models.Note => "note",
|
||||||
|
@ -25,6 +23,44 @@ fn tableName(comptime T: type) String {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn readRow(comptime RowTuple: type, row: sql.Row, allocator: ?std.mem.Allocator) !RowTuple {
|
||||||
|
var result: RowTuple = undefined;
|
||||||
|
// TODO: undo allocations on failure
|
||||||
|
inline for (std.meta.fields(RowTuple)) |f, i| {
|
||||||
|
@field(result, f.name) = try getAlloc(row, f.field_type, i, allocator);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ResultSet(comptime result_types: []const type) type {
|
||||||
|
return struct {
|
||||||
|
pub const QueryError = anyerror;
|
||||||
|
pub const Row = std.meta.Tuple(result_types);
|
||||||
|
|
||||||
|
_stmt: sql.PreparedStmt,
|
||||||
|
err: ?QueryError = null,
|
||||||
|
|
||||||
|
pub fn finish(self: *@This()) void {
|
||||||
|
self._stmt.finalize();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn row(self: *@This(), allocator: ?std.mem.Allocator) ?Row {
|
||||||
|
const sql_result = self._stmt.step() catch |err| {
|
||||||
|
self.err = err;
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (sql_result) |sql_row| {
|
||||||
|
return readRow(Row, sql_row, allocator) catch |err| {
|
||||||
|
self.err = err;
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
} else return null;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// Combines an array/tuple of strings into a single string, with a copy of
|
// Combines an array/tuple of strings into a single string, with a copy of
|
||||||
// joiner in between each one
|
// joiner in between each one
|
||||||
fn join(comptime vals: anytype, comptime joiner: String) String {
|
fn join(comptime vals: anytype, comptime joiner: String) String {
|
||||||
|
@ -121,8 +157,9 @@ fn fieldsExcept(comptime T: type, comptime to_ignore: []const String) []const St
|
||||||
// pub fn bindToSql(val: T, stmt: sql.PreparedStmt, idx: u15) !void
|
// pub fn bindToSql(val: T, stmt: sql.PreparedStmt, idx: u15) !void
|
||||||
// TODO define what error set this ^ should return
|
// TODO define what error set this ^ should return
|
||||||
fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void {
|
fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void {
|
||||||
|
if (comptime std.meta.trait.isZigString(@TypeOf(val))) return stmt.bindText(idx, val);
|
||||||
|
|
||||||
return switch (@TypeOf(val)) {
|
return switch (@TypeOf(val)) {
|
||||||
[]u8, []const u8 => stmt.bindText(idx, val),
|
|
||||||
i64 => stmt.bindI64(idx, val),
|
i64 => stmt.bindI64(idx, val),
|
||||||
Uuid => stmt.bindUuid(idx, val),
|
Uuid => stmt.bindUuid(idx, val),
|
||||||
DateTime => stmt.bindDateTime(idx, val),
|
DateTime => stmt.bindDateTime(idx, val),
|
||||||
|
@ -134,7 +171,8 @@ fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void {
|
||||||
val.bindToSql(stmt, idx)
|
val.bindToSql(stmt, idx)
|
||||||
else
|
else
|
||||||
@compileError("unsupported type " ++ @typeName(T)),
|
@compileError("unsupported type " ++ @typeName(T)),
|
||||||
else => @compileError("unsupported Type " ++ @typeName(T)),
|
else => unreachable,
|
||||||
|
//@compileError("unsupported type " ++ @typeName(T)),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -144,9 +182,9 @@ fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void {
|
||||||
// declaring a method with the given signature:
|
// declaring a method with the given signature:
|
||||||
// pub fn getFromSql(row: sql.Row, idx: u15, alloc: std.mem.Allocator) !T
|
// pub fn getFromSql(row: sql.Row, idx: u15, alloc: std.mem.Allocator) !T
|
||||||
// TODO define what error set this ^ should return
|
// TODO define what error set this ^ should return
|
||||||
fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: std.mem.Allocator) !T {
|
fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: ?std.mem.Allocator) !T {
|
||||||
return switch (T) {
|
return switch (T) {
|
||||||
[]u8, []const u8 => row.getTextAlloc(idx, alloc),
|
[]u8, []const u8 => row.getTextAlloc(idx, alloc orelse return error.AllocatorRequired),
|
||||||
i64 => row.getI64(idx),
|
i64 => row.getI64(idx),
|
||||||
Uuid => row.getUuid(idx),
|
Uuid => row.getUuid(idx),
|
||||||
DateTime => row.getDateTime(idx),
|
DateTime => row.getDateTime(idx),
|
||||||
|
@ -158,11 +196,11 @@ fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: std.mem.Allocator)
|
||||||
try getAlloc(row, std.meta.Child(T), idx, alloc),
|
try getAlloc(row, std.meta.Child(T), idx, alloc),
|
||||||
|
|
||||||
.Struct, .Union, .Opaque => if (@hasDecl(T, "getFromSql"))
|
.Struct, .Union, .Opaque => if (@hasDecl(T, "getFromSql"))
|
||||||
T.getFromSql(row, idx, alloc)
|
T.getFromSql(row, idx, alloc orelse return error.AllocatorRequired)
|
||||||
else
|
else
|
||||||
@compileError("unknown type " ++ @typeName(T)),
|
@compileError("unknown type " ++ @typeName(T)),
|
||||||
|
|
||||||
.Enum => try getEnum(row, T, idx, alloc),
|
.Enum => try getEnum(row, T, idx, alloc orelse return error.AllocatorRequired),
|
||||||
|
|
||||||
else => @compileError("unknown type " ++ @typeName(T)),
|
else => @compileError("unknown type " ++ @typeName(T)),
|
||||||
},
|
},
|
||||||
|
@ -195,20 +233,65 @@ pub const Database = struct {
|
||||||
self.db.close();
|
self.db.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn execRowQuery(self: *Database, comptime q: builder.Query, alloc: std.mem.Allocator) !?q.rowType() {
|
pub fn exec2(
|
||||||
std.log.debug("executing sql:\n===\n{s}\n===", .{q.str()});
|
self: *Database,
|
||||||
var stmt = try self.db.prepare(q.str());
|
comptime result_types: []const type,
|
||||||
|
comptime q: []const u8,
|
||||||
|
args: anytype,
|
||||||
|
) !ResultSet(result_types) {
|
||||||
|
std.log.debug("executing sql:\n===\n{s}\n===", .{q});
|
||||||
|
|
||||||
|
const stmt = try self.db.prepare(q);
|
||||||
errdefer stmt.finalize();
|
errdefer stmt.finalize();
|
||||||
|
|
||||||
const row = (try stmt.step()) orelse return null;
|
inline for (std.meta.fields(@TypeOf(args))) |field, i| {
|
||||||
|
try bind(stmt, @intCast(u15, i + 1), @field(args, field.name));
|
||||||
std.log.debug("successful query", .{});
|
|
||||||
var result: q.rowType() = undefined;
|
|
||||||
inline for (std.meta.fields(q.rowType())) |f, i| {
|
|
||||||
result[i] = try getAlloc(row, f.field_type, i, alloc);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return ResultSet(result_types){
|
||||||
|
._stmt = stmt,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn execRow2(
|
||||||
|
self: *Database,
|
||||||
|
comptime result_types: []const type,
|
||||||
|
comptime q: []const u8,
|
||||||
|
args: anytype,
|
||||||
|
allocator: ?std.mem.Allocator,
|
||||||
|
) !?ResultSet(result_types).Row {
|
||||||
|
var results = try self.exec2(result_types, q, args);
|
||||||
|
defer results.finish();
|
||||||
|
|
||||||
|
const row = results.row(allocator);
|
||||||
|
return row orelse (results.err orelse null);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_field_list(comptime T: type, comptime placeholder: ?[]const u8) []const u8 {
|
||||||
|
comptime {
|
||||||
|
const joiner = ",";
|
||||||
|
var result: []const u8 = "";
|
||||||
|
inline for (std.meta.fields(T)) |f| {
|
||||||
|
result = result ++ joiner ++ (placeholder orelse f.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
return "(" ++ result[joiner.len..] ++ ")";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insert2(
|
||||||
|
self: *Database,
|
||||||
|
comptime table: []const u8,
|
||||||
|
value: anytype,
|
||||||
|
) !void {
|
||||||
|
const ValueType = comptime @TypeOf(value);
|
||||||
|
const table_spec = comptime table ++ build_field_list(ValueType, null);
|
||||||
|
const value_spec = comptime build_field_list(ValueType, "?");
|
||||||
|
const q = comptime std.fmt.comptimePrint(
|
||||||
|
"INSERT INTO {s} VALUES {s}",
|
||||||
|
.{ table_spec, value_spec },
|
||||||
|
);
|
||||||
|
_ = try self.execRow2(&.{}, q, value, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lower level function
|
// Lower level function
|
||||||
|
@ -306,55 +389,6 @@ pub const Database = struct {
|
||||||
return results.toOwnedSlice();
|
return results.toOwnedSlice();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the number of rows that satisfy an equality check on
|
|
||||||
// one of their fields
|
|
||||||
pub fn countWhereEq(
|
|
||||||
self: *Database,
|
|
||||||
comptime T: type,
|
|
||||||
comptime field: std.meta.FieldEnum(T),
|
|
||||||
val: std.meta.fieldInfo(T, field).field_type,
|
|
||||||
) !usize {
|
|
||||||
const field_name = std.meta.fieldInfo(T, field).name;
|
|
||||||
const q = comptime (Query{
|
|
||||||
.select = &.{"COUNT()"},
|
|
||||||
.from = tableName(T),
|
|
||||||
.where = field_name ++ " = ?",
|
|
||||||
}).str();
|
|
||||||
|
|
||||||
var stmt = try self.db.prepare(q);
|
|
||||||
defer stmt.finalize();
|
|
||||||
|
|
||||||
try bind(stmt, 1, val);
|
|
||||||
|
|
||||||
const row = (try stmt.step()) orelse unreachable;
|
|
||||||
return @intCast(usize, try row.getI64(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns whether a row with the given value exists.
|
|
||||||
pub fn existsWhereEq(
|
|
||||||
self: *Database,
|
|
||||||
comptime T: type,
|
|
||||||
comptime field: std.meta.FieldEnum(T),
|
|
||||||
val: std.meta.fieldInfo(T, field).field_type,
|
|
||||||
) !bool {
|
|
||||||
const field_name = std.meta.fieldInfo(T, field).name;
|
|
||||||
// TODO: don't like this query
|
|
||||||
const q = comptime (Query{
|
|
||||||
.select = &.{"COUNT(1)"},
|
|
||||||
.from = tableName(T),
|
|
||||||
.where = field_name ++ " = ?",
|
|
||||||
.limit = 1,
|
|
||||||
}).str();
|
|
||||||
|
|
||||||
var stmt = try self.db.prepare(q);
|
|
||||||
defer stmt.finalize();
|
|
||||||
|
|
||||||
try bind(stmt, 1, val);
|
|
||||||
|
|
||||||
const row = (try stmt.step()) orelse unreachable;
|
|
||||||
return (try row.getI64(0)) > 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Inserts a row into the database
|
// Inserts a row into the database
|
||||||
// TODO: consider making this generic?
|
// TODO: consider making this generic?
|
||||||
pub fn insert(self: *Database, comptime T: type, val: T) !void {
|
pub fn insert(self: *Database, comptime T: type, val: T) !void {
|
||||||
|
|
|
@ -1,284 +0,0 @@
|
||||||
const std = @import("std");
|
|
||||||
const util = @import("util");
|
|
||||||
const builtin = @import("builtin");
|
|
||||||
|
|
||||||
const String = []const u8;
|
|
||||||
const comptimePrint = std.fmt.comptimePrint;
|
|
||||||
|
|
||||||
fn baseTypeName(comptime T: type) []const u8 {
|
|
||||||
comptime {
|
|
||||||
const name = @typeName(T);
|
|
||||||
const start = for (name) |_, i| {
|
|
||||||
if (name[name.len - i] == '.') {
|
|
||||||
// This function has an off-by-one error in the self hosted compiler (-fno-stage1)
|
|
||||||
// The following code fixes it as of 2022-08-07
|
|
||||||
// TODO: Figure out what's going on here
|
|
||||||
if (builtin.zig_backend == .stage1) {
|
|
||||||
break name.len - i;
|
|
||||||
} else {
|
|
||||||
break name.len - i + 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else 0;
|
|
||||||
|
|
||||||
return name[start..];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn tableName(comptime T: type) String {
|
|
||||||
return comptime util.case.pascalToSnake(baseTypeName(T));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Represents a table bound to an identifier in a sql query
|
|
||||||
pub const QueryTable = struct {
|
|
||||||
Model: type,
|
|
||||||
index: comptime_int,
|
|
||||||
|
|
||||||
// Gets a fully qualified field from a literal
|
|
||||||
pub fn field(comptime self: QueryTable, comptime lit: @Type(.EnumLiteral)) String {
|
|
||||||
comptime {
|
|
||||||
const f = @as(std.meta.FieldEnum(self.Model), lit);
|
|
||||||
return comptimePrint("{s}.{s}", .{ self.as(), @tagName(f) });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn select(comptime self: QueryTable, comptime lit: @Type(.EnumLiteral)) ResultColumn {
|
|
||||||
return .{
|
|
||||||
.@"type" = std.meta.fieldInfo(self.Model, lit).field_type,
|
|
||||||
.field = self.field(lit),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// returns the declaration to put in the FROM clause
|
|
||||||
fn declarationStr(comptime self: QueryTable) String {
|
|
||||||
comptime {
|
|
||||||
return comptimePrint("{s} AS {s}", .{ tableName(self.Model), self.as() });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn as(comptime self: QueryTable) String {
|
|
||||||
comptime {
|
|
||||||
return comptimePrint("{s}_{}", .{ tableName(self.Model), self.index });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
fn makeQueryTable(comptime Model: type, comptime table_index: usize) QueryTable {
|
|
||||||
return .{ .Model = Model, .index = table_index };
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn queryTables(comptime models: []const type) *const [models.len]QueryTable {
|
|
||||||
return map(type, QueryTable, models, makeQueryTable);
|
|
||||||
}
|
|
||||||
|
|
||||||
test "QueryTable.declarationStr" {
|
|
||||||
const MyTable = struct { id: i64 };
|
|
||||||
const tbl = QueryTable{
|
|
||||||
.Model = MyTable,
|
|
||||||
.index = 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
try std.testing.expectEqualStrings("my_table AS my_table_0", tbl.declarationStr());
|
|
||||||
try std.testing.expectEqualStrings("my_table_0.id", tbl.field(.id));
|
|
||||||
}
|
|
||||||
|
|
||||||
test "queryTables constructor" {
|
|
||||||
const MyTable = struct { id: i64 };
|
|
||||||
const MyOtherTable = struct { val: i64 };
|
|
||||||
|
|
||||||
const qt = queryTables(&.{ MyTable, MyOtherTable });
|
|
||||||
|
|
||||||
try std.testing.expectEqual(MyTable, qt[0].Model);
|
|
||||||
try std.testing.expectEqual(MyOtherTable, qt[1].Model);
|
|
||||||
try std.testing.expectEqualStrings("my_table_0", qt[0].as());
|
|
||||||
try std.testing.expectEqualStrings("my_other_table_1", qt[1].as());
|
|
||||||
}
|
|
||||||
|
|
||||||
fn map(comptime T: type, comptime R: type, comptime vals: []const T, comptime func: anytype) *const [vals.len]R {
|
|
||||||
var result: [vals.len]R = undefined;
|
|
||||||
if (@typeInfo(@TypeOf(func)).Fn.args.len == 2) {
|
|
||||||
inline for (vals) |v, i| result[i] = @as(R, func(v, i));
|
|
||||||
} else {
|
|
||||||
inline for (vals) |v, i| result[i] = @as(R, func(v));
|
|
||||||
}
|
|
||||||
|
|
||||||
return &result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Combines an array/tuple of strings into a single string, with a copy of
|
|
||||||
// joiner in between each one
|
|
||||||
fn join(comptime vals: []const String, comptime joiner: String) String {
|
|
||||||
if (vals.len == 0) return "";
|
|
||||||
|
|
||||||
var result: String = "";
|
|
||||||
for (vals) |v| {
|
|
||||||
result = comptimePrint("{s}{s}{s}", .{ result, joiner, v });
|
|
||||||
}
|
|
||||||
|
|
||||||
return result[joiner.len..];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stringifies and joins an array of conditions into a single string
|
|
||||||
fn joinConditions(comptime cs: []const Condition, comptime joiner: String) String {
|
|
||||||
var strs: [cs.len]String = undefined;
|
|
||||||
for (cs) |v, i| strs[i] = v.str();
|
|
||||||
return join(&strs, joiner);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Represents a condition in a SQL statement
|
|
||||||
pub const Condition = union(enum) {
|
|
||||||
const BinaryOp = struct {
|
|
||||||
lhs: String,
|
|
||||||
rhs: String,
|
|
||||||
};
|
|
||||||
|
|
||||||
eql: BinaryOp,
|
|
||||||
is_null: String,
|
|
||||||
val: String,
|
|
||||||
not: *const Condition,
|
|
||||||
all: []const Condition,
|
|
||||||
any: []const Condition,
|
|
||||||
|
|
||||||
fn str(comptime self: Condition) String {
|
|
||||||
comptime {
|
|
||||||
return comptimePrint("({s})", .{switch (self) {
|
|
||||||
.eql => |op| comptimePrint("{s} = {s}", .{ op.lhs, op.rhs }),
|
|
||||||
.is_null => |val| comptimePrint("{s} IS NULL", .{val}),
|
|
||||||
.val => |val| val,
|
|
||||||
.not => |c| comptimePrint("NOT {s}", .{c.str()}),
|
|
||||||
.all => |cs| joinConditions(cs, " AND "),
|
|
||||||
.any => |cs| joinConditions(cs, " OR "),
|
|
||||||
}});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn eql(comptime lhs: String, comptime rhs: String) Condition {
|
|
||||||
return .{
|
|
||||||
.eql = .{ .lhs = lhs, .rhs = rhs },
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn all(comptime cs: []const Condition) Condition {
|
|
||||||
return .{
|
|
||||||
.all = cs,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
test "Condition.str()" {
|
|
||||||
try std.testing.expectEqualStrings(
|
|
||||||
"((abc = def) AND (def = abc))",
|
|
||||||
(comptime Condition{ .all = &.{
|
|
||||||
.{ .eql = .{ .lhs = "abc", .rhs = "def" } },
|
|
||||||
.{ .eql = .{ .lhs = "def", .rhs = "abc" } },
|
|
||||||
} }).str(),
|
|
||||||
);
|
|
||||||
|
|
||||||
try std.testing.expectEqualStrings(
|
|
||||||
"((abc IS NULL) OR (NOT (def)))",
|
|
||||||
(comptime Condition{ .any = &.{
|
|
||||||
.{ .is_null = "abc" },
|
|
||||||
.{ .not = &.{ .val = "def" } },
|
|
||||||
} }).str(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
const ResultColumn = struct {
|
|
||||||
@"type": type,
|
|
||||||
field: []const u8,
|
|
||||||
|
|
||||||
pub fn toSelectClause(comptime self: ResultColumn) String {
|
|
||||||
return self.field;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn toStructField(comptime self: ResultColumn, comptime index: usize) std.builtin.Type.StructField {
|
|
||||||
return .{
|
|
||||||
.name = comptimePrint("{}", .{index}),
|
|
||||||
.field_type = self.@"type",
|
|
||||||
.default_value = null,
|
|
||||||
.is_comptime = false,
|
|
||||||
.alignment = 0,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Represents a full SQL query
|
|
||||||
pub const Query = struct {
|
|
||||||
tables: []const QueryTable,
|
|
||||||
fields: []const ResultColumn,
|
|
||||||
filter: Condition,
|
|
||||||
|
|
||||||
pub fn from(comptime tables: []const QueryTable) Query {
|
|
||||||
return .{
|
|
||||||
.tables = tables,
|
|
||||||
.fields = &.{},
|
|
||||||
.filter = .{ .val = "TRUE" }, // TODO
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn str(comptime self: Query) String {
|
|
||||||
comptime {
|
|
||||||
const table_aliases = map(QueryTable, String, self.tables, QueryTable.declarationStr);
|
|
||||||
const select_clauses = map(ResultColumn, String, self.fields, ResultColumn.toSelectClause);
|
|
||||||
return comptimePrint("SELECT {s} FROM {s} WHERE {s}", .{ join(select_clauses, ", "), join(table_aliases, ", "), self.filter.str() });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn rowType(comptime self: *const Query) type {
|
|
||||||
const struct_fields = map(ResultColumn, std.builtin.Type.StructField, self.fields, ResultColumn.toStructField);
|
|
||||||
|
|
||||||
return @Type(.{ .Struct = .{
|
|
||||||
.layout = .Auto,
|
|
||||||
.fields = struct_fields,
|
|
||||||
.decls = &.{},
|
|
||||||
.is_tuple = true,
|
|
||||||
} });
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn select(comptime self: Query, comptime fields: []const ResultColumn) Query {
|
|
||||||
return .{
|
|
||||||
.tables = self.tables,
|
|
||||||
.fields = fields,
|
|
||||||
.filter = self.filter,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn where(comptime self: Query, comptime condition: Condition) Query {
|
|
||||||
return .{
|
|
||||||
.tables = self.tables,
|
|
||||||
.fields = self.fields,
|
|
||||||
.filter = condition,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
test "Query" {
|
|
||||||
const C = Condition;
|
|
||||||
const MyTable = struct { id: i64 };
|
|
||||||
const MyOtherTable = struct {
|
|
||||||
val: []const u8,
|
|
||||||
};
|
|
||||||
const qt = queryTables(&.{ MyTable, MyOtherTable, MyTable });
|
|
||||||
const t1 = qt[0];
|
|
||||||
const t2 = qt[2];
|
|
||||||
const t_other = qt[1];
|
|
||||||
|
|
||||||
const q = comptime Query
|
|
||||||
.from(qt)
|
|
||||||
.select(&.{ t1.select(.id), t_other.select(.val) })
|
|
||||||
.where(C.all(&.{
|
|
||||||
C.eql(t1.field(.id), t2.field(.id)),
|
|
||||||
C.eql(t1.field(.id), t2.field(.id)),
|
|
||||||
}));
|
|
||||||
|
|
||||||
try std.testing.expectEqualStrings(
|
|
||||||
"SELECT my_table_0.id, my_other_table_1.val " ++
|
|
||||||
"FROM my_table AS my_table_0, my_other_table AS my_other_table_1, my_table AS my_table_2 " ++
|
|
||||||
"WHERE ((my_table_0.id = my_table_2.id) AND (my_table_0.id = my_table_2.id))",
|
|
||||||
comptime q.str(),
|
|
||||||
);
|
|
||||||
|
|
||||||
const fields = std.meta.fields(q.rowType());
|
|
||||||
try std.testing.expectEqual(i64, fields[0].field_type);
|
|
||||||
try std.testing.expectEqual([]const u8, fields[1].field_type);
|
|
||||||
}
|
|
Loading…
Reference in New Issue