diff --git a/src/main/api.zig b/src/main/api.zig index f02dac9..840a4bf 100644 --- a/src/main/api.zig +++ b/src/main/api.zig @@ -156,34 +156,31 @@ fn ApiConn(comptime DbConn: type) type { self.arena.deinit(); } - fn getAuthenticatedUser(self: *Self) !models.LocalUser { + fn getAuthenticatedLocalUser(self: *Self) !models.LocalUser { if (self.as_user) |user_id| { - const local_user = try self.db.getBy(models.LocalUser, .id, user_id, self.arena.allocator()); - if (local_user == null) return error.UserNotFound; + const local_user = try self.db.getBy(models.LocalUser, .user_id, user_id, self.arena.allocator()); + if (local_user == null) return error.NotAuthorized; return local_user.?; } else { - return error.NotAuthenticated; + return error.NotAuthorized; } } fn getAuthenticatedActor(self: *Self) !models.Actor { - const user = try self.getAuthenticatedUser(); - if (user.actor_id) |actor_id| { - const actor = try self.db.getBy(models.Actor, .id, actor_id, self.arena); - return actor.?; - } else { - return error.NoActor; - } + return if (self.as_user) |user_id| + (try self.db.getBy(models.Actor, .user_id, user_id, self.arena.allocator())) orelse error.NotAuthorized + else + error.NotAuthorized; } pub fn createNote(self: *Self, info: NoteCreateInfo) !models.Note { const id = Uuid.randV4(prng.random()); - const user = try self.getAuthenticatedUser(); + const actor = try self.getAuthenticatedActor(); const note = models.Note{ .id = id, - .author_id = user.actor_id orelse return error.NotAuthorized, + .author_id = actor.user_id, .content = info.content, .created_at = DateTime.now(), @@ -197,18 +194,19 @@ fn ApiConn(comptime DbConn: type) type { return self.db.getBy(models.Note, .id, id, self.arena.allocator()); } - pub fn getActor(self: *Self, id: Uuid) !?models.Actor { - return self.db.getBy(models.Actor, .id, id, self.arena.allocator()); + pub fn getActor(self: *Self, user_id: Uuid) !?models.Actor { + return self.db.getBy(models.Actor, .user_id, user_id, self.arena.allocator()); } pub fn getActorByHandle(self: *Self, handle: []const u8) !?models.Actor { - return self.db.getBy(models.Actor, .handle, handle, self.arena.allocator()); + const user = (try self.db.getBy(models.User, .username, handle, self.arena.allocator())) orelse return null; + return self.db.getBy(models.Actor, .user_id, user.id, self.arena.allocator()); } pub fn react(self: *Self, note_id: Uuid) !void { const id = Uuid.randV4(prng.random()); - const user = try self.getAuthenticatedUser(); - try self.db.insert(models.Reaction, .{ .id = id, .note_id = note_id, .reactor_id = user.actor_id orelse return error.NotAuthorized, .created_at = DateTime.now() }); + const actor = try self.getAuthenticatedActor(); + try self.db.insert(models.Reaction, .{ .id = id, .note_id = note_id, .reactor_id = actor.user_id, .created_at = DateTime.now() }); } pub fn listReacts(self: *Self, note_id: Uuid) ![]models.Reaction { @@ -216,18 +214,13 @@ fn ApiConn(comptime DbConn: type) type { } pub fn register(self: *Self, info: RegistrationInfo) !models.Actor { - const actor_id = Uuid.randV4(prng.random()); const user_id = Uuid.randV4(prng.random()); // TODO: lock for transaction - if (try self.db.existsWhereEq(models.LocalUser, .username, info.username)) { + if (try self.db.existsWhereEq(models.User, .username, info.username)) { return error.UsernameUnavailable; } - if (try self.db.existsWhereEq(models.Actor, .handle, info.username)) { - return error.InconsistentDb; - } - const now = DateTime.now(); const invite_id = if (info.invite_code) |invite_code| blk: { const invite = (try self.db.getBy(models.Invite, .invite_code, invite_code, self.arena.allocator())) orelse return error.InvalidInvite; @@ -244,39 +237,42 @@ fn ApiConn(comptime DbConn: type) type { var buf: [pw_hash_buf_size]u8 = undefined; const hash = try PwHash.strHash(info.password, .{ .allocator = self.internal_alloc, .params = pw_hash_params, .encoding = pw_hash_encoding }, &buf); - const actor = models.Actor{ - .id = actor_id, - .handle = info.username, + const user = models.User{ + .id = user_id, + .username = info.username, .created_at = now, }; - const user = models.LocalUser{ - .id = user_id, - .actor_id = actor_id, - .username = info.username, + const actor = models.Actor{ + .user_id = user_id, + .public_id = "abc", // TODO + }; + const local_user = models.LocalUser{ + .user_id = user_id, .email = info.email, .invite_id = invite_id, .hashed_password = hash, .password_changed_at = now, - .created_at = now, }; + try self.db.insert(models.User, user); try self.db.insert(models.Actor, actor); - try self.db.insert(models.LocalUser, user); + try self.db.insert(models.LocalUser, local_user); return actor; } pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResult { // TODO: This gives away the existence of a user through a timing side channel. is that acceptable? - const user_info = (try self.db.getBy(models.LocalUser, .username, username, self.arena.allocator())) orelse return error.InvalidLogin; + const user_info = (try self.db.getBy(models.User, .username, username, self.arena.allocator())) orelse return error.InvalidLogin; + const local_user_info = (try self.db.getBy(models.LocalUser, .user_id, user_info.id, self.arena.allocator())) orelse return error.InvalidLogin; //defer free(self.arena.allocator(), user_info); const Hash = std.crypto.pwhash.scrypt; - Hash.strVerify(user_info.hashed_password, password, .{ .allocator = self.internal_alloc }) catch |err| switch (err) { + Hash.strVerify(local_user_info.hashed_password, password, .{ .allocator = self.internal_alloc }) catch |err| switch (err) { error.PasswordVerificationFailed => return error.InvalidLogin, else => return err, }; - const token = try self.createToken(user_info); + const token = try self.createToken(user_info.id); var token_enc: [token_str_len]u8 = undefined; _ = std.base64.standard.Encoder.encode(&token_enc, &token.value); @@ -292,7 +288,7 @@ fn ApiConn(comptime DbConn: type) type { info: models.Token, value: [token_len]u8, }; - fn createToken(self: *Self, user: models.LocalUser) !TokenResult { + fn createToken(self: *Self, user_id: Uuid) !TokenResult { var token: [token_len]u8 = undefined; std.crypto.random.bytes(&token); @@ -302,7 +298,7 @@ fn ApiConn(comptime DbConn: type) type { const db_token = models.Token{ .id = Uuid.randV4(prng.random()), .hash = .{ .data = hash }, - .user_id = user.id, + .user_id = user_id, .issued_at = DateTime.now(), }; @@ -316,7 +312,7 @@ fn ApiConn(comptime DbConn: type) type { pub fn createInvite(self: *Self, options: InviteOptions) !models.Invite { const id = Uuid.randV4(prng.random()); - const user_id = (try self.getAuthenticatedUser()).id; + const user_id = (try self.getAuthenticatedLocalUser()).user_id; var code: [invite_code_len]u8 = undefined; std.crypto.random.bytes(&code); diff --git a/src/main/db.zig b/src/main/db.zig index 068c116..589511e 100644 --- a/src/main/db.zig +++ b/src/main/db.zig @@ -8,11 +8,25 @@ const DateTime = util.DateTime; const String = []const u8; const comptimePrint = std.fmt.comptimePrint; +fn baseTypeName(comptime T: type) []const u8 { + comptime { + const name = @typeName(T); + const start = for (name) |_, i| { + if (name[name.len - i] == '.') break name.len - i; + } else 0; + + return name[start..]; + } +} + fn tableName(comptime T: type) String { + //return util.case.pascalToSnake(baseTypeName(T)); + return switch (T) { models.Note => "note", models.Actor => "actor", models.Reaction => "reaction", + models.User => "user", models.LocalUser => "local_user", models.Token => "token", models.Invite => "invite", @@ -138,33 +152,28 @@ pub const Database = struct { db: sql.Sqlite, const init_sql_stmts = [_][]const u8{ + \\CREATE TABLE IF NOT EXISTS + \\user( + \\ id TEXT NOT NULL PRIMARY KEY, + \\ username TEXT NOT NULL, + \\ + \\ created_at INTEGER NOT NULL DEFAULT CURRENT_TIMESTAMP + \\) STRICT; + , \\CREATE TABLE IF NOT EXISTS \\actor( - \\ id TEXT NOT NULL, - \\ - \\ handle TEXT NOT NULL, - \\ created_at INTEGER NOT NULL DEFAULT CURRENT_TIMESTAMP, - \\ - \\ PRIMARY KEY(id) + \\ user_id TEXT NOT NULL PRIMARY KEY REFERENCES user(id), + \\ public_id TEXT NOT NULL \\) STRICT; , \\CREATE TABLE IF NOT EXISTS \\local_user( - \\ id TEXT NOT NULL, - \\ actor_id TEXT, + \\ user_id TEXT NOT NULL PRIMARY KEY REFERENCES user(id), \\ - \\ username TEXT NOT NULL, \\ email TEXT, \\ \\ hashed_password TEXT NOT NULL, - \\ password_changed_at INTEGER NOT NULL, - \\ - \\ created_at INTEGER NOT NULL DEFAULT CURRENT_TIMESTAMP, - \\ - \\ UNIQUE(actor_id), - \\ FOREIGN KEY(actor_id) REFERENCES actor(id), - \\ - \\ PRIMARY KEY(id) + \\ password_changed_at INTEGER NOT NULL \\) STRICT; , \\CREATE TABLE IF NOT EXISTS @@ -172,38 +181,29 @@ pub const Database = struct { \\ id TEXT NOT NULL, \\ \\ content TEXT NOT NULL, - \\ author_id TEXT NOT NULL, - \\ created_at INTEGER NOT NULL DEFAULT CURRENT_TIMESTAMP, + \\ author_id TEXT NOT NULL REFERENCES actor(id), \\ - \\ FOREIGN KEY(author_id) REFERENCES actor(id), - \\ - \\ PRIMARY KEY(id) + \\ created_at INTEGER NOT NULL DEFAULT CURRENT_TIMESTAMP \\) STRICT; , \\CREATE TABLE IF NOT EXISTS \\reaction( - \\ id TEXT NOT NULL, + \\ id TEXT NOT NULL PRIMARY KEY, \\ - \\ reactor_id TEXT NOT NULL, - \\ note_id TEXT NOT NULL, + \\ reactor_id TEXT NOT NULL REFERENCES actor(id), + \\ note_id TEXT NOT NULL REFERENCES note(id), \\ - \\ created_at INTEGER NOT NULL DEFAULT CURRENT_TIMESTAMP, - \\ - \\ FOREIGN KEY(reactor_id) REFERENCES actor(id), - \\ FOREIGN KEY(note_id) REFERENCES note(id), - \\ - \\ PRIMARY KEY(id) + \\ created_at INTEGER NOT NULL DEFAULT CURRENT_TIMESTAMP \\) STRICT; , \\CREATE TABLE IF NOT EXISTS \\token( - \\ id TEXT NOT NULL, + \\ id TEXT NOT NULL PRIMARY KEY, \\ \\ hash BLOB UNIQUE NOT NULL, \\ user_id TEXT NOT NULL REFERENCES local_user(id), - \\ issued_at INTEGER NOT NULL DEFAULT CURRENT_TIMESTAMP, \\ - \\ PRIMARY KEY(id) + \\ issued_at INTEGER NOT NULL DEFAULT CURRENT_TIMESTAMP \\) STRICT; , \\CREATE TABLE IF NOT EXISTS @@ -211,12 +211,12 @@ pub const Database = struct { \\ id TEXT NOT NULL PRIMARY KEY, \\ \\ name TEXT NOT NULL, - \\ invite_code TEXT NOT NULL, + \\ invite_code TEXT NOT NULL UNIQUE, \\ created_by TEXT NOT NULL REFERENCES local_user(id), \\ \\ max_uses INTEGER, \\ - \\ created_at INTEGER NOT NULL, + \\ created_at INTEGER NOT NULL DEFAULT CURRENT_TIMESTAMP, \\ expires_at INTEGER \\) STRICT; , diff --git a/src/main/db/models.zig b/src/main/db/models.zig index 2c56f8d..31b29ee 100644 --- a/src/main/db/models.zig +++ b/src/main/db/models.zig @@ -56,31 +56,32 @@ fn Ref(comptime _: type) type { return Uuid; } -pub const Note = struct { +pub const User = struct { id: Uuid, - content: []const u8, - author_id: Ref(Actor), + username: []const u8, created_at: DateTime, }; pub const Actor = struct { - id: Uuid, - handle: []const u8, - - created_at: DateTime, + user_id: Ref(User), + public_id: []const u8, }; pub const LocalUser = struct { - id: Uuid, - actor_id: ?Ref(Actor), + user_id: Ref(User), - username: []const u8, email: ?[]const u8, invite_id: ?Ref(Invite), hashed_password: []const u8, // encoded in PHC format, with salt password_changed_at: DateTime, +}; + +pub const Note = struct { + id: Uuid, + content: []const u8, + author_id: Ref(Actor), created_at: DateTime, }; diff --git a/src/util/lib.zig b/src/util/lib.zig index 6e57425..16f570c 100644 --- a/src/util/lib.zig +++ b/src/util/lib.zig @@ -1,8 +1,62 @@ +const std = @import("std"); + pub const ciutf8 = @import("./ciutf8.zig"); pub const Uuid = @import("./Uuid.zig"); pub const DateTime = @import("./DateTime.zig"); pub const PathIter = @import("./PathIter.zig"); +pub const case = struct { + // returns the number of capital letters in a string. + // only works with ascii characters + fn countCaps(str: []const u8) usize { + var count: usize = 0; + for (str) |ch| { + if (std.ascii.isUpper(ch)) { + count += 1; + } + } + return count; + } + + // converts a string from PascalCase to snake_case at comptime. + // only works with ascii characters + pub fn PascalToSnake(comptime str: []const u8) Return: { + break :Return if (str.len == 0) + *const [0:0]u8 + else + *const [str.len + countCaps(str) - 1:0]u8; + } { + comptime { + if (str.len == 0) return ""; + + var buf = std.mem.zeroes([str.len + countCaps(str) - 1:0]u8); + var i = 0; + for (str) |ch| { + if (std.ascii.isUpper(ch)) { + if (i != 0) { + buf[i] = '_'; + i += 1; + } + buf[i] = std.ascii.toLower(ch); + } else { + buf[i] = ch; + } + i += 1; + } + + return &buf; + } + } +}; + +test "pascalToSnake" { + try std.testing.expectEqual("", case.PascalToSnake("")); + try std.testing.expectEqual("abc", case.PascalToSnake("Abc")); + try std.testing.expectEqual("a_bc", case.PascalToSnake("ABc")); + try std.testing.expectEqual("a_b_c", case.PascalToSnake("ABC")); + try std.testing.expectEqual("ab_c", case.PascalToSnake("AbC")); +} + test { _ = ciutf8; _ = Uuid;