From 73f7022d3602c512284e75d35b6e8dc549ca79c3 Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Wed, 23 Nov 2022 20:51:30 -0800 Subject: [PATCH] Start work on middleware api --- src/http/lib.zig | 2 + src/http/middleware.zig | 345 +++++++++++++++++++++++++++++++++++++++ src/main/controllers.zig | 30 +++- src/util/iters.zig | 25 ++- 4 files changed, 391 insertions(+), 11 deletions(-) create mode 100644 src/http/middleware.zig diff --git a/src/http/lib.zig b/src/http/lib.zig index 26f7756..9a4d8c9 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -15,6 +15,8 @@ pub const Response = server.Response; pub const Handler = server.Handler; pub const Server = server.Server; +pub const middleware = @import("./middleware.zig"); + pub const Fields = @import("./headers.zig").Fields; pub const Protocol = enum { diff --git a/src/http/middleware.zig b/src/http/middleware.zig new file mode 100644 index 0000000..428d52f --- /dev/null +++ b/src/http/middleware.zig @@ -0,0 +1,345 @@ +const std = @import("std"); +const root = @import("root"); +const builtin = @import("builtin"); +const http = @import("./lib.zig"); +const util = @import("util"); +//const query_utils = @import("./query.zig"); +//const json_utils = @import("./json.zig"); +const json_utils = util; +const query_utils = util; + +fn AddFields(comptime lhs: type, comptime rhs: type) type { + const Ctx = @Type(.{ .Struct = .{ + .layout = .Auto, + .fields = std.meta.fields(lhs) ++ std.meta.fields(rhs), + .decls = &.{}, + .is_tuple = false, + } }); + return Ctx; +} + +fn addFields(lhs: anytype, rhs: anytype) AddFields(@TypeOf(lhs), @TypeOf(rhs)) { + var result: AddFields(@TypeOf(lhs), @TypeOf(rhs)) = undefined; + inline for (comptime std.meta.fieldNames(@TypeOf(lhs))) |f| @field(result, f) = @field(lhs, f); + inline for (comptime std.meta.fieldNames(@TypeOf(rhs))) |f| @field(result, f) = @field(rhs, f); + return result; +} + +test { + // apply is a plumbing function that applies a tuple of middlewares in order + const base = apply(.{ + split_uri, + mount("/abc"), + }); + + const request = .{ .uri = "/abc/defg/hijkl?some_query=true#section" }; + const response = .{}; + const initial_context = .{}; + try base.handle(request, response, initial_context, {}); +} + +fn ApplyInternal(comptime fields: []const std.builtin.Type.StructField) type { + if (fields.len == 0) return void; + + return NextHandler( + fields[0].field_type, + ApplyInternal(fields[1..]), + ); +} + +fn applyInternal(middlewares: anytype, comptime fields: []const std.builtin.Type.StructField) ApplyInternal(fields) { + if (fields.len == 0) return {}; + return .{ + .first = @field(middlewares, fields[0].name), + .next = applyInternal(middlewares, fields[1..]), + }; +} + +pub fn apply(middlewares: anytype) ApplyInternal(std.meta.fields(@TypeOf(middlewares))) { + return applyInternal(middlewares, std.meta.fields(@TypeOf(middlewares))); +} + +pub fn AddContext(comptime Rhs: type) type { + return struct { + values: Rhs, + pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { + return next.handle(req, res, addFields(ctx, self.values), {}); + } + }; +} + +pub fn NextHandler(comptime First: type, comptime Next: type) type { + return struct { + first: First, + next: Next, + + pub fn handle( + self: @This(), + req: anytype, + res: anytype, + ctx: anytype, + next: void, + ) !void { + _ = next; + return self.first.handle(req, res, ctx, self.next); + } + }; +} + +pub fn CatchErrors(comptime ErrorHandler: type) type { + return struct { + error_handler: ErrorHandler, + pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { + return next.handle(req, res, ctx, {}) catch |err| { + return self.error_handler.handle( + req, + res, + addFields(ctx, .{ .err = err }), + next, + ); + }; + } + }; +} +pub fn catchErrors(error_handler: anytype) CatchErrors(@TypeOf(error_handler)) { + return .{ .error_handler = error_handler }; +} + +pub const default_error_handler = struct { + fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { + _ = next; + std.log.err("Error {} on uri {s}", .{ ctx.err, req.uri }); + if (!res.was_opened) { + if (res.open(.internal_server_error)) |stream| { + defer stream.close(); + stream.finish() catch {}; + } + } + + // Tell the server to close the connection after this request + res.should_close = true; + } +}{}; + +pub const split_uri = struct { + pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { + var frag_split = std.mem.split(u8, req.uri, "#"); + const without_fragment = frag_split.first(); + const fragment = frag_split.rest(); + + var query_split = std.mem.split(u8, without_fragment, "?"); + const path = query_split.first(); + const query = query_split.rest(); + + const added_ctx = .{ + .path = path, + .query_string = query, + .fragment_string = fragment, + }; + + return next.handle( + req, + res, + addFields(ctx, added_ctx), + {}, + ); + } +}{}; + +// helper function for doing route analysis +fn routeApplies(comptime R: type, req: anytype) bool { + if (R.method != req.method) return false; + + var path_iter = util.PathIter.from(req.path); + comptime var route_iter = util.PathIter.from(R.path); + inline while (comptime route_iter.next()) |route_segment| { + const path_segment = path_iter.next() orelse return false; + if (route_segment.len > 0 and route_segment[0] == ':') { + // Route Argument + } else { + if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return false; + } + } + if (path_iter.next() != null) return false; + + return true; +} + +// routes a request to the correct handler based on declared HTTP method and path +pub fn Router(comptime Routes: []const type) type { + return struct { + routes: std.meta.Tuple(Routes), + + pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: void) !void { + _ = next; + + inline for (self.routes) |r| if (routeApplies(@TypeOf(r), req, ctx)) { + if (r.handle(req, res, ctx, {})) |_| { + // success! + return; + } else |err| switch (err) { + error.RouteMismatch => {}, + else => return err, + } + }; + + return error.RouteMismatch; + } + }; +} + +pub fn Mount(comptime route: []const u8) type { + return struct { + pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { + var path_iter = util.PathIter.from(ctx.path); + comptime var route_iter = util.PathIter.from(route); + var path_unused = ctx.path; + + inline while (comptime route_iter.next()) |route_segment| { + if (comptime route_segment.len == 0) continue; + const path_segment = path_iter.next() orelse return error.RouteMismatch; + path_unused = path_iter.rest(); + if (comptime route_segment[0] == ':') { + @compileLog("Argument segments cannot be mounted"); + // Route Argument + } else { + if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch; + } + } + + var new_ctx = ctx; + new_ctx.path = path_unused; + return next.handle(req, res, new_ctx, {}); + } + }; +} +pub fn mount(comptime route: []const u8) Mount(route) { + return .{}; +} + +pub fn HandleNotFound(comptime NotFoundHandler: type) type { + return struct { + not_found: NotFoundHandler, + + pub fn handler(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { + return next.handler(req, res, ctx, {}) catch |err| switch (err) { + error.RouteMismatch => return self.not_found.handler(req, res, ctx, {}), + else => return err, + }; + } + }; +} + +fn parsePathArgs(comptime route: []const u8, comptime Args: type, path: []const u8) !Args { + var args: Args = undefined; + var path_iter = util.PathIter.from(path); + comptime var route_iter = util.PathIter.from(route); + inline while (comptime route_iter.next()) |route_segment| { + const path_segment = path_iter.next() orelse return error.RouteMismatch; + if (route_segment.len > 0 and route_segment[0] == ':') { + // route segment is an argument segment + const A = @TypeOf(@field(args, route_segment[1..])); + @field(args, route_segment[1..]) = try parsePathArg(A, path_segment); + } else { + if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch; + } + } + + if (path_iter.next() != null) return error.RouteMismatch; + + return args; +} + +fn parsePathArg(comptime T: type, segment: []const u8) !T { + if (T == []const u8) return segment; + if (comptime std.meta.trait.isContainer(T) and std.meta.trait.hasFn("parse")(T)) return T.parse(segment); + + @compileError("Unsupported Type " ++ @typeName(T)); +} + +pub fn ParsePathArgs(comptime route: []const u8, comptime Args: type) type { + return struct { + pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { + return next.handle( + req, + res, + addFields(ctx, .{ .args = parsePathArgs(route, Args, req.path) }), + {}, + ); + } + }; +} + +const BaseContentType = enum { + json, + url_encoded, + octet_stream, + + other, +}; + +fn parseBody(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T { + const buf = try reader.readAllAlloc(alloc, 1 << 16); + defer alloc.free(buf); + + switch (content_type) { + .octet_stream, .json => { + const body = try json_utils.parse(T, buf, alloc); + defer json_utils.parseFree(body, alloc); + + return try util.deepClone(alloc, body); + }, + .url_encoded => return query_utils.parseQuery(alloc, T, buf) catch |err| switch (err) { + error.NoQuery => error.NoBody, + else => err, + }, + else => return error.UnsupportedMediaType, + } +} + +fn matchContentType(hdr: ?[]const u8) ?BaseContentType { + if (hdr) |h| { + if (std.ascii.eqlIgnoreCase(h, "application/x-www-form-urlencoded")) return .url_encoded; + if (std.ascii.eqlIgnoreCase(h, "application/json")) return .json; + if (std.ascii.eqlIgnoreCase(h, "application/octet-stream")) return .octet_stream; + + return .other; + } + + return null; +} + +pub fn ParseBody(comptime Body: type) type { + return struct { + pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { + const base_content_type = matchContentType(req.headers.get("Content-Type")); + + var stream = req.body orelse return error.NoBody; + const body = try parseBody(Body, base_content_type orelse .json, stream.reader(), ctx.allocator); + defer ctx.allocator.free(body); + + return next.handler( + req, + res, + addFields(ctx, .{ .body = body }), + ); + } + }; +} + +pub fn ParseQueryParams(comptime Next: type, comptime QueryParams: type) type { + return struct { + next: Next, + + pub fn handler(self: @This(), req: anytype, res: anytype, ctx: anytype) !void { + const query = try query_utils.parseQuery(ctx.allocator, QueryParams, ctx.query_string); + defer ctx.allocator.free(query); + + return self.next.handler( + req, + res, + addFields(ctx, .{ .query = query }), + ); + } + }; +} diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 007e909..3c22e66 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -18,14 +18,36 @@ pub const timelines = @import("./controllers/api/timelines.zig"); const web = @import("./controllers/web.zig"); +const mdw = http.middleware; + +const router = Router(&.{}); + +const not_found = struct { + pub fn handler(self: @This(), _: anytype, res: anytype, ctx: anytype) !void { + var headers = http.Fields.init(ctx.allocator); + defer headers.deinit(); + + var stream = try res.open(.not_found, &headers); + defer stream.close(); + try stream.finish(); + } +}; + +const base_handler = mdw.SplitUri(mdw.CatchErrors(not_found, mdw.DefaultErrorHandler)); + +fn ApiCall(comptime Route: type) type { + return mdw. +} + pub fn routeRequest(api_source: anytype, req: *http.Request, res: *http.Response, alloc: std.mem.Allocator) void { // TODO: hashmaps? - var response = Response{ .headers = http.Fields.init(alloc), .res = res }; - defer response.headers.deinit(); + base_handler + //var response = Response{ .headers = http.Fields.init(alloc), .res = res }; + //defer response.headers.deinit(); - const found = routeRequestInternal(api_source, req, &response, alloc); + //const found = routeRequestInternal(api_source, req, &response, alloc); - if (!found) response.status(.not_found) catch {}; + //if (!found) response.status(.not_found) catch {}; } fn routeRequestInternal(api_source: anytype, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool { diff --git a/src/util/iters.zig b/src/util/iters.zig index b19c5bd..5ad2258 100644 --- a/src/util/iters.zig +++ b/src/util/iters.zig @@ -49,19 +49,30 @@ pub const QueryIter = struct { pub const PathIter = struct { is_first: bool, - iter: Separator('/'), + iter: std.mem.SplitIterator(u8), pub fn from(path: []const u8) PathIter { - return .{ .is_first = true, .iter = Separator('/').from(path) }; + return .{ .is_first = true, .iter = std.mem.split(u8, path, "/") }; } pub fn next(self: *PathIter) ?[]const u8 { - if (self.is_first) { - self.is_first = false; - return self.iter.next() orelse ""; - } + defer self.is_first = false; + while (self.iter.next()) |it| if (it.len != 0) { + return it; + }; - return self.iter.next(); + if (self.is_first) return self.iter.rest(); + + return null; + } + + pub fn first(self: *PathIter) []const u8 { + std.debug.assert(self.is_first); + return self.next().?; + } + + pub fn rest(self: *PathIter) []const u8 { + return self.iter.rest(); } };