diff --git a/src/api/lib.zig b/src/api/lib.zig
index a3f4cd6..00a16eb 100644
--- a/src/api/lib.zig
+++ b/src/api/lib.zig
@@ -204,6 +204,13 @@ pub const FileResult = struct {
data: []const u8,
};
+pub const ValidInvite = struct {
+ code: []const u8,
+ kind: services.invites.Kind,
+ name: []const u8,
+ creator: UserResponse,
+};
+
pub fn isAdminSetup(db: sql.Db) !bool {
_ = services.communities.adminCommunityId(db) catch |err| switch (err) {
error.NotFound => return false,
@@ -396,6 +403,12 @@ fn ApiConn(comptime DbConn: type) type {
return try services.invites.get(self.db, invite_id, self.allocator);
}
+ fn isInviteValid(invite: services.invites.Invite) bool {
+ if (invite.max_uses != null and invite.times_used >= invite.max_uses.?) return false;
+ if (invite.expires_at != null and DateTime.now().isAfter(invite.expires_at.?)) return false;
+ return true;
+ }
+
pub fn register(self: *Self, username: []const u8, password: []const u8, opt: RegistrationOptions) !UserResponse {
const tx = try self.db.beginOrSavepoint();
const maybe_invite = if (opt.invite_code) |code|
@@ -406,8 +419,7 @@ fn ApiConn(comptime DbConn: type) type {
if (maybe_invite) |invite| {
if (!Uuid.eql(invite.community_id, self.community.id)) return error.WrongCommunity;
- if (invite.max_uses != null and invite.times_used >= invite.max_uses.?) return error.InviteExpired;
- if (invite.expires_at != null and DateTime.now().isAfter(invite.expires_at.?)) return error.InviteExpired;
+ if (!isInviteValid(invite)) return error.InvalidInvite;
}
const invite_kind = if (maybe_invite) |inv| inv.kind else .user;
@@ -434,19 +446,18 @@ fn ApiConn(comptime DbConn: type) type {
},
}
- return self.getUser(user_id) catch |err| switch (err) {
- error.NotFound => error.Unexpected,
- else => err,
+ const user = self.getUserUnchecked(tx, user_id) catch |err| switch (err) {
+ error.NotFound => return error.Unexpected,
+ else => |e| return e,
};
- }
-
- pub fn getUser(self: *Self, user_id: Uuid) !UserResponse {
- const user = try services.actors.get(self.db, user_id, self.allocator);
errdefer util.deepFree(self.allocator, user);
- if (self.user_id == null) {
- if (!Uuid.eql(self.community.id, user.community_id)) return error.NotFound;
- }
+ try tx.commit();
+ return user;
+ }
+
+ fn getUserUnchecked(self: *Self, db: anytype, user_id: Uuid) !UserResponse {
+ const user = try services.actors.get(db, user_id, self.allocator);
return UserResponse{
.id = user.id,
@@ -469,6 +480,17 @@ fn ApiConn(comptime DbConn: type) type {
};
}
+ pub fn getUser(self: *Self, user_id: Uuid) !UserResponse {
+ const user = try self.getUserUnchecked(self.db, user_id);
+ errdefer util.deepFree(self.allocator, user);
+
+ if (self.user_id == null) {
+ if (!Uuid.eql(self.community.id, user.community_id)) return error.NotFound;
+ }
+
+ return user;
+ }
+
pub fn createNote(self: *Self, content: []const u8) !NoteResponse {
// You cannot post on admin accounts
if (self.community.kind == .admin) return error.WrongCommunity;
@@ -747,5 +769,33 @@ fn ApiConn(comptime DbConn: type) type {
if (!Uuid.eql(id, self.user_id orelse return error.NoToken)) return error.AccessDenied;
try services.actors.updateProfile(self.db, id, data, self.allocator);
}
+
+ pub fn validateInvite(self: *Self, code: []const u8) !ValidInvite {
+ const invite = services.invites.getByCode(
+ self.db,
+ code,
+ self.community.id,
+ self.allocator,
+ ) catch |err| switch (err) {
+ error.NotFound => return error.InvalidInvite,
+ else => return error.DatabaseFailure,
+ };
+ errdefer util.deepFree(self.allocator, invite);
+
+ if (!Uuid.eql(invite.community_id, self.community.id)) return error.InvalidInvite;
+ if (!isInviteValid(invite)) return error.InvalidInvite;
+
+ const creator = self.getUserUnchecked(self.db, invite.created_by) catch |err| switch (err) {
+ error.NotFound => return error.Unexpected,
+ else => return error.DatabaseFailure,
+ };
+
+ return ValidInvite{
+ .code = invite.code,
+ .name = invite.name,
+ .kind = invite.kind,
+ .creator = creator,
+ };
+ }
};
}
diff --git a/src/api/services/actors.zig b/src/api/services/actors.zig
index c51c23e..0244336 100644
--- a/src/api/services/actors.zig
+++ b/src/api/services/actors.zig
@@ -67,12 +67,12 @@ pub const UsernameValidationError = error{
/// - Be at least 1 character
/// - Be no more than 32 characters
/// - All characters are in [A-Za-z0-9_]
-pub fn validateUsername(username: []const u8) UsernameValidationError!void {
+pub fn validateUsername(username: []const u8, lax: bool) UsernameValidationError!void {
if (username.len == 0) return error.UsernameEmpty;
if (username.len > max_username_chars) return error.UsernameTooLong;
for (username) |ch| {
- const valid = std.ascii.isAlNum(ch) or ch == '_';
+ const valid = std.ascii.isAlNum(ch) or ch == '_' or (lax and ch == '.');
if (!valid) return error.UsernameContainsInvalidChar;
}
}
@@ -81,11 +81,12 @@ pub fn create(
db: anytype,
username: []const u8,
community_id: Uuid,
+ lax_username: bool,
alloc: std.mem.Allocator,
) CreateError!Uuid {
const id = Uuid.randV4(util.getThreadPrng());
- try validateUsername(username);
+ try validateUsername(username, lax_username);
db.insert("actor", .{
.id = id,
@@ -153,8 +154,11 @@ pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Actor {
.{id},
alloc,
) catch |err| switch (err) {
- error.NoRows => error.NotFound,
- else => error.DatabaseFailure,
+ error.NoRows => return error.NotFound,
+ else => |e| {
+ std.log.err("{}, {?}", .{ e, @errorReturnTrace() });
+ return error.DatabaseFailure;
+ },
};
}
diff --git a/src/api/services/auth.zig b/src/api/services/auth.zig
index 426c734..88e0f20 100644
--- a/src/api/services/auth.zig
+++ b/src/api/services/auth.zig
@@ -36,14 +36,14 @@ pub fn register(
if (password.len < min_password_chars) return error.PasswordTooShort;
// perform pre-validation to avoid having to hash the password if it fails
- try actors.validateUsername(username);
+ try actors.validateUsername(username, false);
const hash = try hashPassword(password, alloc);
defer alloc.free(hash);
const tx = db.beginOrSavepoint() catch return error.DatabaseFailure;
errdefer tx.rollback();
- const id = try actors.create(tx, username, community_id, alloc);
+ const id = try actors.create(tx, username, community_id, false, alloc);
tx.insert("account", .{
.id = id,
.invite_id = options.invite_id,
diff --git a/src/api/services/communities.zig b/src/api/services/communities.zig
index 824957e..81ae8b4 100644
--- a/src/api/services/communities.zig
+++ b/src/api/services/communities.zig
@@ -97,7 +97,7 @@ pub fn create(db: anytype, origin: []const u8, options: CreateOptions, alloc: st
}, alloc);
if (options.kind == .local) {
- const actor_id = actors.create(tx, "community.actor", id, alloc) catch |err| switch (err) {
+ const actor_id = actors.create(tx, "community.actor", id, true, alloc) catch |err| switch (err) {
error.UsernameContainsInvalidChar,
error.UsernameTooLong,
error.UsernameEmpty,
@@ -109,7 +109,6 @@ pub fn create(db: anytype, origin: []const u8, options: CreateOptions, alloc: st
\\UPDATE community
\\SET community_actor_id = $1
\\WHERE id = $2
- \\LIMIT 1
, .{ actor_id, id }, alloc);
}
diff --git a/src/http/middleware.zig b/src/http/middleware.zig
index ddb88c3..825f3c7 100644
--- a/src/http/middleware.zig
+++ b/src/http/middleware.zig
@@ -201,7 +201,7 @@ pub fn CatchErrors(comptime ErrorHandler: type) type {
return self.error_handler.handle(
req,
res,
- addField(ctx, "err", err),
+ addField(addField(ctx, "err", err), "err_trace", @errorReturnTrace()),
next,
);
};
@@ -218,7 +218,10 @@ pub fn catchErrors(error_handler: anytype) CatchErrors(@TypeOf(error_handler)) {
pub const default_error_handler = struct {
fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, _: anytype) !void {
const should_log = !@import("builtin").is_test;
- if (should_log) std.log.err("Error {} on uri {s}", .{ ctx.err, req.uri });
+ if (should_log) {
+ std.log.err("Error {} on uri {s}", .{ ctx.err, req.uri });
+ std.log.debug("Additional details: {?}", .{ctx.err_trace});
+ }
// Tell the server to close the connection after this request
res.should_close = true;
@@ -335,12 +338,12 @@ pub fn Router(comptime Routes: type) type {
_ = next;
inline for (self.routes) |r| {
- if (r.handle(req, res, ctx, {})) |_|
+ if (r.handle(req, res, ctx, {}))
// success
return
else |err| switch (err) {
error.RouteMismatch => {},
- else => return err,
+ else => |e| return e,
}
}
@@ -406,10 +409,10 @@ pub const Route = struct {
}
pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
- return if (self.applies(req, ctx))
- next.handle(req, res, ctx, {})
+ if (self.applies(req, ctx))
+ return next.handle(req, res, ctx, {})
else
- error.RouteMismatch;
+ return error.RouteMismatch;
}
};
diff --git a/src/main/controllers/web.zig b/src/main/controllers/web.zig
index 0a05d33..7910d71 100644
--- a/src/main/controllers/web.zig
+++ b/src/main/controllers/web.zig
@@ -1,5 +1,6 @@
const std = @import("std");
const util = @import("util");
+const http = @import("http");
const controllers = @import("../controllers.zig");
pub const routes = .{
@@ -10,6 +11,9 @@ pub const routes = .{
controllers.apiEndpoint(cluster.overview),
controllers.apiEndpoint(media),
controllers.apiEndpoint(static),
+ controllers.apiEndpoint(signup.page),
+ controllers.apiEndpoint(signup.with_invite),
+ controllers.apiEndpoint(signup.submit),
};
const static = struct {
@@ -94,6 +98,101 @@ const login = struct {
}
};
+const signup = struct {
+ const tmpl = @embedFile("./web/signup.tmpl.html");
+
+ fn servePage(
+ invite_code: ?[]const u8,
+ error_msg: ?[]const u8,
+ status: http.Status,
+ res: anytype,
+ srv: anytype,
+ ) !void {
+ const invite = if (invite_code) |code| srv.validateInvite(code) catch |err| switch (err) {
+ error.InvalidInvite => return servePage(null, "Invite is not valid", .bad_request, res, srv),
+ else => |e| return e,
+ } else null;
+ defer util.deepFree(srv.allocator, invite);
+
+ try res.template(status, srv, tmpl, .{
+ .error_msg = error_msg,
+ .invite = invite,
+ });
+ }
+
+ const page = struct {
+ pub const path = "/signup";
+ pub const method = .GET;
+
+ pub fn handler(_: anytype, res: anytype, srv: anytype) !void {
+ try servePage(null, null, .ok, res, srv);
+ }
+ };
+
+ const with_invite = struct {
+ pub const path = "/invite/:code";
+ pub const method = .GET;
+
+ pub const Args = struct {
+ code: []const u8,
+ };
+
+ pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
+ std.log.debug("{s}", .{req.args.code});
+ try servePage(req.args.code, null, .ok, res, srv);
+ }
+ };
+
+ const submit = struct {
+ pub const path = "/signup";
+ pub const method = .POST;
+
+ pub const Body = struct {
+ username: []const u8,
+ password: []const u8,
+ email: ?[]const u8 = null,
+ invite_code: ?[]const u8 = null,
+ };
+
+ pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
+ const user = srv.register(req.body.username, req.body.password, .{
+ .email = req.body.email,
+ .invite_code = req.body.invite_code,
+ }) catch |err| {
+ var status: http.Status = .bad_request;
+ const err_msg = switch (err) {
+ error.UsernameEmpty => "Username cannot be empty",
+ error.UsernameContainsInvalidChar => "Username must be composed of alphanumeric characters and underscore",
+ error.UsernameTooLong => "Username too long",
+ error.PasswordTooShort => "Password too short, must be at least 12 chars",
+
+ error.UsernameTaken => blk: {
+ status = .unprocessable_entity;
+ break :blk "Username is already registered";
+ },
+ else => blk: {
+ status = .internal_server_error;
+ break :blk "an internal error occurred";
+ },
+ };
+
+ return servePage(req.body.invite_code, err_msg, status, res, srv);
+ };
+ defer util.deepFree(srv.allocator, user);
+
+ const token = try srv.login(req.body.username, req.body.password);
+
+ try res.headers.put("Location", index.path);
+ var buf: [64]u8 = undefined;
+ const cookie_name = try std.fmt.bufPrint(&buf, "token.{s}", .{req.body.username});
+ try res.headers.setCookie(cookie_name, token.token, .{});
+ try res.headers.setCookie("active_account", req.body.username, .{ .HttpOnly = false });
+
+ try res.status(.see_other);
+ }
+ };
+};
+
const global_timeline = struct {
pub const path = "/timelines/global";
pub const method = .GET;
diff --git a/src/main/controllers/web/signup.tmpl.html b/src/main/controllers/web/signup.tmpl.html
new file mode 100644
index 0000000..f9a9254
--- /dev/null
+++ b/src/main/controllers/web/signup.tmpl.html
@@ -0,0 +1,54 @@
+{ %community.name }
+