Add tests for registration api call

This commit is contained in:
jaina heartles 2023-01-08 15:35:58 -08:00
parent 2571043580
commit 3a52aad023
7 changed files with 249 additions and 131 deletions

View File

@ -8,7 +8,6 @@ const DateTime = util.DateTime;
const Uuid = util.Uuid;
pub usingnamespace types;
pub const Account = types.accounts.Account;
pub const Actor = types.actors.Actor;
pub const Community = types.communities.Community;
pub const Invite = types.invites.Invite;
@ -47,10 +46,12 @@ pub fn setupAdmin(db: sql.Db, origin: []const u8, username: []const u8, password
const user = try @import("./methods/auth.zig").createLocalAccount(
arena.allocator(),
tx,
username,
password,
community_id,
.{ .role = .admin },
.{
.username = username,
.password = password,
.community_id = community_id,
.role = .admin,
},
);
try tx.transferCommunityOwnership(community_id, user);
@ -280,3 +281,7 @@ fn ApiConn(comptime DbConn: type, comptime methods: anytype) type {
}
};
}
test {
std.testing.refAllDecls(@This());
}

View File

@ -4,29 +4,29 @@ const pkg = @import("../lib.zig");
const services = @import("../services.zig");
const invites = @import("./invites.zig");
const Allocator = std.mem.Allocator;
const Uuid = util.Uuid;
const DateTime = util.DateTime;
const ApiContext = pkg.ApiContext;
const Invite = pkg.invites.Invite;
const Token = pkg.tokens.Token;
const RegistrationOptions = pkg.auth.RegistrationOptions;
const AccountCreateOptions = services.accounts.CreateOptions;
const Invite = services.invites.Invite;
pub fn register(
alloc: std.mem.Allocator,
ctx: ApiContext,
svcs: anytype,
username: []const u8,
password: []const u8,
opt: RegistrationOptions,
) !Uuid {
const tx = try svcs.beginTx();
errdefer tx.rollbackTx();
const maybe_invite = if (opt.invite_code) |code|
try tx.getInviteByCode(alloc, code, ctx.community.id)
tx.getInviteByCode(alloc, code, ctx.community.id) catch |err| switch (err) {
error.NotFound => return error.InvalidInvite,
else => |e| return e,
}
else
null;
defer if (maybe_invite) |inv| util.deepFree(alloc, inv);
@ -43,10 +43,10 @@ pub fn register(
const account_id = try createLocalAccount(
alloc,
tx,
username,
password,
ctx.community.id,
.{
.username = opt.username,
.password = opt.password,
.community_id = ctx.community.id,
.invite_id = if (maybe_invite) |inv| @as(?Uuid, inv.id) else null,
.email = opt.email,
},
@ -64,22 +64,34 @@ pub fn register(
return account_id;
}
pub fn createLocalAccount(
alloc: std.mem.Allocator,
svcs: anytype,
pub const AccountCreateArgs = struct {
username: []const u8,
password: []const u8,
community_id: Uuid,
opt: AccountCreateOptions,
invite_id: ?Uuid = null,
email: ?[]const u8 = null,
role: services.accounts.Role = .user,
};
pub fn createLocalAccount(
alloc: std.mem.Allocator,
svcs: anytype,
args: AccountCreateArgs,
) !Uuid {
const tx = try svcs.beginTx();
errdefer tx.rollbackTx();
const hash = try hashPassword(password, alloc);
const hash = try hashPassword(args.password, alloc);
defer alloc.free(hash);
const id = try tx.createActor(alloc, username, community_id, false);
try tx.createAccount(alloc, id, hash, opt);
const id = try tx.createActor(alloc, args.username, args.community_id, false);
try tx.createAccount(alloc, .{
.for_actor = id,
.password_hash = hash,
.invite_id = args.invite_id,
.email = args.email,
.role = args.role,
});
try tx.commitTx();
@ -156,7 +168,7 @@ pub fn login(
// password hashing.
// Attempting to calculate/verify a hash will use about 50mb of work space.
const scrypt = std.crypto.pwhash.scrypt;
const password_hash_len = 128;
const max_password_hash_len = 128;
fn verifyPassword(
hash: []const u8,
password: []const u8,
@ -167,23 +179,33 @@ fn verifyPassword(
password,
.{ .allocator = alloc },
) catch |err| return switch (err) {
error.PasswordVerificationFailed => error.InvalidLogin,
else => error.HashFailure,
error.PasswordVerificationFailed => return error.InvalidLogin,
error.OutOfMemory => return error.OutOfMemory,
else => |e| return e,
};
}
const scrypt_params = if (!@import("builtin").is_test)
scrypt.Params.interactive
else
scrypt.Params{
.ln = 8,
.r = 8,
.p = 1,
};
fn hashPassword(password: []const u8, alloc: std.mem.Allocator) ![]const u8 {
const buf = try alloc.alloc(u8, password_hash_len);
errdefer alloc.free(buf);
return scrypt.strHash(
var buf: [max_password_hash_len]u8 = undefined;
const hash = try scrypt.strHash(
password,
.{
.allocator = alloc,
.params = scrypt.Params.interactive,
.params = scrypt_params,
.encoding = .phc,
},
buf,
) catch error.HashFailure;
&buf,
);
return util.deepClone(alloc, hash);
}
/// A raw token is a sequence of N random bytes, base64 encoded.
@ -226,101 +248,190 @@ fn hashToken(token_b64: []const u8, alloc: std.mem.Allocator) ![]const u8 {
const hash_b64 = try alloc.alloc(u8, hash_b64_len);
return Base64Encoder.encode(hash_b64, &hash);
}
const TestDb = struct {
tx_level: usize = 0,
rolled_back: bool = false,
committed: bool = false,
fn beginOrSavepoint(self: *TestDb) !*TestDb {
self.tx_level += 1;
return self;
}
fn rollback(self: *TestDb) void {
self.rolled_back = true;
self.tx_level -= 1;
}
fn commitOrRelease(self: *TestDb) !void {
self.committed = true;
self.tx_level -= 1;
}
};
test "register" {
comptime var exp_code = "code";
comptime var exp_community = Uuid.parse("a210c035-c9e1-4361-82a2-aaeac8e40dc6") catch unreachable;
comptime var uid = Uuid.parse("6d951fcc-1c9f-497b-9c96-31dfb9873708") catch unreachable;
const testCase = struct {
const test_invite_code = "xyz";
const test_invite_id = Uuid.parse("d24e7f2a-7e6e-4e2a-8e9d-987538a04a40") catch unreachable;
const test_acc_id = Uuid.parse("e8e21e1d-7b80-4e48-876d-9929326af511") catch unreachable;
const test_community_id = Uuid.parse("8bf88bd7-fb07-492d-a89a-6350c036183f") catch unreachable;
const MockSvc = struct {
const invites = struct {
fn getByCode(db: *TestDb, code: []const u8, community_id: Uuid, alloc: std.mem.Allocator) !Invite {
try std.testing.expectEqual(db.tx_level, 1);
try std.testing.expectEqualStrings(exp_code, code);
try std.testing.expectEqual(exp_community, community_id);
const Args = struct {
username: []const u8 = "username",
password: []const u8 = "password1234",
return try util.deepClone(alloc, Invite{
.id = Uuid.parse("eac18f43-4dcc-489f-9fb5-4c1633e7b4e0") catch unreachable,
use_invite: bool = false,
invite_community_id: Uuid = test_community_id,
invite_kind: services.invites.Kind = .user,
invite_max_uses: ?usize = null,
invite_current_uses: usize = 0,
invite_expires_at: ?DateTime = null,
.created_by = Uuid.parse("6d951fcc-1c9f-497b-9c96-31dfb9873708") catch unreachable,
.community_id = exp_community,
.name = "test invite",
.code = exp_code,
get_invite_error: ?anyerror = null,
create_account_error: ?anyerror = null,
create_actor_error: ?anyerror = null,
transfer_error: ?anyerror = null,
.kind = .user,
expect_error: ?anyerror = null,
expect_transferred: bool = false,
};
.created_at = DateTime.parse("2022-12-21T09:05:50Z") catch unreachable,
.times_used = 0,
fn runCaseOnce(allocator: std.mem.Allocator, test_args: Args) anyerror!void {
const Svc = struct {
test_args: Args,
tx_level: usize = 0,
rolled_back: bool = false,
committed: bool = false,
.expires_at = null,
.max_uses = null,
});
account_created: bool = false,
actor_created: bool = false,
community_transferred: bool = false,
fn beginTx(self: *@This()) !*@This() {
self.tx_level += 1;
return self;
}
fn rollbackTx(self: *@This()) void {
self.tx_level -= 1;
self.rolled_back = true;
}
fn commitTx(self: *@This()) !void {
self.tx_level -= 1;
self.committed = true;
}
fn getInviteByCode(self: *@This(), alloc: Allocator, code: []const u8, community_id: Uuid) anyerror!services.invites.Invite {
try std.testing.expect(self.tx_level > 0);
try std.testing.expectEqualStrings(test_invite_code, code);
try std.testing.expectEqual(test_community_id, community_id);
if (self.test_args.get_invite_error) |err| return err;
return try util.deepClone(alloc, std.mem.zeroInit(services.invites.Invite, .{
.id = test_invite_id,
.community_id = self.test_args.invite_community_id,
.code = code,
.kind = self.test_args.invite_kind,
.times_used = self.test_args.invite_current_uses,
.max_uses = self.test_args.invite_max_uses,
.expires_at = self.test_args.invite_expires_at,
}));
}
fn createActor(self: *@This(), _: Allocator, username: []const u8, community_id: Uuid, _: bool) anyerror!Uuid {
try std.testing.expect(self.tx_level > 0);
if (self.test_args.create_actor_error) |err| return err;
try std.testing.expectEqualStrings(self.test_args.username, username);
try std.testing.expectEqual(test_community_id, community_id);
self.actor_created = true;
return test_acc_id;
}
fn createAccount(self: *@This(), alloc: Allocator, args: services.accounts.CreateArgs) anyerror!void {
try std.testing.expect(self.tx_level > 0);
if (self.test_args.create_account_error) |err| return err;
try verifyPassword(args.password_hash, self.test_args.password, alloc);
if (self.test_args.use_invite)
try std.testing.expectEqual(@as(?Uuid, test_invite_id), args.invite_id)
else
try std.testing.expect(args.invite_id == null);
try std.testing.expectEqual(services.accounts.Role.user, args.role);
self.account_created = true;
}
fn transferCommunityOwnership(self: *@This(), community_id: Uuid, account_id: Uuid) !void {
try std.testing.expect(self.tx_level > 0);
if (self.test_args.transfer_error) |err| return err;
self.community_transferred = true;
try std.testing.expectEqual(test_community_id, community_id);
try std.testing.expectEqual(test_acc_id, account_id);
}
};
var svc = Svc{ .test_args = test_args };
const community = std.mem.zeroInit(pkg.Community, .{ .kind = .local, .id = test_community_id });
const result = register(
allocator,
.{ .community = community },
&svc,
.{
.username = test_args.username,
.password = test_args.password,
.invite_code = if (test_args.use_invite) test_invite_code else null,
},
// shortcut out of memory errors to test allocation
) catch |err| if (err == error.OutOfMemory) return err else err;
if (test_args.expect_error) |err| {
try std.testing.expectError(err, result);
try std.testing.expect(!svc.committed);
if (svc.account_created or svc.actor_created or svc.community_transferred) {
try std.testing.expect(svc.rolled_back);
}
} else {
try std.testing.expectEqual(test_acc_id, try result);
try std.testing.expect(svc.committed);
try std.testing.expect(!svc.rolled_back);
try std.testing.expect(svc.account_created);
try std.testing.expect(svc.actor_created);
try std.testing.expectEqual(test_args.expect_transferred, svc.community_transferred);
}
};
const auth = struct {
fn register(
db: *TestDb,
username: []const u8,
password: []const u8,
community_id: Uuid,
_: AccountCreateOptions,
_: std.mem.Allocator,
) anyerror!Uuid {
try std.testing.expectEqual(db.tx_level, 1);
try std.testing.expectEqualStrings("root", username);
try std.testing.expectEqualStrings("password", password);
try std.testing.expectEqual(exp_community, community_id);
}
return uid;
}
};
const actors = struct {
fn get(_: *TestDb, id: Uuid, alloc: std.mem.Allocator) anyerror!pkg.Actor {
try std.testing.expectEqual(uid, id);
return try util.deepClone(alloc, std.mem.zeroInit(pkg.Actor, .{
.id = id,
.username = "root",
.host = "example.com",
.community_id = exp_community,
}));
}
};
const communities = struct {
fn transferOwnership(_: *TestDb, _: Uuid, _: Uuid) anyerror!void {}
};
};
fn case(args: Args) !void {
try std.testing.checkAllAllocationFailures(std.testing.allocator, runCaseOnce, .{args});
}
}.case;
var db = TestDb{};
// regular registration
try testCase(.{});
_ = MockSvc;
util.deepFree(std.testing.allocator, try register(.{
.db = &db,
.allocator = std.testing.allocator,
.community = .{
.id = exp_community,
.kind = .local,
},
}, "root", "password", .{}));
try std.testing.expectEqual(false, db.rolled_back);
try std.testing.expectEqual(true, db.committed);
try std.testing.expectEqual(@as(usize, 0), db.tx_level);
// registration with invite
try testCase(.{ .use_invite = true });
// registration with invite for a different community
try testCase(.{
.invite_community_id = Uuid.parse("11111111-1111-1111-1111-111111111111") catch unreachable,
.use_invite = true,
.expect_error = error.WrongCommunity,
});
// registration as a new community owner
try testCase(.{
.use_invite = true,
.invite_kind = .community_owner,
.expect_transferred = true,
});
// invite with expiration info
try testCase(.{
.use_invite = true,
.invite_max_uses = 100,
.invite_current_uses = 10,
.invite_expires_at = DateTime{ .seconds_since_epoch = DateTime.test_now_timestamp + 3600 },
});
// missing invite
try testCase(.{
.use_invite = true,
.get_invite_error = error.NotFound,
.expect_error = error.InvalidInvite,
});
// expired invite
try testCase(.{
.use_invite = true,
.invite_expires_at = DateTime{ .seconds_since_epoch = DateTime.test_now_timestamp - 3600 },
.expect_error = error.InvalidInvite,
});
// used invite
try testCase(.{
.use_invite = true,
.invite_max_uses = 100,
.invite_current_uses = 110,
.expect_error = error.InvalidInvite,
});
}

View File

@ -52,11 +52,9 @@ pub fn Services(comptime Db: type) type {
pub fn createAccount(
self: Self,
alloc: std.mem.Allocator,
actor: Uuid,
password_hash: []const u8,
options: types.accounts.CreateOptions,
args: types.accounts.CreateArgs,
) !void {
return try impl.accounts.create(self.db, actor, password_hash, options, alloc);
return try impl.accounts.create(self.db, args, alloc);
}
pub fn getCredentialsByUsername(

View File

@ -5,29 +5,27 @@ const types = @import("./types.zig");
const Uuid = util.Uuid;
const DateTime = util.DateTime;
const CreateOptions = types.accounts.CreateOptions;
const CreateArgs = types.accounts.CreateArgs;
const Credentials = types.accounts.Credentials;
/// Creates a local account with the given information
pub fn create(
db: anytype,
for_actor: Uuid,
password_hash: []const u8,
options: CreateOptions,
args: CreateArgs,
alloc: std.mem.Allocator,
) !void {
const tx = try db.beginOrSavepoint();
errdefer tx.rollback();
tx.insert("account", .{
.id = for_actor,
.invite_id = options.invite_id,
.email = options.email,
.kind = options.role,
.id = args.for_actor,
.invite_id = args.invite_id,
.email = args.email,
.kind = args.role,
}, alloc) catch return error.DatabaseFailure;
tx.insert("password", .{
.account_id = for_actor,
.hash = password_hash,
.account_id = args.for_actor,
.hash = args.password_hash,
.changed_at = DateTime.now(),
}, alloc) catch return error.DatabaseFailure;

View File

@ -34,7 +34,9 @@ pub const accounts = struct {
admin,
};
pub const CreateOptions = struct {
pub const CreateArgs = struct {
for_actor: Uuid,
password_hash: []const u8,
invite_id: ?Uuid = null,
email: ?[]const u8 = null,
role: Role = .user,

View File

@ -15,6 +15,8 @@ fn QueryResult(comptime R: type, comptime A: type) type {
pub const auth = struct {
pub const RegistrationOptions = struct {
username: []const u8,
password: []const u8,
invite_code: ?[]const u8 = null,
email: ?[]const u8 = null,
};

View File

@ -14,10 +14,12 @@ pub const create = struct {
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
const options = .{
.username = req.body.username,
.password = req.body.password,
.invite_code = req.body.invite_code,
.email = req.body.email,
};
const user = srv.register(req.body.username, req.body.password, options) catch |err| switch (err) {
const user = srv.register(options) catch |err| switch (err) {
error.UsernameTaken => return res.err(.unprocessable_entity, "Username Unavailable", {}),
else => return err,
};