diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 4470d75..f162fc0 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -7,14 +7,18 @@ const util = @import("util"); const query_utils = @import("./query.zig"); const json_utils = @import("./json.zig"); -pub const auth = @import("./controllers/auth.zig"); -pub const communities = @import("./controllers/communities.zig"); -pub const invites = @import("./controllers/invites.zig"); -pub const users = @import("./controllers/users.zig"); -pub const follows = @import("./controllers/users/follows.zig"); -pub const notes = @import("./controllers/notes.zig"); -pub const streaming = @import("./controllers/streaming.zig"); -pub const timelines = @import("./controllers/timelines.zig"); +pub const auth = @import("./controllers/api/auth.zig"); +pub const communities = @import("./controllers/api/communities.zig"); +pub const invites = @import("./controllers/api/invites.zig"); +pub const users = @import("./controllers/api/users.zig"); +pub const follows = @import("./controllers/api/users/follows.zig"); +pub const notes = @import("./controllers/api/notes.zig"); +pub const streaming = @import("./controllers/api/streaming.zig"); +pub const timelines = @import("./controllers/api/timelines.zig"); + +const web = struct { + const index = @import("./controllers/web/index.zig"); +}; pub fn routeRequest(api_source: anytype, req: *http.Request, res: *http.Response, alloc: std.mem.Allocator) void { // TODO: hashmaps? @@ -50,8 +54,75 @@ const routes = .{ follows.create, follows.query_followers, follows.query_following, + + web.index, }; +fn parseRouteArgs(comptime route: []const u8, comptime Args: type, path: []const u8) !Args { + var args: Args = undefined; + var path_iter = util.PathIter.from(path); + comptime var route_iter = util.PathIter.from(route); + inline while (comptime route_iter.next()) |route_segment| { + const path_segment = path_iter.next() orelse return error.RouteMismatch; + if (route_segment.len > 0 and route_segment[0] == ':') { + const A = @TypeOf(@field(args, route_segment[1..])); + @field(args, route_segment[1..]) = try parseRouteArg(A, path_segment); + } else { + if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch; + } + } + + if (path_iter.next() != null) return error.RouteMismatch; + + return args; +} + +fn parseRouteArg(comptime T: type, segment: []const u8) !T { + if (T == []const u8) return segment; + if (comptime std.meta.trait.isContainer(T) and std.meta.trait.hasFn("parse")(T)) return T.parse(segment); + + @compileError("Unsupported Type " ++ @typeName(T)); +} + +const BaseContentType = enum { + json, + url_encoded, + octet_stream, + + other, +}; + +fn parseBody(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T { + const buf = try reader.readAllAlloc(alloc, 1 << 16); + defer alloc.free(buf); + + switch (content_type) { + .octet_stream, .json => { + const body = try json_utils.parse(T, buf, alloc); + defer json_utils.parseFree(body, alloc); + + return try util.deepClone(alloc, body); + }, + .url_encoded => return query_utils.parseQuery(alloc, T, buf) catch |err| switch (err) { + error.NoQuery => error.NoBody, + else => err, + }, + else => return error.UnsupportedMediaType, + } +} + +fn matchContentType(hdr: ?[]const u8) ?BaseContentType { + if (hdr) |h| { + if (std.ascii.eqlIgnoreCase(h, "application/x-www-form-urlencoded")) return .url_encoded; + if (std.ascii.eqlIgnoreCase(h, "application/json")) return .json; + if (std.ascii.eqlIgnoreCase(h, "application/octet-stream")) return .octet_stream; + + return .other; + } + + return null; +} + pub fn Context(comptime Route: type) type { return struct { const Self = @This(); @@ -81,38 +152,61 @@ pub fn Context(comptime Route: type) type { // TODO body_buf: ?[]const u8 = null, - fn parseArgs(path: []const u8) ?Args { - var args: Args = undefined; - var path_iter = util.PathIter.from(path); - comptime var route_iter = util.PathIter.from(Route.path); - inline while (comptime route_iter.next()) |route_segment| { - const path_segment = path_iter.next() orelse return null; - if (route_segment[0] == ':') { - const A = @TypeOf(@field(args, route_segment[1..])); - @field(args, route_segment[1..]) = parseArg(A, path_segment) catch return null; - } else { - if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return null; - } - } - - if (path_iter.next() != null) return null; - - return args; - } - - fn parseArg(comptime T: type, segment: []const u8) !T { - if (T == []const u8) return segment; - if (comptime std.meta.trait.hasFn("parse")(T)) return T.parse(segment); - - @compileError("Unsupported Type " ++ @typeName(T)); - } - pub fn matchAndHandle(api_source: *api.ApiSource, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool { if (req.method != Route.method) return false; var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?'); - var args: Args = parseArgs(path) orelse return false; + var args = parseRouteArgs(Route.path, Args, path) catch return false; - var self = Self{ + std.log.debug("Matched route {s}", .{Route.path}); + + handle(api_source, req, res, alloc, args) catch |err| { + std.log.err("{}", .{err}); + if (!res.opened) res.err(.internal_server_error, "", {}) catch {}; + }; + + return true; + } + + fn handle( + api_source: *api.ApiSource, + req: *http.Request, + res: *Response, + alloc: std.mem.Allocator, + args: Args, + ) !void { + const base_content_type = matchContentType(req.headers.get("Content-Type")); + + const body = if (Body != void) blk: { + var stream = req.body orelse return error.NoBody; + break :blk try parseBody(Body, base_content_type orelse .json, stream.reader(), alloc); + } else {}; + defer if (Body != void) util.deepFree(alloc, body); + + const query = if (Query != void) blk: { + const path = std.mem.sliceTo(req.uri, '?'); + const q = req.uri[path.len..]; + + break :blk try query_utils.parseQuery(alloc, Query, q); + }; + defer if (Query != void) util.deepFree(alloc, query); + + var api_conn = conn: { + const host = req.headers.get("Host") orelse return error.NoHost; + const auth_header = req.headers.get("Authorization"); + const token = if (auth_header) |header| blk: { + const prefix = "bearer "; + if (header.len < prefix.len) break :blk null; + if (!std.ascii.eqlIgnoreCase(prefix, header[0..prefix.len])) break :blk null; + break :blk header[prefix.len..]; + } else null; + + if (token) |t| break :conn try api_source.connectToken(host, t, alloc); + + break :conn try api_source.connectUnauthorized(host, alloc); + }; + defer api_conn.close(); + + const self = Self{ .allocator = alloc, .base_request = req, @@ -121,15 +215,11 @@ pub fn Context(comptime Route: type) type { .headers = req.headers, .args = args, - .body = undefined, - .query = undefined, + .body = body, + .query = query, }; - std.log.debug("Matched route {s}", .{path}); - - self.prepareAndHandle(api_source, req, res); - - return true; + try Route.handler(self, res, &api_conn); } fn errorHandler(response: *Response, status: http.Status, err: anytype) void { @@ -143,68 +233,6 @@ pub fn Context(comptime Route: type) type { std.log.err("Error printing response: {}", .{err2}); }; } - - fn prepareAndHandle(self: *Self, api_source: anytype, req: *http.Request, response: *Response) void { - self.parseBody(req) catch |err| return errorHandler(response, .bad_request, err); - defer self.freeBody(); - - self.parseQuery() catch |err| return errorHandler(response, .bad_request, err); - - var api_conn = self.getApiConn(api_source) catch |err| return errorHandler(response, .internal_server_error, err); - defer api_conn.close(); - - self.handle(response, &api_conn); - } - - fn parseBody(self: *Self, req: *http.Request) !void { - if (Body != void) { - var stream = req.body orelse return error.NoBody; - const body = try stream.reader().readAllAlloc(self.allocator, 1 << 16); - errdefer self.allocator.free(body); - self.body = try json_utils.parse(Body, body, self.allocator); - self.body_buf = body; - } - } - - fn freeBody(self: *Self) void { - if (Body != void) { - json_utils.parseFree(self.body, self.allocator); - self.allocator.free(self.body_buf.?); - } - } - - fn parseQuery(self: *Self) !void { - if (Query != void) { - const path = std.mem.sliceTo(self.uri, '?'); - const q = std.mem.sliceTo(self.uri[path.len..], '#'); - - self.query = try query_utils.parseQuery(Query, q); - } - } - - fn handle(self: Self, response: *Response, api_conn: anytype) void { - Route.handler(self, response, api_conn) catch |err| switch (err) { - else => { - std.log.err("{}", .{err}); - if (!response.opened) response.err(.internal_server_error, "", {}) catch {}; - }, - }; - } - - fn getApiConn(self: *Self, api_source: anytype) !api.ApiSource.Conn { - const host = self.headers.get("Host") orelse return error.NoHost; - const auth_header = self.headers.get("Authorization"); - const token = if (auth_header) |header| blk: { - const prefix = "bearer "; - if (header.len < prefix.len) break :blk null; - if (!std.ascii.eqlIgnoreCase(prefix, header[0..prefix.len])) break :blk null; - break :blk header[prefix.len..]; - } else null; - - if (token) |t| return try api_source.connectToken(host, t, self.allocator); - - return try api_source.connectUnauthorized(host, self.allocator); - } }; } @@ -216,22 +244,16 @@ pub const Response = struct { /// Write a response with no body, only a given status pub fn status(self: *Self, status_code: http.Status) !void { - std.debug.assert(!self.opened); - self.opened = true; - - var stream = try self.res.open(status_code, &self.headers); + var stream = try self.open(status_code); defer stream.close(); try stream.finish(); } /// Write a request body as json pub fn json(self: *Self, status_code: http.Status, response_body: anytype) !void { - std.debug.assert(!self.opened); - self.opened = true; - try self.headers.put("Content-Type", "application/json"); - var stream = try self.res.open(status_code, &self.headers); + var stream = try self.open(status_code); defer stream.close(); const writer = stream.writer(); @@ -240,6 +262,13 @@ pub const Response = struct { try stream.finish(); } + pub fn open(self: *Self, status_code: http.Status) !http.Response.Stream { + std.debug.assert(!self.opened); + self.opened = true; + + return try self.res.open(status_code, &self.headers); + } + /// Prints the given error as json pub fn err(self: *Self, status_code: http.Status, message: []const u8, details: anytype) !void { return self.json(status_code, .{ @@ -251,6 +280,7 @@ pub const Response = struct { /// Signals that the HTTP connection should be hijacked without writing a /// response beforehand. pub fn hijack(self: *Self) *http.Response { + std.debug.assert(!self.opened); self.opened = true; return self.res; } diff --git a/src/main/controllers/auth.zig b/src/main/controllers/api/auth.zig similarity index 100% rename from src/main/controllers/auth.zig rename to src/main/controllers/api/auth.zig diff --git a/src/main/controllers/communities.zig b/src/main/controllers/api/communities.zig similarity index 77% rename from src/main/controllers/communities.zig rename to src/main/controllers/api/communities.zig index 89ed8a5..f6f475d 100644 --- a/src/main/controllers/communities.zig +++ b/src/main/controllers/api/communities.zig @@ -1,12 +1,7 @@ -const std = @import("std"); const api = @import("api"); -const util = @import("util"); -const query_utils = @import("../query.zig"); -const controller_utils = @import("../controllers.zig").helpers; +const controller_utils = @import("../../controllers.zig").helpers; const QueryArgs = api.CommunityQueryArgs; -const Uuid = util.Uuid; -const DateTime = util.DateTime; pub const create = struct { pub const method = .POST; diff --git a/src/main/controllers/invites.zig b/src/main/controllers/api/invites.zig similarity index 100% rename from src/main/controllers/invites.zig rename to src/main/controllers/api/invites.zig diff --git a/src/main/controllers/notes.zig b/src/main/controllers/api/notes.zig similarity index 100% rename from src/main/controllers/notes.zig rename to src/main/controllers/api/notes.zig diff --git a/src/main/controllers/streaming.zig b/src/main/controllers/api/streaming.zig similarity index 100% rename from src/main/controllers/streaming.zig rename to src/main/controllers/api/streaming.zig diff --git a/src/main/controllers/timelines.zig b/src/main/controllers/api/timelines.zig similarity index 91% rename from src/main/controllers/timelines.zig rename to src/main/controllers/api/timelines.zig index 2d5ace6..8c30cc1 100644 --- a/src/main/controllers/timelines.zig +++ b/src/main/controllers/api/timelines.zig @@ -1,7 +1,6 @@ const std = @import("std"); const api = @import("api"); -const query_utils = @import("../query.zig"); -const controller_utils = @import("../controllers.zig").helpers; +const controller_utils = @import("../../controllers.zig").helpers; pub const global = struct { pub const method = .GET; diff --git a/src/main/controllers/users.zig b/src/main/controllers/api/users.zig similarity index 100% rename from src/main/controllers/users.zig rename to src/main/controllers/api/users.zig diff --git a/src/main/controllers/users/follows.zig b/src/main/controllers/api/users/follows.zig similarity index 94% rename from src/main/controllers/users/follows.zig rename to src/main/controllers/api/users/follows.zig index dcb36b9..5ee4699 100644 --- a/src/main/controllers/users/follows.zig +++ b/src/main/controllers/api/users/follows.zig @@ -1,6 +1,6 @@ const api = @import("api"); const util = @import("util"); -const controller_utils = @import("../../controllers.zig").helpers; +const controller_utils = @import("../../../controllers.zig").helpers; const Uuid = util.Uuid; diff --git a/src/main/controllers/web/index.fmt.html b/src/main/controllers/web/index.fmt.html new file mode 100644 index 0000000..15b263f --- /dev/null +++ b/src/main/controllers/web/index.fmt.html @@ -0,0 +1,25 @@ + + +
+ +