diff --git a/build.zig b/build.zig index 7cc183b..26d44c5 100644 --- a/build.zig +++ b/build.zig @@ -116,7 +116,12 @@ pub fn build(b: *std.build.Builder) !void { const unittest_template_cmd = b.step("unit:template", "Run tests for template package"); const unittest_template = b.addTest("src/template/lib.zig"); unittest_template_cmd.dependOn(&unittest_template.step); - //unittest_template.addPackage(pkgs.util); + + const unittest_api_cmd = b.step("unit:api", "Run tests for api package"); + const unittest_api = b.addTest("src/api/lib.zig"); + unittest_api_cmd.dependOn(&unittest_api.step); + unittest_api.addPackage(pkgs.util); + unittest_api.addPackage(pkgs.sql); //const util_tests = b.addTest("src/util/lib.zig"); //const sql_tests = b.addTest("src/sql/lib.zig"); @@ -129,6 +134,7 @@ pub fn build(b: *std.build.Builder) !void { unittest_all.dependOn(unittest_util_cmd); unittest_all.dependOn(unittest_sql_cmd); unittest_all.dependOn(unittest_template_cmd); + unittest_all.dependOn(unittest_api_cmd); const api_integration = b.addTest("./tests/api_integration/lib.zig"); api_integration.addPackage(pkgs.opts); diff --git a/src/api/lib.zig b/src/api/lib.zig index 9da0674..80ebef3 100644 --- a/src/api/lib.zig +++ b/src/api/lib.zig @@ -8,17 +8,21 @@ const Uuid = util.Uuid; const default_avatar = "static/default_avi.png"; const services = struct { - const communities = @import("./services/communities.zig"); - const actors = @import("./services/actors.zig"); - const auth = @import("./services/auth.zig"); - const drive = @import("./services/drive.zig"); - const files = @import("./services/files.zig"); - const invites = @import("./services/invites.zig"); - const notes = @import("./services/notes.zig"); - const follows = @import("./services/follows.zig"); + pub const communities = @import("./services/communities.zig"); + pub const actors = @import("./services/actors.zig"); + pub const auth = @import("./services/auth.zig"); + pub const drive = @import("./services/drive.zig"); + pub const files = @import("./services/files.zig"); + pub const invites = @import("./services/invites.zig"); + pub const notes = @import("./services/notes.zig"); + pub const follows = @import("./services/follows.zig"); }; -const types = @import("./services/types.zig"); +test { + _ = @import("./methods/auth.zig"); +} + +const types = @import("./types.zig"); pub const QueryResult = types.QueryResult; @@ -460,52 +464,53 @@ fn ApiConn(comptime DbConn: type, comptime models: anytype) type { return true; } - pub fn register(self: *Self, username: []const u8, password: []const u8, opt: RegistrationOptions) !UserResponse { - const tx = try self.db.beginOrSavepoint(); - const maybe_invite = if (opt.invite_code) |code| - try models.invites.getByCode(tx, code, self.community.id, self.allocator) - else - null; - defer if (maybe_invite) |inv| util.deepFree(self.allocator, inv); + pub usingnamespace @import("./methods/auth.zig").methods(models); + // pub fn register(self: *Self, username: []const u8, password: []const u8, opt: RegistrationOptions) !UserResponse { + // const tx = try self.db.beginOrSavepoint(); + // const maybe_invite = if (opt.invite_code) |code| + // try models.invites.getByCode(tx, code, self.community.id, self.allocator) + // else + // null; + // defer if (maybe_invite) |inv| util.deepFree(self.allocator, inv); - if (maybe_invite) |invite| { - if (!Uuid.eql(invite.community_id, self.community.id)) return error.WrongCommunity; - if (!isInviteValid(invite)) return error.InvalidInvite; - } + // if (maybe_invite) |invite| { + // if (!Uuid.eql(invite.community_id, self.community.id)) return error.WrongCommunity; + // if (!isInviteValid(invite)) return error.InvalidInvite; + // } - const invite_kind = if (maybe_invite) |inv| inv.kind else .user; + // const invite_kind = if (maybe_invite) |inv| inv.kind else .user; - if (self.community.kind == .admin) @panic("Unimplmented"); + // if (self.community.kind == .admin) @panic("Unimplmented"); - const user_id = try models.auth.register( - tx, - username, - password, - self.community.id, - .{ - .invite_id = if (maybe_invite) |inv| @as(?Uuid, inv.id) else null, - .email = opt.email, - }, - self.allocator, - ); + // const user_id = try models.auth.register( + // tx, + // username, + // password, + // self.community.id, + // .{ + // .invite_id = if (maybe_invite) |inv| @as(?Uuid, inv.id) else null, + // .email = opt.email, + // }, + // self.allocator, + // ); - switch (invite_kind) { - .user => {}, - .system => @panic("System user invites unimplemented"), - .community_owner => { - try models.communities.transferOwnership(tx, self.community.id, user_id); - }, - } + // switch (invite_kind) { + // .user => {}, + // .system => @panic("System user invites unimplemented"), + // .community_owner => { + // try models.communities.transferOwnership(tx, self.community.id, user_id); + // }, + // } - const user = self.getUserUnchecked(tx, user_id) catch |err| switch (err) { - error.NotFound => return error.Unexpected, - else => |e| return e, - }; - errdefer util.deepFree(self.allocator, user); + // const user = self.getUserUnchecked(tx, user_id) catch |err| switch (err) { + // error.NotFound => return error.Unexpected, + // else => |e| return e, + // }; + // errdefer util.deepFree(self.allocator, user); - try tx.commit(); - return user; - } + // try tx.commit(); + // return user; + // } fn getUserUnchecked(self: *Self, db: anytype, user_id: Uuid) !UserResponse { const user = try models.actors.get(db, user_id, self.allocator); @@ -720,7 +725,7 @@ fn ApiConn(comptime DbConn: type, comptime models: anytype) type { errdefer self.allocator.free(result); for (children) |child, i| { - result[i] = try self.backendDriveEntryToFrontend(child, false); + result[i] = try backendDriveEntryToFrontend(self, child, false); count += 1; } @@ -879,53 +884,3 @@ fn ApiConn(comptime DbConn: type, comptime models: anytype) type { } }; } - -// test "register" { -// const TestDb = void; -// const exp_code = "abcd"; -// const exp_community = Uuid.parse("a210c035-c9e1-4361-82a2-aaeac8e40dc6") catch unreachable; -// var conn = ApiConn(TestDb, struct { -// const invites = struct { -// fn getByCode(_: TestDb, code: []const u8, community_id: Uuid, alloc: std.mem.Allocator) !services.invites.Invite { -// try std.testing.expectEqualStrings(exp_code, code); -// try std.testing.expectEqual(exp_community, community_id); - -// return try util.deepClone(alloc, services.invites.Invite{ -// .id = Uuid.parse("eac18f43-4dcc-489f-9fb5-4c1633e7b4e0") catch unreachable, - -// .created_by = Uuid.parse("6d951fcc-1c9f-497b-9c96-31dfb9873708") catch unreachable, -// .community_id = exp_community, -// .name = "test invite", -// .code = exp_code, - -// .created_at = DateTime.parse("2022-12-21T09:05:50Z") catch unreachable, -// .times_used = 0, - -// .expires_at = null, -// .max_uses = null, -// }); -// } -// }; -// const auth = struct { -// fn register( -// _: TestDb, -// username: []const u8, -// password: []const u8, -// community_id: Uuid, -// _: RegistrationOptions, -// _: std.mem.Allocator, -// ) !Uuid { -// try std.testing.expectEqualStrings("root", username); -// try std.testing.expectEqualStrings("password", password); -// try std.testing.expectEqual(exp_community, community_id); - -// return Uuid.parse("6d951fcc-1c9f-497b-9c96-31dfb9873708") catch unreachable; -// } -// }; -// }){}; -// defer conn.close(); - -// const result = try conn.register("root", "password", .{}); -// try std.allocator. - -// } diff --git a/src/api/methods/auth.zig b/src/api/methods/auth.zig new file mode 100644 index 0000000..3829851 --- /dev/null +++ b/src/api/methods/auth.zig @@ -0,0 +1,162 @@ +const std = @import("std"); +const util = @import("util"); +const types = @import("../types.zig"); + +const Uuid = util.Uuid; +const DateTime = util.DateTime; +const RegistrationOptions = @import("../lib.zig").RegistrationOptions; +const UserResponse = @import("../lib.zig").UserResponse; +const Invite = @import("../lib.zig").Invite; + +pub fn methods(comptime models: type) type { + return struct { + fn isInviteValid(invite: Invite) bool { + if (invite.max_uses != null and invite.times_used >= invite.max_uses.?) return false; + if (invite.expires_at != null and DateTime.now().isAfter(invite.expires_at.?)) return false; + return true; + } + pub fn register(self: anytype, username: []const u8, password: []const u8, opt: RegistrationOptions) !types.Actor { + const tx = try self.db.beginOrSavepoint(); + const maybe_invite = if (opt.invite_code) |code| + try models.invites.getByCode(tx, code, self.community.id, self.allocator) + else + null; + defer if (maybe_invite) |inv| util.deepFree(self.allocator, inv); + + if (maybe_invite) |invite| { + if (!Uuid.eql(invite.community_id, self.community.id)) return error.WrongCommunity; + if (!isInviteValid(invite)) return error.InvalidInvite; + } + + const invite_kind = if (maybe_invite) |inv| inv.kind else .user; + + if (self.community.kind == .admin) @panic("Unimplmented"); + + const user_id = try models.auth.register( + tx, + username, + password, + self.community.id, + .{ + .invite_id = if (maybe_invite) |inv| @as(?Uuid, inv.id) else null, + .email = opt.email, + }, + self.allocator, + ); + + switch (invite_kind) { + .user => {}, + .system => @panic("System user invites unimplemented"), + .community_owner => { + try models.communities.transferOwnership(tx, self.community.id, user_id); + }, + } + + const user = models.actors.get(tx, user_id, self.allocator) catch |err| switch (err) { + error.NotFound => return error.Unexpected, + else => |e| return e, + }; + errdefer util.deepFree(self.allocator, user); + + try tx.commitOrRelease(); + return user; + } + }; +} + +const TestDb = struct { + tx_level: usize = 0, + rolled_back: bool = false, + committed: bool = false, + fn beginOrSavepoint(self: *TestDb) !*TestDb { + self.tx_level += 1; + return self; + } + + fn rollback(self: *TestDb) void { + self.rolled_back = true; + self.tx_level -= 1; + } + + fn commitOrRelease(self: *TestDb) !void { + self.committed = true; + self.tx_level -= 1; + } +}; + +test "register" { + comptime var exp_code = "code"; + comptime var exp_community = Uuid.parse("a210c035-c9e1-4361-82a2-aaeac8e40dc6") catch unreachable; + comptime var uid = Uuid.parse("6d951fcc-1c9f-497b-9c96-31dfb9873708") catch unreachable; + + const MockSvc = struct { + const invites = struct { + fn getByCode(db: *TestDb, code: []const u8, community_id: Uuid, alloc: std.mem.Allocator) !Invite { + try std.testing.expectEqual(db.tx_level, 1); + try std.testing.expectEqualStrings(exp_code, code); + try std.testing.expectEqual(exp_community, community_id); + + return try util.deepClone(alloc, Invite{ + .id = Uuid.parse("eac18f43-4dcc-489f-9fb5-4c1633e7b4e0") catch unreachable, + + .created_by = Uuid.parse("6d951fcc-1c9f-497b-9c96-31dfb9873708") catch unreachable, + .community_id = exp_community, + .name = "test invite", + .code = exp_code, + + .kind = .user, + + .created_at = DateTime.parse("2022-12-21T09:05:50Z") catch unreachable, + .times_used = 0, + + .expires_at = null, + .max_uses = null, + }); + } + }; + const auth = struct { + fn register( + db: *TestDb, + username: []const u8, + password: []const u8, + community_id: Uuid, + _: @import("../services/auth.zig").RegistrationOptions, + _: std.mem.Allocator, + ) anyerror!Uuid { + try std.testing.expectEqual(db.tx_level, 1); + try std.testing.expectEqualStrings("root", username); + try std.testing.expectEqualStrings("password", password); + try std.testing.expectEqual(exp_community, community_id); + + return uid; + } + }; + const actors = struct { + fn get(_: *TestDb, id: Uuid, alloc: std.mem.Allocator) anyerror!types.Actor { + try std.testing.expectEqual(uid, id); + return try util.deepClone(alloc, std.mem.zeroInit(types.Actor, .{ + .id = id, + .username = "root", + .host = "example.com", + .community_id = exp_community, + })); + } + }; + const communities = struct { + fn transferOwnership(_: *TestDb, _: Uuid, _: Uuid) anyerror!void {} + }; + }; + + var db = TestDb{}; + util.deepFree(std.testing.allocator, try methods(MockSvc).register(.{ + .db = &db, + .allocator = std.testing.allocator, + .community = .{ + .id = exp_community, + .kind = .local, + }, + }, "root", "password", .{})); + try std.testing.expectEqual(false, db.rolled_back); + try std.testing.expectEqual(true, db.committed); + try std.testing.expectEqual(@as(usize, 0), db.tx_level); +} diff --git a/src/api/services/actors.zig b/src/api/services/actors.zig index 48567da..95bfe10 100644 --- a/src/api/services/actors.zig +++ b/src/api/services/actors.zig @@ -3,7 +3,7 @@ const util = @import("util"); const sql = @import("sql"); const common = @import("./common.zig"); const files = @import("./files.zig"); -const types = @import("./types.zig"); +const types = @import("../types.zig"); const Uuid = util.Uuid; const DateTime = util.DateTime; diff --git a/src/api/services/auth.zig b/src/api/services/auth.zig index 9cf1734..b0f8200 100644 --- a/src/api/services/auth.zig +++ b/src/api/services/auth.zig @@ -1,7 +1,7 @@ const std = @import("std"); const util = @import("util"); const actors = @import("./actors.zig"); -const types = @import("./types.zig"); +const types = @import("../types.zig"); const Token = types.Token; const Uuid = util.Uuid; diff --git a/src/api/services/communities.zig b/src/api/services/communities.zig index 8bc3869..6b575d5 100644 --- a/src/api/services/communities.zig +++ b/src/api/services/communities.zig @@ -3,7 +3,7 @@ const builtin = @import("builtin"); const util = @import("util"); const sql = @import("sql"); const actors = @import("./actors.zig"); -const types = @import("./types.zig"); +const types = @import("../types.zig"); const Uuid = util.Uuid; const DateTime = util.DateTime; diff --git a/src/api/services/files.zig b/src/api/services/files.zig index ca8361d..9b4412d 100644 --- a/src/api/services/files.zig +++ b/src/api/services/files.zig @@ -1,7 +1,7 @@ const std = @import("std"); const sql = @import("sql"); const util = @import("util"); -const types = @import("./types.zig"); +const types = @import("../types.zig"); const Uuid = util.Uuid; const DateTime = util.DateTime; diff --git a/src/api/services/invites.zig b/src/api/services/invites.zig index dff5401..fe0ce6f 100644 --- a/src/api/services/invites.zig +++ b/src/api/services/invites.zig @@ -1,7 +1,7 @@ const std = @import("std"); const builtin = @import("builtin"); const util = @import("util"); -const types = @import("./types.zig"); +const types = @import("../types.zig"); const Uuid = util.Uuid; const DateTime = util.DateTime; diff --git a/src/api/services/notes.zig b/src/api/services/notes.zig index 59782d4..0cce202 100644 --- a/src/api/services/notes.zig +++ b/src/api/services/notes.zig @@ -2,7 +2,7 @@ const std = @import("std"); const util = @import("util"); const sql = @import("sql"); const common = @import("./common.zig"); -const types = @import("./types.zig"); +const types = @import("../types.zig"); const Uuid = util.Uuid; const DateTime = util.DateTime; diff --git a/src/api/services/types.zig b/src/api/types.zig similarity index 100% rename from src/api/services/types.zig rename to src/api/types.zig