const std = @import("std"); const util = @import("util"); const http = @import("./lib.zig"); const Args = struct { const max_args = 16; names: [max_args]?[]const u8 = [_]?[]const u8{null} ** max_args, values: [max_args]?[]const u8 = [_]?[]const u8{null} ** max_args, pub fn get(self: Args, name: []const u8) ?[]const u8 { for (self.names) |arg_name, i| { if (arg_name == null) return null; if (util.ciutf8.eql(name, arg_name.?)) return self.values[i]; } return null; } }; pub fn Router(comptime Context: type) type { return struct { const Self = @This(); routes: []const Route, pub const Handler = fn (Context, Args) anyerror!void; pub const Route = struct { method: http.Method, path: []const u8, path_segments: []const RouteSegment, handler: Handler, pub fn new(method: http.Method, comptime path: []const u8, handler: Handler) Route { const result = .{ .method = method, .path = path, .path_segments = splitPathSegments(path), .handler = handler, }; var param_count: usize = 0; for (result.path_segments) |seg| { if (seg == .param) param_count += 1; } if (param_count > Args.max_args) @panic("Too many params"); return result; } fn dispatch(self: Route, ctx: Context, req_method: std.http.Method, req_path: []const u8) anyerror!void { if (req_method != self.method) return error.RouteNotApplicable; var args = Args{}; var arg_count: usize = 0; var req_segments = util.PathIter.from(req_path); for (self.path_segments) |seg| { const req_seg = req_segments.next() orelse return error.RouteNotApplicable; switch (seg) { .literal => |literal| { if (!util.ciutf8.eql(literal, req_seg)) return error.RouteNotApplicable; }, .param => |param| { args.names[arg_count] = param; args.values[arg_count] = req_seg; }, } } if (req_segments.next() != null) return error.RouteNotApplicable; return self.handler(ctx, args); } }; pub fn dispatch(self: Self, ctx: Context, method: std.http.Method, path: []const u8) anyerror!void { for (self.routes) |r| { r.dispatch(ctx, method, path) catch |err| switch (err) { error.RouteNotApplicable => continue, else => return err, }; return; } return error.RouteNotApplicable; } }; } const RouteSegment = union(enum) { literal: []const u8, param: []const u8, }; fn paramNameFrom(segment: []const u8) []const u8 { return segment[1..]; } fn isParamSegment(segment: []const u8) bool { return segment[0] == ':'; } fn splitPathSegments(comptime path: []const u8) []const RouteSegment { comptime { var segments: [path.len]RouteSegment = undefined; var segment_count: usize = 0; var iter = util.PathIter.from(path); while (iter.next()) |segment| { if (isParamSegment(segment)) { segments[segment_count] = .{ .param = paramNameFrom(segment), }; } else { segments[segment_count] = .{ .literal = segment }; } segment_count += 1; } return segments[0..segment_count]; } } const _test = struct { fn CallTracker(comptime _uniq: anytype, comptime next: anytype) type { _ = _uniq; var ctx_type: type = undefined; var args_type: type = undefined; switch (@typeInfo(@TypeOf(next))) { .Fn => |func| { if (func.args.len != 2) @compileError("next() must take 2 arguments"); ctx_type = func.args[0].arg_type.?; args_type = func.args[1].arg_type.?; //if (@typeInfo(Args) != .Struct) @compileError("second argument to next() must be struct"); }, else => @compileError("next must be function"), } const Context = ctx_type; return struct { var calls: u32 = 0; var last_ctx: ?Context = null; var last_args: ?Args = null; fn func(ctx: Context, args: Args) !void { calls += 1; last_ctx = ctx; last_args = args; return next(ctx, args); } fn expectCalledOnceWith(exp_ctx: Context, exp_args: Args) !void { try std.testing.expectEqual(@as(u32, 1), calls); try std.testing.expectEqual(exp_ctx, last_ctx.?); for (exp_args.names) |exp_name, i| { if (exp_name == null) { try std.testing.expectEqual(@as(?[]const u8, null), last_args.?.names[i]); try std.testing.expectEqual(@as(?[]const u8, null), last_args.?.values[i]); } else { try std.testing.expectEqualStrings(exp_name.?, last_args.?.names[i].?); try std.testing.expectEqualStrings(exp_args.values[i].?, last_args.?.values[i].?); } } } fn expectNotCalled() !void { try std.testing.expectEqual(@as(u32, 0), calls); } fn reset() void { calls = 0; last_ctx = null; last_args = null; } }; } const TestContext = u32; fn dummyHandler(_: TestContext, _: Args) anyerror!void {} }; test "route(T, ...) basic" { const mock_a = _test.CallTracker(.{}, _test.dummyHandler); const MyRouter = Router(_test.TestContext); const MyRoute = MyRouter.Route; const my_router = MyRouter{ .routes = &[_]MyRoute{ MyRoute.new(.GET, "/a", mock_a.func), } }; _ = try my_router.dispatch(10, .GET, "/a"); try mock_a.expectCalledOnceWith(10, .{}); mock_a.reset(); } test "Route(T) matches correct route from multiple routes" { const mock_a = _test.CallTracker(.{}, _test.dummyHandler); const mock_b = _test.CallTracker(.{}, _test.dummyHandler); const MyRouter = Router(_test.TestContext); const MyRoute = MyRouter.Route; const my_router = MyRouter{ .routes = &[_]MyRoute{ MyRoute.new(.GET, "/a", mock_a.func), MyRoute.new(.GET, "/b", mock_b.func), } }; _ = try my_router.dispatch(10, .GET, "/a"); try mock_a.expectCalledOnceWith(10, .{}); try mock_b.expectNotCalled(); mock_a.reset(); _ = try my_router.dispatch(10, .GET, "/b"); try mock_a.expectNotCalled(); try mock_b.expectCalledOnceWith(10, .{}); mock_b.reset(); } test "Route(T) passes correct context" { const mock_a = _test.CallTracker(.{}, _test.dummyHandler); const MyRouter = Router(_test.TestContext); const MyRoute = MyRouter.Route; const my_router = MyRouter{ .routes = &[_]MyRoute{ MyRoute.new(.GET, "/a", mock_a.func), } }; _ = try my_router.dispatch(10, .GET, "/a"); try mock_a.expectCalledOnceWith(10, .{}); mock_a.reset(); _ = try my_router.dispatch(16, .GET, "/a"); try mock_a.expectCalledOnceWith(16, .{}); mock_a.reset(); } test "Route(T) errors on no matching route" { const mock_a = _test.CallTracker(.{}, _test.dummyHandler); const mock_b = _test.CallTracker(.{}, _test.dummyHandler); const MyRouter = Router(_test.TestContext); const MyRoute = MyRouter.Route; const my_router = MyRouter{ .routes = &[_]MyRoute{ MyRoute.new(.GET, "/a", mock_a.func), MyRoute.new(.GET, "/b", mock_b.func), } }; try std.testing.expectError(error.RouteNotApplicable, my_router.dispatch(0, .GET, "/c")); try mock_a.expectNotCalled(); try mock_b.expectNotCalled(); } test "route(T) with no routes" { const MyRouter = Router(_test.TestContext); const MyRoute = MyRouter.Route; const my_router = MyRouter{ .routes = &[_]MyRoute{} }; try std.testing.expectError(error.RouteNotApplicable, my_router.dispatch(0, .GET, "/c")); } test "route(T, ...) same path different methods" { const mock_get = _test.CallTracker(.{}, _test.dummyHandler); const mock_post = _test.CallTracker(.{}, _test.dummyHandler); const MyRouter = Router(_test.TestContext); const MyRoute = MyRouter.Route; const my_router = MyRouter{ .routes = &[_]MyRoute{ MyRoute.new(.GET, "/a", mock_get.func), MyRoute.new(.POST, "/a", mock_post.func), } }; _ = try my_router.dispatch(10, .GET, "/a"); try mock_get.expectCalledOnceWith(10, .{}); try mock_post.expectNotCalled(); mock_get.reset(); _ = try my_router.dispatch(10, .POST, "/a"); try mock_get.expectNotCalled(); try mock_post.expectCalledOnceWith(10, .{}); } test "route(T, ...) route under subpath" { const mock_a = _test.CallTracker(.{}, _test.dummyHandler); const mock_b = _test.CallTracker(.{}, _test.dummyHandler); const MyRouter = Router(_test.TestContext); const MyRoute = MyRouter.Route; const my_router = MyRouter{ .routes = &[_]MyRoute{ MyRoute.new(.GET, "/a", mock_a.func), MyRoute.new(.GET, "/a/b", mock_b.func), } }; _ = try my_router.dispatch(10, .GET, "/a"); try mock_a.expectCalledOnceWith(10, .{}); try mock_b.expectNotCalled(); mock_a.reset(); _ = try my_router.dispatch(10, .GET, "/a/b"); try mock_a.expectNotCalled(); try mock_b.expectCalledOnceWith(10, .{}); } test "route(T, ...) case-insensitive route" { const mock_a = _test.CallTracker(.{}, _test.dummyHandler); const MyRouter = Router(_test.TestContext); const MyRoute = MyRouter.Route; const my_router = MyRouter{ .routes = &[_]MyRoute{ MyRoute.new(.GET, "/test/a", mock_a.func), } }; _ = try my_router.dispatch(10, .GET, "/TEST/A"); try mock_a.expectCalledOnceWith(10, .{}); mock_a.reset(); _ = try my_router.dispatch(10, .GET, "/TesT/a"); try mock_a.expectCalledOnceWith(10, .{}); } test "route(T, ...) redundant /" { const mock_a = _test.CallTracker(.{}, _test.dummyHandler); const MyRouter = Router(_test.TestContext); const MyRoute = MyRouter.Route; const my_router = MyRouter{ .routes = &[_]MyRoute{ MyRoute.new(.GET, "/test/a", mock_a.func), } }; _ = try my_router.dispatch(10, .GET, "/test//a"); try mock_a.expectCalledOnceWith(10, .{}); mock_a.reset(); _ = try my_router.dispatch(10, .GET, "//test///////////a////"); try mock_a.expectCalledOnceWith(10, .{}); mock_a.reset(); _ = try my_router.dispatch(10, .GET, "test/a"); try mock_a.expectCalledOnceWith(10, .{}); mock_a.reset(); try std.testing.expectError(error.RouteNotApplicable, my_router.dispatch(10, .GET, "/te/st/a")); try mock_a.expectNotCalled(); }