diff --git a/src/api/lib.zig b/src/api/lib.zig index 29f18d2..e046498 100644 --- a/src/api/lib.zig +++ b/src/api/lib.zig @@ -8,7 +8,6 @@ const DateTime = util.DateTime; const Uuid = util.Uuid; pub usingnamespace types; -pub const Account = types.accounts.Account; pub const Actor = types.actors.Actor; pub const Community = types.communities.Community; pub const Invite = types.invites.Invite; @@ -47,10 +46,12 @@ pub fn setupAdmin(db: sql.Db, origin: []const u8, username: []const u8, password const user = try @import("./methods/auth.zig").createLocalAccount( arena.allocator(), tx, - username, - password, - community_id, - .{ .role = .admin }, + .{ + .username = username, + .password = password, + .community_id = community_id, + .role = .admin, + }, ); try tx.transferCommunityOwnership(community_id, user); @@ -280,3 +281,7 @@ fn ApiConn(comptime DbConn: type, comptime methods: anytype) type { } }; } + +test { + std.testing.refAllDecls(@This()); +} diff --git a/src/api/methods/auth.zig b/src/api/methods/auth.zig index 80a462c..0d085d6 100644 --- a/src/api/methods/auth.zig +++ b/src/api/methods/auth.zig @@ -4,29 +4,29 @@ const pkg = @import("../lib.zig"); const services = @import("../services.zig"); const invites = @import("./invites.zig"); +const Allocator = std.mem.Allocator; const Uuid = util.Uuid; const DateTime = util.DateTime; const ApiContext = pkg.ApiContext; -const Invite = pkg.invites.Invite; const Token = pkg.tokens.Token; - const RegistrationOptions = pkg.auth.RegistrationOptions; -const AccountCreateOptions = services.accounts.CreateOptions; +const Invite = services.invites.Invite; pub fn register( alloc: std.mem.Allocator, ctx: ApiContext, svcs: anytype, - username: []const u8, - password: []const u8, opt: RegistrationOptions, ) !Uuid { const tx = try svcs.beginTx(); errdefer tx.rollbackTx(); const maybe_invite = if (opt.invite_code) |code| - try tx.getInviteByCode(alloc, code, ctx.community.id) + tx.getInviteByCode(alloc, code, ctx.community.id) catch |err| switch (err) { + error.NotFound => return error.InvalidInvite, + else => |e| return e, + } else null; defer if (maybe_invite) |inv| util.deepFree(alloc, inv); @@ -43,10 +43,10 @@ pub fn register( const account_id = try createLocalAccount( alloc, tx, - username, - password, - ctx.community.id, .{ + .username = opt.username, + .password = opt.password, + .community_id = ctx.community.id, .invite_id = if (maybe_invite) |inv| @as(?Uuid, inv.id) else null, .email = opt.email, }, @@ -64,22 +64,34 @@ pub fn register( return account_id; } -pub fn createLocalAccount( - alloc: std.mem.Allocator, - svcs: anytype, +pub const AccountCreateArgs = struct { username: []const u8, password: []const u8, community_id: Uuid, - opt: AccountCreateOptions, + invite_id: ?Uuid = null, + email: ?[]const u8 = null, + role: services.accounts.Role = .user, +}; + +pub fn createLocalAccount( + alloc: std.mem.Allocator, + svcs: anytype, + args: AccountCreateArgs, ) !Uuid { const tx = try svcs.beginTx(); errdefer tx.rollbackTx(); - const hash = try hashPassword(password, alloc); + const hash = try hashPassword(args.password, alloc); defer alloc.free(hash); - const id = try tx.createActor(alloc, username, community_id, false); - try tx.createAccount(alloc, id, hash, opt); + const id = try tx.createActor(alloc, args.username, args.community_id, false); + try tx.createAccount(alloc, .{ + .for_actor = id, + .password_hash = hash, + .invite_id = args.invite_id, + .email = args.email, + .role = args.role, + }); try tx.commitTx(); @@ -156,7 +168,7 @@ pub fn login( // password hashing. // Attempting to calculate/verify a hash will use about 50mb of work space. const scrypt = std.crypto.pwhash.scrypt; -const password_hash_len = 128; +const max_password_hash_len = 128; fn verifyPassword( hash: []const u8, password: []const u8, @@ -167,23 +179,33 @@ fn verifyPassword( password, .{ .allocator = alloc }, ) catch |err| return switch (err) { - error.PasswordVerificationFailed => error.InvalidLogin, - else => error.HashFailure, + error.PasswordVerificationFailed => return error.InvalidLogin, + error.OutOfMemory => return error.OutOfMemory, + else => |e| return e, }; } +const scrypt_params = if (!@import("builtin").is_test) + scrypt.Params.interactive +else + scrypt.Params{ + .ln = 8, + .r = 8, + .p = 1, + }; fn hashPassword(password: []const u8, alloc: std.mem.Allocator) ![]const u8 { - const buf = try alloc.alloc(u8, password_hash_len); - errdefer alloc.free(buf); - return scrypt.strHash( + var buf: [max_password_hash_len]u8 = undefined; + const hash = try scrypt.strHash( password, .{ .allocator = alloc, - .params = scrypt.Params.interactive, + .params = scrypt_params, .encoding = .phc, }, - buf, - ) catch error.HashFailure; + &buf, + ); + + return util.deepClone(alloc, hash); } /// A raw token is a sequence of N random bytes, base64 encoded. @@ -226,101 +248,190 @@ fn hashToken(token_b64: []const u8, alloc: std.mem.Allocator) ![]const u8 { const hash_b64 = try alloc.alloc(u8, hash_b64_len); return Base64Encoder.encode(hash_b64, &hash); } -const TestDb = struct { - tx_level: usize = 0, - rolled_back: bool = false, - committed: bool = false, - fn beginOrSavepoint(self: *TestDb) !*TestDb { - self.tx_level += 1; - return self; - } - - fn rollback(self: *TestDb) void { - self.rolled_back = true; - self.tx_level -= 1; - } - - fn commitOrRelease(self: *TestDb) !void { - self.committed = true; - self.tx_level -= 1; - } -}; test "register" { - comptime var exp_code = "code"; - comptime var exp_community = Uuid.parse("a210c035-c9e1-4361-82a2-aaeac8e40dc6") catch unreachable; - comptime var uid = Uuid.parse("6d951fcc-1c9f-497b-9c96-31dfb9873708") catch unreachable; + const testCase = struct { + const test_invite_code = "xyz"; + const test_invite_id = Uuid.parse("d24e7f2a-7e6e-4e2a-8e9d-987538a04a40") catch unreachable; + const test_acc_id = Uuid.parse("e8e21e1d-7b80-4e48-876d-9929326af511") catch unreachable; + const test_community_id = Uuid.parse("8bf88bd7-fb07-492d-a89a-6350c036183f") catch unreachable; - const MockSvc = struct { - const invites = struct { - fn getByCode(db: *TestDb, code: []const u8, community_id: Uuid, alloc: std.mem.Allocator) !Invite { - try std.testing.expectEqual(db.tx_level, 1); - try std.testing.expectEqualStrings(exp_code, code); - try std.testing.expectEqual(exp_community, community_id); + const Args = struct { + username: []const u8 = "username", + password: []const u8 = "password1234", - return try util.deepClone(alloc, Invite{ - .id = Uuid.parse("eac18f43-4dcc-489f-9fb5-4c1633e7b4e0") catch unreachable, + use_invite: bool = false, + invite_community_id: Uuid = test_community_id, + invite_kind: services.invites.Kind = .user, + invite_max_uses: ?usize = null, + invite_current_uses: usize = 0, + invite_expires_at: ?DateTime = null, - .created_by = Uuid.parse("6d951fcc-1c9f-497b-9c96-31dfb9873708") catch unreachable, - .community_id = exp_community, - .name = "test invite", - .code = exp_code, + get_invite_error: ?anyerror = null, + create_account_error: ?anyerror = null, + create_actor_error: ?anyerror = null, + transfer_error: ?anyerror = null, - .kind = .user, + expect_error: ?anyerror = null, + expect_transferred: bool = false, + }; - .created_at = DateTime.parse("2022-12-21T09:05:50Z") catch unreachable, - .times_used = 0, + fn runCaseOnce(allocator: std.mem.Allocator, test_args: Args) anyerror!void { + const Svc = struct { + test_args: Args, + tx_level: usize = 0, + rolled_back: bool = false, + committed: bool = false, - .expires_at = null, - .max_uses = null, - }); + account_created: bool = false, + actor_created: bool = false, + community_transferred: bool = false, + + fn beginTx(self: *@This()) !*@This() { + self.tx_level += 1; + return self; + } + fn rollbackTx(self: *@This()) void { + self.tx_level -= 1; + self.rolled_back = true; + } + fn commitTx(self: *@This()) !void { + self.tx_level -= 1; + self.committed = true; + } + + fn getInviteByCode(self: *@This(), alloc: Allocator, code: []const u8, community_id: Uuid) anyerror!services.invites.Invite { + try std.testing.expect(self.tx_level > 0); + try std.testing.expectEqualStrings(test_invite_code, code); + try std.testing.expectEqual(test_community_id, community_id); + if (self.test_args.get_invite_error) |err| return err; + return try util.deepClone(alloc, std.mem.zeroInit(services.invites.Invite, .{ + .id = test_invite_id, + .community_id = self.test_args.invite_community_id, + .code = code, + .kind = self.test_args.invite_kind, + + .times_used = self.test_args.invite_current_uses, + .max_uses = self.test_args.invite_max_uses, + .expires_at = self.test_args.invite_expires_at, + })); + } + + fn createActor(self: *@This(), _: Allocator, username: []const u8, community_id: Uuid, _: bool) anyerror!Uuid { + try std.testing.expect(self.tx_level > 0); + if (self.test_args.create_actor_error) |err| return err; + try std.testing.expectEqualStrings(self.test_args.username, username); + try std.testing.expectEqual(test_community_id, community_id); + self.actor_created = true; + return test_acc_id; + } + + fn createAccount(self: *@This(), alloc: Allocator, args: services.accounts.CreateArgs) anyerror!void { + try std.testing.expect(self.tx_level > 0); + if (self.test_args.create_account_error) |err| return err; + try verifyPassword(args.password_hash, self.test_args.password, alloc); + if (self.test_args.use_invite) + try std.testing.expectEqual(@as(?Uuid, test_invite_id), args.invite_id) + else + try std.testing.expect(args.invite_id == null); + + try std.testing.expectEqual(services.accounts.Role.user, args.role); + self.account_created = true; + } + + fn transferCommunityOwnership(self: *@This(), community_id: Uuid, account_id: Uuid) !void { + try std.testing.expect(self.tx_level > 0); + if (self.test_args.transfer_error) |err| return err; + self.community_transferred = true; + try std.testing.expectEqual(test_community_id, community_id); + try std.testing.expectEqual(test_acc_id, account_id); + } + }; + + var svc = Svc{ .test_args = test_args }; + + const community = std.mem.zeroInit(pkg.Community, .{ .kind = .local, .id = test_community_id }); + + const result = register( + allocator, + .{ .community = community }, + &svc, + .{ + .username = test_args.username, + .password = test_args.password, + .invite_code = if (test_args.use_invite) test_invite_code else null, + }, + // shortcut out of memory errors to test allocation + ) catch |err| if (err == error.OutOfMemory) return err else err; + + if (test_args.expect_error) |err| { + try std.testing.expectError(err, result); + try std.testing.expect(!svc.committed); + if (svc.account_created or svc.actor_created or svc.community_transferred) { + try std.testing.expect(svc.rolled_back); + } + } else { + try std.testing.expectEqual(test_acc_id, try result); + try std.testing.expect(svc.committed); + try std.testing.expect(!svc.rolled_back); + try std.testing.expect(svc.account_created); + try std.testing.expect(svc.actor_created); + try std.testing.expectEqual(test_args.expect_transferred, svc.community_transferred); } - }; - const auth = struct { - fn register( - db: *TestDb, - username: []const u8, - password: []const u8, - community_id: Uuid, - _: AccountCreateOptions, - _: std.mem.Allocator, - ) anyerror!Uuid { - try std.testing.expectEqual(db.tx_level, 1); - try std.testing.expectEqualStrings("root", username); - try std.testing.expectEqualStrings("password", password); - try std.testing.expectEqual(exp_community, community_id); + } - return uid; - } - }; - const actors = struct { - fn get(_: *TestDb, id: Uuid, alloc: std.mem.Allocator) anyerror!pkg.Actor { - try std.testing.expectEqual(uid, id); - return try util.deepClone(alloc, std.mem.zeroInit(pkg.Actor, .{ - .id = id, - .username = "root", - .host = "example.com", - .community_id = exp_community, - })); - } - }; - const communities = struct { - fn transferOwnership(_: *TestDb, _: Uuid, _: Uuid) anyerror!void {} - }; - }; + fn case(args: Args) !void { + try std.testing.checkAllAllocationFailures(std.testing.allocator, runCaseOnce, .{args}); + } + }.case; - var db = TestDb{}; + // regular registration + try testCase(.{}); - _ = MockSvc; - util.deepFree(std.testing.allocator, try register(.{ - .db = &db, - .allocator = std.testing.allocator, - .community = .{ - .id = exp_community, - .kind = .local, - }, - }, "root", "password", .{})); - try std.testing.expectEqual(false, db.rolled_back); - try std.testing.expectEqual(true, db.committed); - try std.testing.expectEqual(@as(usize, 0), db.tx_level); + // registration with invite + try testCase(.{ .use_invite = true }); + + // registration with invite for a different community + try testCase(.{ + .invite_community_id = Uuid.parse("11111111-1111-1111-1111-111111111111") catch unreachable, + .use_invite = true, + .expect_error = error.WrongCommunity, + }); + + // registration as a new community owner + try testCase(.{ + .use_invite = true, + .invite_kind = .community_owner, + .expect_transferred = true, + }); + + // invite with expiration info + try testCase(.{ + .use_invite = true, + .invite_max_uses = 100, + .invite_current_uses = 10, + .invite_expires_at = DateTime{ .seconds_since_epoch = DateTime.test_now_timestamp + 3600 }, + }); + + // missing invite + try testCase(.{ + .use_invite = true, + .get_invite_error = error.NotFound, + .expect_error = error.InvalidInvite, + }); + + // expired invite + try testCase(.{ + .use_invite = true, + .invite_expires_at = DateTime{ .seconds_since_epoch = DateTime.test_now_timestamp - 3600 }, + .expect_error = error.InvalidInvite, + }); + + // used invite + try testCase(.{ + .use_invite = true, + .invite_max_uses = 100, + .invite_current_uses = 110, + .expect_error = error.InvalidInvite, + }); } diff --git a/src/api/services.zig b/src/api/services.zig index dbaf690..8d09f48 100644 --- a/src/api/services.zig +++ b/src/api/services.zig @@ -52,11 +52,9 @@ pub fn Services(comptime Db: type) type { pub fn createAccount( self: Self, alloc: std.mem.Allocator, - actor: Uuid, - password_hash: []const u8, - options: types.accounts.CreateOptions, + args: types.accounts.CreateArgs, ) !void { - return try impl.accounts.create(self.db, actor, password_hash, options, alloc); + return try impl.accounts.create(self.db, args, alloc); } pub fn getCredentialsByUsername( diff --git a/src/api/services/accounts.zig b/src/api/services/accounts.zig index e97014a..2bf5419 100644 --- a/src/api/services/accounts.zig +++ b/src/api/services/accounts.zig @@ -5,29 +5,27 @@ const types = @import("./types.zig"); const Uuid = util.Uuid; const DateTime = util.DateTime; -const CreateOptions = types.accounts.CreateOptions; +const CreateArgs = types.accounts.CreateArgs; const Credentials = types.accounts.Credentials; /// Creates a local account with the given information pub fn create( db: anytype, - for_actor: Uuid, - password_hash: []const u8, - options: CreateOptions, + args: CreateArgs, alloc: std.mem.Allocator, ) !void { const tx = try db.beginOrSavepoint(); errdefer tx.rollback(); tx.insert("account", .{ - .id = for_actor, - .invite_id = options.invite_id, - .email = options.email, - .kind = options.role, + .id = args.for_actor, + .invite_id = args.invite_id, + .email = args.email, + .kind = args.role, }, alloc) catch return error.DatabaseFailure; tx.insert("password", .{ - .account_id = for_actor, - .hash = password_hash, + .account_id = args.for_actor, + .hash = args.password_hash, .changed_at = DateTime.now(), }, alloc) catch return error.DatabaseFailure; diff --git a/src/api/services/types.zig b/src/api/services/types.zig index 6db53e7..2b66058 100644 --- a/src/api/services/types.zig +++ b/src/api/services/types.zig @@ -34,7 +34,9 @@ pub const accounts = struct { admin, }; - pub const CreateOptions = struct { + pub const CreateArgs = struct { + for_actor: Uuid, + password_hash: []const u8, invite_id: ?Uuid = null, email: ?[]const u8 = null, role: Role = .user, diff --git a/src/api/types.zig b/src/api/types.zig index e072f3d..390306a 100644 --- a/src/api/types.zig +++ b/src/api/types.zig @@ -15,6 +15,8 @@ fn QueryResult(comptime R: type, comptime A: type) type { pub const auth = struct { pub const RegistrationOptions = struct { + username: []const u8, + password: []const u8, invite_code: ?[]const u8 = null, email: ?[]const u8 = null, }; diff --git a/src/main/controllers/api/users.zig b/src/main/controllers/api/users.zig index 3bf0190..321b9ba 100644 --- a/src/main/controllers/api/users.zig +++ b/src/main/controllers/api/users.zig @@ -14,10 +14,12 @@ pub const create = struct { pub fn handler(req: anytype, res: anytype, srv: anytype) !void { const options = .{ + .username = req.body.username, + .password = req.body.password, .invite_code = req.body.invite_code, .email = req.body.email, }; - const user = srv.register(req.body.username, req.body.password, options) catch |err| switch (err) { + const user = srv.register(options) catch |err| switch (err) { error.UsernameTaken => return res.err(.unprocessable_entity, "Username Unavailable", {}), else => return err, };