diff --git a/src/http/lib.zig b/src/http/lib.zig index eeebf35..e2d0052 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -1,9 +1,13 @@ const std = @import("std"); const ciutf8 = @import("util").ciutf8; +const routing = @import("./routing.zig"); +const response_stream = @import("./response_stream.zig"); + pub const Status = std.http.Status; pub const Method = std.http.Method; -pub const ResponseStream = @import("./response_stream.zig").ResponseStream; +pub const ResponseStream = response_stream.ResponseStream; +//pub const Router = routing.Router(Context); pub const Headers = std.HashMap([]const u8, []const u8, struct { pub fn eql(_: @This(), a: []const u8, b: []const u8) bool { @@ -17,4 +21,5 @@ pub const Headers = std.HashMap([]const u8, []const u8, struct { test { _ = ResponseStream; + _ = routing; } diff --git a/src/http/routing.zig b/src/http/routing.zig new file mode 100644 index 0000000..c6824b9 --- /dev/null +++ b/src/http/routing.zig @@ -0,0 +1,359 @@ +const std = @import("std"); +const util = @import("util"); + +const http = @import("./lib.zig"); + +const Args = struct { + const max_args = 16; + + names: [max_args]?[]const u8 = [_]?[]const u8{null} ** max_args, + values: [max_args]?[]const u8 = [_]?[]const u8{null} ** max_args, + + pub fn get(self: Args, name: []const u8) ?[]const u8 { + for (self.names) |arg_name, i| { + if (arg_name == null) return null; + + if (util.ciutf8.eql(name, arg_name.?)) return self.values[i]; + } + + return null; + } +}; + +pub fn Router(comptime Context: type) type { + return struct { + const Self = @This(); + + routes: []const Route, + + pub const Handler = fn (Context, Args) anyerror!void; + pub const Route = struct { + method: http.Method, + path: []const u8, + path_segments: []const RouteSegment, + handler: Handler, + + pub fn new(method: http.Method, comptime path: []const u8, handler: Handler) Route { + const result = .{ + .method = method, + .path = path, + .path_segments = splitPathSegments(path), + .handler = handler, + }; + + var param_count: usize = 0; + for (result.path_segments) |seg| { + if (seg == .param) param_count += 1; + } + + if (param_count > Args.max_args) @panic("Too many params"); + + return result; + } + + fn dispatch(self: Route, ctx: Context, req_method: std.http.Method, req_path: []const u8) anyerror!void { + if (req_method != self.method) return error.RouteNotApplicable; + + var args = Args{}; + var arg_count: usize = 0; + var req_segments = util.PathIter.from(req_path); + for (self.path_segments) |seg| { + const req_seg = req_segments.next() orelse return error.RouteNotApplicable; + switch (seg) { + .literal => |literal| { + if (!util.ciutf8.eql(literal, req_seg)) return error.RouteNotApplicable; + }, + .param => |param| { + args.names[arg_count] = param; + args.values[arg_count] = req_seg; + }, + } + } + + if (req_segments.next() != null) return error.RouteNotApplicable; + + return self.handler(ctx, args); + } + }; + + pub fn dispatch(self: Self, ctx: Context, method: std.http.Method, path: []const u8) anyerror!void { + for (self.routes) |r| { + r.dispatch(ctx, method, path) catch |err| switch (err) { + error.RouteNotApplicable => continue, + else => return err, + }; + + return; + } + + return error.RouteNotApplicable; + } + }; +} + +const RouteSegment = union(enum) { + literal: []const u8, + param: []const u8, +}; + +fn paramNameFrom(segment: []const u8) []const u8 { + return segment[1..]; +} + +fn isParamSegment(segment: []const u8) bool { + return segment[0] == ':'; +} + +fn splitPathSegments(comptime path: []const u8) []const RouteSegment { + comptime { + var segments: [path.len]RouteSegment = undefined; + var segment_count: usize = 0; + + var iter = util.PathIter.from(path); + while (iter.next()) |segment| { + if (isParamSegment(segment)) { + segments[segment_count] = .{ + .param = paramNameFrom(segment), + }; + } else { + segments[segment_count] = .{ .literal = segment }; + } + + segment_count += 1; + } + + return segments[0..segment_count]; + } +} + +const _test = struct { + fn CallTracker(comptime _uniq: anytype, comptime next: anytype) type { + _ = _uniq; + + var ctx_type: type = undefined; + var args_type: type = undefined; + switch (@typeInfo(@TypeOf(next))) { + .Fn => |func| { + if (func.args.len != 2) @compileError("next() must take 2 arguments"); + + ctx_type = func.args[0].arg_type.?; + args_type = func.args[1].arg_type.?; + //if (@typeInfo(Args) != .Struct) @compileError("second argument to next() must be struct"); + }, + else => @compileError("next must be function"), + } + + const Context = ctx_type; + + return struct { + var calls: u32 = 0; + + var last_ctx: ?Context = null; + var last_args: ?Args = null; + + fn func(ctx: Context, args: Args) !void { + calls += 1; + last_ctx = ctx; + last_args = args; + return next(ctx, args); + } + + fn expectCalledOnceWith(exp_ctx: Context, exp_args: Args) !void { + try std.testing.expectEqual(@as(u32, 1), calls); + try std.testing.expectEqual(exp_ctx, last_ctx.?); + for (exp_args.names) |exp_name, i| { + if (exp_name == null) { + try std.testing.expectEqual(@as(?[]const u8, null), last_args.?.names[i]); + try std.testing.expectEqual(@as(?[]const u8, null), last_args.?.values[i]); + } else { + try std.testing.expectEqualStrings(exp_name.?, last_args.?.names[i].?); + try std.testing.expectEqualStrings(exp_args.values[i].?, last_args.?.values[i].?); + } + } + } + + fn expectNotCalled() !void { + try std.testing.expectEqual(@as(u32, 0), calls); + } + + fn reset() void { + calls = 0; + last_ctx = null; + last_args = null; + } + }; + } + const TestContext = u32; + fn dummyHandler(_: TestContext, _: Args) anyerror!void {} +}; + +test "route(T, ...) basic" { + const mock_a = _test.CallTracker(.{}, _test.dummyHandler); + + const MyRouter = Router(_test.TestContext); + const MyRoute = MyRouter.Route; + + const my_router = MyRouter{ .routes = &[_]MyRoute{ + MyRoute.new(.GET, "/a", mock_a.func), + } }; + + _ = try my_router.dispatch(10, .GET, "/a"); + try mock_a.expectCalledOnceWith(10, .{}); + mock_a.reset(); +} + +test "Route(T) matches correct route from multiple routes" { + const mock_a = _test.CallTracker(.{}, _test.dummyHandler); + const mock_b = _test.CallTracker(.{}, _test.dummyHandler); + + const MyRouter = Router(_test.TestContext); + const MyRoute = MyRouter.Route; + + const my_router = MyRouter{ .routes = &[_]MyRoute{ + MyRoute.new(.GET, "/a", mock_a.func), + MyRoute.new(.GET, "/b", mock_b.func), + } }; + + _ = try my_router.dispatch(10, .GET, "/a"); + try mock_a.expectCalledOnceWith(10, .{}); + try mock_b.expectNotCalled(); + mock_a.reset(); + + _ = try my_router.dispatch(10, .GET, "/b"); + try mock_a.expectNotCalled(); + try mock_b.expectCalledOnceWith(10, .{}); + mock_b.reset(); +} + +test "Route(T) passes correct context" { + const mock_a = _test.CallTracker(.{}, _test.dummyHandler); + + const MyRouter = Router(_test.TestContext); + const MyRoute = MyRouter.Route; + + const my_router = MyRouter{ .routes = &[_]MyRoute{ + MyRoute.new(.GET, "/a", mock_a.func), + } }; + + _ = try my_router.dispatch(10, .GET, "/a"); + try mock_a.expectCalledOnceWith(10, .{}); + mock_a.reset(); + + _ = try my_router.dispatch(16, .GET, "/a"); + try mock_a.expectCalledOnceWith(16, .{}); + mock_a.reset(); +} + +test "Route(T) errors on no matching route" { + const mock_a = _test.CallTracker(.{}, _test.dummyHandler); + const mock_b = _test.CallTracker(.{}, _test.dummyHandler); + + const MyRouter = Router(_test.TestContext); + const MyRoute = MyRouter.Route; + + const my_router = MyRouter{ .routes = &[_]MyRoute{ + MyRoute.new(.GET, "/a", mock_a.func), + MyRoute.new(.GET, "/b", mock_b.func), + } }; + + try std.testing.expectError(error.RouteNotApplicable, my_router.dispatch(0, .GET, "/c")); + try mock_a.expectNotCalled(); + try mock_b.expectNotCalled(); +} + +test "route(T) with no routes" { + const MyRouter = Router(_test.TestContext); + const MyRoute = MyRouter.Route; + + const my_router = MyRouter{ .routes = &[_]MyRoute{} }; + + try std.testing.expectError(error.RouteNotApplicable, my_router.dispatch(0, .GET, "/c")); +} + +test "route(T, ...) same path different methods" { + const mock_get = _test.CallTracker(.{}, _test.dummyHandler); + const mock_post = _test.CallTracker(.{}, _test.dummyHandler); + + const MyRouter = Router(_test.TestContext); + const MyRoute = MyRouter.Route; + + const my_router = MyRouter{ .routes = &[_]MyRoute{ + MyRoute.new(.GET, "/a", mock_get.func), + MyRoute.new(.POST, "/a", mock_post.func), + } }; + + _ = try my_router.dispatch(10, .GET, "/a"); + try mock_get.expectCalledOnceWith(10, .{}); + try mock_post.expectNotCalled(); + mock_get.reset(); + + _ = try my_router.dispatch(10, .POST, "/a"); + try mock_get.expectNotCalled(); + try mock_post.expectCalledOnceWith(10, .{}); +} + +test "route(T, ...) route under subpath" { + const mock_a = _test.CallTracker(.{}, _test.dummyHandler); + const mock_b = _test.CallTracker(.{}, _test.dummyHandler); + + const MyRouter = Router(_test.TestContext); + const MyRoute = MyRouter.Route; + + const my_router = MyRouter{ .routes = &[_]MyRoute{ + MyRoute.new(.GET, "/a", mock_a.func), + MyRoute.new(.GET, "/a/b", mock_b.func), + } }; + + _ = try my_router.dispatch(10, .GET, "/a"); + try mock_a.expectCalledOnceWith(10, .{}); + try mock_b.expectNotCalled(); + mock_a.reset(); + + _ = try my_router.dispatch(10, .GET, "/a/b"); + try mock_a.expectNotCalled(); + try mock_b.expectCalledOnceWith(10, .{}); +} + +test "route(T, ...) case-insensitive route" { + const mock_a = _test.CallTracker(.{}, _test.dummyHandler); + + const MyRouter = Router(_test.TestContext); + const MyRoute = MyRouter.Route; + + const my_router = MyRouter{ .routes = &[_]MyRoute{ + MyRoute.new(.GET, "/test/a", mock_a.func), + } }; + + _ = try my_router.dispatch(10, .GET, "/TEST/A"); + try mock_a.expectCalledOnceWith(10, .{}); + mock_a.reset(); + + _ = try my_router.dispatch(10, .GET, "/TesT/a"); + try mock_a.expectCalledOnceWith(10, .{}); +} + +test "route(T, ...) redundant /" { + const mock_a = _test.CallTracker(.{}, _test.dummyHandler); + + const MyRouter = Router(_test.TestContext); + const MyRoute = MyRouter.Route; + + const my_router = MyRouter{ .routes = &[_]MyRoute{ + MyRoute.new(.GET, "/test/a", mock_a.func), + } }; + + _ = try my_router.dispatch(10, .GET, "/test//a"); + try mock_a.expectCalledOnceWith(10, .{}); + mock_a.reset(); + + _ = try my_router.dispatch(10, .GET, "//test///////////a////"); + try mock_a.expectCalledOnceWith(10, .{}); + mock_a.reset(); + + _ = try my_router.dispatch(10, .GET, "test/a"); + try mock_a.expectCalledOnceWith(10, .{}); + mock_a.reset(); + + try std.testing.expectError(error.RouteNotApplicable, my_router.dispatch(10, .GET, "/te/st/a")); + try mock_a.expectNotCalled(); +}