diff --git a/src/http/headers.zig b/src/http/headers.zig index 4f45029..1b91865 100644 --- a/src/http/headers.zig +++ b/src/http/headers.zig @@ -89,8 +89,8 @@ pub const Fields = struct { } }; - pub fn getList(self: Fields, key: []const u8) ?ListIterator { - return if (self.unmanaged.get(key)) |hdr| ListIterator{ .remaining = hdr } else null; + pub fn getList(self: Fields, key: []const u8) ListIterator { + return if (self.unmanaged.get(key)) |hdr| ListIterator{ .remaining = hdr } else ListIterator{ .remaining = "" }; } pub fn put(self: *Fields, key: []const u8, val: []const u8) !void { @@ -103,7 +103,8 @@ pub const Fields = struct { errdefer self.allocator.free(val_clone); if (try self.unmanaged.fetchPut(self.allocator, key_clone, val_clone)) |entry| { - self.allocator.free(entry.key); + self.allocator.free(key_clone); + //self.allocator.free(entry.key); self.allocator.free(entry.value); } } @@ -121,4 +122,50 @@ pub const Fields = struct { pub fn count(self: Fields) usize { return self.unmanaged.count(); } + + pub const CookieOptions = struct { + Secure: bool = true, + HttpOnly: bool = true, + SameSite: ?enum { + Strict, + Lax, + None, + } = null, + }; + + // TODO: Escape cookie values + pub fn setCookie(self: *Fields, name: []const u8, value: []const u8, opt: CookieOptions) !void { + const cookie = try std.fmt.allocPrint( + self.allocator, + "{s}={s}{s}{s}{s}{s}", + .{ + name, + value, + if (opt.Secure) "; Secure" else "", + if (opt.HttpOnly) "; HttpOnly" else "", + if (opt.SameSite) |_| "; SameSite=" else "", + if (opt.SameSite) |same_site| @tagName(same_site) else "", + }, + ); + defer self.allocator.free(cookie); + + // TODO: reduce unnecessary allocations + try self.append("Set-Cookie", cookie); + } + + // TODO: perform validation at request parse time? + pub fn getCookie(self: *Fields, name: []const u8) !?[]const u8 { + const hdr = self.get("Cookie") orelse return null; + var iter = std.mem.split(u8, hdr, ";"); + while (iter.next()) |cookie| { + const trimmed = std.mem.trimLeft(u8, cookie, " "); + const cookie_name = std.mem.sliceTo(trimmed, '='); + if (std.mem.eql(u8, name, cookie_name)) { + const rest = trimmed[cookie_name.len..]; + if (rest.len == 0) return error.InvalidCookie; + return rest[1..]; + } + } + return null; + } }; diff --git a/src/http/server.zig b/src/http/server.zig index 21a83be..b24ae75 100644 --- a/src/http/server.zig +++ b/src/http/server.zig @@ -38,7 +38,7 @@ pub const Stream = struct { os.closeSocket(self.socket); } - pub const ReadError = os.RecvFromError; + pub const ReadError = os.ReadError; pub const WriteError = os.SendError; pub const Reader = std.io.Reader(Stream, ReadError, read); @@ -48,7 +48,7 @@ pub const Stream = struct { if (std.io.is_async) @compileError("TODO: async not supported"); if (self.kind != .tcp) @panic("TODO"); - return os.recv(self.socket, buffer, 0); + return os.read(self.socket, buffer); } pub fn write(self: Stream, buffer: []const u8) WriteError!usize { diff --git a/src/http/server/response.zig b/src/http/server/response.zig index 296382d..0da7ab0 100644 --- a/src/http/server/response.zig +++ b/src/http/server/response.zig @@ -43,8 +43,15 @@ fn writeFields(writer: anytype, headers: *const Fields) !void { if (ch == '\r' or ch == '\n') @panic("newlines not yet supported in headers"); } + if (std.ascii.eqlIgnoreCase("Set-Cookie", header.key_ptr.*)) continue; + try writer.print("{s}: {s}\r\n", .{ header.key_ptr.*, header.value_ptr.* }); } + + var cookie_iter = headers.getList("Set-Cookie"); + while (cookie_iter.next()) |cookie| { + try writer.print("Set-Cookie: {s}\r\n", .{cookie}); + } } fn writeChunk(writer: anytype, contents: []const u8) @TypeOf(writer).Error!void { diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 3fcfd6a..eeba618 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -16,9 +16,7 @@ pub const notes = @import("./controllers/api/notes.zig"); pub const streaming = @import("./controllers/api/streaming.zig"); pub const timelines = @import("./controllers/api/timelines.zig"); -const web = struct { - const index = @import("./controllers/web/index.zig"); -}; +const web = @import("./controllers/web.zig"); pub fn routeRequest(api_source: anytype, req: *http.Request, res: *http.Response, alloc: std.mem.Allocator) void { // TODO: hashmaps? @@ -54,9 +52,7 @@ const routes = .{ follows.create, follows.query_followers, follows.query_following, - - web.index, -}; +} ++ web.routes; fn parseRouteArgs(comptime route: []const u8, comptime Args: type, path: []const u8) !Args { var args: Args = undefined; @@ -202,6 +198,16 @@ pub fn Context(comptime Route: type) type { if (token) |t| break :conn try api_source.connectToken(host, t, alloc); + if (req.headers.getCookie("active_account") catch return error.BadRequest) |account| { + if (account.len + ("token.").len <= 64) { + var buf: [64]u8 = undefined; + const cookie_name = std.fmt.bufPrint(&buf, "token.{s}", .{account}) catch unreachable; + if (try req.headers.getCookie(cookie_name)) |token_hdr| { + break :conn try api_source.connectToken(host, token_hdr, alloc); + } + } else return error.InvalidToken; + } + break :conn try api_source.connectUnauthorized(host, alloc); }; defer api_conn.close(); @@ -284,6 +290,18 @@ pub const Response = struct { self.opened = true; return self.res; } + + pub fn template(self: *Self, status_code: http.Status, comptime templ: []const u8, data: anytype) !void { + try self.headers.put("Content-Type", "text/html"); + + var stream = try self.open(status_code); + defer stream.close(); + + const writer = stream.writer(); + try @import("template").execute(writer, templ, data); + + try stream.finish(); + } }; const json_options = if (builtin.mode == .Debug) diff --git a/src/main/controllers/web.zig b/src/main/controllers/web.zig new file mode 100644 index 0000000..5a6d52f --- /dev/null +++ b/src/main/controllers/web.zig @@ -0,0 +1,58 @@ +const std = @import("std"); + +pub const routes = .{ + index, + about, + login, +}; + +const index = struct { + pub const path = "/"; + pub const method = .GET; + + pub fn handler(_: anytype, res: anytype, srv: anytype) !void { + if (srv.user_id == null) { + try res.headers.put("Location", about.path); + return res.status(.see_other); + } + + try res.template(.ok, "Hello", .{}); + } +}; + +const about = struct { + pub const path = "/about"; + pub const method = .GET; + + pub fn handler(_: anytype, res: anytype, srv: anytype) !void { + try res.headers.put("Content-Type", "text/html"); + + try res.template(.ok, tmpl, .{ + .community = srv.community, + }); + } + + const tmpl = @embedFile("./web/index.tmpl.html"); +}; + +const login = struct { + pub const path = "/login"; + pub const method = .POST; + + pub const Body = struct { + username: []const u8, + password: []const u8, + }; + + pub fn handler(req: anytype, res: anytype, srv: anytype) !void { + const token = try srv.login(req.body.username, req.body.password); + + try res.headers.put("Location", index.path); + var buf: [64]u8 = undefined; + const cookie_name = try std.fmt.bufPrint(&buf, "token.{s}", .{req.body.username}); + try res.headers.setCookie(cookie_name, token.token, .{}); + try res.headers.setCookie("active_account", req.body.username, .{ .HttpOnly = false }); + + try res.status(.see_other); + } +}; diff --git a/src/main/controllers/web/index.zig b/src/main/controllers/web/index.zig deleted file mode 100644 index d1c6821..0000000 --- a/src/main/controllers/web/index.zig +++ /dev/null @@ -1,19 +0,0 @@ -const template = @import("template"); - -pub const path = "/"; -pub const method = .GET; - -pub fn handler(_: anytype, res: anytype, srv: anytype) !void { - try res.headers.put("Content-Type", "text/html"); - - var stream = try res.open(.ok); - defer stream.close(); - - try template.execute(stream.writer(), tmpl, .{ - .community = srv.community, - }); - - try stream.finish(); -} - -const tmpl = @embedFile("./index.tmpl.html");