From 051217cdaf97b52a87d42a707773b0924e03cadf Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Thu, 24 Nov 2022 03:31:24 -0800 Subject: [PATCH] Use middlewares in controller endpoint helper --- src/main/controllers.zig | 343 +++++++++++++++------------------------ 1 file changed, 132 insertions(+), 211 deletions(-) diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 3c22e66..ede27c3 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -20,10 +20,15 @@ const web = @import("./controllers/web.zig"); const mdw = http.middleware; -const router = Router(&.{}); +const router = mdw.Router(&.{}); const not_found = struct { - pub fn handler(self: @This(), _: anytype, res: anytype, ctx: anytype) !void { + pub fn handler( + _: @This(), + _: anytype, + res: anytype, + ctx: anytype, + ) !void { var headers = http.Fields.init(ctx.allocator); defer headers.deinit(); @@ -35,13 +40,135 @@ const not_found = struct { const base_handler = mdw.SplitUri(mdw.CatchErrors(not_found, mdw.DefaultErrorHandler)); -fn ApiCall(comptime Route: type) type { - return mdw. +fn InjectApiConn(comptime ApiSource: type) type { + return struct { + api_source: ApiSource, + fn getApiConn(self: @This(), alloc: std.mem.Allocator, req: anytype) !ApiSource.Conn { + const host = req.headers.get("Host") orelse return error.NoHost; + const auth_header = req.headers.get("Authorization"); + const token = if (auth_header) |header| blk: { + const prefix = "bearer "; + if (header.len < prefix.len) break :blk null; + if (!std.ascii.eqlIgnoreCase(prefix, header[0..prefix.len])) break :blk null; + break :blk header[prefix.len..]; + } else null; + + if (token) |t| return try self.api_source.connectToken(host, t, alloc); + + if (req.headers.getCookie("active_account") catch return error.BadRequest) |account| { + if (account.len + ("token.").len <= 64) { + var buf: [64]u8 = undefined; + const cookie_name = std.fmt.bufPrint(&buf, "token.{s}", .{account}) catch unreachable; + if (try req.headers.getCookie(cookie_name)) |token_hdr| { + return try self.api_source.connectToken(host, token_hdr, alloc); + } + } else return error.InvalidCookie; + } + + return try self.api_source.connectUnauthorized(host, alloc); + } + + fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { + var api_conn = try self.getApiConn(ctx.allocator, req); + defer api_conn.close(); + + return next.handle( + req, + res, + mdw.injectContext(.{ .api_conn = &api_conn }), + {}, + ); + } + }; +} + +pub fn EndpointRequest(comptime Endpoint: type) type { + return struct { + pub const Args = if (@hasDecl(Endpoint, "Args")) Endpoint.Args else void; + pub const Body = if (@hasDecl(Endpoint, "Body")) Endpoint.Body else void; + pub const Query = if (@hasDecl(Endpoint, "Query")) Endpoint.Query else void; + + allocator: std.mem.Allocator, + + method: http.Method, + uri: []const u8, + headers: http.Fields, + + args: Args, + body: Body, + query: Query, + + const args_middleware = if (Args == void) + mdw.injectContext(.{ .args = {} }) + else + mdw.ParsePathArgs(Args){}; + + const body_middleware = if (Body == void) + mdw.injectContext(.{ .body = {} }) + else + mdw.ParseBody(Body){}; + + const query_middleware = if (Query == void) + mdw.injectContext(.{ .query = {} }) + else + mdw.ParseQueryParams(Query){}; + }; +} + +fn CallApiEndpoint(comptime Endpoint: type) type { + return struct { + fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, _: void) !void { + const request = EndpointRequest(Endpoint){ + .allocator = ctx.allocator, + + .method = req.method, + .uri = req.uri, + .headers = req.headers, + + .args = ctx.args, + .body = ctx.body, + .query = ctx.query, + }; + + var response = Response{ .headers = http.Fields.init(ctx.allocator), .res = res }; + defer response.headers.deinit(); + + return Endpoint.handler(request, &response, ctx.api_conn); + } + }; +} + +pub fn apiEndpoint( + comptime Endpoint: type, + api_source: anytype, +) return_type: { + const RequestType = EndpointRequest(Endpoint); + break :return_type mdw.Apply(std.meta.Tuple(.{ + mdw.Route, + @TypeOf(RequestType.args_middleware), + @TypeOf(RequestType.query_middleware), + @TypeOf(RequestType.body_middleware), + // TODO: allocation strategy + InjectApiConn(@TypeOf(api_source)), + CallApiEndpoint(Endpoint), + })); +} { + const RequestType = EndpointRequest(Endpoint); + return mdw.apply(.{ + mdw.Route{ .desc = .{ .path = Endpoint.path, .method = Endpoint.method } }, + RequestType.args_middleware, + RequestType.query_middleware, + RequestType.body_middleware, + // TODO: allocation strategy + InjectApiConn(@TypeOf(api_source)){ .api_source = api_source }, + CallApiEndpoint(Endpoint){}, + }); } pub fn routeRequest(api_source: anytype, req: *http.Request, res: *http.Response, alloc: std.mem.Allocator) void { // TODO: hashmaps? - base_handler + _ = .{ api_source, req, res, alloc }; + unreachable; //var response = Response{ .headers = http.Fields.init(alloc), .res = res }; //defer response.headers.deinit(); @@ -50,14 +177,6 @@ pub fn routeRequest(api_source: anytype, req: *http.Request, res: *http.Response //if (!found) response.status(.not_found) catch {}; } -fn routeRequestInternal(api_source: anytype, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool { - inline for (routes) |route| { - if (Context(route).matchAndHandle(api_source, req, res, alloc)) return true; - } - - return false; -} - const routes = .{ auth.login, auth.verify_login, @@ -77,209 +196,11 @@ const routes = .{ follows.query_following, } ++ web.routes; -fn parseRouteArgs(comptime route: []const u8, comptime Args: type, path: []const u8) !Args { - var args: Args = undefined; - var path_iter = util.PathIter.from(path); - comptime var route_iter = util.PathIter.from(route); - inline while (comptime route_iter.next()) |route_segment| { - const path_segment = path_iter.next() orelse return error.RouteMismatch; - if (route_segment.len > 0 and route_segment[0] == ':') { - const A = @TypeOf(@field(args, route_segment[1..])); - @field(args, route_segment[1..]) = try parseRouteArg(A, path_segment); - } else { - if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch; - } - } - - if (path_iter.next() != null) return error.RouteMismatch; - - return args; -} - -fn parseRouteArg(comptime T: type, segment: []const u8) !T { - if (T == []const u8) return segment; - if (comptime std.meta.trait.isContainer(T) and std.meta.trait.hasFn("parse")(T)) return T.parse(segment); - - @compileError("Unsupported Type " ++ @typeName(T)); -} - -const BaseContentType = enum { - json, - url_encoded, - octet_stream, - - other, -}; - -fn parseBody(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T { - const buf = try reader.readAllAlloc(alloc, 1 << 16); - defer alloc.free(buf); - - switch (content_type) { - .octet_stream, .json => { - const body = try json_utils.parse(T, buf, alloc); - defer json_utils.parseFree(body, alloc); - - return try util.deepClone(alloc, body); - }, - .url_encoded => return query_utils.parseQuery(alloc, T, buf) catch |err| switch (err) { - error.NoQuery => error.NoBody, - else => err, - }, - else => return error.UnsupportedMediaType, - } -} - -fn matchContentType(hdr: ?[]const u8) ?BaseContentType { - if (hdr) |h| { - if (std.ascii.eqlIgnoreCase(h, "application/x-www-form-urlencoded")) return .url_encoded; - if (std.ascii.eqlIgnoreCase(h, "application/json")) return .json; - if (std.ascii.eqlIgnoreCase(h, "application/octet-stream")) return .octet_stream; - - return .other; - } - - return null; -} - pub const AllocationStrategy = enum { arena, normal, }; -pub fn Context(comptime Route: type) type { - return struct { - const Self = @This(); - - pub const Args = if (@hasDecl(Route, "Args")) Route.Args else void; - - // TODO: if controller does not provide a body type, maybe we should - // leave it as a simple reader instead of void - pub const Body = if (@hasDecl(Route, "Body")) Route.Body else void; - - // TODO: if controller does not provide a query type, maybe we should - // leave it as a simple string instead of void - pub const Query = if (@hasDecl(Route, "Query")) Route.Query else void; - - const allocation_strategy: AllocationStrategy = if (@hasDecl(Route, "allocation_strategy")) - Route.AllocationStrategy - else - .arena; - - base_request: *http.Request, - - allocator: std.mem.Allocator, - - method: http.Method, - uri: []const u8, - headers: http.Fields, - - args: Args, - body: Body, - query: Query, - - // TODO - body_buf: ?[]const u8 = null, - - pub fn matchAndHandle(api_source: *api.ApiSource, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool { - if (req.method != Route.method) return false; - var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?'); - var args = parseRouteArgs(Route.path, Args, path) catch return false; - - std.log.debug("Matched route {s}", .{Route.path}); - - handle(api_source, req, res, alloc, args) catch |err| { - std.log.err("{}", .{err}); - if (!res.opened) res.err(.internal_server_error, "", {}) catch {}; - }; - - return true; - } - - fn handle( - api_source: *api.ApiSource, - req: *http.Request, - res: *Response, - base_allocator: std.mem.Allocator, - args: Args, - ) !void { - const base_content_type = matchContentType(req.headers.get("Content-Type")); - - var arena = if (allocation_strategy == .arena) - std.heap.ArenaAllocator.init(base_allocator) - else {}; - const alloc = if (allocation_strategy == .arena) arena.allocator() else base_allocator; - - const body = if (Body != void) blk: { - var stream = req.body orelse return error.NoBody; - break :blk try parseBody(Body, base_content_type orelse .json, stream.reader(), alloc); - } else {}; - defer if (Body != void) util.deepFree(alloc, body); - - const query = if (Query != void) blk: { - const path = std.mem.sliceTo(req.uri, '?'); - const q = req.uri[path.len..]; - - break :blk try query_utils.parseQuery(alloc, Query, q); - }; - defer if (Query != void) util.deepFree(alloc, query); - - var api_conn = conn: { - const host = req.headers.get("Host") orelse return error.NoHost; - const auth_header = req.headers.get("Authorization"); - const token = if (auth_header) |header| blk: { - const prefix = "bearer "; - if (header.len < prefix.len) break :blk null; - if (!std.ascii.eqlIgnoreCase(prefix, header[0..prefix.len])) break :blk null; - break :blk header[prefix.len..]; - } else null; - - if (token) |t| break :conn try api_source.connectToken(host, t, alloc); - - if (req.headers.getCookie("active_account") catch return error.BadRequest) |account| { - if (account.len + ("token.").len <= 64) { - var buf: [64]u8 = undefined; - const cookie_name = std.fmt.bufPrint(&buf, "token.{s}", .{account}) catch unreachable; - if (try req.headers.getCookie(cookie_name)) |token_hdr| { - break :conn try api_source.connectToken(host, token_hdr, alloc); - } - } else return error.InvalidCookie; - } - - break :conn try api_source.connectUnauthorized(host, alloc); - }; - defer api_conn.close(); - - const self = Self{ - .allocator = alloc, - .base_request = req, - - .method = req.method, - .uri = req.uri, - .headers = req.headers, - - .args = args, - .body = body, - .query = query, - }; - - try Route.handler(self, res, &api_conn); - } - - fn errorHandler(response: *Response, status: http.Status, err: anytype) void { - std.log.err("Error occured on handler {s} {s}", .{ @tagName(Route.method), Route.path }); - std.log.err("{}", .{err}); - const result = if (builtin.mode == .Debug) - response.err(status, @errorName(err), {}) - else - response.status(status); - _ = result catch |err2| { - std.log.err("Error printing response: {}", .{err2}); - }; - } - }; -} - pub const Response = struct { const Self = @This(); headers: http.Fields,