Remove dead code
This commit is contained in:
parent
91c116a303
commit
99337b6429
6 changed files with 19 additions and 419 deletions
|
@ -127,7 +127,7 @@ pub const ApiSource = struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn getCommunityFromHost(self: *ApiSource, host: []const u8) !?Uuid {
|
fn getCommunityFromHost(self: *ApiSource, host: []const u8) !?Uuid {
|
||||||
if (try self.db.execRow2(
|
if (try self.db.execRow(
|
||||||
&.{Uuid},
|
&.{Uuid},
|
||||||
"SELECT id FROM community WHERE host = ?",
|
"SELECT id FROM community WHERE host = ?",
|
||||||
.{host},
|
.{host},
|
||||||
|
@ -183,35 +183,6 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
self.arena.deinit();
|
self.arena.deinit();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn getAuthenticatedUser(self: *Self) !models.User {
|
|
||||||
if (self.user_id) |id| {
|
|
||||||
const user = try self.db.getBy(models.User, .id, id, self.arena.allocator());
|
|
||||||
if (user == null) return error.NotAuthorized;
|
|
||||||
|
|
||||||
return user.?;
|
|
||||||
} else {
|
|
||||||
return error.NotAuthorized;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn getAuthenticatedLocalUser(self: *Self) !models.LocalUser {
|
|
||||||
if (self.user_id) |user_id| {
|
|
||||||
const local_user = try self.db.getBy(models.LocalUser, .user_id, user_id, self.arena.allocator());
|
|
||||||
if (local_user == null) return error.NotAuthorized;
|
|
||||||
|
|
||||||
return local_user.?;
|
|
||||||
} else {
|
|
||||||
return error.NotAuthorized;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn getAuthenticatedActor(self: *Self) !models.Actor {
|
|
||||||
return if (self.user_id) |user_id|
|
|
||||||
(try self.db.getBy(models.Actor, .user_id, user_id, self.arena.allocator())) orelse error.NotAuthorized
|
|
||||||
else
|
|
||||||
error.NotAuthorized;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResult {
|
pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResult {
|
||||||
const user_id = (try services.users.lookupByUsername(&self.db, username, self.community_id)) orelse return error.InvalidLogin;
|
const user_id = (try services.users.lookupByUsername(&self.db, username, self.community_id)) orelse return error.InvalidLogin;
|
||||||
try services.auth.passwords.verify(&self.db, user_id, password, self.internal_alloc);
|
try services.auth.passwords.verify(&self.db, user_id, password, self.internal_alloc);
|
||||||
|
@ -230,7 +201,7 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
};
|
};
|
||||||
pub fn getTokenInfo(self: *Self) !TokenInfo {
|
pub fn getTokenInfo(self: *Self) !TokenInfo {
|
||||||
if (self.user_id) |user_id| {
|
if (self.user_id) |user_id| {
|
||||||
const result = (try self.db.execRow2(
|
const result = (try self.db.execRow(
|
||||||
&.{[]const u8},
|
&.{[]const u8},
|
||||||
"SELECT username FROM user WHERE id = ?",
|
"SELECT username FROM user WHERE id = ?",
|
||||||
.{user_id},
|
.{user_id},
|
||||||
|
@ -260,7 +231,7 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
// Users can only make invites to their own community, unless they
|
// Users can only make invites to their own community, unless they
|
||||||
// are system users
|
// are system users
|
||||||
const community_id = if (options.to_community) |host| blk: {
|
const community_id = if (options.to_community) |host| blk: {
|
||||||
const desired_community = (try self.db.execRow2(
|
const desired_community = (try self.db.execRow(
|
||||||
&.{Uuid},
|
&.{Uuid},
|
||||||
"SELECT id FROM community WHERE host = ?",
|
"SELECT id FROM community WHERE host = ?",
|
||||||
.{host},
|
.{host},
|
||||||
|
@ -307,9 +278,5 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
|
|
||||||
return invite;
|
return invite;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getInvite(self: *Self, id: Uuid) !?models.Invite {
|
|
||||||
return self.db.getBy(models.Invite, .id, id, self.arena.allocator());
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
149
src/main/api/'
149
src/main/api/'
|
@ -1,149 +0,0 @@
|
||||||
const std = @import("std");
|
|
||||||
const util = @import("util");
|
|
||||||
|
|
||||||
const Uuid = util.Uuid;
|
|
||||||
const DateTime = util.DateTime;
|
|
||||||
|
|
||||||
pub const passwords = struct {
|
|
||||||
const PwHash = std.crypto.pwhash.scrypt;
|
|
||||||
const pw_hash_params = PwHash.Params.interactive;
|
|
||||||
const pw_hash_encoding = .phc;
|
|
||||||
const pw_hash_buf_size = 128;
|
|
||||||
|
|
||||||
const PwHashBuf = [pw_hash_buf_size]u8;
|
|
||||||
|
|
||||||
pub const Password = struct {
|
|
||||||
user_id: Uuid,
|
|
||||||
|
|
||||||
hashed_password: []const u8,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Returned slice points into buf
|
|
||||||
fn hashPassword(password: []const u8, alloc: std.mem.Allocator, buf: *PwHashBuf) []const u8 {
|
|
||||||
return PwHash.strHash(password, .{ .allocator = alloc, .params = pw_hash_params, .encoding = pw_hash_encoding }, buf) catch unreachable;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const VerifyError = error{
|
|
||||||
InvalidLogin,
|
|
||||||
DbError,
|
|
||||||
};
|
|
||||||
pub fn verify(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) VerifyError!void {
|
|
||||||
const hash = (try db.execRow2(
|
|
||||||
&.{PwHashBuf},
|
|
||||||
"SELECT hashed_password FROM account_password WHERE user_id = ? LIMIT 1",
|
|
||||||
.{user_id},
|
|
||||||
null,
|
|
||||||
)) orelse return error.PasswordNotFound;
|
|
||||||
|
|
||||||
try PwHash.strVerify(&hash[0], password, .{ .allocator = alloc });
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const CreateError = error{DbError};
|
|
||||||
pub fn create(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) CreateError!void {
|
|
||||||
var buf: PwHashBuf = undefined;
|
|
||||||
const hash = PwHash.strHash(password, .{ .allocator = alloc, .params = pw_hash_params, .encoding = pw_hash_encoding }, &buf) catch unreachable;
|
|
||||||
|
|
||||||
try db.insert2("account_password", .{
|
|
||||||
.user_id = user_id,
|
|
||||||
.hashed_password = hash,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const tokens = struct {
|
|
||||||
const token_len = 20;
|
|
||||||
pub const Token = struct {
|
|
||||||
pub const Value = [token_len]u8;
|
|
||||||
pub const Info = struct {
|
|
||||||
user_id: Uuid,
|
|
||||||
issued_at: DateTime,
|
|
||||||
};
|
|
||||||
|
|
||||||
value: Value,
|
|
||||||
|
|
||||||
issued_at: DateTime,
|
|
||||||
};
|
|
||||||
|
|
||||||
const TokenHash = std.crypto.hash.sha2.Sha256;
|
|
||||||
|
|
||||||
const DbToken = struct {
|
|
||||||
hash: []const u8,
|
|
||||||
user_id: Uuid,
|
|
||||||
issued_at: DateTime,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const CreateError = error{DbError};
|
|
||||||
pub fn create(db: anytype, user_id: Uuid) CreateError!Token {
|
|
||||||
var token: [token_len]u8 = undefined;
|
|
||||||
std.crypto.random.bytes(&token);
|
|
||||||
|
|
||||||
var hash: [TokenHash.digest_length]u8 = undefined;
|
|
||||||
TokenHash.hash(&token, &hash, .{});
|
|
||||||
|
|
||||||
const issued_at = DateTime.now();
|
|
||||||
|
|
||||||
db.insert2("token", DbToken{
|
|
||||||
.hash = &hash,
|
|
||||||
.user_id = user_id,
|
|
||||||
.issued_at = issued_at,
|
|
||||||
}) catch return error.DbError;
|
|
||||||
|
|
||||||
return Token{
|
|
||||||
.value = token,
|
|
||||||
.issued_at = issued_at,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
fn lookupUserTokenFromHash(db: anytype, hash: []const u8, community_id: Uuid) !?Uuid {
|
|
||||||
return if (try db.execRow2(
|
|
||||||
&.{ Uuid, DateTime },
|
|
||||||
\\SELECT user.id, token.issued_at
|
|
||||||
\\FROM token JOIN user ON token.user_id = user.id
|
|
||||||
\\WHERE user.community_id = ? AND token.hash = ?
|
|
||||||
\\LIMIT 1
|
|
||||||
,
|
|
||||||
.{ community_id, hash },
|
|
||||||
null,
|
|
||||||
)) |result|
|
|
||||||
Token.Info{
|
|
||||||
.user_id = result[0],
|
|
||||||
.issued_at = result[1],
|
|
||||||
}
|
|
||||||
else
|
|
||||||
null;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn lookupSystemTokenFromHash(db: anytype, hash: []const u8) !?Token.Info {
|
|
||||||
return if (try db.execRow2(
|
|
||||||
&.{ Uuid, DateTime },
|
|
||||||
\\SELECT user.id, token.issued_at
|
|
||||||
\\FROM token JOIN user ON token.user_id = user.id
|
|
||||||
\\WHERE user.community_id IS NULL AND token.hash = ?
|
|
||||||
\\LIMIT 1
|
|
||||||
,
|
|
||||||
.{hash},
|
|
||||||
null,
|
|
||||||
)) |result|
|
|
||||||
Token.Info{
|
|
||||||
.user_id = result[0],
|
|
||||||
.issued_at = result[1],
|
|
||||||
}
|
|
||||||
else
|
|
||||||
null;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const VerifyError = error{ InvalidToken, DbError };
|
|
||||||
pub fn verifyToken(db: anytype, token: []const u8, community_id: ?Uuid) VerifyError!Token.Info {
|
|
||||||
var hash: [TokenHash.digest_length]u8 = undefined;
|
|
||||||
TokenHash.hash(&token, &hash, .{});
|
|
||||||
|
|
||||||
const token_info = if (community_id) |id|
|
|
||||||
lookupUserTokenFromHash(db, &hash, id) catch return error.DbError
|
|
||||||
else
|
|
||||||
lookupSystemTokenFromHash(db, &hash) catch return error.DbError;
|
|
||||||
|
|
||||||
if (token_info) |info| return info;
|
|
||||||
|
|
||||||
return error.InvalidToken;
|
|
||||||
}
|
|
||||||
};
|
|
|
@ -23,7 +23,7 @@ pub const passwords = struct {
|
||||||
};
|
};
|
||||||
pub fn verify(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) VerifyError!void {
|
pub fn verify(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) VerifyError!void {
|
||||||
// TODO: This could be done w/o the dynamically allocated hash buf
|
// TODO: This could be done w/o the dynamically allocated hash buf
|
||||||
const hash = (db.execRow2(
|
const hash = (db.execRow(
|
||||||
&.{[]const u8},
|
&.{[]const u8},
|
||||||
"SELECT hashed_password FROM account_password WHERE user_id = ? LIMIT 1",
|
"SELECT hashed_password FROM account_password WHERE user_id = ? LIMIT 1",
|
||||||
.{user_id},
|
.{user_id},
|
||||||
|
@ -39,7 +39,7 @@ pub const passwords = struct {
|
||||||
var buf: PwHashBuf = undefined;
|
var buf: PwHashBuf = undefined;
|
||||||
const hash = PwHash.strHash(password, .{ .allocator = alloc, .params = pw_hash_params, .encoding = pw_hash_encoding }, &buf) catch unreachable;
|
const hash = PwHash.strHash(password, .{ .allocator = alloc, .params = pw_hash_params, .encoding = pw_hash_encoding }, &buf) catch unreachable;
|
||||||
|
|
||||||
db.insert2("account_password", .{
|
db.insert("account_password", .{
|
||||||
.user_id = user_id,
|
.user_id = user_id,
|
||||||
.hashed_password = hash,
|
.hashed_password = hash,
|
||||||
}) catch return error.DbError;
|
}) catch return error.DbError;
|
||||||
|
@ -79,7 +79,7 @@ pub const tokens = struct {
|
||||||
|
|
||||||
const issued_at = DateTime.now();
|
const issued_at = DateTime.now();
|
||||||
|
|
||||||
db.insert2("token", DbToken{
|
db.insert("token", DbToken{
|
||||||
.hash = &hash,
|
.hash = &hash,
|
||||||
.user_id = user_id,
|
.user_id = user_id,
|
||||||
.issued_at = issued_at,
|
.issued_at = issued_at,
|
||||||
|
@ -95,7 +95,7 @@ pub const tokens = struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn lookupUserTokenFromHash(db: anytype, hash: []const u8, community_id: Uuid) !?Token.Info {
|
fn lookupUserTokenFromHash(db: anytype, hash: []const u8, community_id: Uuid) !?Token.Info {
|
||||||
return if (try db.execRow2(
|
return if (try db.execRow(
|
||||||
&.{ Uuid, DateTime },
|
&.{ Uuid, DateTime },
|
||||||
\\SELECT user.id, token.issued_at
|
\\SELECT user.id, token.issued_at
|
||||||
\\FROM token JOIN user ON token.user_id = user.id
|
\\FROM token JOIN user ON token.user_id = user.id
|
||||||
|
@ -114,7 +114,7 @@ pub const tokens = struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn lookupSystemTokenFromHash(db: anytype, hash: []const u8) !?Token.Info {
|
fn lookupSystemTokenFromHash(db: anytype, hash: []const u8) !?Token.Info {
|
||||||
return if (try db.execRow2(
|
return if (try db.execRow(
|
||||||
&.{ Uuid, DateTime },
|
&.{ Uuid, DateTime },
|
||||||
\\SELECT user.id, token.issued_at
|
\\SELECT user.id, token.issued_at
|
||||||
\\FROM token JOIN user ON token.user_id = user.id
|
\\FROM token JOIN user ON token.user_id = user.id
|
||||||
|
|
|
@ -66,11 +66,11 @@ pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Co
|
||||||
.scheme = scheme,
|
.scheme = scheme,
|
||||||
};
|
};
|
||||||
|
|
||||||
if ((try db.execRow2(&.{Uuid}, "SELECT id FROM community WHERE host = ?", .{host}, null)) != null) {
|
if ((try db.execRow(&.{Uuid}, "SELECT id FROM community WHERE host = ?", .{host}, null)) != null) {
|
||||||
return error.CommunityExists;
|
return error.CommunityExists;
|
||||||
}
|
}
|
||||||
|
|
||||||
try db.insert2("community", community);
|
try db.insert("community", community);
|
||||||
|
|
||||||
return community;
|
return community;
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ pub const CreateOptions = struct {
|
||||||
};
|
};
|
||||||
|
|
||||||
fn lookupSystemUserByUsername(db: anytype, username: []const u8) !?Uuid {
|
fn lookupSystemUserByUsername(db: anytype, username: []const u8) !?Uuid {
|
||||||
return if (try db.execRow2(
|
return if (try db.execRow(
|
||||||
&.{Uuid},
|
&.{Uuid},
|
||||||
"SELECT user.id FROM user WHERE community_id IS NULL AND username = ?",
|
"SELECT user.id FROM user WHERE community_id IS NULL AND username = ?",
|
||||||
.{username},
|
.{username},
|
||||||
|
@ -48,7 +48,7 @@ fn lookupSystemUserByUsername(db: anytype, username: []const u8) !?Uuid {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn lookupUserByUsername(db: anytype, username: []const u8, community_id: Uuid) !?Uuid {
|
fn lookupUserByUsername(db: anytype, username: []const u8, community_id: Uuid) !?Uuid {
|
||||||
return if (try db.execRow2(
|
return if (try db.execRow(
|
||||||
&.{Uuid},
|
&.{Uuid},
|
||||||
"SELECT user.id FROM user WHERE community_id = ? AND username = ?",
|
"SELECT user.id FROM user WHERE community_id = ? AND username = ?",
|
||||||
.{ community_id, username },
|
.{ community_id, username },
|
||||||
|
@ -79,13 +79,13 @@ pub fn create(
|
||||||
return error.UsernameTaken;
|
return error.UsernameTaken;
|
||||||
}
|
}
|
||||||
|
|
||||||
db.insert2("user", .{
|
db.insert("user", .{
|
||||||
.id = id,
|
.id = id,
|
||||||
.username = username,
|
.username = username,
|
||||||
.community_id = community_id,
|
.community_id = community_id,
|
||||||
}) catch return error.DbError;
|
}) catch return error.DbError;
|
||||||
try auth.passwords.create(db, id, password, alloc);
|
try auth.passwords.create(db, id, password, alloc);
|
||||||
db.insert2("local_user", .{
|
db.insert("local_user", .{
|
||||||
.user_id = id,
|
.user_id = id,
|
||||||
.invite_id = options.invite_id,
|
.invite_id = options.invite_id,
|
||||||
.email = options.email,
|
.email = options.email,
|
||||||
|
|
228
src/main/db.zig
228
src/main/db.zig
|
@ -9,20 +9,6 @@ const DateTime = util.DateTime;
|
||||||
const String = []const u8;
|
const String = []const u8;
|
||||||
const comptimePrint = std.fmt.comptimePrint;
|
const comptimePrint = std.fmt.comptimePrint;
|
||||||
|
|
||||||
fn tableName(comptime T: type) String {
|
|
||||||
return switch (T) {
|
|
||||||
models.Note => "note",
|
|
||||||
models.Actor => "actor",
|
|
||||||
models.Reaction => "reaction",
|
|
||||||
models.User => "user",
|
|
||||||
models.LocalUser => "local_user",
|
|
||||||
models.Token => "token",
|
|
||||||
models.Invite => "invite",
|
|
||||||
models.Community => "community",
|
|
||||||
else => unreachable,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
fn readRow(comptime RowTuple: type, row: sql.Row, allocator: ?std.mem.Allocator) !RowTuple {
|
fn readRow(comptime RowTuple: type, row: sql.Row, allocator: ?std.mem.Allocator) !RowTuple {
|
||||||
var result: RowTuple = undefined;
|
var result: RowTuple = undefined;
|
||||||
// TODO: undo allocations on failure
|
// TODO: undo allocations on failure
|
||||||
|
@ -60,95 +46,6 @@ pub fn ResultSet(comptime result_types: []const type) type {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Combines an array/tuple of strings into a single string, with a copy of
|
|
||||||
// joiner in between each one
|
|
||||||
fn join(comptime vals: anytype, comptime joiner: String) String {
|
|
||||||
comptime {
|
|
||||||
if (vals.len == 0) return "";
|
|
||||||
|
|
||||||
var result: String = "";
|
|
||||||
for (vals) |v| {
|
|
||||||
result = comptimePrint("{s}{s}{s}", .{ result, joiner, v });
|
|
||||||
}
|
|
||||||
|
|
||||||
return result[joiner.len..];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select query builder struct
|
|
||||||
const Query = struct {
|
|
||||||
select: []const String, // the fields to grab
|
|
||||||
from: String, // what table to query
|
|
||||||
where: String, // conditions on records to query
|
|
||||||
order_by: ?[]const String = null,
|
|
||||||
group_by: ?[]const String = null,
|
|
||||||
limit: ?usize = null,
|
|
||||||
offset: ?usize = null,
|
|
||||||
|
|
||||||
pub fn str(comptime self: Query) String {
|
|
||||||
comptime {
|
|
||||||
const order_expr = if (self.order_by == null) "" else comptimePrint(" ORDER BY {s}", .{join(self.order_by.?, ", ")});
|
|
||||||
const group_expr = if (self.group_by == null) "" else comptimePrint(" GROUP BY {s}", .{join(self.group_by.?, ", ")});
|
|
||||||
const limit_expr = if (self.limit == null) "" else comptimePrint(" LIMIT {?}", .{self.limit});
|
|
||||||
const offset_expr = if (self.offset == null) "" else comptimePrint(" OFFSET {?}", .{self.offset});
|
|
||||||
return comptimePrint(
|
|
||||||
"SELECT {s} FROM {s} WHERE {s}{s}{s}{s}{s};",
|
|
||||||
.{ join(self.select, ", "), self.from, self.where, order_expr, group_expr, limit_expr, offset_expr },
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Insert query builder struct
|
|
||||||
const Insert = struct {
|
|
||||||
into: String, // the table to modify
|
|
||||||
columns: []const String, // the columns to provide
|
|
||||||
count: usize = 1, // the number of records to insert
|
|
||||||
|
|
||||||
pub fn str(comptime self: Insert) String {
|
|
||||||
comptime {
|
|
||||||
const row = comptimePrint(
|
|
||||||
"({s})",
|
|
||||||
.{join(.{"?"} ** self.columns.len, ", ")},
|
|
||||||
);
|
|
||||||
|
|
||||||
return comptimePrint(
|
|
||||||
"INSERT INTO {s} ({s}) VALUES {s};",
|
|
||||||
.{ self.into, join(self.columns, ", "), join(.{row} ** self.count, ", ") },
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// treats the inputs as sets and performs set subtraction. Assumes that elements do not appear
|
|
||||||
// multiple times.
|
|
||||||
fn setSubtract(comptime lhs: []const String, comptime rhs: []const String) []const String {
|
|
||||||
comptime {
|
|
||||||
var result: [lhs.len]String = undefined;
|
|
||||||
var count = 0;
|
|
||||||
|
|
||||||
for (lhs) |l| {
|
|
||||||
const keep = for (rhs) |r| {
|
|
||||||
if (std.mem.eql(u8, l, r)) break false;
|
|
||||||
} else true;
|
|
||||||
|
|
||||||
if (keep) {
|
|
||||||
result[count] = l;
|
|
||||||
count += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result[0..count];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// returns all fields of T except for those in a specific set
|
|
||||||
fn fieldsExcept(comptime T: type, comptime to_ignore: []const String) []const String {
|
|
||||||
comptime {
|
|
||||||
return setSubtract(std.meta.fieldNames(T), to_ignore);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Binds a value to a parameter in the query. Use this instead of string
|
// Binds a value to a parameter in the query. Use this instead of string
|
||||||
// concatenation to avoid injection attacks;
|
// concatenation to avoid injection attacks;
|
||||||
// If a given type is not supported by this function, you can add support by
|
// If a given type is not supported by this function, you can add support by
|
||||||
|
@ -234,7 +131,7 @@ pub const Database = struct {
|
||||||
self.db.close();
|
self.db.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn exec2(
|
pub fn exec(
|
||||||
self: *Database,
|
self: *Database,
|
||||||
comptime result_types: []const type,
|
comptime result_types: []const type,
|
||||||
comptime q: []const u8,
|
comptime q: []const u8,
|
||||||
|
@ -254,14 +151,14 @@ pub const Database = struct {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn execRow2(
|
pub fn execRow(
|
||||||
self: *Database,
|
self: *Database,
|
||||||
comptime result_types: []const type,
|
comptime result_types: []const type,
|
||||||
comptime q: []const u8,
|
comptime q: []const u8,
|
||||||
args: anytype,
|
args: anytype,
|
||||||
allocator: ?std.mem.Allocator,
|
allocator: ?std.mem.Allocator,
|
||||||
) ExecError!?ResultSet(result_types).Row {
|
) ExecError!?ResultSet(result_types).Row {
|
||||||
var results = try self.exec2(result_types, q, args);
|
var results = try self.exec(result_types, q, args);
|
||||||
defer results.finish();
|
defer results.finish();
|
||||||
|
|
||||||
const row = results.row(allocator);
|
const row = results.row(allocator);
|
||||||
|
@ -287,7 +184,7 @@ pub const Database = struct {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn insert2(
|
pub fn insert(
|
||||||
self: *Database,
|
self: *Database,
|
||||||
comptime table: []const u8,
|
comptime table: []const u8,
|
||||||
value: anytype,
|
value: anytype,
|
||||||
|
@ -299,121 +196,6 @@ pub const Database = struct {
|
||||||
"INSERT INTO {s} VALUES {s}",
|
"INSERT INTO {s} VALUES {s}",
|
||||||
.{ table_spec, value_spec },
|
.{ table_spec, value_spec },
|
||||||
);
|
);
|
||||||
_ = try self.execRow2(&.{}, q, value, null);
|
_ = try self.execRow(&.{}, q, value, null);
|
||||||
}
|
|
||||||
|
|
||||||
// Lower level function
|
|
||||||
pub fn execRow(
|
|
||||||
self: *Database,
|
|
||||||
comptime q: []const u8,
|
|
||||||
args: anytype,
|
|
||||||
comptime return_types: []const type,
|
|
||||||
alloc: std.mem.Allocator,
|
|
||||||
) !?std.meta.Tuple(return_types) {
|
|
||||||
var stmt = try self.db.prepare(q);
|
|
||||||
errdefer stmt.finalize();
|
|
||||||
|
|
||||||
inline for (std.meta.fields(@TypeOf(args))) |field, i| {
|
|
||||||
try bind(stmt, @intCast(u15, i + 1), @field(args, field.name));
|
|
||||||
}
|
|
||||||
|
|
||||||
const row = (try stmt.step()) orelse return null;
|
|
||||||
var result: std.meta.Tuple(return_types) = undefined;
|
|
||||||
inline for (std.meta.fields(@TypeOf(result))) |field, i| {
|
|
||||||
@field(result, field.name) = try getAlloc(row, field.field_type, i, alloc);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the first row that satisfies an equality check on the
|
|
||||||
// field specified
|
|
||||||
pub fn getBy(
|
|
||||||
self: *Database,
|
|
||||||
comptime T: type,
|
|
||||||
comptime field: std.meta.FieldEnum(T),
|
|
||||||
val: std.meta.fieldInfo(T, field).field_type,
|
|
||||||
alloc: std.mem.Allocator,
|
|
||||||
) !?T {
|
|
||||||
const field_name = std.meta.fieldInfo(T, field).name;
|
|
||||||
const fields = comptime fieldsExcept(T, &.{field_name});
|
|
||||||
const q = comptime (Query{
|
|
||||||
.select = fields,
|
|
||||||
.from = tableName(T),
|
|
||||||
.where = field_name ++ " = ?",
|
|
||||||
.limit = 1,
|
|
||||||
}).str();
|
|
||||||
|
|
||||||
var stmt = try self.db.prepare(q);
|
|
||||||
defer stmt.finalize();
|
|
||||||
|
|
||||||
try bind(stmt, 1, val);
|
|
||||||
|
|
||||||
const row = (try stmt.step()) orelse return null;
|
|
||||||
var result: T = undefined;
|
|
||||||
@field(result, field_name) = val;
|
|
||||||
|
|
||||||
inline for (fields) |f, i| {
|
|
||||||
@field(result, f) = getAlloc(row, @TypeOf(@field(result, f)), i, alloc) catch unreachable;
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns an array of all rows that satisfy an equality check
|
|
||||||
// TODO: paginate this
|
|
||||||
pub fn getWhereEq(
|
|
||||||
self: *Database,
|
|
||||||
comptime T: type,
|
|
||||||
comptime field: std.meta.FieldEnum(T),
|
|
||||||
val: std.meta.fieldInfo(T, field).field_type,
|
|
||||||
alloc: std.mem.Allocator,
|
|
||||||
) ![]T {
|
|
||||||
const field_name = std.meta.fieldInfo(T, field).name;
|
|
||||||
const fields = comptime fieldsExcept(T, &.{field_name});
|
|
||||||
const q = comptime (Query{
|
|
||||||
.select = fields,
|
|
||||||
.from = tableName(T),
|
|
||||||
.where = field_name ++ " = ?",
|
|
||||||
}).str();
|
|
||||||
|
|
||||||
var stmt = try self.db.prepare(q);
|
|
||||||
defer stmt.finalize();
|
|
||||||
|
|
||||||
try bind(stmt, 1, val);
|
|
||||||
|
|
||||||
var results = std.ArrayList(T).init(alloc);
|
|
||||||
|
|
||||||
while (try stmt.step()) |row| {
|
|
||||||
var item: T = undefined;
|
|
||||||
@field(item, field_name) = val;
|
|
||||||
inline for (fields) |f, i| {
|
|
||||||
@field(item, f) = getAlloc(row, @TypeOf(@field(item, f)), i, alloc) catch unreachable;
|
|
||||||
}
|
|
||||||
|
|
||||||
try results.append(item);
|
|
||||||
}
|
|
||||||
|
|
||||||
return results.toOwnedSlice();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Inserts a row into the database
|
|
||||||
// TODO: consider making this generic?
|
|
||||||
pub fn insert(self: *Database, comptime T: type, val: T) !void {
|
|
||||||
const fields = comptime std.meta.fieldNames(T);
|
|
||||||
const q = comptime (Insert{
|
|
||||||
.into = tableName(T),
|
|
||||||
.columns = fields,
|
|
||||||
.count = 1,
|
|
||||||
}).str();
|
|
||||||
|
|
||||||
var stmt = try self.db.prepare(q);
|
|
||||||
defer stmt.finalize();
|
|
||||||
|
|
||||||
inline for (fields) |f, i| {
|
|
||||||
try bind(stmt, i + 1, @field(val, f));
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((try stmt.step()) != null) return error.UnknownError;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in a new issue