Major refactor in router.zig

This commit is contained in:
jaina heartles 2022-05-22 12:58:03 -07:00
parent 789e9062b2
commit b2430b333c

View file

@ -70,54 +70,117 @@ const RouteSegment = union(enum) {
param: []const u8, param: []const u8,
}; };
fn RouteWithContext(comptime Context: type) type { // convention: return HttpError IFF a status you can't handle happens.
return struct { // If status line/headers were written, always return void
const Self = @This(); const HttpError = error{Http404};
pub const Handler = fn (Context) void;
path_segments: []const RouteSegment, fn Route(comptime Context: type) type {
method: http.Method, return fn (Context, http.Method, []const u8) HttpError!void;
handler: Handler,
pub fn bind(method: http.Method, comptime path: []const u8, handler: Handler) Self {
return .{ .method = method, .path_segments = splitRoutePath(path), .handler = handler };
}
fn matchesPath(self: *const Self, request_path: []const u8) bool {
var request_segments = PathIter.from(request_path);
for (self.path_segments) |route_seg| {
const request_seg = request_segments.next() orelse return false;
switch (route_seg) {
.literal => |lit| {
if (!ciutf8.eql(lit, request_seg)) return false;
},
.param => {},
}
}
return request_segments.next() == null;
}
};
} }
pub fn Router(comptime Context: type) type { /// `makeRoute` takes a route definition and a handler of the form `fn(<Context>, <Params>) HttpError`
/// where `Params` is a struct containing one field of type `[]const u8` for each path parameter
///
/// Arguments:
/// method: The HTTP method to match
/// path: The path spec to match against. Path segments beginning with a `:` will cause the rest of
/// the segment to be treated as the name of a path parameter
/// handler: The code to execute on route match. This must be a function of form `fn(<Context>, <Params>) HttpError!void`
///
/// Implicit Arguments:
/// Context: the type of a user-supplied Context that is passed through the route. typically `http.Context` but
/// made generic for ease of testing. There are no restrictions on this type
/// Params: the type of a struct representing the path parameters expressed in `<path>`. This must be
/// a struct, with a one-one map between fields and path parameters. Each field must be of type
/// `[]const u8` and it must have the same name as a single path parameter.
///
/// Returns:
/// A new route function of type `fn(<Context>, http.Method, []const u8) ?HttpError`. When called,
/// this function will test the provided values against its specification. If they match, then
/// this function will parse path parameters and <handler> will be called with the supplied
/// context and params. If they do not match, this function will return null
///
/// Example:
/// route(.GET, "/user/:id/followers", struct{
/// fn getFollowers(ctx: http.Context, params: struct{ id: []const u8 } HttpError { ... }
/// ).getFollowers)
///
fn makeRoute(
comptime method: http.Method,
comptime path: []const u8,
comptime handler: anytype,
) return_type: {
const handler_info = @typeInfo(@TypeOf(handler));
if (handler_info != .Fn) @compileError("Route expects a function");
break :return_type Route(@typeInfo(@TypeOf(handler)).Fn.args[0].arg_type.?);
} {
const handler_args = @typeInfo(@TypeOf(handler)).Fn.args;
if (handler_args.len != 2) @compileError("handler function must have signature fn(Context, Params) HttpError");
if (@typeInfo(handler_args[1].arg_type.?) != .Struct) @compileError("Params in handler(Context, Params) must be struct");
const Context = handler_args[0].arg_type.?;
const Params = handler_args[1].arg_type.?;
const params_fields = std.meta.fields(Params);
var params_field_used = [_]bool{false} ** std.meta.fields(Params).len;
const segments = splitRoutePath(path);
for (segments) |seg| {
if (seg == .param) {
const found = for (params_fields) |f, i| {
if (std.mem.eql(u8, seg.param, f.name)) {
params_field_used[i] = true;
break true;
}
} else false;
if (!found) @compileError("Params does not contain " ++ seg.param ++ " field");
}
}
for (params_fields) |f, i| {
if (f.field_type != []const u8) @compileError("Params fields must be []const u8");
if (!params_field_used[i]) @compileError("Params field " ++ f.name ++ " not found in path");
}
return struct {
fn func(ctx: Context, req_method: http.Method, req_path: []const u8) HttpError!void {
if (req_method != method) return error.Http404;
var params: Params = undefined;
var req_segments = PathIter.from(req_path);
inline for (segments) |seg| {
const req_seg = req_segments.next() orelse return error.Http404;
var match = switch (seg) {
.literal => |literal| ciutf8.eql(literal, req_seg),
.param => |param| blk: {
@field(params, param) = req_seg;
break :blk true;
},
};
if (!match) return error.Http404;
}
if (req_segments.next() != null) return error.Http404;
return handler(ctx, params);
}
}.func;
}
pub fn Router(comptime Context: type, comptime routes: []const Route(Context)) type {
return struct { return struct {
const Self = @This(); const Self = @This();
pub const Route = RouteWithContext(Context); pub fn dispatch(_: *const Self, method: http.Method, path: []const u8, ctx: Context) HttpError!void {
for (routes) |r| {
routes: []const Route, return r(ctx, method, path) catch |err| switch (err) {
route_404: Route.Handler, error.Http404 => continue,
else => err,
pub fn dispatch(self: *const Self, method: http.Method, path: []const u8, ctx: Context) void { };
for (self.routes) |r| {
if (method == r.method and r.matchesPath(path)) {
return r.handler(ctx);
}
} }
return self.route_404(ctx); return error.Http404;
} }
}; };
} }
@ -175,34 +238,47 @@ const _tests = struct {
}, segments); }, segments);
} }
test "RouteWithContext(T).matchesPath" { fn CallTracker(comptime _uniq: anytype, comptime next: anytype) type {
const R = RouteWithContext(Context);
const r = R.bind(.GET, "/ab/cd", dummyHandler);
try std.testing.expectEqual(true, r.matchesPath("ab///cd////"));
try std.testing.expectEqual(true, r.matchesPath("//ab///cd"));
try std.testing.expectEqual(true, r.matchesPath("ab/cd"));
try std.testing.expectEqual(true, r.matchesPath("/ab/cd"));
try std.testing.expectEqual(false, r.matchesPath("/a/b/c/d"));
try std.testing.expectEqual(false, r.matchesPath("/aa/aa"));
try std.testing.expectEqual(false, r.matchesPath(""));
}
fn CallTracker(comptime _uniq: anytype, comptime next: fn (Context) void) type {
_ = _uniq; _ = _uniq;
var ctx_type: type = undefined;
var args_type: type = undefined;
switch (@typeInfo(@TypeOf(next))) {
.Fn => |func| {
if (func.args.len != 2) @compileError("next() must take 2 arguments");
ctx_type = func.args[0].arg_type.?;
args_type = func.args[1].arg_type.?;
//if (@typeInfo(Args) != .Struct) @compileError("second argument to next() must be struct");
},
else => @compileError("next must be function"),
}
const Context = ctx_type;
const Args = args_type;
return struct { return struct {
var calls: u32 = 0; var calls: u32 = 0;
var last_arg: ?Context = null;
fn func(ctx: Context) void { var last_ctx: ?Context = null;
var last_args: ?Args = null;
fn func(ctx: Context, args: Args) !void {
calls += 1; calls += 1;
last_arg = ctx; last_ctx = ctx;
return next(ctx); last_args = args;
return next(ctx, args);
} }
fn expectCalledOnceWith(expected: Context) !void { fn expectCalledOnceWith(exp_ctx: Context, exp_args: Args) !void {
try std.testing.expectEqual(@as(u32, 1), calls); try std.testing.expectEqual(@as(u32, 1), calls);
try std.testing.expectEqual(expected, last_arg.?); try std.testing.expectEqual(exp_ctx, last_ctx.?);
inline for (std.meta.fields(Args)) |f| {
try std.testing.expectEqualStrings(
@field(exp_args, f.name),
@field(last_args.?, f.name),
);
}
} }
fn expectNotCalled() !void { fn expectNotCalled() !void {
@ -211,172 +287,142 @@ const _tests = struct {
fn reset() void { fn reset() void {
calls = 0; calls = 0;
last_arg = null; last_ctx = null;
last_args = null;
} }
}; };
} }
const Context = u32; const TestContext = u32;
fn dummyHandler(_: Context) void {} const DummyArgs = struct {};
fn dummyHandler(comptime Args: type) type {
test "Router(T).dispatch" { comptime {
const mock_a = CallTracker(.{}, dummyHandler); return struct {
const mock_b = CallTracker(.{}, dummyHandler); fn func(_: TestContext, _: Args) HttpError!void {}
const mock_404 = CallTracker(.{}, dummyHandler); };
}
const R = Router(Context).Route;
const routes = [_]R{
R.bind(.GET, "/a", mock_a.func),
R.bind(.GET, "/b", mock_b.func),
};
const router = Router(Context){ .routes = &routes, .route_404 = mock_404.func };
router.dispatch(.GET, "/a", 10);
try mock_a.expectCalledOnceWith(10);
try mock_b.expectNotCalled();
try mock_404.expectNotCalled();
mock_a.reset();
router.dispatch(.GET, "/b", 0);
try mock_a.expectNotCalled();
try mock_b.expectCalledOnceWith(0);
try mock_404.expectNotCalled();
} }
test "Router(T).dispatch 404" { test "Router(T).dispatch" {
const mock_a = CallTracker(.{}, dummyHandler); const mock_a = CallTracker(.{}, dummyHandler(DummyArgs).func);
const mock_b = CallTracker(.{}, dummyHandler); const mock_b = CallTracker(.{}, dummyHandler(DummyArgs).func);
const mock_404 = CallTracker(.{}, dummyHandler);
const R = Router(Context).Route; const routes = comptime [_]Route(TestContext){
const routes = [_]R{ makeRoute(.GET, "/a", mock_a.func),
R.bind(.GET, "/a", mock_a.func), makeRoute(.GET, "/b", mock_b.func),
R.bind(.GET, "/b", mock_b.func),
}; };
const router = Router(Context){ .routes = &routes, .route_404 = mock_404.func }; const router = Router(TestContext, &routes){};
router.dispatch(.GET, "/c", 10); _ = try router.dispatch(.GET, "/a", 10);
try mock_a.expectNotCalled(); try mock_a.expectCalledOnceWith(10, .{});
try mock_b.expectNotCalled(); try mock_b.expectNotCalled();
try mock_404.expectCalledOnceWith(10); mock_a.reset();
mock_404.reset();
router.dispatch(.POST, "/a", 10); _ = try router.dispatch(.GET, "/b", 0);
try mock_a.expectNotCalled(); try mock_a.expectNotCalled();
try mock_b.expectNotCalled(); try mock_b.expectCalledOnceWith(0, .{});
try mock_404.expectCalledOnceWith(10); mock_b.reset();
try std.testing.expectError(error.Http404, router.dispatch(.GET, "/c", 0));
} }
test "Router(T).dispatch same path different methods" { test "Router(T).dispatch same path different methods" {
const mock_get = CallTracker(.{}, dummyHandler); const mock_get = CallTracker(.{}, dummyHandler(DummyArgs).func);
const mock_post = CallTracker(.{}, dummyHandler); const mock_post = CallTracker(.{}, dummyHandler(DummyArgs).func);
const mock_404 = CallTracker(.{}, dummyHandler);
const R = Router(Context).Route; const routes = comptime [_]Route(TestContext){
const routes = [_]R{ makeRoute(.GET, "/a", mock_get.func),
R.bind(.GET, "/a", mock_get.func), makeRoute(.POST, "/a", mock_post.func),
R.bind(.POST, "/a", mock_post.func),
}; };
const router = Router(Context){ .routes = &routes, .route_404 = mock_404.func }; const router = Router(TestContext, &routes){};
router.dispatch(.GET, "/a", 10); _ = try router.dispatch(.GET, "/a", 10);
try mock_get.expectCalledOnceWith(10); try mock_get.expectCalledOnceWith(10, .{});
try mock_post.expectNotCalled(); try mock_post.expectNotCalled();
try mock_404.expectNotCalled();
mock_get.reset(); mock_get.reset();
router.dispatch(.POST, "/a", 10); _ = try router.dispatch(.POST, "/a", 10);
try mock_get.expectNotCalled(); try mock_get.expectNotCalled();
try mock_post.expectCalledOnceWith(10); try mock_post.expectCalledOnceWith(10, .{});
try mock_404.expectNotCalled();
} }
test "Router(T).dispatch route under subpath" { test "Router(T).dispatch route under subpath" {
const mock_a = CallTracker(.{}, dummyHandler); const mock_a = CallTracker(.{}, dummyHandler(DummyArgs).func);
const mock_b = CallTracker(.{}, dummyHandler); const mock_b = CallTracker(.{}, dummyHandler(DummyArgs).func);
const mock_404 = CallTracker(.{}, dummyHandler);
const R = Router(Context).Route; const routes = comptime [_]Route(TestContext){
const routes = [_]R{ makeRoute(.GET, "/a", mock_a.func),
R.bind(.GET, "/a", mock_a.func), makeRoute(.GET, "/a/b", mock_b.func),
R.bind(.GET, "/a/b", mock_b.func),
}; };
const router = Router(Context){ .routes = &routes, .route_404 = mock_404.func }; const router = Router(TestContext, &routes){};
router.dispatch(.GET, "/a", 10); _ = try router.dispatch(.GET, "/a", 10);
try mock_a.expectCalledOnceWith(10); try mock_a.expectCalledOnceWith(10, .{});
try mock_b.expectNotCalled(); try mock_b.expectNotCalled();
try mock_404.expectNotCalled();
mock_a.reset(); mock_a.reset();
router.dispatch(.GET, "/a/b", 11); _ = try router.dispatch(.GET, "/a/b", 11);
try mock_a.expectNotCalled(); try mock_a.expectNotCalled();
try mock_b.expectCalledOnceWith(11); try mock_b.expectCalledOnceWith(11, .{});
try mock_404.expectNotCalled();
} }
test "Router(T).dispatch case-insensitive route" { test "Router(T).dispatch case-insensitive route" {
const mock_a = CallTracker(.{}, dummyHandler); const mock_a = CallTracker(.{}, dummyHandler(DummyArgs).func);
const mock_404 = CallTracker(.{}, dummyHandler);
const R = Router(Context).Route; const routes = comptime [_]Route(TestContext){
const routes = [_]R{ makeRoute(.GET, "/test/a", mock_a.func),
R.bind(.GET, "/test/a", mock_a.func),
}; };
const router = Router(Context){ .routes = &routes, .route_404 = mock_404.func }; const router = Router(TestContext, &routes){};
router.dispatch(.GET, "/TEST/A", 10); _ = try router.dispatch(.GET, "/TEST/A", 10);
try mock_a.expectCalledOnceWith(10); try mock_a.expectCalledOnceWith(10, .{});
try mock_404.expectNotCalled();
mock_a.reset(); mock_a.reset();
router.dispatch(.GET, "/TesT/a", 11); _ = try router.dispatch(.GET, "/TesT/a", 11);
try mock_a.expectCalledOnceWith(11); try mock_a.expectCalledOnceWith(11, .{});
try mock_404.expectNotCalled();
} }
test "Router(T).dispatch redundant /" { test "Router(T).dispatch redundant /" {
const mock_a = CallTracker(.{}, dummyHandler); const mock_a = CallTracker(.{}, dummyHandler(DummyArgs).func);
const mock_404 = CallTracker(.{}, dummyHandler);
const R = Router(Context).Route; const routes = comptime [_]Route(TestContext){
const routes = [_]R{ makeRoute(.GET, "/test/a", mock_a.func),
R.bind(.GET, "/test/a", mock_a.func),
}; };
const router = Router(Context){ .routes = &routes, .route_404 = mock_404.func }; const router = Router(TestContext, &routes){};
router.dispatch(.GET, "/test//a", 10); _ = try router.dispatch(.GET, "/test//a", 10);
try mock_a.expectCalledOnceWith(10); try mock_a.expectCalledOnceWith(10, .{});
try mock_404.expectNotCalled();
mock_a.reset(); mock_a.reset();
router.dispatch(.GET, "/test/a/", 11); _ = try router.dispatch(.GET, "//test///////////a////", 11);
try mock_a.expectCalledOnceWith(11); try mock_a.expectCalledOnceWith(11, .{});
try mock_404.expectNotCalled();
mock_a.reset(); mock_a.reset();
router.dispatch(.GET, "test/a/", 12); _ = try router.dispatch(.GET, "test/a", 12);
try mock_a.expectCalledOnceWith(12); try mock_a.expectCalledOnceWith(12, .{});
try mock_404.expectNotCalled();
} }
test "Router(T).dispatch with variables" { test "Router(T).dispatch with variables" {
const mock_a = CallTracker(.{}, dummyHandler); const mock_a = CallTracker(.{}, dummyHandler(struct { id: []const u8 }).func);
const mock_404 = CallTracker(.{}, dummyHandler); const mock_b = CallTracker(.{}, dummyHandler(struct { a_id: []const u8, b_id: []const u8 }).func);
const R = Router(Context).Route; const routes = comptime [_]Route(TestContext){
const routes = [_]R{ makeRoute(.GET, "/test/:id/abcd", mock_a.func),
R.bind(.GET, "/test/:id/abcd", mock_a.func), makeRoute(.GET, "/test/:a_id/abcd/:b_id", mock_b.func),
}; };
const router = Router(Context){ .routes = &routes, .route_404 = mock_404.func }; const router = Router(TestContext, &routes){};
router.dispatch(.GET, "/test/lskdjflsdjfksld/abcd", 10); _ = try router.dispatch(.GET, "/test/xyz/abcd", 10);
try mock_a.expectCalledOnceWith(10); try mock_a.expectCalledOnceWith(10, .{ .id = "xyz" });
try mock_404.expectNotCalled(); try mock_b.expectNotCalled();
mock_a.reset(); mock_a.reset();
router.dispatch(.GET, "/test//abcd", 10); try std.testing.expectError(error.Http404, router.dispatch(.GET, "/test//abcd", 10));
try mock_a.expectNotCalled(); try mock_a.expectNotCalled();
try mock_404.expectCalledOnceWith(10); try mock_b.expectNotCalled();
_ = try router.dispatch(.GET, "/test/xyz/abcd/zyx", 10);
try mock_a.expectNotCalled();
try mock_b.expectCalledOnceWith(10, .{ .a_id = "xyz", .b_id = "zyx" });
} }
}; };