diff --git a/src/main/api.zig b/src/main/api.zig index ac7f499..2042cad 100644 --- a/src/main/api.zig +++ b/src/main/api.zig @@ -127,7 +127,7 @@ pub const ApiSource = struct { } fn getCommunityFromHost(self: *ApiSource, host: []const u8) !?Uuid { - if (try self.db.execRow2( + if (try self.db.execRow( &.{Uuid}, "SELECT id FROM community WHERE host = ?", .{host}, @@ -183,35 +183,6 @@ fn ApiConn(comptime DbConn: type) type { self.arena.deinit(); } - fn getAuthenticatedUser(self: *Self) !models.User { - if (self.user_id) |id| { - const user = try self.db.getBy(models.User, .id, id, self.arena.allocator()); - if (user == null) return error.NotAuthorized; - - return user.?; - } else { - return error.NotAuthorized; - } - } - - fn getAuthenticatedLocalUser(self: *Self) !models.LocalUser { - if (self.user_id) |user_id| { - const local_user = try self.db.getBy(models.LocalUser, .user_id, user_id, self.arena.allocator()); - if (local_user == null) return error.NotAuthorized; - - return local_user.?; - } else { - return error.NotAuthorized; - } - } - - fn getAuthenticatedActor(self: *Self) !models.Actor { - return if (self.user_id) |user_id| - (try self.db.getBy(models.Actor, .user_id, user_id, self.arena.allocator())) orelse error.NotAuthorized - else - error.NotAuthorized; - } - pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResult { const user_id = (try services.users.lookupByUsername(&self.db, username, self.community_id)) orelse return error.InvalidLogin; try services.auth.passwords.verify(&self.db, user_id, password, self.internal_alloc); @@ -230,7 +201,7 @@ fn ApiConn(comptime DbConn: type) type { }; pub fn getTokenInfo(self: *Self) !TokenInfo { if (self.user_id) |user_id| { - const result = (try self.db.execRow2( + const result = (try self.db.execRow( &.{[]const u8}, "SELECT username FROM user WHERE id = ?", .{user_id}, @@ -260,7 +231,7 @@ fn ApiConn(comptime DbConn: type) type { // Users can only make invites to their own community, unless they // are system users const community_id = if (options.to_community) |host| blk: { - const desired_community = (try self.db.execRow2( + const desired_community = (try self.db.execRow( &.{Uuid}, "SELECT id FROM community WHERE host = ?", .{host}, @@ -307,9 +278,5 @@ fn ApiConn(comptime DbConn: type) type { return invite; } - - pub fn getInvite(self: *Self, id: Uuid) !?models.Invite { - return self.db.getBy(models.Invite, .id, id, self.arena.allocator()); - } }; } diff --git a/src/main/api/' b/src/main/api/' deleted file mode 100644 index 5c82228..0000000 --- a/src/main/api/' +++ /dev/null @@ -1,149 +0,0 @@ -const std = @import("std"); -const util = @import("util"); - -const Uuid = util.Uuid; -const DateTime = util.DateTime; - -pub const passwords = struct { - const PwHash = std.crypto.pwhash.scrypt; - const pw_hash_params = PwHash.Params.interactive; - const pw_hash_encoding = .phc; - const pw_hash_buf_size = 128; - - const PwHashBuf = [pw_hash_buf_size]u8; - - pub const Password = struct { - user_id: Uuid, - - hashed_password: []const u8, - }; - - // Returned slice points into buf - fn hashPassword(password: []const u8, alloc: std.mem.Allocator, buf: *PwHashBuf) []const u8 { - return PwHash.strHash(password, .{ .allocator = alloc, .params = pw_hash_params, .encoding = pw_hash_encoding }, buf) catch unreachable; - } - - pub const VerifyError = error{ - InvalidLogin, - DbError, - }; - pub fn verify(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) VerifyError!void { - const hash = (try db.execRow2( - &.{PwHashBuf}, - "SELECT hashed_password FROM account_password WHERE user_id = ? LIMIT 1", - .{user_id}, - null, - )) orelse return error.PasswordNotFound; - - try PwHash.strVerify(&hash[0], password, .{ .allocator = alloc }); - } - - pub const CreateError = error{DbError}; - pub fn create(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) CreateError!void { - var buf: PwHashBuf = undefined; - const hash = PwHash.strHash(password, .{ .allocator = alloc, .params = pw_hash_params, .encoding = pw_hash_encoding }, &buf) catch unreachable; - - try db.insert2("account_password", .{ - .user_id = user_id, - .hashed_password = hash, - }); - } -}; - -pub const tokens = struct { - const token_len = 20; - pub const Token = struct { - pub const Value = [token_len]u8; - pub const Info = struct { - user_id: Uuid, - issued_at: DateTime, - }; - - value: Value, - - issued_at: DateTime, - }; - - const TokenHash = std.crypto.hash.sha2.Sha256; - - const DbToken = struct { - hash: []const u8, - user_id: Uuid, - issued_at: DateTime, - }; - - pub const CreateError = error{DbError}; - pub fn create(db: anytype, user_id: Uuid) CreateError!Token { - var token: [token_len]u8 = undefined; - std.crypto.random.bytes(&token); - - var hash: [TokenHash.digest_length]u8 = undefined; - TokenHash.hash(&token, &hash, .{}); - - const issued_at = DateTime.now(); - - db.insert2("token", DbToken{ - .hash = &hash, - .user_id = user_id, - .issued_at = issued_at, - }) catch return error.DbError; - - return Token{ - .value = token, - .issued_at = issued_at, - }; - } - - fn lookupUserTokenFromHash(db: anytype, hash: []const u8, community_id: Uuid) !?Uuid { - return if (try db.execRow2( - &.{ Uuid, DateTime }, - \\SELECT user.id, token.issued_at - \\FROM token JOIN user ON token.user_id = user.id - \\WHERE user.community_id = ? AND token.hash = ? - \\LIMIT 1 - , - .{ community_id, hash }, - null, - )) |result| - Token.Info{ - .user_id = result[0], - .issued_at = result[1], - } - else - null; - } - - fn lookupSystemTokenFromHash(db: anytype, hash: []const u8) !?Token.Info { - return if (try db.execRow2( - &.{ Uuid, DateTime }, - \\SELECT user.id, token.issued_at - \\FROM token JOIN user ON token.user_id = user.id - \\WHERE user.community_id IS NULL AND token.hash = ? - \\LIMIT 1 - , - .{hash}, - null, - )) |result| - Token.Info{ - .user_id = result[0], - .issued_at = result[1], - } - else - null; - } - - pub const VerifyError = error{ InvalidToken, DbError }; - pub fn verifyToken(db: anytype, token: []const u8, community_id: ?Uuid) VerifyError!Token.Info { - var hash: [TokenHash.digest_length]u8 = undefined; - TokenHash.hash(&token, &hash, .{}); - - const token_info = if (community_id) |id| - lookupUserTokenFromHash(db, &hash, id) catch return error.DbError - else - lookupSystemTokenFromHash(db, &hash) catch return error.DbError; - - if (token_info) |info| return info; - - return error.InvalidToken; - } -}; diff --git a/src/main/api/auth.zig b/src/main/api/auth.zig index 5e8ccee..04b36b3 100644 --- a/src/main/api/auth.zig +++ b/src/main/api/auth.zig @@ -23,7 +23,7 @@ pub const passwords = struct { }; pub fn verify(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) VerifyError!void { // TODO: This could be done w/o the dynamically allocated hash buf - const hash = (db.execRow2( + const hash = (db.execRow( &.{[]const u8}, "SELECT hashed_password FROM account_password WHERE user_id = ? LIMIT 1", .{user_id}, @@ -39,7 +39,7 @@ pub const passwords = struct { var buf: PwHashBuf = undefined; const hash = PwHash.strHash(password, .{ .allocator = alloc, .params = pw_hash_params, .encoding = pw_hash_encoding }, &buf) catch unreachable; - db.insert2("account_password", .{ + db.insert("account_password", .{ .user_id = user_id, .hashed_password = hash, }) catch return error.DbError; @@ -79,7 +79,7 @@ pub const tokens = struct { const issued_at = DateTime.now(); - db.insert2("token", DbToken{ + db.insert("token", DbToken{ .hash = &hash, .user_id = user_id, .issued_at = issued_at, @@ -95,7 +95,7 @@ pub const tokens = struct { } fn lookupUserTokenFromHash(db: anytype, hash: []const u8, community_id: Uuid) !?Token.Info { - return if (try db.execRow2( + return if (try db.execRow( &.{ Uuid, DateTime }, \\SELECT user.id, token.issued_at \\FROM token JOIN user ON token.user_id = user.id @@ -114,7 +114,7 @@ pub const tokens = struct { } fn lookupSystemTokenFromHash(db: anytype, hash: []const u8) !?Token.Info { - return if (try db.execRow2( + return if (try db.execRow( &.{ Uuid, DateTime }, \\SELECT user.id, token.issued_at \\FROM token JOIN user ON token.user_id = user.id diff --git a/src/main/api/communities.zig b/src/main/api/communities.zig index cfeccf9..75c3d9a 100644 --- a/src/main/api/communities.zig +++ b/src/main/api/communities.zig @@ -66,11 +66,11 @@ pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Co .scheme = scheme, }; - if ((try db.execRow2(&.{Uuid}, "SELECT id FROM community WHERE host = ?", .{host}, null)) != null) { + if ((try db.execRow(&.{Uuid}, "SELECT id FROM community WHERE host = ?", .{host}, null)) != null) { return error.CommunityExists; } - try db.insert2("community", community); + try db.insert("community", community); return community; } diff --git a/src/main/api/users.zig b/src/main/api/users.zig index cc9b4c9..ff0cf15 100644 --- a/src/main/api/users.zig +++ b/src/main/api/users.zig @@ -36,7 +36,7 @@ pub const CreateOptions = struct { }; fn lookupSystemUserByUsername(db: anytype, username: []const u8) !?Uuid { - return if (try db.execRow2( + return if (try db.execRow( &.{Uuid}, "SELECT user.id FROM user WHERE community_id IS NULL AND username = ?", .{username}, @@ -48,7 +48,7 @@ fn lookupSystemUserByUsername(db: anytype, username: []const u8) !?Uuid { } fn lookupUserByUsername(db: anytype, username: []const u8, community_id: Uuid) !?Uuid { - return if (try db.execRow2( + return if (try db.execRow( &.{Uuid}, "SELECT user.id FROM user WHERE community_id = ? AND username = ?", .{ community_id, username }, @@ -79,13 +79,13 @@ pub fn create( return error.UsernameTaken; } - db.insert2("user", .{ + db.insert("user", .{ .id = id, .username = username, .community_id = community_id, }) catch return error.DbError; try auth.passwords.create(db, id, password, alloc); - db.insert2("local_user", .{ + db.insert("local_user", .{ .user_id = id, .invite_id = options.invite_id, .email = options.email, diff --git a/src/main/db.zig b/src/main/db.zig index 92daf0a..76b07a1 100644 --- a/src/main/db.zig +++ b/src/main/db.zig @@ -9,20 +9,6 @@ const DateTime = util.DateTime; const String = []const u8; const comptimePrint = std.fmt.comptimePrint; -fn tableName(comptime T: type) String { - return switch (T) { - models.Note => "note", - models.Actor => "actor", - models.Reaction => "reaction", - models.User => "user", - models.LocalUser => "local_user", - models.Token => "token", - models.Invite => "invite", - models.Community => "community", - else => unreachable, - }; -} - fn readRow(comptime RowTuple: type, row: sql.Row, allocator: ?std.mem.Allocator) !RowTuple { var result: RowTuple = undefined; // TODO: undo allocations on failure @@ -60,95 +46,6 @@ pub fn ResultSet(comptime result_types: []const type) type { }; } -// Combines an array/tuple of strings into a single string, with a copy of -// joiner in between each one -fn join(comptime vals: anytype, comptime joiner: String) String { - comptime { - if (vals.len == 0) return ""; - - var result: String = ""; - for (vals) |v| { - result = comptimePrint("{s}{s}{s}", .{ result, joiner, v }); - } - - return result[joiner.len..]; - } -} - -// Select query builder struct -const Query = struct { - select: []const String, // the fields to grab - from: String, // what table to query - where: String, // conditions on records to query - order_by: ?[]const String = null, - group_by: ?[]const String = null, - limit: ?usize = null, - offset: ?usize = null, - - pub fn str(comptime self: Query) String { - comptime { - const order_expr = if (self.order_by == null) "" else comptimePrint(" ORDER BY {s}", .{join(self.order_by.?, ", ")}); - const group_expr = if (self.group_by == null) "" else comptimePrint(" GROUP BY {s}", .{join(self.group_by.?, ", ")}); - const limit_expr = if (self.limit == null) "" else comptimePrint(" LIMIT {?}", .{self.limit}); - const offset_expr = if (self.offset == null) "" else comptimePrint(" OFFSET {?}", .{self.offset}); - return comptimePrint( - "SELECT {s} FROM {s} WHERE {s}{s}{s}{s}{s};", - .{ join(self.select, ", "), self.from, self.where, order_expr, group_expr, limit_expr, offset_expr }, - ); - } - } -}; - -// Insert query builder struct -const Insert = struct { - into: String, // the table to modify - columns: []const String, // the columns to provide - count: usize = 1, // the number of records to insert - - pub fn str(comptime self: Insert) String { - comptime { - const row = comptimePrint( - "({s})", - .{join(.{"?"} ** self.columns.len, ", ")}, - ); - - return comptimePrint( - "INSERT INTO {s} ({s}) VALUES {s};", - .{ self.into, join(self.columns, ", "), join(.{row} ** self.count, ", ") }, - ); - } - } -}; - -// treats the inputs as sets and performs set subtraction. Assumes that elements do not appear -// multiple times. -fn setSubtract(comptime lhs: []const String, comptime rhs: []const String) []const String { - comptime { - var result: [lhs.len]String = undefined; - var count = 0; - - for (lhs) |l| { - const keep = for (rhs) |r| { - if (std.mem.eql(u8, l, r)) break false; - } else true; - - if (keep) { - result[count] = l; - count += 1; - } - } - - return result[0..count]; - } -} - -// returns all fields of T except for those in a specific set -fn fieldsExcept(comptime T: type, comptime to_ignore: []const String) []const String { - comptime { - return setSubtract(std.meta.fieldNames(T), to_ignore); - } -} - // Binds a value to a parameter in the query. Use this instead of string // concatenation to avoid injection attacks; // If a given type is not supported by this function, you can add support by @@ -234,7 +131,7 @@ pub const Database = struct { self.db.close(); } - pub fn exec2( + pub fn exec( self: *Database, comptime result_types: []const type, comptime q: []const u8, @@ -254,14 +151,14 @@ pub const Database = struct { }; } - pub fn execRow2( + pub fn execRow( self: *Database, comptime result_types: []const type, comptime q: []const u8, args: anytype, allocator: ?std.mem.Allocator, ) ExecError!?ResultSet(result_types).Row { - var results = try self.exec2(result_types, q, args); + var results = try self.exec(result_types, q, args); defer results.finish(); const row = results.row(allocator); @@ -287,7 +184,7 @@ pub const Database = struct { } } - pub fn insert2( + pub fn insert( self: *Database, comptime table: []const u8, value: anytype, @@ -299,121 +196,6 @@ pub const Database = struct { "INSERT INTO {s} VALUES {s}", .{ table_spec, value_spec }, ); - _ = try self.execRow2(&.{}, q, value, null); - } - - // Lower level function - pub fn execRow( - self: *Database, - comptime q: []const u8, - args: anytype, - comptime return_types: []const type, - alloc: std.mem.Allocator, - ) !?std.meta.Tuple(return_types) { - var stmt = try self.db.prepare(q); - errdefer stmt.finalize(); - - inline for (std.meta.fields(@TypeOf(args))) |field, i| { - try bind(stmt, @intCast(u15, i + 1), @field(args, field.name)); - } - - const row = (try stmt.step()) orelse return null; - var result: std.meta.Tuple(return_types) = undefined; - inline for (std.meta.fields(@TypeOf(result))) |field, i| { - @field(result, field.name) = try getAlloc(row, field.field_type, i, alloc); - } - - return result; - } - - // Returns the first row that satisfies an equality check on the - // field specified - pub fn getBy( - self: *Database, - comptime T: type, - comptime field: std.meta.FieldEnum(T), - val: std.meta.fieldInfo(T, field).field_type, - alloc: std.mem.Allocator, - ) !?T { - const field_name = std.meta.fieldInfo(T, field).name; - const fields = comptime fieldsExcept(T, &.{field_name}); - const q = comptime (Query{ - .select = fields, - .from = tableName(T), - .where = field_name ++ " = ?", - .limit = 1, - }).str(); - - var stmt = try self.db.prepare(q); - defer stmt.finalize(); - - try bind(stmt, 1, val); - - const row = (try stmt.step()) orelse return null; - var result: T = undefined; - @field(result, field_name) = val; - - inline for (fields) |f, i| { - @field(result, f) = getAlloc(row, @TypeOf(@field(result, f)), i, alloc) catch unreachable; - } - - return result; - } - - // Returns an array of all rows that satisfy an equality check - // TODO: paginate this - pub fn getWhereEq( - self: *Database, - comptime T: type, - comptime field: std.meta.FieldEnum(T), - val: std.meta.fieldInfo(T, field).field_type, - alloc: std.mem.Allocator, - ) ![]T { - const field_name = std.meta.fieldInfo(T, field).name; - const fields = comptime fieldsExcept(T, &.{field_name}); - const q = comptime (Query{ - .select = fields, - .from = tableName(T), - .where = field_name ++ " = ?", - }).str(); - - var stmt = try self.db.prepare(q); - defer stmt.finalize(); - - try bind(stmt, 1, val); - - var results = std.ArrayList(T).init(alloc); - - while (try stmt.step()) |row| { - var item: T = undefined; - @field(item, field_name) = val; - inline for (fields) |f, i| { - @field(item, f) = getAlloc(row, @TypeOf(@field(item, f)), i, alloc) catch unreachable; - } - - try results.append(item); - } - - return results.toOwnedSlice(); - } - - // Inserts a row into the database - // TODO: consider making this generic? - pub fn insert(self: *Database, comptime T: type, val: T) !void { - const fields = comptime std.meta.fieldNames(T); - const q = comptime (Insert{ - .into = tableName(T), - .columns = fields, - .count = 1, - }).str(); - - var stmt = try self.db.prepare(q); - defer stmt.finalize(); - - inline for (fields) |f, i| { - try bind(stmt, i + 1, @field(val, f)); - } - - if ((try stmt.step()) != null) return error.UnknownError; + _ = try self.execRow(&.{}, q, value, null); } };