Controller refactor

This commit is contained in:
jaina heartles 2022-10-09 19:06:11 -07:00
parent 2aa9569050
commit 83ee7efba0
5 changed files with 17 additions and 430 deletions

View file

@ -1,7 +1,6 @@
const std = @import("std"); const std = @import("std");
const ciutf8 = @import("util").ciutf8; const ciutf8 = @import("util").ciutf8;
const routing = @import("./routing.zig");
const request = @import("./request.zig"); const request = @import("./request.zig");
pub const server = @import("./server.zig"); pub const server = @import("./server.zig");
@ -12,12 +11,6 @@ pub const Status = std.http.Status;
pub const Request = request.Request; pub const Request = request.Request;
pub const Server = server.Server; pub const Server = server.Server;
// TODO: rework routing
pub fn Router(comptime ServerContext: type) type {
return routing.Router(ServerContext, *server.Context);
}
pub const RouteArgs = routing.RouteArgs;
pub const Headers = std.HashMap([]const u8, []const u8, struct { pub const Headers = std.HashMap([]const u8, []const u8, struct {
pub fn eql(_: @This(), a: []const u8, b: []const u8) bool { pub fn eql(_: @This(), a: []const u8, b: []const u8) bool {
return ciutf8.eql(a, b); return ciutf8.eql(a, b);
@ -29,7 +22,6 @@ pub const Headers = std.HashMap([]const u8, []const u8, struct {
}, std.hash_map.default_max_load_percentage); }, std.hash_map.default_max_load_percentage);
test { test {
_ = routing;
_ = server; _ = server;
_ = request; _ = request;
} }

View file

@ -1,372 +0,0 @@
const std = @import("std");
const util = @import("util");
const builtin = @import("builtin");
const http = @import("./lib.zig");
pub const RouteArgs = struct {
const max_args = 16;
names: [max_args]?[]const u8 = [_]?[]const u8{null} ** max_args,
values: [max_args]?[]const u8 = [_]?[]const u8{null} ** max_args,
pub fn get(self: RouteArgs, name: []const u8) ?[]const u8 {
for (self.names) |arg_name, i| {
if (arg_name == null) return null;
if (util.ciutf8.eql(name, arg_name.?)) return self.values[i];
}
return null;
}
};
pub fn Router(comptime ServerContext: type, comptime RequestContext: type) type {
return struct {
const Self = @This();
routes: []const Route,
pub const Handler = *const fn (ServerContext, RequestContext, RouteArgs) anyerror!void;
pub const Route = struct {
pub const Args = RouteArgs;
method: http.Method,
path: []const u8,
path_segments: []const RouteSegment,
handler: Handler,
pub fn new(method: http.Method, comptime path: []const u8, handler: Handler) Route {
const result = .{
.method = method,
.path = path,
.path_segments = splitPathSegments(path),
.handler = handler,
};
var param_count: usize = 0;
for (result.path_segments) |seg| {
if (seg == .param) param_count += 1;
}
if (param_count > RouteArgs.max_args) @panic("Too many params");
return result;
}
fn dispatch(route: Route, sctx: ServerContext, rctx: RequestContext, req_method: std.http.Method, req_path: []const u8) anyerror!void {
if (req_method != route.method) return error.RouteNotApplicable;
var args = RouteArgs{};
var arg_count: usize = 0;
var req_segments = util.PathIter.from(req_path);
for (route.path_segments) |seg| {
const req_seg = req_segments.next() orelse return error.RouteNotApplicable;
switch (seg) {
.literal => |literal| {
if (!util.ciutf8.eql(literal, req_seg)) return error.RouteNotApplicable;
},
.param => |param| {
args.names[arg_count] = param;
args.values[arg_count] = req_seg;
},
}
}
if (req_segments.next() != null) return error.RouteNotApplicable;
std.log.debug("selected route {s} {s}", .{ @tagName(route.method), route.path });
if (builtin.zig_backend == .stage1) {
return route.handler.*(sctx, rctx, args);
}
return route.handler(sctx, rctx, args);
}
};
pub fn dispatch(self: Self, sctx: ServerContext, rctx: RequestContext, method: std.http.Method, path: []const u8) anyerror!void {
const eff_path = std.mem.sliceTo(std.mem.sliceTo(path, '#'), '?');
for (self.routes) |r| {
r.dispatch(sctx, rctx, method, eff_path) catch |err| switch (err) {
error.RouteNotApplicable => continue,
else => return err,
};
return;
}
return error.RouteNotApplicable;
}
};
}
const RouteSegment = union(enum) {
literal: []const u8,
param: []const u8,
};
fn paramNameFrom(segment: []const u8) []const u8 {
return segment[1..];
}
fn isParamSegment(segment: []const u8) bool {
return segment[0] == ':';
}
fn splitPathSegments(comptime path: []const u8) []const RouteSegment {
comptime {
var segments: [path.len]RouteSegment = undefined;
var segment_count: usize = 0;
var iter = util.PathIter.from(path);
while (iter.next()) |segment| {
if (isParamSegment(segment)) {
segments[segment_count] = .{
.param = paramNameFrom(segment),
};
} else {
segments[segment_count] = .{ .literal = segment };
}
segment_count += 1;
}
return segments[0..segment_count];
}
}
const _test = struct {
fn CallTracker(comptime _uniq: anytype, comptime next: anytype) type {
_ = _uniq;
var ctx_type: type = undefined;
var args_type: type = undefined;
switch (@typeInfo(@TypeOf(next))) {
.Fn => |func| {
if (func.args.len != 3) @compileError("next() must take 3 arguments");
ctx_type = func.args[0].arg_type.?;
args_type = func.args[2].arg_type.?;
//if (@typeInfo(RouteArgs) != .Struct) @compileError("second argument to next() must be struct");
},
else => @compileError("next must be function"),
}
const Context = ctx_type;
return struct {
var calls: u32 = 0;
var last_rctx: ?Context = null;
var last_sctx: ?Context = null;
var last_args: ?RouteArgs = null;
fn func(sctx: Context, rctx: Context, args: RouteArgs) !void {
calls += 1;
last_sctx = sctx;
last_rctx = rctx;
last_args = args;
return next(sctx, rctx, args);
}
fn expectCalledOnceWith(exp_sctx: Context, exp_rctx: Context, exp_args: RouteArgs) !void {
try std.testing.expectEqual(@as(u32, 1), calls);
try std.testing.expectEqual(exp_sctx, last_sctx.?);
try std.testing.expectEqual(exp_rctx, last_rctx.?);
for (exp_args.names) |exp_name, i| {
if (exp_name == null) {
try std.testing.expectEqual(@as(?[]const u8, null), last_args.?.names[i]);
try std.testing.expectEqual(@as(?[]const u8, null), last_args.?.values[i]);
} else {
try std.testing.expectEqualStrings(exp_name.?, last_args.?.names[i].?);
try std.testing.expectEqualStrings(exp_args.values[i].?, last_args.?.values[i].?);
}
}
}
fn expectNotCalled() !void {
try std.testing.expectEqual(@as(u32, 0), calls);
}
fn reset() void {
calls = 0;
last_sctx = null;
last_rctx = null;
last_args = null;
}
};
}
const TestContext = u32;
fn dummyHandler(_: TestContext, _: TestContext, _: RouteArgs) anyerror!void {}
};
test "route(T, ...) basic" {
const mock_a = _test.CallTracker(.{}, _test.dummyHandler);
const MyRouter = Router(_test.TestContext, _test.TestContext);
const MyRoute = MyRouter.Route;
const my_router = MyRouter{ .routes = &[_]MyRoute{
MyRoute.new(.GET, "/a", mock_a.func),
} };
_ = try my_router.dispatch(10, 100, .GET, "/a");
try mock_a.expectCalledOnceWith(10, 100, .{});
mock_a.reset();
}
test "Route(T) matches correct route from multiple routes" {
const mock_a = _test.CallTracker(.{}, _test.dummyHandler);
const mock_b = _test.CallTracker(.{}, _test.dummyHandler);
const MyRouter = Router(_test.TestContext, _test.TestContext);
const MyRoute = MyRouter.Route;
const my_router = MyRouter{ .routes = &[_]MyRoute{
MyRoute.new(.GET, "/a", mock_a.func),
MyRoute.new(.GET, "/b", mock_b.func),
} };
_ = try my_router.dispatch(10, 100, .GET, "/a");
try mock_a.expectCalledOnceWith(10, 100, .{});
try mock_b.expectNotCalled();
mock_a.reset();
_ = try my_router.dispatch(10, 100, .GET, "/b");
try mock_a.expectNotCalled();
try mock_b.expectCalledOnceWith(10, 100, .{});
mock_b.reset();
}
test "Route(T) passes correct context" {
const mock_a = _test.CallTracker(.{}, _test.dummyHandler);
const MyRouter = Router(_test.TestContext, _test.TestContext);
const MyRoute = MyRouter.Route;
const my_router = MyRouter{ .routes = &[_]MyRoute{
MyRoute.new(.GET, "/a", mock_a.func),
} };
_ = try my_router.dispatch(10, 100, .GET, "/a");
try mock_a.expectCalledOnceWith(10, 100, .{});
mock_a.reset();
_ = try my_router.dispatch(16, 32, .GET, "/a");
try mock_a.expectCalledOnceWith(16, 32, .{});
mock_a.reset();
}
test "Route(T) errors on no matching route" {
const mock_a = _test.CallTracker(.{}, _test.dummyHandler);
const mock_b = _test.CallTracker(.{}, _test.dummyHandler);
const MyRouter = Router(_test.TestContext, _test.TestContext);
const MyRoute = MyRouter.Route;
const my_router = MyRouter{ .routes = &[_]MyRoute{
MyRoute.new(.GET, "/a", mock_a.func),
MyRoute.new(.GET, "/b", mock_b.func),
} };
try std.testing.expectError(error.RouteNotApplicable, my_router.dispatch(10, 100, .GET, "/c"));
try mock_a.expectNotCalled();
try mock_b.expectNotCalled();
}
test "route(T) with no routes" {
const MyRouter = Router(_test.TestContext, _test.TestContext);
const MyRoute = MyRouter.Route;
const my_router = MyRouter{ .routes = &[_]MyRoute{} };
try std.testing.expectError(error.RouteNotApplicable, my_router.dispatch(10, 100, .GET, "/c"));
}
test "route(T, ...) same path different methods" {
const mock_get = _test.CallTracker(.{}, _test.dummyHandler);
const mock_post = _test.CallTracker(.{}, _test.dummyHandler);
const MyRouter = Router(_test.TestContext, _test.TestContext);
const MyRoute = MyRouter.Route;
const my_router = MyRouter{ .routes = &[_]MyRoute{
MyRoute.new(.GET, "/a", mock_get.func),
MyRoute.new(.POST, "/a", mock_post.func),
} };
_ = try my_router.dispatch(10, 100, .GET, "/a");
try mock_get.expectCalledOnceWith(10, 100, .{});
try mock_post.expectNotCalled();
mock_get.reset();
_ = try my_router.dispatch(10, 100, .POST, "/a");
try mock_get.expectNotCalled();
try mock_post.expectCalledOnceWith(10, 100, .{});
}
test "route(T, ...) route under subpath" {
const mock_a = _test.CallTracker(.{}, _test.dummyHandler);
const mock_b = _test.CallTracker(.{}, _test.dummyHandler);
const MyRouter = Router(_test.TestContext, _test.TestContext);
const MyRoute = MyRouter.Route;
const my_router = MyRouter{ .routes = &[_]MyRoute{
MyRoute.new(.GET, "/a", mock_a.func),
MyRoute.new(.GET, "/a/b", mock_b.func),
} };
_ = try my_router.dispatch(10, 100, .GET, "/a");
try mock_a.expectCalledOnceWith(10, 100, .{});
try mock_b.expectNotCalled();
mock_a.reset();
_ = try my_router.dispatch(10, 100, .GET, "/a/b");
try mock_a.expectNotCalled();
try mock_b.expectCalledOnceWith(10, 100, .{});
}
test "route(T, ...) case-insensitive route" {
const mock_a = _test.CallTracker(.{}, _test.dummyHandler);
const MyRouter = Router(_test.TestContext, _test.TestContext);
const MyRoute = MyRouter.Route;
const my_router = MyRouter{ .routes = &[_]MyRoute{
MyRoute.new(.GET, "/test/a", mock_a.func),
} };
_ = try my_router.dispatch(10, 100, .GET, "/TEST/A");
try mock_a.expectCalledOnceWith(10, 100, .{});
mock_a.reset();
_ = try my_router.dispatch(10, 100, .GET, "/TesT/a");
try mock_a.expectCalledOnceWith(10, 100, .{});
}
test "route(T, ...) redundant /" {
const mock_a = _test.CallTracker(.{}, _test.dummyHandler);
const MyRouter = Router(_test.TestContext, _test.TestContext);
const MyRoute = MyRouter.Route;
const my_router = MyRouter{ .routes = &[_]MyRoute{
MyRoute.new(.GET, "/test/a", mock_a.func),
} };
_ = try my_router.dispatch(10, 100, .GET, "/test//a");
try mock_a.expectCalledOnceWith(10, 100, .{});
mock_a.reset();
_ = try my_router.dispatch(10, 100, .GET, "//test///////////a////");
try mock_a.expectCalledOnceWith(10, 100, .{});
mock_a.reset();
_ = try my_router.dispatch(10, 100, .GET, "test/a");
try mock_a.expectCalledOnceWith(10, 100, .{});
mock_a.reset();
try std.testing.expectError(error.RouteNotApplicable, my_router.dispatch(10, 100, .GET, "/te/st/a"));
try mock_a.expectNotCalled();
}

View file

@ -5,6 +5,7 @@ const http = @import("http");
const api = @import("api"); const api = @import("api");
const util = @import("util"); const util = @import("util");
const query_utils = @import("./query.zig"); const query_utils = @import("./query.zig");
const json_utils = @import("./json.zig");
pub const auth = @import("./controllers/auth.zig"); pub const auth = @import("./controllers/auth.zig");
pub const communities = @import("./controllers/communities.zig"); pub const communities = @import("./controllers/communities.zig");
@ -12,6 +13,17 @@ pub const invites = @import("./controllers/invites.zig");
pub const users = @import("./controllers/users.zig"); pub const users = @import("./controllers/users.zig");
pub const notes = @import("./controllers/notes.zig"); pub const notes = @import("./controllers/notes.zig");
pub fn routeRequest(api_source: anytype, request: http.Request, response: http.Response, alloc: std.mem.Allocator) void {
// TODO: hashmaps?
inline for (routes) |route| {
if (Context(route).matchAndHandle(api_source, request, response, alloc)) return;
}
// todo 404
}
const routes = .{sample_api};
pub const sample_api = struct { pub const sample_api = struct {
const Self = @This(); const Self = @This();
@ -24,7 +36,7 @@ pub const sample_api = struct {
}; };
pub const Body = struct { pub const Body = struct {
content: []const u8, content: util.Uuid,
}; };
pub const Query = struct { pub const Query = struct {
@ -32,6 +44,7 @@ pub const sample_api = struct {
}; };
pub fn handler(ctx: Context(Self), response: *Response, _: api.ApiSource.Conn) !void { pub fn handler(ctx: Context(Self), response: *Response, _: api.ApiSource.Conn) !void {
std.log.debug("{}", .{ctx.body.content});
try response.writeJson(.created, ctx.query); try response.writeJson(.created, ctx.query);
} }
}; };
@ -115,14 +128,13 @@ pub fn Context(comptime Route: type) type {
fn parseBody(self: *Self, req: http.Request) !void { fn parseBody(self: *Self, req: http.Request) !void {
if (Body != void) { if (Body != void) {
const body = req.body orelse return error.NoBody; const body = req.body orelse return error.NoBody;
var tokens = std.json.TokenStream.init(body); self.body = try json_utils.parse(Body, body, self.allocator);
self.body = try std.json.parse(Body, &tokens, .{ .allocator = self.allocator });
} }
} }
fn freeBody(self: *Self) void { fn freeBody(self: *Self) void {
if (Body != void) { if (Body != void) {
std.json.parseFree(Body, self.body, .{ .allocator = self.allocator }); json_utils.parseFree(self.body, self.allocator);
} }
} }

View file

@ -6,46 +6,8 @@ const util = @import("util");
const api = @import("api"); const api = @import("api");
const migrations = @import("./migrations.zig"); const migrations = @import("./migrations.zig");
const Uuid = util.Uuid;
const c = @import("./controllers.zig"); const c = @import("./controllers.zig");
// this thing is overcomplicated and weird. stop this
const Router = http.Router(*RequestServer);
const Route = Router.Route;
const RouteArgs = http.RouteArgs;
const router = Router{
.routes = &[_]Route{
//Route.new(.GET, "/healthcheck", &c.healthcheck),
//prepare(c.auth.login),
//prepare(c.auth.verify_login),
//prepare(c.communities.create),
//prepare(c.invites.create),
//prepare(c.users.create),
//prepare(c.notes.create),
//prepare(c.notes.get),
//prepare(c.communities.query),
//Route.new(.GET, "/notes/:id/reacts", &c.notes.reacts.list),
//Route.new(.POST, "/notes/:id/reacts", &c.notes.reacts.create),
//Route.new(.GET, "/actors/:id", &c.actors.get),
//Route.new(.GET, "/admin/invites/:id", &c.admin.invites.get),
//Route.new(.GET, "/admin/communities/:host", &c.admin.communities.get),
},
};
fn prepare(comptime route_desc: type) Route {
return Route.new(route_desc.method, route_desc.path, &route_desc.handler);
}
pub const RequestServer = struct { pub const RequestServer = struct {
alloc: std.mem.Allocator, alloc: std.mem.Allocator,
api: *api.ApiSource, api: *api.ApiSource,
@ -74,14 +36,6 @@ pub const RequestServer = struct {
_ = c.Context(c.sample_api).matchAndHandle(self.api, ctx, self.alloc); _ = c.Context(c.sample_api).matchAndHandle(self.api, ctx, self.alloc);
if (true) continue; if (true) continue;
router.dispatch(self, &ctx, ctx.request.method, ctx.request.path) catch |err| switch (err) {
error.NotFound, error.RouteNotApplicable => c.notFound(self, &ctx),
else => {
std.log.err("Unhandled error in controller ({s}): {}\nStack Trace\n{?}", .{ ctx.request.path, err, @errorReturnTrace() });
c.internalServerError(self, &ctx);
},
};
} }
} }
}; };

View file

@ -48,6 +48,7 @@ pub fn format(value: Uuid, comptime _: []const u8, _: std.fmt.FormatOptions, wri
}); });
} }
pub const JsonParseAs = []const u8;
pub fn jsonStringify(value: Uuid, _: std.json.StringifyOptions, writer: anytype) !void { pub fn jsonStringify(value: Uuid, _: std.json.StringifyOptions, writer: anytype) !void {
try std.fmt.format(writer, "\"{}\"", .{value}); try std.fmt.format(writer, "\"{}\"", .{value});
} }