From 5b0505b35572b18adb95bfcf03f17b2bbc5346fc Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Sat, 26 Nov 2022 22:56:16 -0800 Subject: [PATCH 1/9] Add tests for middleware.zig --- src/http/middleware.zig | 418 +++++++++++++++++++++++++++++++++------- 1 file changed, 344 insertions(+), 74 deletions(-) diff --git a/src/http/middleware.zig b/src/http/middleware.zig index 12ddcf3..e34f671 100644 --- a/src/http/middleware.zig +++ b/src/http/middleware.zig @@ -1,10 +1,103 @@ +/// Middlewares are types with a method of type: +/// fn handle( +/// self: @This(), +/// request: *http.Request(< some type >), +/// response: *http.Response(< some type >), +/// context: anytype, +/// next_handler: anytype, +/// ) !void +/// +/// If a middleware returns error.RouteMismatch, then it is assumed that the handler +/// did not apply to the request, and this is used by routing implementations to +/// determine when to stop attempting to match a route. +/// +/// Terminal middlewares that are not implemented using other middlewares should +/// only accept a `void` value for `next_handler`. const std = @import("std"); -const root = @import("root"); -const builtin = @import("builtin"); const http = @import("./lib.zig"); const util = @import("util"); const query_utils = @import("./query.zig"); const json_utils = @import("./json.zig"); + +/// Takes an iterable of middlewares and chains them together. +pub fn apply(middlewares: anytype) Apply(@TypeOf(middlewares)) { + return applyInternal(middlewares, std.meta.fields(@TypeOf(middlewares))); +} + +/// Helper function for the return type of `apply()` +pub fn Apply(comptime Middlewares: type) type { + return ApplyInternal(std.meta.fields(Middlewares)); +} + +fn ApplyInternal(comptime fields: []const std.builtin.Type.StructField) type { + if (fields.len == 0) return void; + + return HandlerList( + fields[0].field_type, + ApplyInternal(fields[1..]), + ); +} + +fn applyInternal(middlewares: anytype, comptime fields: []const std.builtin.Type.StructField) ApplyInternal(fields) { + if (fields.len == 0) return {}; + return .{ + .first = @field(middlewares, fields[0].name), + .next = applyInternal(middlewares, fields[1..]), + }; +} + +pub fn HandlerList(comptime First: type, comptime Next: type) type { + return struct { + first: First, + next: Next, + + pub fn handle( + self: @This(), + req: anytype, + res: anytype, + ctx: anytype, + next: void, + ) !void { + _ = next; + return self.first.handle(req, res, ctx, self.next); + } + }; +} + +test "apply" { + var count: usize = 0; + const NoOp = struct { + ptr: *usize, + fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { + self.ptr.* += 1; + if (@TypeOf(next) != void) return next.handle(req, res, ctx, {}); + } + }; + + const middlewares = .{ + NoOp{ .ptr = &count }, + NoOp{ .ptr = &count }, + NoOp{ .ptr = &count }, + NoOp{ .ptr = &count }, + }; + try std.testing.expectEqual( + Apply(@TypeOf(middlewares)), + HandlerList(NoOp, HandlerList(NoOp, HandlerList(NoOp, HandlerList(NoOp, void)))), + ); + + try apply(middlewares).handle(.{}, .{}, .{}, {}); + try std.testing.expectEqual(count, 4); +} + +test "injectContextValue - chained" { + try apply(.{ + injectContextValue("abcd", @as(usize, 5)), + injectContextValue("efgh", @as(usize, 10)), + injectContextValue("ijkl", @as(usize, 15)), + ExpectContext(.{ .abcd = 5, .efgh = 10, .ijkl = 15 }){}, + }).handle(.{}, .{}, .{}, {}); +} + fn AddUniqueField(comptime Lhs: type, comptime N: usize, comptime name: [N]u8, comptime Val: type) type { const Ctx = @Type(.{ .Struct = .{ .layout = .Auto, @@ -34,44 +127,19 @@ fn addField(lhs: anytype, comptime name: []const u8, val: anytype) AddField(@Typ return result; } -test { - // apply is a plumbing function that applies a tuple of middlewares in order - const base = apply(.{ - split_uri, - mount("/abc"), - }); +test "addField" { + const expect = std.testing.expect; + const eql = std.meta.eql; - const request = .{ .uri = "/abc/defg/hijkl?some_query=true#section" }; - const response = .{}; - const initial_context = .{}; - try base.handle(request, response, initial_context, {}); -} - -fn ApplyInternal(comptime fields: []const std.builtin.Type.StructField) type { - if (fields.len == 0) return void; - - return NextHandler( - fields[0].field_type, - ApplyInternal(fields[1..]), - ); -} - -fn applyInternal(middlewares: anytype, comptime fields: []const std.builtin.Type.StructField) ApplyInternal(fields) { - if (fields.len == 0) return {}; - return .{ - .first = @field(middlewares, fields[0].name), - .next = applyInternal(middlewares, fields[1..]), - }; -} - -pub fn apply(middlewares: anytype) Apply(@TypeOf(middlewares)) { - return applyInternal(middlewares, std.meta.fields(@TypeOf(middlewares))); -} - -pub fn Apply(comptime Middlewares: type) type { - return ApplyInternal(std.meta.fields(Middlewares)); + try expect(eql(addField(.{}, "abcd", 5), .{ .abcd = 5 })); + try expect(eql(addField(.{ .abcd = 5 }, "efgh", 10), .{ .abcd = 5, .efgh = 10 })); + try expect(eql( + addField(addField(.{}, "abcd", 5), "efgh", 10), + .{ .abcd = 5, .efgh = 10 }, + )); } +/// Adds a single value to the context object pub fn InjectContextValue(comptime name: []const u8, comptime V: type) type { return struct { val: V, @@ -85,24 +153,43 @@ pub fn injectContextValue(comptime name: []const u8, val: anytype) InjectContext return .{ .val = val }; } -pub fn NextHandler(comptime First: type, comptime Next: type) type { - return struct { - first: First, - next: Next, +test "InjectContextValue" { + try injectContextValue("abcd", @as(usize, 5)) + .handle(.{}, .{}, .{}, ExpectContext(.{ .abcd = 5 }){}); + try injectContextValue("abcd", @as(usize, 5)) + .handle(.{}, .{}, .{ .efgh = @as(usize, 10) }, ExpectContext(.{ .abcd = 5, .efgh = 10 }){}); +} - pub fn handle( - self: @This(), - req: anytype, - res: anytype, - ctx: anytype, - next: void, - ) !void { - _ = next; - return self.first.handle(req, res, ctx, self.next); +fn expectDeepEquals(expected: anytype, actual: anytype) !void { + const E = @TypeOf(expected); + const A = @TypeOf(actual); + if (E == void) return std.testing.expect(A == void); + try std.testing.expect(std.meta.fields(E).len == std.meta.fields(A).len); + inline for (std.meta.fields(E)) |f| { + const e = @field(expected, f.name); + const a = @field(actual, f.name); + if (comptime std.meta.trait.isZigString(f.field_type)) { + try std.testing.expectEqualStrings(a, e); + } else { + try std.testing.expectEqual(a, e); + } + } +} + +// Helper for testing purposes +fn ExpectContext(comptime val: anytype) type { + return struct { + pub fn handle(_: @This(), _: anytype, _: anytype, ctx: anytype, _: void) !void { + try expectDeepEquals(val, ctx); } }; } +fn expectContext(comptime val: anytype) ExpectContext(val) { + return .{}; +} +/// Catches any errors returned by the `next` chain, and passes them via context +/// to an error handler if one occurs pub fn CatchErrors(comptime ErrorHandler: type) type { return struct { error_handler: ErrorHandler, @@ -118,14 +205,17 @@ pub fn CatchErrors(comptime ErrorHandler: type) type { } }; } + pub fn catchErrors(error_handler: anytype) CatchErrors(@TypeOf(error_handler)) { return .{ .error_handler = error_handler }; } +/// Default error handler for CatchErrors, logs the error and outputs responds with a 500 if +/// a response has not been written yet pub const default_error_handler = struct { - fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - _ = next; - std.log.err("Error {} on uri {s}", .{ ctx.err, req.uri }); + fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, _: anytype) !void { + const should_log = !@import("builtin").is_test; + if (should_log) std.log.err("Error {} on uri {s}", .{ ctx.err, req.uri }); // Tell the server to close the connection after this request res.should_close = true; @@ -141,7 +231,47 @@ pub const default_error_handler = struct { } }{}; -pub const split_uri = struct { +test "CatchErrors" { + const TestResponse = struct { + should_close: bool = false, + was_opened: bool = false, + + test_should_open: bool, + const TestStream = struct { + fn close(_: *@This()) void {} + fn finish(_: *@This()) !void {} + }; + + fn open(self: *@This(), status: http.Status, _: *http.Fields) !TestStream { + self.was_opened = true; + if (!self.test_should_open) return error.ResponseOpenedTwice; + try std.testing.expectEqual(status, .internal_server_error); + return .{}; + } + }; + + const middleware_list = apply(.{ + catchErrors(default_error_handler), + struct { + fn handle(_: @This(), _: anytype, _: anytype, _: anytype, _: anytype) !void { + return error.SomeError; + } + }{}, + }); + + var response = TestResponse{ .test_should_open = true }; + try middleware_list.handle(.{ .uri = "abcd" }, &response, .{}, {}); + try std.testing.expect(response.should_close); + + // Test that it doesn't open a response if one was already opened + response = TestResponse{ .test_should_open = false, .was_opened = true }; + try middleware_list.handle(.{ .uri = "abcd" }, &response, .{}, {}); + try std.testing.expect(response.should_close); +} + +/// Takes the request uri provided and splits it into "path", "query_string", and "fragment_string" +/// parts, which are placed into context. +const SplitUri = struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { var frag_split = std.mem.split(u8, req.uri, "#"); const without_fragment = frag_split.first(); @@ -168,9 +298,32 @@ pub const split_uri = struct { {}, ); } -}{}; +}; +pub const split_uri = SplitUri{}; -// routes a request to the correct handler based on declared HTTP method and path +test "split_uri" { + const testCase = struct { + fn func(uri: []const u8, ctx: anytype, expected: anytype) !void { + const v = apply(.{ + split_uri, + expectContext(expected), + }); + try v.handle(.{ .uri = uri }, .{}, ctx, {}); + } + }.func; + + try testCase("/", .{}, .{ .path = "/", .query_string = "", .fragment_string = "" }); + try testCase("", .{}, .{ .path = "", .query_string = "", .fragment_string = "" }); + try testCase("/path", .{}, .{ .path = "/path", .query_string = "", .fragment_string = "" }); + try testCase("?abcd=1234", .{}, .{ .path = "", .query_string = "abcd=1234", .fragment_string = "" }); + try testCase("#abcd", .{}, .{ .path = "", .query_string = "", .fragment_string = "abcd" }); + try testCase("/abcd/efgh?query=no#frag", .{}, .{ .path = "/abcd/efgh", .query_string = "query=no", .fragment_string = "frag" }); +} + +/// Routes a request between the provided routes. +/// +/// CURRENTLY: Does not do this intelligently, all routing is handled by the routes themselves. +/// TODO: Consider implementing this with a hashmap? pub fn Router(comptime Routes: type) type { return struct { routes: Routes, @@ -204,6 +357,7 @@ fn pathMatches(route: []const u8, path: []const u8) bool { const path_segment = path_iter.next() orelse return false; if (route_segment.len > 0 and route_segment[0] == ':') { // Route Argument + if (path_segment.len == 0) return error.RouteMismatch; } else { if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return false; } @@ -212,6 +366,19 @@ fn pathMatches(route: []const u8, path: []const u8) bool { return true; } + +/// Handler that either calls its next middleware parameter or returns error.RouteMismatch +/// depending on if the request matches the described route. +/// Must be below `split_uri` on the middleware list. +/// +/// Format: +/// Each route segment can be either a literal string or an argument. Literal strings +/// must match exactly in order to constitute a matching route. Arguments must begin with +/// the character ':', with the remainer of the segment referring to the name of the argument. +/// Argument values must be nonempty. +/// +/// For example, the route "/abc/:foo/def" would match "/abc/x/def" or "/abc/blahblah/def" but +/// not "/abc//def". pub const Route = struct { pub const Desc = struct { path: []const u8, @@ -232,7 +399,6 @@ pub const Route = struct { } pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - std.log.debug("Testing path {s} against {s}", .{ ctx.path, self.desc.path }); return if (self.applies(req, ctx)) next.handle(req, res, ctx, {}) else @@ -240,12 +406,47 @@ pub const Route = struct { } }; +test "route" { + const testCase = struct { + fn func(should_match: bool, route: Route.Desc, method: http.Method, path: []const u8) !void { + const no_op = struct { + fn handle(_: anytype, _: anytype, _: anytype, _: anytype, _: anytype) !void {} + }{}; + const result = (Route{ .desc = route }).handle(.{ .method = method }, .{}, .{ .path = path }, no_op); + try if (should_match) result else std.testing.expectError(error.RouteMismatch, result); + } + }.func; + + try testCase(true, .{ .method = .GET, .path = "/" }, .GET, "/"); + try testCase(true, .{ .method = .GET, .path = "/" }, .GET, ""); + try testCase(true, .{ .method = .GET, .path = "/abcd" }, .GET, "/abcd"); + try testCase(true, .{ .method = .GET, .path = "/abcd" }, .GET, "abcd"); + try testCase(true, .{ .method = .POST, .path = "/" }, .POST, "/"); + try testCase(true, .{ .method = .POST, .path = "/" }, .POST, ""); + try testCase(true, .{ .method = .POST, .path = "/abcd" }, .POST, "/abcd"); + try testCase(true, .{ .method = .POST, .path = "/abcd" }, .POST, "abcd"); + try testCase(true, .{ .method = .POST, .path = "/abcd/efgh" }, .POST, "abcd/efgh"); + try testCase(true, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "abcd/efgh"); + try testCase(true, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz"); + + try testCase(false, .{ .method = .POST, .path = "/" }, .GET, "/"); + try testCase(false, .{ .method = .GET, .path = "/abcd" }, .GET, ""); + try testCase(false, .{ .method = .GET, .path = "/" }, .GET, "/abcd"); + try testCase(false, .{ .method = .POST, .path = "/abcd/efgh" }, .POST, "efgh"); + try testCase(false, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "/abcd/"); + try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/"); + try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz/foo"); +} + +/// Mounts a router subtree under a given path. Middlewares further down on the list +/// are called with the path prefix specified by `route` removed from the path. +/// Must be below `split_uri` on the middleware list. pub fn Mount(comptime route: []const u8) type { return struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { var path_iter = util.PathIter.from(ctx.path); comptime var route_iter = util.PathIter.from(route); - var path_unused = ctx.path; + var path_unused: []const u8 = ctx.path; inline while (comptime route_iter.next()) |route_segment| { if (comptime route_segment.len == 0) continue; @@ -269,20 +470,26 @@ pub fn mount(comptime route: []const u8) Mount(route) { return .{}; } -pub fn HandleNotFound(comptime NotFoundHandler: type) type { - return struct { - not_found: NotFoundHandler, - - pub fn handler(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - return next.handler(req, res, ctx, {}) catch |err| switch (err) { - error.RouteMismatch => return self.not_found.handler(req, res, ctx, {}), - else => return err, - }; +test "mount" { + const testCase = struct { + fn func(comptime base: []const u8, request: []const u8, comptime expected: ?[]const u8) !void { + const result = mount(base).handle(.{}, .{}, addField(.{}, "path", request), expectContext(.{ .path = expected orelse "" })); + try if (expected != null) result else std.testing.expectError(error.RouteMismatch, result); } - }; + }.func; + try testCase("/api/", "/api/", ""); + try testCase("/api/", "/api/abcd", "abcd"); + try testCase("/api/", "/api/abcd/efgh", "abcd/efgh"); + try testCase("/api/", "/api/abcd/efgh/", "abcd/efgh/"); + try testCase("/api/v0", "/api/v0/call", "call"); + + try testCase("/api/", "/web/abcd/efgh/", null); + try testCase("/api/", "/", null); + try testCase("/api/", "/ap", null); + try testCase("/api/v0", "/api/v1/", null); } -fn parsePathArgs(comptime route: []const u8, comptime Args: type, path: []const u8) !Args { +fn parseArgsFromPath(comptime route: []const u8, comptime Args: type, path: []const u8) !Args { var args: Args = undefined; var path_iter = util.PathIter.from(path); comptime var route_iter = util.PathIter.from(route); @@ -290,8 +497,9 @@ fn parsePathArgs(comptime route: []const u8, comptime Args: type, path: []const const path_segment = path_iter.next() orelse return error.RouteMismatch; if (route_segment.len > 0 and route_segment[0] == ':') { // route segment is an argument segment + if (path_segment.len == 0) return error.RouteMismatch; const A = @TypeOf(@field(args, route_segment[1..])); - @field(args, route_segment[1..]) = try parsePathArg(A, path_segment); + @field(args, route_segment[1..]) = try parseArgFromPath(A, path_segment); } else { if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch; } @@ -302,13 +510,35 @@ fn parsePathArgs(comptime route: []const u8, comptime Args: type, path: []const return args; } -fn parsePathArg(comptime T: type, segment: []const u8) !T { +fn parseArgFromPath(comptime T: type, segment: []const u8) !T { if (T == []const u8) return segment; if (comptime std.meta.trait.isContainer(T) and std.meta.trait.hasFn("parse")(T)) return T.parse(segment); + if (comptime std.meta.trait.is(.Int)(T)) return std.fmt.parseInt(T, segment, 0); @compileError("Unsupported Type " ++ @typeName(T)); } +/// Parse arguments directly the request path. +/// Must be placed after a `split_uri` middleware in order to get `path` from context. +/// +/// Route arguments are specified in the same format as for Route. The name of the argument +/// refers to the field name in Args that the argument will be parsed to. +/// +/// This currently works with arguments of 3 different types: +/// - integers +/// - []const u8, +/// - anything with a function of the form: +/// * T.parse([]const u8) Error!T +/// * This function cannot hold a reference to the passed string once it appears +/// +/// Example: +/// ParsePathArgs("/:id/foo/:name/byrank/:rank", struct { +/// id: util.Uuid, +/// name: []const u8, +/// rank: u32, +/// }) +/// Would parse a path of "/00000000-0000-0000-0000-000000000000/foo/jaina/byrank/3" into +/// .{ .id = try Uuid.parse("00000000-0000-0000-0000-000000000000"), .name = "jaina", .rank = 3 } pub fn ParsePathArgs(comptime route: []const u8, comptime Args: type) type { return struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { @@ -316,12 +546,42 @@ pub fn ParsePathArgs(comptime route: []const u8, comptime Args: type) type { return next.handle( req, res, - addField(ctx, "args", try parsePathArgs(route, Args, ctx.path)), + addField(ctx, "args", try parseArgsFromPath(route, Args, ctx.path)), {}, ); } }; } +pub fn parsePathArgs(comptime route: []const u8, comptime Args: type) ParsePathArgs(route, Args) { + return .{}; +} + +test "ParsePathArgs" { + const testCase = struct { + fn func(comptime route: []const u8, comptime Args: type, path: []const u8, expected: anytype) !void { + const check = struct { + expected: @TypeOf(expected), + path: []const u8, + fn handle(self: @This(), _: anytype, _: anytype, ctx: anytype, _: void) !void { + try expectDeepEquals(self.expected, ctx.args); + try std.testing.expectEqualStrings(self.path, ctx.path); + } + }{ .expected = expected, .path = path }; + try parsePathArgs(route, Args).handle(.{}, .{}, .{ .path = path }, check); + } + }.func; + + try testCase("/", void, "/", {}); + try testCase("/:id", struct { id: usize }, "/3", .{ .id = 3 }); + try testCase("/:str", struct { str: []const u8 }, "/abcd", .{ .str = "abcd" }); + try testCase("/:id/xyz/:str", struct { id: usize, str: []const u8 }, "/3/xyz/abcd", .{ .id = 3, .str = "abcd" }); + try testCase("/:id", struct { id: util.Uuid }, "/" ++ util.Uuid.nil.toCharArray(), .{ .id = util.Uuid.nil }); + + try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/", .{})); + try std.testing.expectError(error.RouteMismatch, testCase("/abcd/:id", struct { id: usize }, "/123", .{})); + try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/3/id/blahblah", .{ .id = 3 })); + try std.testing.expectError(error.InvalidCharacter, testCase("/:id", struct { id: usize }, "/xyz", .{})); +} const BaseContentType = enum { json, @@ -331,7 +591,7 @@ const BaseContentType = enum { other, }; -fn parseBody(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T { +fn parseBodyFromRequest(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T { //@compileLog(T); const buf = try reader.readAllAlloc(alloc, 1 << 16); defer alloc.free(buf); @@ -351,6 +611,7 @@ fn parseBody(comptime T: type, content_type: BaseContentType, reader: anytype, a } } +// figure out what base parser to use fn matchContentType(hdr: ?[]const u8) ?BaseContentType { if (hdr) |h| { if (std.ascii.eqlIgnoreCase(h, "application/x-www-form-urlencoded")) return .url_encoded; @@ -363,6 +624,11 @@ fn matchContentType(hdr: ?[]const u8) ?BaseContentType { return null; } +/// Parses a set of body arguments from the request body based on the request's Content-Type +/// header. +/// +/// The exact method for parsing depends partially on the Content-Type. json types are preferred +/// TODO: Need tests for this, including various Content-Type values pub fn ParseBody(comptime Body: type) type { return struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { @@ -377,7 +643,7 @@ pub fn ParseBody(comptime Body: type) type { const base_content_type = matchContentType(content_type); var stream = req.body orelse return error.NoBody; - const body = try parseBody(Body, base_content_type orelse .json, stream.reader(), ctx.allocator); + const body = try parseBodyFromRequest(Body, base_content_type orelse .json, stream.reader(), ctx.allocator); defer util.deepFree(ctx.allocator, body); return next.handle( @@ -389,7 +655,11 @@ pub fn ParseBody(comptime Body: type) type { } }; } +pub fn parseBody(comptime Body: type) ParseBody(Body) { + return .{}; +} +/// Parses query parameters as defined in query.zig pub fn ParseQueryParams(comptime QueryParams: type) type { return struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { From de19083cd981d11fd4d8f5a42eddbd6106929b74 Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Sat, 26 Nov 2022 23:11:34 -0800 Subject: [PATCH 2/9] Update build.zig for unit tests --- build.zig | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/build.zig b/build.zig index f68bbd5..cf43658 100644 --- a/build.zig +++ b/build.zig @@ -98,16 +98,24 @@ pub fn build(b: *std.build.Builder) !void { exe.linkLibC(); exe.addSystemIncludePath("/usr/include/"); + const unittest_http_cmd = b.step("unit:http", "Run tests for http package"); + const unittest_http = b.addTest("src/http/test.zig"); + unittest_http_cmd.dependOn(&unittest_http.step); + unittest_http.addPackage(pkgs.util); + + //const unittest_util_cmd = b.step("unit:util", "Run tests for util package"); + //const unittest_util = b.addTest("src/util/Uuid.zig"); + //unittest_util_cmd.dependOn(&unittest_util.step); + //const util_tests = b.addTest("src/util/lib.zig"); - const http_tests = b.addTest("src/http/test.zig"); //const sql_tests = b.addTest("src/sql/lib.zig"); - http_tests.addPackage(pkgs.util); + //http_tests.addPackage(pkgs.util); //sql_tests.addPackage(pkgs.util); - const unit_tests = b.step("unit-tests", "Run tests"); - //unit_tests.dependOn(&util_tests.step); - unit_tests.dependOn(&http_tests.step); - //unit_tests.dependOn(&sql_tests.step); + //const unit_tests = b.step("unit-tests", "Run tests"); + const unittest_all = b.step("unit", "Run unit tests"); + unittest_all.dependOn(unittest_http_cmd); + //unittest_all.dependOn(unittest_util_cmd); const api_integration = b.addTest("./tests/api_integration/lib.zig"); api_integration.addPackage(pkgs.opts); From 29a38240d933136d88a57bcd62d9f5d1521a7a5b Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Sat, 26 Nov 2022 23:14:19 -0800 Subject: [PATCH 3/9] Fix test --- src/http/middleware.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/http/middleware.zig b/src/http/middleware.zig index e34f671..97855b1 100644 --- a/src/http/middleware.zig +++ b/src/http/middleware.zig @@ -357,7 +357,7 @@ fn pathMatches(route: []const u8, path: []const u8) bool { const path_segment = path_iter.next() orelse return false; if (route_segment.len > 0 and route_segment[0] == ':') { // Route Argument - if (path_segment.len == 0) return error.RouteMismatch; + if (path_segment.len == 0) return false; } else { if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return false; } From d0e08e4b04570afb3ae550391f9e8e0e35f6859e Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Sat, 26 Nov 2022 23:28:00 -0800 Subject: [PATCH 4/9] Fix parsing for enum and boolean values --- src/http/query.zig | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/http/query.zig b/src/http/query.zig index 1933429..35438f5 100644 --- a/src/http/query.zig +++ b/src/http/query.zig @@ -66,10 +66,6 @@ const QueryIter = @import("util").QueryIter; /// Would be used to parse a query string like /// `?foo.baz=12345` /// -/// Compound types cannot currently be nullable, and must be structs. -/// -/// TODO: values are currently case-sensitive, and are not url-decoded properly. -/// This should be fixed. pub fn parseQuery(alloc: std.mem.Allocator, comptime T: type, query: []const u8) !T { if (comptime !std.meta.trait.isContainer(T)) @compileError("T must be a struct"); var iter = QueryIter.from(query); @@ -88,7 +84,7 @@ pub fn parseQuery(alloc: std.mem.Allocator, comptime T: type, query: []const u8) return (try parse(alloc, T, "", "", fields)) orelse error.NoQuery; } -fn decodeString(alloc: std.mem.Allocator, val: []const u8) ![]const u8 { +fn decodeString(alloc: std.mem.Allocator, val: []const u8) ![]u8 { var list = try std.ArrayList(u8).initCapacity(alloc, val.len); errdefer list.deinit(); @@ -266,20 +262,23 @@ fn parseQueryValueNotNull(alloc: std.mem.Allocator, comptime T: type, value: []c if (comptime std.meta.trait.isZigString(T)) return decoded; + defer alloc.free(decoded); + const result = if (comptime std.meta.trait.isIntegral(T)) try std.fmt.parseInt(T, decoded, 0) else if (comptime std.meta.trait.isFloat(T)) try std.fmt.parseFloat(T, decoded) - else if (comptime std.meta.trait.is(.Enum)(T)) - std.meta.stringToEnum(T, decoded) orelse return error.InvalidEnumValue - else if (T == bool) - bool_map.get(value) orelse return error.InvalidBool - else if (comptime std.meta.trait.hasFn("parse")(T)) + else if (comptime std.meta.trait.is(.Enum)(T)) blk: { + _ = std.ascii.lowerString(decoded, decoded); + break :blk std.meta.stringToEnum(T, decoded) orelse return error.InvalidEnumValue; + } else if (T == bool) blk: { + _ = std.ascii.lowerString(decoded, decoded); + break :blk bool_map.get(decoded) orelse return error.InvalidBool; + } else if (comptime std.meta.trait.hasFn("parse")(T)) try T.parse(value) else @compileError("Invalid type " ++ @typeName(T)); - alloc.free(decoded); return result; } @@ -359,7 +358,7 @@ fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytyp } } -test { +test "parseQuery" { const TestQuery = struct { int: usize = 3, boolean: bool = false, @@ -370,11 +369,11 @@ test { .int = 3, .boolean = false, .str_enum = null, - }, try parseQuery(TestQuery, "")); + }, try parseQuery(std.testing.allocator, TestQuery, "")); try std.testing.expectEqual(TestQuery{ .int = 5, .boolean = true, .str_enum = .foo, - }, try parseQuery(TestQuery, "?int=5&boolean=yes&str_enum=foo")); + }, try parseQuery(std.testing.allocator, TestQuery, "?int=5&boolean=yes&str_enum=foo")); } From f7f84f051629384b4c88485c5c34e6864e04ae61 Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Sun, 27 Nov 2022 00:21:50 -0800 Subject: [PATCH 5/9] Add util.testing.expectDeepEqual --- src/util/lib.zig | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/src/util/lib.zig b/src/util/lib.zig index 9a66dc8..acc29f3 100644 --- a/src/util/lib.zig +++ b/src/util/lib.zig @@ -194,3 +194,44 @@ pub fn seedThreadPrng() !void { prng = std.rand.DefaultPrng.init(@bitCast(u64, buf)); } + +pub const testing = struct { + pub fn expectDeepEqual(expected: anytype, actual: @TypeOf(expected)) !void { + const T = @TypeOf(expected); + switch (@typeInfo(T)) { + .Null, .Void => return, + .Int, .Float, .Bool, .Enum => try std.testing.expectEqual(expected, actual), + .Struct => { + inline for (comptime std.meta.fieldNames(T)) |f| { + try expectDeepEqual(@field(expected, f), @field(actual, f)); + } + }, + .Union => { + inline for (comptime std.meta.fieldNames(T)) |f| { + if (std.meta.isTag(expected, f)) { + try std.testing.expect(std.std.meta.isTag(actual, f)); + try expectDeepEqual(@field(expected, f), @field(actual, f)); + } + } + }, + .Pointer, .Array => { + if (comptime std.meta.trait.isIndexable(T)) { + try std.testing.expectEqual(expected.len, actual.len); + for (expected) |_, i| { + try expectDeepEqual(expected[i], actual[i]); + } + } else if (comptime std.meta.trait.isSingleItemPtr(T)) { + try expectDeepEqual(expected.*, actual.*); + } + }, + .Optional => { + if (expected) |e| { + try expectDeepEqual(e, actual orelse return error.TestExpectedEqual); + } else { + try std.testing.expect(actual == null); + } + }, + else => @compileError("Unsupported Type " ++ @typeName(T)), + } + } +}; From ce40448dc8ae866502163baf78f9d821ddca76ae Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Sun, 27 Nov 2022 00:58:56 -0800 Subject: [PATCH 6/9] Revamp QueryString parser test --- src/http/query.zig | 137 ++++++++++++++++++++++++++++++++++----------- src/http/test.zig | 2 + src/util/lib.zig | 7 +++ 3 files changed, 113 insertions(+), 33 deletions(-) diff --git a/src/http/query.zig b/src/http/query.zig index 35438f5..1396696 100644 --- a/src/http/query.zig +++ b/src/http/query.zig @@ -1,6 +1,7 @@ const std = @import("std"); +const util = @import("util"); -const QueryIter = @import("util").QueryIter; +const QueryIter = util.QueryIter; /// Parses a set of query parameters described by the struct `T`. /// @@ -84,6 +85,10 @@ pub fn parseQuery(alloc: std.mem.Allocator, comptime T: type, query: []const u8) return (try parse(alloc, T, "", "", fields)) orelse error.NoQuery; } +pub fn parseQueryFree(alloc: std.mem.Allocator, val: anytype) void { + util.deepFree(alloc, val); +} + fn decodeString(alloc: std.mem.Allocator, val: []const u8) ![]u8 { var list = try std.ArrayList(u8).initCapacity(alloc, val.len); errdefer list.deinit(); @@ -142,6 +147,9 @@ fn parse( .Struct => |info| { var result: T = undefined; var fields_specified: usize = 0; + errdefer inline for (info.fields) |field, i| { + if (fields_specified < i) util.deepFree(alloc, @field(result, field.name)); + }; inline for (info.fields) |field| { const F = field.field_type; @@ -151,7 +159,7 @@ fn parse( maybe_value = v; } else if (field.default_value) |default| { if (comptime @sizeOf(F) != 0) { - maybe_value = @ptrCast(*const F, @alignCast(@alignOf(F), default)).*; + maybe_value = try util.deepClone(alloc, @ptrCast(*const F, @alignCast(@alignOf(F), default)).*); } else { maybe_value = std.mem.zeroes(F); } @@ -227,10 +235,38 @@ fn Intermediary(comptime T: type) type { } }); } -fn parseQueryValue(alloc: std.mem.Allocator, comptime T: type, value: ?[]const u8) !T { +fn parseQueryValue(alloc: std.mem.Allocator, comptime T: type, maybe_value: ?[]const u8) !T { const is_optional = comptime std.meta.trait.is(.Optional)(T); - // If param is present, but without an associated value - if (value == null) { + if (maybe_value) |value| { + const Eff = if (is_optional) std.meta.Child(T) else T; + + if (value.len == 0 and is_optional) return null; + + const decoded = try decodeString(alloc, value); + errdefer alloc.free(decoded); + + if (comptime std.meta.trait.isZigString(Eff)) return decoded; + + defer alloc.free(decoded); + + const result = if (comptime std.meta.trait.isIntegral(Eff)) + try std.fmt.parseInt(Eff, decoded, 0) + else if (comptime std.meta.trait.isFloat(Eff)) + try std.fmt.parseFloat(Eff, decoded) + else if (comptime std.meta.trait.is(.Enum)(Eff)) blk: { + _ = std.ascii.lowerString(decoded, decoded); + break :blk std.meta.stringToEnum(Eff, decoded) orelse return error.InvalidEnumValue; + } else if (Eff == bool) blk: { + _ = std.ascii.lowerString(decoded, decoded); + break :blk bool_map.get(decoded) orelse return error.InvalidBool; + } else if (comptime std.meta.trait.hasFn("parse")(Eff)) + try Eff.parse(value) + else + @compileError("Invalid type " ++ @typeName(T)); + + return result; + } else { + // If param is present, but without an associated value return if (is_optional) null else if (T == bool) @@ -238,8 +274,6 @@ fn parseQueryValue(alloc: std.mem.Allocator, comptime T: type, value: ?[]const u else error.InvalidValue; } - - return try parseQueryValueNotNull(alloc, if (is_optional) std.meta.Child(T) else T, value.?); } const bool_map = std.ComptimeStringMap(bool, .{ @@ -256,32 +290,6 @@ const bool_map = std.ComptimeStringMap(bool, .{ .{ "0", false }, }); -fn parseQueryValueNotNull(alloc: std.mem.Allocator, comptime T: type, value: []const u8) !T { - const decoded = try decodeString(alloc, value); - errdefer alloc.free(decoded); - - if (comptime std.meta.trait.isZigString(T)) return decoded; - - defer alloc.free(decoded); - - const result = if (comptime std.meta.trait.isIntegral(T)) - try std.fmt.parseInt(T, decoded, 0) - else if (comptime std.meta.trait.isFloat(T)) - try std.fmt.parseFloat(T, decoded) - else if (comptime std.meta.trait.is(.Enum)(T)) blk: { - _ = std.ascii.lowerString(decoded, decoded); - break :blk std.meta.stringToEnum(T, decoded) orelse return error.InvalidEnumValue; - } else if (T == bool) blk: { - _ = std.ascii.lowerString(decoded, decoded); - break :blk bool_map.get(decoded) orelse return error.InvalidBool; - } else if (comptime std.meta.trait.hasFn("parse")(T)) - try T.parse(value) - else - @compileError("Invalid type " ++ @typeName(T)); - - return result; -} - fn isScalar(comptime T: type) bool { if (comptime std.meta.trait.isZigString(T)) return true; if (comptime std.meta.trait.isIntegral(T)) return true; @@ -359,6 +367,69 @@ fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytyp } test "parseQuery" { + const testCase = struct { + fn case(comptime T: type, expected: T, query_string: []const u8) !void { + const result = try parseQuery(std.testing.allocator, T, query_string); + defer parseQueryFree(std.testing.allocator, result); + try util.testing.expectDeepEqual(expected, result); + } + }.case; + + try testCase(struct { int: usize = 3 }, .{ .int = 3 }, ""); + try testCase(struct { int: usize = 3 }, .{ .int = 2 }, "int=2"); + try testCase(struct { int: usize = 3 }, .{ .int = 2 }, "int=2&"); + try testCase(struct { boolean: bool = false }, .{ .boolean = false }, ""); + try testCase(struct { boolean: bool = false }, .{ .boolean = true }, "boolean"); + try testCase(struct { boolean: bool = false }, .{ .boolean = true }, "boolean=true"); + try testCase(struct { boolean: bool = false }, .{ .boolean = true }, "boolean=y"); + try testCase(struct { boolean: bool = false }, .{ .boolean = false }, "boolean=f"); + try testCase(struct { boolean: bool = false }, .{ .boolean = false }, "boolean=no"); + try testCase(struct { str_enum: ?enum { foo, bar } = null }, .{ .str_enum = null }, ""); + try testCase(struct { str_enum: ?enum { foo, bar } = null }, .{ .str_enum = .foo }, "str_enum=foo"); + try testCase(struct { str_enum: ?enum { foo, bar } = null }, .{ .str_enum = .bar }, "str_enum=bar"); + try testCase(struct { str_enum: ?enum { foo, bar } = .foo }, .{ .str_enum = .foo }, ""); + try testCase(struct { str_enum: ?enum { foo, bar } = .foo }, .{ .str_enum = null }, "str_enum"); + try testCase(struct { n1: usize = 5, n2: usize = 5 }, .{ .n1 = 1, .n2 = 2 }, "n1=1&n2=2"); + try testCase(struct { n1: usize = 5, n2: usize = 5 }, .{ .n1 = 1, .n2 = 2 }, "n1=1&n2=2&"); + try testCase(struct { n1: usize = 5, n2: usize = 5 }, .{ .n1 = 1, .n2 = 2 }, "n1=1&&n2=2&"); + + try testCase(struct { str: ?[]const u8 = null }, .{ .str = null }, ""); + try testCase(struct { str: ?[]const u8 = null }, .{ .str = null }, "str"); + try testCase(struct { str: ?[]const u8 = null }, .{ .str = null }, "str="); + try testCase(struct { str: ?[]const u8 = null }, .{ .str = "foo" }, "str=foo"); + try testCase(struct { str: ?[]const u8 = "foo" }, .{ .str = "foo" }, "str=foo"); + try testCase(struct { str: ?[]const u8 = "foo" }, .{ .str = "foo" }, ""); + try testCase(struct { str: ?[]const u8 = "foo" }, .{ .str = null }, "str"); + try testCase(struct { str: ?[]const u8 = "foo" }, .{ .str = null }, "str="); + + const rand_uuid = comptime util.Uuid.parse("c1fb6578-4d0c-4eb9-9f67-d56da3ae6f5d") catch unreachable; + try testCase(struct { id: ?util.Uuid = null }, .{ .id = null }, ""); + try testCase(struct { id: ?util.Uuid = null }, .{ .id = null }, "id="); + try testCase(struct { id: ?util.Uuid = null }, .{ .id = null }, "id"); + try testCase(struct { id: ?util.Uuid = null }, .{ .id = rand_uuid }, "id=" ++ rand_uuid.toCharArray()); + try testCase(struct { id: ?util.Uuid = rand_uuid }, .{ .id = rand_uuid }, ""); + try testCase(struct { id: ?util.Uuid = rand_uuid }, .{ .id = null }, "id="); + try testCase(struct { id: ?util.Uuid = rand_uuid }, .{ .id = null }, "id"); + try testCase(struct { id: ?util.Uuid = rand_uuid }, .{ .id = rand_uuid }, "id=" ++ rand_uuid.toCharArray()); + + const SubStruct = struct { + sub: struct { + foo: usize = 1, + bar: usize = 2, + } = .{}, + }; + try testCase(SubStruct, .{ .sub = .{ .foo = 1, .bar = 2 } }, ""); + try testCase(SubStruct, .{ .sub = .{ .foo = 3, .bar = 3 } }, "sub.foo=3&sub.bar=3"); + try testCase(SubStruct, .{ .sub = .{ .foo = 3, .bar = 2 } }, "sub.foo=3"); + + // TODO: Semantics are ill-defined here + // const SubStruct2 = struct { + // sub: ?struct { + // foo: usize = 1, + // } = null, + // }; + // try testCase(SubStruct2, .{ .sub = null }, ""); + const TestQuery = struct { int: usize = 3, boolean: bool = false, diff --git a/src/http/test.zig b/src/http/test.zig index 1441ec2..c142f68 100644 --- a/src/http/test.zig +++ b/src/http/test.zig @@ -1,3 +1,5 @@ test { _ = @import("./request/test_parser.zig"); + _ = @import("./middleware.zig"); + _ = @import("./query.zig"); } diff --git a/src/util/lib.zig b/src/util/lib.zig index acc29f3..84b2122 100644 --- a/src/util/lib.zig +++ b/src/util/lib.zig @@ -160,6 +160,13 @@ pub fn deepClone(alloc: std.mem.Allocator, val: anytype) !@TypeOf(val) { count += 1; } }, + .Union => { + inline for (comptime std.meta.fieldNames(T)) |f| { + if (std.meta.isTag(val, f)) { + return @unionInit(T, f, try deepClone(alloc, @field(val, f))); + } + } else unreachable; + }, .Array => { var count: usize = 0; errdefer for (result[0..count]) |v| deepFree(alloc, v); From abf31ea33c4bdf69d0134c3e12a919d38a5a67ff Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Sun, 27 Nov 2022 01:07:45 -0800 Subject: [PATCH 7/9] Flesh out test cases for unions --- src/http/query.zig | 41 ++++++++++++++++++++++++----------------- src/util/lib.zig | 2 +- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/src/http/query.zig b/src/http/query.zig index 1396696..7bbdcf1 100644 --- a/src/http/query.zig +++ b/src/http/query.zig @@ -422,29 +422,36 @@ test "parseQuery" { try testCase(SubStruct, .{ .sub = .{ .foo = 3, .bar = 3 } }, "sub.foo=3&sub.bar=3"); try testCase(SubStruct, .{ .sub = .{ .foo = 3, .bar = 2 } }, "sub.foo=3"); - // TODO: Semantics are ill-defined here + // TODO: Semantics are ill-defined here. What happens if the substruct doesn't have + // default values? // const SubStruct2 = struct { // sub: ?struct { // foo: usize = 1, // } = null, // }; // try testCase(SubStruct2, .{ .sub = null }, ""); + // try testCase(SubStruct2, .{ .sub = null }, "sub="); - const TestQuery = struct { - int: usize = 3, - boolean: bool = false, - str_enum: ?enum { foo, bar } = null, + // TODO: also here (semantics are well defined it just breaks tests) + // const SubUnion = struct { + // sub: ?union(enum) { + // foo: usize, + // bar: usize, + // } = null, + // }; + // try testCase(SubUnion, .{ .sub = null }, ""); + // try testCase(SubUnion, .{ .sub = null }, "sub="); + + const SubUnion2 = struct { + sub: ?struct { + foo: usize, + val: union(enum) { + bar: []const u8, + baz: []const u8, + }, + } = null, }; - - try std.testing.expectEqual(TestQuery{ - .int = 3, - .boolean = false, - .str_enum = null, - }, try parseQuery(std.testing.allocator, TestQuery, "")); - - try std.testing.expectEqual(TestQuery{ - .int = 5, - .boolean = true, - .str_enum = .foo, - }, try parseQuery(std.testing.allocator, TestQuery, "?int=5&boolean=yes&str_enum=foo")); + try testCase(SubUnion2, .{ .sub = null }, ""); + try testCase(SubUnion2, .{ .sub = .{ .foo = 1, .val = .{ .bar = "abc" } } }, "sub.foo=1&sub.bar=abc"); + try testCase(SubUnion2, .{ .sub = .{ .foo = 1, .val = .{ .baz = "abc" } } }, "sub.foo=1&sub.baz=abc"); } diff --git a/src/util/lib.zig b/src/util/lib.zig index 84b2122..9829ae3 100644 --- a/src/util/lib.zig +++ b/src/util/lib.zig @@ -216,7 +216,7 @@ pub const testing = struct { .Union => { inline for (comptime std.meta.fieldNames(T)) |f| { if (std.meta.isTag(expected, f)) { - try std.testing.expect(std.std.meta.isTag(actual, f)); + try std.testing.expect(std.meta.isTag(actual, f)); try expectDeepEqual(@field(expected, f), @field(actual, f)); } } From b2c69c2df826efcab73402aed2e3a956f08d366e Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Sun, 27 Nov 2022 01:47:21 -0800 Subject: [PATCH 8/9] queryStringify cleanup --- src/http/lib.zig | 1 + src/http/query.zig | 35 +- src/main/controllers.zig | 14 +- src/main/json.zig | 677 --------------------------------------- src/main/query.zig | 380 ---------------------- 5 files changed, 31 insertions(+), 1076 deletions(-) delete mode 100644 src/main/json.zig delete mode 100644 src/main/query.zig diff --git a/src/http/lib.zig b/src/http/lib.zig index 9a4d8c9..9fc5a61 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -16,6 +16,7 @@ pub const Handler = server.Handler; pub const Server = server.Server; pub const middleware = @import("./middleware.zig"); +pub const queryStringify = @import("./query.zig").queryStringify; pub const Fields = @import("./headers.zig").Fields; diff --git a/src/http/query.zig b/src/http/query.zig index 7bbdcf1..36b5d33 100644 --- a/src/http/query.zig +++ b/src/http/query.zig @@ -295,6 +295,7 @@ fn isScalar(comptime T: type) bool { if (comptime std.meta.trait.isIntegral(T)) return true; if (comptime std.meta.trait.isFloat(T)) return true; if (comptime std.meta.trait.is(.Enum)(T)) return true; + if (comptime std.meta.trait.is(.EnumLiteral)(T)) return true; if (T == bool) return true; if (comptime std.meta.trait.hasFn("parse")(T)) return true; @@ -303,8 +304,16 @@ fn isScalar(comptime T: type) bool { return false; } -pub fn formatQuery(params: anytype, writer: anytype) !void { - try format("", "", params, writer); +pub fn QueryStringify(comptime Params: type) type { + return struct { + params: Params, + pub fn format(v: @This(), comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { + try formatQuery("", "", v.params, writer); + } + }; +} +pub fn queryStringify(val: anytype) QueryStringify(@TypeOf(val)) { + return QueryStringify(@TypeOf(val)){ .params = val }; } fn urlFormatString(writer: anytype, val: []const u8) !void { @@ -330,14 +339,14 @@ fn formatScalar(comptime name: []const u8, val: anytype, writer: anytype) !void if (comptime std.meta.trait.isZigString(T)) { try urlFormatString(writer, val); } else try switch (@typeInfo(T)) { - .Enum => urlFormatString(writer, @tagName(val)), + .EnumLiteral, .Enum => urlFormatString(writer, @tagName(val)), else => std.fmt.format(writer, "{}", .{val}), }; try writer.writeByte('&'); } -fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytype, writer: anytype) !void { +fn formatQuery(comptime prefix: []const u8, comptime name: []const u8, params: anytype, writer: anytype) !void { const T = @TypeOf(params); const eff_prefix = if (prefix.len == 0) "" else prefix ++ "."; if (comptime isScalar(T)) return formatScalar(eff_prefix ++ name, params, writer); @@ -346,7 +355,7 @@ fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytyp .Struct => { inline for (std.meta.fields(T)) |field| { const val = @field(params, field.name); - try format(eff_prefix ++ name, field.name, val, writer); + try formatQuery(eff_prefix ++ name, field.name, val, writer); } }, .Union => { @@ -355,12 +364,12 @@ fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytyp const tag_name = field.name; if (@as(std.meta.Tag(T), params) == tag) { const val = @field(params, tag_name); - try format(prefix, tag_name, val, writer); + try formatQuery(prefix, tag_name, val, writer); } } }, .Optional => { - if (params) |p| try format(prefix, name, p, writer); + if (params) |p| try formatQuery(prefix, name, p, writer); }, else => @compileError("Unsupported query type"), } @@ -455,3 +464,15 @@ test "parseQuery" { try testCase(SubUnion2, .{ .sub = .{ .foo = 1, .val = .{ .bar = "abc" } } }, "sub.foo=1&sub.bar=abc"); try testCase(SubUnion2, .{ .sub = .{ .foo = 1, .val = .{ .baz = "abc" } } }, "sub.foo=1&sub.baz=abc"); } + +test "formatQuery" { + try std.testing.expectFmt("", "{}", .{queryStringify(.{})}); + try std.testing.expectFmt("id=3&", "{}", .{queryStringify(.{ .id = 3 })}); + try std.testing.expectFmt("id=3&id2=4&", "{}", .{queryStringify(.{ .id = 3, .id2 = 4 })}); + + try std.testing.expectFmt("str=foo&", "{}", .{queryStringify(.{ .str = "foo" })}); + try std.testing.expectFmt("enum_str=foo&", "{}", .{queryStringify(.{ .enum_str = .foo })}); + + try std.testing.expectFmt("boolean=false&", "{}", .{queryStringify(.{ .boolean = false })}); + try std.testing.expectFmt("boolean=true&", "{}", .{queryStringify(.{ .boolean = true })}); +} diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 3263cfa..ea7fa45 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -4,8 +4,6 @@ const builtin = @import("builtin"); const http = @import("http"); const api = @import("api"); const util = @import("util"); -const query_utils = @import("./query.zig"); -const json_utils = @import("./json.zig"); const web_endpoints = @import("./controllers/web.zig").routes; const api_endpoints = @import("./controllers/api.zig").routes; @@ -268,16 +266,8 @@ pub const helpers = struct { // TODO: percent-encode try std.fmt.format( writer, - "<{s}://{s}/{s}?", - .{ @tagName(community.scheme), community.host, path }, - ); - - try query_utils.formatQuery(params, writer); - - try std.fmt.format( - writer, - ">; rel=\"{s}\"", - .{rel}, + "<{s}://{s}/{s}?{}>; rel=\"{s}\"", + .{ @tagName(community.scheme), community.host, path, http.queryStringify(params), rel }, ); } }; diff --git a/src/main/json.zig b/src/main/json.zig deleted file mode 100644 index 21474cc..0000000 --- a/src/main/json.zig +++ /dev/null @@ -1,677 +0,0 @@ -const std = @import("std"); -const mem = std.mem; -const Allocator = std.mem.Allocator; -const assert = std.debug.assert; - -// This file is largely a copy of std.json - -const StreamingParser = std.json.StreamingParser; -const Token = std.json.Token; -const unescapeValidString = std.json.unescapeValidString; -const UnescapeValidStringError = std.json.UnescapeValidStringError; - -pub fn parse(comptime T: type, body: []const u8, alloc: std.mem.Allocator) !T { - var tokens = TokenStream.init(body); - - const options = ParseOptions{ .allocator = alloc }; - - const token = (try tokens.next()) orelse return error.UnexpectedEndOfJson; - const r = try parseInternal(T, token, &tokens, options); - errdefer parseFreeInternal(T, r, options); - if (!options.allow_trailing_data) { - if ((try tokens.next()) != null) unreachable; - assert(tokens.i >= tokens.slice.len); - } - return r; -} - -pub fn parseFree(value: anytype, alloc: std.mem.Allocator) void { - parseFreeInternal(@TypeOf(value), value, .{ .allocator = alloc }); -} - -// WARNING: the objects "parse" method must not contain a reference to the original value -fn hasCustomParse(comptime T: type) bool { - if (!std.meta.trait.hasFn("parse")(T)) return false; - if (!@hasDecl(T, "JsonParseAs")) return false; - - return true; -} - -///// The rest is (modified) from std.json - -/// A small wrapper over a StreamingParser for full slices. Returns a stream of json Tokens. -pub const TokenStream = struct { - i: usize, - slice: []const u8, - parser: StreamingParser, - token: ?Token, - - pub const Error = StreamingParser.Error || error{UnexpectedEndOfJson}; - - pub fn init(slice: []const u8) TokenStream { - return TokenStream{ - .i = 0, - .slice = slice, - .parser = StreamingParser.init(), - .token = null, - }; - } - - fn stackUsed(self: *TokenStream) usize { - return self.parser.stack.len + if (self.token != null) @as(usize, 1) else 0; - } - - pub fn next(self: *TokenStream) Error!?Token { - if (self.token) |token| { - self.token = null; - return token; - } - - var t1: ?Token = undefined; - var t2: ?Token = undefined; - - while (self.i < self.slice.len) { - try self.parser.feed(self.slice[self.i], &t1, &t2); - self.i += 1; - - if (t1) |token| { - self.token = t2; - return token; - } - } - - // Without this a bare number fails, the streaming parser doesn't know the input ended - try self.parser.feed(' ', &t1, &t2); - self.i += 1; - - if (t1) |token| { - return token; - } else if (self.parser.complete) { - return null; - } else { - return error.UnexpectedEndOfJson; - } - } -}; - -/// Checks to see if a string matches what it would be as a json-encoded string -/// Assumes that `encoded` is a well-formed json string -fn encodesTo(decoded: []const u8, encoded: []const u8) bool { - var i: usize = 0; - var j: usize = 0; - while (i < decoded.len) { - if (j >= encoded.len) return false; - if (encoded[j] != '\\') { - if (decoded[i] != encoded[j]) return false; - j += 1; - i += 1; - } else { - const escape_type = encoded[j + 1]; - if (escape_type != 'u') { - const t: u8 = switch (escape_type) { - '\\' => '\\', - '/' => '/', - 'n' => '\n', - 'r' => '\r', - 't' => '\t', - 'f' => 12, - 'b' => 8, - '"' => '"', - else => unreachable, - }; - if (decoded[i] != t) return false; - j += 2; - i += 1; - } else { - var codepoint = std.fmt.parseInt(u21, encoded[j + 2 .. j + 6], 16) catch unreachable; - j += 6; - if (codepoint >= 0xD800 and codepoint < 0xDC00) { - // surrogate pair - assert(encoded[j] == '\\'); - assert(encoded[j + 1] == 'u'); - const low_surrogate = std.fmt.parseInt(u21, encoded[j + 2 .. j + 6], 16) catch unreachable; - codepoint = 0x10000 + (((codepoint & 0x03ff) << 10) | (low_surrogate & 0x03ff)); - j += 6; - } - var buf: [4]u8 = undefined; - const len = std.unicode.utf8Encode(codepoint, &buf) catch unreachable; - if (i + len > decoded.len) return false; - if (!mem.eql(u8, decoded[i .. i + len], buf[0..len])) return false; - i += len; - } - } - } - assert(i == decoded.len); - assert(j == encoded.len); - return true; -} - -/// parse tokens from a stream, returning `false` if they do not decode to `value` -fn parsesTo(comptime T: type, value: T, tokens: *TokenStream, options: ParseOptions) !bool { - // TODO: should be able to write this function to not require an allocator - const tmp = try parse(T, tokens, options); - defer parseFree(T, tmp, options); - - return parsedEqual(tmp, value); -} - -/// Returns if a value returned by `parse` is deep-equal to another value -fn parsedEqual(a: anytype, b: @TypeOf(a)) bool { - switch (@typeInfo(@TypeOf(a))) { - .Optional => { - if (a == null and b == null) return true; - if (a == null or b == null) return false; - return parsedEqual(a.?, b.?); - }, - .Union => |info| { - if (info.tag_type) |UnionTag| { - const tag_a = std.meta.activeTag(a); - const tag_b = std.meta.activeTag(b); - if (tag_a != tag_b) return false; - - inline for (info.fields) |field_info| { - if (@field(UnionTag, field_info.name) == tag_a) { - return parsedEqual(@field(a, field_info.name), @field(b, field_info.name)); - } - } - return false; - } else { - unreachable; - } - }, - .Array => { - for (a) |e, i| - if (!parsedEqual(e, b[i])) return false; - return true; - }, - .Struct => |info| { - inline for (info.fields) |field_info| { - if (!parsedEqual(@field(a, field_info.name), @field(b, field_info.name))) return false; - } - return true; - }, - .Pointer => |ptrInfo| switch (ptrInfo.size) { - .One => return parsedEqual(a.*, b.*), - .Slice => { - if (a.len != b.len) return false; - for (a) |e, i| - if (!parsedEqual(e, b[i])) return false; - return true; - }, - .Many, .C => unreachable, - }, - else => return a == b, - } - unreachable; -} - -const ParseOptions = struct { - allocator: ?Allocator = null, - - /// Behaviour when a duplicate field is encountered. - duplicate_field_behavior: enum { - UseFirst, - Error, - UseLast, - } = .Error, - - /// If false, finding an unknown field returns an error. - ignore_unknown_fields: bool = false, - - allow_trailing_data: bool = false, -}; - -const SkipValueError = error{UnexpectedJsonDepth} || TokenStream.Error; - -fn skipValue(tokens: *TokenStream) SkipValueError!void { - const original_depth = tokens.stackUsed(); - - // Return an error if no value is found - _ = try tokens.next(); - if (tokens.stackUsed() < original_depth) return error.UnexpectedJsonDepth; - if (tokens.stackUsed() == original_depth) return; - - while (try tokens.next()) |_| { - if (tokens.stackUsed() == original_depth) return; - } -} - -fn ParseInternalError(comptime T: type) type { - // `inferred_types` is used to avoid infinite recursion for recursive type definitions. - const inferred_types = [_]type{}; - return ParseInternalErrorImpl(T, &inferred_types); -} - -fn ParseInternalErrorImpl(comptime T: type, comptime inferred_types: []const type) type { - if (hasCustomParse(T)) { - return ParseInternalError(T.JsonParseAs) || T.ParseError; - } - for (inferred_types) |ty| { - if (T == ty) return error{}; - } - - switch (@typeInfo(T)) { - .Bool => return error{UnexpectedToken}, - .Float, .ComptimeFloat => return error{UnexpectedToken} || std.fmt.ParseFloatError, - .Int, .ComptimeInt => { - return error{ UnexpectedToken, InvalidNumber, Overflow } || - std.fmt.ParseIntError || std.fmt.ParseFloatError; - }, - .Optional => |optionalInfo| { - return ParseInternalErrorImpl(optionalInfo.child, inferred_types ++ [_]type{T}); - }, - .Enum => return error{ UnexpectedToken, InvalidEnumTag } || std.fmt.ParseIntError || - std.meta.IntToEnumError || std.meta.IntToEnumError, - .Union => |unionInfo| { - if (unionInfo.tag_type) |_| { - var errors = error{NoUnionMembersMatched}; - for (unionInfo.fields) |u_field| { - errors = errors || ParseInternalErrorImpl(u_field.field_type, inferred_types ++ [_]type{T}); - } - return errors; - } else { - @compileError("Unable to parse into untagged union '" ++ @typeName(T) ++ "'"); - } - }, - .Struct => |structInfo| { - var errors = error{ - DuplicateJSONField, - UnexpectedEndOfJson, - UnexpectedToken, - UnexpectedValue, - UnknownField, - MissingField, - } || SkipValueError || TokenStream.Error; - for (structInfo.fields) |field| { - errors = errors || ParseInternalErrorImpl(field.field_type, inferred_types ++ [_]type{T}); - } - return errors; - }, - .Array => |arrayInfo| { - return error{ UnexpectedEndOfJson, UnexpectedToken } || TokenStream.Error || - UnescapeValidStringError || - ParseInternalErrorImpl(arrayInfo.child, inferred_types ++ [_]type{T}); - }, - .Pointer => |ptrInfo| { - var errors = error{AllocatorRequired} || std.mem.Allocator.Error; - switch (ptrInfo.size) { - .One => { - return errors || ParseInternalErrorImpl(ptrInfo.child, inferred_types ++ [_]type{T}); - }, - .Slice => { - return errors || error{ UnexpectedEndOfJson, UnexpectedToken } || - ParseInternalErrorImpl(ptrInfo.child, inferred_types ++ [_]type{T}) || - UnescapeValidStringError || TokenStream.Error; - }, - else => @compileError("Unable to parse into type '" ++ @typeName(T) ++ "'"), - } - }, - else => return error{}, - } - unreachable; -} - -fn parseInternal( - comptime T: type, - token: Token, - tokens: *TokenStream, - options: ParseOptions, -) ParseInternalError(T)!T { - if (comptime hasCustomParse(T)) { - const val = try parseInternal(T.JsonParseAs, token, tokens, options); - defer parseFreeInternal(T.JsonParseAs, val, options); - return try T.parse(val); - } - - switch (@typeInfo(T)) { - .Bool => { - return switch (token) { - .True => true, - .False => false, - else => error.UnexpectedToken, - }; - }, - .Float, .ComptimeFloat => { - switch (token) { - .Number => |numberToken| return try std.fmt.parseFloat(T, numberToken.slice(tokens.slice, tokens.i - 1)), - .String => |stringToken| return try std.fmt.parseFloat(T, stringToken.slice(tokens.slice, tokens.i - 1)), - else => return error.UnexpectedToken, - } - }, - .Int, .ComptimeInt => { - switch (token) { - .Number => |numberToken| { - if (numberToken.is_integer) - return try std.fmt.parseInt(T, numberToken.slice(tokens.slice, tokens.i - 1), 10); - const float = try std.fmt.parseFloat(f128, numberToken.slice(tokens.slice, tokens.i - 1)); - if (@round(float) != float) return error.InvalidNumber; - if (float > std.math.maxInt(T) or float < std.math.minInt(T)) return error.Overflow; - return @floatToInt(T, float); - }, - .String => |stringToken| { - return std.fmt.parseInt(T, stringToken.slice(tokens.slice, tokens.i - 1), 10) catch |err| { - switch (err) { - error.Overflow => return err, - error.InvalidCharacter => { - const float = try std.fmt.parseFloat(f128, stringToken.slice(tokens.slice, tokens.i - 1)); - if (@round(float) != float) return error.InvalidNumber; - if (float > std.math.maxInt(T) or float < std.math.minInt(T)) return error.Overflow; - return @floatToInt(T, float); - }, - } - }; - }, - else => return error.UnexpectedToken, - } - }, - .Optional => |optionalInfo| { - if (token == .Null) { - return null; - } else { - return try parseInternal(optionalInfo.child, token, tokens, options); - } - }, - .Enum => |enumInfo| { - switch (token) { - .Number => |numberToken| { - if (!numberToken.is_integer) return error.UnexpectedToken; - const n = try std.fmt.parseInt(enumInfo.tag_type, numberToken.slice(tokens.slice, tokens.i - 1), 10); - return try std.meta.intToEnum(T, n); - }, - .String => |stringToken| { - const source_slice = stringToken.slice(tokens.slice, tokens.i - 1); - switch (stringToken.escapes) { - .None => return std.meta.stringToEnum(T, source_slice) orelse return error.InvalidEnumTag, - .Some => { - inline for (enumInfo.fields) |field| { - if (field.name.len == stringToken.decodedLength() and encodesTo(field.name, source_slice)) { - return @field(T, field.name); - } - } - return error.InvalidEnumTag; - }, - } - }, - else => return error.UnexpectedToken, - } - }, - .Union => |unionInfo| { - if (unionInfo.tag_type) |_| { - // try each of the union fields until we find one that matches - inline for (unionInfo.fields) |u_field| { - // take a copy of tokens so we can withhold mutations until success - var tokens_copy = tokens.*; - if (parseInternal(u_field.field_type, token, &tokens_copy, options)) |value| { - tokens.* = tokens_copy; - return @unionInit(T, u_field.name, value); - } else |err| { - // Bubble up error.OutOfMemory - // Parsing some types won't have OutOfMemory in their - // error-sets, for the condition to be valid, merge it in. - if (@as(@TypeOf(err) || error{OutOfMemory}, err) == error.OutOfMemory) return err; - // Bubble up AllocatorRequired, as it indicates missing option - if (@as(@TypeOf(err) || error{AllocatorRequired}, err) == error.AllocatorRequired) return err; - // otherwise continue through the `inline for` - } - } - return error.NoUnionMembersMatched; - } else { - @compileError("Unable to parse into untagged union '" ++ @typeName(T) ++ "'"); - } - }, - .Struct => |structInfo| { - switch (token) { - .ObjectBegin => {}, - else => return error.UnexpectedToken, - } - var r: T = undefined; - var fields_seen = [_]bool{false} ** structInfo.fields.len; - errdefer { - inline for (structInfo.fields) |field, i| { - if (fields_seen[i] and !field.is_comptime) { - parseFreeInternal(field.field_type, @field(r, field.name), options); - } - } - } - - while (true) { - switch ((try tokens.next()) orelse return error.UnexpectedEndOfJson) { - .ObjectEnd => break, - .String => |stringToken| { - const key_source_slice = stringToken.slice(tokens.slice, tokens.i - 1); - var child_options = options; - child_options.allow_trailing_data = true; - var found = false; - inline for (structInfo.fields) |field, i| { - // TODO: using switches here segfault the compiler (#2727?) - if ((stringToken.escapes == .None and mem.eql(u8, field.name, key_source_slice)) or (stringToken.escapes == .Some and (field.name.len == stringToken.decodedLength() and encodesTo(field.name, key_source_slice)))) { - // if (switch (stringToken.escapes) { - // .None => mem.eql(u8, field.name, key_source_slice), - // .Some => (field.name.len == stringToken.decodedLength() and encodesTo(field.name, key_source_slice)), - // }) { - if (fields_seen[i]) { - // switch (options.duplicate_field_behavior) { - // .UseFirst => {}, - // .Error => {}, - // .UseLast => {}, - // } - if (options.duplicate_field_behavior == .UseFirst) { - // unconditonally ignore value. for comptime fields, this skips check against default_value - const next_token = (try tokens.next()) orelse return error.UnexpectedEndOfJson; - parseFreeInternal(field.field_type, try parseInternal(field.field_type, next_token, tokens, child_options), child_options); - found = true; - break; - } else if (options.duplicate_field_behavior == .Error) { - return error.DuplicateJSONField; - } else if (options.duplicate_field_behavior == .UseLast) { - if (!field.is_comptime) { - parseFreeInternal(field.field_type, @field(r, field.name), child_options); - } - fields_seen[i] = false; - } - } - if (field.is_comptime) { - if (!try parsesTo(field.field_type, @ptrCast(*const field.field_type, field.default_value.?).*, tokens, child_options)) { - return error.UnexpectedValue; - } - } else { - const next_token = (try tokens.next()) orelse return error.UnexpectedEndOfJson; - @field(r, field.name) = try parseInternal(field.field_type, next_token, tokens, child_options); - } - fields_seen[i] = true; - found = true; - break; - } - } - if (!found) { - if (options.ignore_unknown_fields) { - try skipValue(tokens); - continue; - } else { - return error.UnknownField; - } - } - }, - else => return error.UnexpectedToken, - } - } - inline for (structInfo.fields) |field, i| { - if (!fields_seen[i]) { - if (field.default_value) |default_ptr| { - if (!field.is_comptime) { - const default = @ptrCast(*align(1) const field.field_type, default_ptr).*; - @field(r, field.name) = default; - } - } else { - return error.MissingField; - } - } - } - return r; - }, - .Array => |arrayInfo| { - switch (token) { - .ArrayBegin => { - var r: T = undefined; - var i: usize = 0; - var child_options = options; - child_options.allow_trailing_data = true; - errdefer { - // Without the r.len check `r[i]` is not allowed - if (r.len > 0) while (true) : (i -= 1) { - parseFreeInternal(arrayInfo.child, r[i], options); - if (i == 0) break; - }; - } - while (i < r.len) : (i += 1) { - const next_token = (try tokens.next()) orelse return error.UnexpectedEndOfJson; - r[i] = try parseInternal(arrayInfo.child, next_token, tokens, child_options); - } - const tok = (try tokens.next()) orelse return error.UnexpectedEndOfJson; - switch (tok) { - .ArrayEnd => {}, - else => return error.UnexpectedToken, - } - return r; - }, - .String => |stringToken| { - if (arrayInfo.child != u8) return error.UnexpectedToken; - var r: T = undefined; - const source_slice = stringToken.slice(tokens.slice, tokens.i - 1); - switch (stringToken.escapes) { - .None => mem.copy(u8, &r, source_slice), - .Some => try unescapeValidString(&r, source_slice), - } - return r; - }, - else => return error.UnexpectedToken, - } - }, - .Pointer => |ptrInfo| { - const allocator = options.allocator orelse return error.AllocatorRequired; - switch (ptrInfo.size) { - .One => { - const r: T = try allocator.create(ptrInfo.child); - errdefer allocator.destroy(r); - r.* = try parseInternal(ptrInfo.child, token, tokens, options); - return r; - }, - .Slice => { - switch (token) { - .ArrayBegin => { - var arraylist = std.ArrayList(ptrInfo.child).init(allocator); - errdefer { - while (arraylist.popOrNull()) |v| { - parseFreeInternal(ptrInfo.child, v, options); - } - arraylist.deinit(); - } - - while (true) { - const tok = (try tokens.next()) orelse return error.UnexpectedEndOfJson; - switch (tok) { - .ArrayEnd => break, - else => {}, - } - - try arraylist.ensureUnusedCapacity(1); - const v = try parseInternal(ptrInfo.child, tok, tokens, options); - arraylist.appendAssumeCapacity(v); - } - - if (ptrInfo.sentinel) |some| { - const sentinel_value = @ptrCast(*const ptrInfo.child, some).*; - try arraylist.append(sentinel_value); - const output = arraylist.toOwnedSlice(); - return output[0 .. output.len - 1 :sentinel_value]; - } - - return arraylist.toOwnedSlice(); - }, - .String => |stringToken| { - if (ptrInfo.child != u8) return error.UnexpectedToken; - const source_slice = stringToken.slice(tokens.slice, tokens.i - 1); - const len = stringToken.decodedLength(); - const output = try allocator.alloc(u8, len + @boolToInt(ptrInfo.sentinel != null)); - errdefer allocator.free(output); - switch (stringToken.escapes) { - .None => mem.copy(u8, output, source_slice), - .Some => try unescapeValidString(output, source_slice), - } - - if (ptrInfo.sentinel) |some| { - const char = @ptrCast(*const u8, some).*; - output[len] = char; - return output[0..len :char]; - } - - return output; - }, - else => return error.UnexpectedToken, - } - }, - else => @compileError("Unable to parse into type '" ++ @typeName(T) ++ "'"), - } - }, - else => @compileError("Unable to parse into type '" ++ @typeName(T) ++ "'"), - } - unreachable; -} - -fn ParseError(comptime T: type) type { - return ParseInternalError(T) || error{UnexpectedEndOfJson} || TokenStream.Error; -} - -/// Releases resources created by `parse`. -/// Should be called with the same type and `ParseOptions` that were passed to `parse` -fn parseFreeInternal(comptime T: type, value: T, options: ParseOptions) void { - switch (@typeInfo(T)) { - .Bool, .Float, .ComptimeFloat, .Int, .ComptimeInt, .Enum => {}, - .Optional => { - if (value) |v| { - return parseFreeInternal(@TypeOf(v), v, options); - } - }, - .Union => |unionInfo| { - if (unionInfo.tag_type) |UnionTagType| { - inline for (unionInfo.fields) |u_field| { - if (value == @field(UnionTagType, u_field.name)) { - parseFreeInternal(u_field.field_type, @field(value, u_field.name), options); - break; - } - } - } else { - unreachable; - } - }, - .Struct => |structInfo| { - inline for (structInfo.fields) |field| { - if (!field.is_comptime) { - parseFreeInternal(field.field_type, @field(value, field.name), options); - } - } - }, - .Array => |arrayInfo| { - for (value) |v| { - parseFreeInternal(arrayInfo.child, v, options); - } - }, - .Pointer => |ptrInfo| { - const allocator = options.allocator orelse unreachable; - switch (ptrInfo.size) { - .One => { - parseFreeInternal(ptrInfo.child, value.*, options); - allocator.destroy(value); - }, - .Slice => { - for (value) |v| { - parseFreeInternal(ptrInfo.child, v, options); - } - allocator.free(value); - }, - else => unreachable, - } - }, - else => unreachable, - } -} diff --git a/src/main/query.zig b/src/main/query.zig deleted file mode 100644 index 1933429..0000000 --- a/src/main/query.zig +++ /dev/null @@ -1,380 +0,0 @@ -const std = @import("std"); - -const QueryIter = @import("util").QueryIter; - -/// Parses a set of query parameters described by the struct `T`. -/// -/// To specify query parameters, provide a struct similar to the following: -/// ``` -/// struct { -/// foo: bool = false, -/// bar: ?[]const u8 = null, -/// baz: usize = 10, -/// qux: enum { quux, snap } = .quux, -/// } -/// ``` -/// -/// This will allow it to parse a query string like the following: -/// `?foo&bar=abc&qux=snap` -/// -/// Every parameter must have a default value that will be used when the -/// parameter is not provided, and parameter keys. -/// Numbers are parsed from their string representations, and a parameter -/// provided in the query string without a value is parsed either as a bool -/// `true` flag or as `null` depending on the type of its param. -/// -/// Parameter types supported: -/// - []const u8 -/// - numbers (both integer and float) -/// + Numbers are parsed in base 10 -/// - bool -/// + See below for detals -/// - exhaustive enums -/// + Enums are treated as strings with values equal to the enum fields -/// - ?F (where isScalar(F) and F != bool) -/// - Any type that implements: -/// + pub fn parse([]const u8) !F -/// -/// Boolean Parameters: -/// The following query strings will all parse a `true` value for the -/// parameter `foo: bool = false`: -/// - `?foo` -/// - `?foo=true` -/// - `?foo=t` -/// - `?foo=yes` -/// - `?foo=y` -/// - `?foo=1` -/// And the following query strings all parse a `false` value: -/// - `?` -/// - `?foo=false` -/// - `?foo=f` -/// - `?foo=no` -/// - `?foo=n` -/// - `?foo=0` -/// -/// Compound Types: -/// Compound (struct) types are also supported, with the parameter key -/// for its parameters consisting of the struct's field + '.' + parameter -/// field. For example: -/// ``` -/// struct { -/// foo: struct { -/// baz: usize = 0, -/// } = .{}, -/// } -/// ``` -/// Would be used to parse a query string like -/// `?foo.baz=12345` -/// -/// Compound types cannot currently be nullable, and must be structs. -/// -/// TODO: values are currently case-sensitive, and are not url-decoded properly. -/// This should be fixed. -pub fn parseQuery(alloc: std.mem.Allocator, comptime T: type, query: []const u8) !T { - if (comptime !std.meta.trait.isContainer(T)) @compileError("T must be a struct"); - var iter = QueryIter.from(query); - - var fields = Intermediary(T){}; - while (iter.next()) |pair| { - // TODO: Hash map - inline for (std.meta.fields(Intermediary(T))) |field| { - if (std.ascii.eqlIgnoreCase(field.name[2..], pair.key)) { - @field(fields, field.name) = if (pair.value) |v| .{ .value = v } else .{ .no_value = {} }; - break; - } - } else std.log.debug("unknown param {s}", .{pair.key}); - } - - return (try parse(alloc, T, "", "", fields)) orelse error.NoQuery; -} - -fn decodeString(alloc: std.mem.Allocator, val: []const u8) ![]const u8 { - var list = try std.ArrayList(u8).initCapacity(alloc, val.len); - errdefer list.deinit(); - - var idx: usize = 0; - while (idx < val.len) : (idx += 1) { - if (val[idx] != '%') { - try list.append(val[idx]); - } else { - if (val.len < idx + 2) return error.InvalidEscape; - const buf = [2]u8{ val[idx + 1], val[idx + 2] }; - idx += 2; - - const ch = try std.fmt.parseInt(u8, &buf, 16); - try list.append(ch); - } - } - - return list.toOwnedSlice(); -} - -fn parseScalar(alloc: std.mem.Allocator, comptime T: type, comptime name: []const u8, fields: anytype) !?T { - const param = @field(fields, name); - return switch (param) { - .not_specified => null, - .no_value => try parseQueryValue(alloc, T, null), - .value => |v| try parseQueryValue(alloc, T, v), - }; -} - -fn parse( - alloc: std.mem.Allocator, - comptime T: type, - comptime prefix: []const u8, - comptime name: []const u8, - fields: anytype, -) !?T { - if (comptime isScalar(T)) return parseScalar(alloc, T, prefix ++ "." ++ name, fields); - switch (@typeInfo(T)) { - .Union => |info| { - var result: ?T = null; - inline for (info.fields) |field| { - const F = field.field_type; - - const maybe_value = try parse(alloc, F, prefix, field.name, fields); - if (maybe_value) |value| { - if (result != null) return error.DuplicateUnionField; - - result = @unionInit(T, field.name, value); - } - } - std.log.debug("{any}", .{result}); - return result; - }, - - .Struct => |info| { - var result: T = undefined; - var fields_specified: usize = 0; - - inline for (info.fields) |field| { - const F = field.field_type; - - var maybe_value: ?F = null; - if (try parse(alloc, F, prefix ++ "." ++ name, field.name, fields)) |v| { - maybe_value = v; - } else if (field.default_value) |default| { - if (comptime @sizeOf(F) != 0) { - maybe_value = @ptrCast(*const F, @alignCast(@alignOf(F), default)).*; - } else { - maybe_value = std.mem.zeroes(F); - } - } - - if (maybe_value) |v| { - fields_specified += 1; - @field(result, field.name) = v; - } - } - - if (fields_specified == 0) { - return null; - } else if (fields_specified != info.fields.len) { - std.log.debug("{} {s} {s}", .{ T, prefix, name }); - return error.PartiallySpecifiedStruct; - } else { - return result; - } - }, - - // Only applies to non-scalar optionals - .Optional => |info| return try parse(alloc, info.child, prefix, name, fields), - - else => @compileError("tmp"), - } -} - -fn recursiveFieldPaths(comptime T: type, comptime prefix: []const u8) []const []const u8 { - comptime { - if (std.meta.trait.is(.Optional)(T)) return recursiveFieldPaths(std.meta.Child(T), prefix); - - var fields: []const []const u8 = &.{}; - - for (std.meta.fields(T)) |f| { - const full_name = prefix ++ f.name; - - if (isScalar(f.field_type)) { - fields = fields ++ @as([]const []const u8, &.{full_name}); - } else { - const field_prefix = if (@typeInfo(f.field_type) == .Union) prefix else full_name ++ "."; - fields = fields ++ recursiveFieldPaths(f.field_type, field_prefix); - } - } - - return fields; - } -} - -const QueryParam = union(enum) { - not_specified: void, - no_value: void, - value: []const u8, -}; - -fn Intermediary(comptime T: type) type { - const field_names = recursiveFieldPaths(T, ".."); - - var fields: [field_names.len]std.builtin.Type.StructField = undefined; - for (field_names) |name, i| fields[i] = .{ - .name = name, - .field_type = QueryParam, - .default_value = &QueryParam{ .not_specified = {} }, - .is_comptime = false, - .alignment = @alignOf(QueryParam), - }; - - return @Type(.{ .Struct = .{ - .layout = .Auto, - .fields = &fields, - .decls = &.{}, - .is_tuple = false, - } }); -} - -fn parseQueryValue(alloc: std.mem.Allocator, comptime T: type, value: ?[]const u8) !T { - const is_optional = comptime std.meta.trait.is(.Optional)(T); - // If param is present, but without an associated value - if (value == null) { - return if (is_optional) - null - else if (T == bool) - true - else - error.InvalidValue; - } - - return try parseQueryValueNotNull(alloc, if (is_optional) std.meta.Child(T) else T, value.?); -} - -const bool_map = std.ComptimeStringMap(bool, .{ - .{ "true", true }, - .{ "t", true }, - .{ "yes", true }, - .{ "y", true }, - .{ "1", true }, - - .{ "false", false }, - .{ "f", false }, - .{ "no", false }, - .{ "n", false }, - .{ "0", false }, -}); - -fn parseQueryValueNotNull(alloc: std.mem.Allocator, comptime T: type, value: []const u8) !T { - const decoded = try decodeString(alloc, value); - errdefer alloc.free(decoded); - - if (comptime std.meta.trait.isZigString(T)) return decoded; - - const result = if (comptime std.meta.trait.isIntegral(T)) - try std.fmt.parseInt(T, decoded, 0) - else if (comptime std.meta.trait.isFloat(T)) - try std.fmt.parseFloat(T, decoded) - else if (comptime std.meta.trait.is(.Enum)(T)) - std.meta.stringToEnum(T, decoded) orelse return error.InvalidEnumValue - else if (T == bool) - bool_map.get(value) orelse return error.InvalidBool - else if (comptime std.meta.trait.hasFn("parse")(T)) - try T.parse(value) - else - @compileError("Invalid type " ++ @typeName(T)); - - alloc.free(decoded); - return result; -} - -fn isScalar(comptime T: type) bool { - if (comptime std.meta.trait.isZigString(T)) return true; - if (comptime std.meta.trait.isIntegral(T)) return true; - if (comptime std.meta.trait.isFloat(T)) return true; - if (comptime std.meta.trait.is(.Enum)(T)) return true; - if (T == bool) return true; - if (comptime std.meta.trait.hasFn("parse")(T)) return true; - - if (comptime std.meta.trait.is(.Optional)(T) and isScalar(std.meta.Child(T))) return true; - - return false; -} - -pub fn formatQuery(params: anytype, writer: anytype) !void { - try format("", "", params, writer); -} - -fn urlFormatString(writer: anytype, val: []const u8) !void { - for (val) |ch| { - const printable = switch (ch) { - '0'...'9', 'a'...'z', 'A'...'Z' => true, - '-', '.', '_', '~', ':', '@', '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=' => true, - else => false, - }; - - try if (printable) writer.writeByte(ch) else std.fmt.format(writer, "%{x:0>2}", .{ch}); - } -} - -fn formatScalar(comptime name: []const u8, val: anytype, writer: anytype) !void { - const T = @TypeOf(val); - if (comptime std.meta.trait.is(.Optional)(T)) { - return if (val) |v| formatScalar(name, v, writer) else {}; - } - - try urlFormatString(writer, name); - try writer.writeByte('='); - if (comptime std.meta.trait.isZigString(T)) { - try urlFormatString(writer, val); - } else try switch (@typeInfo(T)) { - .Enum => urlFormatString(writer, @tagName(val)), - else => std.fmt.format(writer, "{}", .{val}), - }; - - try writer.writeByte('&'); -} - -fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytype, writer: anytype) !void { - const T = @TypeOf(params); - const eff_prefix = if (prefix.len == 0) "" else prefix ++ "."; - if (comptime isScalar(T)) return formatScalar(eff_prefix ++ name, params, writer); - - switch (@typeInfo(T)) { - .Struct => { - inline for (std.meta.fields(T)) |field| { - const val = @field(params, field.name); - try format(eff_prefix ++ name, field.name, val, writer); - } - }, - .Union => { - inline for (std.meta.fields(T)) |field| { - const tag = @field(std.meta.Tag(T), field.name); - const tag_name = field.name; - if (@as(std.meta.Tag(T), params) == tag) { - const val = @field(params, tag_name); - try format(prefix, tag_name, val, writer); - } - } - }, - .Optional => { - if (params) |p| try format(prefix, name, p, writer); - }, - else => @compileError("Unsupported query type"), - } -} - -test { - const TestQuery = struct { - int: usize = 3, - boolean: bool = false, - str_enum: ?enum { foo, bar } = null, - }; - - try std.testing.expectEqual(TestQuery{ - .int = 3, - .boolean = false, - .str_enum = null, - }, try parseQuery(TestQuery, "")); - - try std.testing.expectEqual(TestQuery{ - .int = 5, - .boolean = true, - .str_enum = .foo, - }, try parseQuery(TestQuery, "?int=5&boolean=yes&str_enum=foo")); -} From 8aa4f900f6289880c3a122e7ca36773420a462cd Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Sun, 27 Nov 2022 01:59:37 -0800 Subject: [PATCH 9/9] Use relative links on pagination --- src/main/controllers.zig | 26 ++++++++++++++-------- src/main/controllers/api/communities.zig | 2 +- src/main/controllers/api/timelines.zig | 6 ++--- src/main/controllers/api/users/follows.zig | 4 ++-- 4 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/main/controllers.zig b/src/main/controllers.zig index ea7fa45..10ecdb7 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -242,14 +242,14 @@ const json_options = if (builtin.mode == .Debug) }; pub const helpers = struct { - pub fn paginate(community: api.Community, path: []const u8, results: anytype, res: *Response, alloc: std.mem.Allocator) !void { + pub fn paginate(results: anytype, res: *Response, alloc: std.mem.Allocator) !void { var link = std.ArrayList(u8).init(alloc); const link_writer = link.writer(); defer link.deinit(); - try writeLink(link_writer, community, path, results.next_page, "next"); + try writeLink(link_writer, null, "", results.next_page, "next"); try link_writer.writeByte(','); - try writeLink(link_writer, community, path, results.prev_page, "prev"); + try writeLink(link_writer, null, "", results.prev_page, "prev"); try res.headers.put("Link", link.items); @@ -258,16 +258,24 @@ pub const helpers = struct { fn writeLink( writer: anytype, - community: api.Community, + community: ?api.Community, path: []const u8, params: anytype, rel: []const u8, ) !void { + if (community) |c| { + try std.fmt.format( + writer, + "<{s}://{s}/{s}?{}>; rel=\"{s}\"", + .{ @tagName(c.scheme), c.host, path, http.queryStringify(params), rel }, + ); + } else { + try std.fmt.format( + writer, + "<{s}?{}>; rel=\"{s}\"", + .{ path, http.queryStringify(params), rel }, + ); + } // TODO: percent-encode - try std.fmt.format( - writer, - "<{s}://{s}/{s}?{}>; rel=\"{s}\"", - .{ @tagName(community.scheme), community.host, path, http.queryStringify(params), rel }, - ); } }; diff --git a/src/main/controllers/api/communities.zig b/src/main/controllers/api/communities.zig index f6f475d..7b61d86 100644 --- a/src/main/controllers/api/communities.zig +++ b/src/main/controllers/api/communities.zig @@ -27,6 +27,6 @@ pub const query = struct { pub fn handler(req: anytype, res: anytype, srv: anytype) !void { const results = try srv.queryCommunities(req.query); - try controller_utils.paginate(srv.community, path, results, res, req.allocator); + try controller_utils.paginate(results, res, req.allocator); } }; diff --git a/src/main/controllers/api/timelines.zig b/src/main/controllers/api/timelines.zig index 8c30cc1..df41599 100644 --- a/src/main/controllers/api/timelines.zig +++ b/src/main/controllers/api/timelines.zig @@ -10,7 +10,7 @@ pub const global = struct { pub fn handler(req: anytype, res: anytype, srv: anytype) !void { const results = try srv.globalTimeline(req.query); - try controller_utils.paginate(srv.community, path, results, res, req.allocator); + try controller_utils.paginate(results, res, req.allocator); } }; @@ -22,7 +22,7 @@ pub const local = struct { pub fn handler(req: anytype, res: anytype, srv: anytype) !void { const results = try srv.localTimeline(req.query); - try controller_utils.paginate(srv.community, path, results, res, req.allocator); + try controller_utils.paginate(results, res, req.allocator); } }; @@ -34,6 +34,6 @@ pub const home = struct { pub fn handler(req: anytype, res: anytype, srv: anytype) !void { const results = try srv.homeTimeline(req.query); - try controller_utils.paginate(srv.community, path, results, res, req.allocator); + try controller_utils.paginate(results, res, req.allocator); } }; diff --git a/src/main/controllers/api/users/follows.zig b/src/main/controllers/api/users/follows.zig index 765e7a0..4ac3799 100644 --- a/src/main/controllers/api/users/follows.zig +++ b/src/main/controllers/api/users/follows.zig @@ -47,7 +47,7 @@ pub const query_followers = struct { pub fn handler(req: anytype, res: anytype, srv: anytype) !void { const results = try srv.queryFollowers(req.args.id, req.query); - try controller_utils.paginate(srv.community, path, results, res, req.allocator); + try controller_utils.paginate(results, res, req.allocator); } }; @@ -64,6 +64,6 @@ pub const query_following = struct { pub fn handler(req: anytype, res: anytype, srv: anytype) !void { const results = try srv.queryFollowing(req.args.id, req.query); - try controller_utils.paginate(srv.community, path, results, res, req.allocator); + try controller_utils.paginate(results, res, req.allocator); } };