Controller refactor
This commit is contained in:
parent
2aa9569050
commit
83ee7efba0
5 changed files with 17 additions and 430 deletions
|
@ -1,7 +1,6 @@
|
|||
const std = @import("std");
|
||||
const ciutf8 = @import("util").ciutf8;
|
||||
|
||||
const routing = @import("./routing.zig");
|
||||
const request = @import("./request.zig");
|
||||
|
||||
pub const server = @import("./server.zig");
|
||||
|
@ -12,12 +11,6 @@ pub const Status = std.http.Status;
|
|||
pub const Request = request.Request;
|
||||
pub const Server = server.Server;
|
||||
|
||||
// TODO: rework routing
|
||||
pub fn Router(comptime ServerContext: type) type {
|
||||
return routing.Router(ServerContext, *server.Context);
|
||||
}
|
||||
pub const RouteArgs = routing.RouteArgs;
|
||||
|
||||
pub const Headers = std.HashMap([]const u8, []const u8, struct {
|
||||
pub fn eql(_: @This(), a: []const u8, b: []const u8) bool {
|
||||
return ciutf8.eql(a, b);
|
||||
|
@ -29,7 +22,6 @@ pub const Headers = std.HashMap([]const u8, []const u8, struct {
|
|||
}, std.hash_map.default_max_load_percentage);
|
||||
|
||||
test {
|
||||
_ = routing;
|
||||
_ = server;
|
||||
_ = request;
|
||||
}
|
||||
|
|
|
@ -1,372 +0,0 @@
|
|||
const std = @import("std");
|
||||
const util = @import("util");
|
||||
const builtin = @import("builtin");
|
||||
|
||||
const http = @import("./lib.zig");
|
||||
|
||||
pub const RouteArgs = struct {
|
||||
const max_args = 16;
|
||||
|
||||
names: [max_args]?[]const u8 = [_]?[]const u8{null} ** max_args,
|
||||
values: [max_args]?[]const u8 = [_]?[]const u8{null} ** max_args,
|
||||
|
||||
pub fn get(self: RouteArgs, name: []const u8) ?[]const u8 {
|
||||
for (self.names) |arg_name, i| {
|
||||
if (arg_name == null) return null;
|
||||
|
||||
if (util.ciutf8.eql(name, arg_name.?)) return self.values[i];
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
pub fn Router(comptime ServerContext: type, comptime RequestContext: type) type {
|
||||
return struct {
|
||||
const Self = @This();
|
||||
|
||||
routes: []const Route,
|
||||
|
||||
pub const Handler = *const fn (ServerContext, RequestContext, RouteArgs) anyerror!void;
|
||||
pub const Route = struct {
|
||||
pub const Args = RouteArgs;
|
||||
|
||||
method: http.Method,
|
||||
path: []const u8,
|
||||
path_segments: []const RouteSegment,
|
||||
handler: Handler,
|
||||
|
||||
pub fn new(method: http.Method, comptime path: []const u8, handler: Handler) Route {
|
||||
const result = .{
|
||||
.method = method,
|
||||
.path = path,
|
||||
.path_segments = splitPathSegments(path),
|
||||
.handler = handler,
|
||||
};
|
||||
|
||||
var param_count: usize = 0;
|
||||
for (result.path_segments) |seg| {
|
||||
if (seg == .param) param_count += 1;
|
||||
}
|
||||
|
||||
if (param_count > RouteArgs.max_args) @panic("Too many params");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
fn dispatch(route: Route, sctx: ServerContext, rctx: RequestContext, req_method: std.http.Method, req_path: []const u8) anyerror!void {
|
||||
if (req_method != route.method) return error.RouteNotApplicable;
|
||||
|
||||
var args = RouteArgs{};
|
||||
var arg_count: usize = 0;
|
||||
var req_segments = util.PathIter.from(req_path);
|
||||
for (route.path_segments) |seg| {
|
||||
const req_seg = req_segments.next() orelse return error.RouteNotApplicable;
|
||||
switch (seg) {
|
||||
.literal => |literal| {
|
||||
if (!util.ciutf8.eql(literal, req_seg)) return error.RouteNotApplicable;
|
||||
},
|
||||
.param => |param| {
|
||||
args.names[arg_count] = param;
|
||||
args.values[arg_count] = req_seg;
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if (req_segments.next() != null) return error.RouteNotApplicable;
|
||||
|
||||
std.log.debug("selected route {s} {s}", .{ @tagName(route.method), route.path });
|
||||
if (builtin.zig_backend == .stage1) {
|
||||
return route.handler.*(sctx, rctx, args);
|
||||
}
|
||||
|
||||
return route.handler(sctx, rctx, args);
|
||||
}
|
||||
};
|
||||
|
||||
pub fn dispatch(self: Self, sctx: ServerContext, rctx: RequestContext, method: std.http.Method, path: []const u8) anyerror!void {
|
||||
const eff_path = std.mem.sliceTo(std.mem.sliceTo(path, '#'), '?');
|
||||
for (self.routes) |r| {
|
||||
r.dispatch(sctx, rctx, method, eff_path) catch |err| switch (err) {
|
||||
error.RouteNotApplicable => continue,
|
||||
else => return err,
|
||||
};
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
return error.RouteNotApplicable;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
const RouteSegment = union(enum) {
|
||||
literal: []const u8,
|
||||
param: []const u8,
|
||||
};
|
||||
|
||||
fn paramNameFrom(segment: []const u8) []const u8 {
|
||||
return segment[1..];
|
||||
}
|
||||
|
||||
fn isParamSegment(segment: []const u8) bool {
|
||||
return segment[0] == ':';
|
||||
}
|
||||
|
||||
fn splitPathSegments(comptime path: []const u8) []const RouteSegment {
|
||||
comptime {
|
||||
var segments: [path.len]RouteSegment = undefined;
|
||||
var segment_count: usize = 0;
|
||||
|
||||
var iter = util.PathIter.from(path);
|
||||
while (iter.next()) |segment| {
|
||||
if (isParamSegment(segment)) {
|
||||
segments[segment_count] = .{
|
||||
.param = paramNameFrom(segment),
|
||||
};
|
||||
} else {
|
||||
segments[segment_count] = .{ .literal = segment };
|
||||
}
|
||||
|
||||
segment_count += 1;
|
||||
}
|
||||
|
||||
return segments[0..segment_count];
|
||||
}
|
||||
}
|
||||
|
||||
const _test = struct {
|
||||
fn CallTracker(comptime _uniq: anytype, comptime next: anytype) type {
|
||||
_ = _uniq;
|
||||
|
||||
var ctx_type: type = undefined;
|
||||
var args_type: type = undefined;
|
||||
switch (@typeInfo(@TypeOf(next))) {
|
||||
.Fn => |func| {
|
||||
if (func.args.len != 3) @compileError("next() must take 3 arguments");
|
||||
|
||||
ctx_type = func.args[0].arg_type.?;
|
||||
args_type = func.args[2].arg_type.?;
|
||||
//if (@typeInfo(RouteArgs) != .Struct) @compileError("second argument to next() must be struct");
|
||||
},
|
||||
else => @compileError("next must be function"),
|
||||
}
|
||||
|
||||
const Context = ctx_type;
|
||||
|
||||
return struct {
|
||||
var calls: u32 = 0;
|
||||
|
||||
var last_rctx: ?Context = null;
|
||||
var last_sctx: ?Context = null;
|
||||
var last_args: ?RouteArgs = null;
|
||||
|
||||
fn func(sctx: Context, rctx: Context, args: RouteArgs) !void {
|
||||
calls += 1;
|
||||
last_sctx = sctx;
|
||||
last_rctx = rctx;
|
||||
last_args = args;
|
||||
return next(sctx, rctx, args);
|
||||
}
|
||||
|
||||
fn expectCalledOnceWith(exp_sctx: Context, exp_rctx: Context, exp_args: RouteArgs) !void {
|
||||
try std.testing.expectEqual(@as(u32, 1), calls);
|
||||
try std.testing.expectEqual(exp_sctx, last_sctx.?);
|
||||
try std.testing.expectEqual(exp_rctx, last_rctx.?);
|
||||
for (exp_args.names) |exp_name, i| {
|
||||
if (exp_name == null) {
|
||||
try std.testing.expectEqual(@as(?[]const u8, null), last_args.?.names[i]);
|
||||
try std.testing.expectEqual(@as(?[]const u8, null), last_args.?.values[i]);
|
||||
} else {
|
||||
try std.testing.expectEqualStrings(exp_name.?, last_args.?.names[i].?);
|
||||
try std.testing.expectEqualStrings(exp_args.values[i].?, last_args.?.values[i].?);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn expectNotCalled() !void {
|
||||
try std.testing.expectEqual(@as(u32, 0), calls);
|
||||
}
|
||||
|
||||
fn reset() void {
|
||||
calls = 0;
|
||||
last_sctx = null;
|
||||
last_rctx = null;
|
||||
last_args = null;
|
||||
}
|
||||
};
|
||||
}
|
||||
const TestContext = u32;
|
||||
fn dummyHandler(_: TestContext, _: TestContext, _: RouteArgs) anyerror!void {}
|
||||
};
|
||||
|
||||
test "route(T, ...) basic" {
|
||||
const mock_a = _test.CallTracker(.{}, _test.dummyHandler);
|
||||
|
||||
const MyRouter = Router(_test.TestContext, _test.TestContext);
|
||||
const MyRoute = MyRouter.Route;
|
||||
|
||||
const my_router = MyRouter{ .routes = &[_]MyRoute{
|
||||
MyRoute.new(.GET, "/a", mock_a.func),
|
||||
} };
|
||||
|
||||
_ = try my_router.dispatch(10, 100, .GET, "/a");
|
||||
try mock_a.expectCalledOnceWith(10, 100, .{});
|
||||
mock_a.reset();
|
||||
}
|
||||
|
||||
test "Route(T) matches correct route from multiple routes" {
|
||||
const mock_a = _test.CallTracker(.{}, _test.dummyHandler);
|
||||
const mock_b = _test.CallTracker(.{}, _test.dummyHandler);
|
||||
|
||||
const MyRouter = Router(_test.TestContext, _test.TestContext);
|
||||
const MyRoute = MyRouter.Route;
|
||||
|
||||
const my_router = MyRouter{ .routes = &[_]MyRoute{
|
||||
MyRoute.new(.GET, "/a", mock_a.func),
|
||||
MyRoute.new(.GET, "/b", mock_b.func),
|
||||
} };
|
||||
|
||||
_ = try my_router.dispatch(10, 100, .GET, "/a");
|
||||
try mock_a.expectCalledOnceWith(10, 100, .{});
|
||||
try mock_b.expectNotCalled();
|
||||
mock_a.reset();
|
||||
|
||||
_ = try my_router.dispatch(10, 100, .GET, "/b");
|
||||
try mock_a.expectNotCalled();
|
||||
try mock_b.expectCalledOnceWith(10, 100, .{});
|
||||
mock_b.reset();
|
||||
}
|
||||
|
||||
test "Route(T) passes correct context" {
|
||||
const mock_a = _test.CallTracker(.{}, _test.dummyHandler);
|
||||
|
||||
const MyRouter = Router(_test.TestContext, _test.TestContext);
|
||||
const MyRoute = MyRouter.Route;
|
||||
|
||||
const my_router = MyRouter{ .routes = &[_]MyRoute{
|
||||
MyRoute.new(.GET, "/a", mock_a.func),
|
||||
} };
|
||||
|
||||
_ = try my_router.dispatch(10, 100, .GET, "/a");
|
||||
try mock_a.expectCalledOnceWith(10, 100, .{});
|
||||
mock_a.reset();
|
||||
|
||||
_ = try my_router.dispatch(16, 32, .GET, "/a");
|
||||
try mock_a.expectCalledOnceWith(16, 32, .{});
|
||||
mock_a.reset();
|
||||
}
|
||||
|
||||
test "Route(T) errors on no matching route" {
|
||||
const mock_a = _test.CallTracker(.{}, _test.dummyHandler);
|
||||
const mock_b = _test.CallTracker(.{}, _test.dummyHandler);
|
||||
|
||||
const MyRouter = Router(_test.TestContext, _test.TestContext);
|
||||
const MyRoute = MyRouter.Route;
|
||||
|
||||
const my_router = MyRouter{ .routes = &[_]MyRoute{
|
||||
MyRoute.new(.GET, "/a", mock_a.func),
|
||||
MyRoute.new(.GET, "/b", mock_b.func),
|
||||
} };
|
||||
|
||||
try std.testing.expectError(error.RouteNotApplicable, my_router.dispatch(10, 100, .GET, "/c"));
|
||||
try mock_a.expectNotCalled();
|
||||
try mock_b.expectNotCalled();
|
||||
}
|
||||
|
||||
test "route(T) with no routes" {
|
||||
const MyRouter = Router(_test.TestContext, _test.TestContext);
|
||||
const MyRoute = MyRouter.Route;
|
||||
|
||||
const my_router = MyRouter{ .routes = &[_]MyRoute{} };
|
||||
|
||||
try std.testing.expectError(error.RouteNotApplicable, my_router.dispatch(10, 100, .GET, "/c"));
|
||||
}
|
||||
|
||||
test "route(T, ...) same path different methods" {
|
||||
const mock_get = _test.CallTracker(.{}, _test.dummyHandler);
|
||||
const mock_post = _test.CallTracker(.{}, _test.dummyHandler);
|
||||
|
||||
const MyRouter = Router(_test.TestContext, _test.TestContext);
|
||||
const MyRoute = MyRouter.Route;
|
||||
|
||||
const my_router = MyRouter{ .routes = &[_]MyRoute{
|
||||
MyRoute.new(.GET, "/a", mock_get.func),
|
||||
MyRoute.new(.POST, "/a", mock_post.func),
|
||||
} };
|
||||
|
||||
_ = try my_router.dispatch(10, 100, .GET, "/a");
|
||||
try mock_get.expectCalledOnceWith(10, 100, .{});
|
||||
try mock_post.expectNotCalled();
|
||||
mock_get.reset();
|
||||
|
||||
_ = try my_router.dispatch(10, 100, .POST, "/a");
|
||||
try mock_get.expectNotCalled();
|
||||
try mock_post.expectCalledOnceWith(10, 100, .{});
|
||||
}
|
||||
|
||||
test "route(T, ...) route under subpath" {
|
||||
const mock_a = _test.CallTracker(.{}, _test.dummyHandler);
|
||||
const mock_b = _test.CallTracker(.{}, _test.dummyHandler);
|
||||
|
||||
const MyRouter = Router(_test.TestContext, _test.TestContext);
|
||||
const MyRoute = MyRouter.Route;
|
||||
|
||||
const my_router = MyRouter{ .routes = &[_]MyRoute{
|
||||
MyRoute.new(.GET, "/a", mock_a.func),
|
||||
MyRoute.new(.GET, "/a/b", mock_b.func),
|
||||
} };
|
||||
|
||||
_ = try my_router.dispatch(10, 100, .GET, "/a");
|
||||
try mock_a.expectCalledOnceWith(10, 100, .{});
|
||||
try mock_b.expectNotCalled();
|
||||
mock_a.reset();
|
||||
|
||||
_ = try my_router.dispatch(10, 100, .GET, "/a/b");
|
||||
try mock_a.expectNotCalled();
|
||||
try mock_b.expectCalledOnceWith(10, 100, .{});
|
||||
}
|
||||
|
||||
test "route(T, ...) case-insensitive route" {
|
||||
const mock_a = _test.CallTracker(.{}, _test.dummyHandler);
|
||||
|
||||
const MyRouter = Router(_test.TestContext, _test.TestContext);
|
||||
const MyRoute = MyRouter.Route;
|
||||
|
||||
const my_router = MyRouter{ .routes = &[_]MyRoute{
|
||||
MyRoute.new(.GET, "/test/a", mock_a.func),
|
||||
} };
|
||||
|
||||
_ = try my_router.dispatch(10, 100, .GET, "/TEST/A");
|
||||
try mock_a.expectCalledOnceWith(10, 100, .{});
|
||||
mock_a.reset();
|
||||
|
||||
_ = try my_router.dispatch(10, 100, .GET, "/TesT/a");
|
||||
try mock_a.expectCalledOnceWith(10, 100, .{});
|
||||
}
|
||||
|
||||
test "route(T, ...) redundant /" {
|
||||
const mock_a = _test.CallTracker(.{}, _test.dummyHandler);
|
||||
|
||||
const MyRouter = Router(_test.TestContext, _test.TestContext);
|
||||
const MyRoute = MyRouter.Route;
|
||||
|
||||
const my_router = MyRouter{ .routes = &[_]MyRoute{
|
||||
MyRoute.new(.GET, "/test/a", mock_a.func),
|
||||
} };
|
||||
|
||||
_ = try my_router.dispatch(10, 100, .GET, "/test//a");
|
||||
try mock_a.expectCalledOnceWith(10, 100, .{});
|
||||
mock_a.reset();
|
||||
|
||||
_ = try my_router.dispatch(10, 100, .GET, "//test///////////a////");
|
||||
try mock_a.expectCalledOnceWith(10, 100, .{});
|
||||
mock_a.reset();
|
||||
|
||||
_ = try my_router.dispatch(10, 100, .GET, "test/a");
|
||||
try mock_a.expectCalledOnceWith(10, 100, .{});
|
||||
mock_a.reset();
|
||||
|
||||
try std.testing.expectError(error.RouteNotApplicable, my_router.dispatch(10, 100, .GET, "/te/st/a"));
|
||||
try mock_a.expectNotCalled();
|
||||
}
|
|
@ -5,6 +5,7 @@ const http = @import("http");
|
|||
const api = @import("api");
|
||||
const util = @import("util");
|
||||
const query_utils = @import("./query.zig");
|
||||
const json_utils = @import("./json.zig");
|
||||
|
||||
pub const auth = @import("./controllers/auth.zig");
|
||||
pub const communities = @import("./controllers/communities.zig");
|
||||
|
@ -12,6 +13,17 @@ pub const invites = @import("./controllers/invites.zig");
|
|||
pub const users = @import("./controllers/users.zig");
|
||||
pub const notes = @import("./controllers/notes.zig");
|
||||
|
||||
pub fn routeRequest(api_source: anytype, request: http.Request, response: http.Response, alloc: std.mem.Allocator) void {
|
||||
// TODO: hashmaps?
|
||||
inline for (routes) |route| {
|
||||
if (Context(route).matchAndHandle(api_source, request, response, alloc)) return;
|
||||
}
|
||||
|
||||
// todo 404
|
||||
}
|
||||
|
||||
const routes = .{sample_api};
|
||||
|
||||
pub const sample_api = struct {
|
||||
const Self = @This();
|
||||
|
||||
|
@ -24,7 +36,7 @@ pub const sample_api = struct {
|
|||
};
|
||||
|
||||
pub const Body = struct {
|
||||
content: []const u8,
|
||||
content: util.Uuid,
|
||||
};
|
||||
|
||||
pub const Query = struct {
|
||||
|
@ -32,6 +44,7 @@ pub const sample_api = struct {
|
|||
};
|
||||
|
||||
pub fn handler(ctx: Context(Self), response: *Response, _: api.ApiSource.Conn) !void {
|
||||
std.log.debug("{}", .{ctx.body.content});
|
||||
try response.writeJson(.created, ctx.query);
|
||||
}
|
||||
};
|
||||
|
@ -115,14 +128,13 @@ pub fn Context(comptime Route: type) type {
|
|||
fn parseBody(self: *Self, req: http.Request) !void {
|
||||
if (Body != void) {
|
||||
const body = req.body orelse return error.NoBody;
|
||||
var tokens = std.json.TokenStream.init(body);
|
||||
self.body = try std.json.parse(Body, &tokens, .{ .allocator = self.allocator });
|
||||
self.body = try json_utils.parse(Body, body, self.allocator);
|
||||
}
|
||||
}
|
||||
|
||||
fn freeBody(self: *Self) void {
|
||||
if (Body != void) {
|
||||
std.json.parseFree(Body, self.body, .{ .allocator = self.allocator });
|
||||
json_utils.parseFree(self.body, self.allocator);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -6,46 +6,8 @@ const util = @import("util");
|
|||
|
||||
const api = @import("api");
|
||||
const migrations = @import("./migrations.zig");
|
||||
const Uuid = util.Uuid;
|
||||
const c = @import("./controllers.zig");
|
||||
|
||||
// this thing is overcomplicated and weird. stop this
|
||||
const Router = http.Router(*RequestServer);
|
||||
const Route = Router.Route;
|
||||
const RouteArgs = http.RouteArgs;
|
||||
const router = Router{
|
||||
.routes = &[_]Route{
|
||||
//Route.new(.GET, "/healthcheck", &c.healthcheck),
|
||||
|
||||
//prepare(c.auth.login),
|
||||
//prepare(c.auth.verify_login),
|
||||
|
||||
//prepare(c.communities.create),
|
||||
|
||||
//prepare(c.invites.create),
|
||||
|
||||
//prepare(c.users.create),
|
||||
|
||||
//prepare(c.notes.create),
|
||||
//prepare(c.notes.get),
|
||||
|
||||
//prepare(c.communities.query),
|
||||
|
||||
//Route.new(.GET, "/notes/:id/reacts", &c.notes.reacts.list),
|
||||
//Route.new(.POST, "/notes/:id/reacts", &c.notes.reacts.create),
|
||||
|
||||
//Route.new(.GET, "/actors/:id", &c.actors.get),
|
||||
|
||||
//Route.new(.GET, "/admin/invites/:id", &c.admin.invites.get),
|
||||
|
||||
//Route.new(.GET, "/admin/communities/:host", &c.admin.communities.get),
|
||||
},
|
||||
};
|
||||
|
||||
fn prepare(comptime route_desc: type) Route {
|
||||
return Route.new(route_desc.method, route_desc.path, &route_desc.handler);
|
||||
}
|
||||
|
||||
pub const RequestServer = struct {
|
||||
alloc: std.mem.Allocator,
|
||||
api: *api.ApiSource,
|
||||
|
@ -74,14 +36,6 @@ pub const RequestServer = struct {
|
|||
|
||||
_ = c.Context(c.sample_api).matchAndHandle(self.api, ctx, self.alloc);
|
||||
if (true) continue;
|
||||
|
||||
router.dispatch(self, &ctx, ctx.request.method, ctx.request.path) catch |err| switch (err) {
|
||||
error.NotFound, error.RouteNotApplicable => c.notFound(self, &ctx),
|
||||
else => {
|
||||
std.log.err("Unhandled error in controller ({s}): {}\nStack Trace\n{?}", .{ ctx.request.path, err, @errorReturnTrace() });
|
||||
c.internalServerError(self, &ctx);
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -48,6 +48,7 @@ pub fn format(value: Uuid, comptime _: []const u8, _: std.fmt.FormatOptions, wri
|
|||
});
|
||||
}
|
||||
|
||||
pub const JsonParseAs = []const u8;
|
||||
pub fn jsonStringify(value: Uuid, _: std.json.StringifyOptions, writer: anytype) !void {
|
||||
try std.fmt.format(writer, "\"{}\"", .{value});
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue