From 5b0505b35572b18adb95bfcf03f17b2bbc5346fc Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Sat, 26 Nov 2022 22:56:16 -0800 Subject: [PATCH] Add tests for middleware.zig --- src/http/middleware.zig | 418 +++++++++++++++++++++++++++++++++------- 1 file changed, 344 insertions(+), 74 deletions(-) diff --git a/src/http/middleware.zig b/src/http/middleware.zig index 12ddcf3..e34f671 100644 --- a/src/http/middleware.zig +++ b/src/http/middleware.zig @@ -1,10 +1,103 @@ +/// Middlewares are types with a method of type: +/// fn handle( +/// self: @This(), +/// request: *http.Request(< some type >), +/// response: *http.Response(< some type >), +/// context: anytype, +/// next_handler: anytype, +/// ) !void +/// +/// If a middleware returns error.RouteMismatch, then it is assumed that the handler +/// did not apply to the request, and this is used by routing implementations to +/// determine when to stop attempting to match a route. +/// +/// Terminal middlewares that are not implemented using other middlewares should +/// only accept a `void` value for `next_handler`. const std = @import("std"); -const root = @import("root"); -const builtin = @import("builtin"); const http = @import("./lib.zig"); const util = @import("util"); const query_utils = @import("./query.zig"); const json_utils = @import("./json.zig"); + +/// Takes an iterable of middlewares and chains them together. +pub fn apply(middlewares: anytype) Apply(@TypeOf(middlewares)) { + return applyInternal(middlewares, std.meta.fields(@TypeOf(middlewares))); +} + +/// Helper function for the return type of `apply()` +pub fn Apply(comptime Middlewares: type) type { + return ApplyInternal(std.meta.fields(Middlewares)); +} + +fn ApplyInternal(comptime fields: []const std.builtin.Type.StructField) type { + if (fields.len == 0) return void; + + return HandlerList( + fields[0].field_type, + ApplyInternal(fields[1..]), + ); +} + +fn applyInternal(middlewares: anytype, comptime fields: []const std.builtin.Type.StructField) ApplyInternal(fields) { + if (fields.len == 0) return {}; + return .{ + .first = @field(middlewares, fields[0].name), + .next = applyInternal(middlewares, fields[1..]), + }; +} + +pub fn HandlerList(comptime First: type, comptime Next: type) type { + return struct { + first: First, + next: Next, + + pub fn handle( + self: @This(), + req: anytype, + res: anytype, + ctx: anytype, + next: void, + ) !void { + _ = next; + return self.first.handle(req, res, ctx, self.next); + } + }; +} + +test "apply" { + var count: usize = 0; + const NoOp = struct { + ptr: *usize, + fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { + self.ptr.* += 1; + if (@TypeOf(next) != void) return next.handle(req, res, ctx, {}); + } + }; + + const middlewares = .{ + NoOp{ .ptr = &count }, + NoOp{ .ptr = &count }, + NoOp{ .ptr = &count }, + NoOp{ .ptr = &count }, + }; + try std.testing.expectEqual( + Apply(@TypeOf(middlewares)), + HandlerList(NoOp, HandlerList(NoOp, HandlerList(NoOp, HandlerList(NoOp, void)))), + ); + + try apply(middlewares).handle(.{}, .{}, .{}, {}); + try std.testing.expectEqual(count, 4); +} + +test "injectContextValue - chained" { + try apply(.{ + injectContextValue("abcd", @as(usize, 5)), + injectContextValue("efgh", @as(usize, 10)), + injectContextValue("ijkl", @as(usize, 15)), + ExpectContext(.{ .abcd = 5, .efgh = 10, .ijkl = 15 }){}, + }).handle(.{}, .{}, .{}, {}); +} + fn AddUniqueField(comptime Lhs: type, comptime N: usize, comptime name: [N]u8, comptime Val: type) type { const Ctx = @Type(.{ .Struct = .{ .layout = .Auto, @@ -34,44 +127,19 @@ fn addField(lhs: anytype, comptime name: []const u8, val: anytype) AddField(@Typ return result; } -test { - // apply is a plumbing function that applies a tuple of middlewares in order - const base = apply(.{ - split_uri, - mount("/abc"), - }); +test "addField" { + const expect = std.testing.expect; + const eql = std.meta.eql; - const request = .{ .uri = "/abc/defg/hijkl?some_query=true#section" }; - const response = .{}; - const initial_context = .{}; - try base.handle(request, response, initial_context, {}); -} - -fn ApplyInternal(comptime fields: []const std.builtin.Type.StructField) type { - if (fields.len == 0) return void; - - return NextHandler( - fields[0].field_type, - ApplyInternal(fields[1..]), - ); -} - -fn applyInternal(middlewares: anytype, comptime fields: []const std.builtin.Type.StructField) ApplyInternal(fields) { - if (fields.len == 0) return {}; - return .{ - .first = @field(middlewares, fields[0].name), - .next = applyInternal(middlewares, fields[1..]), - }; -} - -pub fn apply(middlewares: anytype) Apply(@TypeOf(middlewares)) { - return applyInternal(middlewares, std.meta.fields(@TypeOf(middlewares))); -} - -pub fn Apply(comptime Middlewares: type) type { - return ApplyInternal(std.meta.fields(Middlewares)); + try expect(eql(addField(.{}, "abcd", 5), .{ .abcd = 5 })); + try expect(eql(addField(.{ .abcd = 5 }, "efgh", 10), .{ .abcd = 5, .efgh = 10 })); + try expect(eql( + addField(addField(.{}, "abcd", 5), "efgh", 10), + .{ .abcd = 5, .efgh = 10 }, + )); } +/// Adds a single value to the context object pub fn InjectContextValue(comptime name: []const u8, comptime V: type) type { return struct { val: V, @@ -85,24 +153,43 @@ pub fn injectContextValue(comptime name: []const u8, val: anytype) InjectContext return .{ .val = val }; } -pub fn NextHandler(comptime First: type, comptime Next: type) type { - return struct { - first: First, - next: Next, +test "InjectContextValue" { + try injectContextValue("abcd", @as(usize, 5)) + .handle(.{}, .{}, .{}, ExpectContext(.{ .abcd = 5 }){}); + try injectContextValue("abcd", @as(usize, 5)) + .handle(.{}, .{}, .{ .efgh = @as(usize, 10) }, ExpectContext(.{ .abcd = 5, .efgh = 10 }){}); +} - pub fn handle( - self: @This(), - req: anytype, - res: anytype, - ctx: anytype, - next: void, - ) !void { - _ = next; - return self.first.handle(req, res, ctx, self.next); +fn expectDeepEquals(expected: anytype, actual: anytype) !void { + const E = @TypeOf(expected); + const A = @TypeOf(actual); + if (E == void) return std.testing.expect(A == void); + try std.testing.expect(std.meta.fields(E).len == std.meta.fields(A).len); + inline for (std.meta.fields(E)) |f| { + const e = @field(expected, f.name); + const a = @field(actual, f.name); + if (comptime std.meta.trait.isZigString(f.field_type)) { + try std.testing.expectEqualStrings(a, e); + } else { + try std.testing.expectEqual(a, e); + } + } +} + +// Helper for testing purposes +fn ExpectContext(comptime val: anytype) type { + return struct { + pub fn handle(_: @This(), _: anytype, _: anytype, ctx: anytype, _: void) !void { + try expectDeepEquals(val, ctx); } }; } +fn expectContext(comptime val: anytype) ExpectContext(val) { + return .{}; +} +/// Catches any errors returned by the `next` chain, and passes them via context +/// to an error handler if one occurs pub fn CatchErrors(comptime ErrorHandler: type) type { return struct { error_handler: ErrorHandler, @@ -118,14 +205,17 @@ pub fn CatchErrors(comptime ErrorHandler: type) type { } }; } + pub fn catchErrors(error_handler: anytype) CatchErrors(@TypeOf(error_handler)) { return .{ .error_handler = error_handler }; } +/// Default error handler for CatchErrors, logs the error and outputs responds with a 500 if +/// a response has not been written yet pub const default_error_handler = struct { - fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - _ = next; - std.log.err("Error {} on uri {s}", .{ ctx.err, req.uri }); + fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, _: anytype) !void { + const should_log = !@import("builtin").is_test; + if (should_log) std.log.err("Error {} on uri {s}", .{ ctx.err, req.uri }); // Tell the server to close the connection after this request res.should_close = true; @@ -141,7 +231,47 @@ pub const default_error_handler = struct { } }{}; -pub const split_uri = struct { +test "CatchErrors" { + const TestResponse = struct { + should_close: bool = false, + was_opened: bool = false, + + test_should_open: bool, + const TestStream = struct { + fn close(_: *@This()) void {} + fn finish(_: *@This()) !void {} + }; + + fn open(self: *@This(), status: http.Status, _: *http.Fields) !TestStream { + self.was_opened = true; + if (!self.test_should_open) return error.ResponseOpenedTwice; + try std.testing.expectEqual(status, .internal_server_error); + return .{}; + } + }; + + const middleware_list = apply(.{ + catchErrors(default_error_handler), + struct { + fn handle(_: @This(), _: anytype, _: anytype, _: anytype, _: anytype) !void { + return error.SomeError; + } + }{}, + }); + + var response = TestResponse{ .test_should_open = true }; + try middleware_list.handle(.{ .uri = "abcd" }, &response, .{}, {}); + try std.testing.expect(response.should_close); + + // Test that it doesn't open a response if one was already opened + response = TestResponse{ .test_should_open = false, .was_opened = true }; + try middleware_list.handle(.{ .uri = "abcd" }, &response, .{}, {}); + try std.testing.expect(response.should_close); +} + +/// Takes the request uri provided and splits it into "path", "query_string", and "fragment_string" +/// parts, which are placed into context. +const SplitUri = struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { var frag_split = std.mem.split(u8, req.uri, "#"); const without_fragment = frag_split.first(); @@ -168,9 +298,32 @@ pub const split_uri = struct { {}, ); } -}{}; +}; +pub const split_uri = SplitUri{}; -// routes a request to the correct handler based on declared HTTP method and path +test "split_uri" { + const testCase = struct { + fn func(uri: []const u8, ctx: anytype, expected: anytype) !void { + const v = apply(.{ + split_uri, + expectContext(expected), + }); + try v.handle(.{ .uri = uri }, .{}, ctx, {}); + } + }.func; + + try testCase("/", .{}, .{ .path = "/", .query_string = "", .fragment_string = "" }); + try testCase("", .{}, .{ .path = "", .query_string = "", .fragment_string = "" }); + try testCase("/path", .{}, .{ .path = "/path", .query_string = "", .fragment_string = "" }); + try testCase("?abcd=1234", .{}, .{ .path = "", .query_string = "abcd=1234", .fragment_string = "" }); + try testCase("#abcd", .{}, .{ .path = "", .query_string = "", .fragment_string = "abcd" }); + try testCase("/abcd/efgh?query=no#frag", .{}, .{ .path = "/abcd/efgh", .query_string = "query=no", .fragment_string = "frag" }); +} + +/// Routes a request between the provided routes. +/// +/// CURRENTLY: Does not do this intelligently, all routing is handled by the routes themselves. +/// TODO: Consider implementing this with a hashmap? pub fn Router(comptime Routes: type) type { return struct { routes: Routes, @@ -204,6 +357,7 @@ fn pathMatches(route: []const u8, path: []const u8) bool { const path_segment = path_iter.next() orelse return false; if (route_segment.len > 0 and route_segment[0] == ':') { // Route Argument + if (path_segment.len == 0) return error.RouteMismatch; } else { if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return false; } @@ -212,6 +366,19 @@ fn pathMatches(route: []const u8, path: []const u8) bool { return true; } + +/// Handler that either calls its next middleware parameter or returns error.RouteMismatch +/// depending on if the request matches the described route. +/// Must be below `split_uri` on the middleware list. +/// +/// Format: +/// Each route segment can be either a literal string or an argument. Literal strings +/// must match exactly in order to constitute a matching route. Arguments must begin with +/// the character ':', with the remainer of the segment referring to the name of the argument. +/// Argument values must be nonempty. +/// +/// For example, the route "/abc/:foo/def" would match "/abc/x/def" or "/abc/blahblah/def" but +/// not "/abc//def". pub const Route = struct { pub const Desc = struct { path: []const u8, @@ -232,7 +399,6 @@ pub const Route = struct { } pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - std.log.debug("Testing path {s} against {s}", .{ ctx.path, self.desc.path }); return if (self.applies(req, ctx)) next.handle(req, res, ctx, {}) else @@ -240,12 +406,47 @@ pub const Route = struct { } }; +test "route" { + const testCase = struct { + fn func(should_match: bool, route: Route.Desc, method: http.Method, path: []const u8) !void { + const no_op = struct { + fn handle(_: anytype, _: anytype, _: anytype, _: anytype, _: anytype) !void {} + }{}; + const result = (Route{ .desc = route }).handle(.{ .method = method }, .{}, .{ .path = path }, no_op); + try if (should_match) result else std.testing.expectError(error.RouteMismatch, result); + } + }.func; + + try testCase(true, .{ .method = .GET, .path = "/" }, .GET, "/"); + try testCase(true, .{ .method = .GET, .path = "/" }, .GET, ""); + try testCase(true, .{ .method = .GET, .path = "/abcd" }, .GET, "/abcd"); + try testCase(true, .{ .method = .GET, .path = "/abcd" }, .GET, "abcd"); + try testCase(true, .{ .method = .POST, .path = "/" }, .POST, "/"); + try testCase(true, .{ .method = .POST, .path = "/" }, .POST, ""); + try testCase(true, .{ .method = .POST, .path = "/abcd" }, .POST, "/abcd"); + try testCase(true, .{ .method = .POST, .path = "/abcd" }, .POST, "abcd"); + try testCase(true, .{ .method = .POST, .path = "/abcd/efgh" }, .POST, "abcd/efgh"); + try testCase(true, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "abcd/efgh"); + try testCase(true, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz"); + + try testCase(false, .{ .method = .POST, .path = "/" }, .GET, "/"); + try testCase(false, .{ .method = .GET, .path = "/abcd" }, .GET, ""); + try testCase(false, .{ .method = .GET, .path = "/" }, .GET, "/abcd"); + try testCase(false, .{ .method = .POST, .path = "/abcd/efgh" }, .POST, "efgh"); + try testCase(false, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "/abcd/"); + try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/"); + try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz/foo"); +} + +/// Mounts a router subtree under a given path. Middlewares further down on the list +/// are called with the path prefix specified by `route` removed from the path. +/// Must be below `split_uri` on the middleware list. pub fn Mount(comptime route: []const u8) type { return struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { var path_iter = util.PathIter.from(ctx.path); comptime var route_iter = util.PathIter.from(route); - var path_unused = ctx.path; + var path_unused: []const u8 = ctx.path; inline while (comptime route_iter.next()) |route_segment| { if (comptime route_segment.len == 0) continue; @@ -269,20 +470,26 @@ pub fn mount(comptime route: []const u8) Mount(route) { return .{}; } -pub fn HandleNotFound(comptime NotFoundHandler: type) type { - return struct { - not_found: NotFoundHandler, - - pub fn handler(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - return next.handler(req, res, ctx, {}) catch |err| switch (err) { - error.RouteMismatch => return self.not_found.handler(req, res, ctx, {}), - else => return err, - }; +test "mount" { + const testCase = struct { + fn func(comptime base: []const u8, request: []const u8, comptime expected: ?[]const u8) !void { + const result = mount(base).handle(.{}, .{}, addField(.{}, "path", request), expectContext(.{ .path = expected orelse "" })); + try if (expected != null) result else std.testing.expectError(error.RouteMismatch, result); } - }; + }.func; + try testCase("/api/", "/api/", ""); + try testCase("/api/", "/api/abcd", "abcd"); + try testCase("/api/", "/api/abcd/efgh", "abcd/efgh"); + try testCase("/api/", "/api/abcd/efgh/", "abcd/efgh/"); + try testCase("/api/v0", "/api/v0/call", "call"); + + try testCase("/api/", "/web/abcd/efgh/", null); + try testCase("/api/", "/", null); + try testCase("/api/", "/ap", null); + try testCase("/api/v0", "/api/v1/", null); } -fn parsePathArgs(comptime route: []const u8, comptime Args: type, path: []const u8) !Args { +fn parseArgsFromPath(comptime route: []const u8, comptime Args: type, path: []const u8) !Args { var args: Args = undefined; var path_iter = util.PathIter.from(path); comptime var route_iter = util.PathIter.from(route); @@ -290,8 +497,9 @@ fn parsePathArgs(comptime route: []const u8, comptime Args: type, path: []const const path_segment = path_iter.next() orelse return error.RouteMismatch; if (route_segment.len > 0 and route_segment[0] == ':') { // route segment is an argument segment + if (path_segment.len == 0) return error.RouteMismatch; const A = @TypeOf(@field(args, route_segment[1..])); - @field(args, route_segment[1..]) = try parsePathArg(A, path_segment); + @field(args, route_segment[1..]) = try parseArgFromPath(A, path_segment); } else { if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch; } @@ -302,13 +510,35 @@ fn parsePathArgs(comptime route: []const u8, comptime Args: type, path: []const return args; } -fn parsePathArg(comptime T: type, segment: []const u8) !T { +fn parseArgFromPath(comptime T: type, segment: []const u8) !T { if (T == []const u8) return segment; if (comptime std.meta.trait.isContainer(T) and std.meta.trait.hasFn("parse")(T)) return T.parse(segment); + if (comptime std.meta.trait.is(.Int)(T)) return std.fmt.parseInt(T, segment, 0); @compileError("Unsupported Type " ++ @typeName(T)); } +/// Parse arguments directly the request path. +/// Must be placed after a `split_uri` middleware in order to get `path` from context. +/// +/// Route arguments are specified in the same format as for Route. The name of the argument +/// refers to the field name in Args that the argument will be parsed to. +/// +/// This currently works with arguments of 3 different types: +/// - integers +/// - []const u8, +/// - anything with a function of the form: +/// * T.parse([]const u8) Error!T +/// * This function cannot hold a reference to the passed string once it appears +/// +/// Example: +/// ParsePathArgs("/:id/foo/:name/byrank/:rank", struct { +/// id: util.Uuid, +/// name: []const u8, +/// rank: u32, +/// }) +/// Would parse a path of "/00000000-0000-0000-0000-000000000000/foo/jaina/byrank/3" into +/// .{ .id = try Uuid.parse("00000000-0000-0000-0000-000000000000"), .name = "jaina", .rank = 3 } pub fn ParsePathArgs(comptime route: []const u8, comptime Args: type) type { return struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { @@ -316,12 +546,42 @@ pub fn ParsePathArgs(comptime route: []const u8, comptime Args: type) type { return next.handle( req, res, - addField(ctx, "args", try parsePathArgs(route, Args, ctx.path)), + addField(ctx, "args", try parseArgsFromPath(route, Args, ctx.path)), {}, ); } }; } +pub fn parsePathArgs(comptime route: []const u8, comptime Args: type) ParsePathArgs(route, Args) { + return .{}; +} + +test "ParsePathArgs" { + const testCase = struct { + fn func(comptime route: []const u8, comptime Args: type, path: []const u8, expected: anytype) !void { + const check = struct { + expected: @TypeOf(expected), + path: []const u8, + fn handle(self: @This(), _: anytype, _: anytype, ctx: anytype, _: void) !void { + try expectDeepEquals(self.expected, ctx.args); + try std.testing.expectEqualStrings(self.path, ctx.path); + } + }{ .expected = expected, .path = path }; + try parsePathArgs(route, Args).handle(.{}, .{}, .{ .path = path }, check); + } + }.func; + + try testCase("/", void, "/", {}); + try testCase("/:id", struct { id: usize }, "/3", .{ .id = 3 }); + try testCase("/:str", struct { str: []const u8 }, "/abcd", .{ .str = "abcd" }); + try testCase("/:id/xyz/:str", struct { id: usize, str: []const u8 }, "/3/xyz/abcd", .{ .id = 3, .str = "abcd" }); + try testCase("/:id", struct { id: util.Uuid }, "/" ++ util.Uuid.nil.toCharArray(), .{ .id = util.Uuid.nil }); + + try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/", .{})); + try std.testing.expectError(error.RouteMismatch, testCase("/abcd/:id", struct { id: usize }, "/123", .{})); + try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/3/id/blahblah", .{ .id = 3 })); + try std.testing.expectError(error.InvalidCharacter, testCase("/:id", struct { id: usize }, "/xyz", .{})); +} const BaseContentType = enum { json, @@ -331,7 +591,7 @@ const BaseContentType = enum { other, }; -fn parseBody(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T { +fn parseBodyFromRequest(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T { //@compileLog(T); const buf = try reader.readAllAlloc(alloc, 1 << 16); defer alloc.free(buf); @@ -351,6 +611,7 @@ fn parseBody(comptime T: type, content_type: BaseContentType, reader: anytype, a } } +// figure out what base parser to use fn matchContentType(hdr: ?[]const u8) ?BaseContentType { if (hdr) |h| { if (std.ascii.eqlIgnoreCase(h, "application/x-www-form-urlencoded")) return .url_encoded; @@ -363,6 +624,11 @@ fn matchContentType(hdr: ?[]const u8) ?BaseContentType { return null; } +/// Parses a set of body arguments from the request body based on the request's Content-Type +/// header. +/// +/// The exact method for parsing depends partially on the Content-Type. json types are preferred +/// TODO: Need tests for this, including various Content-Type values pub fn ParseBody(comptime Body: type) type { return struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { @@ -377,7 +643,7 @@ pub fn ParseBody(comptime Body: type) type { const base_content_type = matchContentType(content_type); var stream = req.body orelse return error.NoBody; - const body = try parseBody(Body, base_content_type orelse .json, stream.reader(), ctx.allocator); + const body = try parseBodyFromRequest(Body, base_content_type orelse .json, stream.reader(), ctx.allocator); defer util.deepFree(ctx.allocator, body); return next.handle( @@ -389,7 +655,11 @@ pub fn ParseBody(comptime Body: type) type { } }; } +pub fn parseBody(comptime Body: type) ParseBody(Body) { + return .{}; +} +/// Parses query parameters as defined in query.zig pub fn ParseQueryParams(comptime QueryParams: type) type { return struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {