diff --git a/src/http/middleware.zig b/src/http/middleware.zig index b0eab15..ce4d307 100644 --- a/src/http/middleware.zig +++ b/src/http/middleware.zig @@ -407,10 +407,14 @@ fn pathMatches(route: []const u8, path: []const u8) bool { var path_iter = PathIter.from(path); var route_iter = PathIter.from(route); while (route_iter.next()) |route_segment| { - const path_segment = path_iter.next() orelse return false; + const path_segment = path_iter.next() orelse ""; + if (route_segment.len > 0 and route_segment[0] == ':') { // Route Argument - if (path_segment.len == 0) return false; + if (route_segment[route_segment.len - 1] == '*') { + // consume rest of path segments + while (path_iter.next()) |_| {} + } else if (path_segment.len == 0) return false; } else { if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return false; } @@ -481,6 +485,10 @@ test "route" { try testCase(true, .{ .method = .POST, .path = "/abcd/efgh" }, .POST, "abcd/efgh"); try testCase(true, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "abcd/efgh"); try testCase(true, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz"); + try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/efgh/xyz"); + try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/efgh"); + try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/"); + try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd"); try testCase(false, .{ .method = .POST, .path = "/" }, .GET, "/"); try testCase(false, .{ .method = .GET, .path = "/abcd" }, .GET, ""); @@ -489,32 +497,21 @@ test "route" { try testCase(false, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "/abcd/"); try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/"); try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz/foo"); + try testCase(false, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "defg/abcd"); } /// Mounts a router subtree under a given path. Middlewares further down on the list /// are called with the path prefix specified by `route` removed from the path. /// Must be below `split_uri` on the middleware list. pub fn Mount(comptime route: []const u8) type { + if (std.mem.indexOfScalar(u8, route, ':') != null) @compileError("Route args cannot be mounted"); return struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - var path_iter = PathIter.from(ctx.path); - comptime var route_iter = PathIter.from(route); - var path_unused: []const u8 = ctx.path; - - inline while (comptime route_iter.next()) |route_segment| { - if (comptime route_segment.len == 0) continue; - const path_segment = path_iter.next() orelse return error.RouteMismatch; - path_unused = path_iter.rest(); - if (comptime route_segment[0] == ':') { - @compileLog("Argument segments cannot be mounted"); - // Route Argument - } else { - if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch; - } - } + const args = try parseArgsFromPath(route ++ "/:path*", struct { path: []const u8 }, ctx.path); var new_ctx = ctx; - new_ctx.path = path_unused; + new_ctx.path = args.path; + return next.handle(req, res, new_ctx, {}); } }; @@ -546,16 +543,31 @@ fn parseArgsFromPath(comptime route: []const u8, comptime Args: type, path: []co var args: Args = undefined; var path_iter = PathIter.from(path); comptime var route_iter = PathIter.from(route); + var path_unused: []const u8 = path; + inline while (comptime route_iter.next()) |route_segment| { - const path_segment = path_iter.next() orelse return error.RouteMismatch; - if (route_segment.len > 0 and route_segment[0] == ':') { + const path_segment = path_iter.next() orelse ""; + if (route_segment[0] == ':') { + comptime var name: []const u8 = route_segment[1..]; + var value: []const u8 = path_segment; + // route segment is an argument segment - if (path_segment.len == 0) return error.RouteMismatch; - const A = @TypeOf(@field(args, route_segment[1..])); - @field(args, route_segment[1..]) = try parseArgFromPath(A, path_segment); + if (comptime route_segment[route_segment.len - 1] == '*') { + // waste remaining args + while (path_iter.next()) |_| {} + name = route_segment[1 .. route_segment.len - 1]; + value = path_unused; + } else { + if (path_segment.len == 0) return error.RouteMismatch; + } + + const A = @TypeOf(@field(args, name)); + @field(args, name) = try parseArgFromPath(A, value); } else { + // route segment is a literal segment if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch; } + path_unused = path_iter.rest(); } if (path_iter.next() != null) return error.RouteMismatch; @@ -630,6 +642,21 @@ test "ParsePathArgs" { try testCase("/:id/xyz/:str", struct { id: usize, str: []const u8 }, "/3/xyz/abcd", .{ .id = 3, .str = "abcd" }); try testCase("/:id", struct { id: util.Uuid }, "/" ++ util.Uuid.nil.toCharArray(), .{ .id = util.Uuid.nil }); + try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/abc", .{ .arg = "abc" }); + try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/abc/def", .{ .arg = "abc/def" }); + try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/", .{ .arg = "" }); + + // Compiler crashes if i keep the args named the same as above. + // TODO: Debug this and try to fix it + try testCase("/xyz/:bar*", struct { bar: []const u8 }, "/xyz", .{ .bar = "" }); + + // It's a quirk that the initial / is left in for these cases. However, it results in a path + // that's semantically equivalent so i didn't bother fixing it + try testCase("/:foo*", struct { foo: []const u8 }, "/abc", .{ .foo = "/abc" }); + try testCase("/:foo*", struct { foo: []const u8 }, "/abc/def", .{ .foo = "/abc/def" }); + try testCase("/:foo*", struct { foo: []const u8 }, "/", .{ .foo = "/" }); + try testCase("/:foo*", struct { foo: []const u8 }, "", .{ .foo = "" }); + try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/", .{})); try std.testing.expectError(error.RouteMismatch, testCase("/abcd/:id", struct { id: usize }, "/123", .{})); try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/3/id/blahblah", .{ .id = 3 }));