diff --git a/src/api/lib.zig b/src/api/lib.zig index a3f4cd6..00a16eb 100644 --- a/src/api/lib.zig +++ b/src/api/lib.zig @@ -204,6 +204,13 @@ pub const FileResult = struct { data: []const u8, }; +pub const ValidInvite = struct { + code: []const u8, + kind: services.invites.Kind, + name: []const u8, + creator: UserResponse, +}; + pub fn isAdminSetup(db: sql.Db) !bool { _ = services.communities.adminCommunityId(db) catch |err| switch (err) { error.NotFound => return false, @@ -396,6 +403,12 @@ fn ApiConn(comptime DbConn: type) type { return try services.invites.get(self.db, invite_id, self.allocator); } + fn isInviteValid(invite: services.invites.Invite) bool { + if (invite.max_uses != null and invite.times_used >= invite.max_uses.?) return false; + if (invite.expires_at != null and DateTime.now().isAfter(invite.expires_at.?)) return false; + return true; + } + pub fn register(self: *Self, username: []const u8, password: []const u8, opt: RegistrationOptions) !UserResponse { const tx = try self.db.beginOrSavepoint(); const maybe_invite = if (opt.invite_code) |code| @@ -406,8 +419,7 @@ fn ApiConn(comptime DbConn: type) type { if (maybe_invite) |invite| { if (!Uuid.eql(invite.community_id, self.community.id)) return error.WrongCommunity; - if (invite.max_uses != null and invite.times_used >= invite.max_uses.?) return error.InviteExpired; - if (invite.expires_at != null and DateTime.now().isAfter(invite.expires_at.?)) return error.InviteExpired; + if (!isInviteValid(invite)) return error.InvalidInvite; } const invite_kind = if (maybe_invite) |inv| inv.kind else .user; @@ -434,19 +446,18 @@ fn ApiConn(comptime DbConn: type) type { }, } - return self.getUser(user_id) catch |err| switch (err) { - error.NotFound => error.Unexpected, - else => err, + const user = self.getUserUnchecked(tx, user_id) catch |err| switch (err) { + error.NotFound => return error.Unexpected, + else => |e| return e, }; - } - - pub fn getUser(self: *Self, user_id: Uuid) !UserResponse { - const user = try services.actors.get(self.db, user_id, self.allocator); errdefer util.deepFree(self.allocator, user); - if (self.user_id == null) { - if (!Uuid.eql(self.community.id, user.community_id)) return error.NotFound; - } + try tx.commit(); + return user; + } + + fn getUserUnchecked(self: *Self, db: anytype, user_id: Uuid) !UserResponse { + const user = try services.actors.get(db, user_id, self.allocator); return UserResponse{ .id = user.id, @@ -469,6 +480,17 @@ fn ApiConn(comptime DbConn: type) type { }; } + pub fn getUser(self: *Self, user_id: Uuid) !UserResponse { + const user = try self.getUserUnchecked(self.db, user_id); + errdefer util.deepFree(self.allocator, user); + + if (self.user_id == null) { + if (!Uuid.eql(self.community.id, user.community_id)) return error.NotFound; + } + + return user; + } + pub fn createNote(self: *Self, content: []const u8) !NoteResponse { // You cannot post on admin accounts if (self.community.kind == .admin) return error.WrongCommunity; @@ -747,5 +769,33 @@ fn ApiConn(comptime DbConn: type) type { if (!Uuid.eql(id, self.user_id orelse return error.NoToken)) return error.AccessDenied; try services.actors.updateProfile(self.db, id, data, self.allocator); } + + pub fn validateInvite(self: *Self, code: []const u8) !ValidInvite { + const invite = services.invites.getByCode( + self.db, + code, + self.community.id, + self.allocator, + ) catch |err| switch (err) { + error.NotFound => return error.InvalidInvite, + else => return error.DatabaseFailure, + }; + errdefer util.deepFree(self.allocator, invite); + + if (!Uuid.eql(invite.community_id, self.community.id)) return error.InvalidInvite; + if (!isInviteValid(invite)) return error.InvalidInvite; + + const creator = self.getUserUnchecked(self.db, invite.created_by) catch |err| switch (err) { + error.NotFound => return error.Unexpected, + else => return error.DatabaseFailure, + }; + + return ValidInvite{ + .code = invite.code, + .name = invite.name, + .kind = invite.kind, + .creator = creator, + }; + } }; } diff --git a/src/api/services/actors.zig b/src/api/services/actors.zig index c51c23e..0244336 100644 --- a/src/api/services/actors.zig +++ b/src/api/services/actors.zig @@ -67,12 +67,12 @@ pub const UsernameValidationError = error{ /// - Be at least 1 character /// - Be no more than 32 characters /// - All characters are in [A-Za-z0-9_] -pub fn validateUsername(username: []const u8) UsernameValidationError!void { +pub fn validateUsername(username: []const u8, lax: bool) UsernameValidationError!void { if (username.len == 0) return error.UsernameEmpty; if (username.len > max_username_chars) return error.UsernameTooLong; for (username) |ch| { - const valid = std.ascii.isAlNum(ch) or ch == '_'; + const valid = std.ascii.isAlNum(ch) or ch == '_' or (lax and ch == '.'); if (!valid) return error.UsernameContainsInvalidChar; } } @@ -81,11 +81,12 @@ pub fn create( db: anytype, username: []const u8, community_id: Uuid, + lax_username: bool, alloc: std.mem.Allocator, ) CreateError!Uuid { const id = Uuid.randV4(util.getThreadPrng()); - try validateUsername(username); + try validateUsername(username, lax_username); db.insert("actor", .{ .id = id, @@ -153,8 +154,11 @@ pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Actor { .{id}, alloc, ) catch |err| switch (err) { - error.NoRows => error.NotFound, - else => error.DatabaseFailure, + error.NoRows => return error.NotFound, + else => |e| { + std.log.err("{}, {?}", .{ e, @errorReturnTrace() }); + return error.DatabaseFailure; + }, }; } diff --git a/src/api/services/auth.zig b/src/api/services/auth.zig index 426c734..88e0f20 100644 --- a/src/api/services/auth.zig +++ b/src/api/services/auth.zig @@ -36,14 +36,14 @@ pub fn register( if (password.len < min_password_chars) return error.PasswordTooShort; // perform pre-validation to avoid having to hash the password if it fails - try actors.validateUsername(username); + try actors.validateUsername(username, false); const hash = try hashPassword(password, alloc); defer alloc.free(hash); const tx = db.beginOrSavepoint() catch return error.DatabaseFailure; errdefer tx.rollback(); - const id = try actors.create(tx, username, community_id, alloc); + const id = try actors.create(tx, username, community_id, false, alloc); tx.insert("account", .{ .id = id, .invite_id = options.invite_id, diff --git a/src/api/services/communities.zig b/src/api/services/communities.zig index 824957e..81ae8b4 100644 --- a/src/api/services/communities.zig +++ b/src/api/services/communities.zig @@ -97,7 +97,7 @@ pub fn create(db: anytype, origin: []const u8, options: CreateOptions, alloc: st }, alloc); if (options.kind == .local) { - const actor_id = actors.create(tx, "community.actor", id, alloc) catch |err| switch (err) { + const actor_id = actors.create(tx, "community.actor", id, true, alloc) catch |err| switch (err) { error.UsernameContainsInvalidChar, error.UsernameTooLong, error.UsernameEmpty, @@ -109,7 +109,6 @@ pub fn create(db: anytype, origin: []const u8, options: CreateOptions, alloc: st \\UPDATE community \\SET community_actor_id = $1 \\WHERE id = $2 - \\LIMIT 1 , .{ actor_id, id }, alloc); } diff --git a/src/http/middleware.zig b/src/http/middleware.zig index ddb88c3..825f3c7 100644 --- a/src/http/middleware.zig +++ b/src/http/middleware.zig @@ -201,7 +201,7 @@ pub fn CatchErrors(comptime ErrorHandler: type) type { return self.error_handler.handle( req, res, - addField(ctx, "err", err), + addField(addField(ctx, "err", err), "err_trace", @errorReturnTrace()), next, ); }; @@ -218,7 +218,10 @@ pub fn catchErrors(error_handler: anytype) CatchErrors(@TypeOf(error_handler)) { pub const default_error_handler = struct { fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, _: anytype) !void { const should_log = !@import("builtin").is_test; - if (should_log) std.log.err("Error {} on uri {s}", .{ ctx.err, req.uri }); + if (should_log) { + std.log.err("Error {} on uri {s}", .{ ctx.err, req.uri }); + std.log.debug("Additional details: {?}", .{ctx.err_trace}); + } // Tell the server to close the connection after this request res.should_close = true; @@ -335,12 +338,12 @@ pub fn Router(comptime Routes: type) type { _ = next; inline for (self.routes) |r| { - if (r.handle(req, res, ctx, {})) |_| + if (r.handle(req, res, ctx, {})) // success return else |err| switch (err) { error.RouteMismatch => {}, - else => return err, + else => |e| return e, } } @@ -406,10 +409,10 @@ pub const Route = struct { } pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - return if (self.applies(req, ctx)) - next.handle(req, res, ctx, {}) + if (self.applies(req, ctx)) + return next.handle(req, res, ctx, {}) else - error.RouteMismatch; + return error.RouteMismatch; } }; diff --git a/src/main/controllers/web.zig b/src/main/controllers/web.zig index 0a05d33..7910d71 100644 --- a/src/main/controllers/web.zig +++ b/src/main/controllers/web.zig @@ -1,5 +1,6 @@ const std = @import("std"); const util = @import("util"); +const http = @import("http"); const controllers = @import("../controllers.zig"); pub const routes = .{ @@ -10,6 +11,9 @@ pub const routes = .{ controllers.apiEndpoint(cluster.overview), controllers.apiEndpoint(media), controllers.apiEndpoint(static), + controllers.apiEndpoint(signup.page), + controllers.apiEndpoint(signup.with_invite), + controllers.apiEndpoint(signup.submit), }; const static = struct { @@ -94,6 +98,101 @@ const login = struct { } }; +const signup = struct { + const tmpl = @embedFile("./web/signup.tmpl.html"); + + fn servePage( + invite_code: ?[]const u8, + error_msg: ?[]const u8, + status: http.Status, + res: anytype, + srv: anytype, + ) !void { + const invite = if (invite_code) |code| srv.validateInvite(code) catch |err| switch (err) { + error.InvalidInvite => return servePage(null, "Invite is not valid", .bad_request, res, srv), + else => |e| return e, + } else null; + defer util.deepFree(srv.allocator, invite); + + try res.template(status, srv, tmpl, .{ + .error_msg = error_msg, + .invite = invite, + }); + } + + const page = struct { + pub const path = "/signup"; + pub const method = .GET; + + pub fn handler(_: anytype, res: anytype, srv: anytype) !void { + try servePage(null, null, .ok, res, srv); + } + }; + + const with_invite = struct { + pub const path = "/invite/:code"; + pub const method = .GET; + + pub const Args = struct { + code: []const u8, + }; + + pub fn handler(req: anytype, res: anytype, srv: anytype) !void { + std.log.debug("{s}", .{req.args.code}); + try servePage(req.args.code, null, .ok, res, srv); + } + }; + + const submit = struct { + pub const path = "/signup"; + pub const method = .POST; + + pub const Body = struct { + username: []const u8, + password: []const u8, + email: ?[]const u8 = null, + invite_code: ?[]const u8 = null, + }; + + pub fn handler(req: anytype, res: anytype, srv: anytype) !void { + const user = srv.register(req.body.username, req.body.password, .{ + .email = req.body.email, + .invite_code = req.body.invite_code, + }) catch |err| { + var status: http.Status = .bad_request; + const err_msg = switch (err) { + error.UsernameEmpty => "Username cannot be empty", + error.UsernameContainsInvalidChar => "Username must be composed of alphanumeric characters and underscore", + error.UsernameTooLong => "Username too long", + error.PasswordTooShort => "Password too short, must be at least 12 chars", + + error.UsernameTaken => blk: { + status = .unprocessable_entity; + break :blk "Username is already registered"; + }, + else => blk: { + status = .internal_server_error; + break :blk "an internal error occurred"; + }, + }; + + return servePage(req.body.invite_code, err_msg, status, res, srv); + }; + defer util.deepFree(srv.allocator, user); + + const token = try srv.login(req.body.username, req.body.password); + + try res.headers.put("Location", index.path); + var buf: [64]u8 = undefined; + const cookie_name = try std.fmt.bufPrint(&buf, "token.{s}", .{req.body.username}); + try res.headers.setCookie(cookie_name, token.token, .{}); + try res.headers.setCookie("active_account", req.body.username, .{ .HttpOnly = false }); + + try res.status(.see_other); + } + }; +}; + const global_timeline = struct { pub const path = "/timelines/global"; pub const method = .GET; diff --git a/src/main/controllers/web/signup.tmpl.html b/src/main/controllers/web/signup.tmpl.html new file mode 100644 index 0000000..f9a9254 --- /dev/null +++ b/src/main/controllers/web/signup.tmpl.html @@ -0,0 +1,54 @@ +
+

{ %community.name }

+
+
+

Sign Up

+ {#if .error_msg |$msg| =} +
Error: {$msg}
+ {= /if} + {#if .invite |$invite| =} +
+
You are about to accept an invite from:
+
+
+ {=#if $invite.creator.display_name |$name|=} + {$name} + {= #else =} + {$invite.creator.username} + {= /if =} +
+
@{$invite.creator.username}@{$invite.creator.host}
+
+ {#if @isTag($invite.kind, community_owner) =} +
This act will make your new account the owner of { %community.name }
+ {/if =} +
+ {=/if} + + + + {#if .invite |$invite| =} + + {/if =} + +
diff --git a/src/sql/engines/sqlite.zig b/src/sql/engines/sqlite.zig index a69de93..be94c7e 100644 --- a/src/sql/engines/sqlite.zig +++ b/src/sql/engines/sqlite.zig @@ -313,6 +313,11 @@ fn getColumn(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15, alloc: ?Allocato c.SQLITE_FLOAT => getColumnFloat(stmt, T, idx), c.SQLITE_TEXT => getColumnText(stmt, T, idx, alloc), c.SQLITE_NULL => { + if (T == DateTime) { + std.log.warn("SQLite: Treating NULL as DateTime epoch", .{}); + return std.mem.zeroes(DateTime); + } + if (@typeInfo(T) != .Optional) { std.log.err("SQLite column {}: Expected value of type {}, got (null)", .{ idx, T }); return error.ResultTypeMismatch; diff --git a/src/sql/lib.zig b/src/sql/lib.zig index 0eaf206..776b159 100644 --- a/src/sql/lib.zig +++ b/src/sql/lib.zig @@ -304,8 +304,8 @@ const Row = union(Engine) { fn get(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T { if (T == void) return; return switch (self) { - .postgres => |pg| pg.get(T, idx, alloc), - .sqlite => |lite| lite.get(T, idx, alloc), + .postgres => |pg| try pg.get(T, idx, alloc), + .sqlite => |lite| try lite.get(T, idx, alloc), }; } }; @@ -420,7 +420,7 @@ fn Tx(comptime tx_level: u8) type { if (tx_level != 0) @compileError("close must be called on root db"); if (self.conn.current_tx_level != 0) { std.log.warn("Database released while transaction in progress!", .{}); - self.rollbackUnchecked() catch {}; + self.rollbackUnchecked() catch {}; // TODO: Burn database connection } if (!self.conn.in_use.swap(false, .AcqRel)) @panic("Double close on db conection"); @@ -719,6 +719,7 @@ fn Tx(comptime tx_level: u8) type { fn rollbackUnchecked(self: Self) !void { try self.execInternal("ROLLBACK", {}, .{}, false); + self.conn.current_tx_level = 0; } }; } diff --git a/src/template/lib.zig b/src/template/lib.zig index 8fbcbe1..33b55d3 100644 --- a/src/template/lib.zig +++ b/src/template/lib.zig @@ -1,6 +1,7 @@ const std = @import("std"); pub fn main() !void { + const Enum = enum { foo, bar, baz }; try execute( std.io.getStdOut().writer(), .{ .test_tmpl = "{.x} {%context_foo}" }, @@ -15,8 +16,15 @@ pub fn main() !void { .bar = .{ .x = "x" }, .qux = false, .quxx = true, + .quxx2 = true, .maybe_foo = @as(?[]const u8, "foo"), .maybe_bar = @as(?[]const u8, null), + .snap = Enum.bar, + .crackle = union(Enum) { + foo: []const u8, + bar: []const u8, + baz: []const u8, + }{ .foo = "abcd" }, .x = "y", }, .{ @@ -119,6 +127,49 @@ fn executeStatement( }; } }, + .@"switch" => |switch_stmt| { + const expr = evaluateExpression(switch_stmt.expression, args, captures, context); + + const exhaustive = switch_stmt.cases.len == std.meta.fields(@TypeOf(expr)).len; + + if (exhaustive and switch_stmt.else_branch != null) @compileError("Unused else branch in switch"); + if (!exhaustive and switch_stmt.else_branch == null) @compileError("Not all switch cases covered"); + + var found = false; + inline for (switch_stmt.cases) |case| { + if (std.meta.isTag(expr, case.header.tag)) { + found = true; + if (case.header.capture) |capture| { + try executeTemplate( + writer, + templates, + case.subtemplate, + args, + addCapture(captures, capture, @field(expr, case.header.tag)), + context, + ); + } else { + try executeTemplate( + writer, + templates, + case.subtemplate, + args, + captures, + context, + ); + } + } + } else if (!found) if (switch_stmt.else_branch) |subtemplate| { + try executeTemplate( + writer, + templates, + subtemplate, + args, + captures, + context, + ); + }; + }, .call_template => |call| { const new_template = @field(templates, call.template_name); try execute( @@ -168,6 +219,10 @@ fn EvaluateExpression( .arg_deref => |names| Deref(Args, names), .capture_deref => |names| Deref(Captures, names), .context_deref => |names| Deref(Context, names), + .equals => bool, + .builtin => |call| switch (call.*) { + .isTag => bool, + }, }; } @@ -181,6 +236,22 @@ fn evaluateExpression( .arg_deref => |names| deref(args, names), .capture_deref => |names| deref(captures, names), .context_deref => |names| deref(context, names), + .equals => |eql| { + const lhs = evaluateExpression(eql.lhs, args, captures, context); + const rhs = evaluateExpression(eql.rhs, args, captures, context); + const T = @TypeOf(lhs, rhs); + if (comptime std.meta.trait.isZigString(T)) { + return std.mem.eql(u8, lhs, rhs); + } else if (comptime std.meta.trait.isContainer(T) and @hasDecl(T, "eql")) { + return T.eql(lhs, rhs); + } else return lhs == rhs; + }, + .builtin => |call| switch (call.*) { + .isTag => |hdr| { + const val = evaluateExpression(hdr.expression, args, captures, context); + return std.meta.isTag(val, hdr.tag); + }, + }, }; } @@ -214,6 +285,7 @@ const TemplateType = enum { for_block, if_block, if_else_block, + switch_block, }; const TemplateParseResult = struct { @@ -270,6 +342,7 @@ fn parseTemplate( comptime template_type: TemplateType, ) TemplateParseResult { comptime { + @setEvalBranchQuota(tokens.len * 100); var i: usize = start; var current_text: []const u8 = ""; var items: []const TemplateItem = &.{}; @@ -313,6 +386,41 @@ fn parseTemplate( }}; i = subtemplate.new_idx; }, + .switch_header => |header| { + var cases: []const Case = &.{}; + var else_branch: ?[]const TemplateItem = null; + var last_header: CaseHeader = header.first_case; + var is_else = false; + while (true) { + const case = parseTemplate(tokens, i + 1, .switch_block); + i = case.new_idx; + + if (!is_else) { + cases = cases ++ [_]Case{.{ + .header = last_header, + .subtemplate = case.items, + }}; + } else { + else_branch = case.items; + } + switch (case.closing_block.?.block) { + .end_switch => break, + .@"else" => is_else = true, + .case_header => |case_header| last_header = case_header, + else => @compileError("Unexpected token"), + } + } + + items = items ++ [_]TemplateItem{.{ + .statement = .{ + .@"switch" = .{ + .expression = header.expression, + .cases = cases, + .else_branch = else_branch, + }, + }, + }}; + }, .end_for => if (template_type == .for_block) break cb else @@ -325,13 +433,17 @@ fn parseTemplate( break cb else @compileError("Unexpected #elif tag"), - .@"else" => if (template_type == .if_block) + .@"else" => if (template_type == .if_block or template_type == .switch_block) break cb else @compileError("Unexpected #else tag"), .call_template => |call| items = items ++ [_]TemplateItem{.{ .statement = .{ .call_template = call }, }}, + .end_switch, .case_header => if (template_type == .switch_block) + break cb + else + @compileError("Unexpected /switch tag"), } }, } @@ -390,6 +502,8 @@ fn parseTemplateTokens(comptime tokens: ControlTokenIter) []const TemplateToken .at => items = items ++ [_]TemplateToken{.{ .text = "@" }}, .comma => items = items ++ [_]TemplateToken{.{ .text = "," }}, .percent => items = items ++ [_]TemplateToken{.{ .text = "%" }}, + .open_paren => items = items ++ [_]TemplateToken{.{ .text = "(" }}, + .close_paren => items = items ++ [_]TemplateToken{.{ .text = ")" }}, }; return items; @@ -400,29 +514,108 @@ fn parseExpression(comptime tokens: ControlTokenIter) ParseResult(ControlTokenIt comptime { var iter = tokens; - var expr: Expression = while (iter.next()) |token| switch (token) { + var last_valid_iter: ?ControlTokenIter = null; + var expr: ?Expression = null; + while (iter.next()) |token| switch (token) { .whitespace => {}, .period => { const names = parseDeref(iter); iter = names.new_iter; - break .{ .arg_deref = names.item }; + if (expr != null) break; + expr = .{ .arg_deref = names.item }; + last_valid_iter = iter; }, .dollar => { const names = parseDeref(iter); iter = names.new_iter; - break .{ .capture_deref = names.item }; + if (expr != null) break; + expr = .{ .capture_deref = names.item }; + last_valid_iter = iter; }, .percent => { const names = parseDeref(iter); iter = names.new_iter; - break .{ .context_deref = names.item }; + if (expr != null) break; + expr = .{ .context_deref = names.item }; + last_valid_iter = iter; }, - else => @compileError("TODO"), + .equals => { + const next = iter.next() orelse break; + if (next == .equals) { + const lhs = expr orelse break; + const rhs = parseExpression(iter); + iter = rhs.new_iter; + + expr = .{ + .equals = &.{ + .lhs = lhs, + .rhs = rhs.item, + }, + }; + last_valid_iter = iter; + } else break; + }, + .at => { + if (expr != null) break; + const builtin = parseBuiltin(iter); + iter = builtin.new_iter; + expr = .{ .builtin = &builtin.item }; + last_valid_iter = iter; + }, + else => break, }; + return .{ + .new_iter = last_valid_iter orelse @compileError("Invalid Expression"), + .item = expr orelse @compileError("Invalid Expression"), + }; + } +} + +fn expectToken(comptime token: ?ControlToken, comptime exp: std.meta.Tag(ControlToken)) void { + comptime { + if (token == null) @compileError("Unexpected End Of Template"); + const token_tag = std.meta.activeTag(token.?); + + if (token_tag != exp) + @compileError("Expected " ++ @tagName(exp) ++ ", got " ++ @tagName(token_tag)); + } +} + +fn parseBuiltin(comptime tokens: ControlTokenIter) ParseResult(ControlTokenIter, BuiltinCall) { + comptime { + var iter = tokens; + const builtin = blk: { + const next = iter.next() orelse @compileError("Invalid Builtin"); + if (next != .text) @compileError("Invalid Builtin"); + break :blk std.meta.stringToEnum(Builtin, next.text) orelse @compileError("Invalid Builtin"); + }; + + iter = skipWhitespace(iter); + expectToken(iter.next(), .open_paren); + iter = skipWhitespace(iter); + const call = switch (builtin) { + .isTag => blk: { + const expr = parseExpression(iter); + iter = expr.new_iter; + expectToken(iter.next(), .comma); + iter = skipWhitespace(iter); + const tag = iter.next(); + expectToken(tag, .text); + break :blk .{ + .isTag = .{ + .tag = tag.?.text, + .expression = expr.item, + }, + }; + }, + }; + iter = skipWhitespace(iter); + expectToken(iter.next(), .close_paren); + return .{ .new_iter = iter, - .item = expr, + .item = call, }; } } @@ -471,6 +664,16 @@ fn parseControlBlock(comptime tokens: ControlTokenIter) ParseResult(ControlToken iter = result.new_iter; break .{ .call_template = result.item }; }, + .@"switch" => { + const result = parseSwitchHeader(iter); + iter = result.new_iter; + break .{ .switch_header = result.item }; + }, + .case => { + const result = parseCaseHeader(iter); + iter = result.new_iter; + break .{ .case_header = result.item }; + }, //else => @compileError("TODO"), } @@ -484,6 +687,7 @@ fn parseControlBlock(comptime tokens: ControlTokenIter) ParseResult(ControlToken switch (keyword) { .@"for" => break .{ .end_for = {} }, .@"if" => break .{ .end_if = {} }, + .@"switch" => break .{ .end_switch = {} }, } }, .period, .dollar, .percent => { @@ -518,7 +722,7 @@ fn parseControlBlock(comptime tokens: ControlTokenIter) ParseResult(ControlToken }, else => { @compileLog(iter.row); - @compileError("TODO" ++ @tagName(token)); + @compileError("TODO " ++ @tagName(token) ++ " " ++ token.text); }, }; @@ -642,6 +846,57 @@ fn parseIfHeader(comptime tokens: ControlTokenIter) ParseResult(ControlTokenIter } } +fn parseCaseHeader(comptime tokens: ControlTokenIter) ParseResult(ControlTokenIter, CaseHeader) { + comptime { + var iter = skipWhitespace(tokens); + const tag = iter.next(); + expectToken(tag, .text); + + const captures = tryParseCapture(iter); + if (captures) |cap| { + if (cap.item.len == 1) { + return .{ + .new_iter = cap.new_iter, + .item = CaseHeader{ + .tag = tag.?.text, + .capture = cap.item[0], + }, + }; + } else @compileError("Only one capture allowed for case statements"); + } + + return .{ + .new_iter = iter, + .item = .{ + .tag = tag.?.text, + .capture = null, + }, + }; + } +} + +fn parseSwitchHeader(comptime tokens: ControlTokenIter) ParseResult(ControlTokenIter, SwitchHeader) { + comptime { + const condition = parseExpression(tokens); + var iter = skipWhitespace(condition.new_iter); + + const next = iter.next(); + expectToken(next, .text); + if (!std.mem.eql(u8, next.?.text, "case")) @compileError("Expected case following switch condition"); + + iter = skipWhitespace(iter); + const first = parseCaseHeader(iter); + + return .{ + .new_iter = first.new_iter, + .item = .{ + .expression = condition.item, + .first_case = first.item, + }, + }; + } +} + fn parseDeref(comptime tokens: ControlTokenIter) ParseResult(ControlTokenIter, []const []const u8) { comptime { var iter = tokens; @@ -651,7 +906,7 @@ fn parseDeref(comptime tokens: ControlTokenIter) ParseResult(ControlTokenIter, [ switch (token) { .whitespace => {}, .text => |text| { - if (wants != .text) @compileError("Unexpected token \"" ++ text ++ "\""); + if (wants == .period) break; fields = fields ++ [1][]const u8{text}; wants = .period; }, @@ -659,13 +914,15 @@ fn parseDeref(comptime tokens: ControlTokenIter) ParseResult(ControlTokenIter, [ if (wants != .period) @compileError("Unexpected token \".\""); wants = .text; }, - else => if (wants == .period or fields.len == 0) return .{ - .new_iter = iter, - .item = fields, - } else @compileError("Unexpected token"), + else => if (wants == .period or fields.len == 0) break else @compileError("Unexpected token"), } _ = iter.next(); } + + return .{ + .new_iter = iter, + .item = fields, + }; } } @@ -702,10 +959,17 @@ const TemplateItem = union(enum) { statement: Statement, }; +const EqualsExpr = struct { + lhs: Expression, + rhs: Expression, +}; + const Expression = union(enum) { arg_deref: []const []const u8, capture_deref: []const []const u8, context_deref: []const []const u8, + equals: *const EqualsExpr, + builtin: *const BuiltinCall, }; const For = struct { @@ -727,6 +991,27 @@ const If = struct { }, }; +const Case = struct { + header: CaseHeader, + subtemplate: []const TemplateItem, +}; + +const SwitchHeader = struct { + expression: Expression, + first_case: CaseHeader, +}; + +const CaseHeader = struct { + tag: []const u8, + capture: ?[]const u8, +}; + +const Switch = struct { + expression: Expression, + cases: []const Case, + else_branch: ?[]const TemplateItem, +}; + const CallTemplate = struct { template_name: []const u8, args: Expression, @@ -741,6 +1026,7 @@ const Statement = union(enum) { expression: Expression, @"for": For, @"if": If, + @"switch": Switch, call_template: CallTemplate, }; @@ -754,6 +1040,9 @@ const ControlBlock = struct { @"else": void, elif_header: IfHeader, call_template: CallTemplate, + switch_header: SwitchHeader, + case_header: CaseHeader, + end_switch: void, }; block: Data, strip_before: bool, @@ -766,11 +1055,25 @@ const Keyword = enum { @"else", @"elif", @"template", + @"switch", + @"case", }; const EndKeyword = enum { @"for", @"if", + @"switch", +}; + +const Builtin = enum { + isTag, +}; + +const BuiltinCall = union(Builtin) { + isTag: struct { + tag: []const u8, + expression: Expression, + }, }; const ControlToken = union(enum) { @@ -787,6 +1090,8 @@ const ControlToken = union(enum) { at: void, comma: void, percent: void, + open_paren: void, + close_paren: void, }; const ControlTokenIter = struct { @@ -819,6 +1124,8 @@ const ControlTokenIter = struct { '@' => return .{ .at = {} }, ',' => return .{ .comma = {} }, '%' => return .{ .percent = {} }, + '(' => return .{ .open_paren = {} }, + ')' => return .{ .close_paren = {} }, ' ', '\t', '\n', '\r' => { var idx: usize = 0; while (idx < remaining.len and std.mem.indexOfScalar(u8, " \t\n\r", remaining[idx]) != null) : (idx += 1) {} @@ -830,7 +1137,7 @@ const ControlTokenIter = struct { }, else => { var idx: usize = 0; - while (idx < remaining.len and std.mem.indexOfScalar(u8, "{}.#|$/=@ \t\n\r", remaining[idx]) == null) : (idx += 1) {} + while (idx < remaining.len and std.mem.indexOfScalar(u8, "{}.#|$/=@,%() \t\n\r", remaining[idx]) == null) : (idx += 1) {} self.start += idx - 1; return .{ .text = remaining[0..idx] }; diff --git a/src/template/test.tmp.html b/src/template/test.tmp.html index 29930ed..de2620d 100644 --- a/src/template/test.tmp.html +++ b/src/template/test.tmp.html @@ -14,6 +14,10 @@ {$b}: {= /for =} {= /for} + {#if .quxx == .quxx2}eql{#else}neq{/if} + {#if .quxx == .qux}eql{#else}neq{/if} + {#if @isTag(.snap, foo)}foo{/if} + {#if @isTag(.snap, bar)}bar{/if} {#if .qux=} qux {=#elif .quxx=} @@ -22,6 +26,22 @@ neither {=/if} + {#switch .snap case foo =} + foo + {= #case bar =} + bar + {= #else =} + other + {= /switch} + + crackle: {#switch .crackle case foo |$foo|=} + foo:{$foo} + {= #case bar |$bar|=} + bar:{$bar} + {= #else =} + other + {= /switch} + {#if .maybe_foo |$v|}{$v}{#else}null{/if} {#if .maybe_bar |$v|}{$v}{#else}null{/if} {#if .maybe_foo |$_|}abcd{#else}null{/if} diff --git a/src/util/serialize.zig b/src/util/serialize.zig index b28eb93..bd2e2a4 100644 --- a/src/util/serialize.zig +++ b/src/util/serialize.zig @@ -349,10 +349,14 @@ pub fn DeserializerContext(comptime Result: type, comptime From: type, comptime any_missing = true; } } - if (any_missing) { - return if (any_explicit) error.MissingField else null; - } + if (any_missing and any_explicit) return error.MissingField; + if (!any_explicit) { + inline for (info.fields) |field, i| { + if (fields_alloced[i]) self.deserializeFree(allocator, @field(result, field.name)); + } + return null; + } return result; }, diff --git a/static/site.css b/static/site.css index fea8141..2eff495 100644 --- a/static/site.css +++ b/static/site.css @@ -92,6 +92,10 @@ form[action*=login] .textinput span.suffix { outline: none; } +.form-helpinfo { + font-size: small; +} + button, a.button { padding: 5px; border-radius: 10px;