From e0fd7097eb1c85e3993e0b6ef92b934840cce014 Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Wed, 7 Sep 2022 23:56:29 -0700 Subject: [PATCH] User creation --- src/main/api.zig | 40 ++++++++++++--- src/main/api/communities.zig | 15 ++++-- src/main/api/invites.zig | 73 +++++++++++++++++++++++++--- src/main/api/users.zig | 41 ++++++++++++++-- src/main/controllers.zig | 1 + src/main/controllers/communities.zig | 2 +- src/main/controllers/users.zig | 43 ++++++---------- src/main/db.zig | 4 +- src/main/db/migrations.zig | 7 ++- src/main/main.zig | 2 + src/util/DateTime.zig | 4 ++ src/util/Uuid.zig | 2 +- src/util/lib.zig | 6 +++ 13 files changed, 186 insertions(+), 54 deletions(-) diff --git a/src/main/api.zig b/src/main/api.zig index c26f033..fe899a2 100644 --- a/src/main/api.zig +++ b/src/main/api.zig @@ -26,6 +26,13 @@ const services = struct { const invites = @import("./api/invites.zig"); }; +pub const RegistrationRequest = struct { + username: []const u8, + password: []const u8, + invite_code: []const u8, + email: ?[]const u8, +}; + pub const InviteRequest = struct { pub const Type = services.invites.InviteType; @@ -69,13 +76,6 @@ pub fn firstIndexOf(str: []const u8, ch: u8) ?usize { pub const Scheme = models.Community.Scheme; -pub const RegistrationInfo = struct { - username: []const u8, - password: []const u8, - email: ?[]const u8, - invite_code: ?[]const u8, -}; - pub const LoginResult = struct { user_id: Uuid, token: [token_str_len]u8, @@ -257,5 +257,31 @@ fn ApiConn(comptime DbConn: type) type { .invite_type = options.invite_type, }, self.arena.allocator()); } + + pub fn register(self: *Self, request: RegistrationRequest) !services.users.User { + std.log.debug("registering user {s} with code {s}", .{ request.username, request.invite_code }); + const invite = try services.invites.getByCode(&self.db, request.invite_code, self.arena.allocator()); + + if (!Uuid.eql(invite.to_community, self.community_id)) return error.NotFound; + if (invite.max_uses != null and invite.times_used >= invite.max_uses.?) return error.InviteExpired; + if (invite.expires_at != null and DateTime.now().isAfter(invite.expires_at.?)) return error.InviteExpired; + + if (self.community_id == null) @panic("Unimplmented"); + + const user_id = try services.users.create(&self.db, request.username, request.password, self.community_id, .{ .invite_id = invite.id, .email = request.email }, self.internal_alloc); + + switch (invite.invite_type) { + .user => {}, + .system => @panic("System user invites unimplemented"), + .community_owner => { + try services.communities.transferOwnership(&self.db, self.community_id.?, user_id); + }, + } + + return services.users.get(&self.db, user_id, self.arena.allocator()) catch |err| switch (err) { + error.NotFound => error.Unexpected, + else => err, + }; + } }; } diff --git a/src/main/api/communities.zig b/src/main/api/communities.zig index 2f6d416..069e023 100644 --- a/src/main/api/communities.zig +++ b/src/main/api/communities.zig @@ -27,6 +27,7 @@ pub const Scheme = enum { pub const Community = struct { id: Uuid, + owner_id: ?Uuid, host: []const u8, name: []const u8, @@ -61,6 +62,7 @@ pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Co const community = Community{ .id = id, + .owner_id = null, .host = host, .name = name orelse host, .scheme = scheme, @@ -84,12 +86,17 @@ fn firstIndexOf(str: []const u8, ch: u8) ?usize { } pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) !Community { - const result = (try db.execRow(&.{ Uuid, []const u8, []const u8, Scheme }, "SELECT id, host, name, scheme FROM community WHERE host = ?", .{host}, alloc)) orelse return error.NotFound; + const result = (try db.execRow(&.{ Uuid, ?Uuid, []const u8, []const u8, Scheme }, "SELECT id, owner_id, host, name, scheme FROM community WHERE host = ?", .{host}, alloc)) orelse return error.NotFound; return Community{ .id = result[0], - .host = result[1], - .name = result[2], - .scheme = result[3], + .owner_id = result[1], + .host = result[2], + .name = result[3], + .scheme = result[4], }; } + +pub fn transferOwnership(db: anytype, community_id: Uuid, new_owner: Uuid) !void { + _ = try db.execRow(&.{i64}, "UPDATE community SET owner_id = ? WHERE id = ?", .{ new_owner, community_id }, null); +} diff --git a/src/main/api/invites.zig b/src/main/api/invites.zig index 27da401..6e1f9c4 100644 --- a/src/main/api/invites.zig +++ b/src/main/api/invites.zig @@ -48,6 +48,21 @@ pub const Invite = struct { invite_type: InviteType, }; +const DbModel = struct { + id: Uuid, + + created_by: Uuid, // User ID + to_community: ?Uuid, + name: []const u8, + code: []const u8, + + created_at: DateTime, + expires_at: ?DateTime, + + max_uses: ?usize, + + @"type": InviteType, +}; fn cloneStr(str: []const u8, alloc: std.mem.Allocator) ![]const u8 { const new = try alloc.alloc(u8, str.len); std.mem.copy(u8, new, str); @@ -74,11 +89,12 @@ pub fn create(db: anytype, created_by: Uuid, to_community: ?Uuid, options: Invit try cloneStr(name, alloc) else try cloneStr(code, alloc); + errdefer alloc.free(name); const id = Uuid.randV4(getRandom()); const created_at = DateTime.now(); - const invite = Invite{ + try db.insert("invite", DbModel{ .id = id, .created_by = created_by, @@ -87,15 +103,58 @@ pub fn create(db: anytype, created_by: Uuid, to_community: ?Uuid, options: Invit .code = code, .created_at = created_at, - .times_used = 0, - .expires_at = options.expires_at, + + .max_uses = options.max_uses, + + .@"type" = options.invite_type, + }); + + return Invite{ + .id = id, + + .created_by = created_by, + .to_community = to_community, + .name = name, + .code = code, + + .created_at = created_at, + .expires_at = options.expires_at, + + .times_used = 0, .max_uses = options.max_uses, .invite_type = options.invite_type, }; - - try db.insert("invite", invite); - - return invite; +} + +pub fn getByCode(db: anytype, code: []const u8, alloc: std.mem.Allocator) !Invite { + const code_clone = try cloneStr(code, alloc); + const info = (try db.execRow(&.{ Uuid, Uuid, Uuid, []const u8, DateTime, ?DateTime, usize, ?usize, InviteType }, + \\SELECT + \\ invite.id, invite.created_by, invite.to_community, invite.name, + \\ invite.created_at, invite.expires_at, + \\ COUNT(local_user.user_id) as uses, invite.max_uses, + \\ invite.type + \\FROM invite LEFT OUTER JOIN local_user ON invite.id = local_user.invite_id + \\WHERE invite.code = ? + \\GROUP BY invite.id + , .{code}, alloc)) orelse return error.NotFound; + + return Invite{ + .id = info[0], + + .created_by = info[1], + .to_community = info[2], + .name = info[3], + .code = code_clone, + + .created_at = info[4], + .expires_at = info[5], + + .times_used = info[6], + .max_uses = info[7], + + .invite_type = info[8], + }; } diff --git a/src/main/api/users.zig b/src/main/api/users.zig index ff0cf15..0a03efa 100644 --- a/src/main/api/users.zig +++ b/src/main/api/users.zig @@ -3,6 +3,7 @@ const util = @import("util"); const auth = @import("./auth.zig"); const Uuid = util.Uuid; +const DateTime = util.DateTime; const getRandom = @import("../api.zig").getRandom; const UserAuthInfo = struct { @@ -16,14 +17,14 @@ pub const CreateError = error{ DbError, }; -const User = struct { +const DbUser = struct { id: Uuid, username: []const u8, community_id: ?Uuid, }; -const LocalUser = struct { +const DbLocalUser = struct { user_id: Uuid, invite_id: ?Uuid, @@ -72,7 +73,7 @@ pub fn create( password: []const u8, community_id: ?Uuid, options: CreateOptions, - alloc: std.mem.Allocator, + password_alloc: std.mem.Allocator, ) CreateError!Uuid { const id = Uuid.randV4(getRandom()); if ((try lookupByUsername(db, username, community_id)) != null) { @@ -84,7 +85,7 @@ pub fn create( .username = username, .community_id = community_id, }) catch return error.DbError; - try auth.passwords.create(db, id, password, alloc); + try auth.passwords.create(db, id, password, password_alloc); db.insert("local_user", .{ .user_id = id, .invite_id = options.invite_id, @@ -93,3 +94,35 @@ pub fn create( return id; } + +pub const User = struct { + id: Uuid, + + username: []const u8, + host: []const u8, + + community_id: Uuid, + + created_at: DateTime, +}; + +pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !User { + const result = (try db.execRow( + &.{ []const u8, []const u8, Uuid, DateTime }, + \\SELECT user.username, community.host, community.id, user.created_at + \\FROM user JOIN community ON user.community_id = community.id + \\WHERE user.id = ? + \\LIMIT 1 + , + .{id}, + alloc, + )) orelse return error.NotFound; + + return User{ + .id = id, + .username = result[0], + .host = result[1], + .community_id = result[2], + .created_at = result[3], + }; +} diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 316c84d..844918a 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -8,6 +8,7 @@ const Uuid = @import("util").Uuid; pub const auth = @import("./controllers/auth.zig"); pub const communities = @import("./controllers/communities.zig"); pub const invites = @import("./controllers/invites.zig"); +pub const users = @import("./controllers/users.zig"); pub const utils = struct { const json_options = if (builtin.mode == .Debug) .{ diff --git a/src/main/controllers/communities.zig b/src/main/controllers/communities.zig index 777aa56..952fd90 100644 --- a/src/main/controllers/communities.zig +++ b/src/main/controllers/communities.zig @@ -7,7 +7,7 @@ const RequestServer = root.RequestServer; const RouteArgs = http.RouteArgs; pub const create = struct { - pub const method = .GET; + pub const method = .POST; pub const path = "/communities"; pub fn handler(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { const opt = try utils.parseRequestBody(struct { origin: []const u8 }, ctx); diff --git a/src/main/controllers/users.zig b/src/main/controllers/users.zig index bba4229..1be934d 100644 --- a/src/main/controllers/users.zig +++ b/src/main/controllers/users.zig @@ -4,38 +4,27 @@ const builtin = @import("builtin"); const http = @import("http"); const Uuid = @import("util").Uuid; -const RegistrationInfo = @import("../api.zig").RegistrationInfo; +const RegistrationRequest = @import("../api.zig").RegistrationRequest; const utils = @import("../controllers.zig").utils; const RequestServer = root.RequestServer; const RouteArgs = http.RouteArgs; -pub fn register(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { - const info = try utils.parseRequestBody(RegistrationInfo, ctx); - defer utils.freeRequestBody(info, ctx.alloc); +pub const create = struct { + pub const method = .POST; + pub const path = "/users"; + pub fn handler(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { + const info = try utils.parseRequestBody(RegistrationRequest, ctx); + defer utils.freeRequestBody(info, ctx.alloc); - var api = try utils.getApiConn(srv, ctx); - defer api.close(); + var api = try utils.getApiConn(srv, ctx); + defer api.close(); - const user = api.register(info) catch |err| switch (err) { - error.UsernameUnavailable => return utils.respondError(ctx, .bad_request, "Username Unavailable"), - else => return err, - }; + const user = api.register(info) catch |err| switch (err) { + error.UsernameTaken => return utils.respondError(ctx, .bad_request, "Username Unavailable"), + else => return err, + }; - try utils.respondJson(ctx, .created, user); -} - -pub fn login(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { - const credentials = try utils.parseRequestBody(struct { username: []const u8, password: []const u8 }, ctx); - defer utils.freeRequestBody(credentials, ctx.alloc); - - var api = try utils.getApiConn(srv, ctx); - defer api.close(); - - const token = api.login(credentials.username, credentials.password) catch |err| switch (err) { - error.PasswordVerificationFailed => return utils.respondError(ctx, .bad_request, "Invalid Login"), - else => return err, - }; - - try utils.respondJson(ctx, .ok, token); -} + try utils.respondJson(ctx, .created, user); + } +}; diff --git a/src/main/db.zig b/src/main/db.zig index e36d4bf..fa7aa46 100644 --- a/src/main/db.zig +++ b/src/main/db.zig @@ -86,7 +86,7 @@ fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: ?std.mem.Allocator) DateTime => row.getDateTime(idx), else => switch (@typeInfo(T)) { - .Optional => if (row.isNull(idx)) + .Optional => if (try row.isNull(idx)) null else try getAlloc(row, std.meta.Child(T), idx, alloc), @@ -98,6 +98,8 @@ fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: ?std.mem.Allocator) .Enum => try getEnum(row, T, idx), + .Int => @intCast(T, try row.getI64(idx)), + //else => unreachable, else => @compileError("unknown type " ++ @typeName(T)), }, diff --git a/src/main/db/migrations.zig b/src/main/db/migrations.zig index e7e40a6..8b993c5 100644 --- a/src/main/db/migrations.zig +++ b/src/main/db/migrations.zig @@ -153,13 +153,15 @@ const migrations: []const Migration = &.{ \\ id TEXT NOT NULL PRIMARY KEY, \\ \\ name TEXT NOT NULL, - \\ invite_code TEXT NOT NULL UNIQUE, + \\ code TEXT NOT NULL UNIQUE, \\ created_by TEXT NOT NULL REFERENCES local_user(id), \\ \\ max_uses INTEGER, \\ \\ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - \\ expires_at DATETIME + \\ expires_at DATETIME, + \\ + \\ type TEXT NOT NULL CHECK (type in ('system', 'community_owner', 'user')) \\); \\ALTER TABLE local_user ADD COLUMN invite_id TEXT REFERENCES invite(id); , @@ -174,6 +176,7 @@ const migrations: []const Migration = &.{ \\CREATE TABLE community( \\ id TEXT NOT NULL PRIMARY KEY, \\ + \\ owner_id TEXT REFERENCES user(id), \\ name TEXT NOT NULL, \\ host TEXT NOT NULL UNIQUE, \\ scheme TEXT NOT NULL CHECK (scheme IN ('http', 'https')), diff --git a/src/main/main.zig b/src/main/main.zig index 00bfaca..c8ef53f 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -26,6 +26,8 @@ const router = Router{ prepare(c.invites.create), + prepare(c.users.create), + //Route.new(.POST, "/notes", &c.notes.create), //Route.new(.GET, "/notes/:id", &c.notes.get), diff --git a/src/util/DateTime.zig b/src/util/DateTime.zig index 8a52173..072191a 100644 --- a/src/util/DateTime.zig +++ b/src/util/DateTime.zig @@ -8,6 +8,10 @@ pub fn now() DateTime { return .{ .seconds_since_epoch = std.time.timestamp() }; } +pub fn isAfter(lhs: DateTime, rhs: DateTime) bool { + return lhs.seconds_since_epoch > rhs.seconds_since_epoch; +} + pub fn epochSeconds(value: DateTime) std.time.epoch.EpochSeconds { return .{ .secs = @intCast(u64, value.seconds_since_epoch) }; } diff --git a/src/util/Uuid.zig b/src/util/Uuid.zig index 72fda04..1aa95a0 100644 --- a/src/util/Uuid.zig +++ b/src/util/Uuid.zig @@ -11,7 +11,7 @@ pub fn eql(lhs: ?Uuid, rhs: ?Uuid) bool { if (lhs == null and rhs == null) return true; if (lhs == null or rhs == null) return false; - return lhs.data == rhs.data; + return lhs.?.data == rhs.?.data; } pub fn toCharArray(value: Uuid) [string_len]u8 { diff --git a/src/util/lib.zig b/src/util/lib.zig index ac57039..3ab69c8 100644 --- a/src/util/lib.zig +++ b/src/util/lib.zig @@ -5,6 +5,12 @@ pub const Uuid = @import("./Uuid.zig"); pub const DateTime = @import("./DateTime.zig"); pub const PathIter = @import("./PathIter.zig"); +pub fn cloneStr(str: []const u8, alloc: std.mem.Allocator) ![]const u8 { + var new = try alloc.alloc(u8, str.len); + std.mem.copy(u8, new, str); + return new; +} + pub const case = struct { // returns the number of capital letters in a string. // only works with ascii characters